flwr-nightly 1.8.0.dev20240304__py3-none-any.whl → 1.8.0.dev20240306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +2 -0
- flwr/cli/flower_toml.py +151 -0
- flwr/cli/new/new.py +1 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +12 -0
- flwr/cli/new/templates/app/flower.toml.tpl +2 -2
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +2 -0
- flwr/cli/run/__init__.py +21 -0
- flwr/cli/run/run.py +102 -0
- flwr/client/app.py +93 -8
- flwr/client/grpc_client/connection.py +16 -14
- flwr/client/grpc_rere_client/connection.py +14 -4
- flwr/client/message_handler/message_handler.py +5 -10
- flwr/client/mod/centraldp_mods.py +5 -5
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +2 -2
- flwr/client/rest_client/connection.py +16 -4
- flwr/common/__init__.py +6 -0
- flwr/common/constant.py +21 -4
- flwr/server/app.py +7 -7
- flwr/server/compat/driver_client_proxy.py +5 -11
- flwr/server/run_serverapp.py +14 -9
- flwr/server/server.py +5 -5
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +17 -5
- flwr/server/workflow/default_workflows.py +4 -8
- flwr/simulation/__init__.py +2 -5
- flwr/simulation/ray_transport/ray_client_proxy.py +5 -10
- flwr/simulation/run_simulation.py +301 -76
- {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/METADATA +4 -3
- {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/RECORD +33 -27
- {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/entry_points.txt +1 -1
- {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/WHEEL +0 -0
flwr/cli/app.py
CHANGED
@@ -18,6 +18,7 @@ import typer
|
|
18
18
|
|
19
19
|
from .example import example
|
20
20
|
from .new import new
|
21
|
+
from .run import run
|
21
22
|
|
22
23
|
app = typer.Typer(
|
23
24
|
help=typer.style(
|
@@ -30,6 +31,7 @@ app = typer.Typer(
|
|
30
31
|
|
31
32
|
app.command()(new)
|
32
33
|
app.command()(example)
|
34
|
+
app.command()(run)
|
33
35
|
|
34
36
|
if __name__ == "__main__":
|
35
37
|
app()
|
flwr/cli/flower_toml.py
ADDED
@@ -0,0 +1,151 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utility to validate the `flower.toml` file."""
|
16
|
+
|
17
|
+
import importlib
|
18
|
+
import os
|
19
|
+
from typing import Any, Dict, List, Optional, Tuple
|
20
|
+
|
21
|
+
import tomli
|
22
|
+
|
23
|
+
|
24
|
+
def load_flower_toml(path: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
25
|
+
"""Load flower.toml and return as dict."""
|
26
|
+
if path is None:
|
27
|
+
cur_dir = os.getcwd()
|
28
|
+
toml_path = os.path.join(cur_dir, "flower.toml")
|
29
|
+
else:
|
30
|
+
toml_path = path
|
31
|
+
|
32
|
+
if not os.path.isfile(toml_path):
|
33
|
+
return None
|
34
|
+
|
35
|
+
with open(toml_path, encoding="utf-8") as toml_file:
|
36
|
+
data = tomli.loads(toml_file.read())
|
37
|
+
return data
|
38
|
+
|
39
|
+
|
40
|
+
def validate_flower_toml_fields(
|
41
|
+
config: Dict[str, Any]
|
42
|
+
) -> Tuple[bool, List[str], List[str]]:
|
43
|
+
"""Validate flower.toml fields."""
|
44
|
+
errors = []
|
45
|
+
warnings = []
|
46
|
+
|
47
|
+
if "project" not in config:
|
48
|
+
errors.append("Missing [project] section")
|
49
|
+
else:
|
50
|
+
if "name" not in config["project"]:
|
51
|
+
errors.append('Property "name" missing in [project]')
|
52
|
+
if "version" not in config["project"]:
|
53
|
+
errors.append('Property "version" missing in [project]')
|
54
|
+
if "description" not in config["project"]:
|
55
|
+
warnings.append('Recommended property "description" missing in [project]')
|
56
|
+
if "license" not in config["project"]:
|
57
|
+
warnings.append('Recommended property "license" missing in [project]')
|
58
|
+
if "authors" not in config["project"]:
|
59
|
+
warnings.append('Recommended property "authors" missing in [project]')
|
60
|
+
|
61
|
+
if "flower" not in config:
|
62
|
+
errors.append("Missing [flower] section")
|
63
|
+
elif "components" not in config["flower"]:
|
64
|
+
errors.append("Missing [flower.components] section")
|
65
|
+
else:
|
66
|
+
if "serverapp" not in config["flower"]["components"]:
|
67
|
+
errors.append('Property "serverapp" missing in [flower.components]')
|
68
|
+
if "clientapp" not in config["flower"]["components"]:
|
69
|
+
errors.append('Property "clientapp" missing in [flower.components]')
|
70
|
+
|
71
|
+
return len(errors) == 0, errors, warnings
|
72
|
+
|
73
|
+
|
74
|
+
def validate_object_reference(ref: str) -> Tuple[bool, Optional[str]]:
|
75
|
+
"""Validate object reference.
|
76
|
+
|
77
|
+
Returns
|
78
|
+
-------
|
79
|
+
Tuple[bool, Optional[str]]
|
80
|
+
A boolean indicating whether an object reference is valid and
|
81
|
+
the reason why it might not be.
|
82
|
+
"""
|
83
|
+
module_str, _, attributes_str = ref.partition(":")
|
84
|
+
if not module_str:
|
85
|
+
return (
|
86
|
+
False,
|
87
|
+
f"Missing module in {ref}",
|
88
|
+
)
|
89
|
+
if not attributes_str:
|
90
|
+
return (
|
91
|
+
False,
|
92
|
+
f"Missing attribute in {ref}",
|
93
|
+
)
|
94
|
+
|
95
|
+
# Load module
|
96
|
+
try:
|
97
|
+
module = importlib.import_module(module_str)
|
98
|
+
except ModuleNotFoundError:
|
99
|
+
return False, f"Unable to load module {module_str}"
|
100
|
+
|
101
|
+
# Recursively load attribute
|
102
|
+
attribute = module
|
103
|
+
try:
|
104
|
+
for attribute_str in attributes_str.split("."):
|
105
|
+
attribute = getattr(attribute, attribute_str)
|
106
|
+
except AttributeError:
|
107
|
+
return (
|
108
|
+
False,
|
109
|
+
f"Unable to load attribute {attributes_str} from module {module_str}",
|
110
|
+
)
|
111
|
+
|
112
|
+
return (True, None)
|
113
|
+
|
114
|
+
|
115
|
+
def validate_flower_toml(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
|
116
|
+
"""Validate flower.toml."""
|
117
|
+
is_valid, errors, warnings = validate_flower_toml_fields(config)
|
118
|
+
|
119
|
+
if not is_valid:
|
120
|
+
return False, errors, warnings
|
121
|
+
|
122
|
+
# Validate serverapp
|
123
|
+
is_valid, reason = validate_object_reference(
|
124
|
+
config["flower"]["components"]["serverapp"]
|
125
|
+
)
|
126
|
+
if not is_valid and isinstance(reason, str):
|
127
|
+
return False, [reason], []
|
128
|
+
|
129
|
+
# Validate clientapp
|
130
|
+
is_valid, reason = validate_object_reference(
|
131
|
+
config["flower"]["components"]["clientapp"]
|
132
|
+
)
|
133
|
+
|
134
|
+
if not is_valid and isinstance(reason, str):
|
135
|
+
return False, [reason], []
|
136
|
+
|
137
|
+
return True, [], []
|
138
|
+
|
139
|
+
|
140
|
+
def apply_defaults(
|
141
|
+
config: Dict[str, Any],
|
142
|
+
defaults: Dict[str, Any],
|
143
|
+
) -> Dict[str, Any]:
|
144
|
+
"""Apply defaults to config."""
|
145
|
+
for key in defaults:
|
146
|
+
if key in config:
|
147
|
+
if isinstance(config[key], dict) and isinstance(defaults[key], dict):
|
148
|
+
apply_defaults(config[key], defaults[key])
|
149
|
+
else:
|
150
|
+
config[key] = defaults[key]
|
151
|
+
return config
|
flwr/cli/new/new.py
CHANGED
@@ -0,0 +1,24 @@
|
|
1
|
+
"""$project_name: A Flower / NumPy app."""
|
2
|
+
|
3
|
+
import flwr as fl
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
|
7
|
+
# Flower client, adapted from Pytorch quickstart example
|
8
|
+
class FlowerClient(fl.client.NumPyClient):
|
9
|
+
def get_parameters(self, config):
|
10
|
+
return [np.ones((1, 1))]
|
11
|
+
|
12
|
+
def fit(self, parameters, config):
|
13
|
+
return ([np.ones((1, 1))], 1, {})
|
14
|
+
|
15
|
+
def evaluate(self, parameters, config):
|
16
|
+
return float(0.0), 1, {"accuracy": float(1.0)}
|
17
|
+
|
18
|
+
|
19
|
+
def client_fn(cid: str):
|
20
|
+
return FlowerClient().to_client()
|
21
|
+
|
22
|
+
|
23
|
+
# ClientApp for Flower-Next
|
24
|
+
app = fl.client.ClientApp(client_fn=client_fn)
|
@@ -0,0 +1,12 @@
|
|
1
|
+
"""$project_name: A Flower / NumPy app."""
|
2
|
+
|
3
|
+
import flwr as fl
|
4
|
+
|
5
|
+
# Configure the strategy
|
6
|
+
strategy = fl.server.strategy.FedAvg()
|
7
|
+
|
8
|
+
# Flower ServerApp
|
9
|
+
app = fl.server.ServerApp(
|
10
|
+
config=fl.server.ServerConfig(num_rounds=1),
|
11
|
+
strategy=strategy,
|
12
|
+
)
|
@@ -1,10 +1,10 @@
|
|
1
|
-
[
|
1
|
+
[project]
|
2
2
|
name = "$project_name"
|
3
3
|
version = "1.0.0"
|
4
4
|
description = ""
|
5
5
|
license = "Apache-2.0"
|
6
6
|
authors = ["The Flower Authors <hello@flower.ai>"]
|
7
7
|
|
8
|
-
[components]
|
8
|
+
[flower.components]
|
9
9
|
serverapp = "$project_name.server:app"
|
10
10
|
clientapp = "$project_name.client:app"
|
flwr/cli/run/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Flower command line interface `run` command."""
|
16
|
+
|
17
|
+
from .run import run as run
|
18
|
+
|
19
|
+
__all__ = [
|
20
|
+
"run",
|
21
|
+
]
|
flwr/cli/run/run.py
ADDED
@@ -0,0 +1,102 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Flower command line interface `run` command."""
|
16
|
+
|
17
|
+
import sys
|
18
|
+
|
19
|
+
import typer
|
20
|
+
|
21
|
+
from flwr.cli.flower_toml import apply_defaults, load_flower_toml, validate_flower_toml
|
22
|
+
from flwr.simulation.run_simulation import _run_simulation
|
23
|
+
|
24
|
+
|
25
|
+
def run() -> None:
|
26
|
+
"""Run Flower project."""
|
27
|
+
print(
|
28
|
+
typer.style("Loading project configuration... ", fg=typer.colors.BLUE),
|
29
|
+
end="",
|
30
|
+
)
|
31
|
+
config = load_flower_toml()
|
32
|
+
if not config:
|
33
|
+
print(
|
34
|
+
typer.style(
|
35
|
+
"Project configuration could not be loaded. "
|
36
|
+
"flower.toml does not exist.",
|
37
|
+
fg=typer.colors.RED,
|
38
|
+
bold=True,
|
39
|
+
)
|
40
|
+
)
|
41
|
+
sys.exit()
|
42
|
+
print(typer.style("Success", fg=typer.colors.GREEN))
|
43
|
+
|
44
|
+
print(
|
45
|
+
typer.style("Validating project configuration... ", fg=typer.colors.BLUE),
|
46
|
+
end="",
|
47
|
+
)
|
48
|
+
is_valid, errors, warnings = validate_flower_toml(config)
|
49
|
+
if warnings:
|
50
|
+
print(
|
51
|
+
typer.style(
|
52
|
+
"Project configuration is missing the following "
|
53
|
+
"recommended properties:\n"
|
54
|
+
+ "\n".join([f"- {line}" for line in warnings]),
|
55
|
+
fg=typer.colors.RED,
|
56
|
+
bold=True,
|
57
|
+
)
|
58
|
+
)
|
59
|
+
|
60
|
+
if not is_valid:
|
61
|
+
print(
|
62
|
+
typer.style(
|
63
|
+
"Project configuration could not be loaded.\nflower.toml is invalid:\n"
|
64
|
+
+ "\n".join([f"- {line}" for line in errors]),
|
65
|
+
fg=typer.colors.RED,
|
66
|
+
bold=True,
|
67
|
+
)
|
68
|
+
)
|
69
|
+
sys.exit()
|
70
|
+
print(typer.style("Success", fg=typer.colors.GREEN))
|
71
|
+
|
72
|
+
# Apply defaults
|
73
|
+
defaults = {
|
74
|
+
"flower": {
|
75
|
+
"engine": {"name": "simulation", "simulation": {"super-node": {"num": 100}}}
|
76
|
+
}
|
77
|
+
}
|
78
|
+
config = apply_defaults(config, defaults)
|
79
|
+
|
80
|
+
server_app_ref = config["flower"]["components"]["serverapp"]
|
81
|
+
client_app_ref = config["flower"]["components"]["clientapp"]
|
82
|
+
engine = config["flower"]["engine"]["name"]
|
83
|
+
|
84
|
+
if engine == "simulation":
|
85
|
+
num_supernodes = config["flower"]["engine"]["simulation"]["super-node"]["num"]
|
86
|
+
|
87
|
+
print(
|
88
|
+
typer.style("Starting run... ", fg=typer.colors.BLUE),
|
89
|
+
)
|
90
|
+
_run_simulation(
|
91
|
+
server_app_attr=server_app_ref,
|
92
|
+
client_app_attr=client_app_ref,
|
93
|
+
num_supernodes=num_supernodes,
|
94
|
+
)
|
95
|
+
else:
|
96
|
+
print(
|
97
|
+
typer.style(
|
98
|
+
f"Engine '{engine}' is not yet supported in `flwr run`",
|
99
|
+
fg=typer.colors.RED,
|
100
|
+
bold=True,
|
101
|
+
)
|
102
|
+
)
|
flwr/client/app.py
CHANGED
@@ -20,7 +20,9 @@ import sys
|
|
20
20
|
import time
|
21
21
|
from logging import DEBUG, INFO, WARN
|
22
22
|
from pathlib import Path
|
23
|
-
from typing import Callable, ContextManager, Optional, Tuple, Union
|
23
|
+
from typing import Callable, ContextManager, Optional, Tuple, Type, Union
|
24
|
+
|
25
|
+
from grpc import RpcError
|
24
26
|
|
25
27
|
from flwr.client.client import Client
|
26
28
|
from flwr.client.client_app import ClientApp
|
@@ -36,6 +38,7 @@ from flwr.common.constant import (
|
|
36
38
|
)
|
37
39
|
from flwr.common.exit_handlers import register_exit_handlers
|
38
40
|
from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature
|
41
|
+
from flwr.common.retry_invoker import RetryInvoker, exponential
|
39
42
|
|
40
43
|
from .client_app import load_client_app
|
41
44
|
from .grpc_client.connection import grpc_connection
|
@@ -104,6 +107,8 @@ def run_client_app() -> None:
|
|
104
107
|
transport="rest" if args.rest else "grpc-rere",
|
105
108
|
root_certificates=root_certificates,
|
106
109
|
insecure=args.insecure,
|
110
|
+
max_retries=args.max_retries,
|
111
|
+
max_wait_time=args.max_wait_time,
|
107
112
|
)
|
108
113
|
register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
|
109
114
|
|
@@ -141,6 +146,22 @@ def _parse_args_run_client_app() -> argparse.ArgumentParser:
|
|
141
146
|
default="0.0.0.0:9092",
|
142
147
|
help="Server address",
|
143
148
|
)
|
149
|
+
parser.add_argument(
|
150
|
+
"--max-retries",
|
151
|
+
type=int,
|
152
|
+
default=None,
|
153
|
+
help="The maximum number of times the client will try to connect to the"
|
154
|
+
"server before giving up in case of a connection error. By default,"
|
155
|
+
"it is set to None, meaning there is no limit to the number of tries.",
|
156
|
+
)
|
157
|
+
parser.add_argument(
|
158
|
+
"--max-wait-time",
|
159
|
+
type=float,
|
160
|
+
default=None,
|
161
|
+
help="The maximum duration before the client stops trying to"
|
162
|
+
"connect to the server in case of connection error. By default, it"
|
163
|
+
"is set to None, meaning there is no limit to the total time.",
|
164
|
+
)
|
144
165
|
parser.add_argument(
|
145
166
|
"--dir",
|
146
167
|
default="",
|
@@ -180,6 +201,8 @@ def start_client(
|
|
180
201
|
root_certificates: Optional[Union[bytes, str]] = None,
|
181
202
|
insecure: Optional[bool] = None,
|
182
203
|
transport: Optional[str] = None,
|
204
|
+
max_retries: Optional[int] = None,
|
205
|
+
max_wait_time: Optional[float] = None,
|
183
206
|
) -> None:
|
184
207
|
"""Start a Flower client node which connects to a Flower server.
|
185
208
|
|
@@ -213,6 +236,14 @@ def start_client(
|
|
213
236
|
- 'grpc-bidi': gRPC, bidirectional streaming
|
214
237
|
- 'grpc-rere': gRPC, request-response (experimental)
|
215
238
|
- 'rest': HTTP (experimental)
|
239
|
+
max_retries: Optional[int] (default: None)
|
240
|
+
The maximum number of times the client will try to connect to the
|
241
|
+
server before giving up in case of a connection error. If set to None,
|
242
|
+
there is no limit to the number of tries.
|
243
|
+
max_wait_time: Optional[float] (default: None)
|
244
|
+
The maximum duration before the client stops trying to
|
245
|
+
connect to the server in case of connection error.
|
246
|
+
If set to None, there is no limit to the total time.
|
216
247
|
|
217
248
|
Examples
|
218
249
|
--------
|
@@ -254,6 +285,8 @@ def start_client(
|
|
254
285
|
root_certificates=root_certificates,
|
255
286
|
insecure=insecure,
|
256
287
|
transport=transport,
|
288
|
+
max_retries=max_retries,
|
289
|
+
max_wait_time=max_wait_time,
|
257
290
|
)
|
258
291
|
event(EventType.START_CLIENT_LEAVE)
|
259
292
|
|
@@ -272,6 +305,8 @@ def _start_client_internal(
|
|
272
305
|
root_certificates: Optional[Union[bytes, str]] = None,
|
273
306
|
insecure: Optional[bool] = None,
|
274
307
|
transport: Optional[str] = None,
|
308
|
+
max_retries: Optional[int] = None,
|
309
|
+
max_wait_time: Optional[float] = None,
|
275
310
|
) -> None:
|
276
311
|
"""Start a Flower client node which connects to a Flower server.
|
277
312
|
|
@@ -299,7 +334,7 @@ def _start_client_internal(
|
|
299
334
|
The PEM-encoded root certificates as a byte string or a path string.
|
300
335
|
If provided, a secure connection using the certificates will be
|
301
336
|
established to an SSL-enabled Flower server.
|
302
|
-
insecure : bool (default:
|
337
|
+
insecure : Optional[bool] (default: None)
|
303
338
|
Starts an insecure gRPC connection when True. Enables HTTPS connection
|
304
339
|
when False, using system certificates if `root_certificates` is None.
|
305
340
|
transport : Optional[str] (default: None)
|
@@ -307,6 +342,14 @@ def _start_client_internal(
|
|
307
342
|
- 'grpc-bidi': gRPC, bidirectional streaming
|
308
343
|
- 'grpc-rere': gRPC, request-response (experimental)
|
309
344
|
- 'rest': HTTP (experimental)
|
345
|
+
max_retries: Optional[int] (default: None)
|
346
|
+
The maximum number of times the client will try to connect to the
|
347
|
+
server before giving up in case of a connection error. If set to None,
|
348
|
+
there is no limit to the number of tries.
|
349
|
+
max_wait_time: Optional[float] (default: None)
|
350
|
+
The maximum duration before the client stops trying to
|
351
|
+
connect to the server in case of connection error.
|
352
|
+
If set to None, there is no limit to the total time.
|
310
353
|
"""
|
311
354
|
if insecure is None:
|
312
355
|
insecure = root_certificates is None
|
@@ -338,7 +381,45 @@ def _start_client_internal(
|
|
338
381
|
# Both `client` and `client_fn` must not be used directly
|
339
382
|
|
340
383
|
# Initialize connection context manager
|
341
|
-
connection, address = _init_connection(
|
384
|
+
connection, address, connection_error_type = _init_connection(
|
385
|
+
transport, server_address
|
386
|
+
)
|
387
|
+
|
388
|
+
retry_invoker = RetryInvoker(
|
389
|
+
wait_factory=exponential,
|
390
|
+
recoverable_exceptions=connection_error_type,
|
391
|
+
max_tries=max_retries,
|
392
|
+
max_time=max_wait_time,
|
393
|
+
on_giveup=lambda retry_state: (
|
394
|
+
log(
|
395
|
+
WARN,
|
396
|
+
"Giving up reconnection after %.2f seconds and %s tries.",
|
397
|
+
retry_state.elapsed_time,
|
398
|
+
retry_state.tries,
|
399
|
+
)
|
400
|
+
if retry_state.tries > 1
|
401
|
+
else None
|
402
|
+
),
|
403
|
+
on_success=lambda retry_state: (
|
404
|
+
log(
|
405
|
+
INFO,
|
406
|
+
"Connection successful after %.2f seconds and %s tries.",
|
407
|
+
retry_state.elapsed_time,
|
408
|
+
retry_state.tries,
|
409
|
+
)
|
410
|
+
if retry_state.tries > 1
|
411
|
+
else None
|
412
|
+
),
|
413
|
+
on_backoff=lambda retry_state: (
|
414
|
+
log(WARN, "Connection attempt failed, retrying...")
|
415
|
+
if retry_state.tries == 1
|
416
|
+
else log(
|
417
|
+
DEBUG,
|
418
|
+
"Connection attempt failed, retrying in %.2f seconds",
|
419
|
+
retry_state.actual_wait,
|
420
|
+
)
|
421
|
+
),
|
422
|
+
)
|
342
423
|
|
343
424
|
node_state = NodeState()
|
344
425
|
|
@@ -347,6 +428,7 @@ def _start_client_internal(
|
|
347
428
|
with connection(
|
348
429
|
address,
|
349
430
|
insecure,
|
431
|
+
retry_invoker,
|
350
432
|
grpc_max_message_length,
|
351
433
|
root_certificates,
|
352
434
|
) as conn:
|
@@ -509,7 +591,7 @@ def start_numpy_client(
|
|
509
591
|
|
510
592
|
def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
511
593
|
Callable[
|
512
|
-
[str, bool, int, Union[bytes, str, None]],
|
594
|
+
[str, bool, RetryInvoker, int, Union[bytes, str, None]],
|
513
595
|
ContextManager[
|
514
596
|
Tuple[
|
515
597
|
Callable[[], Optional[Message]],
|
@@ -520,6 +602,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
520
602
|
],
|
521
603
|
],
|
522
604
|
str,
|
605
|
+
Type[Exception],
|
523
606
|
]:
|
524
607
|
# Parse IP address
|
525
608
|
parsed_address = parse_address(server_address)
|
@@ -535,6 +618,8 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
535
618
|
# Use either gRPC bidirectional streaming or REST request/response
|
536
619
|
if transport == TRANSPORT_TYPE_REST:
|
537
620
|
try:
|
621
|
+
from requests.exceptions import ConnectionError as RequestsConnectionError
|
622
|
+
|
538
623
|
from .rest_client.connection import http_request_response
|
539
624
|
except ModuleNotFoundError:
|
540
625
|
sys.exit(MISSING_EXTRA_REST)
|
@@ -543,14 +628,14 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
543
628
|
"When using the REST API, please provide `https://` or "
|
544
629
|
"`http://` before the server address (e.g. `http://127.0.0.1:8080`)"
|
545
630
|
)
|
546
|
-
connection = http_request_response
|
631
|
+
connection, error_type = http_request_response, RequestsConnectionError
|
547
632
|
elif transport == TRANSPORT_TYPE_GRPC_RERE:
|
548
|
-
connection = grpc_request_response
|
633
|
+
connection, error_type = grpc_request_response, RpcError
|
549
634
|
elif transport == TRANSPORT_TYPE_GRPC_BIDI:
|
550
|
-
connection = grpc_connection
|
635
|
+
connection, error_type = grpc_connection, RpcError
|
551
636
|
else:
|
552
637
|
raise ValueError(
|
553
638
|
f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})"
|
554
639
|
)
|
555
640
|
|
556
|
-
return connection, address
|
641
|
+
return connection, address, error_type
|
@@ -31,14 +31,10 @@ from flwr.common import (
|
|
31
31
|
)
|
32
32
|
from flwr.common import recordset_compat as compat
|
33
33
|
from flwr.common import serde
|
34
|
-
from flwr.common.constant import
|
35
|
-
MESSAGE_TYPE_EVALUATE,
|
36
|
-
MESSAGE_TYPE_FIT,
|
37
|
-
MESSAGE_TYPE_GET_PARAMETERS,
|
38
|
-
MESSAGE_TYPE_GET_PROPERTIES,
|
39
|
-
)
|
34
|
+
from flwr.common.constant import MessageType, MessageTypeLegacy
|
40
35
|
from flwr.common.grpc import create_channel
|
41
36
|
from flwr.common.logger import log
|
37
|
+
from flwr.common.retry_invoker import RetryInvoker
|
42
38
|
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
43
39
|
ClientMessage,
|
44
40
|
Reason,
|
@@ -62,6 +58,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:
|
|
62
58
|
def grpc_connection( # pylint: disable=R0915
|
63
59
|
server_address: str,
|
64
60
|
insecure: bool,
|
61
|
+
retry_invoker: RetryInvoker, # pylint: disable=unused-argument
|
65
62
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
66
63
|
root_certificates: Optional[Union[bytes, str]] = None,
|
67
64
|
) -> Iterator[
|
@@ -80,6 +77,11 @@ def grpc_connection( # pylint: disable=R0915
|
|
80
77
|
The IPv4 or IPv6 address of the server. If the Flower server runs on the same
|
81
78
|
machine on port 8080, then `server_address` would be `"0.0.0.0:8080"` or
|
82
79
|
`"[::]:8080"`.
|
80
|
+
insecure : bool
|
81
|
+
Starts an insecure gRPC connection when True. Enables HTTPS connection
|
82
|
+
when False, using system certificates if `root_certificates` is None.
|
83
|
+
retry_invoker: RetryInvoker
|
84
|
+
Unused argument present for compatibilty.
|
83
85
|
max_message_length : int
|
84
86
|
The maximum length of gRPC messages that can be exchanged with the Flower
|
85
87
|
server. The default should be sufficient for most models. Users who train
|
@@ -141,22 +143,22 @@ def grpc_connection( # pylint: disable=R0915
|
|
141
143
|
recordset = compat.getpropertiesins_to_recordset(
|
142
144
|
serde.get_properties_ins_from_proto(proto.get_properties_ins)
|
143
145
|
)
|
144
|
-
message_type =
|
146
|
+
message_type = MessageTypeLegacy.GET_PROPERTIES
|
145
147
|
elif field == "get_parameters_ins":
|
146
148
|
recordset = compat.getparametersins_to_recordset(
|
147
149
|
serde.get_parameters_ins_from_proto(proto.get_parameters_ins)
|
148
150
|
)
|
149
|
-
message_type =
|
151
|
+
message_type = MessageTypeLegacy.GET_PARAMETERS
|
150
152
|
elif field == "fit_ins":
|
151
153
|
recordset = compat.fitins_to_recordset(
|
152
154
|
serde.fit_ins_from_proto(proto.fit_ins), False
|
153
155
|
)
|
154
|
-
message_type =
|
156
|
+
message_type = MessageType.TRAIN
|
155
157
|
elif field == "evaluate_ins":
|
156
158
|
recordset = compat.evaluateins_to_recordset(
|
157
159
|
serde.evaluate_ins_from_proto(proto.evaluate_ins), False
|
158
160
|
)
|
159
|
-
message_type =
|
161
|
+
message_type = MessageType.EVALUATE
|
160
162
|
elif field == "reconnect_ins":
|
161
163
|
recordset = RecordSet()
|
162
164
|
recordset.configs_records["config"] = ConfigsRecord(
|
@@ -190,20 +192,20 @@ def grpc_connection( # pylint: disable=R0915
|
|
190
192
|
message_type = message.metadata.message_type
|
191
193
|
|
192
194
|
# RecordSet --> *Res --> *Res proto -> ClientMessage proto
|
193
|
-
if message_type ==
|
195
|
+
if message_type == MessageTypeLegacy.GET_PROPERTIES:
|
194
196
|
getpropres = compat.recordset_to_getpropertiesres(recordset)
|
195
197
|
msg_proto = ClientMessage(
|
196
198
|
get_properties_res=serde.get_properties_res_to_proto(getpropres)
|
197
199
|
)
|
198
|
-
elif message_type ==
|
200
|
+
elif message_type == MessageTypeLegacy.GET_PARAMETERS:
|
199
201
|
getparamres = compat.recordset_to_getparametersres(recordset, False)
|
200
202
|
msg_proto = ClientMessage(
|
201
203
|
get_parameters_res=serde.get_parameters_res_to_proto(getparamres)
|
202
204
|
)
|
203
|
-
elif message_type ==
|
205
|
+
elif message_type == MessageType.TRAIN:
|
204
206
|
fitres = compat.recordset_to_fitres(recordset, False)
|
205
207
|
msg_proto = ClientMessage(fit_res=serde.fit_res_to_proto(fitres))
|
206
|
-
elif message_type ==
|
208
|
+
elif message_type == MessageType.EVALUATE:
|
207
209
|
evalres = compat.recordset_to_evaluateres(recordset)
|
208
210
|
msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres))
|
209
211
|
elif message_type == "reconnect":
|