flwr-nightly 1.13.0.dev20241021__py3-none-any.whl → 1.13.0.dev20241023__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (34) hide show
  1. flwr/client/app.py +13 -14
  2. flwr/client/node_state_tests.py +7 -8
  3. flwr/client/{node_state.py → run_info_store.py} +3 -3
  4. flwr/client/supernode/app.py +6 -8
  5. flwr/common/constant.py +31 -3
  6. flwr/common/typing.py +9 -0
  7. flwr/server/app.py +121 -10
  8. flwr/server/driver/inmemory_driver.py +2 -2
  9. flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
  10. flwr/server/serverapp/app.py +78 -0
  11. flwr/server/superlink/driver/driver_grpc.py +2 -2
  12. flwr/server/superlink/driver/driver_servicer.py +9 -7
  13. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
  14. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
  15. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
  16. flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
  17. flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
  18. flwr/server/superlink/fleet/vce/vce_api.py +23 -23
  19. flwr/server/superlink/linkstate/__init__.py +28 -0
  20. flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +109 -19
  21. flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +59 -11
  22. flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
  23. flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +136 -35
  24. flwr/server/superlink/{state → linkstate}/utils.py +57 -1
  25. flwr/simulation/app.py +1 -1
  26. flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
  27. flwr/simulation/run_simulation.py +15 -7
  28. flwr/superexec/app.py +9 -2
  29. flwr/superexec/simulation.py +1 -1
  30. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241023.dist-info}/METADATA +1 -1
  31. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241023.dist-info}/RECORD +34 -32
  32. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241023.dist-info}/entry_points.txt +1 -0
  33. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241023.dist-info}/LICENSE +0 -0
  34. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241023.dist-info}/WHEEL +0 -0
@@ -51,14 +51,16 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
51
51
  from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
52
52
  from flwr.server.superlink.ffs.ffs import Ffs
53
53
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
54
- from flwr.server.superlink.state import State, StateFactory
54
+ from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
55
55
  from flwr.server.utils.validator import validate_task_ins_or_res
56
56
 
57
57
 
58
58
  class DriverServicer(driver_pb2_grpc.DriverServicer):
59
59
  """Driver API servicer."""
60
60
 
61
- def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
61
+ def __init__(
62
+ self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
63
+ ) -> None:
62
64
  self.state_factory = state_factory
63
65
  self.ffs_factory = ffs_factory
64
66
 
@@ -67,7 +69,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
67
69
  ) -> GetNodesResponse:
68
70
  """Get available nodes."""
69
71
  log(DEBUG, "DriverServicer.GetNodes")
70
- state: State = self.state_factory.state()
72
+ state: LinkState = self.state_factory.state()
71
73
  all_ids: set[int] = state.get_nodes(request.run_id)
72
74
  nodes: list[Node] = [
73
75
  Node(node_id=node_id, anonymous=False) for node_id in all_ids
@@ -79,7 +81,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
79
81
  ) -> CreateRunResponse:
80
82
  """Create run ID."""
81
83
  log(DEBUG, "DriverServicer.CreateRun")
82
- state: State = self.state_factory.state()
84
+ state: LinkState = self.state_factory.state()
83
85
  if request.HasField("fab"):
84
86
  fab = fab_from_proto(request.fab)
85
87
  ffs: Ffs = self.ffs_factory.ffs()
@@ -116,7 +118,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
116
118
  _raise_if(bool(validation_errors), ", ".join(validation_errors))
117
119
 
118
120
  # Init state
119
- state: State = self.state_factory.state()
121
+ state: LinkState = self.state_factory.state()
120
122
 
121
123
  # Store each TaskIns
122
124
  task_ids: list[Optional[UUID]] = []
@@ -138,7 +140,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
138
140
  task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
139
141
 
140
142
  # Init state
141
- state: State = self.state_factory.state()
143
+ state: LinkState = self.state_factory.state()
142
144
 
143
145
  # Register callback
144
146
  def on_rpc_done() -> None:
@@ -167,7 +169,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
167
169
  log(DEBUG, "DriverServicer.GetRun")
168
170
 
169
171
  # Init state
170
- state: State = self.state_factory.state()
172
+ state: LinkState = self.state_factory.state()
171
173
 
172
174
  # Retrieve run information
173
175
  run = state.get_run(request.run_id)
@@ -48,7 +48,7 @@ from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
48
48
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
49
49
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
50
50
  from flwr.server.superlink.fleet.message_handler import message_handler
51
- from flwr.server.superlink.state import StateFactory
51
+ from flwr.server.superlink.linkstate import LinkStateFactory
52
52
 
53
53
  T = TypeVar("T", bound=GrpcMessage)
54
54
 
@@ -77,7 +77,9 @@ def _handle(
77
77
  class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
78
78
  """Fleet API via GrpcAdapter servicer."""
79
79
 
80
- def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
80
+ def __init__(
81
+ self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
82
+ ) -> None:
81
83
  self.state_factory = state_factory
82
84
  self.ffs_factory = ffs_factory
83
85
 
@@ -37,13 +37,15 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
37
37
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
38
38
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
39
39
  from flwr.server.superlink.fleet.message_handler import message_handler
40
- from flwr.server.superlink.state import StateFactory
40
+ from flwr.server.superlink.linkstate import LinkStateFactory
41
41
 
42
42
 
43
43
  class FleetServicer(fleet_pb2_grpc.FleetServicer):
44
44
  """Fleet API servicer."""
45
45
 
46
- def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
46
+ def __init__(
47
+ self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
48
+ ) -> None:
47
49
  self.state_factory = state_factory
48
50
  self.ffs_factory = ffs_factory
49
51
 
@@ -45,7 +45,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
45
45
  )
46
46
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
47
47
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
48
- from flwr.server.superlink.state import State
48
+ from flwr.server.superlink.linkstate import LinkState
49
49
 
50
50
  _PUBLIC_KEY_HEADER = "public-key"
51
51
  _AUTH_TOKEN_HEADER = "auth-token"
@@ -84,7 +84,7 @@ def _get_value_from_tuples(
84
84
  class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
85
85
  """Server interceptor for node authentication."""
86
86
 
87
- def __init__(self, state: State):
87
+ def __init__(self, state: LinkState):
88
88
  self.state = state
89
89
 
90
90
  self.node_public_keys = state.get_node_public_keys()
@@ -43,12 +43,12 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
43
43
  )
44
44
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
45
45
  from flwr.server.superlink.ffs.ffs import Ffs
46
- from flwr.server.superlink.state import State
46
+ from flwr.server.superlink.linkstate import LinkState
47
47
 
48
48
 
49
49
  def create_node(
50
50
  request: CreateNodeRequest, # pylint: disable=unused-argument
51
- state: State,
51
+ state: LinkState,
52
52
  ) -> CreateNodeResponse:
53
53
  """."""
54
54
  # Create node
@@ -56,7 +56,7 @@ def create_node(
56
56
  return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
57
57
 
58
58
 
59
- def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
59
+ def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
60
60
  """."""
61
61
  # Validate node_id
62
62
  if request.node.anonymous or request.node.node_id == 0:
@@ -69,14 +69,14 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
69
69
 
70
70
  def ping(
71
71
  request: PingRequest, # pylint: disable=unused-argument
72
- state: State, # pylint: disable=unused-argument
72
+ state: LinkState, # pylint: disable=unused-argument
73
73
  ) -> PingResponse:
74
74
  """."""
75
75
  res = state.acknowledge_ping(request.node.node_id, request.ping_interval)
76
76
  return PingResponse(success=res)
77
77
 
78
78
 
79
- def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse:
79
+ def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse:
80
80
  """Pull TaskIns handler."""
81
81
  # Get node_id if client node is not anonymous
82
82
  node = request.node # pylint: disable=no-member
@@ -92,7 +92,7 @@ def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsRespo
92
92
  return response
93
93
 
94
94
 
95
- def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResResponse:
95
+ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResResponse:
96
96
  """Push TaskRes handler."""
97
97
  # pylint: disable=no-member
98
98
  task_res: TaskRes = request.task_res_list[0]
@@ -113,7 +113,7 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
113
113
 
114
114
 
115
115
  def get_run(
116
- request: GetRunRequest, state: State # pylint: disable=W0613
116
+ request: GetRunRequest, state: LinkState # pylint: disable=W0613
117
117
  ) -> GetRunResponse:
118
118
  """Get run information."""
119
119
  run = state.get_run(request.run_id)
@@ -40,7 +40,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
40
40
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
41
41
  from flwr.server.superlink.ffs.ffs import Ffs
42
42
  from flwr.server.superlink.fleet.message_handler import message_handler
43
- from flwr.server.superlink.state import State
43
+ from flwr.server.superlink.linkstate import LinkState
44
44
 
45
45
  try:
46
46
  from starlette.applications import Starlette
@@ -90,7 +90,7 @@ def rest_request_response(
90
90
  async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
91
91
  """Create Node."""
92
92
  # Get state from app
93
- state: State = app.state.STATE_FACTORY.state()
93
+ state: LinkState = app.state.STATE_FACTORY.state()
94
94
 
95
95
  # Handle message
96
96
  return message_handler.create_node(request=request, state=state)
@@ -100,7 +100,7 @@ async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
100
100
  async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
101
101
  """Delete Node Id."""
102
102
  # Get state from app
103
- state: State = app.state.STATE_FACTORY.state()
103
+ state: LinkState = app.state.STATE_FACTORY.state()
104
104
 
105
105
  # Handle message
106
106
  return message_handler.delete_node(request=request, state=state)
@@ -110,7 +110,7 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
110
110
  async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
111
111
  """Pull TaskIns."""
112
112
  # Get state from app
113
- state: State = app.state.STATE_FACTORY.state()
113
+ state: LinkState = app.state.STATE_FACTORY.state()
114
114
 
115
115
  # Handle message
116
116
  return message_handler.pull_task_ins(request=request, state=state)
@@ -121,7 +121,7 @@ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
121
121
  async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
122
122
  """Push TaskRes."""
123
123
  # Get state from app
124
- state: State = app.state.STATE_FACTORY.state()
124
+ state: LinkState = app.state.STATE_FACTORY.state()
125
125
 
126
126
  # Handle message
127
127
  return message_handler.push_task_res(request=request, state=state)
@@ -131,7 +131,7 @@ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
131
131
  async def ping(request: PingRequest) -> PingResponse:
132
132
  """Ping."""
133
133
  # Get state from app
134
- state: State = app.state.STATE_FACTORY.state()
134
+ state: LinkState = app.state.STATE_FACTORY.state()
135
135
 
136
136
  # Handle message
137
137
  return message_handler.ping(request=request, state=state)
@@ -141,7 +141,7 @@ async def ping(request: PingRequest) -> PingResponse:
141
141
  async def get_run(request: GetRunRequest) -> GetRunResponse:
142
142
  """GetRun."""
143
143
  # Get state from app
144
- state: State = app.state.STATE_FACTORY.state()
144
+ state: LinkState = app.state.STATE_FACTORY.state()
145
145
 
146
146
  # Handle message
147
147
  return message_handler.get_run(request=request, state=state)
@@ -28,7 +28,7 @@ from typing import Callable, Optional
28
28
 
29
29
  from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
30
30
  from flwr.client.clientapp.utils import get_load_client_app_fn
31
- from flwr.client.node_state import NodeState
31
+ from flwr.client.run_info_store import DeprecatedRunInfoStore
32
32
  from flwr.common.constant import (
33
33
  NUM_PARTITIONS_KEY,
34
34
  PARTITION_ID_KEY,
@@ -40,7 +40,7 @@ from flwr.common.message import Error
40
40
  from flwr.common.serde import message_from_taskins, message_to_taskres
41
41
  from flwr.common.typing import Run
42
42
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
43
- from flwr.server.superlink.state import State, StateFactory
43
+ from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
44
44
 
45
45
  from .backend import Backend, error_messages_backends, supported_backends
46
46
 
@@ -48,7 +48,7 @@ NodeToPartitionMapping = dict[int, int]
48
48
 
49
49
 
50
50
  def _register_nodes(
51
- num_nodes: int, state_factory: StateFactory
51
+ num_nodes: int, state_factory: LinkStateFactory
52
52
  ) -> NodeToPartitionMapping:
53
53
  """Register nodes with the StateFactory and create node-id:partition-id mapping."""
54
54
  nodes_mapping: NodeToPartitionMapping = {}
@@ -60,16 +60,16 @@ def _register_nodes(
60
60
  return nodes_mapping
61
61
 
62
62
 
63
- def _register_node_states(
63
+ def _register_node_info_stores(
64
64
  nodes_mapping: NodeToPartitionMapping,
65
65
  run: Run,
66
66
  app_dir: Optional[str] = None,
67
- ) -> dict[int, NodeState]:
68
- """Create NodeState objects and pre-register the context for the run."""
69
- node_states: dict[int, NodeState] = {}
67
+ ) -> dict[int, DeprecatedRunInfoStore]:
68
+ """Create DeprecatedRunInfoStore objects and register the context for the run."""
69
+ node_info_store: dict[int, DeprecatedRunInfoStore] = {}
70
70
  num_partitions = len(set(nodes_mapping.values()))
71
71
  for node_id, partition_id in nodes_mapping.items():
72
- node_states[node_id] = NodeState(
72
+ node_info_store[node_id] = DeprecatedRunInfoStore(
73
73
  node_id=node_id,
74
74
  node_config={
75
75
  PARTITION_ID_KEY: partition_id,
@@ -78,18 +78,18 @@ def _register_node_states(
78
78
  )
79
79
 
80
80
  # Pre-register Context objects
81
- node_states[node_id].register_context(
81
+ node_info_store[node_id].register_context(
82
82
  run_id=run.run_id, run=run, app_dir=app_dir
83
83
  )
84
84
 
85
- return node_states
85
+ return node_info_store
86
86
 
87
87
 
88
88
  # pylint: disable=too-many-arguments,too-many-locals
89
89
  def worker(
90
90
  taskins_queue: "Queue[TaskIns]",
91
91
  taskres_queue: "Queue[TaskRes]",
92
- node_states: dict[int, NodeState],
92
+ node_info_store: dict[int, DeprecatedRunInfoStore],
93
93
  backend: Backend,
94
94
  f_stop: threading.Event,
95
95
  ) -> None:
@@ -103,7 +103,7 @@ def worker(
103
103
  node_id = task_ins.task.consumer.node_id
104
104
 
105
105
  # Retrieve context
106
- context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
106
+ context = node_info_store[node_id].retrieve_context(run_id=task_ins.run_id)
107
107
 
108
108
  # Convert TaskIns to Message
109
109
  message = message_from_taskins(task_ins)
@@ -112,7 +112,7 @@ def worker(
112
112
  out_mssg, updated_context = backend.process_message(message, context)
113
113
 
114
114
  # Update Context
115
- node_states[node_id].update_context(
115
+ node_info_store[node_id].update_context(
116
116
  task_ins.run_id, context=updated_context
117
117
  )
118
118
  except Empty:
@@ -145,7 +145,7 @@ def worker(
145
145
 
146
146
 
147
147
  def add_taskins_to_queue(
148
- state: State,
148
+ state: LinkState,
149
149
  queue: "Queue[TaskIns]",
150
150
  nodes_mapping: NodeToPartitionMapping,
151
151
  f_stop: threading.Event,
@@ -160,7 +160,7 @@ def add_taskins_to_queue(
160
160
 
161
161
 
162
162
  def put_taskres_into_state(
163
- state: State, queue: "Queue[TaskRes]", f_stop: threading.Event
163
+ state: LinkState, queue: "Queue[TaskRes]", f_stop: threading.Event
164
164
  ) -> None:
165
165
  """Put TaskRes into State from a queue."""
166
166
  while not f_stop.is_set():
@@ -177,8 +177,8 @@ def run_api(
177
177
  app_fn: Callable[[], ClientApp],
178
178
  backend_fn: Callable[[], Backend],
179
179
  nodes_mapping: NodeToPartitionMapping,
180
- state_factory: StateFactory,
181
- node_states: dict[int, NodeState],
180
+ state_factory: LinkStateFactory,
181
+ node_info_stores: dict[int, DeprecatedRunInfoStore],
182
182
  f_stop: threading.Event,
183
183
  ) -> None:
184
184
  """Run the VCE."""
@@ -223,7 +223,7 @@ def run_api(
223
223
  worker,
224
224
  taskins_queue,
225
225
  taskres_queue,
226
- node_states,
226
+ node_info_stores,
227
227
  backend,
228
228
  f_stop,
229
229
  )
@@ -264,7 +264,7 @@ def start_vce(
264
264
  client_app: Optional[ClientApp] = None,
265
265
  client_app_attr: Optional[str] = None,
266
266
  num_supernodes: Optional[int] = None,
267
- state_factory: Optional[StateFactory] = None,
267
+ state_factory: Optional[LinkStateFactory] = None,
268
268
  existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
269
269
  ) -> None:
270
270
  """Start Fleet API with the Simulation Engine."""
@@ -303,7 +303,7 @@ def start_vce(
303
303
  if not state_factory:
304
304
  log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
305
305
  # Create an empty in-memory state factory
306
- state_factory = StateFactory(":flwr-in-memory-state:")
306
+ state_factory = LinkStateFactory(":flwr-in-memory-state:")
307
307
  log(INFO, "Created new %s.", state_factory.__class__.__name__)
308
308
 
309
309
  if num_supernodes:
@@ -312,8 +312,8 @@ def start_vce(
312
312
  num_nodes=num_supernodes, state_factory=state_factory
313
313
  )
314
314
 
315
- # Construct mapping of NodeStates
316
- node_states = _register_node_states(
315
+ # Construct mapping of DeprecatedRunInfoStore
316
+ node_info_stores = _register_node_info_stores(
317
317
  nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None
318
318
  )
319
319
 
@@ -376,7 +376,7 @@ def start_vce(
376
376
  backend_fn,
377
377
  nodes_mapping,
378
378
  state_factory,
379
- node_states,
379
+ node_info_stores,
380
380
  f_stop,
381
381
  )
382
382
  except LoadClientAppError as loadapp_ex:
@@ -0,0 +1,28 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower LinkState."""
16
+
17
+
18
+ from .in_memory_linkstate import InMemoryLinkState as InMemoryLinkState
19
+ from .linkstate import LinkState as LinkState
20
+ from .linkstate_factory import LinkStateFactory as LinkStateFactory
21
+ from .sqlite_linkstate import SqliteLinkState as SqliteLinkState
22
+
23
+ __all__ = [
24
+ "InMemoryLinkState",
25
+ "LinkState",
26
+ "LinkStateFactory",
27
+ "SqliteLinkState",
28
+ ]
@@ -12,11 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """In-memory State implementation."""
15
+ """In-memory LinkState implementation."""
16
16
 
17
17
 
18
18
  import threading
19
19
  import time
20
+ from dataclasses import dataclass
20
21
  from logging import ERROR, WARNING
21
22
  from typing import Optional
22
23
  from uuid import UUID, uuid4
@@ -26,17 +27,35 @@ from flwr.common.constant import (
26
27
  MESSAGE_TTL_TOLERANCE,
27
28
  NODE_ID_NUM_BYTES,
28
29
  RUN_ID_NUM_BYTES,
30
+ Status,
29
31
  )
30
- from flwr.common.typing import Run, UserConfig
32
+ from flwr.common.typing import Run, RunStatus, UserConfig
31
33
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
32
- from flwr.server.superlink.state.state import State
34
+ from flwr.server.superlink.linkstate.linkstate import LinkState
33
35
  from flwr.server.utils import validate_task_ins_or_res
34
36
 
35
- from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
37
+ from .utils import (
38
+ generate_rand_int_from_bytes,
39
+ has_valid_sub_status,
40
+ is_valid_transition,
41
+ make_node_unavailable_taskres,
42
+ )
43
+
44
+
45
+ @dataclass
46
+ class RunRecord:
47
+ """The record of a specific run, including its status and timestamps."""
36
48
 
49
+ run: Run
50
+ status: RunStatus
51
+ pending_at: str = ""
52
+ starting_at: str = ""
53
+ running_at: str = ""
54
+ finished_at: str = ""
37
55
 
38
- class InMemoryState(State): # pylint: disable=R0902,R0904
39
- """In-memory State implementation."""
56
+
57
+ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
58
+ """In-memory LinkState implementation."""
40
59
 
41
60
  def __init__(self) -> None:
42
61
 
@@ -44,8 +63,8 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
44
63
  self.node_ids: dict[int, tuple[float, float]] = {}
45
64
  self.public_key_to_node_id: dict[bytes, int] = {}
46
65
 
47
- # Map run_id to (fab_id, fab_version)
48
- self.run_ids: dict[int, Run] = {}
66
+ # Map run_id to RunRecord
67
+ self.run_ids: dict[int, RunRecord] = {}
49
68
  self.task_ins_store: dict[UUID, TaskIns] = {}
50
69
  self.task_res_store: dict[UUID, TaskRes] = {}
51
70
 
@@ -277,7 +296,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
277
296
  def create_node(
278
297
  self, ping_interval: float, public_key: Optional[bytes] = None
279
298
  ) -> int:
280
- """Create, store in state, and return `node_id`."""
299
+ """Create, store in the link state, and return `node_id`."""
281
300
  # Sample a random int64 as node_id
282
301
  node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
283
302
 
@@ -351,13 +370,22 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
351
370
  run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
352
371
 
353
372
  if run_id not in self.run_ids:
354
- self.run_ids[run_id] = Run(
355
- run_id=run_id,
356
- fab_id=fab_id if fab_id else "",
357
- fab_version=fab_version if fab_version else "",
358
- fab_hash=fab_hash if fab_hash else "",
359
- override_config=override_config,
373
+ run_record = RunRecord(
374
+ run=Run(
375
+ run_id=run_id,
376
+ fab_id=fab_id if fab_id else "",
377
+ fab_version=fab_version if fab_version else "",
378
+ fab_hash=fab_hash if fab_hash else "",
379
+ override_config=override_config,
380
+ ),
381
+ status=RunStatus(
382
+ status=Status.PENDING,
383
+ sub_status="",
384
+ details="",
385
+ ),
386
+ pending_at=now().isoformat(),
360
387
  )
388
+ self.run_ids[run_id] = run_record
361
389
  return run_id
362
390
  log(ERROR, "Unexpected run creation failure.")
363
391
  return 0
@@ -365,7 +393,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
365
393
  def store_server_private_public_key(
366
394
  self, private_key: bytes, public_key: bytes
367
395
  ) -> None:
368
- """Store `server_private_key` and `server_public_key` in state."""
396
+ """Store `server_private_key` and `server_public_key` in the link state."""
369
397
  with self.lock:
370
398
  if self.server_private_key is None and self.server_public_key is None:
371
399
  self.server_private_key = private_key
@@ -382,12 +410,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
382
410
  return self.server_public_key
383
411
 
384
412
  def store_node_public_keys(self, public_keys: set[bytes]) -> None:
385
- """Store a set of `node_public_keys` in state."""
413
+ """Store a set of `node_public_keys` in the link state."""
386
414
  with self.lock:
387
415
  self.node_public_keys = public_keys
388
416
 
389
417
  def store_node_public_key(self, public_key: bytes) -> None:
390
- """Store a `node_public_key` in state."""
418
+ """Store a `node_public_key` in the link state."""
391
419
  with self.lock:
392
420
  self.node_public_keys.add(public_key)
393
421
 
@@ -401,7 +429,69 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
401
429
  if run_id not in self.run_ids:
402
430
  log(ERROR, "`run_id` is invalid")
403
431
  return None
404
- return self.run_ids[run_id]
432
+ return self.run_ids[run_id].run
433
+
434
+ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
435
+ """Retrieve the statuses for the specified runs."""
436
+ with self.lock:
437
+ return {
438
+ run_id: self.run_ids[run_id].status
439
+ for run_id in set(run_ids)
440
+ if run_id in self.run_ids
441
+ }
442
+
443
+ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
444
+ """Update the status of the run with the specified `run_id`."""
445
+ with self.lock:
446
+ # Check if the run_id exists
447
+ if run_id not in self.run_ids:
448
+ log(ERROR, "`run_id` is invalid")
449
+ return False
450
+
451
+ # Check if the status transition is valid
452
+ current_status = self.run_ids[run_id].status
453
+ if not is_valid_transition(current_status, new_status):
454
+ log(
455
+ ERROR,
456
+ 'Invalid status transition: from "%s" to "%s"',
457
+ current_status.status,
458
+ new_status.status,
459
+ )
460
+ return False
461
+
462
+ # Check if the sub-status is valid
463
+ if not has_valid_sub_status(current_status):
464
+ log(
465
+ ERROR,
466
+ 'Invalid sub-status "%s" for status "%s"',
467
+ current_status.sub_status,
468
+ current_status.status,
469
+ )
470
+ return False
471
+
472
+ # Update the status
473
+ run_record = self.run_ids[run_id]
474
+ if new_status.status == Status.STARTING:
475
+ run_record.starting_at = now().isoformat()
476
+ elif new_status.status == Status.RUNNING:
477
+ run_record.running_at = now().isoformat()
478
+ elif new_status.status == Status.FINISHED:
479
+ run_record.finished_at = now().isoformat()
480
+ run_record.status = new_status
481
+ return True
482
+
483
+ def get_pending_run_id(self) -> Optional[int]:
484
+ """Get the `run_id` of a run with `Status.PENDING` status, if any."""
485
+ pending_run_id = None
486
+
487
+ # Loop through all registered runs
488
+ for run_id, run_rec in self.run_ids.items():
489
+ # Break once a pending run is found
490
+ if run_rec.status.status == Status.PENDING:
491
+ pending_run_id = run_id
492
+ break
493
+
494
+ return pending_run_id
405
495
 
406
496
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
407
497
  """Acknowledge a ping received from a node, serving as a heartbeat."""