flwr-nightly 1.10.0.dev20240619__py3-none-any.whl → 1.10.0.dev20240624__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +3 -0
- flwr/cli/build.py +3 -7
- flwr/cli/new/new.py +104 -28
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +86 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +124 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +42 -0
- flwr/cli/run/run.py +8 -1
- flwr/client/client_app.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/connection.py +1 -1
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +1 -1
- flwr/client/supernode/app.py +1 -1
- flwr/common/address.py +1 -1
- flwr/common/config.py +8 -6
- flwr/common/constant.py +1 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +1 -1
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/version.py +14 -0
- flwr/server/compat/app.py +1 -1
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +1 -1
- flwr/server/driver/driver.py +6 -0
- flwr/server/driver/grpc_driver.py +85 -63
- flwr/server/driver/inmemory_driver.py +28 -26
- flwr/server/run_serverapp.py +61 -18
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +15 -3
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +44 -25
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +1 -1
- flwr/server/superlink/state/sqlite_state.py +1 -1
- flwr/server/superlink/state/state.py +1 -1
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +0 -6
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +47 -28
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/METADATA +2 -1
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/RECORD +98 -88
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/entry_points.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
import time
|
|
18
18
|
import warnings
|
|
19
19
|
from logging import DEBUG, ERROR, WARNING
|
|
20
|
-
from typing import Iterable, List, Optional, Tuple
|
|
20
|
+
from typing import Iterable, List, Optional, Tuple, cast
|
|
21
21
|
|
|
22
22
|
import grpc
|
|
23
23
|
|
|
@@ -25,6 +25,7 @@ from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, ev
|
|
|
25
25
|
from flwr.common.grpc import create_channel
|
|
26
26
|
from flwr.common.logger import log
|
|
27
27
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
28
|
+
from flwr.common.typing import Run
|
|
28
29
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
29
30
|
CreateRunRequest,
|
|
30
31
|
CreateRunResponse,
|
|
@@ -37,6 +38,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
|
37
38
|
)
|
|
38
39
|
from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
|
|
39
40
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
41
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
40
42
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
41
43
|
|
|
42
44
|
from .driver import Driver
|
|
@@ -46,13 +48,24 @@ DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
|
46
48
|
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
|
|
47
49
|
[Driver] Error: Not connected.
|
|
48
50
|
|
|
49
|
-
Call `connect()` on the `
|
|
50
|
-
`
|
|
51
|
+
Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
|
|
52
|
+
`GrpcDriverStub` methods.
|
|
51
53
|
"""
|
|
52
54
|
|
|
53
55
|
|
|
54
|
-
class
|
|
55
|
-
"""`
|
|
56
|
+
class GrpcDriverStub:
|
|
57
|
+
"""`GrpcDriverStub` provides access to the gRPC Driver API/service.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
driver_service_address : Optional[str]
|
|
62
|
+
The IPv4 or IPv6 address of the Driver API server.
|
|
63
|
+
Defaults to `"[::]:9091"`.
|
|
64
|
+
root_certificates : Optional[bytes] (default: None)
|
|
65
|
+
The PEM-encoded root certificates as a byte string.
|
|
66
|
+
If provided, a secure connection using the certificates will be
|
|
67
|
+
established to an SSL-enabled Flower server.
|
|
68
|
+
"""
|
|
56
69
|
|
|
57
70
|
def __init__(
|
|
58
71
|
self,
|
|
@@ -64,6 +77,10 @@ class GrpcDriverHelper:
|
|
|
64
77
|
self.channel: Optional[grpc.Channel] = None
|
|
65
78
|
self.stub: Optional[DriverStub] = None
|
|
66
79
|
|
|
80
|
+
def is_connected(self) -> bool:
|
|
81
|
+
"""Return True if connected to the Driver API server, otherwise False."""
|
|
82
|
+
return self.channel is not None
|
|
83
|
+
|
|
67
84
|
def connect(self) -> None:
|
|
68
85
|
"""Connect to the Driver API."""
|
|
69
86
|
event(EventType.DRIVER_CONNECT)
|
|
@@ -95,18 +112,29 @@ class GrpcDriverHelper:
|
|
|
95
112
|
# Check if channel is open
|
|
96
113
|
if self.stub is None:
|
|
97
114
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
98
|
-
raise ConnectionError("`
|
|
115
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
99
116
|
|
|
100
117
|
# Call Driver API
|
|
101
118
|
res: CreateRunResponse = self.stub.CreateRun(request=req)
|
|
102
119
|
return res
|
|
103
120
|
|
|
121
|
+
def get_run(self, req: GetRunRequest) -> GetRunResponse:
|
|
122
|
+
"""Get run information."""
|
|
123
|
+
# Check if channel is open
|
|
124
|
+
if self.stub is None:
|
|
125
|
+
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
126
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
127
|
+
|
|
128
|
+
# Call gRPC Driver API
|
|
129
|
+
res: GetRunResponse = self.stub.GetRun(request=req)
|
|
130
|
+
return res
|
|
131
|
+
|
|
104
132
|
def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
|
|
105
133
|
"""Get client IDs."""
|
|
106
134
|
# Check if channel is open
|
|
107
135
|
if self.stub is None:
|
|
108
136
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
109
|
-
raise ConnectionError("`
|
|
137
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
110
138
|
|
|
111
139
|
# Call gRPC Driver API
|
|
112
140
|
res: GetNodesResponse = self.stub.GetNodes(request=req)
|
|
@@ -117,7 +145,7 @@ class GrpcDriverHelper:
|
|
|
117
145
|
# Check if channel is open
|
|
118
146
|
if self.stub is None:
|
|
119
147
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
120
|
-
raise ConnectionError("`
|
|
148
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
121
149
|
|
|
122
150
|
# Call gRPC Driver API
|
|
123
151
|
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
|
|
@@ -128,7 +156,7 @@ class GrpcDriverHelper:
|
|
|
128
156
|
# Check if channel is open
|
|
129
157
|
if self.stub is None:
|
|
130
158
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
131
|
-
raise ConnectionError("`
|
|
159
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
132
160
|
|
|
133
161
|
# Call Driver API
|
|
134
162
|
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
|
|
@@ -140,56 +168,52 @@ class GrpcDriver(Driver):
|
|
|
140
168
|
|
|
141
169
|
Parameters
|
|
142
170
|
----------
|
|
143
|
-
|
|
144
|
-
The
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
to start a secure SSL-enabled server. The tuple is expected to have
|
|
149
|
-
three bytes elements in the following order:
|
|
150
|
-
|
|
151
|
-
* CA certificate.
|
|
152
|
-
* server certificate.
|
|
153
|
-
* server private key.
|
|
154
|
-
fab_id : str (default: None)
|
|
155
|
-
The identifier of the FAB used in the run.
|
|
156
|
-
fab_version : str (default: None)
|
|
157
|
-
The version of the FAB used in the run.
|
|
171
|
+
run_id : int
|
|
172
|
+
The identifier of the run.
|
|
173
|
+
stub : Optional[GrpcDriverStub] (default: None)
|
|
174
|
+
The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
|
|
175
|
+
If None, an instance connected to "[::]:9091" will be created.
|
|
158
176
|
"""
|
|
159
177
|
|
|
160
|
-
def __init__(
|
|
178
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
161
179
|
self,
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
fab_id: Optional[str] = None,
|
|
165
|
-
fab_version: Optional[str] = None,
|
|
180
|
+
run_id: int,
|
|
181
|
+
stub: Optional[GrpcDriverStub] = None,
|
|
166
182
|
) -> None:
|
|
167
|
-
self.
|
|
168
|
-
self.
|
|
169
|
-
self.
|
|
170
|
-
self.run_id: Optional[int] = None
|
|
171
|
-
self.fab_id = fab_id if fab_id is not None else ""
|
|
172
|
-
self.fab_version = fab_version if fab_version is not None else ""
|
|
183
|
+
self._run_id = run_id
|
|
184
|
+
self._run: Optional[Run] = None
|
|
185
|
+
self.stub = stub if stub is not None else GrpcDriverStub()
|
|
173
186
|
self.node = Node(node_id=0, anonymous=True)
|
|
174
187
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
188
|
+
@property
|
|
189
|
+
def run(self) -> Run:
|
|
190
|
+
"""Run information."""
|
|
191
|
+
self._get_stub_and_run_id()
|
|
192
|
+
return Run(**vars(cast(Run, self._run)))
|
|
193
|
+
|
|
194
|
+
def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]:
|
|
195
|
+
# Check if is initialized
|
|
196
|
+
if self._run is None:
|
|
197
|
+
# Connect
|
|
198
|
+
if not self.stub.is_connected():
|
|
199
|
+
self.stub.connect()
|
|
200
|
+
# Get the run info
|
|
201
|
+
req = GetRunRequest(run_id=self._run_id)
|
|
202
|
+
res = self.stub.get_run(req)
|
|
203
|
+
if not res.HasField("run"):
|
|
204
|
+
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
|
|
205
|
+
self._run = Run(
|
|
206
|
+
run_id=res.run.run_id,
|
|
207
|
+
fab_id=res.run.fab_id,
|
|
208
|
+
fab_version=res.run.fab_version,
|
|
182
209
|
)
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
res = self.driver_helper.create_run(req)
|
|
186
|
-
self.run_id = res.run_id
|
|
187
|
-
return self.driver_helper, self.run_id
|
|
210
|
+
|
|
211
|
+
return self.stub, self._run.run_id
|
|
188
212
|
|
|
189
213
|
def _check_message(self, message: Message) -> None:
|
|
190
214
|
# Check if the message is valid
|
|
191
215
|
if not (
|
|
192
|
-
message.metadata.run_id == self.run_id
|
|
216
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
193
217
|
and message.metadata.src_node_id == self.node.node_id
|
|
194
218
|
and message.metadata.message_id == ""
|
|
195
219
|
and message.metadata.reply_to_message == ""
|
|
@@ -210,7 +234,7 @@ class GrpcDriver(Driver):
|
|
|
210
234
|
This method constructs a new `Message` with given content and metadata.
|
|
211
235
|
The `run_id` and `src_node_id` will be set automatically.
|
|
212
236
|
"""
|
|
213
|
-
_, run_id = self.
|
|
237
|
+
_, run_id = self._get_stub_and_run_id()
|
|
214
238
|
if ttl:
|
|
215
239
|
warnings.warn(
|
|
216
240
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -234,9 +258,9 @@ class GrpcDriver(Driver):
|
|
|
234
258
|
|
|
235
259
|
def get_node_ids(self) -> List[int]:
|
|
236
260
|
"""Get node IDs."""
|
|
237
|
-
|
|
238
|
-
# Call
|
|
239
|
-
res =
|
|
261
|
+
stub, run_id = self._get_stub_and_run_id()
|
|
262
|
+
# Call GrpcDriverStub method
|
|
263
|
+
res = stub.get_nodes(GetNodesRequest(run_id=run_id))
|
|
240
264
|
return [node.node_id for node in res.nodes]
|
|
241
265
|
|
|
242
266
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
@@ -245,7 +269,7 @@ class GrpcDriver(Driver):
|
|
|
245
269
|
This method takes an iterable of messages and sends each message
|
|
246
270
|
to the node specified in `dst_node_id`.
|
|
247
271
|
"""
|
|
248
|
-
|
|
272
|
+
stub, _ = self._get_stub_and_run_id()
|
|
249
273
|
# Construct TaskIns
|
|
250
274
|
task_ins_list: List[TaskIns] = []
|
|
251
275
|
for msg in messages:
|
|
@@ -255,10 +279,8 @@ class GrpcDriver(Driver):
|
|
|
255
279
|
taskins = message_to_taskins(msg)
|
|
256
280
|
# Add to list
|
|
257
281
|
task_ins_list.append(taskins)
|
|
258
|
-
# Call
|
|
259
|
-
res =
|
|
260
|
-
PushTaskInsRequest(task_ins_list=task_ins_list)
|
|
261
|
-
)
|
|
282
|
+
# Call GrpcDriverStub method
|
|
283
|
+
res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
|
|
262
284
|
return list(res.task_ids)
|
|
263
285
|
|
|
264
286
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
@@ -267,9 +289,9 @@ class GrpcDriver(Driver):
|
|
|
267
289
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
268
290
|
set of given message IDs.
|
|
269
291
|
"""
|
|
270
|
-
|
|
292
|
+
stub, _ = self._get_stub_and_run_id()
|
|
271
293
|
# Pull TaskRes
|
|
272
|
-
res =
|
|
294
|
+
res = stub.pull_task_res(
|
|
273
295
|
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
274
296
|
)
|
|
275
297
|
# Convert TaskRes to Message
|
|
@@ -308,8 +330,8 @@ class GrpcDriver(Driver):
|
|
|
308
330
|
|
|
309
331
|
def close(self) -> None:
|
|
310
332
|
"""Disconnect from the SuperLink if connected."""
|
|
311
|
-
# Check if
|
|
312
|
-
if self.
|
|
333
|
+
# Check if `connect` was called before
|
|
334
|
+
if not self.stub.is_connected():
|
|
313
335
|
return
|
|
314
336
|
# Disconnect
|
|
315
|
-
self.
|
|
337
|
+
self.stub.disconnect()
|
|
@@ -17,11 +17,12 @@
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
19
|
import warnings
|
|
20
|
-
from typing import Iterable, List, Optional
|
|
20
|
+
from typing import Iterable, List, Optional, cast
|
|
21
21
|
from uuid import UUID
|
|
22
22
|
|
|
23
23
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
24
24
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
25
|
+
from flwr.common.typing import Run
|
|
25
26
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
26
27
|
from flwr.server.superlink.state import StateFactory
|
|
27
28
|
|
|
@@ -33,30 +34,27 @@ class InMemoryDriver(Driver):
|
|
|
33
34
|
|
|
34
35
|
Parameters
|
|
35
36
|
----------
|
|
37
|
+
run_id : int
|
|
38
|
+
The identifier of the run.
|
|
36
39
|
state_factory : StateFactory
|
|
37
40
|
A StateFactory embedding a state that this driver can interface with.
|
|
38
|
-
fab_id : str (default: None)
|
|
39
|
-
The identifier of the FAB used in the run.
|
|
40
|
-
fab_version : str (default: None)
|
|
41
|
-
The version of the FAB used in the run.
|
|
42
41
|
"""
|
|
43
42
|
|
|
44
43
|
def __init__(
|
|
45
44
|
self,
|
|
45
|
+
run_id: int,
|
|
46
46
|
state_factory: StateFactory,
|
|
47
|
-
fab_id: Optional[str] = None,
|
|
48
|
-
fab_version: Optional[str] = None,
|
|
49
47
|
) -> None:
|
|
50
|
-
self.
|
|
51
|
-
self.
|
|
52
|
-
self.fab_version = fab_version if fab_version is not None else ""
|
|
53
|
-
self.node = Node(node_id=0, anonymous=True)
|
|
48
|
+
self._run_id = run_id
|
|
49
|
+
self._run: Optional[Run] = None
|
|
54
50
|
self.state = state_factory.state()
|
|
51
|
+
self.node = Node(node_id=0, anonymous=True)
|
|
55
52
|
|
|
56
53
|
def _check_message(self, message: Message) -> None:
|
|
54
|
+
self._init_run()
|
|
57
55
|
# Check if the message is valid
|
|
58
56
|
if not (
|
|
59
|
-
message.metadata.run_id == self.run_id
|
|
57
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
60
58
|
and message.metadata.src_node_id == self.node.node_id
|
|
61
59
|
and message.metadata.message_id == ""
|
|
62
60
|
and message.metadata.reply_to_message == ""
|
|
@@ -64,16 +62,20 @@ class InMemoryDriver(Driver):
|
|
|
64
62
|
):
|
|
65
63
|
raise ValueError(f"Invalid message: {message}")
|
|
66
64
|
|
|
67
|
-
def
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
if
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
65
|
+
def _init_run(self) -> None:
|
|
66
|
+
"""Initialize the run."""
|
|
67
|
+
if self._run is not None:
|
|
68
|
+
return
|
|
69
|
+
run = self.state.get_run(self._run_id)
|
|
70
|
+
if run is None:
|
|
71
|
+
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
|
|
72
|
+
self._run = run
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def run(self) -> Run:
|
|
76
|
+
"""Run ID."""
|
|
77
|
+
self._init_run()
|
|
78
|
+
return Run(**vars(cast(Run, self._run)))
|
|
77
79
|
|
|
78
80
|
def create_message( # pylint: disable=too-many-arguments
|
|
79
81
|
self,
|
|
@@ -88,7 +90,7 @@ class InMemoryDriver(Driver):
|
|
|
88
90
|
This method constructs a new `Message` with given content and metadata.
|
|
89
91
|
The `run_id` and `src_node_id` will be set automatically.
|
|
90
92
|
"""
|
|
91
|
-
|
|
93
|
+
self._init_run()
|
|
92
94
|
if ttl:
|
|
93
95
|
warnings.warn(
|
|
94
96
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -99,7 +101,7 @@ class InMemoryDriver(Driver):
|
|
|
99
101
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
100
102
|
|
|
101
103
|
metadata = Metadata(
|
|
102
|
-
run_id=run_id,
|
|
104
|
+
run_id=cast(Run, self._run).run_id,
|
|
103
105
|
message_id="", # Will be set by the server
|
|
104
106
|
src_node_id=self.node.node_id,
|
|
105
107
|
dst_node_id=dst_node_id,
|
|
@@ -112,8 +114,8 @@ class InMemoryDriver(Driver):
|
|
|
112
114
|
|
|
113
115
|
def get_node_ids(self) -> List[int]:
|
|
114
116
|
"""Get node IDs."""
|
|
115
|
-
|
|
116
|
-
return list(self.state.get_nodes(run_id))
|
|
117
|
+
self._init_run()
|
|
118
|
+
return list(self.state.get_nodes(cast(Run, self._run).run_id))
|
|
117
119
|
|
|
118
120
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
119
121
|
"""Push messages to specified node IDs.
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -22,10 +22,13 @@ from pathlib import Path
|
|
|
22
22
|
from typing import Optional
|
|
23
23
|
|
|
24
24
|
from flwr.common import Context, EventType, RecordSet, event
|
|
25
|
+
from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
|
|
25
26
|
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
|
|
26
27
|
from flwr.common.object_ref import load_app
|
|
28
|
+
from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611
|
|
27
29
|
|
|
28
|
-
from .driver import Driver
|
|
30
|
+
from .driver import Driver
|
|
31
|
+
from .driver.grpc_driver import GrpcDriver, GrpcDriverStub
|
|
29
32
|
from .server_app import LoadServerAppError, ServerApp
|
|
30
33
|
|
|
31
34
|
ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
@@ -41,7 +44,7 @@ def run(
|
|
|
41
44
|
if not (server_app_attr is None) ^ (loaded_server_app is None):
|
|
42
45
|
raise ValueError(
|
|
43
46
|
"Either `server_app_attr` or `loaded_server_app` should be set "
|
|
44
|
-
"but not both.
|
|
47
|
+
"but not both."
|
|
45
48
|
)
|
|
46
49
|
|
|
47
50
|
if server_app_dir is not None:
|
|
@@ -74,7 +77,7 @@ def run(
|
|
|
74
77
|
log(DEBUG, "ServerApp finished running.")
|
|
75
78
|
|
|
76
79
|
|
|
77
|
-
def run_server_app() -> None:
|
|
80
|
+
def run_server_app() -> None: # pylint: disable=too-many-branches
|
|
78
81
|
"""Run Flower server app."""
|
|
79
82
|
event(EventType.RUN_SERVER_APP_ENTER)
|
|
80
83
|
|
|
@@ -134,11 +137,43 @@ def run_server_app() -> None:
|
|
|
134
137
|
cert_path,
|
|
135
138
|
)
|
|
136
139
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
140
|
+
server_app_attr: Optional[str] = getattr(args, "server-app")
|
|
141
|
+
if not (server_app_attr is None) ^ (args.run_id is None):
|
|
142
|
+
raise sys.exit(
|
|
143
|
+
"Please provide either a ServerApp reference or a Run ID, but not both. "
|
|
144
|
+
"For more details, use: ``flower-server-app -h``"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
stub = GrpcDriverStub(
|
|
148
|
+
driver_service_address=args.superlink, root_certificates=root_certificates
|
|
141
149
|
)
|
|
150
|
+
if args.run_id is not None:
|
|
151
|
+
# User provided `--run-id`, but not `server-app`
|
|
152
|
+
run_id = args.run_id
|
|
153
|
+
else:
|
|
154
|
+
# User provided `server-app`, but not `--run-id`
|
|
155
|
+
# Create run if run_id is not provided
|
|
156
|
+
stub.connect()
|
|
157
|
+
req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version)
|
|
158
|
+
res = stub.create_run(req)
|
|
159
|
+
run_id = res.run_id
|
|
160
|
+
|
|
161
|
+
# Initialize GrpcDriver
|
|
162
|
+
driver = GrpcDriver(run_id=run_id, stub=stub)
|
|
163
|
+
|
|
164
|
+
# Dynamically obtain ServerApp path based on run_id
|
|
165
|
+
if args.run_id is not None:
|
|
166
|
+
# User provided `--run-id`, but not `server-app`
|
|
167
|
+
flwr_dir = get_flwr_dir(args.flwr_dir)
|
|
168
|
+
run_ = driver.run
|
|
169
|
+
server_app_dir = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir))
|
|
170
|
+
config = get_project_config(server_app_dir)
|
|
171
|
+
server_app_attr = config["flower"]["components"]["serverapp"]
|
|
172
|
+
else:
|
|
173
|
+
# User provided `server-app`, but not `--run-id`
|
|
174
|
+
server_app_dir = str(Path(args.dir).absolute())
|
|
175
|
+
|
|
176
|
+
log(DEBUG, "Flower will load ServerApp `%s` in %s", server_app_attr, server_app_dir)
|
|
142
177
|
|
|
143
178
|
log(
|
|
144
179
|
DEBUG,
|
|
@@ -146,17 +181,6 @@ def run_server_app() -> None:
|
|
|
146
181
|
root_certificates,
|
|
147
182
|
)
|
|
148
183
|
|
|
149
|
-
server_app_dir = args.dir
|
|
150
|
-
server_app_attr = getattr(args, "server-app")
|
|
151
|
-
|
|
152
|
-
# Initialize GrpcDriver
|
|
153
|
-
driver = GrpcDriver(
|
|
154
|
-
driver_service_address=args.superlink,
|
|
155
|
-
root_certificates=root_certificates,
|
|
156
|
-
fab_id=args.fab_id,
|
|
157
|
-
fab_version=args.fab_version,
|
|
158
|
-
)
|
|
159
|
-
|
|
160
184
|
# Run the ServerApp with the Driver
|
|
161
185
|
run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
|
|
162
186
|
|
|
@@ -174,6 +198,8 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
|
174
198
|
|
|
175
199
|
parser.add_argument(
|
|
176
200
|
"server-app",
|
|
201
|
+
nargs="?",
|
|
202
|
+
default=None,
|
|
177
203
|
help="For example: `server:app` or `project.package.module:wrapper.app`",
|
|
178
204
|
)
|
|
179
205
|
parser.add_argument(
|
|
@@ -223,5 +249,22 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
|
223
249
|
type=str,
|
|
224
250
|
help="The version of the FAB used in the run.",
|
|
225
251
|
)
|
|
252
|
+
parser.add_argument(
|
|
253
|
+
"--run-id",
|
|
254
|
+
default=None,
|
|
255
|
+
type=int,
|
|
256
|
+
help="The identifier of the run.",
|
|
257
|
+
)
|
|
258
|
+
parser.add_argument(
|
|
259
|
+
"--flwr-dir",
|
|
260
|
+
default=None,
|
|
261
|
+
help="""The path containing installed Flower Apps.
|
|
262
|
+
By default, this value is equal to:
|
|
263
|
+
|
|
264
|
+
- `$FLWR_HOME/` if `$FLWR_HOME` is defined
|
|
265
|
+
- `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
|
|
266
|
+
- `$HOME/.flwr/` in all other cases
|
|
267
|
+
""",
|
|
268
|
+
)
|
|
226
269
|
|
|
227
270
|
return parser
|
flwr/server/strategy/bulyan.py
CHANGED
flwr/server/strategy/fedadam.py
CHANGED
flwr/server/strategy/fedavgm.py
CHANGED
flwr/server/strategy/fedopt.py
CHANGED
flwr/server/strategy/fedprox.py
CHANGED
flwr/server/strategy/fedyogi.py
CHANGED
flwr/server/strategy/krum.py
CHANGED
flwr/server/strategy/qfedavg.py
CHANGED