flwr-nightly 1.13.0.dev20241021__py3-none-any.whl → 1.13.0.dev20241111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +2 -2
- flwr/cli/config_utils.py +97 -0
- flwr/cli/log.py +63 -97
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +34 -88
- flwr/client/app.py +23 -20
- flwr/client/clientapp/app.py +22 -18
- flwr/client/nodestate/__init__.py +25 -0
- flwr/client/nodestate/in_memory_nodestate.py +38 -0
- flwr/client/nodestate/nodestate.py +30 -0
- flwr/client/nodestate/nodestate_factory.py +37 -0
- flwr/client/{node_state.py → run_info_store.py} +4 -3
- flwr/client/supernode/app.py +6 -8
- flwr/common/args.py +83 -0
- flwr/common/config.py +10 -0
- flwr/common/constant.py +39 -5
- flwr/common/context.py +9 -4
- flwr/common/date.py +3 -3
- flwr/common/logger.py +108 -1
- flwr/common/object_ref.py +47 -16
- flwr/common/serde.py +24 -0
- flwr/common/telemetry.py +0 -6
- flwr/common/typing.py +10 -1
- flwr/proto/exec_pb2.py +14 -17
- flwr/proto/exec_pb2.pyi +14 -22
- flwr/proto/log_pb2.py +29 -0
- flwr/proto/log_pb2.pyi +39 -0
- flwr/proto/log_pb2_grpc.py +4 -0
- flwr/proto/log_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +8 -8
- flwr/proto/message_pb2.pyi +4 -1
- flwr/proto/run_pb2.py +32 -27
- flwr/proto/run_pb2.pyi +26 -0
- flwr/proto/serverappio_pb2.py +52 -0
- flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
- flwr/proto/serverappio_pb2_grpc.py +376 -0
- flwr/proto/serverappio_pb2_grpc.pyi +147 -0
- flwr/proto/simulationio_pb2.py +38 -0
- flwr/proto/simulationio_pb2.pyi +65 -0
- flwr/proto/simulationio_pb2_grpc.py +205 -0
- flwr/proto/simulationio_pb2_grpc.pyi +81 -0
- flwr/server/app.py +272 -105
- flwr/server/driver/driver.py +15 -1
- flwr/server/driver/grpc_driver.py +25 -36
- flwr/server/driver/inmemory_driver.py +6 -16
- flwr/server/run_serverapp.py +29 -23
- flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
- flwr/server/serverapp/app.py +214 -0
- flwr/server/strategy/aggregate.py +4 -4
- flwr/server/strategy/fedadam.py +11 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
- flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
- flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
- flwr/server/superlink/fleet/vce/vce_api.py +23 -23
- flwr/server/superlink/linkstate/__init__.py +28 -0
- flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +184 -36
- flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +149 -19
- flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
- flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +306 -65
- flwr/server/superlink/{state → linkstate}/utils.py +81 -30
- flwr/server/superlink/simulation/__init__.py +15 -0
- flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
- flwr/server/superlink/simulation/simulationio_servicer.py +153 -0
- flwr/simulation/__init__.py +5 -1
- flwr/simulation/app.py +273 -345
- flwr/simulation/legacy_app.py +382 -0
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- flwr/simulation/run_simulation.py +57 -131
- flwr/simulation/simulationio_connection.py +86 -0
- flwr/superexec/app.py +6 -134
- flwr/superexec/deployment.py +61 -66
- flwr/superexec/exec_grpc.py +15 -8
- flwr/superexec/exec_servicer.py +36 -65
- flwr/superexec/executor.py +26 -7
- flwr/superexec/simulation.py +54 -107
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/METADATA +5 -4
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/RECORD +88 -69
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/entry_points.txt +2 -0
- flwr/client/node_state_tests.py +0 -66
- flwr/proto/driver_pb2.py +0 -42
- flwr/proto/driver_pb2_grpc.py +0 -239
- flwr/proto/driver_pb2_grpc.pyi +0 -94
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/WHEEL +0 -0
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""ServerAppIo gRPC API."""
|
|
16
16
|
|
|
17
17
|
from logging import INFO
|
|
18
18
|
from typing import Optional
|
|
@@ -21,37 +21,40 @@ import grpc
|
|
|
21
21
|
|
|
22
22
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
23
23
|
from flwr.common.logger import log
|
|
24
|
-
from flwr.proto.
|
|
25
|
-
|
|
24
|
+
from flwr.proto.serverappio_pb2_grpc import ( # pylint: disable=E0611
|
|
25
|
+
add_ServerAppIoServicer_to_server,
|
|
26
26
|
)
|
|
27
27
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
28
|
-
from flwr.server.superlink.
|
|
28
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
29
29
|
|
|
30
30
|
from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
|
31
|
-
from .
|
|
31
|
+
from .serverappio_servicer import ServerAppIoServicer
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
def
|
|
34
|
+
def run_serverappio_api_grpc(
|
|
35
35
|
address: str,
|
|
36
|
-
state_factory:
|
|
36
|
+
state_factory: LinkStateFactory,
|
|
37
37
|
ffs_factory: FfsFactory,
|
|
38
38
|
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
39
39
|
) -> grpc.Server:
|
|
40
|
-
"""Run
|
|
41
|
-
# Create
|
|
42
|
-
|
|
40
|
+
"""Run ServerAppIo API (gRPC, request-response)."""
|
|
41
|
+
# Create ServerAppIo API gRPC server
|
|
42
|
+
serverappio_servicer: grpc.Server = ServerAppIoServicer(
|
|
43
43
|
state_factory=state_factory,
|
|
44
44
|
ffs_factory=ffs_factory,
|
|
45
45
|
)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
servicer_and_add_fn=(
|
|
46
|
+
serverappio_add_servicer_to_server_fn = add_ServerAppIoServicer_to_server
|
|
47
|
+
serverappio_grpc_server = generic_create_grpc_server(
|
|
48
|
+
servicer_and_add_fn=(
|
|
49
|
+
serverappio_servicer,
|
|
50
|
+
serverappio_add_servicer_to_server_fn,
|
|
51
|
+
),
|
|
49
52
|
server_address=address,
|
|
50
53
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
51
54
|
certificates=certificates,
|
|
52
55
|
)
|
|
53
56
|
|
|
54
|
-
log(INFO, "Flower ECE: Starting
|
|
55
|
-
|
|
57
|
+
log(INFO, "Flower ECE: Starting ServerAppIo API (gRPC-rere) on %s", address)
|
|
58
|
+
serverappio_grpc_server.start()
|
|
56
59
|
|
|
57
|
-
return
|
|
60
|
+
return serverappio_grpc_server
|
|
@@ -12,62 +12,80 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""ServerAppIo API servicer."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import threading
|
|
18
19
|
import time
|
|
19
|
-
from logging import DEBUG
|
|
20
|
+
from logging import DEBUG, INFO
|
|
20
21
|
from typing import Optional
|
|
21
22
|
from uuid import UUID
|
|
22
23
|
|
|
23
24
|
import grpc
|
|
24
25
|
|
|
26
|
+
from flwr.common import ConfigsRecord
|
|
27
|
+
from flwr.common.constant import Status
|
|
25
28
|
from flwr.common.logger import log
|
|
26
29
|
from flwr.common.serde import (
|
|
30
|
+
context_from_proto,
|
|
31
|
+
context_to_proto,
|
|
27
32
|
fab_from_proto,
|
|
28
33
|
fab_to_proto,
|
|
34
|
+
run_status_from_proto,
|
|
35
|
+
run_to_proto,
|
|
29
36
|
user_config_from_proto,
|
|
30
|
-
user_config_to_proto,
|
|
31
|
-
)
|
|
32
|
-
from flwr.common.typing import Fab
|
|
33
|
-
from flwr.proto import driver_pb2_grpc # pylint: disable=E0611
|
|
34
|
-
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
35
|
-
GetNodesRequest,
|
|
36
|
-
GetNodesResponse,
|
|
37
|
-
PullTaskResRequest,
|
|
38
|
-
PullTaskResResponse,
|
|
39
|
-
PushTaskInsRequest,
|
|
40
|
-
PushTaskInsResponse,
|
|
41
37
|
)
|
|
38
|
+
from flwr.common.typing import Fab, RunStatus
|
|
39
|
+
from flwr.proto import serverappio_pb2_grpc # pylint: disable=E0611
|
|
42
40
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
41
|
+
from flwr.proto.log_pb2 import ( # pylint: disable=E0611
|
|
42
|
+
PushLogsRequest,
|
|
43
|
+
PushLogsResponse,
|
|
44
|
+
)
|
|
43
45
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
44
46
|
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
45
47
|
CreateRunRequest,
|
|
46
48
|
CreateRunResponse,
|
|
47
49
|
GetRunRequest,
|
|
48
50
|
GetRunResponse,
|
|
49
|
-
|
|
51
|
+
UpdateRunStatusRequest,
|
|
52
|
+
UpdateRunStatusResponse,
|
|
53
|
+
)
|
|
54
|
+
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
55
|
+
GetNodesRequest,
|
|
56
|
+
GetNodesResponse,
|
|
57
|
+
PullServerAppInputsRequest,
|
|
58
|
+
PullServerAppInputsResponse,
|
|
59
|
+
PullTaskResRequest,
|
|
60
|
+
PullTaskResResponse,
|
|
61
|
+
PushServerAppOutputsRequest,
|
|
62
|
+
PushServerAppOutputsResponse,
|
|
63
|
+
PushTaskInsRequest,
|
|
64
|
+
PushTaskInsResponse,
|
|
50
65
|
)
|
|
51
66
|
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
52
67
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
53
68
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
54
|
-
from flwr.server.superlink.
|
|
69
|
+
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
55
70
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
56
71
|
|
|
57
72
|
|
|
58
|
-
class
|
|
59
|
-
"""
|
|
73
|
+
class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
74
|
+
"""ServerAppIo API servicer."""
|
|
60
75
|
|
|
61
|
-
def __init__(
|
|
76
|
+
def __init__(
|
|
77
|
+
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
78
|
+
) -> None:
|
|
62
79
|
self.state_factory = state_factory
|
|
63
80
|
self.ffs_factory = ffs_factory
|
|
81
|
+
self.lock = threading.RLock()
|
|
64
82
|
|
|
65
83
|
def GetNodes(
|
|
66
84
|
self, request: GetNodesRequest, context: grpc.ServicerContext
|
|
67
85
|
) -> GetNodesResponse:
|
|
68
86
|
"""Get available nodes."""
|
|
69
|
-
log(DEBUG, "
|
|
70
|
-
state:
|
|
87
|
+
log(DEBUG, "ServerAppIoServicer.GetNodes")
|
|
88
|
+
state: LinkState = self.state_factory.state()
|
|
71
89
|
all_ids: set[int] = state.get_nodes(request.run_id)
|
|
72
90
|
nodes: list[Node] = [
|
|
73
91
|
Node(node_id=node_id, anonymous=False) for node_id in all_ids
|
|
@@ -78,8 +96,8 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
78
96
|
self, request: CreateRunRequest, context: grpc.ServicerContext
|
|
79
97
|
) -> CreateRunResponse:
|
|
80
98
|
"""Create run ID."""
|
|
81
|
-
log(DEBUG, "
|
|
82
|
-
state:
|
|
99
|
+
log(DEBUG, "ServerAppIoServicer.CreateRun")
|
|
100
|
+
state: LinkState = self.state_factory.state()
|
|
83
101
|
if request.HasField("fab"):
|
|
84
102
|
fab = fab_from_proto(request.fab)
|
|
85
103
|
ffs: Ffs = self.ffs_factory.ffs()
|
|
@@ -95,6 +113,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
95
113
|
request.fab_version,
|
|
96
114
|
fab_hash,
|
|
97
115
|
user_config_from_proto(request.override_config),
|
|
116
|
+
ConfigsRecord(),
|
|
98
117
|
)
|
|
99
118
|
return CreateRunResponse(run_id=run_id)
|
|
100
119
|
|
|
@@ -102,7 +121,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
102
121
|
self, request: PushTaskInsRequest, context: grpc.ServicerContext
|
|
103
122
|
) -> PushTaskInsResponse:
|
|
104
123
|
"""Push a set of TaskIns."""
|
|
105
|
-
log(DEBUG, "
|
|
124
|
+
log(DEBUG, "ServerAppIoServicer.PushTaskIns")
|
|
106
125
|
|
|
107
126
|
# Set pushed_at (timestamp in seconds)
|
|
108
127
|
pushed_at = time.time()
|
|
@@ -116,7 +135,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
116
135
|
_raise_if(bool(validation_errors), ", ".join(validation_errors))
|
|
117
136
|
|
|
118
137
|
# Init state
|
|
119
|
-
state:
|
|
138
|
+
state: LinkState = self.state_factory.state()
|
|
120
139
|
|
|
121
140
|
# Store each TaskIns
|
|
122
141
|
task_ids: list[Optional[UUID]] = []
|
|
@@ -132,17 +151,20 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
132
151
|
self, request: PullTaskResRequest, context: grpc.ServicerContext
|
|
133
152
|
) -> PullTaskResResponse:
|
|
134
153
|
"""Pull a set of TaskRes."""
|
|
135
|
-
log(DEBUG, "
|
|
154
|
+
log(DEBUG, "ServerAppIoServicer.PullTaskRes")
|
|
136
155
|
|
|
137
156
|
# Convert each task_id str to UUID
|
|
138
157
|
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
|
139
158
|
|
|
140
159
|
# Init state
|
|
141
|
-
state:
|
|
160
|
+
state: LinkState = self.state_factory.state()
|
|
142
161
|
|
|
143
162
|
# Register callback
|
|
144
163
|
def on_rpc_done() -> None:
|
|
145
|
-
log(
|
|
164
|
+
log(
|
|
165
|
+
DEBUG,
|
|
166
|
+
"ServerAppIoServicer.PullTaskRes callback: delete TaskIns/TaskRes",
|
|
167
|
+
)
|
|
146
168
|
|
|
147
169
|
if context.is_active():
|
|
148
170
|
return
|
|
@@ -164,10 +186,10 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
164
186
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
|
165
187
|
) -> GetRunResponse:
|
|
166
188
|
"""Get run information."""
|
|
167
|
-
log(DEBUG, "
|
|
189
|
+
log(DEBUG, "ServerAppIoServicer.GetRun")
|
|
168
190
|
|
|
169
191
|
# Init state
|
|
170
|
-
state:
|
|
192
|
+
state: LinkState = self.state_factory.state()
|
|
171
193
|
|
|
172
194
|
# Retrieve run information
|
|
173
195
|
run = state.get_run(request.run_id)
|
|
@@ -175,21 +197,13 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
175
197
|
if run is None:
|
|
176
198
|
return GetRunResponse()
|
|
177
199
|
|
|
178
|
-
return GetRunResponse(
|
|
179
|
-
run=Run(
|
|
180
|
-
run_id=run.run_id,
|
|
181
|
-
fab_id=run.fab_id,
|
|
182
|
-
fab_version=run.fab_version,
|
|
183
|
-
override_config=user_config_to_proto(run.override_config),
|
|
184
|
-
fab_hash=run.fab_hash,
|
|
185
|
-
)
|
|
186
|
-
)
|
|
200
|
+
return GetRunResponse(run=run_to_proto(run))
|
|
187
201
|
|
|
188
202
|
def GetFab(
|
|
189
203
|
self, request: GetFabRequest, context: grpc.ServicerContext
|
|
190
204
|
) -> GetFabResponse:
|
|
191
205
|
"""Get FAB from Ffs."""
|
|
192
|
-
log(DEBUG, "
|
|
206
|
+
log(DEBUG, "ServerAppIoServicer.GetFab")
|
|
193
207
|
|
|
194
208
|
ffs: Ffs = self.ffs_factory.ffs()
|
|
195
209
|
if result := ffs.get(request.hash_str):
|
|
@@ -198,6 +212,78 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
198
212
|
|
|
199
213
|
raise ValueError(f"Found no FAB with hash: {request.hash_str}")
|
|
200
214
|
|
|
215
|
+
def PullServerAppInputs(
|
|
216
|
+
self, request: PullServerAppInputsRequest, context: grpc.ServicerContext
|
|
217
|
+
) -> PullServerAppInputsResponse:
|
|
218
|
+
"""Pull ServerApp process inputs."""
|
|
219
|
+
log(DEBUG, "ServerAppIoServicer.PullServerAppInputs")
|
|
220
|
+
# Init access to LinkState and Ffs
|
|
221
|
+
state = self.state_factory.state()
|
|
222
|
+
ffs = self.ffs_factory.ffs()
|
|
223
|
+
|
|
224
|
+
# Lock access to LinkState, preventing obtaining the same pending run_id
|
|
225
|
+
with self.lock:
|
|
226
|
+
# Attempt getting the run_id of a pending run
|
|
227
|
+
run_id = state.get_pending_run_id()
|
|
228
|
+
# If there's no pending run, return an empty response
|
|
229
|
+
if run_id is None:
|
|
230
|
+
return PullServerAppInputsResponse()
|
|
231
|
+
|
|
232
|
+
# Retrieve Context, Run and Fab for the run_id
|
|
233
|
+
serverapp_ctxt = state.get_serverapp_context(run_id)
|
|
234
|
+
run = state.get_run(run_id)
|
|
235
|
+
fab = None
|
|
236
|
+
if run and run.fab_hash:
|
|
237
|
+
if result := ffs.get(run.fab_hash):
|
|
238
|
+
fab = Fab(run.fab_hash, result[0])
|
|
239
|
+
if run and fab and serverapp_ctxt:
|
|
240
|
+
# Update run status to STARTING
|
|
241
|
+
if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
|
|
242
|
+
log(INFO, "Starting run %d", run_id)
|
|
243
|
+
return PullServerAppInputsResponse(
|
|
244
|
+
context=context_to_proto(serverapp_ctxt),
|
|
245
|
+
run=run_to_proto(run),
|
|
246
|
+
fab=fab_to_proto(fab),
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Raise an exception if the Run or Fab is not found,
|
|
250
|
+
# or if the status cannot be updated to STARTING
|
|
251
|
+
raise RuntimeError(f"Failed to start run {run_id}")
|
|
252
|
+
|
|
253
|
+
def PushServerAppOutputs(
|
|
254
|
+
self, request: PushServerAppOutputsRequest, context: grpc.ServicerContext
|
|
255
|
+
) -> PushServerAppOutputsResponse:
|
|
256
|
+
"""Push ServerApp process outputs."""
|
|
257
|
+
log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
|
|
258
|
+
state = self.state_factory.state()
|
|
259
|
+
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
|
260
|
+
return PushServerAppOutputsResponse()
|
|
261
|
+
|
|
262
|
+
def UpdateRunStatus(
|
|
263
|
+
self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
|
|
264
|
+
) -> UpdateRunStatusResponse:
|
|
265
|
+
"""Update the status of a run."""
|
|
266
|
+
log(DEBUG, "ControlServicer.UpdateRunStatus")
|
|
267
|
+
state = self.state_factory.state()
|
|
268
|
+
|
|
269
|
+
# Update the run status
|
|
270
|
+
state.update_run_status(
|
|
271
|
+
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
|
|
272
|
+
)
|
|
273
|
+
return UpdateRunStatusResponse()
|
|
274
|
+
|
|
275
|
+
def PushLogs(
|
|
276
|
+
self, request: PushLogsRequest, context: grpc.ServicerContext
|
|
277
|
+
) -> PushLogsResponse:
|
|
278
|
+
"""Push logs."""
|
|
279
|
+
log(DEBUG, "ServerAppIoServicer.PushLogs")
|
|
280
|
+
state = self.state_factory.state()
|
|
281
|
+
|
|
282
|
+
# Add logs to LinkState
|
|
283
|
+
merged_logs = "".join(request.logs)
|
|
284
|
+
state.add_serverapp_log(request.run_id, merged_logs)
|
|
285
|
+
return PushLogsResponse()
|
|
286
|
+
|
|
201
287
|
|
|
202
288
|
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
203
289
|
if validation_error:
|
|
@@ -48,7 +48,7 @@ from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
|
48
48
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
49
49
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
50
50
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
51
|
-
from flwr.server.superlink.
|
|
51
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
52
52
|
|
|
53
53
|
T = TypeVar("T", bound=GrpcMessage)
|
|
54
54
|
|
|
@@ -77,7 +77,9 @@ def _handle(
|
|
|
77
77
|
class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
|
|
78
78
|
"""Fleet API via GrpcAdapter servicer."""
|
|
79
79
|
|
|
80
|
-
def __init__(
|
|
80
|
+
def __init__(
|
|
81
|
+
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
82
|
+
) -> None:
|
|
81
83
|
self.state_factory = state_factory
|
|
82
84
|
self.ffs_factory = ffs_factory
|
|
83
85
|
|
|
@@ -30,7 +30,7 @@ from flwr.proto.transport_pb2_grpc import ( # pylint: disable=E0611
|
|
|
30
30
|
add_FlowerServiceServicer_to_server,
|
|
31
31
|
)
|
|
32
32
|
from flwr.server.client_manager import ClientManager
|
|
33
|
-
from flwr.server.superlink.driver.
|
|
33
|
+
from flwr.server.superlink.driver.serverappio_servicer import ServerAppIoServicer
|
|
34
34
|
from flwr.server.superlink.fleet.grpc_adapter.grpc_adapter_servicer import (
|
|
35
35
|
GrpcAdapterServicer,
|
|
36
36
|
)
|
|
@@ -161,7 +161,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments,R0917
|
|
|
161
161
|
tuple[FleetServicer, AddServicerToServerFn],
|
|
162
162
|
tuple[GrpcAdapterServicer, AddServicerToServerFn],
|
|
163
163
|
tuple[FlowerServiceServicer, AddServicerToServerFn],
|
|
164
|
-
tuple[
|
|
164
|
+
tuple[ServerAppIoServicer, AddServicerToServerFn],
|
|
165
165
|
],
|
|
166
166
|
server_address: str,
|
|
167
167
|
max_concurrent_workers: int = 1000,
|
|
@@ -37,13 +37,15 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
37
37
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
38
38
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
39
39
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
40
|
-
from flwr.server.superlink.
|
|
40
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
44
44
|
"""Fleet API servicer."""
|
|
45
45
|
|
|
46
|
-
def __init__(
|
|
46
|
+
def __init__(
|
|
47
|
+
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
48
|
+
) -> None:
|
|
47
49
|
self.state_factory = state_factory
|
|
48
50
|
self.ffs_factory = ffs_factory
|
|
49
51
|
|
|
@@ -45,7 +45,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
45
45
|
)
|
|
46
46
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
47
47
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
48
|
-
from flwr.server.superlink.
|
|
48
|
+
from flwr.server.superlink.linkstate import LinkState
|
|
49
49
|
|
|
50
50
|
_PUBLIC_KEY_HEADER = "public-key"
|
|
51
51
|
_AUTH_TOKEN_HEADER = "auth-token"
|
|
@@ -84,7 +84,7 @@ def _get_value_from_tuples(
|
|
|
84
84
|
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
85
85
|
"""Server interceptor for node authentication."""
|
|
86
86
|
|
|
87
|
-
def __init__(self, state:
|
|
87
|
+
def __init__(self, state: LinkState):
|
|
88
88
|
self.state = state
|
|
89
89
|
|
|
90
90
|
self.node_public_keys = state.get_node_public_keys()
|
|
@@ -43,12 +43,12 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
43
43
|
)
|
|
44
44
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
45
45
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
46
|
-
from flwr.server.superlink.
|
|
46
|
+
from flwr.server.superlink.linkstate import LinkState
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
def create_node(
|
|
50
50
|
request: CreateNodeRequest, # pylint: disable=unused-argument
|
|
51
|
-
state:
|
|
51
|
+
state: LinkState,
|
|
52
52
|
) -> CreateNodeResponse:
|
|
53
53
|
"""."""
|
|
54
54
|
# Create node
|
|
@@ -56,7 +56,7 @@ def create_node(
|
|
|
56
56
|
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
def delete_node(request: DeleteNodeRequest, state:
|
|
59
|
+
def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
|
|
60
60
|
"""."""
|
|
61
61
|
# Validate node_id
|
|
62
62
|
if request.node.anonymous or request.node.node_id == 0:
|
|
@@ -69,14 +69,14 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
|
|
|
69
69
|
|
|
70
70
|
def ping(
|
|
71
71
|
request: PingRequest, # pylint: disable=unused-argument
|
|
72
|
-
state:
|
|
72
|
+
state: LinkState, # pylint: disable=unused-argument
|
|
73
73
|
) -> PingResponse:
|
|
74
74
|
"""."""
|
|
75
75
|
res = state.acknowledge_ping(request.node.node_id, request.ping_interval)
|
|
76
76
|
return PingResponse(success=res)
|
|
77
77
|
|
|
78
78
|
|
|
79
|
-
def pull_task_ins(request: PullTaskInsRequest, state:
|
|
79
|
+
def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse:
|
|
80
80
|
"""Pull TaskIns handler."""
|
|
81
81
|
# Get node_id if client node is not anonymous
|
|
82
82
|
node = request.node # pylint: disable=no-member
|
|
@@ -92,7 +92,7 @@ def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsRespo
|
|
|
92
92
|
return response
|
|
93
93
|
|
|
94
94
|
|
|
95
|
-
def push_task_res(request: PushTaskResRequest, state:
|
|
95
|
+
def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResResponse:
|
|
96
96
|
"""Push TaskRes handler."""
|
|
97
97
|
# pylint: disable=no-member
|
|
98
98
|
task_res: TaskRes = request.task_res_list[0]
|
|
@@ -113,7 +113,7 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
|
|
|
113
113
|
|
|
114
114
|
|
|
115
115
|
def get_run(
|
|
116
|
-
request: GetRunRequest, state:
|
|
116
|
+
request: GetRunRequest, state: LinkState # pylint: disable=W0613
|
|
117
117
|
) -> GetRunResponse:
|
|
118
118
|
"""Get run information."""
|
|
119
119
|
run = state.get_run(request.run_id)
|
|
@@ -40,7 +40,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
40
40
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
41
41
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
42
42
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
43
|
-
from flwr.server.superlink.
|
|
43
|
+
from flwr.server.superlink.linkstate import LinkState
|
|
44
44
|
|
|
45
45
|
try:
|
|
46
46
|
from starlette.applications import Starlette
|
|
@@ -90,7 +90,7 @@ def rest_request_response(
|
|
|
90
90
|
async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
|
|
91
91
|
"""Create Node."""
|
|
92
92
|
# Get state from app
|
|
93
|
-
state:
|
|
93
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
94
94
|
|
|
95
95
|
# Handle message
|
|
96
96
|
return message_handler.create_node(request=request, state=state)
|
|
@@ -100,7 +100,7 @@ async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
|
|
|
100
100
|
async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
101
101
|
"""Delete Node Id."""
|
|
102
102
|
# Get state from app
|
|
103
|
-
state:
|
|
103
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
104
104
|
|
|
105
105
|
# Handle message
|
|
106
106
|
return message_handler.delete_node(request=request, state=state)
|
|
@@ -110,7 +110,7 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
|
110
110
|
async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
|
|
111
111
|
"""Pull TaskIns."""
|
|
112
112
|
# Get state from app
|
|
113
|
-
state:
|
|
113
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
114
114
|
|
|
115
115
|
# Handle message
|
|
116
116
|
return message_handler.pull_task_ins(request=request, state=state)
|
|
@@ -121,7 +121,7 @@ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
|
|
|
121
121
|
async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
|
122
122
|
"""Push TaskRes."""
|
|
123
123
|
# Get state from app
|
|
124
|
-
state:
|
|
124
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
125
125
|
|
|
126
126
|
# Handle message
|
|
127
127
|
return message_handler.push_task_res(request=request, state=state)
|
|
@@ -131,7 +131,7 @@ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
|
|
131
131
|
async def ping(request: PingRequest) -> PingResponse:
|
|
132
132
|
"""Ping."""
|
|
133
133
|
# Get state from app
|
|
134
|
-
state:
|
|
134
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
135
135
|
|
|
136
136
|
# Handle message
|
|
137
137
|
return message_handler.ping(request=request, state=state)
|
|
@@ -141,7 +141,7 @@ async def ping(request: PingRequest) -> PingResponse:
|
|
|
141
141
|
async def get_run(request: GetRunRequest) -> GetRunResponse:
|
|
142
142
|
"""GetRun."""
|
|
143
143
|
# Get state from app
|
|
144
|
-
state:
|
|
144
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
145
145
|
|
|
146
146
|
# Handle message
|
|
147
147
|
return message_handler.get_run(request=request, state=state)
|