flwr-nightly 1.13.0.dev20241019__py3-none-any.whl → 1.13.0.dev20241106__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (81) hide show
  1. flwr/cli/build.py +2 -2
  2. flwr/cli/config_utils.py +97 -0
  3. flwr/cli/log.py +63 -97
  4. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
  6. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  7. flwr/cli/run/run.py +18 -83
  8. flwr/client/app.py +13 -14
  9. flwr/client/clientapp/app.py +1 -2
  10. flwr/client/{node_state.py → run_info_store.py} +4 -3
  11. flwr/client/supernode/app.py +6 -8
  12. flwr/common/constant.py +39 -4
  13. flwr/common/context.py +9 -4
  14. flwr/common/date.py +3 -3
  15. flwr/common/logger.py +103 -0
  16. flwr/common/serde.py +24 -0
  17. flwr/common/telemetry.py +0 -6
  18. flwr/common/typing.py +9 -0
  19. flwr/proto/exec_pb2.py +6 -6
  20. flwr/proto/exec_pb2.pyi +8 -2
  21. flwr/proto/log_pb2.py +29 -0
  22. flwr/proto/log_pb2.pyi +39 -0
  23. flwr/proto/log_pb2_grpc.py +4 -0
  24. flwr/proto/log_pb2_grpc.pyi +4 -0
  25. flwr/proto/message_pb2.py +8 -8
  26. flwr/proto/message_pb2.pyi +4 -1
  27. flwr/proto/serverappio_pb2.py +52 -0
  28. flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
  29. flwr/proto/serverappio_pb2_grpc.py +376 -0
  30. flwr/proto/serverappio_pb2_grpc.pyi +147 -0
  31. flwr/proto/simulationio_pb2.py +38 -0
  32. flwr/proto/simulationio_pb2.pyi +65 -0
  33. flwr/proto/simulationio_pb2_grpc.py +171 -0
  34. flwr/proto/simulationio_pb2_grpc.pyi +68 -0
  35. flwr/server/app.py +247 -105
  36. flwr/server/driver/driver.py +15 -1
  37. flwr/server/driver/grpc_driver.py +26 -33
  38. flwr/server/driver/inmemory_driver.py +6 -14
  39. flwr/server/run_serverapp.py +29 -23
  40. flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
  41. flwr/server/serverapp/app.py +270 -0
  42. flwr/server/strategy/fedadam.py +11 -1
  43. flwr/server/superlink/driver/__init__.py +1 -1
  44. flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
  45. flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
  46. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
  47. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
  48. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
  49. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
  50. flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
  51. flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
  52. flwr/server/superlink/fleet/vce/vce_api.py +23 -23
  53. flwr/server/superlink/linkstate/__init__.py +28 -0
  54. flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +180 -21
  55. flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +144 -15
  56. flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
  57. flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +300 -50
  58. flwr/server/superlink/{state → linkstate}/utils.py +84 -2
  59. flwr/server/superlink/simulation/__init__.py +15 -0
  60. flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
  61. flwr/server/superlink/simulation/simulationio_servicer.py +132 -0
  62. flwr/simulation/__init__.py +2 -0
  63. flwr/simulation/app.py +1 -1
  64. flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
  65. flwr/simulation/run_simulation.py +57 -131
  66. flwr/simulation/simulationio_connection.py +86 -0
  67. flwr/superexec/app.py +6 -134
  68. flwr/superexec/deployment.py +60 -65
  69. flwr/superexec/exec_grpc.py +15 -8
  70. flwr/superexec/exec_servicer.py +34 -63
  71. flwr/superexec/executor.py +22 -4
  72. flwr/superexec/simulation.py +13 -8
  73. {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/METADATA +1 -1
  74. {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/RECORD +77 -64
  75. {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/entry_points.txt +1 -0
  76. flwr/client/node_state_tests.py +0 -66
  77. flwr/proto/driver_pb2.py +0 -42
  78. flwr/proto/driver_pb2_grpc.py +0 -239
  79. flwr/proto/driver_pb2_grpc.pyi +0 -94
  80. {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/LICENSE +0 -0
  81. {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/WHEEL +0 -0
@@ -0,0 +1,270 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower ServerApp process."""
16
+
17
+ import argparse
18
+ import sys
19
+ from logging import DEBUG, ERROR, INFO, WARN
20
+ from os.path import isfile
21
+ from pathlib import Path
22
+ from queue import Queue
23
+ from time import sleep
24
+ from typing import Optional
25
+
26
+ from flwr.cli.config_utils import get_fab_metadata
27
+ from flwr.cli.install import install_from_fab
28
+ from flwr.common.config import (
29
+ get_flwr_dir,
30
+ get_fused_config_from_dir,
31
+ get_project_config,
32
+ get_project_dir,
33
+ )
34
+ from flwr.common.constant import Status, SubStatus
35
+ from flwr.common.logger import (
36
+ log,
37
+ mirror_output_to_queue,
38
+ restore_output,
39
+ start_log_uploader,
40
+ stop_log_uploader,
41
+ )
42
+ from flwr.common.serde import (
43
+ context_from_proto,
44
+ context_to_proto,
45
+ fab_from_proto,
46
+ run_from_proto,
47
+ run_status_to_proto,
48
+ )
49
+ from flwr.common.typing import RunStatus
50
+ from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
51
+ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
52
+ PullServerAppInputsRequest,
53
+ PullServerAppInputsResponse,
54
+ PushServerAppOutputsRequest,
55
+ )
56
+ from flwr.server.driver.grpc_driver import GrpcDriver
57
+ from flwr.server.run_serverapp import run as run_
58
+
59
+
60
+ def flwr_serverapp() -> None:
61
+ """Run process-isolated Flower ServerApp."""
62
+ # Capture stdout/stderr
63
+ log_queue: Queue[Optional[str]] = Queue()
64
+ mirror_output_to_queue(log_queue)
65
+
66
+ parser = argparse.ArgumentParser(
67
+ description="Run a Flower ServerApp",
68
+ )
69
+ parser.add_argument(
70
+ "--superlink",
71
+ type=str,
72
+ help="Address of SuperLink's DriverAPI",
73
+ )
74
+ parser.add_argument(
75
+ "--run-once",
76
+ action="store_true",
77
+ help="When set, this process will start a single ServerApp "
78
+ "for a pending Run. If no pending run the process will exit. ",
79
+ )
80
+ parser.add_argument(
81
+ "--flwr-dir",
82
+ default=None,
83
+ help="""The path containing installed Flower Apps.
84
+ By default, this value is equal to:
85
+
86
+ - `$FLWR_HOME/` if `$FLWR_HOME` is defined
87
+ - `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
88
+ - `$HOME/.flwr/` in all other cases
89
+ """,
90
+ )
91
+ parser.add_argument(
92
+ "--insecure",
93
+ action="store_true",
94
+ help="Run the server without HTTPS, regardless of whether certificate "
95
+ "paths are provided. By default, the server runs with HTTPS enabled. "
96
+ "Use this flag only if you understand the risks.",
97
+ )
98
+ parser.add_argument(
99
+ "--root-certificates",
100
+ metavar="ROOT_CERT",
101
+ type=str,
102
+ help="Specifies the path to the PEM-encoded root certificate file for "
103
+ "establishing secure HTTPS connections.",
104
+ )
105
+ args = parser.parse_args()
106
+
107
+ log(INFO, "Starting Flower ServerApp")
108
+ certificates = _try_obtain_certificates(args)
109
+
110
+ log(
111
+ DEBUG,
112
+ "Staring isolated `ServerApp` connected to SuperLink DriverAPI at %s",
113
+ args.superlink,
114
+ )
115
+ run_serverapp(
116
+ superlink=args.superlink,
117
+ log_queue=log_queue,
118
+ run_once=args.run_once,
119
+ flwr_dir_=args.flwr_dir,
120
+ certificates=certificates,
121
+ )
122
+
123
+ # Restore stdout/stderr
124
+ restore_output()
125
+
126
+
127
+ def _try_obtain_certificates(
128
+ args: argparse.Namespace,
129
+ ) -> Optional[bytes]:
130
+
131
+ if args.insecure:
132
+ if args.root_certificates is not None:
133
+ sys.exit(
134
+ "Conflicting options: The '--insecure' flag disables HTTPS, "
135
+ "but '--root-certificates' was also specified. Please remove "
136
+ "the '--root-certificates' option when running in insecure mode, "
137
+ "or omit '--insecure' to use HTTPS."
138
+ )
139
+ log(
140
+ WARN,
141
+ "Option `--insecure` was set. Starting insecure HTTP channel to %s.",
142
+ args.superlink,
143
+ )
144
+ root_certificates = None
145
+ else:
146
+ # Load the certificates if provided, or load the system certificates
147
+ if not isfile(args.root_certificates):
148
+ sys.exit("Path argument `--root-certificates` does not point to a file.")
149
+ root_certificates = Path(args.root_certificates).read_bytes()
150
+ log(
151
+ DEBUG,
152
+ "Starting secure HTTPS channel to %s "
153
+ "with the following certificates: %s.",
154
+ args.superlink,
155
+ args.root_certificates,
156
+ )
157
+ return root_certificates
158
+
159
+
160
+ def run_serverapp( # pylint: disable=R0914, disable=W0212
161
+ superlink: str,
162
+ log_queue: Queue[Optional[str]],
163
+ run_once: bool,
164
+ flwr_dir_: Optional[str] = None,
165
+ certificates: Optional[bytes] = None,
166
+ ) -> None:
167
+ """Run Flower ServerApp process."""
168
+ driver = GrpcDriver(
169
+ serverappio_service_address=superlink,
170
+ root_certificates=certificates,
171
+ )
172
+
173
+ # Resolve directory where FABs are installed
174
+ flwr_dir = get_flwr_dir(flwr_dir_)
175
+ log_uploader = None
176
+
177
+ while True:
178
+
179
+ try:
180
+ # Pull ServerAppInputs from LinkState
181
+ req = PullServerAppInputsRequest()
182
+ res: PullServerAppInputsResponse = driver._stub.PullServerAppInputs(req)
183
+ if not res.HasField("run"):
184
+ sleep(3)
185
+ run_status = None
186
+ continue
187
+
188
+ context = context_from_proto(res.context)
189
+ run = run_from_proto(res.run)
190
+ fab = fab_from_proto(res.fab)
191
+
192
+ driver.init_run(run.run_id)
193
+
194
+ # Start log uploader for this run
195
+ log_uploader = start_log_uploader(
196
+ log_queue=log_queue,
197
+ node_id=0,
198
+ run_id=run.run_id,
199
+ stub=driver._stub,
200
+ )
201
+
202
+ log(DEBUG, "ServerApp process starts FAB installation.")
203
+ install_from_fab(fab.content, flwr_dir=flwr_dir, skip_prompt=True)
204
+
205
+ fab_id, fab_version = get_fab_metadata(fab.content)
206
+
207
+ app_path = str(get_project_dir(fab_id, fab_version, fab.hash_str, flwr_dir))
208
+ config = get_project_config(app_path)
209
+
210
+ # Obtain server app reference and the run config
211
+ server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
212
+ server_app_run_config = get_fused_config_from_dir(
213
+ Path(app_path), run.override_config
214
+ )
215
+
216
+ # Update run_config in context
217
+ context.run_config = server_app_run_config
218
+
219
+ log(
220
+ DEBUG,
221
+ "Flower will load ServerApp `%s` in %s",
222
+ server_app_attr,
223
+ app_path,
224
+ )
225
+
226
+ # Change status to Running
227
+ run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
228
+ driver._stub.UpdateRunStatus(
229
+ UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
230
+ )
231
+
232
+ # Load and run the ServerApp with the Driver
233
+ updated_context = run_(
234
+ driver=driver,
235
+ server_app_dir=app_path,
236
+ server_app_attr=server_app_attr,
237
+ context=context,
238
+ )
239
+
240
+ # Send resulting context
241
+ context_proto = context_to_proto(updated_context)
242
+ out_req = PushServerAppOutputsRequest(
243
+ run_id=run.run_id, context=context_proto
244
+ )
245
+ _ = driver._stub.PushServerAppOutputs(out_req)
246
+
247
+ run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
248
+
249
+ except Exception as ex: # pylint: disable=broad-exception-caught
250
+ exc_entity = "ServerApp"
251
+ log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
252
+ run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))
253
+
254
+ finally:
255
+ if run_status:
256
+ run_status_proto = run_status_to_proto(run_status)
257
+ driver._stub.UpdateRunStatus(
258
+ UpdateRunStatusRequest(
259
+ run_id=run.run_id, run_status=run_status_proto
260
+ )
261
+ )
262
+
263
+ # Stop log uploader for this run
264
+ if log_uploader:
265
+ stop_log_uploader(log_queue, log_uploader)
266
+ log_uploader = None
267
+
268
+ # Stop the loop if `flwr-serverapp` is expected to process a single run
269
+ if run_once:
270
+ break
@@ -170,8 +170,18 @@ class FedAdam(FedOpt):
170
170
  for x, y in zip(self.v_t, delta_t)
171
171
  ]
172
172
 
173
+ # Compute the bias-corrected learning rate, `eta_norm` for improving convergence
174
+ # in the early rounds of FL training. This `eta_norm` is `\alpha_t` in Kingma &
175
+ # Ba, 2014 (http://arxiv.org/abs/1412.6980) "Adam: A Method for Stochastic
176
+ # Optimization" in the formula line right before Section 2.1.
177
+ eta_norm = (
178
+ self.eta
179
+ * np.sqrt(1 - np.power(self.beta_2, server_round + 1.0))
180
+ / (1 - np.power(self.beta_1, server_round + 1.0))
181
+ )
182
+
173
183
  new_weights = [
174
- x + self.eta * y / (np.sqrt(z) + self.tau)
184
+ x + eta_norm * y / (np.sqrt(z) + self.tau)
175
185
  for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
176
186
  ]
177
187
 
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower driver service."""
15
+ """Flower ServerAppIo service."""
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Driver gRPC API."""
15
+ """ServerAppIo gRPC API."""
16
16
 
17
17
  from logging import INFO
18
18
  from typing import Optional
@@ -21,37 +21,40 @@ import grpc
21
21
 
22
22
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
23
23
  from flwr.common.logger import log
24
- from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
25
- add_DriverServicer_to_server,
24
+ from flwr.proto.serverappio_pb2_grpc import ( # pylint: disable=E0611
25
+ add_ServerAppIoServicer_to_server,
26
26
  )
27
27
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
28
- from flwr.server.superlink.state import StateFactory
28
+ from flwr.server.superlink.linkstate import LinkStateFactory
29
29
 
30
30
  from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
31
- from .driver_servicer import DriverServicer
31
+ from .serverappio_servicer import ServerAppIoServicer
32
32
 
33
33
 
34
- def run_driver_api_grpc(
34
+ def run_serverappio_api_grpc(
35
35
  address: str,
36
- state_factory: StateFactory,
36
+ state_factory: LinkStateFactory,
37
37
  ffs_factory: FfsFactory,
38
38
  certificates: Optional[tuple[bytes, bytes, bytes]],
39
39
  ) -> grpc.Server:
40
- """Run Driver API (gRPC, request-response)."""
41
- # Create Driver API gRPC server
42
- driver_servicer: grpc.Server = DriverServicer(
40
+ """Run ServerAppIo API (gRPC, request-response)."""
41
+ # Create ServerAppIo API gRPC server
42
+ serverappio_servicer: grpc.Server = ServerAppIoServicer(
43
43
  state_factory=state_factory,
44
44
  ffs_factory=ffs_factory,
45
45
  )
46
- driver_add_servicer_to_server_fn = add_DriverServicer_to_server
47
- driver_grpc_server = generic_create_grpc_server(
48
- servicer_and_add_fn=(driver_servicer, driver_add_servicer_to_server_fn),
46
+ serverappio_add_servicer_to_server_fn = add_ServerAppIoServicer_to_server
47
+ serverappio_grpc_server = generic_create_grpc_server(
48
+ servicer_and_add_fn=(
49
+ serverappio_servicer,
50
+ serverappio_add_servicer_to_server_fn,
51
+ ),
49
52
  server_address=address,
50
53
  max_message_length=GRPC_MAX_MESSAGE_LENGTH,
51
54
  certificates=certificates,
52
55
  )
53
56
 
54
- log(INFO, "Flower ECE: Starting Driver API (gRPC-rere) on %s", address)
55
- driver_grpc_server.start()
57
+ log(INFO, "Flower ECE: Starting ServerAppIo API (gRPC-rere) on %s", address)
58
+ serverappio_grpc_server.start()
56
59
 
57
- return driver_grpc_server
60
+ return serverappio_grpc_server
@@ -12,62 +12,80 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Driver API servicer."""
15
+ """ServerAppIo API servicer."""
16
16
 
17
17
 
18
+ import threading
18
19
  import time
19
- from logging import DEBUG
20
+ from logging import DEBUG, INFO
20
21
  from typing import Optional
21
22
  from uuid import UUID
22
23
 
23
24
  import grpc
24
25
 
26
+ from flwr.common import ConfigsRecord
27
+ from flwr.common.constant import Status
25
28
  from flwr.common.logger import log
26
29
  from flwr.common.serde import (
30
+ context_from_proto,
31
+ context_to_proto,
27
32
  fab_from_proto,
28
33
  fab_to_proto,
34
+ run_status_from_proto,
35
+ run_to_proto,
29
36
  user_config_from_proto,
30
- user_config_to_proto,
31
- )
32
- from flwr.common.typing import Fab
33
- from flwr.proto import driver_pb2_grpc # pylint: disable=E0611
34
- from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
35
- GetNodesRequest,
36
- GetNodesResponse,
37
- PullTaskResRequest,
38
- PullTaskResResponse,
39
- PushTaskInsRequest,
40
- PushTaskInsResponse,
41
37
  )
38
+ from flwr.common.typing import Fab, RunStatus
39
+ from flwr.proto import serverappio_pb2_grpc # pylint: disable=E0611
42
40
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
41
+ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
42
+ PushLogsRequest,
43
+ PushLogsResponse,
44
+ )
43
45
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
44
46
  from flwr.proto.run_pb2 import ( # pylint: disable=E0611
45
47
  CreateRunRequest,
46
48
  CreateRunResponse,
47
49
  GetRunRequest,
48
50
  GetRunResponse,
49
- Run,
51
+ UpdateRunStatusRequest,
52
+ UpdateRunStatusResponse,
53
+ )
54
+ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
55
+ GetNodesRequest,
56
+ GetNodesResponse,
57
+ PullServerAppInputsRequest,
58
+ PullServerAppInputsResponse,
59
+ PullTaskResRequest,
60
+ PullTaskResResponse,
61
+ PushServerAppOutputsRequest,
62
+ PushServerAppOutputsResponse,
63
+ PushTaskInsRequest,
64
+ PushTaskInsResponse,
50
65
  )
51
66
  from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
52
67
  from flwr.server.superlink.ffs.ffs import Ffs
53
68
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
54
- from flwr.server.superlink.state import State, StateFactory
69
+ from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
55
70
  from flwr.server.utils.validator import validate_task_ins_or_res
56
71
 
57
72
 
58
- class DriverServicer(driver_pb2_grpc.DriverServicer):
59
- """Driver API servicer."""
73
+ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
74
+ """ServerAppIo API servicer."""
60
75
 
61
- def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
76
+ def __init__(
77
+ self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
78
+ ) -> None:
62
79
  self.state_factory = state_factory
63
80
  self.ffs_factory = ffs_factory
81
+ self.lock = threading.RLock()
64
82
 
65
83
  def GetNodes(
66
84
  self, request: GetNodesRequest, context: grpc.ServicerContext
67
85
  ) -> GetNodesResponse:
68
86
  """Get available nodes."""
69
- log(DEBUG, "DriverServicer.GetNodes")
70
- state: State = self.state_factory.state()
87
+ log(DEBUG, "ServerAppIoServicer.GetNodes")
88
+ state: LinkState = self.state_factory.state()
71
89
  all_ids: set[int] = state.get_nodes(request.run_id)
72
90
  nodes: list[Node] = [
73
91
  Node(node_id=node_id, anonymous=False) for node_id in all_ids
@@ -78,8 +96,8 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
78
96
  self, request: CreateRunRequest, context: grpc.ServicerContext
79
97
  ) -> CreateRunResponse:
80
98
  """Create run ID."""
81
- log(DEBUG, "DriverServicer.CreateRun")
82
- state: State = self.state_factory.state()
99
+ log(DEBUG, "ServerAppIoServicer.CreateRun")
100
+ state: LinkState = self.state_factory.state()
83
101
  if request.HasField("fab"):
84
102
  fab = fab_from_proto(request.fab)
85
103
  ffs: Ffs = self.ffs_factory.ffs()
@@ -95,6 +113,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
95
113
  request.fab_version,
96
114
  fab_hash,
97
115
  user_config_from_proto(request.override_config),
116
+ ConfigsRecord(),
98
117
  )
99
118
  return CreateRunResponse(run_id=run_id)
100
119
 
@@ -102,7 +121,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
102
121
  self, request: PushTaskInsRequest, context: grpc.ServicerContext
103
122
  ) -> PushTaskInsResponse:
104
123
  """Push a set of TaskIns."""
105
- log(DEBUG, "DriverServicer.PushTaskIns")
124
+ log(DEBUG, "ServerAppIoServicer.PushTaskIns")
106
125
 
107
126
  # Set pushed_at (timestamp in seconds)
108
127
  pushed_at = time.time()
@@ -116,7 +135,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
116
135
  _raise_if(bool(validation_errors), ", ".join(validation_errors))
117
136
 
118
137
  # Init state
119
- state: State = self.state_factory.state()
138
+ state: LinkState = self.state_factory.state()
120
139
 
121
140
  # Store each TaskIns
122
141
  task_ids: list[Optional[UUID]] = []
@@ -132,17 +151,20 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
132
151
  self, request: PullTaskResRequest, context: grpc.ServicerContext
133
152
  ) -> PullTaskResResponse:
134
153
  """Pull a set of TaskRes."""
135
- log(DEBUG, "DriverServicer.PullTaskRes")
154
+ log(DEBUG, "ServerAppIoServicer.PullTaskRes")
136
155
 
137
156
  # Convert each task_id str to UUID
138
157
  task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
139
158
 
140
159
  # Init state
141
- state: State = self.state_factory.state()
160
+ state: LinkState = self.state_factory.state()
142
161
 
143
162
  # Register callback
144
163
  def on_rpc_done() -> None:
145
- log(DEBUG, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes")
164
+ log(
165
+ DEBUG,
166
+ "ServerAppIoServicer.PullTaskRes callback: delete TaskIns/TaskRes",
167
+ )
146
168
 
147
169
  if context.is_active():
148
170
  return
@@ -164,10 +186,10 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
164
186
  self, request: GetRunRequest, context: grpc.ServicerContext
165
187
  ) -> GetRunResponse:
166
188
  """Get run information."""
167
- log(DEBUG, "DriverServicer.GetRun")
189
+ log(DEBUG, "ServerAppIoServicer.GetRun")
168
190
 
169
191
  # Init state
170
- state: State = self.state_factory.state()
192
+ state: LinkState = self.state_factory.state()
171
193
 
172
194
  # Retrieve run information
173
195
  run = state.get_run(request.run_id)
@@ -175,21 +197,13 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
175
197
  if run is None:
176
198
  return GetRunResponse()
177
199
 
178
- return GetRunResponse(
179
- run=Run(
180
- run_id=run.run_id,
181
- fab_id=run.fab_id,
182
- fab_version=run.fab_version,
183
- override_config=user_config_to_proto(run.override_config),
184
- fab_hash=run.fab_hash,
185
- )
186
- )
200
+ return GetRunResponse(run=run_to_proto(run))
187
201
 
188
202
  def GetFab(
189
203
  self, request: GetFabRequest, context: grpc.ServicerContext
190
204
  ) -> GetFabResponse:
191
205
  """Get FAB from Ffs."""
192
- log(DEBUG, "DriverServicer.GetFab")
206
+ log(DEBUG, "ServerAppIoServicer.GetFab")
193
207
 
194
208
  ffs: Ffs = self.ffs_factory.ffs()
195
209
  if result := ffs.get(request.hash_str):
@@ -198,6 +212,78 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
198
212
 
199
213
  raise ValueError(f"Found no FAB with hash: {request.hash_str}")
200
214
 
215
+ def PullServerAppInputs(
216
+ self, request: PullServerAppInputsRequest, context: grpc.ServicerContext
217
+ ) -> PullServerAppInputsResponse:
218
+ """Pull ServerApp process inputs."""
219
+ log(DEBUG, "ServerAppIoServicer.PullServerAppInputs")
220
+ # Init access to LinkState and Ffs
221
+ state = self.state_factory.state()
222
+ ffs = self.ffs_factory.ffs()
223
+
224
+ # Lock access to LinkState, preventing obtaining the same pending run_id
225
+ with self.lock:
226
+ # Attempt getting the run_id of a pending run
227
+ run_id = state.get_pending_run_id()
228
+ # If there's no pending run, return an empty response
229
+ if run_id is None:
230
+ return PullServerAppInputsResponse()
231
+
232
+ # Retrieve Context, Run and Fab for the run_id
233
+ serverapp_ctxt = state.get_serverapp_context(run_id)
234
+ run = state.get_run(run_id)
235
+ fab = None
236
+ if run and run.fab_hash:
237
+ if result := ffs.get(run.fab_hash):
238
+ fab = Fab(run.fab_hash, result[0])
239
+ if run and fab and serverapp_ctxt:
240
+ # Update run status to STARTING
241
+ if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
242
+ log(INFO, "Starting run %d", run_id)
243
+ return PullServerAppInputsResponse(
244
+ context=context_to_proto(serverapp_ctxt),
245
+ run=run_to_proto(run),
246
+ fab=fab_to_proto(fab),
247
+ )
248
+
249
+ # Raise an exception if the Run or Fab is not found,
250
+ # or if the status cannot be updated to STARTING
251
+ raise RuntimeError(f"Failed to start run {run_id}")
252
+
253
+ def PushServerAppOutputs(
254
+ self, request: PushServerAppOutputsRequest, context: grpc.ServicerContext
255
+ ) -> PushServerAppOutputsResponse:
256
+ """Push ServerApp process outputs."""
257
+ log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
258
+ state = self.state_factory.state()
259
+ state.set_serverapp_context(request.run_id, context_from_proto(request.context))
260
+ return PushServerAppOutputsResponse()
261
+
262
+ def UpdateRunStatus(
263
+ self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
264
+ ) -> UpdateRunStatusResponse:
265
+ """Update the status of a run."""
266
+ log(DEBUG, "ControlServicer.UpdateRunStatus")
267
+ state = self.state_factory.state()
268
+
269
+ # Update the run status
270
+ state.update_run_status(
271
+ run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
272
+ )
273
+ return UpdateRunStatusResponse()
274
+
275
+ def PushLogs(
276
+ self, request: PushLogsRequest, context: grpc.ServicerContext
277
+ ) -> PushLogsResponse:
278
+ """Push logs."""
279
+ log(DEBUG, "ServerAppIoServicer.PushLogs")
280
+ state = self.state_factory.state()
281
+
282
+ # Add logs to LinkState
283
+ merged_logs = "".join(request.logs)
284
+ state.add_serverapp_log(request.run_id, merged_logs)
285
+ return PushLogsResponse()
286
+
201
287
 
202
288
  def _raise_if(validation_error: bool, detail: str) -> None:
203
289
  if validation_error:
@@ -48,7 +48,7 @@ from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
48
48
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
49
49
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
50
50
  from flwr.server.superlink.fleet.message_handler import message_handler
51
- from flwr.server.superlink.state import StateFactory
51
+ from flwr.server.superlink.linkstate import LinkStateFactory
52
52
 
53
53
  T = TypeVar("T", bound=GrpcMessage)
54
54
 
@@ -77,7 +77,9 @@ def _handle(
77
77
  class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
78
78
  """Fleet API via GrpcAdapter servicer."""
79
79
 
80
- def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
80
+ def __init__(
81
+ self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
82
+ ) -> None:
81
83
  self.state_factory = state_factory
82
84
  self.ffs_factory = ffs_factory
83
85
 
@@ -30,7 +30,7 @@ from flwr.proto.transport_pb2_grpc import ( # pylint: disable=E0611
30
30
  add_FlowerServiceServicer_to_server,
31
31
  )
32
32
  from flwr.server.client_manager import ClientManager
33
- from flwr.server.superlink.driver.driver_servicer import DriverServicer
33
+ from flwr.server.superlink.driver.serverappio_servicer import ServerAppIoServicer
34
34
  from flwr.server.superlink.fleet.grpc_adapter.grpc_adapter_servicer import (
35
35
  GrpcAdapterServicer,
36
36
  )
@@ -161,7 +161,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments,R0917
161
161
  tuple[FleetServicer, AddServicerToServerFn],
162
162
  tuple[GrpcAdapterServicer, AddServicerToServerFn],
163
163
  tuple[FlowerServiceServicer, AddServicerToServerFn],
164
- tuple[DriverServicer, AddServicerToServerFn],
164
+ tuple[ServerAppIoServicer, AddServicerToServerFn],
165
165
  ],
166
166
  server_address: str,
167
167
  max_concurrent_workers: int = 1000,