flwr 1.12.0__py3-none-any.whl → 1.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +2 -2
- flwr/cli/config_utils.py +97 -0
- flwr/cli/install.py +0 -16
- flwr/cli/log.py +63 -97
- flwr/cli/ls.py +228 -0
- flwr/cli/new/new.py +23 -13
- flwr/cli/new/templates/app/README.md.tpl +11 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/run/run.py +37 -89
- flwr/client/app.py +73 -34
- flwr/client/clientapp/app.py +58 -37
- flwr/client/grpc_rere_client/connection.py +7 -12
- flwr/client/nodestate/__init__.py +25 -0
- flwr/client/nodestate/in_memory_nodestate.py +38 -0
- flwr/client/nodestate/nodestate.py +30 -0
- flwr/client/nodestate/nodestate_factory.py +37 -0
- flwr/client/rest_client/connection.py +4 -14
- flwr/client/{node_state.py → run_info_store.py} +4 -3
- flwr/client/supernode/app.py +34 -58
- flwr/common/args.py +152 -0
- flwr/common/config.py +10 -0
- flwr/common/constant.py +59 -7
- flwr/common/context.py +9 -4
- flwr/common/date.py +21 -3
- flwr/common/grpc.py +4 -1
- flwr/common/logger.py +108 -1
- flwr/common/object_ref.py +47 -16
- flwr/common/serde.py +34 -0
- flwr/common/telemetry.py +0 -6
- flwr/common/typing.py +32 -2
- flwr/proto/exec_pb2.py +23 -17
- flwr/proto/exec_pb2.pyi +58 -22
- flwr/proto/exec_pb2_grpc.py +34 -0
- flwr/proto/exec_pb2_grpc.pyi +13 -0
- flwr/proto/log_pb2.py +29 -0
- flwr/proto/log_pb2.pyi +39 -0
- flwr/proto/log_pb2_grpc.py +4 -0
- flwr/proto/log_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +8 -8
- flwr/proto/message_pb2.pyi +4 -1
- flwr/proto/run_pb2.py +32 -27
- flwr/proto/run_pb2.pyi +44 -1
- flwr/proto/serverappio_pb2.py +52 -0
- flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
- flwr/proto/serverappio_pb2_grpc.py +376 -0
- flwr/proto/serverappio_pb2_grpc.pyi +147 -0
- flwr/proto/simulationio_pb2.py +38 -0
- flwr/proto/simulationio_pb2.pyi +65 -0
- flwr/proto/simulationio_pb2_grpc.py +205 -0
- flwr/proto/simulationio_pb2_grpc.pyi +81 -0
- flwr/server/app.py +297 -162
- flwr/server/driver/driver.py +15 -1
- flwr/server/driver/grpc_driver.py +89 -50
- flwr/server/driver/inmemory_driver.py +6 -16
- flwr/server/run_serverapp.py +11 -235
- flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
- flwr/server/serverapp/app.py +234 -0
- flwr/server/strategy/aggregate.py +4 -4
- flwr/server/strategy/fedadam.py +11 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
- flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
- flwr/server/superlink/fleet/rest_rere/rest_api.py +10 -9
- flwr/server/superlink/fleet/vce/vce_api.py +23 -23
- flwr/server/superlink/linkstate/__init__.py +28 -0
- flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +237 -64
- flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +166 -22
- flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
- flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +383 -174
- flwr/server/superlink/linkstate/utils.py +389 -0
- flwr/server/superlink/simulation/__init__.py +15 -0
- flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
- flwr/server/superlink/simulation/simulationio_servicer.py +153 -0
- flwr/simulation/__init__.py +5 -1
- flwr/simulation/app.py +236 -347
- flwr/simulation/legacy_app.py +402 -0
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- flwr/simulation/run_simulation.py +56 -141
- flwr/simulation/simulationio_connection.py +86 -0
- flwr/superexec/app.py +6 -134
- flwr/superexec/deployment.py +70 -69
- flwr/superexec/exec_grpc.py +15 -8
- flwr/superexec/exec_servicer.py +65 -65
- flwr/superexec/executor.py +26 -7
- flwr/superexec/simulation.py +62 -150
- {flwr-1.12.0.dist-info → flwr-1.13.1.dist-info}/METADATA +9 -7
- {flwr-1.12.0.dist-info → flwr-1.13.1.dist-info}/RECORD +105 -85
- {flwr-1.12.0.dist-info → flwr-1.13.1.dist-info}/entry_points.txt +2 -0
- flwr/client/node_state_tests.py +0 -66
- flwr/proto/driver_pb2.py +0 -42
- flwr/proto/driver_pb2_grpc.py +0 -239
- flwr/proto/driver_pb2_grpc.pyi +0 -94
- flwr/server/superlink/state/utils.py +0 -148
- {flwr-1.12.0.dist-info → flwr-1.13.1.dist-info}/LICENSE +0 -0
- {flwr-1.12.0.dist-info → flwr-1.13.1.dist-info}/WHEEL +0 -0
|
@@ -17,22 +17,21 @@
|
|
|
17
17
|
import time
|
|
18
18
|
import warnings
|
|
19
19
|
from collections.abc import Iterable
|
|
20
|
-
from logging import DEBUG, WARNING
|
|
21
|
-
from typing import Optional, cast
|
|
20
|
+
from logging import DEBUG, INFO, WARN, WARNING
|
|
21
|
+
from typing import Any, Optional, cast
|
|
22
22
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
25
25
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
26
|
-
from flwr.common.constant import
|
|
26
|
+
from flwr.common.constant import MAX_RETRY_DELAY, SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS
|
|
27
27
|
from flwr.common.grpc import create_channel
|
|
28
28
|
from flwr.common.logger import log
|
|
29
|
-
from flwr.common.
|
|
30
|
-
|
|
31
|
-
message_to_taskins,
|
|
32
|
-
user_config_from_proto,
|
|
33
|
-
)
|
|
29
|
+
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
|
|
30
|
+
from flwr.common.serde import message_from_taskres, message_to_taskins, run_from_proto
|
|
34
31
|
from flwr.common.typing import Run
|
|
35
|
-
from flwr.proto.
|
|
32
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
33
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
34
|
+
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
36
35
|
GetNodesRequest,
|
|
37
36
|
GetNodesResponse,
|
|
38
37
|
PullTaskResRequest,
|
|
@@ -40,9 +39,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
|
40
39
|
PushTaskInsRequest,
|
|
41
40
|
PushTaskInsResponse,
|
|
42
41
|
)
|
|
43
|
-
from flwr.proto.
|
|
44
|
-
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
45
|
-
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
42
|
+
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
|
|
46
43
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
47
44
|
|
|
48
45
|
from .driver import Driver
|
|
@@ -56,14 +53,12 @@ Call `connect()` on the `GrpcDriverStub` instance before calling any of the othe
|
|
|
56
53
|
|
|
57
54
|
|
|
58
55
|
class GrpcDriver(Driver):
|
|
59
|
-
"""`GrpcDriver` provides an interface to the
|
|
56
|
+
"""`GrpcDriver` provides an interface to the ServerAppIo API.
|
|
60
57
|
|
|
61
58
|
Parameters
|
|
62
59
|
----------
|
|
63
|
-
|
|
64
|
-
The
|
|
65
|
-
driver_service_address : str (default: "[::]:9091")
|
|
66
|
-
The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
|
|
60
|
+
serverappio_service_address : str (default: "[::]:9091")
|
|
61
|
+
The address (URL, IPv6, IPv4) of the SuperLink ServerAppIo API service.
|
|
67
62
|
root_certificates : Optional[bytes] (default: None)
|
|
68
63
|
The PEM-encoded root certificates as a byte string.
|
|
69
64
|
If provided, a secure connection using the certificates will be
|
|
@@ -72,25 +67,24 @@ class GrpcDriver(Driver):
|
|
|
72
67
|
|
|
73
68
|
def __init__( # pylint: disable=too-many-arguments
|
|
74
69
|
self,
|
|
75
|
-
|
|
76
|
-
driver_service_address: str = DRIVER_API_DEFAULT_ADDRESS,
|
|
70
|
+
serverappio_service_address: str = SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
|
77
71
|
root_certificates: Optional[bytes] = None,
|
|
78
72
|
) -> None:
|
|
79
|
-
self.
|
|
80
|
-
self._addr = driver_service_address
|
|
73
|
+
self._addr = serverappio_service_address
|
|
81
74
|
self._cert = root_certificates
|
|
82
75
|
self._run: Optional[Run] = None
|
|
83
|
-
self._grpc_stub: Optional[
|
|
76
|
+
self._grpc_stub: Optional[ServerAppIoStub] = None
|
|
84
77
|
self._channel: Optional[grpc.Channel] = None
|
|
85
78
|
self.node = Node(node_id=0, anonymous=True)
|
|
79
|
+
self._retry_invoker = _make_simple_grpc_retry_invoker()
|
|
86
80
|
|
|
87
81
|
@property
|
|
88
82
|
def _is_connected(self) -> bool:
|
|
89
|
-
"""Check if connected to the
|
|
83
|
+
"""Check if connected to the ServerAppIo API server."""
|
|
90
84
|
return self._channel is not None
|
|
91
85
|
|
|
92
86
|
def _connect(self) -> None:
|
|
93
|
-
"""Connect to the
|
|
87
|
+
"""Connect to the ServerAppIo API.
|
|
94
88
|
|
|
95
89
|
This will not call GetRun.
|
|
96
90
|
"""
|
|
@@ -102,11 +96,12 @@ class GrpcDriver(Driver):
|
|
|
102
96
|
insecure=(self._cert is None),
|
|
103
97
|
root_certificates=self._cert,
|
|
104
98
|
)
|
|
105
|
-
self._grpc_stub =
|
|
99
|
+
self._grpc_stub = ServerAppIoStub(self._channel)
|
|
100
|
+
_wrap_stub(self._grpc_stub, self._retry_invoker)
|
|
106
101
|
log(DEBUG, "[Driver] Connected to %s", self._addr)
|
|
107
102
|
|
|
108
103
|
def _disconnect(self) -> None:
|
|
109
|
-
"""Disconnect from the
|
|
104
|
+
"""Disconnect from the ServerAppIo API."""
|
|
110
105
|
if not self._is_connected:
|
|
111
106
|
log(DEBUG, "Already disconnected")
|
|
112
107
|
return
|
|
@@ -116,41 +111,32 @@ class GrpcDriver(Driver):
|
|
|
116
111
|
channel.close()
|
|
117
112
|
log(DEBUG, "[Driver] Disconnected")
|
|
118
113
|
|
|
119
|
-
def
|
|
120
|
-
|
|
121
|
-
if self._run is not None:
|
|
122
|
-
return
|
|
114
|
+
def set_run(self, run_id: int) -> None:
|
|
115
|
+
"""Set the run."""
|
|
123
116
|
# Get the run info
|
|
124
|
-
req = GetRunRequest(run_id=
|
|
117
|
+
req = GetRunRequest(run_id=run_id)
|
|
125
118
|
res: GetRunResponse = self._stub.GetRun(req)
|
|
126
119
|
if not res.HasField("run"):
|
|
127
|
-
raise RuntimeError(f"Cannot find the run with ID: {
|
|
128
|
-
self._run =
|
|
129
|
-
run_id=res.run.run_id,
|
|
130
|
-
fab_id=res.run.fab_id,
|
|
131
|
-
fab_version=res.run.fab_version,
|
|
132
|
-
fab_hash=res.run.fab_hash,
|
|
133
|
-
override_config=user_config_from_proto(res.run.override_config),
|
|
134
|
-
)
|
|
120
|
+
raise RuntimeError(f"Cannot find the run with ID: {run_id}")
|
|
121
|
+
self._run = run_from_proto(res.run)
|
|
135
122
|
|
|
136
123
|
@property
|
|
137
124
|
def run(self) -> Run:
|
|
138
125
|
"""Run information."""
|
|
139
|
-
self._init_run()
|
|
140
126
|
return Run(**vars(self._run))
|
|
141
127
|
|
|
142
128
|
@property
|
|
143
|
-
def _stub(self) ->
|
|
144
|
-
"""
|
|
129
|
+
def _stub(self) -> ServerAppIoStub:
|
|
130
|
+
"""ServerAppIo stub."""
|
|
145
131
|
if not self._is_connected:
|
|
146
132
|
self._connect()
|
|
147
|
-
return cast(
|
|
133
|
+
return cast(ServerAppIoStub, self._grpc_stub)
|
|
148
134
|
|
|
149
135
|
def _check_message(self, message: Message) -> None:
|
|
150
136
|
# Check if the message is valid
|
|
151
137
|
if not (
|
|
152
138
|
# Assume self._run being initialized
|
|
153
|
-
message.metadata.run_id == self.
|
|
139
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
154
140
|
and message.metadata.src_node_id == self.node.node_id
|
|
155
141
|
and message.metadata.message_id == ""
|
|
156
142
|
and message.metadata.reply_to_message == ""
|
|
@@ -171,7 +157,6 @@ class GrpcDriver(Driver):
|
|
|
171
157
|
This method constructs a new `Message` with given content and metadata.
|
|
172
158
|
The `run_id` and `src_node_id` will be set automatically.
|
|
173
159
|
"""
|
|
174
|
-
self._init_run()
|
|
175
160
|
if ttl:
|
|
176
161
|
warnings.warn(
|
|
177
162
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -182,7 +167,7 @@ class GrpcDriver(Driver):
|
|
|
182
167
|
|
|
183
168
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
184
169
|
metadata = Metadata(
|
|
185
|
-
run_id=self.
|
|
170
|
+
run_id=cast(Run, self._run).run_id,
|
|
186
171
|
message_id="", # Will be set by the server
|
|
187
172
|
src_node_id=self.node.node_id,
|
|
188
173
|
dst_node_id=dst_node_id,
|
|
@@ -195,10 +180,9 @@ class GrpcDriver(Driver):
|
|
|
195
180
|
|
|
196
181
|
def get_node_ids(self) -> list[int]:
|
|
197
182
|
"""Get node IDs."""
|
|
198
|
-
self._init_run()
|
|
199
183
|
# Call GrpcDriverStub method
|
|
200
184
|
res: GetNodesResponse = self._stub.GetNodes(
|
|
201
|
-
GetNodesRequest(run_id=self.
|
|
185
|
+
GetNodesRequest(run_id=cast(Run, self._run).run_id)
|
|
202
186
|
)
|
|
203
187
|
return [node.node_id for node in res.nodes]
|
|
204
188
|
|
|
@@ -208,7 +192,6 @@ class GrpcDriver(Driver):
|
|
|
208
192
|
This method takes an iterable of messages and sends each message
|
|
209
193
|
to the node specified in `dst_node_id`.
|
|
210
194
|
"""
|
|
211
|
-
self._init_run()
|
|
212
195
|
# Construct TaskIns
|
|
213
196
|
task_ins_list: list[TaskIns] = []
|
|
214
197
|
for msg in messages:
|
|
@@ -230,7 +213,6 @@ class GrpcDriver(Driver):
|
|
|
230
213
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
231
214
|
set of given message IDs.
|
|
232
215
|
"""
|
|
233
|
-
self._init_run()
|
|
234
216
|
# Pull TaskRes
|
|
235
217
|
res: PullTaskResResponse = self._stub.PullTaskRes(
|
|
236
218
|
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
@@ -276,3 +258,60 @@ class GrpcDriver(Driver):
|
|
|
276
258
|
return
|
|
277
259
|
# Disconnect
|
|
278
260
|
self._disconnect()
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _make_simple_grpc_retry_invoker() -> RetryInvoker:
|
|
264
|
+
"""Create a simple gRPC retry invoker."""
|
|
265
|
+
|
|
266
|
+
def _on_sucess(retry_state: RetryState) -> None:
|
|
267
|
+
if retry_state.tries > 1:
|
|
268
|
+
log(
|
|
269
|
+
INFO,
|
|
270
|
+
"Connection successful after %.2f seconds and %s tries.",
|
|
271
|
+
retry_state.elapsed_time,
|
|
272
|
+
retry_state.tries,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
def _on_backoff(retry_state: RetryState) -> None:
|
|
276
|
+
if retry_state.tries == 1:
|
|
277
|
+
log(WARN, "Connection attempt failed, retrying...")
|
|
278
|
+
else:
|
|
279
|
+
log(
|
|
280
|
+
WARN,
|
|
281
|
+
"Connection attempt failed, retrying in %.2f seconds",
|
|
282
|
+
retry_state.actual_wait,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
def _on_giveup(retry_state: RetryState) -> None:
|
|
286
|
+
if retry_state.tries > 1:
|
|
287
|
+
log(
|
|
288
|
+
WARN,
|
|
289
|
+
"Giving up reconnection after %.2f seconds and %s tries.",
|
|
290
|
+
retry_state.elapsed_time,
|
|
291
|
+
retry_state.tries,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
return RetryInvoker(
|
|
295
|
+
wait_gen_factory=lambda: exponential(max_delay=MAX_RETRY_DELAY),
|
|
296
|
+
recoverable_exceptions=grpc.RpcError,
|
|
297
|
+
max_tries=None,
|
|
298
|
+
max_time=None,
|
|
299
|
+
on_success=_on_sucess,
|
|
300
|
+
on_backoff=_on_backoff,
|
|
301
|
+
on_giveup=_on_giveup,
|
|
302
|
+
should_giveup=lambda e: e.code() != grpc.StatusCode.UNAVAILABLE, # type: ignore
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _wrap_stub(stub: ServerAppIoStub, retry_invoker: RetryInvoker) -> None:
|
|
307
|
+
"""Wrap the gRPC stub with a retry invoker."""
|
|
308
|
+
|
|
309
|
+
def make_lambda(original_method: Any) -> Any:
|
|
310
|
+
return lambda *args, **kwargs: retry_invoker.invoke(
|
|
311
|
+
original_method, *args, **kwargs
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
for method_name in vars(stub):
|
|
315
|
+
method = getattr(stub, method_name)
|
|
316
|
+
if callable(method):
|
|
317
|
+
setattr(stub, method_name, make_lambda(method))
|
|
@@ -25,18 +25,16 @@ from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
|
25
25
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
26
26
|
from flwr.common.typing import Run
|
|
27
27
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
28
|
-
from flwr.server.superlink.
|
|
28
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
29
29
|
|
|
30
30
|
from .driver import Driver
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
class InMemoryDriver(Driver):
|
|
34
|
-
"""`InMemoryDriver` class provides an interface to the
|
|
34
|
+
"""`InMemoryDriver` class provides an interface to the ServerAppIo API.
|
|
35
35
|
|
|
36
36
|
Parameters
|
|
37
37
|
----------
|
|
38
|
-
run_id : int
|
|
39
|
-
The identifier of the run.
|
|
40
38
|
state_factory : StateFactory
|
|
41
39
|
A StateFactory embedding a state that this driver can interface with.
|
|
42
40
|
pull_interval : float (default=0.1)
|
|
@@ -45,18 +43,15 @@ class InMemoryDriver(Driver):
|
|
|
45
43
|
|
|
46
44
|
def __init__(
|
|
47
45
|
self,
|
|
48
|
-
|
|
49
|
-
state_factory: StateFactory,
|
|
46
|
+
state_factory: LinkStateFactory,
|
|
50
47
|
pull_interval: float = 0.1,
|
|
51
48
|
) -> None:
|
|
52
|
-
self._run_id = run_id
|
|
53
49
|
self._run: Optional[Run] = None
|
|
54
50
|
self.state = state_factory.state()
|
|
55
51
|
self.pull_interval = pull_interval
|
|
56
52
|
self.node = Node(node_id=0, anonymous=True)
|
|
57
53
|
|
|
58
54
|
def _check_message(self, message: Message) -> None:
|
|
59
|
-
self._init_run()
|
|
60
55
|
# Check if the message is valid
|
|
61
56
|
if not (
|
|
62
57
|
message.metadata.run_id == cast(Run, self._run).run_id
|
|
@@ -67,19 +62,16 @@ class InMemoryDriver(Driver):
|
|
|
67
62
|
):
|
|
68
63
|
raise ValueError(f"Invalid message: {message}")
|
|
69
64
|
|
|
70
|
-
def
|
|
65
|
+
def set_run(self, run_id: int) -> None:
|
|
71
66
|
"""Initialize the run."""
|
|
72
|
-
|
|
73
|
-
return
|
|
74
|
-
run = self.state.get_run(self._run_id)
|
|
67
|
+
run = self.state.get_run(run_id)
|
|
75
68
|
if run is None:
|
|
76
|
-
raise RuntimeError(f"Cannot find the run with ID: {
|
|
69
|
+
raise RuntimeError(f"Cannot find the run with ID: {run_id}")
|
|
77
70
|
self._run = run
|
|
78
71
|
|
|
79
72
|
@property
|
|
80
73
|
def run(self) -> Run:
|
|
81
74
|
"""Run ID."""
|
|
82
|
-
self._init_run()
|
|
83
75
|
return Run(**vars(cast(Run, self._run)))
|
|
84
76
|
|
|
85
77
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
|
@@ -95,7 +87,6 @@ class InMemoryDriver(Driver):
|
|
|
95
87
|
This method constructs a new `Message` with given content and metadata.
|
|
96
88
|
The `run_id` and `src_node_id` will be set automatically.
|
|
97
89
|
"""
|
|
98
|
-
self._init_run()
|
|
99
90
|
if ttl:
|
|
100
91
|
warnings.warn(
|
|
101
92
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -119,7 +110,6 @@ class InMemoryDriver(Driver):
|
|
|
119
110
|
|
|
120
111
|
def get_node_ids(self) -> list[int]:
|
|
121
112
|
"""Get node IDs."""
|
|
122
|
-
self._init_run()
|
|
123
113
|
return list(self.state.get_nodes(cast(Run, self._run).run_id))
|
|
124
114
|
|
|
125
115
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -15,44 +15,25 @@
|
|
|
15
15
|
"""Run ServerApp."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import argparse
|
|
19
18
|
import sys
|
|
20
|
-
from logging import DEBUG,
|
|
21
|
-
from pathlib import Path
|
|
19
|
+
from logging import DEBUG, ERROR
|
|
22
20
|
from typing import Optional
|
|
23
21
|
|
|
24
|
-
from flwr.
|
|
25
|
-
from flwr.
|
|
26
|
-
from flwr.common import Context, EventType, RecordSet, event
|
|
27
|
-
from flwr.common.config import (
|
|
28
|
-
get_flwr_dir,
|
|
29
|
-
get_fused_config_from_dir,
|
|
30
|
-
get_metadata_from_config,
|
|
31
|
-
get_project_config,
|
|
32
|
-
get_project_dir,
|
|
33
|
-
)
|
|
34
|
-
from flwr.common.constant import DRIVER_API_DEFAULT_ADDRESS
|
|
35
|
-
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
|
|
22
|
+
from flwr.common import Context
|
|
23
|
+
from flwr.common.logger import log, warn_unsupported_feature
|
|
36
24
|
from flwr.common.object_ref import load_app
|
|
37
|
-
from flwr.common.typing import UserConfig
|
|
38
|
-
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
39
|
-
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
40
|
-
CreateRunRequest,
|
|
41
|
-
CreateRunResponse,
|
|
42
|
-
)
|
|
43
25
|
|
|
44
26
|
from .driver import Driver
|
|
45
|
-
from .driver.grpc_driver import GrpcDriver
|
|
46
27
|
from .server_app import LoadServerAppError, ServerApp
|
|
47
28
|
|
|
48
29
|
|
|
49
30
|
def run(
|
|
50
31
|
driver: Driver,
|
|
32
|
+
context: Context,
|
|
51
33
|
server_app_dir: str,
|
|
52
|
-
server_app_run_config: UserConfig,
|
|
53
34
|
server_app_attr: Optional[str] = None,
|
|
54
35
|
loaded_server_app: Optional[ServerApp] = None,
|
|
55
|
-
) ->
|
|
36
|
+
) -> Context:
|
|
56
37
|
"""Run ServerApp with a given Driver."""
|
|
57
38
|
if not (server_app_attr is None) ^ (loaded_server_app is None):
|
|
58
39
|
raise ValueError(
|
|
@@ -78,224 +59,19 @@ def run(
|
|
|
78
59
|
|
|
79
60
|
server_app = _load()
|
|
80
61
|
|
|
81
|
-
# Initialize Context
|
|
82
|
-
context = Context(
|
|
83
|
-
node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config
|
|
84
|
-
)
|
|
85
|
-
|
|
86
62
|
# Call ServerApp
|
|
87
63
|
server_app(driver=driver, context=context)
|
|
88
64
|
|
|
89
65
|
log(DEBUG, "ServerApp finished running.")
|
|
66
|
+
return context
|
|
90
67
|
|
|
91
68
|
|
|
92
69
|
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
|
|
93
70
|
def run_server_app() -> None:
|
|
94
71
|
"""Run Flower server app."""
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
# Check if the server app reference is passed.
|
|
100
|
-
# Since Flower 1.11, passing a reference is not allowed.
|
|
101
|
-
app_path: Optional[str] = args.app
|
|
102
|
-
# If the provided app_path doesn't exist, and contains a ":",
|
|
103
|
-
# it is likely to be a server app reference instead of a path.
|
|
104
|
-
if app_path is not None and not Path(app_path).exists() and ":" in app_path:
|
|
105
|
-
sys.exit(
|
|
106
|
-
"It appears you've passed a reference like `server:app`.\n\n"
|
|
107
|
-
"Note that since version `1.11.0`, `flower-server-app` no longer supports "
|
|
108
|
-
"passing a reference to a `ServerApp` attribute. Instead, you need to pass "
|
|
109
|
-
"the path to Flower app via the argument `--app`. This is the path to a "
|
|
110
|
-
"directory containing a `pyproject.toml`. You can create a valid Flower "
|
|
111
|
-
"app by executing `flwr new` and following the prompt."
|
|
112
|
-
)
|
|
113
|
-
|
|
114
|
-
if args.server != DRIVER_API_DEFAULT_ADDRESS:
|
|
115
|
-
warn = "Passing flag --server is deprecated. Use --superlink instead."
|
|
116
|
-
warn_deprecated_feature(warn)
|
|
117
|
-
|
|
118
|
-
if args.superlink != DRIVER_API_DEFAULT_ADDRESS:
|
|
119
|
-
# if `--superlink` also passed, then
|
|
120
|
-
# warn user that this argument overrides what was passed with `--server`
|
|
121
|
-
log(
|
|
122
|
-
WARN,
|
|
123
|
-
"Both `--server` and `--superlink` were passed. "
|
|
124
|
-
"`--server` will be ignored. Connecting to the Superlink Driver API "
|
|
125
|
-
"at %s.",
|
|
126
|
-
args.superlink,
|
|
127
|
-
)
|
|
128
|
-
else:
|
|
129
|
-
args.superlink = args.server
|
|
130
|
-
|
|
131
|
-
update_console_handler(
|
|
132
|
-
level=DEBUG if args.verbose else INFO,
|
|
133
|
-
timestamps=args.verbose,
|
|
134
|
-
colored=True,
|
|
135
|
-
)
|
|
136
|
-
|
|
137
|
-
# Obtain certificates
|
|
138
|
-
if args.insecure:
|
|
139
|
-
if args.root_certificates is not None:
|
|
140
|
-
sys.exit(
|
|
141
|
-
"Conflicting options: The '--insecure' flag disables HTTPS, "
|
|
142
|
-
"but '--root-certificates' was also specified. Please remove "
|
|
143
|
-
"the '--root-certificates' option when running in insecure mode, "
|
|
144
|
-
"or omit '--insecure' to use HTTPS."
|
|
145
|
-
)
|
|
146
|
-
log(
|
|
147
|
-
WARN,
|
|
148
|
-
"Option `--insecure` was set. "
|
|
149
|
-
"Starting insecure HTTP client connected to %s.",
|
|
150
|
-
args.superlink,
|
|
151
|
-
)
|
|
152
|
-
root_certificates = None
|
|
153
|
-
else:
|
|
154
|
-
# Load the certificates if provided, or load the system certificates
|
|
155
|
-
cert_path = args.root_certificates
|
|
156
|
-
if cert_path is None:
|
|
157
|
-
root_certificates = None
|
|
158
|
-
else:
|
|
159
|
-
root_certificates = Path(cert_path).read_bytes()
|
|
160
|
-
log(
|
|
161
|
-
DEBUG,
|
|
162
|
-
"Starting secure HTTPS client connected to %s "
|
|
163
|
-
"with the following certificates: %s.",
|
|
164
|
-
args.superlink,
|
|
165
|
-
cert_path,
|
|
166
|
-
)
|
|
167
|
-
|
|
168
|
-
if not (app_path is None) ^ (args.run_id is None):
|
|
169
|
-
raise sys.exit(
|
|
170
|
-
"Please provide either a Flower App path or a Run ID, but not both. "
|
|
171
|
-
"For more details, use: ``flower-server-app -h``"
|
|
172
|
-
)
|
|
173
|
-
|
|
174
|
-
# Initialize GrpcDriver
|
|
175
|
-
if app_path is None:
|
|
176
|
-
# User provided `--run-id`, but not `app_dir`
|
|
177
|
-
driver = GrpcDriver(
|
|
178
|
-
run_id=args.run_id,
|
|
179
|
-
driver_service_address=args.superlink,
|
|
180
|
-
root_certificates=root_certificates,
|
|
181
|
-
)
|
|
182
|
-
flwr_dir = get_flwr_dir(args.flwr_dir)
|
|
183
|
-
run_ = driver.run
|
|
184
|
-
if not run_.fab_hash:
|
|
185
|
-
raise ValueError("FAB hash not provided.")
|
|
186
|
-
fab_req = GetFabRequest(hash_str=run_.fab_hash)
|
|
187
|
-
# pylint: disable-next=W0212
|
|
188
|
-
fab_res: GetFabResponse = driver._stub.GetFab(fab_req)
|
|
189
|
-
if fab_res.fab.hash_str != run_.fab_hash:
|
|
190
|
-
raise ValueError("FAB hashes don't match.")
|
|
191
|
-
install_from_fab(fab_res.fab.content, flwr_dir, True)
|
|
192
|
-
fab_id, fab_version = get_fab_metadata(fab_res.fab.content)
|
|
193
|
-
|
|
194
|
-
app_path = str(get_project_dir(fab_id, fab_version, run_.fab_hash, flwr_dir))
|
|
195
|
-
config = get_project_config(app_path)
|
|
196
|
-
else:
|
|
197
|
-
# User provided `app_dir`, but not `--run-id`
|
|
198
|
-
# Create run if run_id is not provided
|
|
199
|
-
driver = GrpcDriver(
|
|
200
|
-
run_id=0, # Will be overwritten
|
|
201
|
-
driver_service_address=args.superlink,
|
|
202
|
-
root_certificates=root_certificates,
|
|
203
|
-
)
|
|
204
|
-
# Load config from the project directory
|
|
205
|
-
config = get_project_config(app_path)
|
|
206
|
-
fab_version, fab_id = get_metadata_from_config(config)
|
|
207
|
-
|
|
208
|
-
# Create run
|
|
209
|
-
req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version)
|
|
210
|
-
res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
|
|
211
|
-
# Overwrite driver._run_id
|
|
212
|
-
driver._run_id = res.run_id # pylint: disable=W0212
|
|
213
|
-
|
|
214
|
-
# Obtain server app reference and the run config
|
|
215
|
-
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
|
|
216
|
-
server_app_run_config = get_fused_config_from_dir(
|
|
217
|
-
Path(app_path), driver.run.override_config
|
|
72
|
+
warn_unsupported_feature(
|
|
73
|
+
"The command `flower-server-app` is deprecated and no longer in use. "
|
|
74
|
+
"Use the `flwr-serverapp` exclusively instead."
|
|
218
75
|
)
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
log(
|
|
223
|
-
DEBUG,
|
|
224
|
-
"root_certificates: `%s`",
|
|
225
|
-
root_certificates,
|
|
226
|
-
)
|
|
227
|
-
|
|
228
|
-
# Run the ServerApp with the Driver
|
|
229
|
-
run(
|
|
230
|
-
driver=driver,
|
|
231
|
-
server_app_dir=app_path,
|
|
232
|
-
server_app_run_config=server_app_run_config,
|
|
233
|
-
server_app_attr=server_app_attr,
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
# Clean up
|
|
237
|
-
driver.close()
|
|
238
|
-
|
|
239
|
-
event(EventType.RUN_SERVER_APP_LEAVE)
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
243
|
-
"""Parse flower-server-app command line arguments."""
|
|
244
|
-
parser = argparse.ArgumentParser(
|
|
245
|
-
description="Start a Flower server app",
|
|
246
|
-
)
|
|
247
|
-
|
|
248
|
-
parser.add_argument(
|
|
249
|
-
"app",
|
|
250
|
-
nargs="?",
|
|
251
|
-
default=None,
|
|
252
|
-
help="Load and run the `ServerApp` from the specified Flower App path. "
|
|
253
|
-
"The `pyproject.toml` file must be located in the root of this path.",
|
|
254
|
-
)
|
|
255
|
-
parser.add_argument(
|
|
256
|
-
"--insecure",
|
|
257
|
-
action="store_true",
|
|
258
|
-
help="Run the `ServerApp` without HTTPS. By default, the app runs with "
|
|
259
|
-
"HTTPS enabled. Use this flag only if you understand the risks.",
|
|
260
|
-
)
|
|
261
|
-
parser.add_argument(
|
|
262
|
-
"--verbose",
|
|
263
|
-
action="store_true",
|
|
264
|
-
help="Set the logging to `DEBUG`.",
|
|
265
|
-
)
|
|
266
|
-
parser.add_argument(
|
|
267
|
-
"--root-certificates",
|
|
268
|
-
metavar="ROOT_CERT",
|
|
269
|
-
type=str,
|
|
270
|
-
help="Specifies the path to the PEM-encoded root certificate file for "
|
|
271
|
-
"establishing secure HTTPS connections.",
|
|
272
|
-
)
|
|
273
|
-
parser.add_argument(
|
|
274
|
-
"--server",
|
|
275
|
-
default=DRIVER_API_DEFAULT_ADDRESS,
|
|
276
|
-
help="Server address",
|
|
277
|
-
)
|
|
278
|
-
parser.add_argument(
|
|
279
|
-
"--superlink",
|
|
280
|
-
default=DRIVER_API_DEFAULT_ADDRESS,
|
|
281
|
-
help="SuperLink Driver API (gRPC-rere) address (IPv4, IPv6, or a domain name)",
|
|
282
|
-
)
|
|
283
|
-
parser.add_argument(
|
|
284
|
-
"--run-id",
|
|
285
|
-
default=None,
|
|
286
|
-
type=int,
|
|
287
|
-
help="The identifier of the run.",
|
|
288
|
-
)
|
|
289
|
-
parser.add_argument(
|
|
290
|
-
"--flwr-dir",
|
|
291
|
-
default=None,
|
|
292
|
-
help="""The path containing installed Flower Apps.
|
|
293
|
-
By default, this value is equal to:
|
|
294
|
-
|
|
295
|
-
- `$FLWR_HOME/` if `$FLWR_HOME` is defined
|
|
296
|
-
- `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
|
|
297
|
-
- `$HOME/.flwr/` in all other cases
|
|
298
|
-
""",
|
|
299
|
-
)
|
|
300
|
-
|
|
301
|
-
return parser
|
|
76
|
+
log(ERROR, "`flower-server-app` used.")
|
|
77
|
+
sys.exit()
|
|
@@ -12,17 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower
|
|
15
|
+
"""Flower AppIO service."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from .
|
|
19
|
-
from .sqlite_state import SqliteState as SqliteState
|
|
20
|
-
from .state import State as State
|
|
21
|
-
from .state_factory import StateFactory as StateFactory
|
|
18
|
+
from .app import flwr_serverapp as flwr_serverapp
|
|
22
19
|
|
|
23
20
|
__all__ = [
|
|
24
|
-
"
|
|
25
|
-
"SqliteState",
|
|
26
|
-
"State",
|
|
27
|
-
"StateFactory",
|
|
21
|
+
"flwr_serverapp",
|
|
28
22
|
]
|