flwr-nightly 1.13.0.dev20241021__py3-none-any.whl → 1.13.0.dev20241111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (92) hide show
  1. flwr/cli/build.py +2 -2
  2. flwr/cli/config_utils.py +97 -0
  3. flwr/cli/log.py +63 -97
  4. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
  6. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  7. flwr/cli/run/run.py +34 -88
  8. flwr/client/app.py +23 -20
  9. flwr/client/clientapp/app.py +22 -18
  10. flwr/client/nodestate/__init__.py +25 -0
  11. flwr/client/nodestate/in_memory_nodestate.py +38 -0
  12. flwr/client/nodestate/nodestate.py +30 -0
  13. flwr/client/nodestate/nodestate_factory.py +37 -0
  14. flwr/client/{node_state.py → run_info_store.py} +4 -3
  15. flwr/client/supernode/app.py +6 -8
  16. flwr/common/args.py +83 -0
  17. flwr/common/config.py +10 -0
  18. flwr/common/constant.py +39 -5
  19. flwr/common/context.py +9 -4
  20. flwr/common/date.py +3 -3
  21. flwr/common/logger.py +108 -1
  22. flwr/common/object_ref.py +47 -16
  23. flwr/common/serde.py +24 -0
  24. flwr/common/telemetry.py +0 -6
  25. flwr/common/typing.py +10 -1
  26. flwr/proto/exec_pb2.py +14 -17
  27. flwr/proto/exec_pb2.pyi +14 -22
  28. flwr/proto/log_pb2.py +29 -0
  29. flwr/proto/log_pb2.pyi +39 -0
  30. flwr/proto/log_pb2_grpc.py +4 -0
  31. flwr/proto/log_pb2_grpc.pyi +4 -0
  32. flwr/proto/message_pb2.py +8 -8
  33. flwr/proto/message_pb2.pyi +4 -1
  34. flwr/proto/run_pb2.py +32 -27
  35. flwr/proto/run_pb2.pyi +26 -0
  36. flwr/proto/serverappio_pb2.py +52 -0
  37. flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
  38. flwr/proto/serverappio_pb2_grpc.py +376 -0
  39. flwr/proto/serverappio_pb2_grpc.pyi +147 -0
  40. flwr/proto/simulationio_pb2.py +38 -0
  41. flwr/proto/simulationio_pb2.pyi +65 -0
  42. flwr/proto/simulationio_pb2_grpc.py +205 -0
  43. flwr/proto/simulationio_pb2_grpc.pyi +81 -0
  44. flwr/server/app.py +272 -105
  45. flwr/server/driver/driver.py +15 -1
  46. flwr/server/driver/grpc_driver.py +25 -36
  47. flwr/server/driver/inmemory_driver.py +6 -16
  48. flwr/server/run_serverapp.py +29 -23
  49. flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
  50. flwr/server/serverapp/app.py +214 -0
  51. flwr/server/strategy/aggregate.py +4 -4
  52. flwr/server/strategy/fedadam.py +11 -1
  53. flwr/server/superlink/driver/__init__.py +1 -1
  54. flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
  55. flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
  56. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
  57. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
  58. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
  59. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
  60. flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
  61. flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
  62. flwr/server/superlink/fleet/vce/vce_api.py +23 -23
  63. flwr/server/superlink/linkstate/__init__.py +28 -0
  64. flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +184 -36
  65. flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +149 -19
  66. flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
  67. flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +306 -65
  68. flwr/server/superlink/{state → linkstate}/utils.py +81 -30
  69. flwr/server/superlink/simulation/__init__.py +15 -0
  70. flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
  71. flwr/server/superlink/simulation/simulationio_servicer.py +153 -0
  72. flwr/simulation/__init__.py +5 -1
  73. flwr/simulation/app.py +273 -345
  74. flwr/simulation/legacy_app.py +382 -0
  75. flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
  76. flwr/simulation/run_simulation.py +57 -131
  77. flwr/simulation/simulationio_connection.py +86 -0
  78. flwr/superexec/app.py +6 -134
  79. flwr/superexec/deployment.py +61 -66
  80. flwr/superexec/exec_grpc.py +15 -8
  81. flwr/superexec/exec_servicer.py +36 -65
  82. flwr/superexec/executor.py +26 -7
  83. flwr/superexec/simulation.py +54 -107
  84. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/METADATA +5 -4
  85. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/RECORD +88 -69
  86. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/entry_points.txt +2 -0
  87. flwr/client/node_state_tests.py +0 -66
  88. flwr/proto/driver_pb2.py +0 -42
  89. flwr/proto/driver_pb2_grpc.py +0 -239
  90. flwr/proto/driver_pb2_grpc.pyi +0 -94
  91. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/LICENSE +0 -0
  92. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/WHEEL +0 -0
@@ -28,7 +28,7 @@ from typing import Callable, Optional
28
28
 
29
29
  from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
30
30
  from flwr.client.clientapp.utils import get_load_client_app_fn
31
- from flwr.client.node_state import NodeState
31
+ from flwr.client.run_info_store import DeprecatedRunInfoStore
32
32
  from flwr.common.constant import (
33
33
  NUM_PARTITIONS_KEY,
34
34
  PARTITION_ID_KEY,
@@ -40,7 +40,7 @@ from flwr.common.message import Error
40
40
  from flwr.common.serde import message_from_taskins, message_to_taskres
41
41
  from flwr.common.typing import Run
42
42
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
43
- from flwr.server.superlink.state import State, StateFactory
43
+ from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
44
44
 
45
45
  from .backend import Backend, error_messages_backends, supported_backends
46
46
 
@@ -48,7 +48,7 @@ NodeToPartitionMapping = dict[int, int]
48
48
 
49
49
 
50
50
  def _register_nodes(
51
- num_nodes: int, state_factory: StateFactory
51
+ num_nodes: int, state_factory: LinkStateFactory
52
52
  ) -> NodeToPartitionMapping:
53
53
  """Register nodes with the StateFactory and create node-id:partition-id mapping."""
54
54
  nodes_mapping: NodeToPartitionMapping = {}
@@ -60,16 +60,16 @@ def _register_nodes(
60
60
  return nodes_mapping
61
61
 
62
62
 
63
- def _register_node_states(
63
+ def _register_node_info_stores(
64
64
  nodes_mapping: NodeToPartitionMapping,
65
65
  run: Run,
66
66
  app_dir: Optional[str] = None,
67
- ) -> dict[int, NodeState]:
68
- """Create NodeState objects and pre-register the context for the run."""
69
- node_states: dict[int, NodeState] = {}
67
+ ) -> dict[int, DeprecatedRunInfoStore]:
68
+ """Create DeprecatedRunInfoStore objects and register the context for the run."""
69
+ node_info_store: dict[int, DeprecatedRunInfoStore] = {}
70
70
  num_partitions = len(set(nodes_mapping.values()))
71
71
  for node_id, partition_id in nodes_mapping.items():
72
- node_states[node_id] = NodeState(
72
+ node_info_store[node_id] = DeprecatedRunInfoStore(
73
73
  node_id=node_id,
74
74
  node_config={
75
75
  PARTITION_ID_KEY: partition_id,
@@ -78,18 +78,18 @@ def _register_node_states(
78
78
  )
79
79
 
80
80
  # Pre-register Context objects
81
- node_states[node_id].register_context(
81
+ node_info_store[node_id].register_context(
82
82
  run_id=run.run_id, run=run, app_dir=app_dir
83
83
  )
84
84
 
85
- return node_states
85
+ return node_info_store
86
86
 
87
87
 
88
88
  # pylint: disable=too-many-arguments,too-many-locals
89
89
  def worker(
90
90
  taskins_queue: "Queue[TaskIns]",
91
91
  taskres_queue: "Queue[TaskRes]",
92
- node_states: dict[int, NodeState],
92
+ node_info_store: dict[int, DeprecatedRunInfoStore],
93
93
  backend: Backend,
94
94
  f_stop: threading.Event,
95
95
  ) -> None:
@@ -103,7 +103,7 @@ def worker(
103
103
  node_id = task_ins.task.consumer.node_id
104
104
 
105
105
  # Retrieve context
106
- context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
106
+ context = node_info_store[node_id].retrieve_context(run_id=task_ins.run_id)
107
107
 
108
108
  # Convert TaskIns to Message
109
109
  message = message_from_taskins(task_ins)
@@ -112,7 +112,7 @@ def worker(
112
112
  out_mssg, updated_context = backend.process_message(message, context)
113
113
 
114
114
  # Update Context
115
- node_states[node_id].update_context(
115
+ node_info_store[node_id].update_context(
116
116
  task_ins.run_id, context=updated_context
117
117
  )
118
118
  except Empty:
@@ -145,7 +145,7 @@ def worker(
145
145
 
146
146
 
147
147
  def add_taskins_to_queue(
148
- state: State,
148
+ state: LinkState,
149
149
  queue: "Queue[TaskIns]",
150
150
  nodes_mapping: NodeToPartitionMapping,
151
151
  f_stop: threading.Event,
@@ -160,7 +160,7 @@ def add_taskins_to_queue(
160
160
 
161
161
 
162
162
  def put_taskres_into_state(
163
- state: State, queue: "Queue[TaskRes]", f_stop: threading.Event
163
+ state: LinkState, queue: "Queue[TaskRes]", f_stop: threading.Event
164
164
  ) -> None:
165
165
  """Put TaskRes into State from a queue."""
166
166
  while not f_stop.is_set():
@@ -177,8 +177,8 @@ def run_api(
177
177
  app_fn: Callable[[], ClientApp],
178
178
  backend_fn: Callable[[], Backend],
179
179
  nodes_mapping: NodeToPartitionMapping,
180
- state_factory: StateFactory,
181
- node_states: dict[int, NodeState],
180
+ state_factory: LinkStateFactory,
181
+ node_info_stores: dict[int, DeprecatedRunInfoStore],
182
182
  f_stop: threading.Event,
183
183
  ) -> None:
184
184
  """Run the VCE."""
@@ -223,7 +223,7 @@ def run_api(
223
223
  worker,
224
224
  taskins_queue,
225
225
  taskres_queue,
226
- node_states,
226
+ node_info_stores,
227
227
  backend,
228
228
  f_stop,
229
229
  )
@@ -264,7 +264,7 @@ def start_vce(
264
264
  client_app: Optional[ClientApp] = None,
265
265
  client_app_attr: Optional[str] = None,
266
266
  num_supernodes: Optional[int] = None,
267
- state_factory: Optional[StateFactory] = None,
267
+ state_factory: Optional[LinkStateFactory] = None,
268
268
  existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
269
269
  ) -> None:
270
270
  """Start Fleet API with the Simulation Engine."""
@@ -303,7 +303,7 @@ def start_vce(
303
303
  if not state_factory:
304
304
  log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
305
305
  # Create an empty in-memory state factory
306
- state_factory = StateFactory(":flwr-in-memory-state:")
306
+ state_factory = LinkStateFactory(":flwr-in-memory-state:")
307
307
  log(INFO, "Created new %s.", state_factory.__class__.__name__)
308
308
 
309
309
  if num_supernodes:
@@ -312,8 +312,8 @@ def start_vce(
312
312
  num_nodes=num_supernodes, state_factory=state_factory
313
313
  )
314
314
 
315
- # Construct mapping of NodeStates
316
- node_states = _register_node_states(
315
+ # Construct mapping of DeprecatedRunInfoStore
316
+ node_info_stores = _register_node_info_stores(
317
317
  nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None
318
318
  )
319
319
 
@@ -376,7 +376,7 @@ def start_vce(
376
376
  backend_fn,
377
377
  nodes_mapping,
378
378
  state_factory,
379
- node_states,
379
+ node_info_stores,
380
380
  f_stop,
381
381
  )
382
382
  except LoadClientAppError as loadapp_ex:
@@ -0,0 +1,28 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower LinkState."""
16
+
17
+
18
+ from .in_memory_linkstate import InMemoryLinkState as InMemoryLinkState
19
+ from .linkstate import LinkState as LinkState
20
+ from .linkstate_factory import LinkStateFactory as LinkStateFactory
21
+ from .sqlite_linkstate import SqliteLinkState as SqliteLinkState
22
+
23
+ __all__ = [
24
+ "InMemoryLinkState",
25
+ "LinkState",
26
+ "LinkStateFactory",
27
+ "SqliteLinkState",
28
+ ]
@@ -12,31 +12,53 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """In-memory State implementation."""
15
+ """In-memory LinkState implementation."""
16
16
 
17
17
 
18
18
  import threading
19
19
  import time
20
+ from bisect import bisect_right
21
+ from dataclasses import dataclass, field
20
22
  from logging import ERROR, WARNING
21
23
  from typing import Optional
22
24
  from uuid import UUID, uuid4
23
25
 
24
- from flwr.common import log, now
26
+ from flwr.common import Context, log, now
25
27
  from flwr.common.constant import (
26
28
  MESSAGE_TTL_TOLERANCE,
27
29
  NODE_ID_NUM_BYTES,
28
30
  RUN_ID_NUM_BYTES,
31
+ Status,
29
32
  )
30
- from flwr.common.typing import Run, UserConfig
33
+ from flwr.common.record import ConfigsRecord
34
+ from flwr.common.typing import Run, RunStatus, UserConfig
31
35
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
32
- from flwr.server.superlink.state.state import State
36
+ from flwr.server.superlink.linkstate.linkstate import LinkState
33
37
  from flwr.server.utils import validate_task_ins_or_res
34
38
 
35
- from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
39
+ from .utils import (
40
+ generate_rand_int_from_bytes,
41
+ has_valid_sub_status,
42
+ is_valid_transition,
43
+ )
44
+
45
+
46
+ @dataclass
47
+ class RunRecord: # pylint: disable=R0902
48
+ """The record of a specific run, including its status and timestamps."""
36
49
 
50
+ run: Run
51
+ status: RunStatus
52
+ pending_at: str = ""
53
+ starting_at: str = ""
54
+ running_at: str = ""
55
+ finished_at: str = ""
56
+ logs: list[tuple[float, str]] = field(default_factory=list)
57
+ log_lock: threading.Lock = field(default_factory=threading.Lock)
37
58
 
38
- class InMemoryState(State): # pylint: disable=R0902,R0904
39
- """In-memory State implementation."""
59
+
60
+ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
61
+ """In-memory LinkState implementation."""
40
62
 
41
63
  def __init__(self) -> None:
42
64
 
@@ -44,8 +66,10 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
44
66
  self.node_ids: dict[int, tuple[float, float]] = {}
45
67
  self.public_key_to_node_id: dict[bytes, int] = {}
46
68
 
47
- # Map run_id to (fab_id, fab_version)
48
- self.run_ids: dict[int, Run] = {}
69
+ # Map run_id to RunRecord
70
+ self.run_ids: dict[int, RunRecord] = {}
71
+ self.contexts: dict[int, Context] = {}
72
+ self.federation_options: dict[int, ConfigsRecord] = {}
49
73
  self.task_ins_store: dict[UUID, TaskIns] = {}
50
74
  self.task_res_store: dict[UUID, TaskRes] = {}
51
75
 
@@ -64,8 +88,25 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
64
88
  return None
65
89
  # Validate run_id
66
90
  if task_ins.run_id not in self.run_ids:
67
- log(ERROR, "`run_id` is invalid")
91
+ log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
68
92
  return None
93
+ # Validate source node ID
94
+ if task_ins.task.producer.node_id != 0:
95
+ log(
96
+ ERROR,
97
+ "Invalid source node ID for TaskIns: %s",
98
+ task_ins.task.producer.node_id,
99
+ )
100
+ return None
101
+ # Validate destination node ID
102
+ if not task_ins.task.consumer.anonymous:
103
+ if task_ins.task.consumer.node_id not in self.node_ids:
104
+ log(
105
+ ERROR,
106
+ "Invalid destination node ID for TaskIns: %s",
107
+ task_ins.task.consumer.node_id,
108
+ )
109
+ return None
69
110
 
70
111
  # Create task_id
71
112
  task_id = uuid4()
@@ -215,21 +256,6 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
215
256
  task_res_list.append(task_res)
216
257
  replied_task_ids.add(reply_to)
217
258
 
218
- # Check if the node is offline
219
- for task_id in task_ids - replied_task_ids:
220
- task_ins = self.task_ins_store.get(task_id)
221
- if task_ins is None:
222
- continue
223
- node_id = task_ins.task.consumer.node_id
224
- online_until, _ = self.node_ids[node_id]
225
- # Generate a TaskRes containing an error reply if the node is offline.
226
- if online_until < time.time():
227
- err_taskres = make_node_unavailable_taskres(
228
- ref_taskins=task_ins,
229
- )
230
- self.task_res_store[UUID(err_taskres.task_id)] = err_taskres
231
- task_res_list.append(err_taskres)
232
-
233
259
  # Mark all of them as delivered
234
260
  delivered_at = now().isoformat()
235
261
  for task_res in task_res_list:
@@ -277,7 +303,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
277
303
  def create_node(
278
304
  self, ping_interval: float, public_key: Optional[bytes] = None
279
305
  ) -> int:
280
- """Create, store in state, and return `node_id`."""
306
+ """Create, store in the link state, and return `node_id`."""
281
307
  # Sample a random int64 as node_id
282
308
  node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
283
309
 
@@ -338,12 +364,14 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
338
364
  """Retrieve stored `node_id` filtered by `node_public_keys`."""
339
365
  return self.public_key_to_node_id.get(node_public_key)
340
366
 
367
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
341
368
  def create_run(
342
369
  self,
343
370
  fab_id: Optional[str],
344
371
  fab_version: Optional[str],
345
372
  fab_hash: Optional[str],
346
373
  override_config: UserConfig,
374
+ federation_options: ConfigsRecord,
347
375
  ) -> int:
348
376
  """Create a new run for the specified `fab_hash`."""
349
377
  # Sample a random int64 as run_id
@@ -351,13 +379,25 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
351
379
  run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
352
380
 
353
381
  if run_id not in self.run_ids:
354
- self.run_ids[run_id] = Run(
355
- run_id=run_id,
356
- fab_id=fab_id if fab_id else "",
357
- fab_version=fab_version if fab_version else "",
358
- fab_hash=fab_hash if fab_hash else "",
359
- override_config=override_config,
382
+ run_record = RunRecord(
383
+ run=Run(
384
+ run_id=run_id,
385
+ fab_id=fab_id if fab_id else "",
386
+ fab_version=fab_version if fab_version else "",
387
+ fab_hash=fab_hash if fab_hash else "",
388
+ override_config=override_config,
389
+ ),
390
+ status=RunStatus(
391
+ status=Status.PENDING,
392
+ sub_status="",
393
+ details="",
394
+ ),
395
+ pending_at=now().isoformat(),
360
396
  )
397
+ self.run_ids[run_id] = run_record
398
+
399
+ # Record federation options. Leave empty if not passed
400
+ self.federation_options[run_id] = federation_options
361
401
  return run_id
362
402
  log(ERROR, "Unexpected run creation failure.")
363
403
  return 0
@@ -365,7 +405,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
365
405
  def store_server_private_public_key(
366
406
  self, private_key: bytes, public_key: bytes
367
407
  ) -> None:
368
- """Store `server_private_key` and `server_public_key` in state."""
408
+ """Store `server_private_key` and `server_public_key` in the link state."""
369
409
  with self.lock:
370
410
  if self.server_private_key is None and self.server_public_key is None:
371
411
  self.server_private_key = private_key
@@ -382,12 +422,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
382
422
  return self.server_public_key
383
423
 
384
424
  def store_node_public_keys(self, public_keys: set[bytes]) -> None:
385
- """Store a set of `node_public_keys` in state."""
425
+ """Store a set of `node_public_keys` in the link state."""
386
426
  with self.lock:
387
427
  self.node_public_keys = public_keys
388
428
 
389
429
  def store_node_public_key(self, public_key: bytes) -> None:
390
- """Store a `node_public_key` in state."""
430
+ """Store a `node_public_key` in the link state."""
391
431
  with self.lock:
392
432
  self.node_public_keys.add(public_key)
393
433
 
@@ -395,13 +435,88 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
395
435
  """Retrieve all currently stored `node_public_keys` as a set."""
396
436
  return self.node_public_keys
397
437
 
438
+ def get_run_ids(self) -> set[int]:
439
+ """Retrieve all run IDs."""
440
+ with self.lock:
441
+ return set(self.run_ids.keys())
442
+
398
443
  def get_run(self, run_id: int) -> Optional[Run]:
399
444
  """Retrieve information about the run with the specified `run_id`."""
400
445
  with self.lock:
401
446
  if run_id not in self.run_ids:
402
447
  log(ERROR, "`run_id` is invalid")
403
448
  return None
404
- return self.run_ids[run_id]
449
+ return self.run_ids[run_id].run
450
+
451
+ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
452
+ """Retrieve the statuses for the specified runs."""
453
+ with self.lock:
454
+ return {
455
+ run_id: self.run_ids[run_id].status
456
+ for run_id in set(run_ids)
457
+ if run_id in self.run_ids
458
+ }
459
+
460
+ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
461
+ """Update the status of the run with the specified `run_id`."""
462
+ with self.lock:
463
+ # Check if the run_id exists
464
+ if run_id not in self.run_ids:
465
+ log(ERROR, "`run_id` is invalid")
466
+ return False
467
+
468
+ # Check if the status transition is valid
469
+ current_status = self.run_ids[run_id].status
470
+ if not is_valid_transition(current_status, new_status):
471
+ log(
472
+ ERROR,
473
+ 'Invalid status transition: from "%s" to "%s"',
474
+ current_status.status,
475
+ new_status.status,
476
+ )
477
+ return False
478
+
479
+ # Check if the sub-status is valid
480
+ if not has_valid_sub_status(current_status):
481
+ log(
482
+ ERROR,
483
+ 'Invalid sub-status "%s" for status "%s"',
484
+ current_status.sub_status,
485
+ current_status.status,
486
+ )
487
+ return False
488
+
489
+ # Update the status
490
+ run_record = self.run_ids[run_id]
491
+ if new_status.status == Status.STARTING:
492
+ run_record.starting_at = now().isoformat()
493
+ elif new_status.status == Status.RUNNING:
494
+ run_record.running_at = now().isoformat()
495
+ elif new_status.status == Status.FINISHED:
496
+ run_record.finished_at = now().isoformat()
497
+ run_record.status = new_status
498
+ return True
499
+
500
+ def get_pending_run_id(self) -> Optional[int]:
501
+ """Get the `run_id` of a run with `Status.PENDING` status, if any."""
502
+ pending_run_id = None
503
+
504
+ # Loop through all registered runs
505
+ for run_id, run_rec in self.run_ids.items():
506
+ # Break once a pending run is found
507
+ if run_rec.status.status == Status.PENDING:
508
+ pending_run_id = run_id
509
+ break
510
+
511
+ return pending_run_id
512
+
513
+ def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
514
+ """Retrieve the federation options for the specified `run_id`."""
515
+ with self.lock:
516
+ if run_id not in self.run_ids:
517
+ log(ERROR, "`run_id` is invalid")
518
+ return None
519
+ return self.federation_options[run_id]
405
520
 
406
521
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
407
522
  """Acknowledge a ping received from a node, serving as a heartbeat."""
@@ -410,3 +525,36 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
410
525
  self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
411
526
  return True
412
527
  return False
528
+
529
+ def get_serverapp_context(self, run_id: int) -> Optional[Context]:
530
+ """Get the context for the specified `run_id`."""
531
+ return self.contexts.get(run_id)
532
+
533
+ def set_serverapp_context(self, run_id: int, context: Context) -> None:
534
+ """Set the context for the specified `run_id`."""
535
+ if run_id not in self.run_ids:
536
+ raise ValueError(f"Run {run_id} not found")
537
+ self.contexts[run_id] = context
538
+
539
+ def add_serverapp_log(self, run_id: int, log_message: str) -> None:
540
+ """Add a log entry to the serverapp logs for the specified `run_id`."""
541
+ if run_id not in self.run_ids:
542
+ raise ValueError(f"Run {run_id} not found")
543
+ run = self.run_ids[run_id]
544
+ with run.log_lock:
545
+ run.logs.append((now().timestamp(), log_message))
546
+
547
+ def get_serverapp_log(
548
+ self, run_id: int, after_timestamp: Optional[float]
549
+ ) -> tuple[str, float]:
550
+ """Get the serverapp logs for the specified `run_id`."""
551
+ if run_id not in self.run_ids:
552
+ raise ValueError(f"Run {run_id} not found")
553
+ run = self.run_ids[run_id]
554
+ if after_timestamp is None:
555
+ after_timestamp = 0.0
556
+ with run.log_lock:
557
+ # Find the index where the timestamp would be inserted
558
+ index = bisect_right(run.logs, (after_timestamp, ""))
559
+ latest_timestamp = run.logs[-1][0] if index < len(run.logs) else 0.0
560
+ return "".join(log for _, log in run.logs[index:]), latest_timestamp