flwr-nightly 1.13.0.dev20241019__py3-none-any.whl → 1.13.0.dev20241106__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (81) hide show
  1. flwr/cli/build.py +2 -2
  2. flwr/cli/config_utils.py +97 -0
  3. flwr/cli/log.py +63 -97
  4. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
  6. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  7. flwr/cli/run/run.py +18 -83
  8. flwr/client/app.py +13 -14
  9. flwr/client/clientapp/app.py +1 -2
  10. flwr/client/{node_state.py → run_info_store.py} +4 -3
  11. flwr/client/supernode/app.py +6 -8
  12. flwr/common/constant.py +39 -4
  13. flwr/common/context.py +9 -4
  14. flwr/common/date.py +3 -3
  15. flwr/common/logger.py +103 -0
  16. flwr/common/serde.py +24 -0
  17. flwr/common/telemetry.py +0 -6
  18. flwr/common/typing.py +9 -0
  19. flwr/proto/exec_pb2.py +6 -6
  20. flwr/proto/exec_pb2.pyi +8 -2
  21. flwr/proto/log_pb2.py +29 -0
  22. flwr/proto/log_pb2.pyi +39 -0
  23. flwr/proto/log_pb2_grpc.py +4 -0
  24. flwr/proto/log_pb2_grpc.pyi +4 -0
  25. flwr/proto/message_pb2.py +8 -8
  26. flwr/proto/message_pb2.pyi +4 -1
  27. flwr/proto/serverappio_pb2.py +52 -0
  28. flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
  29. flwr/proto/serverappio_pb2_grpc.py +376 -0
  30. flwr/proto/serverappio_pb2_grpc.pyi +147 -0
  31. flwr/proto/simulationio_pb2.py +38 -0
  32. flwr/proto/simulationio_pb2.pyi +65 -0
  33. flwr/proto/simulationio_pb2_grpc.py +171 -0
  34. flwr/proto/simulationio_pb2_grpc.pyi +68 -0
  35. flwr/server/app.py +247 -105
  36. flwr/server/driver/driver.py +15 -1
  37. flwr/server/driver/grpc_driver.py +26 -33
  38. flwr/server/driver/inmemory_driver.py +6 -14
  39. flwr/server/run_serverapp.py +29 -23
  40. flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
  41. flwr/server/serverapp/app.py +270 -0
  42. flwr/server/strategy/fedadam.py +11 -1
  43. flwr/server/superlink/driver/__init__.py +1 -1
  44. flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
  45. flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
  46. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
  47. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
  48. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
  49. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
  50. flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
  51. flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
  52. flwr/server/superlink/fleet/vce/vce_api.py +23 -23
  53. flwr/server/superlink/linkstate/__init__.py +28 -0
  54. flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +180 -21
  55. flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +144 -15
  56. flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
  57. flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +300 -50
  58. flwr/server/superlink/{state → linkstate}/utils.py +84 -2
  59. flwr/server/superlink/simulation/__init__.py +15 -0
  60. flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
  61. flwr/server/superlink/simulation/simulationio_servicer.py +132 -0
  62. flwr/simulation/__init__.py +2 -0
  63. flwr/simulation/app.py +1 -1
  64. flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
  65. flwr/simulation/run_simulation.py +57 -131
  66. flwr/simulation/simulationio_connection.py +86 -0
  67. flwr/superexec/app.py +6 -134
  68. flwr/superexec/deployment.py +60 -65
  69. flwr/superexec/exec_grpc.py +15 -8
  70. flwr/superexec/exec_servicer.py +34 -63
  71. flwr/superexec/executor.py +22 -4
  72. flwr/superexec/simulation.py +13 -8
  73. {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/METADATA +1 -1
  74. {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/RECORD +77 -64
  75. {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/entry_points.txt +1 -0
  76. flwr/client/node_state_tests.py +0 -66
  77. flwr/proto/driver_pb2.py +0 -42
  78. flwr/proto/driver_pb2_grpc.py +0 -239
  79. flwr/proto/driver_pb2_grpc.pyi +0 -94
  80. {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/LICENSE +0 -0
  81. {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/WHEEL +0 -0
@@ -37,13 +37,15 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
37
37
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
38
38
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
39
39
  from flwr.server.superlink.fleet.message_handler import message_handler
40
- from flwr.server.superlink.state import StateFactory
40
+ from flwr.server.superlink.linkstate import LinkStateFactory
41
41
 
42
42
 
43
43
  class FleetServicer(fleet_pb2_grpc.FleetServicer):
44
44
  """Fleet API servicer."""
45
45
 
46
- def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
46
+ def __init__(
47
+ self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
48
+ ) -> None:
47
49
  self.state_factory = state_factory
48
50
  self.ffs_factory = ffs_factory
49
51
 
@@ -45,7 +45,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
45
45
  )
46
46
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
47
47
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
48
- from flwr.server.superlink.state import State
48
+ from flwr.server.superlink.linkstate import LinkState
49
49
 
50
50
  _PUBLIC_KEY_HEADER = "public-key"
51
51
  _AUTH_TOKEN_HEADER = "auth-token"
@@ -84,7 +84,7 @@ def _get_value_from_tuples(
84
84
  class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
85
85
  """Server interceptor for node authentication."""
86
86
 
87
- def __init__(self, state: State):
87
+ def __init__(self, state: LinkState):
88
88
  self.state = state
89
89
 
90
90
  self.node_public_keys = state.get_node_public_keys()
@@ -43,12 +43,12 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
43
43
  )
44
44
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
45
45
  from flwr.server.superlink.ffs.ffs import Ffs
46
- from flwr.server.superlink.state import State
46
+ from flwr.server.superlink.linkstate import LinkState
47
47
 
48
48
 
49
49
  def create_node(
50
50
  request: CreateNodeRequest, # pylint: disable=unused-argument
51
- state: State,
51
+ state: LinkState,
52
52
  ) -> CreateNodeResponse:
53
53
  """."""
54
54
  # Create node
@@ -56,7 +56,7 @@ def create_node(
56
56
  return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
57
57
 
58
58
 
59
- def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
59
+ def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
60
60
  """."""
61
61
  # Validate node_id
62
62
  if request.node.anonymous or request.node.node_id == 0:
@@ -69,14 +69,14 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
69
69
 
70
70
  def ping(
71
71
  request: PingRequest, # pylint: disable=unused-argument
72
- state: State, # pylint: disable=unused-argument
72
+ state: LinkState, # pylint: disable=unused-argument
73
73
  ) -> PingResponse:
74
74
  """."""
75
75
  res = state.acknowledge_ping(request.node.node_id, request.ping_interval)
76
76
  return PingResponse(success=res)
77
77
 
78
78
 
79
- def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse:
79
+ def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse:
80
80
  """Pull TaskIns handler."""
81
81
  # Get node_id if client node is not anonymous
82
82
  node = request.node # pylint: disable=no-member
@@ -92,7 +92,7 @@ def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsRespo
92
92
  return response
93
93
 
94
94
 
95
- def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResResponse:
95
+ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResResponse:
96
96
  """Push TaskRes handler."""
97
97
  # pylint: disable=no-member
98
98
  task_res: TaskRes = request.task_res_list[0]
@@ -113,7 +113,7 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
113
113
 
114
114
 
115
115
  def get_run(
116
- request: GetRunRequest, state: State # pylint: disable=W0613
116
+ request: GetRunRequest, state: LinkState # pylint: disable=W0613
117
117
  ) -> GetRunResponse:
118
118
  """Get run information."""
119
119
  run = state.get_run(request.run_id)
@@ -40,7 +40,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
40
40
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
41
41
  from flwr.server.superlink.ffs.ffs import Ffs
42
42
  from flwr.server.superlink.fleet.message_handler import message_handler
43
- from flwr.server.superlink.state import State
43
+ from flwr.server.superlink.linkstate import LinkState
44
44
 
45
45
  try:
46
46
  from starlette.applications import Starlette
@@ -90,7 +90,7 @@ def rest_request_response(
90
90
  async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
91
91
  """Create Node."""
92
92
  # Get state from app
93
- state: State = app.state.STATE_FACTORY.state()
93
+ state: LinkState = app.state.STATE_FACTORY.state()
94
94
 
95
95
  # Handle message
96
96
  return message_handler.create_node(request=request, state=state)
@@ -100,7 +100,7 @@ async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
100
100
  async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
101
101
  """Delete Node Id."""
102
102
  # Get state from app
103
- state: State = app.state.STATE_FACTORY.state()
103
+ state: LinkState = app.state.STATE_FACTORY.state()
104
104
 
105
105
  # Handle message
106
106
  return message_handler.delete_node(request=request, state=state)
@@ -110,7 +110,7 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
110
110
  async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
111
111
  """Pull TaskIns."""
112
112
  # Get state from app
113
- state: State = app.state.STATE_FACTORY.state()
113
+ state: LinkState = app.state.STATE_FACTORY.state()
114
114
 
115
115
  # Handle message
116
116
  return message_handler.pull_task_ins(request=request, state=state)
@@ -121,7 +121,7 @@ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
121
121
  async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
122
122
  """Push TaskRes."""
123
123
  # Get state from app
124
- state: State = app.state.STATE_FACTORY.state()
124
+ state: LinkState = app.state.STATE_FACTORY.state()
125
125
 
126
126
  # Handle message
127
127
  return message_handler.push_task_res(request=request, state=state)
@@ -131,7 +131,7 @@ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
131
131
  async def ping(request: PingRequest) -> PingResponse:
132
132
  """Ping."""
133
133
  # Get state from app
134
- state: State = app.state.STATE_FACTORY.state()
134
+ state: LinkState = app.state.STATE_FACTORY.state()
135
135
 
136
136
  # Handle message
137
137
  return message_handler.ping(request=request, state=state)
@@ -141,7 +141,7 @@ async def ping(request: PingRequest) -> PingResponse:
141
141
  async def get_run(request: GetRunRequest) -> GetRunResponse:
142
142
  """GetRun."""
143
143
  # Get state from app
144
- state: State = app.state.STATE_FACTORY.state()
144
+ state: LinkState = app.state.STATE_FACTORY.state()
145
145
 
146
146
  # Handle message
147
147
  return message_handler.get_run(request=request, state=state)
@@ -28,7 +28,7 @@ from typing import Callable, Optional
28
28
 
29
29
  from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
30
30
  from flwr.client.clientapp.utils import get_load_client_app_fn
31
- from flwr.client.node_state import NodeState
31
+ from flwr.client.run_info_store import DeprecatedRunInfoStore
32
32
  from flwr.common.constant import (
33
33
  NUM_PARTITIONS_KEY,
34
34
  PARTITION_ID_KEY,
@@ -40,7 +40,7 @@ from flwr.common.message import Error
40
40
  from flwr.common.serde import message_from_taskins, message_to_taskres
41
41
  from flwr.common.typing import Run
42
42
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
43
- from flwr.server.superlink.state import State, StateFactory
43
+ from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
44
44
 
45
45
  from .backend import Backend, error_messages_backends, supported_backends
46
46
 
@@ -48,7 +48,7 @@ NodeToPartitionMapping = dict[int, int]
48
48
 
49
49
 
50
50
  def _register_nodes(
51
- num_nodes: int, state_factory: StateFactory
51
+ num_nodes: int, state_factory: LinkStateFactory
52
52
  ) -> NodeToPartitionMapping:
53
53
  """Register nodes with the StateFactory and create node-id:partition-id mapping."""
54
54
  nodes_mapping: NodeToPartitionMapping = {}
@@ -60,16 +60,16 @@ def _register_nodes(
60
60
  return nodes_mapping
61
61
 
62
62
 
63
- def _register_node_states(
63
+ def _register_node_info_stores(
64
64
  nodes_mapping: NodeToPartitionMapping,
65
65
  run: Run,
66
66
  app_dir: Optional[str] = None,
67
- ) -> dict[int, NodeState]:
68
- """Create NodeState objects and pre-register the context for the run."""
69
- node_states: dict[int, NodeState] = {}
67
+ ) -> dict[int, DeprecatedRunInfoStore]:
68
+ """Create DeprecatedRunInfoStore objects and register the context for the run."""
69
+ node_info_store: dict[int, DeprecatedRunInfoStore] = {}
70
70
  num_partitions = len(set(nodes_mapping.values()))
71
71
  for node_id, partition_id in nodes_mapping.items():
72
- node_states[node_id] = NodeState(
72
+ node_info_store[node_id] = DeprecatedRunInfoStore(
73
73
  node_id=node_id,
74
74
  node_config={
75
75
  PARTITION_ID_KEY: partition_id,
@@ -78,18 +78,18 @@ def _register_node_states(
78
78
  )
79
79
 
80
80
  # Pre-register Context objects
81
- node_states[node_id].register_context(
81
+ node_info_store[node_id].register_context(
82
82
  run_id=run.run_id, run=run, app_dir=app_dir
83
83
  )
84
84
 
85
- return node_states
85
+ return node_info_store
86
86
 
87
87
 
88
88
  # pylint: disable=too-many-arguments,too-many-locals
89
89
  def worker(
90
90
  taskins_queue: "Queue[TaskIns]",
91
91
  taskres_queue: "Queue[TaskRes]",
92
- node_states: dict[int, NodeState],
92
+ node_info_store: dict[int, DeprecatedRunInfoStore],
93
93
  backend: Backend,
94
94
  f_stop: threading.Event,
95
95
  ) -> None:
@@ -103,7 +103,7 @@ def worker(
103
103
  node_id = task_ins.task.consumer.node_id
104
104
 
105
105
  # Retrieve context
106
- context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
106
+ context = node_info_store[node_id].retrieve_context(run_id=task_ins.run_id)
107
107
 
108
108
  # Convert TaskIns to Message
109
109
  message = message_from_taskins(task_ins)
@@ -112,7 +112,7 @@ def worker(
112
112
  out_mssg, updated_context = backend.process_message(message, context)
113
113
 
114
114
  # Update Context
115
- node_states[node_id].update_context(
115
+ node_info_store[node_id].update_context(
116
116
  task_ins.run_id, context=updated_context
117
117
  )
118
118
  except Empty:
@@ -145,7 +145,7 @@ def worker(
145
145
 
146
146
 
147
147
  def add_taskins_to_queue(
148
- state: State,
148
+ state: LinkState,
149
149
  queue: "Queue[TaskIns]",
150
150
  nodes_mapping: NodeToPartitionMapping,
151
151
  f_stop: threading.Event,
@@ -160,7 +160,7 @@ def add_taskins_to_queue(
160
160
 
161
161
 
162
162
  def put_taskres_into_state(
163
- state: State, queue: "Queue[TaskRes]", f_stop: threading.Event
163
+ state: LinkState, queue: "Queue[TaskRes]", f_stop: threading.Event
164
164
  ) -> None:
165
165
  """Put TaskRes into State from a queue."""
166
166
  while not f_stop.is_set():
@@ -177,8 +177,8 @@ def run_api(
177
177
  app_fn: Callable[[], ClientApp],
178
178
  backend_fn: Callable[[], Backend],
179
179
  nodes_mapping: NodeToPartitionMapping,
180
- state_factory: StateFactory,
181
- node_states: dict[int, NodeState],
180
+ state_factory: LinkStateFactory,
181
+ node_info_stores: dict[int, DeprecatedRunInfoStore],
182
182
  f_stop: threading.Event,
183
183
  ) -> None:
184
184
  """Run the VCE."""
@@ -223,7 +223,7 @@ def run_api(
223
223
  worker,
224
224
  taskins_queue,
225
225
  taskres_queue,
226
- node_states,
226
+ node_info_stores,
227
227
  backend,
228
228
  f_stop,
229
229
  )
@@ -264,7 +264,7 @@ def start_vce(
264
264
  client_app: Optional[ClientApp] = None,
265
265
  client_app_attr: Optional[str] = None,
266
266
  num_supernodes: Optional[int] = None,
267
- state_factory: Optional[StateFactory] = None,
267
+ state_factory: Optional[LinkStateFactory] = None,
268
268
  existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
269
269
  ) -> None:
270
270
  """Start Fleet API with the Simulation Engine."""
@@ -303,7 +303,7 @@ def start_vce(
303
303
  if not state_factory:
304
304
  log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
305
305
  # Create an empty in-memory state factory
306
- state_factory = StateFactory(":flwr-in-memory-state:")
306
+ state_factory = LinkStateFactory(":flwr-in-memory-state:")
307
307
  log(INFO, "Created new %s.", state_factory.__class__.__name__)
308
308
 
309
309
  if num_supernodes:
@@ -312,8 +312,8 @@ def start_vce(
312
312
  num_nodes=num_supernodes, state_factory=state_factory
313
313
  )
314
314
 
315
- # Construct mapping of NodeStates
316
- node_states = _register_node_states(
315
+ # Construct mapping of DeprecatedRunInfoStore
316
+ node_info_stores = _register_node_info_stores(
317
317
  nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None
318
318
  )
319
319
 
@@ -376,7 +376,7 @@ def start_vce(
376
376
  backend_fn,
377
377
  nodes_mapping,
378
378
  state_factory,
379
- node_states,
379
+ node_info_stores,
380
380
  f_stop,
381
381
  )
382
382
  except LoadClientAppError as loadapp_ex:
@@ -0,0 +1,28 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower LinkState."""
16
+
17
+
18
+ from .in_memory_linkstate import InMemoryLinkState as InMemoryLinkState
19
+ from .linkstate import LinkState as LinkState
20
+ from .linkstate_factory import LinkStateFactory as LinkStateFactory
21
+ from .sqlite_linkstate import SqliteLinkState as SqliteLinkState
22
+
23
+ __all__ = [
24
+ "InMemoryLinkState",
25
+ "LinkState",
26
+ "LinkStateFactory",
27
+ "SqliteLinkState",
28
+ ]
@@ -12,31 +12,54 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """In-memory State implementation."""
15
+ """In-memory LinkState implementation."""
16
16
 
17
17
 
18
18
  import threading
19
19
  import time
20
+ from bisect import bisect_right
21
+ from dataclasses import dataclass, field
20
22
  from logging import ERROR, WARNING
21
23
  from typing import Optional
22
24
  from uuid import UUID, uuid4
23
25
 
24
- from flwr.common import log, now
26
+ from flwr.common import Context, log, now
25
27
  from flwr.common.constant import (
26
28
  MESSAGE_TTL_TOLERANCE,
27
29
  NODE_ID_NUM_BYTES,
28
30
  RUN_ID_NUM_BYTES,
31
+ Status,
29
32
  )
30
- from flwr.common.typing import Run, UserConfig
33
+ from flwr.common.record import ConfigsRecord
34
+ from flwr.common.typing import Run, RunStatus, UserConfig
31
35
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
32
- from flwr.server.superlink.state.state import State
36
+ from flwr.server.superlink.linkstate.linkstate import LinkState
33
37
  from flwr.server.utils import validate_task_ins_or_res
34
38
 
35
- from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
39
+ from .utils import (
40
+ generate_rand_int_from_bytes,
41
+ has_valid_sub_status,
42
+ is_valid_transition,
43
+ make_node_unavailable_taskres,
44
+ )
45
+
46
+
47
+ @dataclass
48
+ class RunRecord: # pylint: disable=R0902
49
+ """The record of a specific run, including its status and timestamps."""
36
50
 
51
+ run: Run
52
+ status: RunStatus
53
+ pending_at: str = ""
54
+ starting_at: str = ""
55
+ running_at: str = ""
56
+ finished_at: str = ""
57
+ logs: list[tuple[float, str]] = field(default_factory=list)
58
+ log_lock: threading.Lock = field(default_factory=threading.Lock)
37
59
 
38
- class InMemoryState(State): # pylint: disable=R0902,R0904
39
- """In-memory State implementation."""
60
+
61
+ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
62
+ """In-memory LinkState implementation."""
40
63
 
41
64
  def __init__(self) -> None:
42
65
 
@@ -44,8 +67,10 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
44
67
  self.node_ids: dict[int, tuple[float, float]] = {}
45
68
  self.public_key_to_node_id: dict[bytes, int] = {}
46
69
 
47
- # Map run_id to (fab_id, fab_version)
48
- self.run_ids: dict[int, Run] = {}
70
+ # Map run_id to RunRecord
71
+ self.run_ids: dict[int, RunRecord] = {}
72
+ self.contexts: dict[int, Context] = {}
73
+ self.federation_options: dict[int, ConfigsRecord] = {}
49
74
  self.task_ins_store: dict[UUID, TaskIns] = {}
50
75
  self.task_res_store: dict[UUID, TaskRes] = {}
51
76
 
@@ -64,8 +89,25 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
64
89
  return None
65
90
  # Validate run_id
66
91
  if task_ins.run_id not in self.run_ids:
67
- log(ERROR, "`run_id` is invalid")
92
+ log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
68
93
  return None
94
+ # Validate source node ID
95
+ if task_ins.task.producer.node_id != 0:
96
+ log(
97
+ ERROR,
98
+ "Invalid source node ID for TaskIns: %s",
99
+ task_ins.task.producer.node_id,
100
+ )
101
+ return None
102
+ # Validate destination node ID
103
+ if not task_ins.task.consumer.anonymous:
104
+ if task_ins.task.consumer.node_id not in self.node_ids:
105
+ log(
106
+ ERROR,
107
+ "Invalid destination node ID for TaskIns: %s",
108
+ task_ins.task.consumer.node_id,
109
+ )
110
+ return None
69
111
 
70
112
  # Create task_id
71
113
  task_id = uuid4()
@@ -277,7 +319,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
277
319
  def create_node(
278
320
  self, ping_interval: float, public_key: Optional[bytes] = None
279
321
  ) -> int:
280
- """Create, store in state, and return `node_id`."""
322
+ """Create, store in the link state, and return `node_id`."""
281
323
  # Sample a random int64 as node_id
282
324
  node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
283
325
 
@@ -338,12 +380,14 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
338
380
  """Retrieve stored `node_id` filtered by `node_public_keys`."""
339
381
  return self.public_key_to_node_id.get(node_public_key)
340
382
 
383
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
341
384
  def create_run(
342
385
  self,
343
386
  fab_id: Optional[str],
344
387
  fab_version: Optional[str],
345
388
  fab_hash: Optional[str],
346
389
  override_config: UserConfig,
390
+ federation_options: ConfigsRecord,
347
391
  ) -> int:
348
392
  """Create a new run for the specified `fab_hash`."""
349
393
  # Sample a random int64 as run_id
@@ -351,13 +395,25 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
351
395
  run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
352
396
 
353
397
  if run_id not in self.run_ids:
354
- self.run_ids[run_id] = Run(
355
- run_id=run_id,
356
- fab_id=fab_id if fab_id else "",
357
- fab_version=fab_version if fab_version else "",
358
- fab_hash=fab_hash if fab_hash else "",
359
- override_config=override_config,
398
+ run_record = RunRecord(
399
+ run=Run(
400
+ run_id=run_id,
401
+ fab_id=fab_id if fab_id else "",
402
+ fab_version=fab_version if fab_version else "",
403
+ fab_hash=fab_hash if fab_hash else "",
404
+ override_config=override_config,
405
+ ),
406
+ status=RunStatus(
407
+ status=Status.PENDING,
408
+ sub_status="",
409
+ details="",
410
+ ),
411
+ pending_at=now().isoformat(),
360
412
  )
413
+ self.run_ids[run_id] = run_record
414
+
415
+ # Record federation options. Leave empty if not passed
416
+ self.federation_options[run_id] = federation_options
361
417
  return run_id
362
418
  log(ERROR, "Unexpected run creation failure.")
363
419
  return 0
@@ -365,7 +421,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
365
421
  def store_server_private_public_key(
366
422
  self, private_key: bytes, public_key: bytes
367
423
  ) -> None:
368
- """Store `server_private_key` and `server_public_key` in state."""
424
+ """Store `server_private_key` and `server_public_key` in the link state."""
369
425
  with self.lock:
370
426
  if self.server_private_key is None and self.server_public_key is None:
371
427
  self.server_private_key = private_key
@@ -382,12 +438,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
382
438
  return self.server_public_key
383
439
 
384
440
  def store_node_public_keys(self, public_keys: set[bytes]) -> None:
385
- """Store a set of `node_public_keys` in state."""
441
+ """Store a set of `node_public_keys` in the link state."""
386
442
  with self.lock:
387
443
  self.node_public_keys = public_keys
388
444
 
389
445
  def store_node_public_key(self, public_key: bytes) -> None:
390
- """Store a `node_public_key` in state."""
446
+ """Store a `node_public_key` in the link state."""
391
447
  with self.lock:
392
448
  self.node_public_keys.add(public_key)
393
449
 
@@ -401,7 +457,77 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
401
457
  if run_id not in self.run_ids:
402
458
  log(ERROR, "`run_id` is invalid")
403
459
  return None
404
- return self.run_ids[run_id]
460
+ return self.run_ids[run_id].run
461
+
462
+ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
463
+ """Retrieve the statuses for the specified runs."""
464
+ with self.lock:
465
+ return {
466
+ run_id: self.run_ids[run_id].status
467
+ for run_id in set(run_ids)
468
+ if run_id in self.run_ids
469
+ }
470
+
471
+ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
472
+ """Update the status of the run with the specified `run_id`."""
473
+ with self.lock:
474
+ # Check if the run_id exists
475
+ if run_id not in self.run_ids:
476
+ log(ERROR, "`run_id` is invalid")
477
+ return False
478
+
479
+ # Check if the status transition is valid
480
+ current_status = self.run_ids[run_id].status
481
+ if not is_valid_transition(current_status, new_status):
482
+ log(
483
+ ERROR,
484
+ 'Invalid status transition: from "%s" to "%s"',
485
+ current_status.status,
486
+ new_status.status,
487
+ )
488
+ return False
489
+
490
+ # Check if the sub-status is valid
491
+ if not has_valid_sub_status(current_status):
492
+ log(
493
+ ERROR,
494
+ 'Invalid sub-status "%s" for status "%s"',
495
+ current_status.sub_status,
496
+ current_status.status,
497
+ )
498
+ return False
499
+
500
+ # Update the status
501
+ run_record = self.run_ids[run_id]
502
+ if new_status.status == Status.STARTING:
503
+ run_record.starting_at = now().isoformat()
504
+ elif new_status.status == Status.RUNNING:
505
+ run_record.running_at = now().isoformat()
506
+ elif new_status.status == Status.FINISHED:
507
+ run_record.finished_at = now().isoformat()
508
+ run_record.status = new_status
509
+ return True
510
+
511
+ def get_pending_run_id(self) -> Optional[int]:
512
+ """Get the `run_id` of a run with `Status.PENDING` status, if any."""
513
+ pending_run_id = None
514
+
515
+ # Loop through all registered runs
516
+ for run_id, run_rec in self.run_ids.items():
517
+ # Break once a pending run is found
518
+ if run_rec.status.status == Status.PENDING:
519
+ pending_run_id = run_id
520
+ break
521
+
522
+ return pending_run_id
523
+
524
+ def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
525
+ """Retrieve the federation options for the specified `run_id`."""
526
+ with self.lock:
527
+ if run_id not in self.run_ids:
528
+ log(ERROR, "`run_id` is invalid")
529
+ return None
530
+ return self.federation_options[run_id]
405
531
 
406
532
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
407
533
  """Acknowledge a ping received from a node, serving as a heartbeat."""
@@ -410,3 +536,36 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
410
536
  self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
411
537
  return True
412
538
  return False
539
+
540
+ def get_serverapp_context(self, run_id: int) -> Optional[Context]:
541
+ """Get the context for the specified `run_id`."""
542
+ return self.contexts.get(run_id)
543
+
544
+ def set_serverapp_context(self, run_id: int, context: Context) -> None:
545
+ """Set the context for the specified `run_id`."""
546
+ if run_id not in self.run_ids:
547
+ raise ValueError(f"Run {run_id} not found")
548
+ self.contexts[run_id] = context
549
+
550
+ def add_serverapp_log(self, run_id: int, log_message: str) -> None:
551
+ """Add a log entry to the serverapp logs for the specified `run_id`."""
552
+ if run_id not in self.run_ids:
553
+ raise ValueError(f"Run {run_id} not found")
554
+ run = self.run_ids[run_id]
555
+ with run.log_lock:
556
+ run.logs.append((now().timestamp(), log_message))
557
+
558
+ def get_serverapp_log(
559
+ self, run_id: int, after_timestamp: Optional[float]
560
+ ) -> tuple[str, float]:
561
+ """Get the serverapp logs for the specified `run_id`."""
562
+ if run_id not in self.run_ids:
563
+ raise ValueError(f"Run {run_id} not found")
564
+ run = self.run_ids[run_id]
565
+ if after_timestamp is None:
566
+ after_timestamp = 0.0
567
+ with run.log_lock:
568
+ # Find the index where the timestamp would be inserted
569
+ index = bisect_right(run.logs, (after_timestamp, ""))
570
+ latest_timestamp = run.logs[-1][0] if index < len(run.logs) else 0.0
571
+ return "".join(log for _, log in run.logs[index:]), latest_timestamp