flwr-nightly 1.13.0.dev20241021__py3-none-any.whl → 1.13.0.dev20241111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (92) hide show
  1. flwr/cli/build.py +2 -2
  2. flwr/cli/config_utils.py +97 -0
  3. flwr/cli/log.py +63 -97
  4. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
  6. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  7. flwr/cli/run/run.py +34 -88
  8. flwr/client/app.py +23 -20
  9. flwr/client/clientapp/app.py +22 -18
  10. flwr/client/nodestate/__init__.py +25 -0
  11. flwr/client/nodestate/in_memory_nodestate.py +38 -0
  12. flwr/client/nodestate/nodestate.py +30 -0
  13. flwr/client/nodestate/nodestate_factory.py +37 -0
  14. flwr/client/{node_state.py → run_info_store.py} +4 -3
  15. flwr/client/supernode/app.py +6 -8
  16. flwr/common/args.py +83 -0
  17. flwr/common/config.py +10 -0
  18. flwr/common/constant.py +39 -5
  19. flwr/common/context.py +9 -4
  20. flwr/common/date.py +3 -3
  21. flwr/common/logger.py +108 -1
  22. flwr/common/object_ref.py +47 -16
  23. flwr/common/serde.py +24 -0
  24. flwr/common/telemetry.py +0 -6
  25. flwr/common/typing.py +10 -1
  26. flwr/proto/exec_pb2.py +14 -17
  27. flwr/proto/exec_pb2.pyi +14 -22
  28. flwr/proto/log_pb2.py +29 -0
  29. flwr/proto/log_pb2.pyi +39 -0
  30. flwr/proto/log_pb2_grpc.py +4 -0
  31. flwr/proto/log_pb2_grpc.pyi +4 -0
  32. flwr/proto/message_pb2.py +8 -8
  33. flwr/proto/message_pb2.pyi +4 -1
  34. flwr/proto/run_pb2.py +32 -27
  35. flwr/proto/run_pb2.pyi +26 -0
  36. flwr/proto/serverappio_pb2.py +52 -0
  37. flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
  38. flwr/proto/serverappio_pb2_grpc.py +376 -0
  39. flwr/proto/serverappio_pb2_grpc.pyi +147 -0
  40. flwr/proto/simulationio_pb2.py +38 -0
  41. flwr/proto/simulationio_pb2.pyi +65 -0
  42. flwr/proto/simulationio_pb2_grpc.py +205 -0
  43. flwr/proto/simulationio_pb2_grpc.pyi +81 -0
  44. flwr/server/app.py +272 -105
  45. flwr/server/driver/driver.py +15 -1
  46. flwr/server/driver/grpc_driver.py +25 -36
  47. flwr/server/driver/inmemory_driver.py +6 -16
  48. flwr/server/run_serverapp.py +29 -23
  49. flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
  50. flwr/server/serverapp/app.py +214 -0
  51. flwr/server/strategy/aggregate.py +4 -4
  52. flwr/server/strategy/fedadam.py +11 -1
  53. flwr/server/superlink/driver/__init__.py +1 -1
  54. flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
  55. flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
  56. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
  57. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
  58. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
  59. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
  60. flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
  61. flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
  62. flwr/server/superlink/fleet/vce/vce_api.py +23 -23
  63. flwr/server/superlink/linkstate/__init__.py +28 -0
  64. flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +184 -36
  65. flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +149 -19
  66. flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
  67. flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +306 -65
  68. flwr/server/superlink/{state → linkstate}/utils.py +81 -30
  69. flwr/server/superlink/simulation/__init__.py +15 -0
  70. flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
  71. flwr/server/superlink/simulation/simulationio_servicer.py +153 -0
  72. flwr/simulation/__init__.py +5 -1
  73. flwr/simulation/app.py +273 -345
  74. flwr/simulation/legacy_app.py +382 -0
  75. flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
  76. flwr/simulation/run_simulation.py +57 -131
  77. flwr/simulation/simulationio_connection.py +86 -0
  78. flwr/superexec/app.py +6 -134
  79. flwr/superexec/deployment.py +61 -66
  80. flwr/superexec/exec_grpc.py +15 -8
  81. flwr/superexec/exec_servicer.py +36 -65
  82. flwr/superexec/executor.py +26 -7
  83. flwr/superexec/simulation.py +54 -107
  84. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/METADATA +5 -4
  85. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/RECORD +88 -69
  86. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/entry_points.txt +2 -0
  87. flwr/client/node_state_tests.py +0 -66
  88. flwr/proto/driver_pb2.py +0 -42
  89. flwr/proto/driver_pb2_grpc.py +0 -239
  90. flwr/proto/driver_pb2_grpc.pyi +0 -94
  91. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/LICENSE +0 -0
  92. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/WHEEL +0 -0
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Driver gRPC API."""
15
+ """ServerAppIo gRPC API."""
16
16
 
17
17
  from logging import INFO
18
18
  from typing import Optional
@@ -21,37 +21,40 @@ import grpc
21
21
 
22
22
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
23
23
  from flwr.common.logger import log
24
- from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
25
- add_DriverServicer_to_server,
24
+ from flwr.proto.serverappio_pb2_grpc import ( # pylint: disable=E0611
25
+ add_ServerAppIoServicer_to_server,
26
26
  )
27
27
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
28
- from flwr.server.superlink.state import StateFactory
28
+ from flwr.server.superlink.linkstate import LinkStateFactory
29
29
 
30
30
  from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
31
- from .driver_servicer import DriverServicer
31
+ from .serverappio_servicer import ServerAppIoServicer
32
32
 
33
33
 
34
- def run_driver_api_grpc(
34
+ def run_serverappio_api_grpc(
35
35
  address: str,
36
- state_factory: StateFactory,
36
+ state_factory: LinkStateFactory,
37
37
  ffs_factory: FfsFactory,
38
38
  certificates: Optional[tuple[bytes, bytes, bytes]],
39
39
  ) -> grpc.Server:
40
- """Run Driver API (gRPC, request-response)."""
41
- # Create Driver API gRPC server
42
- driver_servicer: grpc.Server = DriverServicer(
40
+ """Run ServerAppIo API (gRPC, request-response)."""
41
+ # Create ServerAppIo API gRPC server
42
+ serverappio_servicer: grpc.Server = ServerAppIoServicer(
43
43
  state_factory=state_factory,
44
44
  ffs_factory=ffs_factory,
45
45
  )
46
- driver_add_servicer_to_server_fn = add_DriverServicer_to_server
47
- driver_grpc_server = generic_create_grpc_server(
48
- servicer_and_add_fn=(driver_servicer, driver_add_servicer_to_server_fn),
46
+ serverappio_add_servicer_to_server_fn = add_ServerAppIoServicer_to_server
47
+ serverappio_grpc_server = generic_create_grpc_server(
48
+ servicer_and_add_fn=(
49
+ serverappio_servicer,
50
+ serverappio_add_servicer_to_server_fn,
51
+ ),
49
52
  server_address=address,
50
53
  max_message_length=GRPC_MAX_MESSAGE_LENGTH,
51
54
  certificates=certificates,
52
55
  )
53
56
 
54
- log(INFO, "Flower ECE: Starting Driver API (gRPC-rere) on %s", address)
55
- driver_grpc_server.start()
57
+ log(INFO, "Flower ECE: Starting ServerAppIo API (gRPC-rere) on %s", address)
58
+ serverappio_grpc_server.start()
56
59
 
57
- return driver_grpc_server
60
+ return serverappio_grpc_server
@@ -12,62 +12,80 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Driver API servicer."""
15
+ """ServerAppIo API servicer."""
16
16
 
17
17
 
18
+ import threading
18
19
  import time
19
- from logging import DEBUG
20
+ from logging import DEBUG, INFO
20
21
  from typing import Optional
21
22
  from uuid import UUID
22
23
 
23
24
  import grpc
24
25
 
26
+ from flwr.common import ConfigsRecord
27
+ from flwr.common.constant import Status
25
28
  from flwr.common.logger import log
26
29
  from flwr.common.serde import (
30
+ context_from_proto,
31
+ context_to_proto,
27
32
  fab_from_proto,
28
33
  fab_to_proto,
34
+ run_status_from_proto,
35
+ run_to_proto,
29
36
  user_config_from_proto,
30
- user_config_to_proto,
31
- )
32
- from flwr.common.typing import Fab
33
- from flwr.proto import driver_pb2_grpc # pylint: disable=E0611
34
- from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
35
- GetNodesRequest,
36
- GetNodesResponse,
37
- PullTaskResRequest,
38
- PullTaskResResponse,
39
- PushTaskInsRequest,
40
- PushTaskInsResponse,
41
37
  )
38
+ from flwr.common.typing import Fab, RunStatus
39
+ from flwr.proto import serverappio_pb2_grpc # pylint: disable=E0611
42
40
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
41
+ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
42
+ PushLogsRequest,
43
+ PushLogsResponse,
44
+ )
43
45
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
44
46
  from flwr.proto.run_pb2 import ( # pylint: disable=E0611
45
47
  CreateRunRequest,
46
48
  CreateRunResponse,
47
49
  GetRunRequest,
48
50
  GetRunResponse,
49
- Run,
51
+ UpdateRunStatusRequest,
52
+ UpdateRunStatusResponse,
53
+ )
54
+ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
55
+ GetNodesRequest,
56
+ GetNodesResponse,
57
+ PullServerAppInputsRequest,
58
+ PullServerAppInputsResponse,
59
+ PullTaskResRequest,
60
+ PullTaskResResponse,
61
+ PushServerAppOutputsRequest,
62
+ PushServerAppOutputsResponse,
63
+ PushTaskInsRequest,
64
+ PushTaskInsResponse,
50
65
  )
51
66
  from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
52
67
  from flwr.server.superlink.ffs.ffs import Ffs
53
68
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
54
- from flwr.server.superlink.state import State, StateFactory
69
+ from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
55
70
  from flwr.server.utils.validator import validate_task_ins_or_res
56
71
 
57
72
 
58
- class DriverServicer(driver_pb2_grpc.DriverServicer):
59
- """Driver API servicer."""
73
+ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
74
+ """ServerAppIo API servicer."""
60
75
 
61
- def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
76
+ def __init__(
77
+ self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
78
+ ) -> None:
62
79
  self.state_factory = state_factory
63
80
  self.ffs_factory = ffs_factory
81
+ self.lock = threading.RLock()
64
82
 
65
83
  def GetNodes(
66
84
  self, request: GetNodesRequest, context: grpc.ServicerContext
67
85
  ) -> GetNodesResponse:
68
86
  """Get available nodes."""
69
- log(DEBUG, "DriverServicer.GetNodes")
70
- state: State = self.state_factory.state()
87
+ log(DEBUG, "ServerAppIoServicer.GetNodes")
88
+ state: LinkState = self.state_factory.state()
71
89
  all_ids: set[int] = state.get_nodes(request.run_id)
72
90
  nodes: list[Node] = [
73
91
  Node(node_id=node_id, anonymous=False) for node_id in all_ids
@@ -78,8 +96,8 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
78
96
  self, request: CreateRunRequest, context: grpc.ServicerContext
79
97
  ) -> CreateRunResponse:
80
98
  """Create run ID."""
81
- log(DEBUG, "DriverServicer.CreateRun")
82
- state: State = self.state_factory.state()
99
+ log(DEBUG, "ServerAppIoServicer.CreateRun")
100
+ state: LinkState = self.state_factory.state()
83
101
  if request.HasField("fab"):
84
102
  fab = fab_from_proto(request.fab)
85
103
  ffs: Ffs = self.ffs_factory.ffs()
@@ -95,6 +113,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
95
113
  request.fab_version,
96
114
  fab_hash,
97
115
  user_config_from_proto(request.override_config),
116
+ ConfigsRecord(),
98
117
  )
99
118
  return CreateRunResponse(run_id=run_id)
100
119
 
@@ -102,7 +121,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
102
121
  self, request: PushTaskInsRequest, context: grpc.ServicerContext
103
122
  ) -> PushTaskInsResponse:
104
123
  """Push a set of TaskIns."""
105
- log(DEBUG, "DriverServicer.PushTaskIns")
124
+ log(DEBUG, "ServerAppIoServicer.PushTaskIns")
106
125
 
107
126
  # Set pushed_at (timestamp in seconds)
108
127
  pushed_at = time.time()
@@ -116,7 +135,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
116
135
  _raise_if(bool(validation_errors), ", ".join(validation_errors))
117
136
 
118
137
  # Init state
119
- state: State = self.state_factory.state()
138
+ state: LinkState = self.state_factory.state()
120
139
 
121
140
  # Store each TaskIns
122
141
  task_ids: list[Optional[UUID]] = []
@@ -132,17 +151,20 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
132
151
  self, request: PullTaskResRequest, context: grpc.ServicerContext
133
152
  ) -> PullTaskResResponse:
134
153
  """Pull a set of TaskRes."""
135
- log(DEBUG, "DriverServicer.PullTaskRes")
154
+ log(DEBUG, "ServerAppIoServicer.PullTaskRes")
136
155
 
137
156
  # Convert each task_id str to UUID
138
157
  task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
139
158
 
140
159
  # Init state
141
- state: State = self.state_factory.state()
160
+ state: LinkState = self.state_factory.state()
142
161
 
143
162
  # Register callback
144
163
  def on_rpc_done() -> None:
145
- log(DEBUG, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes")
164
+ log(
165
+ DEBUG,
166
+ "ServerAppIoServicer.PullTaskRes callback: delete TaskIns/TaskRes",
167
+ )
146
168
 
147
169
  if context.is_active():
148
170
  return
@@ -164,10 +186,10 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
164
186
  self, request: GetRunRequest, context: grpc.ServicerContext
165
187
  ) -> GetRunResponse:
166
188
  """Get run information."""
167
- log(DEBUG, "DriverServicer.GetRun")
189
+ log(DEBUG, "ServerAppIoServicer.GetRun")
168
190
 
169
191
  # Init state
170
- state: State = self.state_factory.state()
192
+ state: LinkState = self.state_factory.state()
171
193
 
172
194
  # Retrieve run information
173
195
  run = state.get_run(request.run_id)
@@ -175,21 +197,13 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
175
197
  if run is None:
176
198
  return GetRunResponse()
177
199
 
178
- return GetRunResponse(
179
- run=Run(
180
- run_id=run.run_id,
181
- fab_id=run.fab_id,
182
- fab_version=run.fab_version,
183
- override_config=user_config_to_proto(run.override_config),
184
- fab_hash=run.fab_hash,
185
- )
186
- )
200
+ return GetRunResponse(run=run_to_proto(run))
187
201
 
188
202
  def GetFab(
189
203
  self, request: GetFabRequest, context: grpc.ServicerContext
190
204
  ) -> GetFabResponse:
191
205
  """Get FAB from Ffs."""
192
- log(DEBUG, "DriverServicer.GetFab")
206
+ log(DEBUG, "ServerAppIoServicer.GetFab")
193
207
 
194
208
  ffs: Ffs = self.ffs_factory.ffs()
195
209
  if result := ffs.get(request.hash_str):
@@ -198,6 +212,78 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
198
212
 
199
213
  raise ValueError(f"Found no FAB with hash: {request.hash_str}")
200
214
 
215
+ def PullServerAppInputs(
216
+ self, request: PullServerAppInputsRequest, context: grpc.ServicerContext
217
+ ) -> PullServerAppInputsResponse:
218
+ """Pull ServerApp process inputs."""
219
+ log(DEBUG, "ServerAppIoServicer.PullServerAppInputs")
220
+ # Init access to LinkState and Ffs
221
+ state = self.state_factory.state()
222
+ ffs = self.ffs_factory.ffs()
223
+
224
+ # Lock access to LinkState, preventing obtaining the same pending run_id
225
+ with self.lock:
226
+ # Attempt getting the run_id of a pending run
227
+ run_id = state.get_pending_run_id()
228
+ # If there's no pending run, return an empty response
229
+ if run_id is None:
230
+ return PullServerAppInputsResponse()
231
+
232
+ # Retrieve Context, Run and Fab for the run_id
233
+ serverapp_ctxt = state.get_serverapp_context(run_id)
234
+ run = state.get_run(run_id)
235
+ fab = None
236
+ if run and run.fab_hash:
237
+ if result := ffs.get(run.fab_hash):
238
+ fab = Fab(run.fab_hash, result[0])
239
+ if run and fab and serverapp_ctxt:
240
+ # Update run status to STARTING
241
+ if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
242
+ log(INFO, "Starting run %d", run_id)
243
+ return PullServerAppInputsResponse(
244
+ context=context_to_proto(serverapp_ctxt),
245
+ run=run_to_proto(run),
246
+ fab=fab_to_proto(fab),
247
+ )
248
+
249
+ # Raise an exception if the Run or Fab is not found,
250
+ # or if the status cannot be updated to STARTING
251
+ raise RuntimeError(f"Failed to start run {run_id}")
252
+
253
+ def PushServerAppOutputs(
254
+ self, request: PushServerAppOutputsRequest, context: grpc.ServicerContext
255
+ ) -> PushServerAppOutputsResponse:
256
+ """Push ServerApp process outputs."""
257
+ log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
258
+ state = self.state_factory.state()
259
+ state.set_serverapp_context(request.run_id, context_from_proto(request.context))
260
+ return PushServerAppOutputsResponse()
261
+
262
+ def UpdateRunStatus(
263
+ self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
264
+ ) -> UpdateRunStatusResponse:
265
+ """Update the status of a run."""
266
+ log(DEBUG, "ControlServicer.UpdateRunStatus")
267
+ state = self.state_factory.state()
268
+
269
+ # Update the run status
270
+ state.update_run_status(
271
+ run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
272
+ )
273
+ return UpdateRunStatusResponse()
274
+
275
+ def PushLogs(
276
+ self, request: PushLogsRequest, context: grpc.ServicerContext
277
+ ) -> PushLogsResponse:
278
+ """Push logs."""
279
+ log(DEBUG, "ServerAppIoServicer.PushLogs")
280
+ state = self.state_factory.state()
281
+
282
+ # Add logs to LinkState
283
+ merged_logs = "".join(request.logs)
284
+ state.add_serverapp_log(request.run_id, merged_logs)
285
+ return PushLogsResponse()
286
+
201
287
 
202
288
  def _raise_if(validation_error: bool, detail: str) -> None:
203
289
  if validation_error:
@@ -48,7 +48,7 @@ from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
48
48
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
49
49
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
50
50
  from flwr.server.superlink.fleet.message_handler import message_handler
51
- from flwr.server.superlink.state import StateFactory
51
+ from flwr.server.superlink.linkstate import LinkStateFactory
52
52
 
53
53
  T = TypeVar("T", bound=GrpcMessage)
54
54
 
@@ -77,7 +77,9 @@ def _handle(
77
77
  class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
78
78
  """Fleet API via GrpcAdapter servicer."""
79
79
 
80
- def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
80
+ def __init__(
81
+ self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
82
+ ) -> None:
81
83
  self.state_factory = state_factory
82
84
  self.ffs_factory = ffs_factory
83
85
 
@@ -30,7 +30,7 @@ from flwr.proto.transport_pb2_grpc import ( # pylint: disable=E0611
30
30
  add_FlowerServiceServicer_to_server,
31
31
  )
32
32
  from flwr.server.client_manager import ClientManager
33
- from flwr.server.superlink.driver.driver_servicer import DriverServicer
33
+ from flwr.server.superlink.driver.serverappio_servicer import ServerAppIoServicer
34
34
  from flwr.server.superlink.fleet.grpc_adapter.grpc_adapter_servicer import (
35
35
  GrpcAdapterServicer,
36
36
  )
@@ -161,7 +161,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments,R0917
161
161
  tuple[FleetServicer, AddServicerToServerFn],
162
162
  tuple[GrpcAdapterServicer, AddServicerToServerFn],
163
163
  tuple[FlowerServiceServicer, AddServicerToServerFn],
164
- tuple[DriverServicer, AddServicerToServerFn],
164
+ tuple[ServerAppIoServicer, AddServicerToServerFn],
165
165
  ],
166
166
  server_address: str,
167
167
  max_concurrent_workers: int = 1000,
@@ -37,13 +37,15 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
37
37
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
38
38
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
39
39
  from flwr.server.superlink.fleet.message_handler import message_handler
40
- from flwr.server.superlink.state import StateFactory
40
+ from flwr.server.superlink.linkstate import LinkStateFactory
41
41
 
42
42
 
43
43
  class FleetServicer(fleet_pb2_grpc.FleetServicer):
44
44
  """Fleet API servicer."""
45
45
 
46
- def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
46
+ def __init__(
47
+ self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
48
+ ) -> None:
47
49
  self.state_factory = state_factory
48
50
  self.ffs_factory = ffs_factory
49
51
 
@@ -45,7 +45,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
45
45
  )
46
46
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
47
47
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
48
- from flwr.server.superlink.state import State
48
+ from flwr.server.superlink.linkstate import LinkState
49
49
 
50
50
  _PUBLIC_KEY_HEADER = "public-key"
51
51
  _AUTH_TOKEN_HEADER = "auth-token"
@@ -84,7 +84,7 @@ def _get_value_from_tuples(
84
84
  class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
85
85
  """Server interceptor for node authentication."""
86
86
 
87
- def __init__(self, state: State):
87
+ def __init__(self, state: LinkState):
88
88
  self.state = state
89
89
 
90
90
  self.node_public_keys = state.get_node_public_keys()
@@ -43,12 +43,12 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
43
43
  )
44
44
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
45
45
  from flwr.server.superlink.ffs.ffs import Ffs
46
- from flwr.server.superlink.state import State
46
+ from flwr.server.superlink.linkstate import LinkState
47
47
 
48
48
 
49
49
  def create_node(
50
50
  request: CreateNodeRequest, # pylint: disable=unused-argument
51
- state: State,
51
+ state: LinkState,
52
52
  ) -> CreateNodeResponse:
53
53
  """."""
54
54
  # Create node
@@ -56,7 +56,7 @@ def create_node(
56
56
  return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
57
57
 
58
58
 
59
- def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
59
+ def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
60
60
  """."""
61
61
  # Validate node_id
62
62
  if request.node.anonymous or request.node.node_id == 0:
@@ -69,14 +69,14 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
69
69
 
70
70
  def ping(
71
71
  request: PingRequest, # pylint: disable=unused-argument
72
- state: State, # pylint: disable=unused-argument
72
+ state: LinkState, # pylint: disable=unused-argument
73
73
  ) -> PingResponse:
74
74
  """."""
75
75
  res = state.acknowledge_ping(request.node.node_id, request.ping_interval)
76
76
  return PingResponse(success=res)
77
77
 
78
78
 
79
- def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse:
79
+ def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse:
80
80
  """Pull TaskIns handler."""
81
81
  # Get node_id if client node is not anonymous
82
82
  node = request.node # pylint: disable=no-member
@@ -92,7 +92,7 @@ def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsRespo
92
92
  return response
93
93
 
94
94
 
95
- def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResResponse:
95
+ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResResponse:
96
96
  """Push TaskRes handler."""
97
97
  # pylint: disable=no-member
98
98
  task_res: TaskRes = request.task_res_list[0]
@@ -113,7 +113,7 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
113
113
 
114
114
 
115
115
  def get_run(
116
- request: GetRunRequest, state: State # pylint: disable=W0613
116
+ request: GetRunRequest, state: LinkState # pylint: disable=W0613
117
117
  ) -> GetRunResponse:
118
118
  """Get run information."""
119
119
  run = state.get_run(request.run_id)
@@ -40,7 +40,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
40
40
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
41
41
  from flwr.server.superlink.ffs.ffs import Ffs
42
42
  from flwr.server.superlink.fleet.message_handler import message_handler
43
- from flwr.server.superlink.state import State
43
+ from flwr.server.superlink.linkstate import LinkState
44
44
 
45
45
  try:
46
46
  from starlette.applications import Starlette
@@ -90,7 +90,7 @@ def rest_request_response(
90
90
  async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
91
91
  """Create Node."""
92
92
  # Get state from app
93
- state: State = app.state.STATE_FACTORY.state()
93
+ state: LinkState = app.state.STATE_FACTORY.state()
94
94
 
95
95
  # Handle message
96
96
  return message_handler.create_node(request=request, state=state)
@@ -100,7 +100,7 @@ async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
100
100
  async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
101
101
  """Delete Node Id."""
102
102
  # Get state from app
103
- state: State = app.state.STATE_FACTORY.state()
103
+ state: LinkState = app.state.STATE_FACTORY.state()
104
104
 
105
105
  # Handle message
106
106
  return message_handler.delete_node(request=request, state=state)
@@ -110,7 +110,7 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
110
110
  async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
111
111
  """Pull TaskIns."""
112
112
  # Get state from app
113
- state: State = app.state.STATE_FACTORY.state()
113
+ state: LinkState = app.state.STATE_FACTORY.state()
114
114
 
115
115
  # Handle message
116
116
  return message_handler.pull_task_ins(request=request, state=state)
@@ -121,7 +121,7 @@ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
121
121
  async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
122
122
  """Push TaskRes."""
123
123
  # Get state from app
124
- state: State = app.state.STATE_FACTORY.state()
124
+ state: LinkState = app.state.STATE_FACTORY.state()
125
125
 
126
126
  # Handle message
127
127
  return message_handler.push_task_res(request=request, state=state)
@@ -131,7 +131,7 @@ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
131
131
  async def ping(request: PingRequest) -> PingResponse:
132
132
  """Ping."""
133
133
  # Get state from app
134
- state: State = app.state.STATE_FACTORY.state()
134
+ state: LinkState = app.state.STATE_FACTORY.state()
135
135
 
136
136
  # Handle message
137
137
  return message_handler.ping(request=request, state=state)
@@ -141,7 +141,7 @@ async def ping(request: PingRequest) -> PingResponse:
141
141
  async def get_run(request: GetRunRequest) -> GetRunResponse:
142
142
  """GetRun."""
143
143
  # Get state from app
144
- state: State = app.state.STATE_FACTORY.state()
144
+ state: LinkState = app.state.STATE_FACTORY.state()
145
145
 
146
146
  # Handle message
147
147
  return message_handler.get_run(request=request, state=state)