flwr-nightly 1.13.0.dev20241028__py3-none-any.whl → 1.13.0.dev20241030__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +2 -2
- flwr/cli/log.py +46 -17
- flwr/common/constant.py +6 -0
- flwr/common/date.py +3 -3
- flwr/common/logger.py +103 -0
- flwr/common/serde.py +22 -0
- flwr/proto/driver_pb2.py +24 -23
- flwr/proto/driver_pb2.pyi +0 -5
- flwr/proto/driver_pb2_grpc.py +69 -0
- flwr/proto/driver_pb2_grpc.pyi +27 -0
- flwr/proto/exec_pb2.py +6 -6
- flwr/proto/exec_pb2.pyi +8 -2
- flwr/proto/log_pb2.py +29 -0
- flwr/proto/log_pb2.pyi +39 -0
- flwr/proto/log_pb2_grpc.py +4 -0
- flwr/proto/log_pb2_grpc.pyi +4 -0
- flwr/server/app.py +1 -5
- flwr/server/driver/driver.py +14 -0
- flwr/server/driver/grpc_driver.py +8 -15
- flwr/server/driver/inmemory_driver.py +3 -11
- flwr/server/run_serverapp.py +3 -4
- flwr/server/serverapp/app.py +148 -18
- flwr/server/superlink/driver/driver_servicer.py +36 -9
- flwr/server/superlink/linkstate/in_memory_linkstate.py +28 -2
- flwr/server/superlink/linkstate/linkstate.py +35 -0
- flwr/server/superlink/linkstate/sqlite_linkstate.py +50 -0
- flwr/simulation/run_simulation.py +2 -1
- flwr/superexec/deployment.py +22 -40
- flwr/superexec/exec_servicer.py +23 -62
- flwr/superexec/executor.py +3 -4
- flwr/superexec/simulation.py +4 -7
- {flwr_nightly-1.13.0.dev20241028.dist-info → flwr_nightly-1.13.0.dev20241030.dist-info}/METADATA +1 -1
- {flwr_nightly-1.13.0.dev20241028.dist-info → flwr_nightly-1.13.0.dev20241030.dist-info}/RECORD +36 -32
- {flwr_nightly-1.13.0.dev20241028.dist-info → flwr_nightly-1.13.0.dev20241030.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241028.dist-info → flwr_nightly-1.13.0.dev20241030.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.13.0.dev20241028.dist-info → flwr_nightly-1.13.0.dev20241030.dist-info}/entry_points.txt +0 -0
flwr/server/app.py
CHANGED
|
@@ -382,9 +382,7 @@ def _flwr_serverapp_scheduler(
|
|
|
382
382
|
|
|
383
383
|
log(
|
|
384
384
|
INFO,
|
|
385
|
-
"Launching `flwr-serverapp` subprocess
|
|
386
|
-
"Connects to SuperLink on %s",
|
|
387
|
-
pending_run_id,
|
|
385
|
+
"Launching `flwr-serverapp` subprocess. Connects to SuperLink on %s",
|
|
388
386
|
driver_api_address,
|
|
389
387
|
)
|
|
390
388
|
# Start ServerApp subprocess
|
|
@@ -392,8 +390,6 @@ def _flwr_serverapp_scheduler(
|
|
|
392
390
|
"flwr-serverapp",
|
|
393
391
|
"--superlink",
|
|
394
392
|
driver_api_address,
|
|
395
|
-
"--run-id",
|
|
396
|
-
str(pending_run_id),
|
|
397
393
|
]
|
|
398
394
|
if ssl_ca_certfile:
|
|
399
395
|
command.append("--root-certificates")
|
flwr/server/driver/driver.py
CHANGED
|
@@ -26,6 +26,20 @@ from flwr.common.typing import Run
|
|
|
26
26
|
class Driver(ABC):
|
|
27
27
|
"""Abstract base Driver class for the Driver API."""
|
|
28
28
|
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def init_run(self, run_id: int) -> None:
|
|
31
|
+
"""Request a run to the SuperLink with a given `run_id`.
|
|
32
|
+
|
|
33
|
+
If a Run with the specified `run_id` exists, a local Run
|
|
34
|
+
object will be created. It enables further functionality
|
|
35
|
+
in the driver, such as sending `Messages`.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
run_id : int
|
|
40
|
+
The `run_id` of the Run this Driver object operates in.
|
|
41
|
+
"""
|
|
42
|
+
|
|
29
43
|
@property
|
|
30
44
|
@abstractmethod
|
|
31
45
|
def run(self) -> Run:
|
|
@@ -60,8 +60,6 @@ class GrpcDriver(Driver):
|
|
|
60
60
|
|
|
61
61
|
Parameters
|
|
62
62
|
----------
|
|
63
|
-
run_id : int
|
|
64
|
-
The identifier of the run.
|
|
65
63
|
driver_service_address : str (default: "[::]:9091")
|
|
66
64
|
The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
|
|
67
65
|
root_certificates : Optional[bytes] (default: None)
|
|
@@ -72,11 +70,9 @@ class GrpcDriver(Driver):
|
|
|
72
70
|
|
|
73
71
|
def __init__( # pylint: disable=too-many-arguments
|
|
74
72
|
self,
|
|
75
|
-
run_id: int,
|
|
76
73
|
driver_service_address: str = DRIVER_API_DEFAULT_ADDRESS,
|
|
77
74
|
root_certificates: Optional[bytes] = None,
|
|
78
75
|
) -> None:
|
|
79
|
-
self._run_id = run_id
|
|
80
76
|
self._addr = driver_service_address
|
|
81
77
|
self._cert = root_certificates
|
|
82
78
|
self._run: Optional[Run] = None
|
|
@@ -116,15 +112,17 @@ class GrpcDriver(Driver):
|
|
|
116
112
|
channel.close()
|
|
117
113
|
log(DEBUG, "[Driver] Disconnected")
|
|
118
114
|
|
|
119
|
-
def
|
|
115
|
+
def init_run(self, run_id: int) -> None:
|
|
116
|
+
"""Initialize the run."""
|
|
120
117
|
# Check if is initialized
|
|
121
118
|
if self._run is not None:
|
|
122
119
|
return
|
|
120
|
+
|
|
123
121
|
# Get the run info
|
|
124
|
-
req = GetRunRequest(run_id=
|
|
122
|
+
req = GetRunRequest(run_id=run_id)
|
|
125
123
|
res: GetRunResponse = self._stub.GetRun(req)
|
|
126
124
|
if not res.HasField("run"):
|
|
127
|
-
raise RuntimeError(f"Cannot find the run with ID: {
|
|
125
|
+
raise RuntimeError(f"Cannot find the run with ID: {run_id}")
|
|
128
126
|
self._run = Run(
|
|
129
127
|
run_id=res.run.run_id,
|
|
130
128
|
fab_id=res.run.fab_id,
|
|
@@ -136,7 +134,6 @@ class GrpcDriver(Driver):
|
|
|
136
134
|
@property
|
|
137
135
|
def run(self) -> Run:
|
|
138
136
|
"""Run information."""
|
|
139
|
-
self._init_run()
|
|
140
137
|
return Run(**vars(self._run))
|
|
141
138
|
|
|
142
139
|
@property
|
|
@@ -150,7 +147,7 @@ class GrpcDriver(Driver):
|
|
|
150
147
|
# Check if the message is valid
|
|
151
148
|
if not (
|
|
152
149
|
# Assume self._run being initialized
|
|
153
|
-
message.metadata.run_id == self.
|
|
150
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
154
151
|
and message.metadata.src_node_id == self.node.node_id
|
|
155
152
|
and message.metadata.message_id == ""
|
|
156
153
|
and message.metadata.reply_to_message == ""
|
|
@@ -171,7 +168,6 @@ class GrpcDriver(Driver):
|
|
|
171
168
|
This method constructs a new `Message` with given content and metadata.
|
|
172
169
|
The `run_id` and `src_node_id` will be set automatically.
|
|
173
170
|
"""
|
|
174
|
-
self._init_run()
|
|
175
171
|
if ttl:
|
|
176
172
|
warnings.warn(
|
|
177
173
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -182,7 +178,7 @@ class GrpcDriver(Driver):
|
|
|
182
178
|
|
|
183
179
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
184
180
|
metadata = Metadata(
|
|
185
|
-
run_id=self.
|
|
181
|
+
run_id=cast(Run, self._run).run_id,
|
|
186
182
|
message_id="", # Will be set by the server
|
|
187
183
|
src_node_id=self.node.node_id,
|
|
188
184
|
dst_node_id=dst_node_id,
|
|
@@ -195,10 +191,9 @@ class GrpcDriver(Driver):
|
|
|
195
191
|
|
|
196
192
|
def get_node_ids(self) -> list[int]:
|
|
197
193
|
"""Get node IDs."""
|
|
198
|
-
self._init_run()
|
|
199
194
|
# Call GrpcDriverStub method
|
|
200
195
|
res: GetNodesResponse = self._stub.GetNodes(
|
|
201
|
-
GetNodesRequest(run_id=self.
|
|
196
|
+
GetNodesRequest(run_id=cast(Run, self._run).run_id)
|
|
202
197
|
)
|
|
203
198
|
return [node.node_id for node in res.nodes]
|
|
204
199
|
|
|
@@ -208,7 +203,6 @@ class GrpcDriver(Driver):
|
|
|
208
203
|
This method takes an iterable of messages and sends each message
|
|
209
204
|
to the node specified in `dst_node_id`.
|
|
210
205
|
"""
|
|
211
|
-
self._init_run()
|
|
212
206
|
# Construct TaskIns
|
|
213
207
|
task_ins_list: list[TaskIns] = []
|
|
214
208
|
for msg in messages:
|
|
@@ -230,7 +224,6 @@ class GrpcDriver(Driver):
|
|
|
230
224
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
231
225
|
set of given message IDs.
|
|
232
226
|
"""
|
|
233
|
-
self._init_run()
|
|
234
227
|
# Pull TaskRes
|
|
235
228
|
res: PullTaskResResponse = self._stub.PullTaskRes(
|
|
236
229
|
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
@@ -35,8 +35,6 @@ class InMemoryDriver(Driver):
|
|
|
35
35
|
|
|
36
36
|
Parameters
|
|
37
37
|
----------
|
|
38
|
-
run_id : int
|
|
39
|
-
The identifier of the run.
|
|
40
38
|
state_factory : StateFactory
|
|
41
39
|
A StateFactory embedding a state that this driver can interface with.
|
|
42
40
|
pull_interval : float (default=0.1)
|
|
@@ -45,18 +43,15 @@ class InMemoryDriver(Driver):
|
|
|
45
43
|
|
|
46
44
|
def __init__(
|
|
47
45
|
self,
|
|
48
|
-
run_id: int,
|
|
49
46
|
state_factory: LinkStateFactory,
|
|
50
47
|
pull_interval: float = 0.1,
|
|
51
48
|
) -> None:
|
|
52
|
-
self._run_id = run_id
|
|
53
49
|
self._run: Optional[Run] = None
|
|
54
50
|
self.state = state_factory.state()
|
|
55
51
|
self.pull_interval = pull_interval
|
|
56
52
|
self.node = Node(node_id=0, anonymous=True)
|
|
57
53
|
|
|
58
54
|
def _check_message(self, message: Message) -> None:
|
|
59
|
-
self._init_run()
|
|
60
55
|
# Check if the message is valid
|
|
61
56
|
if not (
|
|
62
57
|
message.metadata.run_id == cast(Run, self._run).run_id
|
|
@@ -67,19 +62,18 @@ class InMemoryDriver(Driver):
|
|
|
67
62
|
):
|
|
68
63
|
raise ValueError(f"Invalid message: {message}")
|
|
69
64
|
|
|
70
|
-
def
|
|
65
|
+
def init_run(self, run_id: int) -> None:
|
|
71
66
|
"""Initialize the run."""
|
|
72
67
|
if self._run is not None:
|
|
73
68
|
return
|
|
74
|
-
run = self.state.get_run(
|
|
69
|
+
run = self.state.get_run(run_id)
|
|
75
70
|
if run is None:
|
|
76
|
-
raise RuntimeError(f"Cannot find the run with ID: {
|
|
71
|
+
raise RuntimeError(f"Cannot find the run with ID: {run_id}")
|
|
77
72
|
self._run = run
|
|
78
73
|
|
|
79
74
|
@property
|
|
80
75
|
def run(self) -> Run:
|
|
81
76
|
"""Run ID."""
|
|
82
|
-
self._init_run()
|
|
83
77
|
return Run(**vars(cast(Run, self._run)))
|
|
84
78
|
|
|
85
79
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
|
@@ -95,7 +89,6 @@ class InMemoryDriver(Driver):
|
|
|
95
89
|
This method constructs a new `Message` with given content and metadata.
|
|
96
90
|
The `run_id` and `src_node_id` will be set automatically.
|
|
97
91
|
"""
|
|
98
|
-
self._init_run()
|
|
99
92
|
if ttl:
|
|
100
93
|
warnings.warn(
|
|
101
94
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -119,7 +112,6 @@ class InMemoryDriver(Driver):
|
|
|
119
112
|
|
|
120
113
|
def get_node_ids(self) -> list[int]:
|
|
121
114
|
"""Get node IDs."""
|
|
122
|
-
self._init_run()
|
|
123
115
|
return list(self.state.get_nodes(cast(Run, self._run).run_id))
|
|
124
116
|
|
|
125
117
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -171,11 +171,11 @@ def run_server_app() -> None:
|
|
|
171
171
|
if app_path is None:
|
|
172
172
|
# User provided `--run-id`, but not `app_dir`
|
|
173
173
|
driver = GrpcDriver(
|
|
174
|
-
run_id=args.run_id,
|
|
175
174
|
driver_service_address=args.superlink,
|
|
176
175
|
root_certificates=root_certificates,
|
|
177
176
|
)
|
|
178
177
|
flwr_dir = get_flwr_dir(args.flwr_dir)
|
|
178
|
+
driver.init_run(args.run_id)
|
|
179
179
|
run_ = driver.run
|
|
180
180
|
if not run_.fab_hash:
|
|
181
181
|
raise ValueError("FAB hash not provided.")
|
|
@@ -193,7 +193,6 @@ def run_server_app() -> None:
|
|
|
193
193
|
# User provided `app_dir`, but not `--run-id`
|
|
194
194
|
# Create run if run_id is not provided
|
|
195
195
|
driver = GrpcDriver(
|
|
196
|
-
run_id=0, # Will be overwritten
|
|
197
196
|
driver_service_address=args.superlink,
|
|
198
197
|
root_certificates=root_certificates,
|
|
199
198
|
)
|
|
@@ -204,8 +203,8 @@ def run_server_app() -> None:
|
|
|
204
203
|
# Create run
|
|
205
204
|
req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version)
|
|
206
205
|
res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
|
|
207
|
-
#
|
|
208
|
-
driver.
|
|
206
|
+
# Fetch full `Run` using `run_id`
|
|
207
|
+
driver.init_run(res.run_id) # pylint: disable=W0212
|
|
209
208
|
|
|
210
209
|
# Obtain server app reference and the run config
|
|
211
210
|
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
|
flwr/server/serverapp/app.py
CHANGED
|
@@ -16,17 +16,53 @@
|
|
|
16
16
|
|
|
17
17
|
import argparse
|
|
18
18
|
import sys
|
|
19
|
-
from logging import DEBUG, INFO, WARN
|
|
19
|
+
from logging import DEBUG, ERROR, INFO, WARN
|
|
20
20
|
from os.path import isfile
|
|
21
21
|
from pathlib import Path
|
|
22
|
+
from queue import Queue
|
|
23
|
+
from time import sleep
|
|
22
24
|
from typing import Optional
|
|
23
25
|
|
|
24
|
-
from flwr.
|
|
26
|
+
from flwr.cli.config_utils import get_fab_metadata
|
|
27
|
+
from flwr.cli.install import install_from_fab
|
|
28
|
+
from flwr.common.config import (
|
|
29
|
+
get_flwr_dir,
|
|
30
|
+
get_fused_config_from_dir,
|
|
31
|
+
get_project_config,
|
|
32
|
+
get_project_dir,
|
|
33
|
+
)
|
|
34
|
+
from flwr.common.constant import Status, SubStatus
|
|
35
|
+
from flwr.common.logger import (
|
|
36
|
+
log,
|
|
37
|
+
mirror_output_to_queue,
|
|
38
|
+
restore_output,
|
|
39
|
+
start_log_uploader,
|
|
40
|
+
stop_log_uploader,
|
|
41
|
+
)
|
|
42
|
+
from flwr.common.serde import (
|
|
43
|
+
context_from_proto,
|
|
44
|
+
context_to_proto,
|
|
45
|
+
fab_from_proto,
|
|
46
|
+
run_from_proto,
|
|
47
|
+
run_status_to_proto,
|
|
48
|
+
)
|
|
49
|
+
from flwr.common.typing import RunStatus
|
|
50
|
+
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
51
|
+
PullServerAppInputsRequest,
|
|
52
|
+
PullServerAppInputsResponse,
|
|
53
|
+
PushServerAppOutputsRequest,
|
|
54
|
+
)
|
|
55
|
+
from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
|
|
25
56
|
from flwr.server.driver.grpc_driver import GrpcDriver
|
|
57
|
+
from flwr.server.run_serverapp import run as run_
|
|
26
58
|
|
|
27
59
|
|
|
28
60
|
def flwr_serverapp() -> None:
|
|
29
61
|
"""Run process-isolated Flower ServerApp."""
|
|
62
|
+
# Capture stdout/stderr
|
|
63
|
+
log_queue: Queue[Optional[str]] = Queue()
|
|
64
|
+
mirror_output_to_queue(log_queue)
|
|
65
|
+
|
|
30
66
|
log(INFO, "Starting Flower ServerApp")
|
|
31
67
|
|
|
32
68
|
parser = argparse.ArgumentParser(
|
|
@@ -38,11 +74,10 @@ def flwr_serverapp() -> None:
|
|
|
38
74
|
help="Address of SuperLink's DriverAPI",
|
|
39
75
|
)
|
|
40
76
|
parser.add_argument(
|
|
41
|
-
"--run-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
"function will request a pending run to the LinkState.",
|
|
77
|
+
"--run-once",
|
|
78
|
+
action="store_true",
|
|
79
|
+
help="When set, this process will start a single ServerApp "
|
|
80
|
+
"for a pending Run. If no pending run the process will exit. ",
|
|
46
81
|
)
|
|
47
82
|
parser.add_argument(
|
|
48
83
|
"--flwr-dir",
|
|
@@ -75,18 +110,20 @@ def flwr_serverapp() -> None:
|
|
|
75
110
|
|
|
76
111
|
log(
|
|
77
112
|
DEBUG,
|
|
78
|
-
"Staring isolated `ServerApp` connected to SuperLink DriverAPI at %s
|
|
79
|
-
"for run-id %s",
|
|
113
|
+
"Staring isolated `ServerApp` connected to SuperLink DriverAPI at %s",
|
|
80
114
|
args.superlink,
|
|
81
|
-
args.run_id,
|
|
82
115
|
)
|
|
83
116
|
run_serverapp(
|
|
84
117
|
superlink=args.superlink,
|
|
85
|
-
|
|
118
|
+
log_queue=log_queue,
|
|
119
|
+
run_once=args.run_once,
|
|
86
120
|
flwr_dir_=args.flwr_dir,
|
|
87
121
|
certificates=certificates,
|
|
88
122
|
)
|
|
89
123
|
|
|
124
|
+
# Restore stdout/stderr
|
|
125
|
+
restore_output()
|
|
126
|
+
|
|
90
127
|
|
|
91
128
|
def _try_obtain_certificates(
|
|
92
129
|
args: argparse.Namespace,
|
|
@@ -121,21 +158,114 @@ def _try_obtain_certificates(
|
|
|
121
158
|
return root_certificates
|
|
122
159
|
|
|
123
160
|
|
|
124
|
-
def run_serverapp( # pylint: disable=R0914
|
|
161
|
+
def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
125
162
|
superlink: str,
|
|
126
|
-
|
|
163
|
+
log_queue: Queue[Optional[str]],
|
|
164
|
+
run_once: bool,
|
|
127
165
|
flwr_dir_: Optional[str] = None,
|
|
128
166
|
certificates: Optional[bytes] = None,
|
|
129
167
|
) -> None:
|
|
130
168
|
"""Run Flower ServerApp process."""
|
|
131
|
-
|
|
132
|
-
run_id=run_id if run_id else 0,
|
|
169
|
+
driver = GrpcDriver(
|
|
133
170
|
driver_service_address=superlink,
|
|
134
171
|
root_certificates=certificates,
|
|
135
172
|
)
|
|
136
173
|
|
|
137
|
-
|
|
174
|
+
# Resolve directory where FABs are installed
|
|
175
|
+
flwr_dir = get_flwr_dir(flwr_dir_)
|
|
176
|
+
log_uploader = None
|
|
177
|
+
|
|
178
|
+
while True:
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
# Pull ServerAppInputs from LinkState
|
|
182
|
+
req = PullServerAppInputsRequest()
|
|
183
|
+
res: PullServerAppInputsResponse = driver._stub.PullServerAppInputs(req)
|
|
184
|
+
if not res.HasField("run"):
|
|
185
|
+
sleep(3)
|
|
186
|
+
run_status = None
|
|
187
|
+
continue
|
|
188
|
+
|
|
189
|
+
context = context_from_proto(res.context)
|
|
190
|
+
run = run_from_proto(res.run)
|
|
191
|
+
fab = fab_from_proto(res.fab)
|
|
192
|
+
|
|
193
|
+
driver.init_run(run.run_id)
|
|
194
|
+
|
|
195
|
+
# Start log uploader for this run
|
|
196
|
+
log_uploader = start_log_uploader(
|
|
197
|
+
log_queue=log_queue,
|
|
198
|
+
node_id=0,
|
|
199
|
+
run_id=run.run_id,
|
|
200
|
+
stub=driver._stub,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
log(DEBUG, "ServerApp process starts FAB installation.")
|
|
204
|
+
install_from_fab(fab.content, flwr_dir=flwr_dir, skip_prompt=True)
|
|
205
|
+
|
|
206
|
+
fab_id, fab_version = get_fab_metadata(fab.content)
|
|
207
|
+
|
|
208
|
+
app_path = str(get_project_dir(fab_id, fab_version, fab.hash_str, flwr_dir))
|
|
209
|
+
config = get_project_config(app_path)
|
|
210
|
+
|
|
211
|
+
# Obtain server app reference and the run config
|
|
212
|
+
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
|
|
213
|
+
server_app_run_config = get_fused_config_from_dir(
|
|
214
|
+
Path(app_path), run.override_config
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Update run_config in context
|
|
218
|
+
context.run_config = server_app_run_config
|
|
219
|
+
|
|
220
|
+
log(
|
|
221
|
+
DEBUG,
|
|
222
|
+
"Flower will load ServerApp `%s` in %s",
|
|
223
|
+
server_app_attr,
|
|
224
|
+
app_path,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Change status to Running
|
|
228
|
+
run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
|
|
229
|
+
driver._stub.UpdateRunStatus(
|
|
230
|
+
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Load and run the ServerApp with the Driver
|
|
234
|
+
updated_context = run_(
|
|
235
|
+
driver=driver,
|
|
236
|
+
server_app_dir=app_path,
|
|
237
|
+
server_app_attr=server_app_attr,
|
|
238
|
+
context=context,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Send resulting context
|
|
242
|
+
context_proto = context_to_proto(updated_context)
|
|
243
|
+
out_req = PushServerAppOutputsRequest(
|
|
244
|
+
run_id=run.run_id, context=context_proto
|
|
245
|
+
)
|
|
246
|
+
_ = driver._stub.PushServerAppOutputs(out_req)
|
|
247
|
+
|
|
248
|
+
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
|
249
|
+
|
|
250
|
+
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
251
|
+
exc_entity = "ServerApp"
|
|
252
|
+
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
|
253
|
+
run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))
|
|
254
|
+
|
|
255
|
+
finally:
|
|
256
|
+
if run_status:
|
|
257
|
+
run_status_proto = run_status_to_proto(run_status)
|
|
258
|
+
driver._stub.UpdateRunStatus(
|
|
259
|
+
UpdateRunStatusRequest(
|
|
260
|
+
run_id=run.run_id, run_status=run_status_proto
|
|
261
|
+
)
|
|
262
|
+
)
|
|
138
263
|
|
|
139
|
-
|
|
264
|
+
# Stop log uploader for this run
|
|
265
|
+
if log_uploader:
|
|
266
|
+
stop_log_uploader(log_queue, log_uploader)
|
|
267
|
+
log_uploader = None
|
|
140
268
|
|
|
141
|
-
|
|
269
|
+
# Stop the loop if `flwr-serverapp` is expected to process a single run
|
|
270
|
+
if run_once:
|
|
271
|
+
break
|
|
@@ -31,6 +31,7 @@ from flwr.common.serde import (
|
|
|
31
31
|
context_to_proto,
|
|
32
32
|
fab_from_proto,
|
|
33
33
|
fab_to_proto,
|
|
34
|
+
run_status_from_proto,
|
|
34
35
|
run_to_proto,
|
|
35
36
|
user_config_from_proto,
|
|
36
37
|
)
|
|
@@ -49,12 +50,18 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
|
49
50
|
PushTaskInsResponse,
|
|
50
51
|
)
|
|
51
52
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
53
|
+
from flwr.proto.log_pb2 import ( # pylint: disable=E0611
|
|
54
|
+
PushLogsRequest,
|
|
55
|
+
PushLogsResponse,
|
|
56
|
+
)
|
|
52
57
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
53
58
|
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
54
59
|
CreateRunRequest,
|
|
55
60
|
CreateRunResponse,
|
|
56
61
|
GetRunRequest,
|
|
57
62
|
GetRunResponse,
|
|
63
|
+
UpdateRunStatusRequest,
|
|
64
|
+
UpdateRunStatusResponse,
|
|
58
65
|
)
|
|
59
66
|
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
60
67
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
@@ -212,11 +219,8 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
212
219
|
|
|
213
220
|
# Lock access to LinkState, preventing obtaining the same pending run_id
|
|
214
221
|
with self.lock:
|
|
215
|
-
#
|
|
216
|
-
|
|
217
|
-
run_id: Optional[int] = request.run_id
|
|
218
|
-
else:
|
|
219
|
-
run_id = state.get_pending_run_id()
|
|
222
|
+
# Attempt getting the run_id of a pending run
|
|
223
|
+
run_id = state.get_pending_run_id()
|
|
220
224
|
# If there's no pending run, return an empty response
|
|
221
225
|
if run_id is None:
|
|
222
226
|
return PullServerAppInputsResponse()
|
|
@@ -228,14 +232,12 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
228
232
|
if run and run.fab_hash:
|
|
229
233
|
if result := ffs.get(run.fab_hash):
|
|
230
234
|
fab = Fab(run.fab_hash, result[0])
|
|
231
|
-
if run and fab:
|
|
235
|
+
if run and fab and serverapp_ctxt:
|
|
232
236
|
# Update run status to STARTING
|
|
233
237
|
if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
|
|
234
238
|
log(INFO, "Starting run %d", run_id)
|
|
235
239
|
return PullServerAppInputsResponse(
|
|
236
|
-
context=(
|
|
237
|
-
context_to_proto(serverapp_ctxt) if serverapp_ctxt else None
|
|
238
|
-
),
|
|
240
|
+
context=context_to_proto(serverapp_ctxt),
|
|
239
241
|
run=run_to_proto(run),
|
|
240
242
|
fab=fab_to_proto(fab),
|
|
241
243
|
)
|
|
@@ -253,6 +255,31 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
253
255
|
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
|
254
256
|
return PushServerAppOutputsResponse()
|
|
255
257
|
|
|
258
|
+
def UpdateRunStatus(
|
|
259
|
+
self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
|
|
260
|
+
) -> UpdateRunStatusResponse:
|
|
261
|
+
"""Update the status of a run."""
|
|
262
|
+
log(DEBUG, "ControlServicer.UpdateRunStatus")
|
|
263
|
+
state = self.state_factory.state()
|
|
264
|
+
|
|
265
|
+
# Update the run status
|
|
266
|
+
state.update_run_status(
|
|
267
|
+
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
|
|
268
|
+
)
|
|
269
|
+
return UpdateRunStatusResponse()
|
|
270
|
+
|
|
271
|
+
def PushLogs(
|
|
272
|
+
self, request: PushLogsRequest, context: grpc.ServicerContext
|
|
273
|
+
) -> PushLogsResponse:
|
|
274
|
+
"""Push logs."""
|
|
275
|
+
log(DEBUG, "DriverServicer.PushLogs")
|
|
276
|
+
state = self.state_factory.state()
|
|
277
|
+
|
|
278
|
+
# Add logs to LinkState
|
|
279
|
+
merged_logs = "".join(request.logs)
|
|
280
|
+
state.add_serverapp_log(request.run_id, merged_logs)
|
|
281
|
+
return PushLogsResponse()
|
|
282
|
+
|
|
256
283
|
|
|
257
284
|
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
258
285
|
if validation_error:
|
|
@@ -17,7 +17,8 @@
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
19
|
import time
|
|
20
|
-
from
|
|
20
|
+
from bisect import bisect_right
|
|
21
|
+
from dataclasses import dataclass, field
|
|
21
22
|
from logging import ERROR, WARNING
|
|
22
23
|
from typing import Optional
|
|
23
24
|
from uuid import UUID, uuid4
|
|
@@ -43,7 +44,7 @@ from .utils import (
|
|
|
43
44
|
|
|
44
45
|
|
|
45
46
|
@dataclass
|
|
46
|
-
class RunRecord:
|
|
47
|
+
class RunRecord: # pylint: disable=R0902
|
|
47
48
|
"""The record of a specific run, including its status and timestamps."""
|
|
48
49
|
|
|
49
50
|
run: Run
|
|
@@ -52,6 +53,8 @@ class RunRecord:
|
|
|
52
53
|
starting_at: str = ""
|
|
53
54
|
running_at: str = ""
|
|
54
55
|
finished_at: str = ""
|
|
56
|
+
logs: list[tuple[float, str]] = field(default_factory=list)
|
|
57
|
+
log_lock: threading.Lock = field(default_factory=threading.Lock)
|
|
55
58
|
|
|
56
59
|
|
|
57
60
|
class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
@@ -511,3 +514,26 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
511
514
|
if run_id not in self.run_ids:
|
|
512
515
|
raise ValueError(f"Run {run_id} not found")
|
|
513
516
|
self.contexts[run_id] = context
|
|
517
|
+
|
|
518
|
+
def add_serverapp_log(self, run_id: int, log_message: str) -> None:
|
|
519
|
+
"""Add a log entry to the serverapp logs for the specified `run_id`."""
|
|
520
|
+
if run_id not in self.run_ids:
|
|
521
|
+
raise ValueError(f"Run {run_id} not found")
|
|
522
|
+
run = self.run_ids[run_id]
|
|
523
|
+
with run.log_lock:
|
|
524
|
+
run.logs.append((now().timestamp(), log_message))
|
|
525
|
+
|
|
526
|
+
def get_serverapp_log(
|
|
527
|
+
self, run_id: int, after_timestamp: Optional[float]
|
|
528
|
+
) -> tuple[str, float]:
|
|
529
|
+
"""Get the serverapp logs for the specified `run_id`."""
|
|
530
|
+
if run_id not in self.run_ids:
|
|
531
|
+
raise ValueError(f"Run {run_id} not found")
|
|
532
|
+
run = self.run_ids[run_id]
|
|
533
|
+
if after_timestamp is None:
|
|
534
|
+
after_timestamp = 0.0
|
|
535
|
+
with run.log_lock:
|
|
536
|
+
# Find the index where the timestamp would be inserted
|
|
537
|
+
index = bisect_right(run.logs, (after_timestamp, ""))
|
|
538
|
+
latest_timestamp = run.logs[-1][0] if index < len(run.logs) else 0.0
|
|
539
|
+
return "".join(log for _, log in run.logs[index:]), latest_timestamp
|
|
@@ -299,3 +299,38 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
299
299
|
context : Context
|
|
300
300
|
The context to be associated with the specified `run_id`.
|
|
301
301
|
"""
|
|
302
|
+
|
|
303
|
+
@abc.abstractmethod
|
|
304
|
+
def add_serverapp_log(self, run_id: int, log_message: str) -> None:
|
|
305
|
+
"""Add a log entry to the ServerApp logs for the specified `run_id`.
|
|
306
|
+
|
|
307
|
+
Parameters
|
|
308
|
+
----------
|
|
309
|
+
run_id : int
|
|
310
|
+
The identifier of the run for which to add a log entry.
|
|
311
|
+
log_message : str
|
|
312
|
+
The log entry to be added to the ServerApp logs.
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
@abc.abstractmethod
|
|
316
|
+
def get_serverapp_log(
|
|
317
|
+
self, run_id: int, after_timestamp: Optional[float]
|
|
318
|
+
) -> tuple[str, float]:
|
|
319
|
+
"""Get the ServerApp logs for the specified `run_id`.
|
|
320
|
+
|
|
321
|
+
Parameters
|
|
322
|
+
----------
|
|
323
|
+
run_id : int
|
|
324
|
+
The identifier of the run for which to retrieve the ServerApp logs.
|
|
325
|
+
|
|
326
|
+
after_timestamp : Optional[float]
|
|
327
|
+
Retrieve logs after this timestamp. If set to `None`, retrieve all logs.
|
|
328
|
+
|
|
329
|
+
Returns
|
|
330
|
+
-------
|
|
331
|
+
tuple[str, float]
|
|
332
|
+
A tuple containing:
|
|
333
|
+
- The ServerApp logs associated with the specified `run_id`.
|
|
334
|
+
- The timestamp of the latest log entry in the returned logs.
|
|
335
|
+
Returns `0` if no logs are returned.
|
|
336
|
+
"""
|