flwr-nightly 1.13.0.dev20241019__py3-none-any.whl → 1.13.0.dev20241106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +2 -2
- flwr/cli/config_utils.py +97 -0
- flwr/cli/log.py +63 -97
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +18 -83
- flwr/client/app.py +13 -14
- flwr/client/clientapp/app.py +1 -2
- flwr/client/{node_state.py → run_info_store.py} +4 -3
- flwr/client/supernode/app.py +6 -8
- flwr/common/constant.py +39 -4
- flwr/common/context.py +9 -4
- flwr/common/date.py +3 -3
- flwr/common/logger.py +103 -0
- flwr/common/serde.py +24 -0
- flwr/common/telemetry.py +0 -6
- flwr/common/typing.py +9 -0
- flwr/proto/exec_pb2.py +6 -6
- flwr/proto/exec_pb2.pyi +8 -2
- flwr/proto/log_pb2.py +29 -0
- flwr/proto/log_pb2.pyi +39 -0
- flwr/proto/log_pb2_grpc.py +4 -0
- flwr/proto/log_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +8 -8
- flwr/proto/message_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +52 -0
- flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
- flwr/proto/serverappio_pb2_grpc.py +376 -0
- flwr/proto/serverappio_pb2_grpc.pyi +147 -0
- flwr/proto/simulationio_pb2.py +38 -0
- flwr/proto/simulationio_pb2.pyi +65 -0
- flwr/proto/simulationio_pb2_grpc.py +171 -0
- flwr/proto/simulationio_pb2_grpc.pyi +68 -0
- flwr/server/app.py +247 -105
- flwr/server/driver/driver.py +15 -1
- flwr/server/driver/grpc_driver.py +26 -33
- flwr/server/driver/inmemory_driver.py +6 -14
- flwr/server/run_serverapp.py +29 -23
- flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
- flwr/server/serverapp/app.py +270 -0
- flwr/server/strategy/fedadam.py +11 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
- flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
- flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
- flwr/server/superlink/fleet/vce/vce_api.py +23 -23
- flwr/server/superlink/linkstate/__init__.py +28 -0
- flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +180 -21
- flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +144 -15
- flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
- flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +300 -50
- flwr/server/superlink/{state → linkstate}/utils.py +84 -2
- flwr/server/superlink/simulation/__init__.py +15 -0
- flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
- flwr/server/superlink/simulation/simulationio_servicer.py +132 -0
- flwr/simulation/__init__.py +2 -0
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- flwr/simulation/run_simulation.py +57 -131
- flwr/simulation/simulationio_connection.py +86 -0
- flwr/superexec/app.py +6 -134
- flwr/superexec/deployment.py +60 -65
- flwr/superexec/exec_grpc.py +15 -8
- flwr/superexec/exec_servicer.py +34 -63
- flwr/superexec/executor.py +22 -4
- flwr/superexec/simulation.py +13 -8
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/METADATA +1 -1
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/RECORD +77 -64
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/entry_points.txt +1 -0
- flwr/client/node_state_tests.py +0 -66
- flwr/proto/driver_pb2.py +0 -42
- flwr/proto/driver_pb2_grpc.py +0 -239
- flwr/proto/driver_pb2_grpc.pyi +0 -94
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/WHEEL +0 -0
|
@@ -37,13 +37,15 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
37
37
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
38
38
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
39
39
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
40
|
-
from flwr.server.superlink.
|
|
40
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
44
44
|
"""Fleet API servicer."""
|
|
45
45
|
|
|
46
|
-
def __init__(
|
|
46
|
+
def __init__(
|
|
47
|
+
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
48
|
+
) -> None:
|
|
47
49
|
self.state_factory = state_factory
|
|
48
50
|
self.ffs_factory = ffs_factory
|
|
49
51
|
|
|
@@ -45,7 +45,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
45
45
|
)
|
|
46
46
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
47
47
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
48
|
-
from flwr.server.superlink.
|
|
48
|
+
from flwr.server.superlink.linkstate import LinkState
|
|
49
49
|
|
|
50
50
|
_PUBLIC_KEY_HEADER = "public-key"
|
|
51
51
|
_AUTH_TOKEN_HEADER = "auth-token"
|
|
@@ -84,7 +84,7 @@ def _get_value_from_tuples(
|
|
|
84
84
|
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
85
85
|
"""Server interceptor for node authentication."""
|
|
86
86
|
|
|
87
|
-
def __init__(self, state:
|
|
87
|
+
def __init__(self, state: LinkState):
|
|
88
88
|
self.state = state
|
|
89
89
|
|
|
90
90
|
self.node_public_keys = state.get_node_public_keys()
|
|
@@ -43,12 +43,12 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
43
43
|
)
|
|
44
44
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
45
45
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
46
|
-
from flwr.server.superlink.
|
|
46
|
+
from flwr.server.superlink.linkstate import LinkState
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
def create_node(
|
|
50
50
|
request: CreateNodeRequest, # pylint: disable=unused-argument
|
|
51
|
-
state:
|
|
51
|
+
state: LinkState,
|
|
52
52
|
) -> CreateNodeResponse:
|
|
53
53
|
"""."""
|
|
54
54
|
# Create node
|
|
@@ -56,7 +56,7 @@ def create_node(
|
|
|
56
56
|
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
def delete_node(request: DeleteNodeRequest, state:
|
|
59
|
+
def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
|
|
60
60
|
"""."""
|
|
61
61
|
# Validate node_id
|
|
62
62
|
if request.node.anonymous or request.node.node_id == 0:
|
|
@@ -69,14 +69,14 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
|
|
|
69
69
|
|
|
70
70
|
def ping(
|
|
71
71
|
request: PingRequest, # pylint: disable=unused-argument
|
|
72
|
-
state:
|
|
72
|
+
state: LinkState, # pylint: disable=unused-argument
|
|
73
73
|
) -> PingResponse:
|
|
74
74
|
"""."""
|
|
75
75
|
res = state.acknowledge_ping(request.node.node_id, request.ping_interval)
|
|
76
76
|
return PingResponse(success=res)
|
|
77
77
|
|
|
78
78
|
|
|
79
|
-
def pull_task_ins(request: PullTaskInsRequest, state:
|
|
79
|
+
def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse:
|
|
80
80
|
"""Pull TaskIns handler."""
|
|
81
81
|
# Get node_id if client node is not anonymous
|
|
82
82
|
node = request.node # pylint: disable=no-member
|
|
@@ -92,7 +92,7 @@ def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsRespo
|
|
|
92
92
|
return response
|
|
93
93
|
|
|
94
94
|
|
|
95
|
-
def push_task_res(request: PushTaskResRequest, state:
|
|
95
|
+
def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResResponse:
|
|
96
96
|
"""Push TaskRes handler."""
|
|
97
97
|
# pylint: disable=no-member
|
|
98
98
|
task_res: TaskRes = request.task_res_list[0]
|
|
@@ -113,7 +113,7 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
|
|
|
113
113
|
|
|
114
114
|
|
|
115
115
|
def get_run(
|
|
116
|
-
request: GetRunRequest, state:
|
|
116
|
+
request: GetRunRequest, state: LinkState # pylint: disable=W0613
|
|
117
117
|
) -> GetRunResponse:
|
|
118
118
|
"""Get run information."""
|
|
119
119
|
run = state.get_run(request.run_id)
|
|
@@ -40,7 +40,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
40
40
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
41
41
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
42
42
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
43
|
-
from flwr.server.superlink.
|
|
43
|
+
from flwr.server.superlink.linkstate import LinkState
|
|
44
44
|
|
|
45
45
|
try:
|
|
46
46
|
from starlette.applications import Starlette
|
|
@@ -90,7 +90,7 @@ def rest_request_response(
|
|
|
90
90
|
async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
|
|
91
91
|
"""Create Node."""
|
|
92
92
|
# Get state from app
|
|
93
|
-
state:
|
|
93
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
94
94
|
|
|
95
95
|
# Handle message
|
|
96
96
|
return message_handler.create_node(request=request, state=state)
|
|
@@ -100,7 +100,7 @@ async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
|
|
|
100
100
|
async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
101
101
|
"""Delete Node Id."""
|
|
102
102
|
# Get state from app
|
|
103
|
-
state:
|
|
103
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
104
104
|
|
|
105
105
|
# Handle message
|
|
106
106
|
return message_handler.delete_node(request=request, state=state)
|
|
@@ -110,7 +110,7 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
|
110
110
|
async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
|
|
111
111
|
"""Pull TaskIns."""
|
|
112
112
|
# Get state from app
|
|
113
|
-
state:
|
|
113
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
114
114
|
|
|
115
115
|
# Handle message
|
|
116
116
|
return message_handler.pull_task_ins(request=request, state=state)
|
|
@@ -121,7 +121,7 @@ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
|
|
|
121
121
|
async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
|
122
122
|
"""Push TaskRes."""
|
|
123
123
|
# Get state from app
|
|
124
|
-
state:
|
|
124
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
125
125
|
|
|
126
126
|
# Handle message
|
|
127
127
|
return message_handler.push_task_res(request=request, state=state)
|
|
@@ -131,7 +131,7 @@ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
|
|
131
131
|
async def ping(request: PingRequest) -> PingResponse:
|
|
132
132
|
"""Ping."""
|
|
133
133
|
# Get state from app
|
|
134
|
-
state:
|
|
134
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
135
135
|
|
|
136
136
|
# Handle message
|
|
137
137
|
return message_handler.ping(request=request, state=state)
|
|
@@ -141,7 +141,7 @@ async def ping(request: PingRequest) -> PingResponse:
|
|
|
141
141
|
async def get_run(request: GetRunRequest) -> GetRunResponse:
|
|
142
142
|
"""GetRun."""
|
|
143
143
|
# Get state from app
|
|
144
|
-
state:
|
|
144
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
145
145
|
|
|
146
146
|
# Handle message
|
|
147
147
|
return message_handler.get_run(request=request, state=state)
|
|
@@ -28,7 +28,7 @@ from typing import Callable, Optional
|
|
|
28
28
|
|
|
29
29
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
30
30
|
from flwr.client.clientapp.utils import get_load_client_app_fn
|
|
31
|
-
from flwr.client.
|
|
31
|
+
from flwr.client.run_info_store import DeprecatedRunInfoStore
|
|
32
32
|
from flwr.common.constant import (
|
|
33
33
|
NUM_PARTITIONS_KEY,
|
|
34
34
|
PARTITION_ID_KEY,
|
|
@@ -40,7 +40,7 @@ from flwr.common.message import Error
|
|
|
40
40
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
41
41
|
from flwr.common.typing import Run
|
|
42
42
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
43
|
-
from flwr.server.superlink.
|
|
43
|
+
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
44
44
|
|
|
45
45
|
from .backend import Backend, error_messages_backends, supported_backends
|
|
46
46
|
|
|
@@ -48,7 +48,7 @@ NodeToPartitionMapping = dict[int, int]
|
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
def _register_nodes(
|
|
51
|
-
num_nodes: int, state_factory:
|
|
51
|
+
num_nodes: int, state_factory: LinkStateFactory
|
|
52
52
|
) -> NodeToPartitionMapping:
|
|
53
53
|
"""Register nodes with the StateFactory and create node-id:partition-id mapping."""
|
|
54
54
|
nodes_mapping: NodeToPartitionMapping = {}
|
|
@@ -60,16 +60,16 @@ def _register_nodes(
|
|
|
60
60
|
return nodes_mapping
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def
|
|
63
|
+
def _register_node_info_stores(
|
|
64
64
|
nodes_mapping: NodeToPartitionMapping,
|
|
65
65
|
run: Run,
|
|
66
66
|
app_dir: Optional[str] = None,
|
|
67
|
-
) -> dict[int,
|
|
68
|
-
"""Create
|
|
69
|
-
|
|
67
|
+
) -> dict[int, DeprecatedRunInfoStore]:
|
|
68
|
+
"""Create DeprecatedRunInfoStore objects and register the context for the run."""
|
|
69
|
+
node_info_store: dict[int, DeprecatedRunInfoStore] = {}
|
|
70
70
|
num_partitions = len(set(nodes_mapping.values()))
|
|
71
71
|
for node_id, partition_id in nodes_mapping.items():
|
|
72
|
-
|
|
72
|
+
node_info_store[node_id] = DeprecatedRunInfoStore(
|
|
73
73
|
node_id=node_id,
|
|
74
74
|
node_config={
|
|
75
75
|
PARTITION_ID_KEY: partition_id,
|
|
@@ -78,18 +78,18 @@ def _register_node_states(
|
|
|
78
78
|
)
|
|
79
79
|
|
|
80
80
|
# Pre-register Context objects
|
|
81
|
-
|
|
81
|
+
node_info_store[node_id].register_context(
|
|
82
82
|
run_id=run.run_id, run=run, app_dir=app_dir
|
|
83
83
|
)
|
|
84
84
|
|
|
85
|
-
return
|
|
85
|
+
return node_info_store
|
|
86
86
|
|
|
87
87
|
|
|
88
88
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
89
89
|
def worker(
|
|
90
90
|
taskins_queue: "Queue[TaskIns]",
|
|
91
91
|
taskres_queue: "Queue[TaskRes]",
|
|
92
|
-
|
|
92
|
+
node_info_store: dict[int, DeprecatedRunInfoStore],
|
|
93
93
|
backend: Backend,
|
|
94
94
|
f_stop: threading.Event,
|
|
95
95
|
) -> None:
|
|
@@ -103,7 +103,7 @@ def worker(
|
|
|
103
103
|
node_id = task_ins.task.consumer.node_id
|
|
104
104
|
|
|
105
105
|
# Retrieve context
|
|
106
|
-
context =
|
|
106
|
+
context = node_info_store[node_id].retrieve_context(run_id=task_ins.run_id)
|
|
107
107
|
|
|
108
108
|
# Convert TaskIns to Message
|
|
109
109
|
message = message_from_taskins(task_ins)
|
|
@@ -112,7 +112,7 @@ def worker(
|
|
|
112
112
|
out_mssg, updated_context = backend.process_message(message, context)
|
|
113
113
|
|
|
114
114
|
# Update Context
|
|
115
|
-
|
|
115
|
+
node_info_store[node_id].update_context(
|
|
116
116
|
task_ins.run_id, context=updated_context
|
|
117
117
|
)
|
|
118
118
|
except Empty:
|
|
@@ -145,7 +145,7 @@ def worker(
|
|
|
145
145
|
|
|
146
146
|
|
|
147
147
|
def add_taskins_to_queue(
|
|
148
|
-
state:
|
|
148
|
+
state: LinkState,
|
|
149
149
|
queue: "Queue[TaskIns]",
|
|
150
150
|
nodes_mapping: NodeToPartitionMapping,
|
|
151
151
|
f_stop: threading.Event,
|
|
@@ -160,7 +160,7 @@ def add_taskins_to_queue(
|
|
|
160
160
|
|
|
161
161
|
|
|
162
162
|
def put_taskres_into_state(
|
|
163
|
-
state:
|
|
163
|
+
state: LinkState, queue: "Queue[TaskRes]", f_stop: threading.Event
|
|
164
164
|
) -> None:
|
|
165
165
|
"""Put TaskRes into State from a queue."""
|
|
166
166
|
while not f_stop.is_set():
|
|
@@ -177,8 +177,8 @@ def run_api(
|
|
|
177
177
|
app_fn: Callable[[], ClientApp],
|
|
178
178
|
backend_fn: Callable[[], Backend],
|
|
179
179
|
nodes_mapping: NodeToPartitionMapping,
|
|
180
|
-
state_factory:
|
|
181
|
-
|
|
180
|
+
state_factory: LinkStateFactory,
|
|
181
|
+
node_info_stores: dict[int, DeprecatedRunInfoStore],
|
|
182
182
|
f_stop: threading.Event,
|
|
183
183
|
) -> None:
|
|
184
184
|
"""Run the VCE."""
|
|
@@ -223,7 +223,7 @@ def run_api(
|
|
|
223
223
|
worker,
|
|
224
224
|
taskins_queue,
|
|
225
225
|
taskres_queue,
|
|
226
|
-
|
|
226
|
+
node_info_stores,
|
|
227
227
|
backend,
|
|
228
228
|
f_stop,
|
|
229
229
|
)
|
|
@@ -264,7 +264,7 @@ def start_vce(
|
|
|
264
264
|
client_app: Optional[ClientApp] = None,
|
|
265
265
|
client_app_attr: Optional[str] = None,
|
|
266
266
|
num_supernodes: Optional[int] = None,
|
|
267
|
-
state_factory: Optional[
|
|
267
|
+
state_factory: Optional[LinkStateFactory] = None,
|
|
268
268
|
existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
|
|
269
269
|
) -> None:
|
|
270
270
|
"""Start Fleet API with the Simulation Engine."""
|
|
@@ -303,7 +303,7 @@ def start_vce(
|
|
|
303
303
|
if not state_factory:
|
|
304
304
|
log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
|
|
305
305
|
# Create an empty in-memory state factory
|
|
306
|
-
state_factory =
|
|
306
|
+
state_factory = LinkStateFactory(":flwr-in-memory-state:")
|
|
307
307
|
log(INFO, "Created new %s.", state_factory.__class__.__name__)
|
|
308
308
|
|
|
309
309
|
if num_supernodes:
|
|
@@ -312,8 +312,8 @@ def start_vce(
|
|
|
312
312
|
num_nodes=num_supernodes, state_factory=state_factory
|
|
313
313
|
)
|
|
314
314
|
|
|
315
|
-
# Construct mapping of
|
|
316
|
-
|
|
315
|
+
# Construct mapping of DeprecatedRunInfoStore
|
|
316
|
+
node_info_stores = _register_node_info_stores(
|
|
317
317
|
nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None
|
|
318
318
|
)
|
|
319
319
|
|
|
@@ -376,7 +376,7 @@ def start_vce(
|
|
|
376
376
|
backend_fn,
|
|
377
377
|
nodes_mapping,
|
|
378
378
|
state_factory,
|
|
379
|
-
|
|
379
|
+
node_info_stores,
|
|
380
380
|
f_stop,
|
|
381
381
|
)
|
|
382
382
|
except LoadClientAppError as loadapp_ex:
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower LinkState."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from .in_memory_linkstate import InMemoryLinkState as InMemoryLinkState
|
|
19
|
+
from .linkstate import LinkState as LinkState
|
|
20
|
+
from .linkstate_factory import LinkStateFactory as LinkStateFactory
|
|
21
|
+
from .sqlite_linkstate import SqliteLinkState as SqliteLinkState
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"InMemoryLinkState",
|
|
25
|
+
"LinkState",
|
|
26
|
+
"LinkStateFactory",
|
|
27
|
+
"SqliteLinkState",
|
|
28
|
+
]
|
|
@@ -12,31 +12,54 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""In-memory
|
|
15
|
+
"""In-memory LinkState implementation."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
19
|
import time
|
|
20
|
+
from bisect import bisect_right
|
|
21
|
+
from dataclasses import dataclass, field
|
|
20
22
|
from logging import ERROR, WARNING
|
|
21
23
|
from typing import Optional
|
|
22
24
|
from uuid import UUID, uuid4
|
|
23
25
|
|
|
24
|
-
from flwr.common import log, now
|
|
26
|
+
from flwr.common import Context, log, now
|
|
25
27
|
from flwr.common.constant import (
|
|
26
28
|
MESSAGE_TTL_TOLERANCE,
|
|
27
29
|
NODE_ID_NUM_BYTES,
|
|
28
30
|
RUN_ID_NUM_BYTES,
|
|
31
|
+
Status,
|
|
29
32
|
)
|
|
30
|
-
from flwr.common.
|
|
33
|
+
from flwr.common.record import ConfigsRecord
|
|
34
|
+
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
31
35
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
32
|
-
from flwr.server.superlink.
|
|
36
|
+
from flwr.server.superlink.linkstate.linkstate import LinkState
|
|
33
37
|
from flwr.server.utils import validate_task_ins_or_res
|
|
34
38
|
|
|
35
|
-
from .utils import
|
|
39
|
+
from .utils import (
|
|
40
|
+
generate_rand_int_from_bytes,
|
|
41
|
+
has_valid_sub_status,
|
|
42
|
+
is_valid_transition,
|
|
43
|
+
make_node_unavailable_taskres,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class RunRecord: # pylint: disable=R0902
|
|
49
|
+
"""The record of a specific run, including its status and timestamps."""
|
|
36
50
|
|
|
51
|
+
run: Run
|
|
52
|
+
status: RunStatus
|
|
53
|
+
pending_at: str = ""
|
|
54
|
+
starting_at: str = ""
|
|
55
|
+
running_at: str = ""
|
|
56
|
+
finished_at: str = ""
|
|
57
|
+
logs: list[tuple[float, str]] = field(default_factory=list)
|
|
58
|
+
log_lock: threading.Lock = field(default_factory=threading.Lock)
|
|
37
59
|
|
|
38
|
-
|
|
39
|
-
|
|
60
|
+
|
|
61
|
+
class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
62
|
+
"""In-memory LinkState implementation."""
|
|
40
63
|
|
|
41
64
|
def __init__(self) -> None:
|
|
42
65
|
|
|
@@ -44,8 +67,10 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
44
67
|
self.node_ids: dict[int, tuple[float, float]] = {}
|
|
45
68
|
self.public_key_to_node_id: dict[bytes, int] = {}
|
|
46
69
|
|
|
47
|
-
# Map run_id to
|
|
48
|
-
self.run_ids: dict[int,
|
|
70
|
+
# Map run_id to RunRecord
|
|
71
|
+
self.run_ids: dict[int, RunRecord] = {}
|
|
72
|
+
self.contexts: dict[int, Context] = {}
|
|
73
|
+
self.federation_options: dict[int, ConfigsRecord] = {}
|
|
49
74
|
self.task_ins_store: dict[UUID, TaskIns] = {}
|
|
50
75
|
self.task_res_store: dict[UUID, TaskRes] = {}
|
|
51
76
|
|
|
@@ -64,8 +89,25 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
64
89
|
return None
|
|
65
90
|
# Validate run_id
|
|
66
91
|
if task_ins.run_id not in self.run_ids:
|
|
67
|
-
log(ERROR, "
|
|
92
|
+
log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
|
|
68
93
|
return None
|
|
94
|
+
# Validate source node ID
|
|
95
|
+
if task_ins.task.producer.node_id != 0:
|
|
96
|
+
log(
|
|
97
|
+
ERROR,
|
|
98
|
+
"Invalid source node ID for TaskIns: %s",
|
|
99
|
+
task_ins.task.producer.node_id,
|
|
100
|
+
)
|
|
101
|
+
return None
|
|
102
|
+
# Validate destination node ID
|
|
103
|
+
if not task_ins.task.consumer.anonymous:
|
|
104
|
+
if task_ins.task.consumer.node_id not in self.node_ids:
|
|
105
|
+
log(
|
|
106
|
+
ERROR,
|
|
107
|
+
"Invalid destination node ID for TaskIns: %s",
|
|
108
|
+
task_ins.task.consumer.node_id,
|
|
109
|
+
)
|
|
110
|
+
return None
|
|
69
111
|
|
|
70
112
|
# Create task_id
|
|
71
113
|
task_id = uuid4()
|
|
@@ -277,7 +319,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
277
319
|
def create_node(
|
|
278
320
|
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
279
321
|
) -> int:
|
|
280
|
-
"""Create, store in state, and return `node_id`."""
|
|
322
|
+
"""Create, store in the link state, and return `node_id`."""
|
|
281
323
|
# Sample a random int64 as node_id
|
|
282
324
|
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
|
283
325
|
|
|
@@ -338,12 +380,14 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
338
380
|
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
|
339
381
|
return self.public_key_to_node_id.get(node_public_key)
|
|
340
382
|
|
|
383
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
341
384
|
def create_run(
|
|
342
385
|
self,
|
|
343
386
|
fab_id: Optional[str],
|
|
344
387
|
fab_version: Optional[str],
|
|
345
388
|
fab_hash: Optional[str],
|
|
346
389
|
override_config: UserConfig,
|
|
390
|
+
federation_options: ConfigsRecord,
|
|
347
391
|
) -> int:
|
|
348
392
|
"""Create a new run for the specified `fab_hash`."""
|
|
349
393
|
# Sample a random int64 as run_id
|
|
@@ -351,13 +395,25 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
351
395
|
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
352
396
|
|
|
353
397
|
if run_id not in self.run_ids:
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
398
|
+
run_record = RunRecord(
|
|
399
|
+
run=Run(
|
|
400
|
+
run_id=run_id,
|
|
401
|
+
fab_id=fab_id if fab_id else "",
|
|
402
|
+
fab_version=fab_version if fab_version else "",
|
|
403
|
+
fab_hash=fab_hash if fab_hash else "",
|
|
404
|
+
override_config=override_config,
|
|
405
|
+
),
|
|
406
|
+
status=RunStatus(
|
|
407
|
+
status=Status.PENDING,
|
|
408
|
+
sub_status="",
|
|
409
|
+
details="",
|
|
410
|
+
),
|
|
411
|
+
pending_at=now().isoformat(),
|
|
360
412
|
)
|
|
413
|
+
self.run_ids[run_id] = run_record
|
|
414
|
+
|
|
415
|
+
# Record federation options. Leave empty if not passed
|
|
416
|
+
self.federation_options[run_id] = federation_options
|
|
361
417
|
return run_id
|
|
362
418
|
log(ERROR, "Unexpected run creation failure.")
|
|
363
419
|
return 0
|
|
@@ -365,7 +421,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
365
421
|
def store_server_private_public_key(
|
|
366
422
|
self, private_key: bytes, public_key: bytes
|
|
367
423
|
) -> None:
|
|
368
|
-
"""Store `server_private_key` and `server_public_key` in state."""
|
|
424
|
+
"""Store `server_private_key` and `server_public_key` in the link state."""
|
|
369
425
|
with self.lock:
|
|
370
426
|
if self.server_private_key is None and self.server_public_key is None:
|
|
371
427
|
self.server_private_key = private_key
|
|
@@ -382,12 +438,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
382
438
|
return self.server_public_key
|
|
383
439
|
|
|
384
440
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
385
|
-
"""Store a set of `node_public_keys` in state."""
|
|
441
|
+
"""Store a set of `node_public_keys` in the link state."""
|
|
386
442
|
with self.lock:
|
|
387
443
|
self.node_public_keys = public_keys
|
|
388
444
|
|
|
389
445
|
def store_node_public_key(self, public_key: bytes) -> None:
|
|
390
|
-
"""Store a `node_public_key` in state."""
|
|
446
|
+
"""Store a `node_public_key` in the link state."""
|
|
391
447
|
with self.lock:
|
|
392
448
|
self.node_public_keys.add(public_key)
|
|
393
449
|
|
|
@@ -401,7 +457,77 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
401
457
|
if run_id not in self.run_ids:
|
|
402
458
|
log(ERROR, "`run_id` is invalid")
|
|
403
459
|
return None
|
|
404
|
-
return self.run_ids[run_id]
|
|
460
|
+
return self.run_ids[run_id].run
|
|
461
|
+
|
|
462
|
+
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
463
|
+
"""Retrieve the statuses for the specified runs."""
|
|
464
|
+
with self.lock:
|
|
465
|
+
return {
|
|
466
|
+
run_id: self.run_ids[run_id].status
|
|
467
|
+
for run_id in set(run_ids)
|
|
468
|
+
if run_id in self.run_ids
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
472
|
+
"""Update the status of the run with the specified `run_id`."""
|
|
473
|
+
with self.lock:
|
|
474
|
+
# Check if the run_id exists
|
|
475
|
+
if run_id not in self.run_ids:
|
|
476
|
+
log(ERROR, "`run_id` is invalid")
|
|
477
|
+
return False
|
|
478
|
+
|
|
479
|
+
# Check if the status transition is valid
|
|
480
|
+
current_status = self.run_ids[run_id].status
|
|
481
|
+
if not is_valid_transition(current_status, new_status):
|
|
482
|
+
log(
|
|
483
|
+
ERROR,
|
|
484
|
+
'Invalid status transition: from "%s" to "%s"',
|
|
485
|
+
current_status.status,
|
|
486
|
+
new_status.status,
|
|
487
|
+
)
|
|
488
|
+
return False
|
|
489
|
+
|
|
490
|
+
# Check if the sub-status is valid
|
|
491
|
+
if not has_valid_sub_status(current_status):
|
|
492
|
+
log(
|
|
493
|
+
ERROR,
|
|
494
|
+
'Invalid sub-status "%s" for status "%s"',
|
|
495
|
+
current_status.sub_status,
|
|
496
|
+
current_status.status,
|
|
497
|
+
)
|
|
498
|
+
return False
|
|
499
|
+
|
|
500
|
+
# Update the status
|
|
501
|
+
run_record = self.run_ids[run_id]
|
|
502
|
+
if new_status.status == Status.STARTING:
|
|
503
|
+
run_record.starting_at = now().isoformat()
|
|
504
|
+
elif new_status.status == Status.RUNNING:
|
|
505
|
+
run_record.running_at = now().isoformat()
|
|
506
|
+
elif new_status.status == Status.FINISHED:
|
|
507
|
+
run_record.finished_at = now().isoformat()
|
|
508
|
+
run_record.status = new_status
|
|
509
|
+
return True
|
|
510
|
+
|
|
511
|
+
def get_pending_run_id(self) -> Optional[int]:
|
|
512
|
+
"""Get the `run_id` of a run with `Status.PENDING` status, if any."""
|
|
513
|
+
pending_run_id = None
|
|
514
|
+
|
|
515
|
+
# Loop through all registered runs
|
|
516
|
+
for run_id, run_rec in self.run_ids.items():
|
|
517
|
+
# Break once a pending run is found
|
|
518
|
+
if run_rec.status.status == Status.PENDING:
|
|
519
|
+
pending_run_id = run_id
|
|
520
|
+
break
|
|
521
|
+
|
|
522
|
+
return pending_run_id
|
|
523
|
+
|
|
524
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
|
|
525
|
+
"""Retrieve the federation options for the specified `run_id`."""
|
|
526
|
+
with self.lock:
|
|
527
|
+
if run_id not in self.run_ids:
|
|
528
|
+
log(ERROR, "`run_id` is invalid")
|
|
529
|
+
return None
|
|
530
|
+
return self.federation_options[run_id]
|
|
405
531
|
|
|
406
532
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
407
533
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
@@ -410,3 +536,36 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
410
536
|
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
411
537
|
return True
|
|
412
538
|
return False
|
|
539
|
+
|
|
540
|
+
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
541
|
+
"""Get the context for the specified `run_id`."""
|
|
542
|
+
return self.contexts.get(run_id)
|
|
543
|
+
|
|
544
|
+
def set_serverapp_context(self, run_id: int, context: Context) -> None:
|
|
545
|
+
"""Set the context for the specified `run_id`."""
|
|
546
|
+
if run_id not in self.run_ids:
|
|
547
|
+
raise ValueError(f"Run {run_id} not found")
|
|
548
|
+
self.contexts[run_id] = context
|
|
549
|
+
|
|
550
|
+
def add_serverapp_log(self, run_id: int, log_message: str) -> None:
|
|
551
|
+
"""Add a log entry to the serverapp logs for the specified `run_id`."""
|
|
552
|
+
if run_id not in self.run_ids:
|
|
553
|
+
raise ValueError(f"Run {run_id} not found")
|
|
554
|
+
run = self.run_ids[run_id]
|
|
555
|
+
with run.log_lock:
|
|
556
|
+
run.logs.append((now().timestamp(), log_message))
|
|
557
|
+
|
|
558
|
+
def get_serverapp_log(
|
|
559
|
+
self, run_id: int, after_timestamp: Optional[float]
|
|
560
|
+
) -> tuple[str, float]:
|
|
561
|
+
"""Get the serverapp logs for the specified `run_id`."""
|
|
562
|
+
if run_id not in self.run_ids:
|
|
563
|
+
raise ValueError(f"Run {run_id} not found")
|
|
564
|
+
run = self.run_ids[run_id]
|
|
565
|
+
if after_timestamp is None:
|
|
566
|
+
after_timestamp = 0.0
|
|
567
|
+
with run.log_lock:
|
|
568
|
+
# Find the index where the timestamp would be inserted
|
|
569
|
+
index = bisect_right(run.logs, (after_timestamp, ""))
|
|
570
|
+
latest_timestamp = run.logs[-1][0] if index < len(run.logs) else 0.0
|
|
571
|
+
return "".join(log for _, log in run.logs[index:]), latest_timestamp
|