flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (92) hide show
  1. flwr/cli/build.py +16 -2
  2. flwr/cli/config_utils.py +36 -14
  3. flwr/cli/install.py +17 -1
  4. flwr/cli/new/new.py +31 -20
  5. flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
  10. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
  12. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
  13. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
  14. flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
  15. flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
  16. flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
  17. flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
  18. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
  19. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
  20. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
  21. flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
  22. flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
  23. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
  25. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
  26. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
  27. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
  28. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
  29. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
  30. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
  31. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
  32. flwr/cli/run/run.py +128 -53
  33. flwr/client/app.py +56 -24
  34. flwr/client/client_app.py +28 -8
  35. flwr/client/grpc_adapter_client/connection.py +3 -2
  36. flwr/client/grpc_client/connection.py +3 -2
  37. flwr/client/grpc_rere_client/connection.py +17 -6
  38. flwr/client/message_handler/message_handler.py +1 -1
  39. flwr/client/node_state.py +59 -12
  40. flwr/client/node_state_tests.py +4 -3
  41. flwr/client/rest_client/connection.py +19 -8
  42. flwr/client/supernode/app.py +55 -24
  43. flwr/client/typing.py +2 -2
  44. flwr/common/config.py +87 -2
  45. flwr/common/constant.py +3 -0
  46. flwr/common/context.py +24 -9
  47. flwr/common/logger.py +25 -0
  48. flwr/common/serde.py +45 -0
  49. flwr/common/telemetry.py +17 -0
  50. flwr/common/typing.py +5 -0
  51. flwr/proto/common_pb2.py +36 -0
  52. flwr/proto/common_pb2.pyi +121 -0
  53. flwr/proto/common_pb2_grpc.py +4 -0
  54. flwr/proto/common_pb2_grpc.pyi +4 -0
  55. flwr/proto/driver_pb2.py +24 -19
  56. flwr/proto/driver_pb2.pyi +21 -1
  57. flwr/proto/exec_pb2.py +16 -11
  58. flwr/proto/exec_pb2.pyi +22 -1
  59. flwr/proto/run_pb2.py +12 -7
  60. flwr/proto/run_pb2.pyi +22 -1
  61. flwr/proto/task_pb2.py +7 -8
  62. flwr/server/__init__.py +2 -0
  63. flwr/server/compat/legacy_context.py +5 -4
  64. flwr/server/driver/grpc_driver.py +82 -140
  65. flwr/server/run_serverapp.py +40 -15
  66. flwr/server/server_app.py +56 -10
  67. flwr/server/serverapp_components.py +52 -0
  68. flwr/server/superlink/driver/driver_servicer.py +18 -3
  69. flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
  70. flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
  71. flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
  72. flwr/server/superlink/fleet/vce/vce_api.py +149 -117
  73. flwr/server/superlink/state/in_memory_state.py +11 -3
  74. flwr/server/superlink/state/sqlite_state.py +23 -8
  75. flwr/server/superlink/state/state.py +7 -2
  76. flwr/server/typing.py +2 -0
  77. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
  78. flwr/simulation/app.py +4 -3
  79. flwr/simulation/ray_transport/ray_actor.py +15 -19
  80. flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
  81. flwr/simulation/run_simulation.py +237 -66
  82. flwr/superexec/app.py +14 -7
  83. flwr/superexec/deployment.py +110 -33
  84. flwr/superexec/exec_grpc.py +5 -1
  85. flwr/superexec/exec_servicer.py +4 -1
  86. flwr/superexec/executor.py +18 -0
  87. flwr/superexec/simulation.py +151 -0
  88. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
  89. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +92 -86
  90. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
  91. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
  92. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
@@ -16,19 +16,21 @@
16
16
 
17
17
  import time
18
18
  import warnings
19
- from logging import DEBUG, ERROR, WARNING
20
- from typing import Iterable, List, Optional, Tuple, cast
19
+ from logging import DEBUG, WARNING
20
+ from typing import Iterable, List, Optional, cast
21
21
 
22
22
  import grpc
23
23
 
24
24
  from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
25
25
  from flwr.common.grpc import create_channel
26
26
  from flwr.common.logger import log
27
- from flwr.common.serde import message_from_taskres, message_to_taskins
27
+ from flwr.common.serde import (
28
+ message_from_taskres,
29
+ message_to_taskins,
30
+ user_config_from_proto,
31
+ )
28
32
  from flwr.common.typing import Run
29
33
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
30
- CreateRunRequest,
31
- CreateRunResponse,
32
34
  GetNodesRequest,
33
35
  GetNodesResponse,
34
36
  PullTaskResRequest,
@@ -53,167 +55,103 @@ Call `connect()` on the `GrpcDriverStub` instance before calling any of the othe
53
55
  """
54
56
 
55
57
 
56
- class GrpcDriverStub:
57
- """`GrpcDriverStub` provides access to the gRPC Driver API/service.
58
+ class GrpcDriver(Driver):
59
+ """`GrpcDriver` provides an interface to the Driver API.
58
60
 
59
61
  Parameters
60
62
  ----------
61
- driver_service_address : Optional[str]
62
- The IPv4 or IPv6 address of the Driver API server.
63
- Defaults to `"[::]:9091"`.
63
+ run_id : int
64
+ The identifier of the run.
65
+ driver_service_address : str (default: "[::]:9091")
66
+ The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
64
67
  root_certificates : Optional[bytes] (default: None)
65
68
  The PEM-encoded root certificates as a byte string.
66
69
  If provided, a secure connection using the certificates will be
67
70
  established to an SSL-enabled Flower server.
68
71
  """
69
72
 
70
- def __init__(
73
+ def __init__( # pylint: disable=too-many-arguments
71
74
  self,
75
+ run_id: int,
72
76
  driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
73
77
  root_certificates: Optional[bytes] = None,
74
78
  ) -> None:
75
- self.driver_service_address = driver_service_address
76
- self.root_certificates = root_certificates
77
- self.channel: Optional[grpc.Channel] = None
78
- self.stub: Optional[DriverStub] = None
79
+ self._run_id = run_id
80
+ self._addr = driver_service_address
81
+ self._cert = root_certificates
82
+ self._run: Optional[Run] = None
83
+ self._grpc_stub: Optional[DriverStub] = None
84
+ self._channel: Optional[grpc.Channel] = None
85
+ self.node = Node(node_id=0, anonymous=True)
79
86
 
80
- def is_connected(self) -> bool:
81
- """Return True if connected to the Driver API server, otherwise False."""
82
- return self.channel is not None
87
+ @property
88
+ def _is_connected(self) -> bool:
89
+ """Check if connected to the Driver API server."""
90
+ return self._channel is not None
91
+
92
+ def _connect(self) -> None:
93
+ """Connect to the Driver API.
83
94
 
84
- def connect(self) -> None:
85
- """Connect to the Driver API."""
95
+ This will not call GetRun.
96
+ """
86
97
  event(EventType.DRIVER_CONNECT)
87
- if self.channel is not None or self.stub is not None:
98
+ if self._is_connected:
88
99
  log(WARNING, "Already connected")
89
100
  return
90
- self.channel = create_channel(
91
- server_address=self.driver_service_address,
92
- insecure=(self.root_certificates is None),
93
- root_certificates=self.root_certificates,
101
+ self._channel = create_channel(
102
+ server_address=self._addr,
103
+ insecure=(self._cert is None),
104
+ root_certificates=self._cert,
94
105
  )
95
- self.stub = DriverStub(self.channel)
96
- log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)
106
+ self._grpc_stub = DriverStub(self._channel)
107
+ log(DEBUG, "[Driver] Connected to %s", self._addr)
97
108
 
98
- def disconnect(self) -> None:
109
+ def _disconnect(self) -> None:
99
110
  """Disconnect from the Driver API."""
100
111
  event(EventType.DRIVER_DISCONNECT)
101
- if self.channel is None or self.stub is None:
112
+ if not self._is_connected:
102
113
  log(DEBUG, "Already disconnected")
103
114
  return
104
- channel = self.channel
105
- self.channel = None
106
- self.stub = None
115
+ channel: grpc.Channel = self._channel
116
+ self._channel = None
117
+ self._grpc_stub = None
107
118
  channel.close()
108
119
  log(DEBUG, "[Driver] Disconnected")
109
120
 
110
- def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
111
- """Request for run ID."""
112
- # Check if channel is open
113
- if self.stub is None:
114
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
115
- raise ConnectionError("`GrpcDriverStub` instance not connected")
116
-
117
- # Call Driver API
118
- res: CreateRunResponse = self.stub.CreateRun(request=req)
119
- return res
120
-
121
- def get_run(self, req: GetRunRequest) -> GetRunResponse:
122
- """Get run information."""
123
- # Check if channel is open
124
- if self.stub is None:
125
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
126
- raise ConnectionError("`GrpcDriverStub` instance not connected")
127
-
128
- # Call gRPC Driver API
129
- res: GetRunResponse = self.stub.GetRun(request=req)
130
- return res
131
-
132
- def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
133
- """Get client IDs."""
134
- # Check if channel is open
135
- if self.stub is None:
136
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
137
- raise ConnectionError("`GrpcDriverStub` instance not connected")
138
-
139
- # Call gRPC Driver API
140
- res: GetNodesResponse = self.stub.GetNodes(request=req)
141
- return res
142
-
143
- def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
144
- """Schedule tasks."""
145
- # Check if channel is open
146
- if self.stub is None:
147
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
148
- raise ConnectionError("`GrpcDriverStub` instance not connected")
149
-
150
- # Call gRPC Driver API
151
- res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
152
- return res
153
-
154
- def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
155
- """Get task results."""
156
- # Check if channel is open
157
- if self.stub is None:
158
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
159
- raise ConnectionError("`GrpcDriverStub` instance not connected")
160
-
161
- # Call Driver API
162
- res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
163
- return res
164
-
165
-
166
- class GrpcDriver(Driver):
167
- """`Driver` class provides an interface to the Driver API.
168
-
169
- Parameters
170
- ----------
171
- run_id : int
172
- The identifier of the run.
173
- stub : Optional[GrpcDriverStub] (default: None)
174
- The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
175
- If None, an instance connected to "[::]:9091" will be created.
176
- """
177
-
178
- def __init__( # pylint: disable=too-many-arguments
179
- self,
180
- run_id: int,
181
- stub: Optional[GrpcDriverStub] = None,
182
- ) -> None:
183
- self._run_id = run_id
184
- self._run: Optional[Run] = None
185
- self.stub = stub if stub is not None else GrpcDriverStub()
186
- self.node = Node(node_id=0, anonymous=True)
121
+ def _init_run(self) -> None:
122
+ # Check if is initialized
123
+ if self._run is not None:
124
+ return
125
+ # Get the run info
126
+ req = GetRunRequest(run_id=self._run_id)
127
+ res: GetRunResponse = self._stub.GetRun(req)
128
+ if not res.HasField("run"):
129
+ raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
130
+ self._run = Run(
131
+ run_id=res.run.run_id,
132
+ fab_id=res.run.fab_id,
133
+ fab_version=res.run.fab_version,
134
+ override_config=user_config_from_proto(res.run.override_config),
135
+ )
187
136
 
188
137
  @property
189
138
  def run(self) -> Run:
190
139
  """Run information."""
191
- self._get_stub_and_run_id()
192
- return Run(**vars(cast(Run, self._run)))
140
+ self._init_run()
141
+ return Run(**vars(self._run))
193
142
 
194
- def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]:
195
- # Check if is initialized
196
- if self._run is None:
197
- # Connect
198
- if not self.stub.is_connected():
199
- self.stub.connect()
200
- # Get the run info
201
- req = GetRunRequest(run_id=self._run_id)
202
- res = self.stub.get_run(req)
203
- if not res.HasField("run"):
204
- raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
205
- self._run = Run(
206
- run_id=res.run.run_id,
207
- fab_id=res.run.fab_id,
208
- fab_version=res.run.fab_version,
209
- )
210
-
211
- return self.stub, self._run.run_id
143
+ @property
144
+ def _stub(self) -> DriverStub:
145
+ """Driver stub."""
146
+ if not self._is_connected:
147
+ self._connect()
148
+ return cast(DriverStub, self._grpc_stub)
212
149
 
213
150
  def _check_message(self, message: Message) -> None:
214
151
  # Check if the message is valid
215
152
  if not (
216
- message.metadata.run_id == cast(Run, self._run).run_id
153
+ # Assume self._run being initialized
154
+ message.metadata.run_id == self._run_id
217
155
  and message.metadata.src_node_id == self.node.node_id
218
156
  and message.metadata.message_id == ""
219
157
  and message.metadata.reply_to_message == ""
@@ -234,7 +172,7 @@ class GrpcDriver(Driver):
234
172
  This method constructs a new `Message` with given content and metadata.
235
173
  The `run_id` and `src_node_id` will be set automatically.
236
174
  """
237
- _, run_id = self._get_stub_and_run_id()
175
+ self._init_run()
238
176
  if ttl:
239
177
  warnings.warn(
240
178
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -245,7 +183,7 @@ class GrpcDriver(Driver):
245
183
 
246
184
  ttl_ = DEFAULT_TTL if ttl is None else ttl
247
185
  metadata = Metadata(
248
- run_id=run_id,
186
+ run_id=self._run_id,
249
187
  message_id="", # Will be set by the server
250
188
  src_node_id=self.node.node_id,
251
189
  dst_node_id=dst_node_id,
@@ -258,9 +196,11 @@ class GrpcDriver(Driver):
258
196
 
259
197
  def get_node_ids(self) -> List[int]:
260
198
  """Get node IDs."""
261
- stub, run_id = self._get_stub_and_run_id()
199
+ self._init_run()
262
200
  # Call GrpcDriverStub method
263
- res = stub.get_nodes(GetNodesRequest(run_id=run_id))
201
+ res: GetNodesResponse = self._stub.GetNodes(
202
+ GetNodesRequest(run_id=self._run_id)
203
+ )
264
204
  return [node.node_id for node in res.nodes]
265
205
 
266
206
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
@@ -269,7 +209,7 @@ class GrpcDriver(Driver):
269
209
  This method takes an iterable of messages and sends each message
270
210
  to the node specified in `dst_node_id`.
271
211
  """
272
- stub, _ = self._get_stub_and_run_id()
212
+ self._init_run()
273
213
  # Construct TaskIns
274
214
  task_ins_list: List[TaskIns] = []
275
215
  for msg in messages:
@@ -280,7 +220,9 @@ class GrpcDriver(Driver):
280
220
  # Add to list
281
221
  task_ins_list.append(taskins)
282
222
  # Call GrpcDriverStub method
283
- res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
223
+ res: PushTaskInsResponse = self._stub.PushTaskIns(
224
+ PushTaskInsRequest(task_ins_list=task_ins_list)
225
+ )
284
226
  return list(res.task_ids)
285
227
 
286
228
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
@@ -289,9 +231,9 @@ class GrpcDriver(Driver):
289
231
  This method is used to collect messages from the SuperLink that correspond to a
290
232
  set of given message IDs.
291
233
  """
292
- stub, _ = self._get_stub_and_run_id()
234
+ self._init_run()
293
235
  # Pull TaskRes
294
- res = stub.pull_task_res(
236
+ res: PullTaskResResponse = self._stub.PullTaskRes(
295
237
  PullTaskResRequest(node=self.node, task_ids=message_ids)
296
238
  )
297
239
  # Convert TaskRes to Message
@@ -331,7 +273,7 @@ class GrpcDriver(Driver):
331
273
  def close(self) -> None:
332
274
  """Disconnect from the SuperLink if connected."""
333
275
  # Check if `connect` was called before
334
- if not self.stub.is_connected():
276
+ if not self._is_connected:
335
277
  return
336
278
  # Disconnect
337
- self.stub.disconnect()
279
+ self._disconnect()
@@ -22,13 +22,22 @@ from pathlib import Path
22
22
  from typing import Optional
23
23
 
24
24
  from flwr.common import Context, EventType, RecordSet, event
25
- from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
25
+ from flwr.common.config import (
26
+ get_flwr_dir,
27
+ get_fused_config,
28
+ get_project_config,
29
+ get_project_dir,
30
+ )
26
31
  from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
27
32
  from flwr.common.object_ref import load_app
28
- from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611
33
+ from flwr.common.typing import UserConfig
34
+ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
35
+ CreateRunRequest,
36
+ CreateRunResponse,
37
+ )
29
38
 
30
39
  from .driver import Driver
31
- from .driver.grpc_driver import GrpcDriver, GrpcDriverStub
40
+ from .driver.grpc_driver import GrpcDriver
32
41
  from .server_app import LoadServerAppError, ServerApp
33
42
 
34
43
  ADDRESS_DRIVER_API = "0.0.0.0:9091"
@@ -37,6 +46,7 @@ ADDRESS_DRIVER_API = "0.0.0.0:9091"
37
46
  def run(
38
47
  driver: Driver,
39
48
  server_app_dir: str,
49
+ server_app_run_config: UserConfig,
40
50
  server_app_attr: Optional[str] = None,
41
51
  loaded_server_app: Optional[ServerApp] = None,
42
52
  ) -> None:
@@ -69,7 +79,9 @@ def run(
69
79
  server_app = _load()
70
80
 
71
81
  # Initialize Context
72
- context = Context(state=RecordSet())
82
+ context = Context(
83
+ node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config
84
+ )
73
85
 
74
86
  # Call ServerApp
75
87
  server_app(driver=driver, context=context)
@@ -144,22 +156,29 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
144
156
  "For more details, use: ``flower-server-app -h``"
145
157
  )
146
158
 
147
- stub = GrpcDriverStub(
148
- driver_service_address=args.superlink, root_certificates=root_certificates
149
- )
159
+ # Initialize GrpcDriver
150
160
  if args.run_id is not None:
151
161
  # User provided `--run-id`, but not `server-app`
152
- run_id = args.run_id
162
+ driver = GrpcDriver(
163
+ run_id=args.run_id,
164
+ driver_service_address=args.superlink,
165
+ root_certificates=root_certificates,
166
+ )
153
167
  else:
154
168
  # User provided `server-app`, but not `--run-id`
155
169
  # Create run if run_id is not provided
156
- stub.connect()
170
+ driver = GrpcDriver(
171
+ run_id=0, # Will be overwritten
172
+ driver_service_address=args.superlink,
173
+ root_certificates=root_certificates,
174
+ )
175
+ # Create run
157
176
  req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version)
158
- res = stub.create_run(req)
159
- run_id = res.run_id
177
+ res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
178
+ # Overwrite driver._run_id
179
+ driver._run_id = res.run_id # pylint: disable=W0212
160
180
 
161
- # Initialize GrpcDriver
162
- driver = GrpcDriver(run_id=run_id, stub=stub)
181
+ server_app_run_config = {}
163
182
 
164
183
  # Dynamically obtain ServerApp path based on run_id
165
184
  if args.run_id is not None:
@@ -168,7 +187,8 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
168
187
  run_ = driver.run
169
188
  server_app_dir = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir))
170
189
  config = get_project_config(server_app_dir)
171
- server_app_attr = config["flower"]["components"]["serverapp"]
190
+ server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
191
+ server_app_run_config = get_fused_config(run_, flwr_dir)
172
192
  else:
173
193
  # User provided `server-app`, but not `--run-id`
174
194
  server_app_dir = str(Path(args.dir).absolute())
@@ -182,7 +202,12 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
182
202
  )
183
203
 
184
204
  # Run the ServerApp with the Driver
185
- run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
205
+ run(
206
+ driver=driver,
207
+ server_app_dir=server_app_dir,
208
+ server_app_run_config=server_app_run_config,
209
+ server_app_attr=server_app_attr,
210
+ )
186
211
 
187
212
  # Clean up
188
213
  driver.close()
flwr/server/server_app.py CHANGED
@@ -17,8 +17,11 @@
17
17
 
18
18
  from typing import Callable, Optional
19
19
 
20
- from flwr.common import Context, RecordSet
21
- from flwr.common.logger import warn_preview_feature
20
+ from flwr.common import Context
21
+ from flwr.common.logger import (
22
+ warn_deprecated_feature_with_example,
23
+ warn_preview_feature,
24
+ )
22
25
  from flwr.server.strategy import Strategy
23
26
 
24
27
  from .client_manager import ClientManager
@@ -26,7 +29,20 @@ from .compat import start_driver
26
29
  from .driver import Driver
27
30
  from .server import Server
28
31
  from .server_config import ServerConfig
29
- from .typing import ServerAppCallable
32
+ from .typing import ServerAppCallable, ServerFn
33
+
34
+ SERVER_FN_USAGE_EXAMPLE = """
35
+
36
+ def server_fn(context: Context):
37
+ server_config = ServerConfig(num_rounds=3)
38
+ strategy = FedAvg()
39
+ return ServerAppComponents(
40
+ strategy=strategy,
41
+ server_config=server_config,
42
+ )
43
+
44
+ app = ServerApp(server_fn=server_fn)
45
+ """
30
46
 
31
47
 
32
48
  class ServerApp:
@@ -36,13 +52,15 @@ class ServerApp:
36
52
  --------
37
53
  Use the `ServerApp` with an existing `Strategy`:
38
54
 
39
- >>> server_config = ServerConfig(num_rounds=3)
40
- >>> strategy = FedAvg()
55
+ >>> def server_fn(context: Context):
56
+ >>> server_config = ServerConfig(num_rounds=3)
57
+ >>> strategy = FedAvg()
58
+ >>> return ServerAppComponents(
59
+ >>> strategy=strategy,
60
+ >>> server_config=server_config,
61
+ >>> )
41
62
  >>>
42
- >>> app = ServerApp(
43
- >>> server_config=server_config,
44
- >>> strategy=strategy,
45
- >>> )
63
+ >>> app = ServerApp(server_fn=server_fn)
46
64
 
47
65
  Use the `ServerApp` with a custom main function:
48
66
 
@@ -53,23 +71,52 @@ class ServerApp:
53
71
  >>> print("ServerApp running")
54
72
  """
55
73
 
74
+ # pylint: disable=too-many-arguments
56
75
  def __init__(
57
76
  self,
58
77
  server: Optional[Server] = None,
59
78
  config: Optional[ServerConfig] = None,
60
79
  strategy: Optional[Strategy] = None,
61
80
  client_manager: Optional[ClientManager] = None,
81
+ server_fn: Optional[ServerFn] = None,
62
82
  ) -> None:
83
+ if any([server, config, strategy, client_manager]):
84
+ warn_deprecated_feature_with_example(
85
+ deprecation_message="Passing either `server`, `config`, `strategy` or "
86
+ "`client_manager` directly to the ServerApp "
87
+ "constructor is deprecated.",
88
+ example_message="Pass `ServerApp` arguments wrapped "
89
+ "in a `flwr.server.ServerAppComponents` object that gets "
90
+ "returned by a function passed as the `server_fn` argument "
91
+ "to the `ServerApp` constructor. For example: ",
92
+ code_example=SERVER_FN_USAGE_EXAMPLE,
93
+ )
94
+
95
+ if server_fn:
96
+ raise ValueError(
97
+ "Passing `server_fn` is incompatible with passing the "
98
+ "other arguments (now deprecated) to ServerApp. "
99
+ "Use `server_fn` exclusively."
100
+ )
101
+
63
102
  self._server = server
64
103
  self._config = config
65
104
  self._strategy = strategy
66
105
  self._client_manager = client_manager
106
+ self._server_fn = server_fn
67
107
  self._main: Optional[ServerAppCallable] = None
68
108
 
69
109
  def __call__(self, driver: Driver, context: Context) -> None:
70
110
  """Execute `ServerApp`."""
71
111
  # Compatibility mode
72
112
  if not self._main:
113
+ if self._server_fn:
114
+ # Execute server_fn()
115
+ components = self._server_fn(context)
116
+ self._server = components.server
117
+ self._config = components.config
118
+ self._strategy = components.strategy
119
+ self._client_manager = components.client_manager
73
120
  start_driver(
74
121
  server=self._server,
75
122
  config=self._config,
@@ -80,7 +127,6 @@ class ServerApp:
80
127
  return
81
128
 
82
129
  # New execution mode
83
- context = Context(state=RecordSet())
84
130
  self._main(driver, context)
85
131
 
86
132
  def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
@@ -0,0 +1,52 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ServerAppComponents for the ServerApp."""
16
+
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Optional
20
+
21
+ from .client_manager import ClientManager
22
+ from .server import Server
23
+ from .server_config import ServerConfig
24
+ from .strategy import Strategy
25
+
26
+
27
+ @dataclass
28
+ class ServerAppComponents: # pylint: disable=too-many-instance-attributes
29
+ """Components to construct a ServerApp.
30
+
31
+ Parameters
32
+ ----------
33
+ server : Optional[Server] (default: None)
34
+ A server implementation, either `flwr.server.Server` or a subclass
35
+ thereof. If no instance is provided, one will be created internally.
36
+ config : Optional[ServerConfig] (default: None)
37
+ Currently supported values are `num_rounds` (int, default: 1) and
38
+ `round_timeout` in seconds (float, default: None).
39
+ strategy : Optional[Strategy] (default: None)
40
+ An implementation of the abstract base class
41
+ `flwr.server.strategy.Strategy`. If no strategy is provided, then
42
+ `flwr.server.strategy.FedAvg` will be used.
43
+ client_manager : Optional[ClientManager] (default: None)
44
+ An implementation of the class `flwr.server.ClientManager`. If no
45
+ implementation is provided, then `flwr.server.SimpleClientManager`
46
+ will be used.
47
+ """
48
+
49
+ server: Optional[Server] = None
50
+ config: Optional[ServerConfig] = None
51
+ strategy: Optional[Strategy] = None
52
+ client_manager: Optional[ClientManager] = None
@@ -23,6 +23,7 @@ from uuid import UUID
23
23
  import grpc
24
24
 
25
25
  from flwr.common.logger import log
26
+ from flwr.common.serde import user_config_from_proto, user_config_to_proto
26
27
  from flwr.proto import driver_pb2_grpc # pylint: disable=E0611
27
28
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
28
29
  CreateRunRequest,
@@ -69,7 +70,11 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
69
70
  """Create run ID."""
70
71
  log(DEBUG, "DriverServicer.CreateRun")
71
72
  state: State = self.state_factory.state()
72
- run_id = state.create_run(request.fab_id, request.fab_version)
73
+ run_id = state.create_run(
74
+ request.fab_id,
75
+ request.fab_version,
76
+ user_config_from_proto(request.override_config),
77
+ )
73
78
  return CreateRunResponse(run_id=run_id)
74
79
 
75
80
  def PushTaskIns(
@@ -145,8 +150,18 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
145
150
 
146
151
  # Retrieve run information
147
152
  run = state.get_run(request.run_id)
148
- run_proto = None if run is None else Run(**vars(run))
149
- return GetRunResponse(run=run_proto)
153
+
154
+ if run is None:
155
+ return GetRunResponse()
156
+
157
+ return GetRunResponse(
158
+ run=Run(
159
+ run_id=run.run_id,
160
+ fab_id=run.fab_id,
161
+ fab_version=run.fab_version,
162
+ override_config=user_config_to_proto(run.override_config),
163
+ )
164
+ )
150
165
 
151
166
 
152
167
  def _raise_if(validation_error: bool, detail: str) -> None:
@@ -19,6 +19,7 @@ import time
19
19
  from typing import List, Optional
20
20
  from uuid import UUID
21
21
 
22
+ from flwr.common.serde import user_config_to_proto
22
23
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
23
24
  CreateNodeRequest,
24
25
  CreateNodeResponse,
@@ -113,5 +114,15 @@ def get_run(
113
114
  ) -> GetRunResponse:
114
115
  """Get run information."""
115
116
  run = state.get_run(request.run_id)
116
- run_proto = None if run is None else Run(**vars(run))
117
- return GetRunResponse(run=run_proto)
117
+
118
+ if run is None:
119
+ return GetRunResponse()
120
+
121
+ return GetRunResponse(
122
+ run=Run(
123
+ run_id=run.run_id,
124
+ fab_id=run.fab_id,
125
+ fab_version=run.fab_version,
126
+ override_config=user_config_to_proto(run.override_config),
127
+ )
128
+ )