flwr-nightly 1.13.0.dev20241028__py3-none-any.whl → 1.13.0.dev20241030__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (36) hide show
  1. flwr/cli/build.py +2 -2
  2. flwr/cli/log.py +46 -17
  3. flwr/common/constant.py +6 -0
  4. flwr/common/date.py +3 -3
  5. flwr/common/logger.py +103 -0
  6. flwr/common/serde.py +22 -0
  7. flwr/proto/driver_pb2.py +24 -23
  8. flwr/proto/driver_pb2.pyi +0 -5
  9. flwr/proto/driver_pb2_grpc.py +69 -0
  10. flwr/proto/driver_pb2_grpc.pyi +27 -0
  11. flwr/proto/exec_pb2.py +6 -6
  12. flwr/proto/exec_pb2.pyi +8 -2
  13. flwr/proto/log_pb2.py +29 -0
  14. flwr/proto/log_pb2.pyi +39 -0
  15. flwr/proto/log_pb2_grpc.py +4 -0
  16. flwr/proto/log_pb2_grpc.pyi +4 -0
  17. flwr/server/app.py +1 -5
  18. flwr/server/driver/driver.py +14 -0
  19. flwr/server/driver/grpc_driver.py +8 -15
  20. flwr/server/driver/inmemory_driver.py +3 -11
  21. flwr/server/run_serverapp.py +3 -4
  22. flwr/server/serverapp/app.py +148 -18
  23. flwr/server/superlink/driver/driver_servicer.py +36 -9
  24. flwr/server/superlink/linkstate/in_memory_linkstate.py +28 -2
  25. flwr/server/superlink/linkstate/linkstate.py +35 -0
  26. flwr/server/superlink/linkstate/sqlite_linkstate.py +50 -0
  27. flwr/simulation/run_simulation.py +2 -1
  28. flwr/superexec/deployment.py +22 -40
  29. flwr/superexec/exec_servicer.py +23 -62
  30. flwr/superexec/executor.py +3 -4
  31. flwr/superexec/simulation.py +4 -7
  32. {flwr_nightly-1.13.0.dev20241028.dist-info → flwr_nightly-1.13.0.dev20241030.dist-info}/METADATA +1 -1
  33. {flwr_nightly-1.13.0.dev20241028.dist-info → flwr_nightly-1.13.0.dev20241030.dist-info}/RECORD +36 -32
  34. {flwr_nightly-1.13.0.dev20241028.dist-info → flwr_nightly-1.13.0.dev20241030.dist-info}/LICENSE +0 -0
  35. {flwr_nightly-1.13.0.dev20241028.dist-info → flwr_nightly-1.13.0.dev20241030.dist-info}/WHEEL +0 -0
  36. {flwr_nightly-1.13.0.dev20241028.dist-info → flwr_nightly-1.13.0.dev20241030.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,4 @@
1
+ """
2
+ @generated by mypy-protobuf. Do not edit manually!
3
+ isort:skip_file
4
+ """
flwr/server/app.py CHANGED
@@ -382,9 +382,7 @@ def _flwr_serverapp_scheduler(
382
382
 
383
383
  log(
384
384
  INFO,
385
- "Launching `flwr-serverapp` subprocess with run-id %d. "
386
- "Connects to SuperLink on %s",
387
- pending_run_id,
385
+ "Launching `flwr-serverapp` subprocess. Connects to SuperLink on %s",
388
386
  driver_api_address,
389
387
  )
390
388
  # Start ServerApp subprocess
@@ -392,8 +390,6 @@ def _flwr_serverapp_scheduler(
392
390
  "flwr-serverapp",
393
391
  "--superlink",
394
392
  driver_api_address,
395
- "--run-id",
396
- str(pending_run_id),
397
393
  ]
398
394
  if ssl_ca_certfile:
399
395
  command.append("--root-certificates")
@@ -26,6 +26,20 @@ from flwr.common.typing import Run
26
26
  class Driver(ABC):
27
27
  """Abstract base Driver class for the Driver API."""
28
28
 
29
+ @abstractmethod
30
+ def init_run(self, run_id: int) -> None:
31
+ """Request a run to the SuperLink with a given `run_id`.
32
+
33
+ If a Run with the specified `run_id` exists, a local Run
34
+ object will be created. It enables further functionality
35
+ in the driver, such as sending `Messages`.
36
+
37
+ Parameters
38
+ ----------
39
+ run_id : int
40
+ The `run_id` of the Run this Driver object operates in.
41
+ """
42
+
29
43
  @property
30
44
  @abstractmethod
31
45
  def run(self) -> Run:
@@ -60,8 +60,6 @@ class GrpcDriver(Driver):
60
60
 
61
61
  Parameters
62
62
  ----------
63
- run_id : int
64
- The identifier of the run.
65
63
  driver_service_address : str (default: "[::]:9091")
66
64
  The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
67
65
  root_certificates : Optional[bytes] (default: None)
@@ -72,11 +70,9 @@ class GrpcDriver(Driver):
72
70
 
73
71
  def __init__( # pylint: disable=too-many-arguments
74
72
  self,
75
- run_id: int,
76
73
  driver_service_address: str = DRIVER_API_DEFAULT_ADDRESS,
77
74
  root_certificates: Optional[bytes] = None,
78
75
  ) -> None:
79
- self._run_id = run_id
80
76
  self._addr = driver_service_address
81
77
  self._cert = root_certificates
82
78
  self._run: Optional[Run] = None
@@ -116,15 +112,17 @@ class GrpcDriver(Driver):
116
112
  channel.close()
117
113
  log(DEBUG, "[Driver] Disconnected")
118
114
 
119
- def _init_run(self) -> None:
115
+ def init_run(self, run_id: int) -> None:
116
+ """Initialize the run."""
120
117
  # Check if is initialized
121
118
  if self._run is not None:
122
119
  return
120
+
123
121
  # Get the run info
124
- req = GetRunRequest(run_id=self._run_id)
122
+ req = GetRunRequest(run_id=run_id)
125
123
  res: GetRunResponse = self._stub.GetRun(req)
126
124
  if not res.HasField("run"):
127
- raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
125
+ raise RuntimeError(f"Cannot find the run with ID: {run_id}")
128
126
  self._run = Run(
129
127
  run_id=res.run.run_id,
130
128
  fab_id=res.run.fab_id,
@@ -136,7 +134,6 @@ class GrpcDriver(Driver):
136
134
  @property
137
135
  def run(self) -> Run:
138
136
  """Run information."""
139
- self._init_run()
140
137
  return Run(**vars(self._run))
141
138
 
142
139
  @property
@@ -150,7 +147,7 @@ class GrpcDriver(Driver):
150
147
  # Check if the message is valid
151
148
  if not (
152
149
  # Assume self._run being initialized
153
- message.metadata.run_id == self._run_id
150
+ message.metadata.run_id == cast(Run, self._run).run_id
154
151
  and message.metadata.src_node_id == self.node.node_id
155
152
  and message.metadata.message_id == ""
156
153
  and message.metadata.reply_to_message == ""
@@ -171,7 +168,6 @@ class GrpcDriver(Driver):
171
168
  This method constructs a new `Message` with given content and metadata.
172
169
  The `run_id` and `src_node_id` will be set automatically.
173
170
  """
174
- self._init_run()
175
171
  if ttl:
176
172
  warnings.warn(
177
173
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -182,7 +178,7 @@ class GrpcDriver(Driver):
182
178
 
183
179
  ttl_ = DEFAULT_TTL if ttl is None else ttl
184
180
  metadata = Metadata(
185
- run_id=self._run_id,
181
+ run_id=cast(Run, self._run).run_id,
186
182
  message_id="", # Will be set by the server
187
183
  src_node_id=self.node.node_id,
188
184
  dst_node_id=dst_node_id,
@@ -195,10 +191,9 @@ class GrpcDriver(Driver):
195
191
 
196
192
  def get_node_ids(self) -> list[int]:
197
193
  """Get node IDs."""
198
- self._init_run()
199
194
  # Call GrpcDriverStub method
200
195
  res: GetNodesResponse = self._stub.GetNodes(
201
- GetNodesRequest(run_id=self._run_id)
196
+ GetNodesRequest(run_id=cast(Run, self._run).run_id)
202
197
  )
203
198
  return [node.node_id for node in res.nodes]
204
199
 
@@ -208,7 +203,6 @@ class GrpcDriver(Driver):
208
203
  This method takes an iterable of messages and sends each message
209
204
  to the node specified in `dst_node_id`.
210
205
  """
211
- self._init_run()
212
206
  # Construct TaskIns
213
207
  task_ins_list: list[TaskIns] = []
214
208
  for msg in messages:
@@ -230,7 +224,6 @@ class GrpcDriver(Driver):
230
224
  This method is used to collect messages from the SuperLink that correspond to a
231
225
  set of given message IDs.
232
226
  """
233
- self._init_run()
234
227
  # Pull TaskRes
235
228
  res: PullTaskResResponse = self._stub.PullTaskRes(
236
229
  PullTaskResRequest(node=self.node, task_ids=message_ids)
@@ -35,8 +35,6 @@ class InMemoryDriver(Driver):
35
35
 
36
36
  Parameters
37
37
  ----------
38
- run_id : int
39
- The identifier of the run.
40
38
  state_factory : StateFactory
41
39
  A StateFactory embedding a state that this driver can interface with.
42
40
  pull_interval : float (default=0.1)
@@ -45,18 +43,15 @@ class InMemoryDriver(Driver):
45
43
 
46
44
  def __init__(
47
45
  self,
48
- run_id: int,
49
46
  state_factory: LinkStateFactory,
50
47
  pull_interval: float = 0.1,
51
48
  ) -> None:
52
- self._run_id = run_id
53
49
  self._run: Optional[Run] = None
54
50
  self.state = state_factory.state()
55
51
  self.pull_interval = pull_interval
56
52
  self.node = Node(node_id=0, anonymous=True)
57
53
 
58
54
  def _check_message(self, message: Message) -> None:
59
- self._init_run()
60
55
  # Check if the message is valid
61
56
  if not (
62
57
  message.metadata.run_id == cast(Run, self._run).run_id
@@ -67,19 +62,18 @@ class InMemoryDriver(Driver):
67
62
  ):
68
63
  raise ValueError(f"Invalid message: {message}")
69
64
 
70
- def _init_run(self) -> None:
65
+ def init_run(self, run_id: int) -> None:
71
66
  """Initialize the run."""
72
67
  if self._run is not None:
73
68
  return
74
- run = self.state.get_run(self._run_id)
69
+ run = self.state.get_run(run_id)
75
70
  if run is None:
76
- raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
71
+ raise RuntimeError(f"Cannot find the run with ID: {run_id}")
77
72
  self._run = run
78
73
 
79
74
  @property
80
75
  def run(self) -> Run:
81
76
  """Run ID."""
82
- self._init_run()
83
77
  return Run(**vars(cast(Run, self._run)))
84
78
 
85
79
  def create_message( # pylint: disable=too-many-arguments,R0917
@@ -95,7 +89,6 @@ class InMemoryDriver(Driver):
95
89
  This method constructs a new `Message` with given content and metadata.
96
90
  The `run_id` and `src_node_id` will be set automatically.
97
91
  """
98
- self._init_run()
99
92
  if ttl:
100
93
  warnings.warn(
101
94
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -119,7 +112,6 @@ class InMemoryDriver(Driver):
119
112
 
120
113
  def get_node_ids(self) -> list[int]:
121
114
  """Get node IDs."""
122
- self._init_run()
123
115
  return list(self.state.get_nodes(cast(Run, self._run).run_id))
124
116
 
125
117
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
@@ -171,11 +171,11 @@ def run_server_app() -> None:
171
171
  if app_path is None:
172
172
  # User provided `--run-id`, but not `app_dir`
173
173
  driver = GrpcDriver(
174
- run_id=args.run_id,
175
174
  driver_service_address=args.superlink,
176
175
  root_certificates=root_certificates,
177
176
  )
178
177
  flwr_dir = get_flwr_dir(args.flwr_dir)
178
+ driver.init_run(args.run_id)
179
179
  run_ = driver.run
180
180
  if not run_.fab_hash:
181
181
  raise ValueError("FAB hash not provided.")
@@ -193,7 +193,6 @@ def run_server_app() -> None:
193
193
  # User provided `app_dir`, but not `--run-id`
194
194
  # Create run if run_id is not provided
195
195
  driver = GrpcDriver(
196
- run_id=0, # Will be overwritten
197
196
  driver_service_address=args.superlink,
198
197
  root_certificates=root_certificates,
199
198
  )
@@ -204,8 +203,8 @@ def run_server_app() -> None:
204
203
  # Create run
205
204
  req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version)
206
205
  res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
207
- # Overwrite driver._run_id
208
- driver._run_id = res.run_id # pylint: disable=W0212
206
+ # Fetch full `Run` using `run_id`
207
+ driver.init_run(res.run_id) # pylint: disable=W0212
209
208
 
210
209
  # Obtain server app reference and the run config
211
210
  server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
@@ -16,17 +16,53 @@
16
16
 
17
17
  import argparse
18
18
  import sys
19
- from logging import DEBUG, INFO, WARN
19
+ from logging import DEBUG, ERROR, INFO, WARN
20
20
  from os.path import isfile
21
21
  from pathlib import Path
22
+ from queue import Queue
23
+ from time import sleep
22
24
  from typing import Optional
23
25
 
24
- from flwr.common.logger import log
26
+ from flwr.cli.config_utils import get_fab_metadata
27
+ from flwr.cli.install import install_from_fab
28
+ from flwr.common.config import (
29
+ get_flwr_dir,
30
+ get_fused_config_from_dir,
31
+ get_project_config,
32
+ get_project_dir,
33
+ )
34
+ from flwr.common.constant import Status, SubStatus
35
+ from flwr.common.logger import (
36
+ log,
37
+ mirror_output_to_queue,
38
+ restore_output,
39
+ start_log_uploader,
40
+ stop_log_uploader,
41
+ )
42
+ from flwr.common.serde import (
43
+ context_from_proto,
44
+ context_to_proto,
45
+ fab_from_proto,
46
+ run_from_proto,
47
+ run_status_to_proto,
48
+ )
49
+ from flwr.common.typing import RunStatus
50
+ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
51
+ PullServerAppInputsRequest,
52
+ PullServerAppInputsResponse,
53
+ PushServerAppOutputsRequest,
54
+ )
55
+ from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
25
56
  from flwr.server.driver.grpc_driver import GrpcDriver
57
+ from flwr.server.run_serverapp import run as run_
26
58
 
27
59
 
28
60
  def flwr_serverapp() -> None:
29
61
  """Run process-isolated Flower ServerApp."""
62
+ # Capture stdout/stderr
63
+ log_queue: Queue[Optional[str]] = Queue()
64
+ mirror_output_to_queue(log_queue)
65
+
30
66
  log(INFO, "Starting Flower ServerApp")
31
67
 
32
68
  parser = argparse.ArgumentParser(
@@ -38,11 +74,10 @@ def flwr_serverapp() -> None:
38
74
  help="Address of SuperLink's DriverAPI",
39
75
  )
40
76
  parser.add_argument(
41
- "--run-id",
42
- type=int,
43
- required=False,
44
- help="Id of the Run this process should start. If not supplied, this "
45
- "function will request a pending run to the LinkState.",
77
+ "--run-once",
78
+ action="store_true",
79
+ help="When set, this process will start a single ServerApp "
80
+ "for a pending Run. If no pending run the process will exit. ",
46
81
  )
47
82
  parser.add_argument(
48
83
  "--flwr-dir",
@@ -75,18 +110,20 @@ def flwr_serverapp() -> None:
75
110
 
76
111
  log(
77
112
  DEBUG,
78
- "Staring isolated `ServerApp` connected to SuperLink DriverAPI at %s "
79
- "for run-id %s",
113
+ "Staring isolated `ServerApp` connected to SuperLink DriverAPI at %s",
80
114
  args.superlink,
81
- args.run_id,
82
115
  )
83
116
  run_serverapp(
84
117
  superlink=args.superlink,
85
- run_id=args.run_id,
118
+ log_queue=log_queue,
119
+ run_once=args.run_once,
86
120
  flwr_dir_=args.flwr_dir,
87
121
  certificates=certificates,
88
122
  )
89
123
 
124
+ # Restore stdout/stderr
125
+ restore_output()
126
+
90
127
 
91
128
  def _try_obtain_certificates(
92
129
  args: argparse.Namespace,
@@ -121,21 +158,114 @@ def _try_obtain_certificates(
121
158
  return root_certificates
122
159
 
123
160
 
124
- def run_serverapp( # pylint: disable=R0914
161
+ def run_serverapp( # pylint: disable=R0914, disable=W0212
125
162
  superlink: str,
126
- run_id: Optional[int] = None,
163
+ log_queue: Queue[Optional[str]],
164
+ run_once: bool,
127
165
  flwr_dir_: Optional[str] = None,
128
166
  certificates: Optional[bytes] = None,
129
167
  ) -> None:
130
168
  """Run Flower ServerApp process."""
131
- _ = GrpcDriver(
132
- run_id=run_id if run_id else 0,
169
+ driver = GrpcDriver(
133
170
  driver_service_address=superlink,
134
171
  root_certificates=certificates,
135
172
  )
136
173
 
137
- log(INFO, "%s", flwr_dir_)
174
+ # Resolve directory where FABs are installed
175
+ flwr_dir = get_flwr_dir(flwr_dir_)
176
+ log_uploader = None
177
+
178
+ while True:
179
+
180
+ try:
181
+ # Pull ServerAppInputs from LinkState
182
+ req = PullServerAppInputsRequest()
183
+ res: PullServerAppInputsResponse = driver._stub.PullServerAppInputs(req)
184
+ if not res.HasField("run"):
185
+ sleep(3)
186
+ run_status = None
187
+ continue
188
+
189
+ context = context_from_proto(res.context)
190
+ run = run_from_proto(res.run)
191
+ fab = fab_from_proto(res.fab)
192
+
193
+ driver.init_run(run.run_id)
194
+
195
+ # Start log uploader for this run
196
+ log_uploader = start_log_uploader(
197
+ log_queue=log_queue,
198
+ node_id=0,
199
+ run_id=run.run_id,
200
+ stub=driver._stub,
201
+ )
202
+
203
+ log(DEBUG, "ServerApp process starts FAB installation.")
204
+ install_from_fab(fab.content, flwr_dir=flwr_dir, skip_prompt=True)
205
+
206
+ fab_id, fab_version = get_fab_metadata(fab.content)
207
+
208
+ app_path = str(get_project_dir(fab_id, fab_version, fab.hash_str, flwr_dir))
209
+ config = get_project_config(app_path)
210
+
211
+ # Obtain server app reference and the run config
212
+ server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
213
+ server_app_run_config = get_fused_config_from_dir(
214
+ Path(app_path), run.override_config
215
+ )
216
+
217
+ # Update run_config in context
218
+ context.run_config = server_app_run_config
219
+
220
+ log(
221
+ DEBUG,
222
+ "Flower will load ServerApp `%s` in %s",
223
+ server_app_attr,
224
+ app_path,
225
+ )
226
+
227
+ # Change status to Running
228
+ run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
229
+ driver._stub.UpdateRunStatus(
230
+ UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
231
+ )
232
+
233
+ # Load and run the ServerApp with the Driver
234
+ updated_context = run_(
235
+ driver=driver,
236
+ server_app_dir=app_path,
237
+ server_app_attr=server_app_attr,
238
+ context=context,
239
+ )
240
+
241
+ # Send resulting context
242
+ context_proto = context_to_proto(updated_context)
243
+ out_req = PushServerAppOutputsRequest(
244
+ run_id=run.run_id, context=context_proto
245
+ )
246
+ _ = driver._stub.PushServerAppOutputs(out_req)
247
+
248
+ run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
249
+
250
+ except Exception as ex: # pylint: disable=broad-exception-caught
251
+ exc_entity = "ServerApp"
252
+ log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
253
+ run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))
254
+
255
+ finally:
256
+ if run_status:
257
+ run_status_proto = run_status_to_proto(run_status)
258
+ driver._stub.UpdateRunStatus(
259
+ UpdateRunStatusRequest(
260
+ run_id=run.run_id, run_status=run_status_proto
261
+ )
262
+ )
138
263
 
139
- # Then, GetServerInputs
264
+ # Stop log uploader for this run
265
+ if log_uploader:
266
+ stop_log_uploader(log_queue, log_uploader)
267
+ log_uploader = None
140
268
 
141
- # Then, run ServerApp
269
+ # Stop the loop if `flwr-serverapp` is expected to process a single run
270
+ if run_once:
271
+ break
@@ -31,6 +31,7 @@ from flwr.common.serde import (
31
31
  context_to_proto,
32
32
  fab_from_proto,
33
33
  fab_to_proto,
34
+ run_status_from_proto,
34
35
  run_to_proto,
35
36
  user_config_from_proto,
36
37
  )
@@ -49,12 +50,18 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
49
50
  PushTaskInsResponse,
50
51
  )
51
52
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
53
+ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
54
+ PushLogsRequest,
55
+ PushLogsResponse,
56
+ )
52
57
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
53
58
  from flwr.proto.run_pb2 import ( # pylint: disable=E0611
54
59
  CreateRunRequest,
55
60
  CreateRunResponse,
56
61
  GetRunRequest,
57
62
  GetRunResponse,
63
+ UpdateRunStatusRequest,
64
+ UpdateRunStatusResponse,
58
65
  )
59
66
  from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
60
67
  from flwr.server.superlink.ffs.ffs import Ffs
@@ -212,11 +219,8 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
212
219
 
213
220
  # Lock access to LinkState, preventing obtaining the same pending run_id
214
221
  with self.lock:
215
- # If run_id is provided, use it, otherwise use the pending run_id
216
- if _has_field(request, "run_id"):
217
- run_id: Optional[int] = request.run_id
218
- else:
219
- run_id = state.get_pending_run_id()
222
+ # Attempt getting the run_id of a pending run
223
+ run_id = state.get_pending_run_id()
220
224
  # If there's no pending run, return an empty response
221
225
  if run_id is None:
222
226
  return PullServerAppInputsResponse()
@@ -228,14 +232,12 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
228
232
  if run and run.fab_hash:
229
233
  if result := ffs.get(run.fab_hash):
230
234
  fab = Fab(run.fab_hash, result[0])
231
- if run and fab:
235
+ if run and fab and serverapp_ctxt:
232
236
  # Update run status to STARTING
233
237
  if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
234
238
  log(INFO, "Starting run %d", run_id)
235
239
  return PullServerAppInputsResponse(
236
- context=(
237
- context_to_proto(serverapp_ctxt) if serverapp_ctxt else None
238
- ),
240
+ context=context_to_proto(serverapp_ctxt),
239
241
  run=run_to_proto(run),
240
242
  fab=fab_to_proto(fab),
241
243
  )
@@ -253,6 +255,31 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
253
255
  state.set_serverapp_context(request.run_id, context_from_proto(request.context))
254
256
  return PushServerAppOutputsResponse()
255
257
 
258
+ def UpdateRunStatus(
259
+ self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
260
+ ) -> UpdateRunStatusResponse:
261
+ """Update the status of a run."""
262
+ log(DEBUG, "ControlServicer.UpdateRunStatus")
263
+ state = self.state_factory.state()
264
+
265
+ # Update the run status
266
+ state.update_run_status(
267
+ run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
268
+ )
269
+ return UpdateRunStatusResponse()
270
+
271
+ def PushLogs(
272
+ self, request: PushLogsRequest, context: grpc.ServicerContext
273
+ ) -> PushLogsResponse:
274
+ """Push logs."""
275
+ log(DEBUG, "DriverServicer.PushLogs")
276
+ state = self.state_factory.state()
277
+
278
+ # Add logs to LinkState
279
+ merged_logs = "".join(request.logs)
280
+ state.add_serverapp_log(request.run_id, merged_logs)
281
+ return PushLogsResponse()
282
+
256
283
 
257
284
  def _raise_if(validation_error: bool, detail: str) -> None:
258
285
  if validation_error:
@@ -17,7 +17,8 @@
17
17
 
18
18
  import threading
19
19
  import time
20
- from dataclasses import dataclass
20
+ from bisect import bisect_right
21
+ from dataclasses import dataclass, field
21
22
  from logging import ERROR, WARNING
22
23
  from typing import Optional
23
24
  from uuid import UUID, uuid4
@@ -43,7 +44,7 @@ from .utils import (
43
44
 
44
45
 
45
46
  @dataclass
46
- class RunRecord:
47
+ class RunRecord: # pylint: disable=R0902
47
48
  """The record of a specific run, including its status and timestamps."""
48
49
 
49
50
  run: Run
@@ -52,6 +53,8 @@ class RunRecord:
52
53
  starting_at: str = ""
53
54
  running_at: str = ""
54
55
  finished_at: str = ""
56
+ logs: list[tuple[float, str]] = field(default_factory=list)
57
+ log_lock: threading.Lock = field(default_factory=threading.Lock)
55
58
 
56
59
 
57
60
  class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
@@ -511,3 +514,26 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
511
514
  if run_id not in self.run_ids:
512
515
  raise ValueError(f"Run {run_id} not found")
513
516
  self.contexts[run_id] = context
517
+
518
+ def add_serverapp_log(self, run_id: int, log_message: str) -> None:
519
+ """Add a log entry to the serverapp logs for the specified `run_id`."""
520
+ if run_id not in self.run_ids:
521
+ raise ValueError(f"Run {run_id} not found")
522
+ run = self.run_ids[run_id]
523
+ with run.log_lock:
524
+ run.logs.append((now().timestamp(), log_message))
525
+
526
+ def get_serverapp_log(
527
+ self, run_id: int, after_timestamp: Optional[float]
528
+ ) -> tuple[str, float]:
529
+ """Get the serverapp logs for the specified `run_id`."""
530
+ if run_id not in self.run_ids:
531
+ raise ValueError(f"Run {run_id} not found")
532
+ run = self.run_ids[run_id]
533
+ if after_timestamp is None:
534
+ after_timestamp = 0.0
535
+ with run.log_lock:
536
+ # Find the index where the timestamp would be inserted
537
+ index = bisect_right(run.logs, (after_timestamp, ""))
538
+ latest_timestamp = run.logs[-1][0] if index < len(run.logs) else 0.0
539
+ return "".join(log for _, log in run.logs[index:]), latest_timestamp
@@ -299,3 +299,38 @@ class LinkState(abc.ABC): # pylint: disable=R0904
299
299
  context : Context
300
300
  The context to be associated with the specified `run_id`.
301
301
  """
302
+
303
+ @abc.abstractmethod
304
+ def add_serverapp_log(self, run_id: int, log_message: str) -> None:
305
+ """Add a log entry to the ServerApp logs for the specified `run_id`.
306
+
307
+ Parameters
308
+ ----------
309
+ run_id : int
310
+ The identifier of the run for which to add a log entry.
311
+ log_message : str
312
+ The log entry to be added to the ServerApp logs.
313
+ """
314
+
315
+ @abc.abstractmethod
316
+ def get_serverapp_log(
317
+ self, run_id: int, after_timestamp: Optional[float]
318
+ ) -> tuple[str, float]:
319
+ """Get the ServerApp logs for the specified `run_id`.
320
+
321
+ Parameters
322
+ ----------
323
+ run_id : int
324
+ The identifier of the run for which to retrieve the ServerApp logs.
325
+
326
+ after_timestamp : Optional[float]
327
+ Retrieve logs after this timestamp. If set to `None`, retrieve all logs.
328
+
329
+ Returns
330
+ -------
331
+ tuple[str, float]
332
+ A tuple containing:
333
+ - The ServerApp logs associated with the specified `run_id`.
334
+ - The timestamp of the latest log entry in the returned logs.
335
+ Returns `0` if no logs are returned.
336
+ """