flwr-nightly 1.8.0.dev20240303__py3-none-any.whl → 1.8.0.dev20240305__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/new.py +1 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +12 -0
- flwr/cli/new/templates/app/flower.toml.tpl +2 -2
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +2 -0
- flwr/client/app.py +93 -8
- flwr/client/grpc_client/connection.py +7 -0
- flwr/client/grpc_rere_client/connection.py +14 -4
- flwr/client/rest_client/connection.py +16 -4
- flwr/common/__init__.py +4 -0
- flwr/common/constant.py +11 -0
- flwr/server/app.py +7 -7
- flwr/server/run_serverapp.py +14 -9
- flwr/server/server.py +5 -5
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +17 -5
- flwr/simulation/__init__.py +2 -5
- flwr/simulation/run_simulation.py +301 -76
- {flwr_nightly-1.8.0.dev20240303.dist-info → flwr_nightly-1.8.0.dev20240305.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240303.dist-info → flwr_nightly-1.8.0.dev20240305.dist-info}/RECORD +23 -20
- {flwr_nightly-1.8.0.dev20240303.dist-info → flwr_nightly-1.8.0.dev20240305.dist-info}/entry_points.txt +1 -1
- {flwr_nightly-1.8.0.dev20240303.dist-info → flwr_nightly-1.8.0.dev20240305.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240303.dist-info → flwr_nightly-1.8.0.dev20240305.dist-info}/WHEEL +0 -0
flwr/cli/new/new.py
CHANGED
@@ -0,0 +1,24 @@
|
|
1
|
+
"""$project_name: A Flower / NumPy app."""
|
2
|
+
|
3
|
+
import flwr as fl
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
|
7
|
+
# Flower client, adapted from Pytorch quickstart example
|
8
|
+
class FlowerClient(fl.client.NumPyClient):
|
9
|
+
def get_parameters(self, config):
|
10
|
+
return [np.ones((1, 1))]
|
11
|
+
|
12
|
+
def fit(self, parameters, config):
|
13
|
+
return ([np.ones((1, 1))], 1, {})
|
14
|
+
|
15
|
+
def evaluate(self, parameters, config):
|
16
|
+
return float(0.0), 1, {"accuracy": float(1.0)}
|
17
|
+
|
18
|
+
|
19
|
+
def client_fn(cid: str):
|
20
|
+
return FlowerClient().to_client()
|
21
|
+
|
22
|
+
|
23
|
+
# ClientApp for Flower-Next
|
24
|
+
app = fl.client.ClientApp(client_fn=client_fn)
|
@@ -0,0 +1,12 @@
|
|
1
|
+
"""$project_name: A Flower / NumPy app."""
|
2
|
+
|
3
|
+
import flwr as fl
|
4
|
+
|
5
|
+
# Configure the strategy
|
6
|
+
strategy = fl.server.strategy.FedAvg()
|
7
|
+
|
8
|
+
# Flower ServerApp
|
9
|
+
app = fl.server.ServerApp(
|
10
|
+
config=fl.server.ServerConfig(num_rounds=1),
|
11
|
+
strategy=strategy,
|
12
|
+
)
|
@@ -1,10 +1,10 @@
|
|
1
|
-
[
|
1
|
+
[project]
|
2
2
|
name = "$project_name"
|
3
3
|
version = "1.0.0"
|
4
4
|
description = ""
|
5
5
|
license = "Apache-2.0"
|
6
6
|
authors = ["The Flower Authors <hello@flower.ai>"]
|
7
7
|
|
8
|
-
[components]
|
8
|
+
[flower.components]
|
9
9
|
serverapp = "$project_name.server:app"
|
10
10
|
clientapp = "$project_name.client:app"
|
flwr/client/app.py
CHANGED
@@ -20,7 +20,9 @@ import sys
|
|
20
20
|
import time
|
21
21
|
from logging import DEBUG, INFO, WARN
|
22
22
|
from pathlib import Path
|
23
|
-
from typing import Callable, ContextManager, Optional, Tuple, Union
|
23
|
+
from typing import Callable, ContextManager, Optional, Tuple, Type, Union
|
24
|
+
|
25
|
+
from grpc import RpcError
|
24
26
|
|
25
27
|
from flwr.client.client import Client
|
26
28
|
from flwr.client.client_app import ClientApp
|
@@ -36,6 +38,7 @@ from flwr.common.constant import (
|
|
36
38
|
)
|
37
39
|
from flwr.common.exit_handlers import register_exit_handlers
|
38
40
|
from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature
|
41
|
+
from flwr.common.retry_invoker import RetryInvoker, exponential
|
39
42
|
|
40
43
|
from .client_app import load_client_app
|
41
44
|
from .grpc_client.connection import grpc_connection
|
@@ -104,6 +107,8 @@ def run_client_app() -> None:
|
|
104
107
|
transport="rest" if args.rest else "grpc-rere",
|
105
108
|
root_certificates=root_certificates,
|
106
109
|
insecure=args.insecure,
|
110
|
+
max_retries=args.max_retries,
|
111
|
+
max_wait_time=args.max_wait_time,
|
107
112
|
)
|
108
113
|
register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
|
109
114
|
|
@@ -141,6 +146,22 @@ def _parse_args_run_client_app() -> argparse.ArgumentParser:
|
|
141
146
|
default="0.0.0.0:9092",
|
142
147
|
help="Server address",
|
143
148
|
)
|
149
|
+
parser.add_argument(
|
150
|
+
"--max-retries",
|
151
|
+
type=int,
|
152
|
+
default=None,
|
153
|
+
help="The maximum number of times the client will try to connect to the"
|
154
|
+
"server before giving up in case of a connection error. By default,"
|
155
|
+
"it is set to None, meaning there is no limit to the number of tries.",
|
156
|
+
)
|
157
|
+
parser.add_argument(
|
158
|
+
"--max-wait-time",
|
159
|
+
type=float,
|
160
|
+
default=None,
|
161
|
+
help="The maximum duration before the client stops trying to"
|
162
|
+
"connect to the server in case of connection error. By default, it"
|
163
|
+
"is set to None, meaning there is no limit to the total time.",
|
164
|
+
)
|
144
165
|
parser.add_argument(
|
145
166
|
"--dir",
|
146
167
|
default="",
|
@@ -180,6 +201,8 @@ def start_client(
|
|
180
201
|
root_certificates: Optional[Union[bytes, str]] = None,
|
181
202
|
insecure: Optional[bool] = None,
|
182
203
|
transport: Optional[str] = None,
|
204
|
+
max_retries: Optional[int] = None,
|
205
|
+
max_wait_time: Optional[float] = None,
|
183
206
|
) -> None:
|
184
207
|
"""Start a Flower client node which connects to a Flower server.
|
185
208
|
|
@@ -213,6 +236,14 @@ def start_client(
|
|
213
236
|
- 'grpc-bidi': gRPC, bidirectional streaming
|
214
237
|
- 'grpc-rere': gRPC, request-response (experimental)
|
215
238
|
- 'rest': HTTP (experimental)
|
239
|
+
max_retries: Optional[int] (default: None)
|
240
|
+
The maximum number of times the client will try to connect to the
|
241
|
+
server before giving up in case of a connection error. If set to None,
|
242
|
+
there is no limit to the number of tries.
|
243
|
+
max_wait_time: Optional[float] (default: None)
|
244
|
+
The maximum duration before the client stops trying to
|
245
|
+
connect to the server in case of connection error.
|
246
|
+
If set to None, there is no limit to the total time.
|
216
247
|
|
217
248
|
Examples
|
218
249
|
--------
|
@@ -254,6 +285,8 @@ def start_client(
|
|
254
285
|
root_certificates=root_certificates,
|
255
286
|
insecure=insecure,
|
256
287
|
transport=transport,
|
288
|
+
max_retries=max_retries,
|
289
|
+
max_wait_time=max_wait_time,
|
257
290
|
)
|
258
291
|
event(EventType.START_CLIENT_LEAVE)
|
259
292
|
|
@@ -272,6 +305,8 @@ def _start_client_internal(
|
|
272
305
|
root_certificates: Optional[Union[bytes, str]] = None,
|
273
306
|
insecure: Optional[bool] = None,
|
274
307
|
transport: Optional[str] = None,
|
308
|
+
max_retries: Optional[int] = None,
|
309
|
+
max_wait_time: Optional[float] = None,
|
275
310
|
) -> None:
|
276
311
|
"""Start a Flower client node which connects to a Flower server.
|
277
312
|
|
@@ -299,7 +334,7 @@ def _start_client_internal(
|
|
299
334
|
The PEM-encoded root certificates as a byte string or a path string.
|
300
335
|
If provided, a secure connection using the certificates will be
|
301
336
|
established to an SSL-enabled Flower server.
|
302
|
-
insecure : bool (default:
|
337
|
+
insecure : Optional[bool] (default: None)
|
303
338
|
Starts an insecure gRPC connection when True. Enables HTTPS connection
|
304
339
|
when False, using system certificates if `root_certificates` is None.
|
305
340
|
transport : Optional[str] (default: None)
|
@@ -307,6 +342,14 @@ def _start_client_internal(
|
|
307
342
|
- 'grpc-bidi': gRPC, bidirectional streaming
|
308
343
|
- 'grpc-rere': gRPC, request-response (experimental)
|
309
344
|
- 'rest': HTTP (experimental)
|
345
|
+
max_retries: Optional[int] (default: None)
|
346
|
+
The maximum number of times the client will try to connect to the
|
347
|
+
server before giving up in case of a connection error. If set to None,
|
348
|
+
there is no limit to the number of tries.
|
349
|
+
max_wait_time: Optional[float] (default: None)
|
350
|
+
The maximum duration before the client stops trying to
|
351
|
+
connect to the server in case of connection error.
|
352
|
+
If set to None, there is no limit to the total time.
|
310
353
|
"""
|
311
354
|
if insecure is None:
|
312
355
|
insecure = root_certificates is None
|
@@ -338,7 +381,45 @@ def _start_client_internal(
|
|
338
381
|
# Both `client` and `client_fn` must not be used directly
|
339
382
|
|
340
383
|
# Initialize connection context manager
|
341
|
-
connection, address = _init_connection(
|
384
|
+
connection, address, connection_error_type = _init_connection(
|
385
|
+
transport, server_address
|
386
|
+
)
|
387
|
+
|
388
|
+
retry_invoker = RetryInvoker(
|
389
|
+
wait_factory=exponential,
|
390
|
+
recoverable_exceptions=connection_error_type,
|
391
|
+
max_tries=max_retries,
|
392
|
+
max_time=max_wait_time,
|
393
|
+
on_giveup=lambda retry_state: (
|
394
|
+
log(
|
395
|
+
WARN,
|
396
|
+
"Giving up reconnection after %.2f seconds and %s tries.",
|
397
|
+
retry_state.elapsed_time,
|
398
|
+
retry_state.tries,
|
399
|
+
)
|
400
|
+
if retry_state.tries > 1
|
401
|
+
else None
|
402
|
+
),
|
403
|
+
on_success=lambda retry_state: (
|
404
|
+
log(
|
405
|
+
INFO,
|
406
|
+
"Connection successful after %.2f seconds and %s tries.",
|
407
|
+
retry_state.elapsed_time,
|
408
|
+
retry_state.tries,
|
409
|
+
)
|
410
|
+
if retry_state.tries > 1
|
411
|
+
else None
|
412
|
+
),
|
413
|
+
on_backoff=lambda retry_state: (
|
414
|
+
log(WARN, "Connection attempt failed, retrying...")
|
415
|
+
if retry_state.tries == 1
|
416
|
+
else log(
|
417
|
+
DEBUG,
|
418
|
+
"Connection attempt failed, retrying in %.2f seconds",
|
419
|
+
retry_state.actual_wait,
|
420
|
+
)
|
421
|
+
),
|
422
|
+
)
|
342
423
|
|
343
424
|
node_state = NodeState()
|
344
425
|
|
@@ -347,6 +428,7 @@ def _start_client_internal(
|
|
347
428
|
with connection(
|
348
429
|
address,
|
349
430
|
insecure,
|
431
|
+
retry_invoker,
|
350
432
|
grpc_max_message_length,
|
351
433
|
root_certificates,
|
352
434
|
) as conn:
|
@@ -509,7 +591,7 @@ def start_numpy_client(
|
|
509
591
|
|
510
592
|
def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
511
593
|
Callable[
|
512
|
-
[str, bool, int, Union[bytes, str, None]],
|
594
|
+
[str, bool, RetryInvoker, int, Union[bytes, str, None]],
|
513
595
|
ContextManager[
|
514
596
|
Tuple[
|
515
597
|
Callable[[], Optional[Message]],
|
@@ -520,6 +602,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
520
602
|
],
|
521
603
|
],
|
522
604
|
str,
|
605
|
+
Type[Exception],
|
523
606
|
]:
|
524
607
|
# Parse IP address
|
525
608
|
parsed_address = parse_address(server_address)
|
@@ -535,6 +618,8 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
535
618
|
# Use either gRPC bidirectional streaming or REST request/response
|
536
619
|
if transport == TRANSPORT_TYPE_REST:
|
537
620
|
try:
|
621
|
+
from requests.exceptions import ConnectionError as RequestsConnectionError
|
622
|
+
|
538
623
|
from .rest_client.connection import http_request_response
|
539
624
|
except ModuleNotFoundError:
|
540
625
|
sys.exit(MISSING_EXTRA_REST)
|
@@ -543,14 +628,14 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
543
628
|
"When using the REST API, please provide `https://` or "
|
544
629
|
"`http://` before the server address (e.g. `http://127.0.0.1:8080`)"
|
545
630
|
)
|
546
|
-
connection = http_request_response
|
631
|
+
connection, error_type = http_request_response, RequestsConnectionError
|
547
632
|
elif transport == TRANSPORT_TYPE_GRPC_RERE:
|
548
|
-
connection = grpc_request_response
|
633
|
+
connection, error_type = grpc_request_response, RpcError
|
549
634
|
elif transport == TRANSPORT_TYPE_GRPC_BIDI:
|
550
|
-
connection = grpc_connection
|
635
|
+
connection, error_type = grpc_connection, RpcError
|
551
636
|
else:
|
552
637
|
raise ValueError(
|
553
638
|
f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})"
|
554
639
|
)
|
555
640
|
|
556
|
-
return connection, address
|
641
|
+
return connection, address, error_type
|
@@ -39,6 +39,7 @@ from flwr.common.constant import (
|
|
39
39
|
)
|
40
40
|
from flwr.common.grpc import create_channel
|
41
41
|
from flwr.common.logger import log
|
42
|
+
from flwr.common.retry_invoker import RetryInvoker
|
42
43
|
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
43
44
|
ClientMessage,
|
44
45
|
Reason,
|
@@ -62,6 +63,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:
|
|
62
63
|
def grpc_connection( # pylint: disable=R0915
|
63
64
|
server_address: str,
|
64
65
|
insecure: bool,
|
66
|
+
retry_invoker: RetryInvoker, # pylint: disable=unused-argument
|
65
67
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
66
68
|
root_certificates: Optional[Union[bytes, str]] = None,
|
67
69
|
) -> Iterator[
|
@@ -80,6 +82,11 @@ def grpc_connection( # pylint: disable=R0915
|
|
80
82
|
The IPv4 or IPv6 address of the server. If the Flower server runs on the same
|
81
83
|
machine on port 8080, then `server_address` would be `"0.0.0.0:8080"` or
|
82
84
|
`"[::]:8080"`.
|
85
|
+
insecure : bool
|
86
|
+
Starts an insecure gRPC connection when True. Enables HTTPS connection
|
87
|
+
when False, using system certificates if `root_certificates` is None.
|
88
|
+
retry_invoker: RetryInvoker
|
89
|
+
Unused argument present for compatibilty.
|
83
90
|
max_message_length : int
|
84
91
|
The maximum length of gRPC messages that can be exchanged with the Flower
|
85
92
|
server. The default should be sufficient for most models. Users who train
|
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
27
27
|
from flwr.common.grpc import create_channel
|
28
28
|
from flwr.common.logger import log, warn_experimental_feature
|
29
29
|
from flwr.common.message import Message, Metadata
|
30
|
+
from flwr.common.retry_invoker import RetryInvoker
|
30
31
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
31
32
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
32
33
|
CreateNodeRequest,
|
@@ -51,6 +52,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:
|
|
51
52
|
def grpc_request_response(
|
52
53
|
server_address: str,
|
53
54
|
insecure: bool,
|
55
|
+
retry_invoker: RetryInvoker,
|
54
56
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
55
57
|
root_certificates: Optional[Union[bytes, str]] = None,
|
56
58
|
) -> Iterator[
|
@@ -72,6 +74,13 @@ def grpc_request_response(
|
|
72
74
|
The IPv6 address of the server with `http://` or `https://`.
|
73
75
|
If the Flower server runs on the same machine
|
74
76
|
on port 8080, then `server_address` would be `"http://[::]:8080"`.
|
77
|
+
insecure : bool
|
78
|
+
Starts an insecure gRPC connection when True. Enables HTTPS connection
|
79
|
+
when False, using system certificates if `root_certificates` is None.
|
80
|
+
retry_invoker: RetryInvoker
|
81
|
+
`RetryInvoker` object that will try to reconnect the client to the server
|
82
|
+
after gRPC errors. If None, the client will only try to
|
83
|
+
reconnect once after a failure.
|
75
84
|
max_message_length : int
|
76
85
|
Ignored, only present to preserve API-compatibility.
|
77
86
|
root_certificates : Optional[Union[bytes, str]] (default: None)
|
@@ -113,7 +122,8 @@ def grpc_request_response(
|
|
113
122
|
def create_node() -> None:
|
114
123
|
"""Set create_node."""
|
115
124
|
create_node_request = CreateNodeRequest()
|
116
|
-
create_node_response =
|
125
|
+
create_node_response = retry_invoker.invoke(
|
126
|
+
stub.CreateNode,
|
117
127
|
request=create_node_request,
|
118
128
|
)
|
119
129
|
node_store[KEY_NODE] = create_node_response.node
|
@@ -127,7 +137,7 @@ def grpc_request_response(
|
|
127
137
|
node: Node = cast(Node, node_store[KEY_NODE])
|
128
138
|
|
129
139
|
delete_node_request = DeleteNodeRequest(node=node)
|
130
|
-
stub.DeleteNode
|
140
|
+
retry_invoker.invoke(stub.DeleteNode, request=delete_node_request)
|
131
141
|
|
132
142
|
del node_store[KEY_NODE]
|
133
143
|
|
@@ -141,7 +151,7 @@ def grpc_request_response(
|
|
141
151
|
|
142
152
|
# Request instructions (task) from server
|
143
153
|
request = PullTaskInsRequest(node=node)
|
144
|
-
response = stub.PullTaskIns
|
154
|
+
response = retry_invoker.invoke(stub.PullTaskIns, request=request)
|
145
155
|
|
146
156
|
# Get the current TaskIns
|
147
157
|
task_ins: Optional[TaskIns] = get_task_ins(response)
|
@@ -185,7 +195,7 @@ def grpc_request_response(
|
|
185
195
|
|
186
196
|
# Serialize ProtoBuf to bytes
|
187
197
|
request = PushTaskResRequest(task_res_list=[task_res])
|
188
|
-
_ = stub.PushTaskRes
|
198
|
+
_ = retry_invoker.invoke(stub.PushTaskRes, request)
|
189
199
|
|
190
200
|
state[KEY_METADATA] = None
|
191
201
|
|
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
27
27
|
from flwr.common.constant import MISSING_EXTRA_REST
|
28
28
|
from flwr.common.logger import log
|
29
29
|
from flwr.common.message import Message, Metadata
|
30
|
+
from flwr.common.retry_invoker import RetryInvoker
|
30
31
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
31
32
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
32
33
|
CreateNodeRequest,
|
@@ -61,6 +62,7 @@ PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
|
|
61
62
|
def http_request_response(
|
62
63
|
server_address: str,
|
63
64
|
insecure: bool, # pylint: disable=unused-argument
|
65
|
+
retry_invoker: RetryInvoker,
|
64
66
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
65
67
|
root_certificates: Optional[
|
66
68
|
Union[bytes, str]
|
@@ -84,6 +86,12 @@ def http_request_response(
|
|
84
86
|
The IPv6 address of the server with `http://` or `https://`.
|
85
87
|
If the Flower server runs on the same machine
|
86
88
|
on port 8080, then `server_address` would be `"http://[::]:8080"`.
|
89
|
+
insecure : bool
|
90
|
+
Unused argument present for compatibilty.
|
91
|
+
retry_invoker: RetryInvoker
|
92
|
+
`RetryInvoker` object that will try to reconnect the client to the server
|
93
|
+
after REST connection errors. If None, the client will only try to
|
94
|
+
reconnect once after a failure.
|
87
95
|
max_message_length : int
|
88
96
|
Ignored, only present to preserve API-compatibility.
|
89
97
|
root_certificates : Optional[Union[bytes, str]] (default: None)
|
@@ -134,7 +142,8 @@ def http_request_response(
|
|
134
142
|
create_node_req_proto = CreateNodeRequest()
|
135
143
|
create_node_req_bytes: bytes = create_node_req_proto.SerializeToString()
|
136
144
|
|
137
|
-
res =
|
145
|
+
res = retry_invoker.invoke(
|
146
|
+
requests.post,
|
138
147
|
url=f"{base_url}/{PATH_CREATE_NODE}",
|
139
148
|
headers={
|
140
149
|
"Accept": "application/protobuf",
|
@@ -177,7 +186,8 @@ def http_request_response(
|
|
177
186
|
node: Node = cast(Node, node_store[KEY_NODE])
|
178
187
|
delete_node_req_proto = DeleteNodeRequest(node=node)
|
179
188
|
delete_node_req_req_bytes: bytes = delete_node_req_proto.SerializeToString()
|
180
|
-
res =
|
189
|
+
res = retry_invoker.invoke(
|
190
|
+
requests.post,
|
181
191
|
url=f"{base_url}/{PATH_DELETE_NODE}",
|
182
192
|
headers={
|
183
193
|
"Accept": "application/protobuf",
|
@@ -218,7 +228,8 @@ def http_request_response(
|
|
218
228
|
pull_task_ins_req_bytes: bytes = pull_task_ins_req_proto.SerializeToString()
|
219
229
|
|
220
230
|
# Request instructions (task) from server
|
221
|
-
res =
|
231
|
+
res = retry_invoker.invoke(
|
232
|
+
requests.post,
|
222
233
|
url=f"{base_url}/{PATH_PULL_TASK_INS}",
|
223
234
|
headers={
|
224
235
|
"Accept": "application/protobuf",
|
@@ -298,7 +309,8 @@ def http_request_response(
|
|
298
309
|
)
|
299
310
|
|
300
311
|
# Send ClientMessage to server
|
301
|
-
res =
|
312
|
+
res = retry_invoker.invoke(
|
313
|
+
requests.post,
|
302
314
|
url=f"{base_url}/{PATH_PUSH_TASK_RES}",
|
303
315
|
headers={
|
304
316
|
"Accept": "application/protobuf",
|
flwr/common/__init__.py
CHANGED
@@ -15,11 +15,13 @@
|
|
15
15
|
"""Common components shared between server and client."""
|
16
16
|
|
17
17
|
|
18
|
+
from .constant import MessageType as MessageType
|
18
19
|
from .context import Context as Context
|
19
20
|
from .date import now as now
|
20
21
|
from .grpc import GRPC_MAX_MESSAGE_LENGTH
|
21
22
|
from .logger import configure as configure
|
22
23
|
from .logger import log as log
|
24
|
+
from .message import Error as Error
|
23
25
|
from .message import Message as Message
|
24
26
|
from .message import Metadata as Metadata
|
25
27
|
from .parameter import bytes_to_ndarray as bytes_to_ndarray
|
@@ -74,6 +76,7 @@ __all__ = [
|
|
74
76
|
"EventType",
|
75
77
|
"FitIns",
|
76
78
|
"FitRes",
|
79
|
+
"Error",
|
77
80
|
"GetParametersIns",
|
78
81
|
"GetParametersRes",
|
79
82
|
"GetPropertiesIns",
|
@@ -81,6 +84,7 @@ __all__ = [
|
|
81
84
|
"GRPC_MAX_MESSAGE_LENGTH",
|
82
85
|
"log",
|
83
86
|
"Message",
|
87
|
+
"MessageType",
|
84
88
|
"Metadata",
|
85
89
|
"Metrics",
|
86
90
|
"MetricsAggregationFn",
|
flwr/common/constant.py
CHANGED
@@ -42,6 +42,17 @@ MESSAGE_TYPE_FIT = "fit"
|
|
42
42
|
MESSAGE_TYPE_EVALUATE = "evaluate"
|
43
43
|
|
44
44
|
|
45
|
+
class MessageType:
|
46
|
+
"""Message type."""
|
47
|
+
|
48
|
+
TRAIN = "train"
|
49
|
+
EVALUATE = "evaluate"
|
50
|
+
|
51
|
+
def __new__(cls) -> MessageType:
|
52
|
+
"""Prevent instantiation."""
|
53
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
54
|
+
|
55
|
+
|
45
56
|
class SType:
|
46
57
|
"""Serialisation type."""
|
47
58
|
|
flwr/server/app.py
CHANGED
@@ -362,10 +362,10 @@ def run_superlink() -> None:
|
|
362
362
|
f_stop = asyncio.Event() # Does nothing
|
363
363
|
_run_fleet_api_vce(
|
364
364
|
num_supernodes=args.num_supernodes,
|
365
|
-
|
365
|
+
client_app_attr=args.client_app,
|
366
366
|
backend_name=args.backend,
|
367
367
|
backend_config_json_stream=args.backend_config,
|
368
|
-
|
368
|
+
app_dir=args.app_dir,
|
369
369
|
state_factory=state_factory,
|
370
370
|
f_stop=f_stop,
|
371
371
|
)
|
@@ -438,10 +438,10 @@ def _run_fleet_api_grpc_rere(
|
|
438
438
|
# pylint: disable=too-many-arguments
|
439
439
|
def _run_fleet_api_vce(
|
440
440
|
num_supernodes: int,
|
441
|
-
|
441
|
+
client_app_attr: str,
|
442
442
|
backend_name: str,
|
443
443
|
backend_config_json_stream: str,
|
444
|
-
|
444
|
+
app_dir: str,
|
445
445
|
state_factory: StateFactory,
|
446
446
|
f_stop: asyncio.Event,
|
447
447
|
) -> None:
|
@@ -449,11 +449,11 @@ def _run_fleet_api_vce(
|
|
449
449
|
|
450
450
|
start_vce(
|
451
451
|
num_supernodes=num_supernodes,
|
452
|
-
|
452
|
+
client_app_attr=client_app_attr,
|
453
453
|
backend_name=backend_name,
|
454
454
|
backend_config_json_stream=backend_config_json_stream,
|
455
455
|
state_factory=state_factory,
|
456
|
-
|
456
|
+
app_dir=app_dir,
|
457
457
|
f_stop=f_stop,
|
458
458
|
)
|
459
459
|
|
@@ -705,7 +705,7 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
|
|
705
705
|
"`flwr.common.typing.ConfigsRecordValues`. ",
|
706
706
|
)
|
707
707
|
parser.add_argument(
|
708
|
-
"--dir",
|
708
|
+
"--app-dir",
|
709
709
|
default="",
|
710
710
|
help="Add specified directory to the PYTHONPATH and load"
|
711
711
|
"ClientApp from there."
|
flwr/server/run_serverapp.py
CHANGED
@@ -16,7 +16,6 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import argparse
|
19
|
-
import asyncio
|
20
19
|
import sys
|
21
20
|
from logging import DEBUG, WARN
|
22
21
|
from pathlib import Path
|
@@ -30,17 +29,27 @@ from .server_app import ServerApp, load_server_app
|
|
30
29
|
|
31
30
|
|
32
31
|
def run(
|
33
|
-
server_app_attr: str,
|
34
32
|
driver: Driver,
|
35
33
|
server_app_dir: str,
|
36
|
-
|
34
|
+
server_app_attr: Optional[str] = None,
|
35
|
+
loaded_server_app: Optional[ServerApp] = None,
|
37
36
|
) -> None:
|
38
37
|
"""Run ServerApp with a given Driver."""
|
38
|
+
if not (server_app_attr is None) ^ (loaded_server_app is None):
|
39
|
+
raise ValueError(
|
40
|
+
"Either `server_app_attr` or `loaded_server_app` should be set "
|
41
|
+
"but not both. "
|
42
|
+
)
|
43
|
+
|
39
44
|
if server_app_dir is not None:
|
40
45
|
sys.path.insert(0, server_app_dir)
|
41
46
|
|
47
|
+
# Load ServerApp if needed
|
42
48
|
def _load() -> ServerApp:
|
43
|
-
|
49
|
+
if server_app_attr:
|
50
|
+
server_app: ServerApp = load_server_app(server_app_attr)
|
51
|
+
if loaded_server_app:
|
52
|
+
server_app = loaded_server_app
|
44
53
|
return server_app
|
45
54
|
|
46
55
|
server_app = _load()
|
@@ -52,10 +61,6 @@ def run(
|
|
52
61
|
server_app(driver=driver, context=context)
|
53
62
|
|
54
63
|
log(DEBUG, "ServerApp finished running.")
|
55
|
-
# Upon completion, trigger stop event if one was passed
|
56
|
-
if stop_event is not None:
|
57
|
-
log(DEBUG, "Triggering stop event.")
|
58
|
-
stop_event.set()
|
59
64
|
|
60
65
|
|
61
66
|
def run_server_app() -> None:
|
@@ -117,7 +122,7 @@ def run_server_app() -> None:
|
|
117
122
|
)
|
118
123
|
|
119
124
|
# Run the Server App with the Driver
|
120
|
-
run(
|
125
|
+
run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
|
121
126
|
|
122
127
|
# Clean up
|
123
128
|
driver.__del__() # pylint: disable=unnecessary-dunder-call
|
flwr/server/server.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
import concurrent.futures
|
19
19
|
import timeit
|
20
|
-
from logging import
|
20
|
+
from logging import INFO, WARN
|
21
21
|
from typing import Dict, List, Optional, Tuple, Union
|
22
22
|
|
23
23
|
from flwr.common import (
|
@@ -173,7 +173,7 @@ class Server:
|
|
173
173
|
log(INFO, "evaluate_round %s: no clients selected, cancel", server_round)
|
174
174
|
return None
|
175
175
|
log(
|
176
|
-
|
176
|
+
INFO,
|
177
177
|
"evaluate_round %s: strategy sampled %s clients (out of %s)",
|
178
178
|
server_round,
|
179
179
|
len(client_instructions),
|
@@ -188,7 +188,7 @@ class Server:
|
|
188
188
|
group_id=server_round,
|
189
189
|
)
|
190
190
|
log(
|
191
|
-
|
191
|
+
INFO,
|
192
192
|
"evaluate_round %s received %s results and %s failures",
|
193
193
|
server_round,
|
194
194
|
len(results),
|
@@ -223,7 +223,7 @@ class Server:
|
|
223
223
|
log(INFO, "fit_round %s: no clients selected, cancel", server_round)
|
224
224
|
return None
|
225
225
|
log(
|
226
|
-
|
226
|
+
INFO,
|
227
227
|
"fit_round %s: strategy sampled %s clients (out of %s)",
|
228
228
|
server_round,
|
229
229
|
len(client_instructions),
|
@@ -238,7 +238,7 @@ class Server:
|
|
238
238
|
group_id=server_round,
|
239
239
|
)
|
240
240
|
log(
|
241
|
-
|
241
|
+
INFO,
|
242
242
|
"fit_round %s received %s results and %s failures",
|
243
243
|
server_round,
|
244
244
|
len(results),
|
@@ -49,7 +49,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
49
49
|
self, request: GetNodesRequest, context: grpc.ServicerContext
|
50
50
|
) -> GetNodesResponse:
|
51
51
|
"""Get available nodes."""
|
52
|
-
log(
|
52
|
+
log(DEBUG, "DriverServicer.GetNodes")
|
53
53
|
state: State = self.state_factory.state()
|
54
54
|
all_ids: Set[int] = state.get_nodes(request.run_id)
|
55
55
|
nodes: List[Node] = [
|