flwr-nightly 1.13.0.dev20241021__py3-none-any.whl → 1.13.0.dev20241023__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/client/app.py +13 -14
- flwr/client/node_state_tests.py +7 -8
- flwr/client/{node_state.py → run_info_store.py} +3 -3
- flwr/client/supernode/app.py +6 -8
- flwr/common/constant.py +31 -3
- flwr/common/typing.py +9 -0
- flwr/server/app.py +121 -10
- flwr/server/driver/inmemory_driver.py +2 -2
- flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
- flwr/server/serverapp/app.py +78 -0
- flwr/server/superlink/driver/driver_grpc.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +9 -7
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
- flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
- flwr/server/superlink/fleet/vce/vce_api.py +23 -23
- flwr/server/superlink/linkstate/__init__.py +28 -0
- flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +109 -19
- flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +59 -11
- flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
- flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +136 -35
- flwr/server/superlink/{state → linkstate}/utils.py +57 -1
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- flwr/simulation/run_simulation.py +15 -7
- flwr/superexec/app.py +9 -2
- flwr/superexec/simulation.py +1 -1
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241023.dist-info}/METADATA +1 -1
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241023.dist-info}/RECORD +34 -32
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241023.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241023.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241023.dist-info}/WHEEL +0 -0
|
@@ -51,14 +51,16 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
51
51
|
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
52
52
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
53
53
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
54
|
-
from flwr.server.superlink.
|
|
54
|
+
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
55
55
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
59
59
|
"""Driver API servicer."""
|
|
60
60
|
|
|
61
|
-
def __init__(
|
|
61
|
+
def __init__(
|
|
62
|
+
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
63
|
+
) -> None:
|
|
62
64
|
self.state_factory = state_factory
|
|
63
65
|
self.ffs_factory = ffs_factory
|
|
64
66
|
|
|
@@ -67,7 +69,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
67
69
|
) -> GetNodesResponse:
|
|
68
70
|
"""Get available nodes."""
|
|
69
71
|
log(DEBUG, "DriverServicer.GetNodes")
|
|
70
|
-
state:
|
|
72
|
+
state: LinkState = self.state_factory.state()
|
|
71
73
|
all_ids: set[int] = state.get_nodes(request.run_id)
|
|
72
74
|
nodes: list[Node] = [
|
|
73
75
|
Node(node_id=node_id, anonymous=False) for node_id in all_ids
|
|
@@ -79,7 +81,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
79
81
|
) -> CreateRunResponse:
|
|
80
82
|
"""Create run ID."""
|
|
81
83
|
log(DEBUG, "DriverServicer.CreateRun")
|
|
82
|
-
state:
|
|
84
|
+
state: LinkState = self.state_factory.state()
|
|
83
85
|
if request.HasField("fab"):
|
|
84
86
|
fab = fab_from_proto(request.fab)
|
|
85
87
|
ffs: Ffs = self.ffs_factory.ffs()
|
|
@@ -116,7 +118,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
116
118
|
_raise_if(bool(validation_errors), ", ".join(validation_errors))
|
|
117
119
|
|
|
118
120
|
# Init state
|
|
119
|
-
state:
|
|
121
|
+
state: LinkState = self.state_factory.state()
|
|
120
122
|
|
|
121
123
|
# Store each TaskIns
|
|
122
124
|
task_ids: list[Optional[UUID]] = []
|
|
@@ -138,7 +140,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
138
140
|
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
|
139
141
|
|
|
140
142
|
# Init state
|
|
141
|
-
state:
|
|
143
|
+
state: LinkState = self.state_factory.state()
|
|
142
144
|
|
|
143
145
|
# Register callback
|
|
144
146
|
def on_rpc_done() -> None:
|
|
@@ -167,7 +169,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
167
169
|
log(DEBUG, "DriverServicer.GetRun")
|
|
168
170
|
|
|
169
171
|
# Init state
|
|
170
|
-
state:
|
|
172
|
+
state: LinkState = self.state_factory.state()
|
|
171
173
|
|
|
172
174
|
# Retrieve run information
|
|
173
175
|
run = state.get_run(request.run_id)
|
|
@@ -48,7 +48,7 @@ from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
|
48
48
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
49
49
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
50
50
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
51
|
-
from flwr.server.superlink.
|
|
51
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
52
52
|
|
|
53
53
|
T = TypeVar("T", bound=GrpcMessage)
|
|
54
54
|
|
|
@@ -77,7 +77,9 @@ def _handle(
|
|
|
77
77
|
class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
|
|
78
78
|
"""Fleet API via GrpcAdapter servicer."""
|
|
79
79
|
|
|
80
|
-
def __init__(
|
|
80
|
+
def __init__(
|
|
81
|
+
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
82
|
+
) -> None:
|
|
81
83
|
self.state_factory = state_factory
|
|
82
84
|
self.ffs_factory = ffs_factory
|
|
83
85
|
|
|
@@ -37,13 +37,15 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
37
37
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
38
38
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
39
39
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
40
|
-
from flwr.server.superlink.
|
|
40
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
44
44
|
"""Fleet API servicer."""
|
|
45
45
|
|
|
46
|
-
def __init__(
|
|
46
|
+
def __init__(
|
|
47
|
+
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
48
|
+
) -> None:
|
|
47
49
|
self.state_factory = state_factory
|
|
48
50
|
self.ffs_factory = ffs_factory
|
|
49
51
|
|
|
@@ -45,7 +45,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
45
45
|
)
|
|
46
46
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
47
47
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
48
|
-
from flwr.server.superlink.
|
|
48
|
+
from flwr.server.superlink.linkstate import LinkState
|
|
49
49
|
|
|
50
50
|
_PUBLIC_KEY_HEADER = "public-key"
|
|
51
51
|
_AUTH_TOKEN_HEADER = "auth-token"
|
|
@@ -84,7 +84,7 @@ def _get_value_from_tuples(
|
|
|
84
84
|
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
85
85
|
"""Server interceptor for node authentication."""
|
|
86
86
|
|
|
87
|
-
def __init__(self, state:
|
|
87
|
+
def __init__(self, state: LinkState):
|
|
88
88
|
self.state = state
|
|
89
89
|
|
|
90
90
|
self.node_public_keys = state.get_node_public_keys()
|
|
@@ -43,12 +43,12 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
43
43
|
)
|
|
44
44
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
45
45
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
46
|
-
from flwr.server.superlink.
|
|
46
|
+
from flwr.server.superlink.linkstate import LinkState
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
def create_node(
|
|
50
50
|
request: CreateNodeRequest, # pylint: disable=unused-argument
|
|
51
|
-
state:
|
|
51
|
+
state: LinkState,
|
|
52
52
|
) -> CreateNodeResponse:
|
|
53
53
|
"""."""
|
|
54
54
|
# Create node
|
|
@@ -56,7 +56,7 @@ def create_node(
|
|
|
56
56
|
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
def delete_node(request: DeleteNodeRequest, state:
|
|
59
|
+
def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
|
|
60
60
|
"""."""
|
|
61
61
|
# Validate node_id
|
|
62
62
|
if request.node.anonymous or request.node.node_id == 0:
|
|
@@ -69,14 +69,14 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
|
|
|
69
69
|
|
|
70
70
|
def ping(
|
|
71
71
|
request: PingRequest, # pylint: disable=unused-argument
|
|
72
|
-
state:
|
|
72
|
+
state: LinkState, # pylint: disable=unused-argument
|
|
73
73
|
) -> PingResponse:
|
|
74
74
|
"""."""
|
|
75
75
|
res = state.acknowledge_ping(request.node.node_id, request.ping_interval)
|
|
76
76
|
return PingResponse(success=res)
|
|
77
77
|
|
|
78
78
|
|
|
79
|
-
def pull_task_ins(request: PullTaskInsRequest, state:
|
|
79
|
+
def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse:
|
|
80
80
|
"""Pull TaskIns handler."""
|
|
81
81
|
# Get node_id if client node is not anonymous
|
|
82
82
|
node = request.node # pylint: disable=no-member
|
|
@@ -92,7 +92,7 @@ def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsRespo
|
|
|
92
92
|
return response
|
|
93
93
|
|
|
94
94
|
|
|
95
|
-
def push_task_res(request: PushTaskResRequest, state:
|
|
95
|
+
def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResResponse:
|
|
96
96
|
"""Push TaskRes handler."""
|
|
97
97
|
# pylint: disable=no-member
|
|
98
98
|
task_res: TaskRes = request.task_res_list[0]
|
|
@@ -113,7 +113,7 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
|
|
|
113
113
|
|
|
114
114
|
|
|
115
115
|
def get_run(
|
|
116
|
-
request: GetRunRequest, state:
|
|
116
|
+
request: GetRunRequest, state: LinkState # pylint: disable=W0613
|
|
117
117
|
) -> GetRunResponse:
|
|
118
118
|
"""Get run information."""
|
|
119
119
|
run = state.get_run(request.run_id)
|
|
@@ -40,7 +40,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
40
40
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
41
41
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
42
42
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
43
|
-
from flwr.server.superlink.
|
|
43
|
+
from flwr.server.superlink.linkstate import LinkState
|
|
44
44
|
|
|
45
45
|
try:
|
|
46
46
|
from starlette.applications import Starlette
|
|
@@ -90,7 +90,7 @@ def rest_request_response(
|
|
|
90
90
|
async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
|
|
91
91
|
"""Create Node."""
|
|
92
92
|
# Get state from app
|
|
93
|
-
state:
|
|
93
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
94
94
|
|
|
95
95
|
# Handle message
|
|
96
96
|
return message_handler.create_node(request=request, state=state)
|
|
@@ -100,7 +100,7 @@ async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
|
|
|
100
100
|
async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
101
101
|
"""Delete Node Id."""
|
|
102
102
|
# Get state from app
|
|
103
|
-
state:
|
|
103
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
104
104
|
|
|
105
105
|
# Handle message
|
|
106
106
|
return message_handler.delete_node(request=request, state=state)
|
|
@@ -110,7 +110,7 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
|
110
110
|
async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
|
|
111
111
|
"""Pull TaskIns."""
|
|
112
112
|
# Get state from app
|
|
113
|
-
state:
|
|
113
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
114
114
|
|
|
115
115
|
# Handle message
|
|
116
116
|
return message_handler.pull_task_ins(request=request, state=state)
|
|
@@ -121,7 +121,7 @@ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
|
|
|
121
121
|
async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
|
122
122
|
"""Push TaskRes."""
|
|
123
123
|
# Get state from app
|
|
124
|
-
state:
|
|
124
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
125
125
|
|
|
126
126
|
# Handle message
|
|
127
127
|
return message_handler.push_task_res(request=request, state=state)
|
|
@@ -131,7 +131,7 @@ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
|
|
131
131
|
async def ping(request: PingRequest) -> PingResponse:
|
|
132
132
|
"""Ping."""
|
|
133
133
|
# Get state from app
|
|
134
|
-
state:
|
|
134
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
135
135
|
|
|
136
136
|
# Handle message
|
|
137
137
|
return message_handler.ping(request=request, state=state)
|
|
@@ -141,7 +141,7 @@ async def ping(request: PingRequest) -> PingResponse:
|
|
|
141
141
|
async def get_run(request: GetRunRequest) -> GetRunResponse:
|
|
142
142
|
"""GetRun."""
|
|
143
143
|
# Get state from app
|
|
144
|
-
state:
|
|
144
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
145
145
|
|
|
146
146
|
# Handle message
|
|
147
147
|
return message_handler.get_run(request=request, state=state)
|
|
@@ -28,7 +28,7 @@ from typing import Callable, Optional
|
|
|
28
28
|
|
|
29
29
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
30
30
|
from flwr.client.clientapp.utils import get_load_client_app_fn
|
|
31
|
-
from flwr.client.
|
|
31
|
+
from flwr.client.run_info_store import DeprecatedRunInfoStore
|
|
32
32
|
from flwr.common.constant import (
|
|
33
33
|
NUM_PARTITIONS_KEY,
|
|
34
34
|
PARTITION_ID_KEY,
|
|
@@ -40,7 +40,7 @@ from flwr.common.message import Error
|
|
|
40
40
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
41
41
|
from flwr.common.typing import Run
|
|
42
42
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
43
|
-
from flwr.server.superlink.
|
|
43
|
+
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
44
44
|
|
|
45
45
|
from .backend import Backend, error_messages_backends, supported_backends
|
|
46
46
|
|
|
@@ -48,7 +48,7 @@ NodeToPartitionMapping = dict[int, int]
|
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
def _register_nodes(
|
|
51
|
-
num_nodes: int, state_factory:
|
|
51
|
+
num_nodes: int, state_factory: LinkStateFactory
|
|
52
52
|
) -> NodeToPartitionMapping:
|
|
53
53
|
"""Register nodes with the StateFactory and create node-id:partition-id mapping."""
|
|
54
54
|
nodes_mapping: NodeToPartitionMapping = {}
|
|
@@ -60,16 +60,16 @@ def _register_nodes(
|
|
|
60
60
|
return nodes_mapping
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def
|
|
63
|
+
def _register_node_info_stores(
|
|
64
64
|
nodes_mapping: NodeToPartitionMapping,
|
|
65
65
|
run: Run,
|
|
66
66
|
app_dir: Optional[str] = None,
|
|
67
|
-
) -> dict[int,
|
|
68
|
-
"""Create
|
|
69
|
-
|
|
67
|
+
) -> dict[int, DeprecatedRunInfoStore]:
|
|
68
|
+
"""Create DeprecatedRunInfoStore objects and register the context for the run."""
|
|
69
|
+
node_info_store: dict[int, DeprecatedRunInfoStore] = {}
|
|
70
70
|
num_partitions = len(set(nodes_mapping.values()))
|
|
71
71
|
for node_id, partition_id in nodes_mapping.items():
|
|
72
|
-
|
|
72
|
+
node_info_store[node_id] = DeprecatedRunInfoStore(
|
|
73
73
|
node_id=node_id,
|
|
74
74
|
node_config={
|
|
75
75
|
PARTITION_ID_KEY: partition_id,
|
|
@@ -78,18 +78,18 @@ def _register_node_states(
|
|
|
78
78
|
)
|
|
79
79
|
|
|
80
80
|
# Pre-register Context objects
|
|
81
|
-
|
|
81
|
+
node_info_store[node_id].register_context(
|
|
82
82
|
run_id=run.run_id, run=run, app_dir=app_dir
|
|
83
83
|
)
|
|
84
84
|
|
|
85
|
-
return
|
|
85
|
+
return node_info_store
|
|
86
86
|
|
|
87
87
|
|
|
88
88
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
89
89
|
def worker(
|
|
90
90
|
taskins_queue: "Queue[TaskIns]",
|
|
91
91
|
taskres_queue: "Queue[TaskRes]",
|
|
92
|
-
|
|
92
|
+
node_info_store: dict[int, DeprecatedRunInfoStore],
|
|
93
93
|
backend: Backend,
|
|
94
94
|
f_stop: threading.Event,
|
|
95
95
|
) -> None:
|
|
@@ -103,7 +103,7 @@ def worker(
|
|
|
103
103
|
node_id = task_ins.task.consumer.node_id
|
|
104
104
|
|
|
105
105
|
# Retrieve context
|
|
106
|
-
context =
|
|
106
|
+
context = node_info_store[node_id].retrieve_context(run_id=task_ins.run_id)
|
|
107
107
|
|
|
108
108
|
# Convert TaskIns to Message
|
|
109
109
|
message = message_from_taskins(task_ins)
|
|
@@ -112,7 +112,7 @@ def worker(
|
|
|
112
112
|
out_mssg, updated_context = backend.process_message(message, context)
|
|
113
113
|
|
|
114
114
|
# Update Context
|
|
115
|
-
|
|
115
|
+
node_info_store[node_id].update_context(
|
|
116
116
|
task_ins.run_id, context=updated_context
|
|
117
117
|
)
|
|
118
118
|
except Empty:
|
|
@@ -145,7 +145,7 @@ def worker(
|
|
|
145
145
|
|
|
146
146
|
|
|
147
147
|
def add_taskins_to_queue(
|
|
148
|
-
state:
|
|
148
|
+
state: LinkState,
|
|
149
149
|
queue: "Queue[TaskIns]",
|
|
150
150
|
nodes_mapping: NodeToPartitionMapping,
|
|
151
151
|
f_stop: threading.Event,
|
|
@@ -160,7 +160,7 @@ def add_taskins_to_queue(
|
|
|
160
160
|
|
|
161
161
|
|
|
162
162
|
def put_taskres_into_state(
|
|
163
|
-
state:
|
|
163
|
+
state: LinkState, queue: "Queue[TaskRes]", f_stop: threading.Event
|
|
164
164
|
) -> None:
|
|
165
165
|
"""Put TaskRes into State from a queue."""
|
|
166
166
|
while not f_stop.is_set():
|
|
@@ -177,8 +177,8 @@ def run_api(
|
|
|
177
177
|
app_fn: Callable[[], ClientApp],
|
|
178
178
|
backend_fn: Callable[[], Backend],
|
|
179
179
|
nodes_mapping: NodeToPartitionMapping,
|
|
180
|
-
state_factory:
|
|
181
|
-
|
|
180
|
+
state_factory: LinkStateFactory,
|
|
181
|
+
node_info_stores: dict[int, DeprecatedRunInfoStore],
|
|
182
182
|
f_stop: threading.Event,
|
|
183
183
|
) -> None:
|
|
184
184
|
"""Run the VCE."""
|
|
@@ -223,7 +223,7 @@ def run_api(
|
|
|
223
223
|
worker,
|
|
224
224
|
taskins_queue,
|
|
225
225
|
taskres_queue,
|
|
226
|
-
|
|
226
|
+
node_info_stores,
|
|
227
227
|
backend,
|
|
228
228
|
f_stop,
|
|
229
229
|
)
|
|
@@ -264,7 +264,7 @@ def start_vce(
|
|
|
264
264
|
client_app: Optional[ClientApp] = None,
|
|
265
265
|
client_app_attr: Optional[str] = None,
|
|
266
266
|
num_supernodes: Optional[int] = None,
|
|
267
|
-
state_factory: Optional[
|
|
267
|
+
state_factory: Optional[LinkStateFactory] = None,
|
|
268
268
|
existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
|
|
269
269
|
) -> None:
|
|
270
270
|
"""Start Fleet API with the Simulation Engine."""
|
|
@@ -303,7 +303,7 @@ def start_vce(
|
|
|
303
303
|
if not state_factory:
|
|
304
304
|
log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
|
|
305
305
|
# Create an empty in-memory state factory
|
|
306
|
-
state_factory =
|
|
306
|
+
state_factory = LinkStateFactory(":flwr-in-memory-state:")
|
|
307
307
|
log(INFO, "Created new %s.", state_factory.__class__.__name__)
|
|
308
308
|
|
|
309
309
|
if num_supernodes:
|
|
@@ -312,8 +312,8 @@ def start_vce(
|
|
|
312
312
|
num_nodes=num_supernodes, state_factory=state_factory
|
|
313
313
|
)
|
|
314
314
|
|
|
315
|
-
# Construct mapping of
|
|
316
|
-
|
|
315
|
+
# Construct mapping of DeprecatedRunInfoStore
|
|
316
|
+
node_info_stores = _register_node_info_stores(
|
|
317
317
|
nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None
|
|
318
318
|
)
|
|
319
319
|
|
|
@@ -376,7 +376,7 @@ def start_vce(
|
|
|
376
376
|
backend_fn,
|
|
377
377
|
nodes_mapping,
|
|
378
378
|
state_factory,
|
|
379
|
-
|
|
379
|
+
node_info_stores,
|
|
380
380
|
f_stop,
|
|
381
381
|
)
|
|
382
382
|
except LoadClientAppError as loadapp_ex:
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower LinkState."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from .in_memory_linkstate import InMemoryLinkState as InMemoryLinkState
|
|
19
|
+
from .linkstate import LinkState as LinkState
|
|
20
|
+
from .linkstate_factory import LinkStateFactory as LinkStateFactory
|
|
21
|
+
from .sqlite_linkstate import SqliteLinkState as SqliteLinkState
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"InMemoryLinkState",
|
|
25
|
+
"LinkState",
|
|
26
|
+
"LinkStateFactory",
|
|
27
|
+
"SqliteLinkState",
|
|
28
|
+
]
|
|
@@ -12,11 +12,12 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""In-memory
|
|
15
|
+
"""In-memory LinkState implementation."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
19
|
import time
|
|
20
|
+
from dataclasses import dataclass
|
|
20
21
|
from logging import ERROR, WARNING
|
|
21
22
|
from typing import Optional
|
|
22
23
|
from uuid import UUID, uuid4
|
|
@@ -26,17 +27,35 @@ from flwr.common.constant import (
|
|
|
26
27
|
MESSAGE_TTL_TOLERANCE,
|
|
27
28
|
NODE_ID_NUM_BYTES,
|
|
28
29
|
RUN_ID_NUM_BYTES,
|
|
30
|
+
Status,
|
|
29
31
|
)
|
|
30
|
-
from flwr.common.typing import Run, UserConfig
|
|
32
|
+
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
31
33
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
32
|
-
from flwr.server.superlink.
|
|
34
|
+
from flwr.server.superlink.linkstate.linkstate import LinkState
|
|
33
35
|
from flwr.server.utils import validate_task_ins_or_res
|
|
34
36
|
|
|
35
|
-
from .utils import
|
|
37
|
+
from .utils import (
|
|
38
|
+
generate_rand_int_from_bytes,
|
|
39
|
+
has_valid_sub_status,
|
|
40
|
+
is_valid_transition,
|
|
41
|
+
make_node_unavailable_taskres,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class RunRecord:
|
|
47
|
+
"""The record of a specific run, including its status and timestamps."""
|
|
36
48
|
|
|
49
|
+
run: Run
|
|
50
|
+
status: RunStatus
|
|
51
|
+
pending_at: str = ""
|
|
52
|
+
starting_at: str = ""
|
|
53
|
+
running_at: str = ""
|
|
54
|
+
finished_at: str = ""
|
|
37
55
|
|
|
38
|
-
|
|
39
|
-
|
|
56
|
+
|
|
57
|
+
class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
58
|
+
"""In-memory LinkState implementation."""
|
|
40
59
|
|
|
41
60
|
def __init__(self) -> None:
|
|
42
61
|
|
|
@@ -44,8 +63,8 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
44
63
|
self.node_ids: dict[int, tuple[float, float]] = {}
|
|
45
64
|
self.public_key_to_node_id: dict[bytes, int] = {}
|
|
46
65
|
|
|
47
|
-
# Map run_id to
|
|
48
|
-
self.run_ids: dict[int,
|
|
66
|
+
# Map run_id to RunRecord
|
|
67
|
+
self.run_ids: dict[int, RunRecord] = {}
|
|
49
68
|
self.task_ins_store: dict[UUID, TaskIns] = {}
|
|
50
69
|
self.task_res_store: dict[UUID, TaskRes] = {}
|
|
51
70
|
|
|
@@ -277,7 +296,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
277
296
|
def create_node(
|
|
278
297
|
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
279
298
|
) -> int:
|
|
280
|
-
"""Create, store in state, and return `node_id`."""
|
|
299
|
+
"""Create, store in the link state, and return `node_id`."""
|
|
281
300
|
# Sample a random int64 as node_id
|
|
282
301
|
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
|
283
302
|
|
|
@@ -351,13 +370,22 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
351
370
|
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
352
371
|
|
|
353
372
|
if run_id not in self.run_ids:
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
373
|
+
run_record = RunRecord(
|
|
374
|
+
run=Run(
|
|
375
|
+
run_id=run_id,
|
|
376
|
+
fab_id=fab_id if fab_id else "",
|
|
377
|
+
fab_version=fab_version if fab_version else "",
|
|
378
|
+
fab_hash=fab_hash if fab_hash else "",
|
|
379
|
+
override_config=override_config,
|
|
380
|
+
),
|
|
381
|
+
status=RunStatus(
|
|
382
|
+
status=Status.PENDING,
|
|
383
|
+
sub_status="",
|
|
384
|
+
details="",
|
|
385
|
+
),
|
|
386
|
+
pending_at=now().isoformat(),
|
|
360
387
|
)
|
|
388
|
+
self.run_ids[run_id] = run_record
|
|
361
389
|
return run_id
|
|
362
390
|
log(ERROR, "Unexpected run creation failure.")
|
|
363
391
|
return 0
|
|
@@ -365,7 +393,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
365
393
|
def store_server_private_public_key(
|
|
366
394
|
self, private_key: bytes, public_key: bytes
|
|
367
395
|
) -> None:
|
|
368
|
-
"""Store `server_private_key` and `server_public_key` in state."""
|
|
396
|
+
"""Store `server_private_key` and `server_public_key` in the link state."""
|
|
369
397
|
with self.lock:
|
|
370
398
|
if self.server_private_key is None and self.server_public_key is None:
|
|
371
399
|
self.server_private_key = private_key
|
|
@@ -382,12 +410,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
382
410
|
return self.server_public_key
|
|
383
411
|
|
|
384
412
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
385
|
-
"""Store a set of `node_public_keys` in state."""
|
|
413
|
+
"""Store a set of `node_public_keys` in the link state."""
|
|
386
414
|
with self.lock:
|
|
387
415
|
self.node_public_keys = public_keys
|
|
388
416
|
|
|
389
417
|
def store_node_public_key(self, public_key: bytes) -> None:
|
|
390
|
-
"""Store a `node_public_key` in state."""
|
|
418
|
+
"""Store a `node_public_key` in the link state."""
|
|
391
419
|
with self.lock:
|
|
392
420
|
self.node_public_keys.add(public_key)
|
|
393
421
|
|
|
@@ -401,7 +429,69 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
401
429
|
if run_id not in self.run_ids:
|
|
402
430
|
log(ERROR, "`run_id` is invalid")
|
|
403
431
|
return None
|
|
404
|
-
return self.run_ids[run_id]
|
|
432
|
+
return self.run_ids[run_id].run
|
|
433
|
+
|
|
434
|
+
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
435
|
+
"""Retrieve the statuses for the specified runs."""
|
|
436
|
+
with self.lock:
|
|
437
|
+
return {
|
|
438
|
+
run_id: self.run_ids[run_id].status
|
|
439
|
+
for run_id in set(run_ids)
|
|
440
|
+
if run_id in self.run_ids
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
444
|
+
"""Update the status of the run with the specified `run_id`."""
|
|
445
|
+
with self.lock:
|
|
446
|
+
# Check if the run_id exists
|
|
447
|
+
if run_id not in self.run_ids:
|
|
448
|
+
log(ERROR, "`run_id` is invalid")
|
|
449
|
+
return False
|
|
450
|
+
|
|
451
|
+
# Check if the status transition is valid
|
|
452
|
+
current_status = self.run_ids[run_id].status
|
|
453
|
+
if not is_valid_transition(current_status, new_status):
|
|
454
|
+
log(
|
|
455
|
+
ERROR,
|
|
456
|
+
'Invalid status transition: from "%s" to "%s"',
|
|
457
|
+
current_status.status,
|
|
458
|
+
new_status.status,
|
|
459
|
+
)
|
|
460
|
+
return False
|
|
461
|
+
|
|
462
|
+
# Check if the sub-status is valid
|
|
463
|
+
if not has_valid_sub_status(current_status):
|
|
464
|
+
log(
|
|
465
|
+
ERROR,
|
|
466
|
+
'Invalid sub-status "%s" for status "%s"',
|
|
467
|
+
current_status.sub_status,
|
|
468
|
+
current_status.status,
|
|
469
|
+
)
|
|
470
|
+
return False
|
|
471
|
+
|
|
472
|
+
# Update the status
|
|
473
|
+
run_record = self.run_ids[run_id]
|
|
474
|
+
if new_status.status == Status.STARTING:
|
|
475
|
+
run_record.starting_at = now().isoformat()
|
|
476
|
+
elif new_status.status == Status.RUNNING:
|
|
477
|
+
run_record.running_at = now().isoformat()
|
|
478
|
+
elif new_status.status == Status.FINISHED:
|
|
479
|
+
run_record.finished_at = now().isoformat()
|
|
480
|
+
run_record.status = new_status
|
|
481
|
+
return True
|
|
482
|
+
|
|
483
|
+
def get_pending_run_id(self) -> Optional[int]:
|
|
484
|
+
"""Get the `run_id` of a run with `Status.PENDING` status, if any."""
|
|
485
|
+
pending_run_id = None
|
|
486
|
+
|
|
487
|
+
# Loop through all registered runs
|
|
488
|
+
for run_id, run_rec in self.run_ids.items():
|
|
489
|
+
# Break once a pending run is found
|
|
490
|
+
if run_rec.status.status == Status.PENDING:
|
|
491
|
+
pending_run_id = run_id
|
|
492
|
+
break
|
|
493
|
+
|
|
494
|
+
return pending_run_id
|
|
405
495
|
|
|
406
496
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
407
497
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|