flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +16 -2
- flwr/cli/config_utils.py +36 -14
- flwr/cli/install.py +17 -1
- flwr/cli/new/new.py +31 -20
- flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
- flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
- flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
- flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
- flwr/cli/run/run.py +128 -53
- flwr/client/app.py +56 -24
- flwr/client/client_app.py +28 -8
- flwr/client/grpc_adapter_client/connection.py +3 -2
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +17 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/node_state.py +59 -12
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/connection.py +19 -8
- flwr/client/supernode/app.py +55 -24
- flwr/client/typing.py +2 -2
- flwr/common/config.py +87 -2
- flwr/common/constant.py +3 -0
- flwr/common/context.py +24 -9
- flwr/common/logger.py +25 -0
- flwr/common/serde.py +45 -0
- flwr/common/telemetry.py +17 -0
- flwr/common/typing.py +5 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +24 -19
- flwr/proto/driver_pb2.pyi +21 -1
- flwr/proto/exec_pb2.py +16 -11
- flwr/proto/exec_pb2.pyi +22 -1
- flwr/proto/run_pb2.py +12 -7
- flwr/proto/run_pb2.pyi +22 -1
- flwr/proto/task_pb2.py +7 -8
- flwr/server/__init__.py +2 -0
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/grpc_driver.py +82 -140
- flwr/server/run_serverapp.py +40 -15
- flwr/server/server_app.py +56 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/superlink/driver/driver_servicer.py +18 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
- flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +149 -117
- flwr/server/superlink/state/in_memory_state.py +11 -3
- flwr/server/superlink/state/sqlite_state.py +23 -8
- flwr/server/superlink/state/state.py +7 -2
- flwr/server/typing.py +2 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
- flwr/simulation/app.py +4 -3
- flwr/simulation/ray_transport/ray_actor.py +15 -19
- flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
- flwr/simulation/run_simulation.py +237 -66
- flwr/superexec/app.py +14 -7
- flwr/superexec/deployment.py +110 -33
- flwr/superexec/exec_grpc.py +5 -1
- flwr/superexec/exec_servicer.py +4 -1
- flwr/superexec/executor.py +18 -0
- flwr/superexec/simulation.py +151 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +92 -86
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
|
@@ -16,19 +16,21 @@
|
|
|
16
16
|
|
|
17
17
|
import time
|
|
18
18
|
import warnings
|
|
19
|
-
from logging import DEBUG,
|
|
20
|
-
from typing import Iterable, List, Optional,
|
|
19
|
+
from logging import DEBUG, WARNING
|
|
20
|
+
from typing import Iterable, List, Optional, cast
|
|
21
21
|
|
|
22
22
|
import grpc
|
|
23
23
|
|
|
24
24
|
from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
|
|
25
25
|
from flwr.common.grpc import create_channel
|
|
26
26
|
from flwr.common.logger import log
|
|
27
|
-
from flwr.common.serde import
|
|
27
|
+
from flwr.common.serde import (
|
|
28
|
+
message_from_taskres,
|
|
29
|
+
message_to_taskins,
|
|
30
|
+
user_config_from_proto,
|
|
31
|
+
)
|
|
28
32
|
from flwr.common.typing import Run
|
|
29
33
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
30
|
-
CreateRunRequest,
|
|
31
|
-
CreateRunResponse,
|
|
32
34
|
GetNodesRequest,
|
|
33
35
|
GetNodesResponse,
|
|
34
36
|
PullTaskResRequest,
|
|
@@ -53,167 +55,103 @@ Call `connect()` on the `GrpcDriverStub` instance before calling any of the othe
|
|
|
53
55
|
"""
|
|
54
56
|
|
|
55
57
|
|
|
56
|
-
class
|
|
57
|
-
"""`
|
|
58
|
+
class GrpcDriver(Driver):
|
|
59
|
+
"""`GrpcDriver` provides an interface to the Driver API.
|
|
58
60
|
|
|
59
61
|
Parameters
|
|
60
62
|
----------
|
|
61
|
-
|
|
62
|
-
The
|
|
63
|
-
|
|
63
|
+
run_id : int
|
|
64
|
+
The identifier of the run.
|
|
65
|
+
driver_service_address : str (default: "[::]:9091")
|
|
66
|
+
The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
|
|
64
67
|
root_certificates : Optional[bytes] (default: None)
|
|
65
68
|
The PEM-encoded root certificates as a byte string.
|
|
66
69
|
If provided, a secure connection using the certificates will be
|
|
67
70
|
established to an SSL-enabled Flower server.
|
|
68
71
|
"""
|
|
69
72
|
|
|
70
|
-
def __init__(
|
|
73
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
71
74
|
self,
|
|
75
|
+
run_id: int,
|
|
72
76
|
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
|
|
73
77
|
root_certificates: Optional[bytes] = None,
|
|
74
78
|
) -> None:
|
|
75
|
-
self.
|
|
76
|
-
self.
|
|
77
|
-
self.
|
|
78
|
-
self.
|
|
79
|
+
self._run_id = run_id
|
|
80
|
+
self._addr = driver_service_address
|
|
81
|
+
self._cert = root_certificates
|
|
82
|
+
self._run: Optional[Run] = None
|
|
83
|
+
self._grpc_stub: Optional[DriverStub] = None
|
|
84
|
+
self._channel: Optional[grpc.Channel] = None
|
|
85
|
+
self.node = Node(node_id=0, anonymous=True)
|
|
79
86
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
87
|
+
@property
|
|
88
|
+
def _is_connected(self) -> bool:
|
|
89
|
+
"""Check if connected to the Driver API server."""
|
|
90
|
+
return self._channel is not None
|
|
91
|
+
|
|
92
|
+
def _connect(self) -> None:
|
|
93
|
+
"""Connect to the Driver API.
|
|
83
94
|
|
|
84
|
-
|
|
85
|
-
"""
|
|
95
|
+
This will not call GetRun.
|
|
96
|
+
"""
|
|
86
97
|
event(EventType.DRIVER_CONNECT)
|
|
87
|
-
if self.
|
|
98
|
+
if self._is_connected:
|
|
88
99
|
log(WARNING, "Already connected")
|
|
89
100
|
return
|
|
90
|
-
self.
|
|
91
|
-
server_address=self.
|
|
92
|
-
insecure=(self.
|
|
93
|
-
root_certificates=self.
|
|
101
|
+
self._channel = create_channel(
|
|
102
|
+
server_address=self._addr,
|
|
103
|
+
insecure=(self._cert is None),
|
|
104
|
+
root_certificates=self._cert,
|
|
94
105
|
)
|
|
95
|
-
self.
|
|
96
|
-
log(DEBUG, "[Driver] Connected to %s", self.
|
|
106
|
+
self._grpc_stub = DriverStub(self._channel)
|
|
107
|
+
log(DEBUG, "[Driver] Connected to %s", self._addr)
|
|
97
108
|
|
|
98
|
-
def
|
|
109
|
+
def _disconnect(self) -> None:
|
|
99
110
|
"""Disconnect from the Driver API."""
|
|
100
111
|
event(EventType.DRIVER_DISCONNECT)
|
|
101
|
-
if
|
|
112
|
+
if not self._is_connected:
|
|
102
113
|
log(DEBUG, "Already disconnected")
|
|
103
114
|
return
|
|
104
|
-
channel = self.
|
|
105
|
-
self.
|
|
106
|
-
self.
|
|
115
|
+
channel: grpc.Channel = self._channel
|
|
116
|
+
self._channel = None
|
|
117
|
+
self._grpc_stub = None
|
|
107
118
|
channel.close()
|
|
108
119
|
log(DEBUG, "[Driver] Disconnected")
|
|
109
120
|
|
|
110
|
-
def
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
126
|
-
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
127
|
-
|
|
128
|
-
# Call gRPC Driver API
|
|
129
|
-
res: GetRunResponse = self.stub.GetRun(request=req)
|
|
130
|
-
return res
|
|
131
|
-
|
|
132
|
-
def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
|
|
133
|
-
"""Get client IDs."""
|
|
134
|
-
# Check if channel is open
|
|
135
|
-
if self.stub is None:
|
|
136
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
137
|
-
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
138
|
-
|
|
139
|
-
# Call gRPC Driver API
|
|
140
|
-
res: GetNodesResponse = self.stub.GetNodes(request=req)
|
|
141
|
-
return res
|
|
142
|
-
|
|
143
|
-
def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
|
|
144
|
-
"""Schedule tasks."""
|
|
145
|
-
# Check if channel is open
|
|
146
|
-
if self.stub is None:
|
|
147
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
148
|
-
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
149
|
-
|
|
150
|
-
# Call gRPC Driver API
|
|
151
|
-
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
|
|
152
|
-
return res
|
|
153
|
-
|
|
154
|
-
def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
|
|
155
|
-
"""Get task results."""
|
|
156
|
-
# Check if channel is open
|
|
157
|
-
if self.stub is None:
|
|
158
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
159
|
-
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
160
|
-
|
|
161
|
-
# Call Driver API
|
|
162
|
-
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
|
|
163
|
-
return res
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
class GrpcDriver(Driver):
|
|
167
|
-
"""`Driver` class provides an interface to the Driver API.
|
|
168
|
-
|
|
169
|
-
Parameters
|
|
170
|
-
----------
|
|
171
|
-
run_id : int
|
|
172
|
-
The identifier of the run.
|
|
173
|
-
stub : Optional[GrpcDriverStub] (default: None)
|
|
174
|
-
The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
|
|
175
|
-
If None, an instance connected to "[::]:9091" will be created.
|
|
176
|
-
"""
|
|
177
|
-
|
|
178
|
-
def __init__( # pylint: disable=too-many-arguments
|
|
179
|
-
self,
|
|
180
|
-
run_id: int,
|
|
181
|
-
stub: Optional[GrpcDriverStub] = None,
|
|
182
|
-
) -> None:
|
|
183
|
-
self._run_id = run_id
|
|
184
|
-
self._run: Optional[Run] = None
|
|
185
|
-
self.stub = stub if stub is not None else GrpcDriverStub()
|
|
186
|
-
self.node = Node(node_id=0, anonymous=True)
|
|
121
|
+
def _init_run(self) -> None:
|
|
122
|
+
# Check if is initialized
|
|
123
|
+
if self._run is not None:
|
|
124
|
+
return
|
|
125
|
+
# Get the run info
|
|
126
|
+
req = GetRunRequest(run_id=self._run_id)
|
|
127
|
+
res: GetRunResponse = self._stub.GetRun(req)
|
|
128
|
+
if not res.HasField("run"):
|
|
129
|
+
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
|
|
130
|
+
self._run = Run(
|
|
131
|
+
run_id=res.run.run_id,
|
|
132
|
+
fab_id=res.run.fab_id,
|
|
133
|
+
fab_version=res.run.fab_version,
|
|
134
|
+
override_config=user_config_from_proto(res.run.override_config),
|
|
135
|
+
)
|
|
187
136
|
|
|
188
137
|
@property
|
|
189
138
|
def run(self) -> Run:
|
|
190
139
|
"""Run information."""
|
|
191
|
-
self.
|
|
192
|
-
return Run(**vars(
|
|
140
|
+
self._init_run()
|
|
141
|
+
return Run(**vars(self._run))
|
|
193
142
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
# Get the run info
|
|
201
|
-
req = GetRunRequest(run_id=self._run_id)
|
|
202
|
-
res = self.stub.get_run(req)
|
|
203
|
-
if not res.HasField("run"):
|
|
204
|
-
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
|
|
205
|
-
self._run = Run(
|
|
206
|
-
run_id=res.run.run_id,
|
|
207
|
-
fab_id=res.run.fab_id,
|
|
208
|
-
fab_version=res.run.fab_version,
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
return self.stub, self._run.run_id
|
|
143
|
+
@property
|
|
144
|
+
def _stub(self) -> DriverStub:
|
|
145
|
+
"""Driver stub."""
|
|
146
|
+
if not self._is_connected:
|
|
147
|
+
self._connect()
|
|
148
|
+
return cast(DriverStub, self._grpc_stub)
|
|
212
149
|
|
|
213
150
|
def _check_message(self, message: Message) -> None:
|
|
214
151
|
# Check if the message is valid
|
|
215
152
|
if not (
|
|
216
|
-
|
|
153
|
+
# Assume self._run being initialized
|
|
154
|
+
message.metadata.run_id == self._run_id
|
|
217
155
|
and message.metadata.src_node_id == self.node.node_id
|
|
218
156
|
and message.metadata.message_id == ""
|
|
219
157
|
and message.metadata.reply_to_message == ""
|
|
@@ -234,7 +172,7 @@ class GrpcDriver(Driver):
|
|
|
234
172
|
This method constructs a new `Message` with given content and metadata.
|
|
235
173
|
The `run_id` and `src_node_id` will be set automatically.
|
|
236
174
|
"""
|
|
237
|
-
|
|
175
|
+
self._init_run()
|
|
238
176
|
if ttl:
|
|
239
177
|
warnings.warn(
|
|
240
178
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -245,7 +183,7 @@ class GrpcDriver(Driver):
|
|
|
245
183
|
|
|
246
184
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
247
185
|
metadata = Metadata(
|
|
248
|
-
run_id=
|
|
186
|
+
run_id=self._run_id,
|
|
249
187
|
message_id="", # Will be set by the server
|
|
250
188
|
src_node_id=self.node.node_id,
|
|
251
189
|
dst_node_id=dst_node_id,
|
|
@@ -258,9 +196,11 @@ class GrpcDriver(Driver):
|
|
|
258
196
|
|
|
259
197
|
def get_node_ids(self) -> List[int]:
|
|
260
198
|
"""Get node IDs."""
|
|
261
|
-
|
|
199
|
+
self._init_run()
|
|
262
200
|
# Call GrpcDriverStub method
|
|
263
|
-
res =
|
|
201
|
+
res: GetNodesResponse = self._stub.GetNodes(
|
|
202
|
+
GetNodesRequest(run_id=self._run_id)
|
|
203
|
+
)
|
|
264
204
|
return [node.node_id for node in res.nodes]
|
|
265
205
|
|
|
266
206
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
@@ -269,7 +209,7 @@ class GrpcDriver(Driver):
|
|
|
269
209
|
This method takes an iterable of messages and sends each message
|
|
270
210
|
to the node specified in `dst_node_id`.
|
|
271
211
|
"""
|
|
272
|
-
|
|
212
|
+
self._init_run()
|
|
273
213
|
# Construct TaskIns
|
|
274
214
|
task_ins_list: List[TaskIns] = []
|
|
275
215
|
for msg in messages:
|
|
@@ -280,7 +220,9 @@ class GrpcDriver(Driver):
|
|
|
280
220
|
# Add to list
|
|
281
221
|
task_ins_list.append(taskins)
|
|
282
222
|
# Call GrpcDriverStub method
|
|
283
|
-
res =
|
|
223
|
+
res: PushTaskInsResponse = self._stub.PushTaskIns(
|
|
224
|
+
PushTaskInsRequest(task_ins_list=task_ins_list)
|
|
225
|
+
)
|
|
284
226
|
return list(res.task_ids)
|
|
285
227
|
|
|
286
228
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
@@ -289,9 +231,9 @@ class GrpcDriver(Driver):
|
|
|
289
231
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
290
232
|
set of given message IDs.
|
|
291
233
|
"""
|
|
292
|
-
|
|
234
|
+
self._init_run()
|
|
293
235
|
# Pull TaskRes
|
|
294
|
-
res =
|
|
236
|
+
res: PullTaskResResponse = self._stub.PullTaskRes(
|
|
295
237
|
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
296
238
|
)
|
|
297
239
|
# Convert TaskRes to Message
|
|
@@ -331,7 +273,7 @@ class GrpcDriver(Driver):
|
|
|
331
273
|
def close(self) -> None:
|
|
332
274
|
"""Disconnect from the SuperLink if connected."""
|
|
333
275
|
# Check if `connect` was called before
|
|
334
|
-
if not self.
|
|
276
|
+
if not self._is_connected:
|
|
335
277
|
return
|
|
336
278
|
# Disconnect
|
|
337
|
-
self.
|
|
279
|
+
self._disconnect()
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -22,13 +22,22 @@ from pathlib import Path
|
|
|
22
22
|
from typing import Optional
|
|
23
23
|
|
|
24
24
|
from flwr.common import Context, EventType, RecordSet, event
|
|
25
|
-
from flwr.common.config import
|
|
25
|
+
from flwr.common.config import (
|
|
26
|
+
get_flwr_dir,
|
|
27
|
+
get_fused_config,
|
|
28
|
+
get_project_config,
|
|
29
|
+
get_project_dir,
|
|
30
|
+
)
|
|
26
31
|
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
|
|
27
32
|
from flwr.common.object_ref import load_app
|
|
28
|
-
from flwr.
|
|
33
|
+
from flwr.common.typing import UserConfig
|
|
34
|
+
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
35
|
+
CreateRunRequest,
|
|
36
|
+
CreateRunResponse,
|
|
37
|
+
)
|
|
29
38
|
|
|
30
39
|
from .driver import Driver
|
|
31
|
-
from .driver.grpc_driver import GrpcDriver
|
|
40
|
+
from .driver.grpc_driver import GrpcDriver
|
|
32
41
|
from .server_app import LoadServerAppError, ServerApp
|
|
33
42
|
|
|
34
43
|
ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
@@ -37,6 +46,7 @@ ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
|
37
46
|
def run(
|
|
38
47
|
driver: Driver,
|
|
39
48
|
server_app_dir: str,
|
|
49
|
+
server_app_run_config: UserConfig,
|
|
40
50
|
server_app_attr: Optional[str] = None,
|
|
41
51
|
loaded_server_app: Optional[ServerApp] = None,
|
|
42
52
|
) -> None:
|
|
@@ -69,7 +79,9 @@ def run(
|
|
|
69
79
|
server_app = _load()
|
|
70
80
|
|
|
71
81
|
# Initialize Context
|
|
72
|
-
context = Context(
|
|
82
|
+
context = Context(
|
|
83
|
+
node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config
|
|
84
|
+
)
|
|
73
85
|
|
|
74
86
|
# Call ServerApp
|
|
75
87
|
server_app(driver=driver, context=context)
|
|
@@ -144,22 +156,29 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
|
|
|
144
156
|
"For more details, use: ``flower-server-app -h``"
|
|
145
157
|
)
|
|
146
158
|
|
|
147
|
-
|
|
148
|
-
driver_service_address=args.superlink, root_certificates=root_certificates
|
|
149
|
-
)
|
|
159
|
+
# Initialize GrpcDriver
|
|
150
160
|
if args.run_id is not None:
|
|
151
161
|
# User provided `--run-id`, but not `server-app`
|
|
152
|
-
|
|
162
|
+
driver = GrpcDriver(
|
|
163
|
+
run_id=args.run_id,
|
|
164
|
+
driver_service_address=args.superlink,
|
|
165
|
+
root_certificates=root_certificates,
|
|
166
|
+
)
|
|
153
167
|
else:
|
|
154
168
|
# User provided `server-app`, but not `--run-id`
|
|
155
169
|
# Create run if run_id is not provided
|
|
156
|
-
|
|
170
|
+
driver = GrpcDriver(
|
|
171
|
+
run_id=0, # Will be overwritten
|
|
172
|
+
driver_service_address=args.superlink,
|
|
173
|
+
root_certificates=root_certificates,
|
|
174
|
+
)
|
|
175
|
+
# Create run
|
|
157
176
|
req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version)
|
|
158
|
-
res =
|
|
159
|
-
|
|
177
|
+
res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
|
|
178
|
+
# Overwrite driver._run_id
|
|
179
|
+
driver._run_id = res.run_id # pylint: disable=W0212
|
|
160
180
|
|
|
161
|
-
|
|
162
|
-
driver = GrpcDriver(run_id=run_id, stub=stub)
|
|
181
|
+
server_app_run_config = {}
|
|
163
182
|
|
|
164
183
|
# Dynamically obtain ServerApp path based on run_id
|
|
165
184
|
if args.run_id is not None:
|
|
@@ -168,7 +187,8 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
|
|
|
168
187
|
run_ = driver.run
|
|
169
188
|
server_app_dir = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir))
|
|
170
189
|
config = get_project_config(server_app_dir)
|
|
171
|
-
server_app_attr = config["
|
|
190
|
+
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
|
|
191
|
+
server_app_run_config = get_fused_config(run_, flwr_dir)
|
|
172
192
|
else:
|
|
173
193
|
# User provided `server-app`, but not `--run-id`
|
|
174
194
|
server_app_dir = str(Path(args.dir).absolute())
|
|
@@ -182,7 +202,12 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
|
|
|
182
202
|
)
|
|
183
203
|
|
|
184
204
|
# Run the ServerApp with the Driver
|
|
185
|
-
run(
|
|
205
|
+
run(
|
|
206
|
+
driver=driver,
|
|
207
|
+
server_app_dir=server_app_dir,
|
|
208
|
+
server_app_run_config=server_app_run_config,
|
|
209
|
+
server_app_attr=server_app_attr,
|
|
210
|
+
)
|
|
186
211
|
|
|
187
212
|
# Clean up
|
|
188
213
|
driver.close()
|
flwr/server/server_app.py
CHANGED
|
@@ -17,8 +17,11 @@
|
|
|
17
17
|
|
|
18
18
|
from typing import Callable, Optional
|
|
19
19
|
|
|
20
|
-
from flwr.common import Context
|
|
21
|
-
from flwr.common.logger import
|
|
20
|
+
from flwr.common import Context
|
|
21
|
+
from flwr.common.logger import (
|
|
22
|
+
warn_deprecated_feature_with_example,
|
|
23
|
+
warn_preview_feature,
|
|
24
|
+
)
|
|
22
25
|
from flwr.server.strategy import Strategy
|
|
23
26
|
|
|
24
27
|
from .client_manager import ClientManager
|
|
@@ -26,7 +29,20 @@ from .compat import start_driver
|
|
|
26
29
|
from .driver import Driver
|
|
27
30
|
from .server import Server
|
|
28
31
|
from .server_config import ServerConfig
|
|
29
|
-
from .typing import ServerAppCallable
|
|
32
|
+
from .typing import ServerAppCallable, ServerFn
|
|
33
|
+
|
|
34
|
+
SERVER_FN_USAGE_EXAMPLE = """
|
|
35
|
+
|
|
36
|
+
def server_fn(context: Context):
|
|
37
|
+
server_config = ServerConfig(num_rounds=3)
|
|
38
|
+
strategy = FedAvg()
|
|
39
|
+
return ServerAppComponents(
|
|
40
|
+
strategy=strategy,
|
|
41
|
+
server_config=server_config,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
app = ServerApp(server_fn=server_fn)
|
|
45
|
+
"""
|
|
30
46
|
|
|
31
47
|
|
|
32
48
|
class ServerApp:
|
|
@@ -36,13 +52,15 @@ class ServerApp:
|
|
|
36
52
|
--------
|
|
37
53
|
Use the `ServerApp` with an existing `Strategy`:
|
|
38
54
|
|
|
39
|
-
>>>
|
|
40
|
-
>>>
|
|
55
|
+
>>> def server_fn(context: Context):
|
|
56
|
+
>>> server_config = ServerConfig(num_rounds=3)
|
|
57
|
+
>>> strategy = FedAvg()
|
|
58
|
+
>>> return ServerAppComponents(
|
|
59
|
+
>>> strategy=strategy,
|
|
60
|
+
>>> server_config=server_config,
|
|
61
|
+
>>> )
|
|
41
62
|
>>>
|
|
42
|
-
>>> app = ServerApp(
|
|
43
|
-
>>> server_config=server_config,
|
|
44
|
-
>>> strategy=strategy,
|
|
45
|
-
>>> )
|
|
63
|
+
>>> app = ServerApp(server_fn=server_fn)
|
|
46
64
|
|
|
47
65
|
Use the `ServerApp` with a custom main function:
|
|
48
66
|
|
|
@@ -53,23 +71,52 @@ class ServerApp:
|
|
|
53
71
|
>>> print("ServerApp running")
|
|
54
72
|
"""
|
|
55
73
|
|
|
74
|
+
# pylint: disable=too-many-arguments
|
|
56
75
|
def __init__(
|
|
57
76
|
self,
|
|
58
77
|
server: Optional[Server] = None,
|
|
59
78
|
config: Optional[ServerConfig] = None,
|
|
60
79
|
strategy: Optional[Strategy] = None,
|
|
61
80
|
client_manager: Optional[ClientManager] = None,
|
|
81
|
+
server_fn: Optional[ServerFn] = None,
|
|
62
82
|
) -> None:
|
|
83
|
+
if any([server, config, strategy, client_manager]):
|
|
84
|
+
warn_deprecated_feature_with_example(
|
|
85
|
+
deprecation_message="Passing either `server`, `config`, `strategy` or "
|
|
86
|
+
"`client_manager` directly to the ServerApp "
|
|
87
|
+
"constructor is deprecated.",
|
|
88
|
+
example_message="Pass `ServerApp` arguments wrapped "
|
|
89
|
+
"in a `flwr.server.ServerAppComponents` object that gets "
|
|
90
|
+
"returned by a function passed as the `server_fn` argument "
|
|
91
|
+
"to the `ServerApp` constructor. For example: ",
|
|
92
|
+
code_example=SERVER_FN_USAGE_EXAMPLE,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if server_fn:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"Passing `server_fn` is incompatible with passing the "
|
|
98
|
+
"other arguments (now deprecated) to ServerApp. "
|
|
99
|
+
"Use `server_fn` exclusively."
|
|
100
|
+
)
|
|
101
|
+
|
|
63
102
|
self._server = server
|
|
64
103
|
self._config = config
|
|
65
104
|
self._strategy = strategy
|
|
66
105
|
self._client_manager = client_manager
|
|
106
|
+
self._server_fn = server_fn
|
|
67
107
|
self._main: Optional[ServerAppCallable] = None
|
|
68
108
|
|
|
69
109
|
def __call__(self, driver: Driver, context: Context) -> None:
|
|
70
110
|
"""Execute `ServerApp`."""
|
|
71
111
|
# Compatibility mode
|
|
72
112
|
if not self._main:
|
|
113
|
+
if self._server_fn:
|
|
114
|
+
# Execute server_fn()
|
|
115
|
+
components = self._server_fn(context)
|
|
116
|
+
self._server = components.server
|
|
117
|
+
self._config = components.config
|
|
118
|
+
self._strategy = components.strategy
|
|
119
|
+
self._client_manager = components.client_manager
|
|
73
120
|
start_driver(
|
|
74
121
|
server=self._server,
|
|
75
122
|
config=self._config,
|
|
@@ -80,7 +127,6 @@ class ServerApp:
|
|
|
80
127
|
return
|
|
81
128
|
|
|
82
129
|
# New execution mode
|
|
83
|
-
context = Context(state=RecordSet())
|
|
84
130
|
self._main(driver, context)
|
|
85
131
|
|
|
86
132
|
def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""ServerAppComponents for the ServerApp."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
21
|
+
from .client_manager import ClientManager
|
|
22
|
+
from .server import Server
|
|
23
|
+
from .server_config import ServerConfig
|
|
24
|
+
from .strategy import Strategy
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class ServerAppComponents: # pylint: disable=too-many-instance-attributes
|
|
29
|
+
"""Components to construct a ServerApp.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
server : Optional[Server] (default: None)
|
|
34
|
+
A server implementation, either `flwr.server.Server` or a subclass
|
|
35
|
+
thereof. If no instance is provided, one will be created internally.
|
|
36
|
+
config : Optional[ServerConfig] (default: None)
|
|
37
|
+
Currently supported values are `num_rounds` (int, default: 1) and
|
|
38
|
+
`round_timeout` in seconds (float, default: None).
|
|
39
|
+
strategy : Optional[Strategy] (default: None)
|
|
40
|
+
An implementation of the abstract base class
|
|
41
|
+
`flwr.server.strategy.Strategy`. If no strategy is provided, then
|
|
42
|
+
`flwr.server.strategy.FedAvg` will be used.
|
|
43
|
+
client_manager : Optional[ClientManager] (default: None)
|
|
44
|
+
An implementation of the class `flwr.server.ClientManager`. If no
|
|
45
|
+
implementation is provided, then `flwr.server.SimpleClientManager`
|
|
46
|
+
will be used.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
server: Optional[Server] = None
|
|
50
|
+
config: Optional[ServerConfig] = None
|
|
51
|
+
strategy: Optional[Strategy] = None
|
|
52
|
+
client_manager: Optional[ClientManager] = None
|
|
@@ -23,6 +23,7 @@ from uuid import UUID
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
25
25
|
from flwr.common.logger import log
|
|
26
|
+
from flwr.common.serde import user_config_from_proto, user_config_to_proto
|
|
26
27
|
from flwr.proto import driver_pb2_grpc # pylint: disable=E0611
|
|
27
28
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
28
29
|
CreateRunRequest,
|
|
@@ -69,7 +70,11 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
69
70
|
"""Create run ID."""
|
|
70
71
|
log(DEBUG, "DriverServicer.CreateRun")
|
|
71
72
|
state: State = self.state_factory.state()
|
|
72
|
-
run_id = state.create_run(
|
|
73
|
+
run_id = state.create_run(
|
|
74
|
+
request.fab_id,
|
|
75
|
+
request.fab_version,
|
|
76
|
+
user_config_from_proto(request.override_config),
|
|
77
|
+
)
|
|
73
78
|
return CreateRunResponse(run_id=run_id)
|
|
74
79
|
|
|
75
80
|
def PushTaskIns(
|
|
@@ -145,8 +150,18 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
145
150
|
|
|
146
151
|
# Retrieve run information
|
|
147
152
|
run = state.get_run(request.run_id)
|
|
148
|
-
|
|
149
|
-
|
|
153
|
+
|
|
154
|
+
if run is None:
|
|
155
|
+
return GetRunResponse()
|
|
156
|
+
|
|
157
|
+
return GetRunResponse(
|
|
158
|
+
run=Run(
|
|
159
|
+
run_id=run.run_id,
|
|
160
|
+
fab_id=run.fab_id,
|
|
161
|
+
fab_version=run.fab_version,
|
|
162
|
+
override_config=user_config_to_proto(run.override_config),
|
|
163
|
+
)
|
|
164
|
+
)
|
|
150
165
|
|
|
151
166
|
|
|
152
167
|
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
@@ -19,6 +19,7 @@ import time
|
|
|
19
19
|
from typing import List, Optional
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
+
from flwr.common.serde import user_config_to_proto
|
|
22
23
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
23
24
|
CreateNodeRequest,
|
|
24
25
|
CreateNodeResponse,
|
|
@@ -113,5 +114,15 @@ def get_run(
|
|
|
113
114
|
) -> GetRunResponse:
|
|
114
115
|
"""Get run information."""
|
|
115
116
|
run = state.get_run(request.run_id)
|
|
116
|
-
|
|
117
|
-
|
|
117
|
+
|
|
118
|
+
if run is None:
|
|
119
|
+
return GetRunResponse()
|
|
120
|
+
|
|
121
|
+
return GetRunResponse(
|
|
122
|
+
run=Run(
|
|
123
|
+
run_id=run.run_id,
|
|
124
|
+
fab_id=run.fab_id,
|
|
125
|
+
fab_version=run.fab_version,
|
|
126
|
+
override_config=user_config_to_proto(run.override_config),
|
|
127
|
+
)
|
|
128
|
+
)
|