flwr-nightly 1.8.0.dev20240304__py3-none-any.whl → 1.8.0.dev20240306__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (33) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/flower_toml.py +151 -0
  3. flwr/cli/new/new.py +1 -0
  4. flwr/cli/new/templates/app/code/client.numpy.py.tpl +24 -0
  5. flwr/cli/new/templates/app/code/server.numpy.py.tpl +12 -0
  6. flwr/cli/new/templates/app/flower.toml.tpl +2 -2
  7. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +2 -0
  8. flwr/cli/run/__init__.py +21 -0
  9. flwr/cli/run/run.py +102 -0
  10. flwr/client/app.py +93 -8
  11. flwr/client/grpc_client/connection.py +16 -14
  12. flwr/client/grpc_rere_client/connection.py +14 -4
  13. flwr/client/message_handler/message_handler.py +5 -10
  14. flwr/client/mod/centraldp_mods.py +5 -5
  15. flwr/client/mod/secure_aggregation/secaggplus_mod.py +2 -2
  16. flwr/client/rest_client/connection.py +16 -4
  17. flwr/common/__init__.py +6 -0
  18. flwr/common/constant.py +21 -4
  19. flwr/server/app.py +7 -7
  20. flwr/server/compat/driver_client_proxy.py +5 -11
  21. flwr/server/run_serverapp.py +14 -9
  22. flwr/server/server.py +5 -5
  23. flwr/server/superlink/driver/driver_servicer.py +1 -1
  24. flwr/server/superlink/fleet/vce/vce_api.py +17 -5
  25. flwr/server/workflow/default_workflows.py +4 -8
  26. flwr/simulation/__init__.py +2 -5
  27. flwr/simulation/ray_transport/ray_client_proxy.py +5 -10
  28. flwr/simulation/run_simulation.py +301 -76
  29. {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/METADATA +4 -3
  30. {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/RECORD +33 -27
  31. {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/entry_points.txt +1 -1
  32. {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/LICENSE +0 -0
  33. {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/WHEEL +0 -0
flwr/cli/app.py CHANGED
@@ -18,6 +18,7 @@ import typer
18
18
 
19
19
  from .example import example
20
20
  from .new import new
21
+ from .run import run
21
22
 
22
23
  app = typer.Typer(
23
24
  help=typer.style(
@@ -30,6 +31,7 @@ app = typer.Typer(
30
31
 
31
32
  app.command()(new)
32
33
  app.command()(example)
34
+ app.command()(run)
33
35
 
34
36
  if __name__ == "__main__":
35
37
  app()
@@ -0,0 +1,151 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utility to validate the `flower.toml` file."""
16
+
17
+ import importlib
18
+ import os
19
+ from typing import Any, Dict, List, Optional, Tuple
20
+
21
+ import tomli
22
+
23
+
24
+ def load_flower_toml(path: Optional[str] = None) -> Optional[Dict[str, Any]]:
25
+ """Load flower.toml and return as dict."""
26
+ if path is None:
27
+ cur_dir = os.getcwd()
28
+ toml_path = os.path.join(cur_dir, "flower.toml")
29
+ else:
30
+ toml_path = path
31
+
32
+ if not os.path.isfile(toml_path):
33
+ return None
34
+
35
+ with open(toml_path, encoding="utf-8") as toml_file:
36
+ data = tomli.loads(toml_file.read())
37
+ return data
38
+
39
+
40
+ def validate_flower_toml_fields(
41
+ config: Dict[str, Any]
42
+ ) -> Tuple[bool, List[str], List[str]]:
43
+ """Validate flower.toml fields."""
44
+ errors = []
45
+ warnings = []
46
+
47
+ if "project" not in config:
48
+ errors.append("Missing [project] section")
49
+ else:
50
+ if "name" not in config["project"]:
51
+ errors.append('Property "name" missing in [project]')
52
+ if "version" not in config["project"]:
53
+ errors.append('Property "version" missing in [project]')
54
+ if "description" not in config["project"]:
55
+ warnings.append('Recommended property "description" missing in [project]')
56
+ if "license" not in config["project"]:
57
+ warnings.append('Recommended property "license" missing in [project]')
58
+ if "authors" not in config["project"]:
59
+ warnings.append('Recommended property "authors" missing in [project]')
60
+
61
+ if "flower" not in config:
62
+ errors.append("Missing [flower] section")
63
+ elif "components" not in config["flower"]:
64
+ errors.append("Missing [flower.components] section")
65
+ else:
66
+ if "serverapp" not in config["flower"]["components"]:
67
+ errors.append('Property "serverapp" missing in [flower.components]')
68
+ if "clientapp" not in config["flower"]["components"]:
69
+ errors.append('Property "clientapp" missing in [flower.components]')
70
+
71
+ return len(errors) == 0, errors, warnings
72
+
73
+
74
+ def validate_object_reference(ref: str) -> Tuple[bool, Optional[str]]:
75
+ """Validate object reference.
76
+
77
+ Returns
78
+ -------
79
+ Tuple[bool, Optional[str]]
80
+ A boolean indicating whether an object reference is valid and
81
+ the reason why it might not be.
82
+ """
83
+ module_str, _, attributes_str = ref.partition(":")
84
+ if not module_str:
85
+ return (
86
+ False,
87
+ f"Missing module in {ref}",
88
+ )
89
+ if not attributes_str:
90
+ return (
91
+ False,
92
+ f"Missing attribute in {ref}",
93
+ )
94
+
95
+ # Load module
96
+ try:
97
+ module = importlib.import_module(module_str)
98
+ except ModuleNotFoundError:
99
+ return False, f"Unable to load module {module_str}"
100
+
101
+ # Recursively load attribute
102
+ attribute = module
103
+ try:
104
+ for attribute_str in attributes_str.split("."):
105
+ attribute = getattr(attribute, attribute_str)
106
+ except AttributeError:
107
+ return (
108
+ False,
109
+ f"Unable to load attribute {attributes_str} from module {module_str}",
110
+ )
111
+
112
+ return (True, None)
113
+
114
+
115
+ def validate_flower_toml(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
116
+ """Validate flower.toml."""
117
+ is_valid, errors, warnings = validate_flower_toml_fields(config)
118
+
119
+ if not is_valid:
120
+ return False, errors, warnings
121
+
122
+ # Validate serverapp
123
+ is_valid, reason = validate_object_reference(
124
+ config["flower"]["components"]["serverapp"]
125
+ )
126
+ if not is_valid and isinstance(reason, str):
127
+ return False, [reason], []
128
+
129
+ # Validate clientapp
130
+ is_valid, reason = validate_object_reference(
131
+ config["flower"]["components"]["clientapp"]
132
+ )
133
+
134
+ if not is_valid and isinstance(reason, str):
135
+ return False, [reason], []
136
+
137
+ return True, [], []
138
+
139
+
140
+ def apply_defaults(
141
+ config: Dict[str, Any],
142
+ defaults: Dict[str, Any],
143
+ ) -> Dict[str, Any]:
144
+ """Apply defaults to config."""
145
+ for key in defaults:
146
+ if key in config:
147
+ if isinstance(config[key], dict) and isinstance(defaults[key], dict):
148
+ apply_defaults(config[key], defaults[key])
149
+ else:
150
+ config[key] = defaults[key]
151
+ return config
flwr/cli/new/new.py CHANGED
@@ -28,6 +28,7 @@ from ..utils import prompt_options
28
28
  class MlFramework(str, Enum):
29
29
  """Available frameworks."""
30
30
 
31
+ NUMPY = "NumPy"
31
32
  PYTORCH = "PyTorch"
32
33
  TENSORFLOW = "TensorFlow"
33
34
 
@@ -0,0 +1,24 @@
1
+ """$project_name: A Flower / NumPy app."""
2
+
3
+ import flwr as fl
4
+ import numpy as np
5
+
6
+
7
+ # Flower client, adapted from Pytorch quickstart example
8
+ class FlowerClient(fl.client.NumPyClient):
9
+ def get_parameters(self, config):
10
+ return [np.ones((1, 1))]
11
+
12
+ def fit(self, parameters, config):
13
+ return ([np.ones((1, 1))], 1, {})
14
+
15
+ def evaluate(self, parameters, config):
16
+ return float(0.0), 1, {"accuracy": float(1.0)}
17
+
18
+
19
+ def client_fn(cid: str):
20
+ return FlowerClient().to_client()
21
+
22
+
23
+ # ClientApp for Flower-Next
24
+ app = fl.client.ClientApp(client_fn=client_fn)
@@ -0,0 +1,12 @@
1
+ """$project_name: A Flower / NumPy app."""
2
+
3
+ import flwr as fl
4
+
5
+ # Configure the strategy
6
+ strategy = fl.server.strategy.FedAvg()
7
+
8
+ # Flower ServerApp
9
+ app = fl.server.ServerApp(
10
+ config=fl.server.ServerConfig(num_rounds=1),
11
+ strategy=strategy,
12
+ )
@@ -1,10 +1,10 @@
1
- [flower]
1
+ [project]
2
2
  name = "$project_name"
3
3
  version = "1.0.0"
4
4
  description = ""
5
5
  license = "Apache-2.0"
6
6
  authors = ["The Flower Authors <hello@flower.ai>"]
7
7
 
8
- [components]
8
+ [flower.components]
9
9
  serverapp = "$project_name.server:app"
10
10
  clientapp = "$project_name.client:app"
@@ -0,0 +1,2 @@
1
+ flwr>=1.8, <2.0
2
+ numpy >= 1.21.0
@@ -0,0 +1,21 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower command line interface `run` command."""
16
+
17
+ from .run import run as run
18
+
19
+ __all__ = [
20
+ "run",
21
+ ]
flwr/cli/run/run.py ADDED
@@ -0,0 +1,102 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower command line interface `run` command."""
16
+
17
+ import sys
18
+
19
+ import typer
20
+
21
+ from flwr.cli.flower_toml import apply_defaults, load_flower_toml, validate_flower_toml
22
+ from flwr.simulation.run_simulation import _run_simulation
23
+
24
+
25
+ def run() -> None:
26
+ """Run Flower project."""
27
+ print(
28
+ typer.style("Loading project configuration... ", fg=typer.colors.BLUE),
29
+ end="",
30
+ )
31
+ config = load_flower_toml()
32
+ if not config:
33
+ print(
34
+ typer.style(
35
+ "Project configuration could not be loaded. "
36
+ "flower.toml does not exist.",
37
+ fg=typer.colors.RED,
38
+ bold=True,
39
+ )
40
+ )
41
+ sys.exit()
42
+ print(typer.style("Success", fg=typer.colors.GREEN))
43
+
44
+ print(
45
+ typer.style("Validating project configuration... ", fg=typer.colors.BLUE),
46
+ end="",
47
+ )
48
+ is_valid, errors, warnings = validate_flower_toml(config)
49
+ if warnings:
50
+ print(
51
+ typer.style(
52
+ "Project configuration is missing the following "
53
+ "recommended properties:\n"
54
+ + "\n".join([f"- {line}" for line in warnings]),
55
+ fg=typer.colors.RED,
56
+ bold=True,
57
+ )
58
+ )
59
+
60
+ if not is_valid:
61
+ print(
62
+ typer.style(
63
+ "Project configuration could not be loaded.\nflower.toml is invalid:\n"
64
+ + "\n".join([f"- {line}" for line in errors]),
65
+ fg=typer.colors.RED,
66
+ bold=True,
67
+ )
68
+ )
69
+ sys.exit()
70
+ print(typer.style("Success", fg=typer.colors.GREEN))
71
+
72
+ # Apply defaults
73
+ defaults = {
74
+ "flower": {
75
+ "engine": {"name": "simulation", "simulation": {"super-node": {"num": 100}}}
76
+ }
77
+ }
78
+ config = apply_defaults(config, defaults)
79
+
80
+ server_app_ref = config["flower"]["components"]["serverapp"]
81
+ client_app_ref = config["flower"]["components"]["clientapp"]
82
+ engine = config["flower"]["engine"]["name"]
83
+
84
+ if engine == "simulation":
85
+ num_supernodes = config["flower"]["engine"]["simulation"]["super-node"]["num"]
86
+
87
+ print(
88
+ typer.style("Starting run... ", fg=typer.colors.BLUE),
89
+ )
90
+ _run_simulation(
91
+ server_app_attr=server_app_ref,
92
+ client_app_attr=client_app_ref,
93
+ num_supernodes=num_supernodes,
94
+ )
95
+ else:
96
+ print(
97
+ typer.style(
98
+ f"Engine '{engine}' is not yet supported in `flwr run`",
99
+ fg=typer.colors.RED,
100
+ bold=True,
101
+ )
102
+ )
flwr/client/app.py CHANGED
@@ -20,7 +20,9 @@ import sys
20
20
  import time
21
21
  from logging import DEBUG, INFO, WARN
22
22
  from pathlib import Path
23
- from typing import Callable, ContextManager, Optional, Tuple, Union
23
+ from typing import Callable, ContextManager, Optional, Tuple, Type, Union
24
+
25
+ from grpc import RpcError
24
26
 
25
27
  from flwr.client.client import Client
26
28
  from flwr.client.client_app import ClientApp
@@ -36,6 +38,7 @@ from flwr.common.constant import (
36
38
  )
37
39
  from flwr.common.exit_handlers import register_exit_handlers
38
40
  from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature
41
+ from flwr.common.retry_invoker import RetryInvoker, exponential
39
42
 
40
43
  from .client_app import load_client_app
41
44
  from .grpc_client.connection import grpc_connection
@@ -104,6 +107,8 @@ def run_client_app() -> None:
104
107
  transport="rest" if args.rest else "grpc-rere",
105
108
  root_certificates=root_certificates,
106
109
  insecure=args.insecure,
110
+ max_retries=args.max_retries,
111
+ max_wait_time=args.max_wait_time,
107
112
  )
108
113
  register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
109
114
 
@@ -141,6 +146,22 @@ def _parse_args_run_client_app() -> argparse.ArgumentParser:
141
146
  default="0.0.0.0:9092",
142
147
  help="Server address",
143
148
  )
149
+ parser.add_argument(
150
+ "--max-retries",
151
+ type=int,
152
+ default=None,
153
+ help="The maximum number of times the client will try to connect to the"
154
+ "server before giving up in case of a connection error. By default,"
155
+ "it is set to None, meaning there is no limit to the number of tries.",
156
+ )
157
+ parser.add_argument(
158
+ "--max-wait-time",
159
+ type=float,
160
+ default=None,
161
+ help="The maximum duration before the client stops trying to"
162
+ "connect to the server in case of connection error. By default, it"
163
+ "is set to None, meaning there is no limit to the total time.",
164
+ )
144
165
  parser.add_argument(
145
166
  "--dir",
146
167
  default="",
@@ -180,6 +201,8 @@ def start_client(
180
201
  root_certificates: Optional[Union[bytes, str]] = None,
181
202
  insecure: Optional[bool] = None,
182
203
  transport: Optional[str] = None,
204
+ max_retries: Optional[int] = None,
205
+ max_wait_time: Optional[float] = None,
183
206
  ) -> None:
184
207
  """Start a Flower client node which connects to a Flower server.
185
208
 
@@ -213,6 +236,14 @@ def start_client(
213
236
  - 'grpc-bidi': gRPC, bidirectional streaming
214
237
  - 'grpc-rere': gRPC, request-response (experimental)
215
238
  - 'rest': HTTP (experimental)
239
+ max_retries: Optional[int] (default: None)
240
+ The maximum number of times the client will try to connect to the
241
+ server before giving up in case of a connection error. If set to None,
242
+ there is no limit to the number of tries.
243
+ max_wait_time: Optional[float] (default: None)
244
+ The maximum duration before the client stops trying to
245
+ connect to the server in case of connection error.
246
+ If set to None, there is no limit to the total time.
216
247
 
217
248
  Examples
218
249
  --------
@@ -254,6 +285,8 @@ def start_client(
254
285
  root_certificates=root_certificates,
255
286
  insecure=insecure,
256
287
  transport=transport,
288
+ max_retries=max_retries,
289
+ max_wait_time=max_wait_time,
257
290
  )
258
291
  event(EventType.START_CLIENT_LEAVE)
259
292
 
@@ -272,6 +305,8 @@ def _start_client_internal(
272
305
  root_certificates: Optional[Union[bytes, str]] = None,
273
306
  insecure: Optional[bool] = None,
274
307
  transport: Optional[str] = None,
308
+ max_retries: Optional[int] = None,
309
+ max_wait_time: Optional[float] = None,
275
310
  ) -> None:
276
311
  """Start a Flower client node which connects to a Flower server.
277
312
 
@@ -299,7 +334,7 @@ def _start_client_internal(
299
334
  The PEM-encoded root certificates as a byte string or a path string.
300
335
  If provided, a secure connection using the certificates will be
301
336
  established to an SSL-enabled Flower server.
302
- insecure : bool (default: True)
337
+ insecure : Optional[bool] (default: None)
303
338
  Starts an insecure gRPC connection when True. Enables HTTPS connection
304
339
  when False, using system certificates if `root_certificates` is None.
305
340
  transport : Optional[str] (default: None)
@@ -307,6 +342,14 @@ def _start_client_internal(
307
342
  - 'grpc-bidi': gRPC, bidirectional streaming
308
343
  - 'grpc-rere': gRPC, request-response (experimental)
309
344
  - 'rest': HTTP (experimental)
345
+ max_retries: Optional[int] (default: None)
346
+ The maximum number of times the client will try to connect to the
347
+ server before giving up in case of a connection error. If set to None,
348
+ there is no limit to the number of tries.
349
+ max_wait_time: Optional[float] (default: None)
350
+ The maximum duration before the client stops trying to
351
+ connect to the server in case of connection error.
352
+ If set to None, there is no limit to the total time.
310
353
  """
311
354
  if insecure is None:
312
355
  insecure = root_certificates is None
@@ -338,7 +381,45 @@ def _start_client_internal(
338
381
  # Both `client` and `client_fn` must not be used directly
339
382
 
340
383
  # Initialize connection context manager
341
- connection, address = _init_connection(transport, server_address)
384
+ connection, address, connection_error_type = _init_connection(
385
+ transport, server_address
386
+ )
387
+
388
+ retry_invoker = RetryInvoker(
389
+ wait_factory=exponential,
390
+ recoverable_exceptions=connection_error_type,
391
+ max_tries=max_retries,
392
+ max_time=max_wait_time,
393
+ on_giveup=lambda retry_state: (
394
+ log(
395
+ WARN,
396
+ "Giving up reconnection after %.2f seconds and %s tries.",
397
+ retry_state.elapsed_time,
398
+ retry_state.tries,
399
+ )
400
+ if retry_state.tries > 1
401
+ else None
402
+ ),
403
+ on_success=lambda retry_state: (
404
+ log(
405
+ INFO,
406
+ "Connection successful after %.2f seconds and %s tries.",
407
+ retry_state.elapsed_time,
408
+ retry_state.tries,
409
+ )
410
+ if retry_state.tries > 1
411
+ else None
412
+ ),
413
+ on_backoff=lambda retry_state: (
414
+ log(WARN, "Connection attempt failed, retrying...")
415
+ if retry_state.tries == 1
416
+ else log(
417
+ DEBUG,
418
+ "Connection attempt failed, retrying in %.2f seconds",
419
+ retry_state.actual_wait,
420
+ )
421
+ ),
422
+ )
342
423
 
343
424
  node_state = NodeState()
344
425
 
@@ -347,6 +428,7 @@ def _start_client_internal(
347
428
  with connection(
348
429
  address,
349
430
  insecure,
431
+ retry_invoker,
350
432
  grpc_max_message_length,
351
433
  root_certificates,
352
434
  ) as conn:
@@ -509,7 +591,7 @@ def start_numpy_client(
509
591
 
510
592
  def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
511
593
  Callable[
512
- [str, bool, int, Union[bytes, str, None]],
594
+ [str, bool, RetryInvoker, int, Union[bytes, str, None]],
513
595
  ContextManager[
514
596
  Tuple[
515
597
  Callable[[], Optional[Message]],
@@ -520,6 +602,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
520
602
  ],
521
603
  ],
522
604
  str,
605
+ Type[Exception],
523
606
  ]:
524
607
  # Parse IP address
525
608
  parsed_address = parse_address(server_address)
@@ -535,6 +618,8 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
535
618
  # Use either gRPC bidirectional streaming or REST request/response
536
619
  if transport == TRANSPORT_TYPE_REST:
537
620
  try:
621
+ from requests.exceptions import ConnectionError as RequestsConnectionError
622
+
538
623
  from .rest_client.connection import http_request_response
539
624
  except ModuleNotFoundError:
540
625
  sys.exit(MISSING_EXTRA_REST)
@@ -543,14 +628,14 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
543
628
  "When using the REST API, please provide `https://` or "
544
629
  "`http://` before the server address (e.g. `http://127.0.0.1:8080`)"
545
630
  )
546
- connection = http_request_response
631
+ connection, error_type = http_request_response, RequestsConnectionError
547
632
  elif transport == TRANSPORT_TYPE_GRPC_RERE:
548
- connection = grpc_request_response
633
+ connection, error_type = grpc_request_response, RpcError
549
634
  elif transport == TRANSPORT_TYPE_GRPC_BIDI:
550
- connection = grpc_connection
635
+ connection, error_type = grpc_connection, RpcError
551
636
  else:
552
637
  raise ValueError(
553
638
  f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})"
554
639
  )
555
640
 
556
- return connection, address
641
+ return connection, address, error_type
@@ -31,14 +31,10 @@ from flwr.common import (
31
31
  )
32
32
  from flwr.common import recordset_compat as compat
33
33
  from flwr.common import serde
34
- from flwr.common.constant import (
35
- MESSAGE_TYPE_EVALUATE,
36
- MESSAGE_TYPE_FIT,
37
- MESSAGE_TYPE_GET_PARAMETERS,
38
- MESSAGE_TYPE_GET_PROPERTIES,
39
- )
34
+ from flwr.common.constant import MessageType, MessageTypeLegacy
40
35
  from flwr.common.grpc import create_channel
41
36
  from flwr.common.logger import log
37
+ from flwr.common.retry_invoker import RetryInvoker
42
38
  from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
43
39
  ClientMessage,
44
40
  Reason,
@@ -62,6 +58,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:
62
58
  def grpc_connection( # pylint: disable=R0915
63
59
  server_address: str,
64
60
  insecure: bool,
61
+ retry_invoker: RetryInvoker, # pylint: disable=unused-argument
65
62
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
66
63
  root_certificates: Optional[Union[bytes, str]] = None,
67
64
  ) -> Iterator[
@@ -80,6 +77,11 @@ def grpc_connection( # pylint: disable=R0915
80
77
  The IPv4 or IPv6 address of the server. If the Flower server runs on the same
81
78
  machine on port 8080, then `server_address` would be `"0.0.0.0:8080"` or
82
79
  `"[::]:8080"`.
80
+ insecure : bool
81
+ Starts an insecure gRPC connection when True. Enables HTTPS connection
82
+ when False, using system certificates if `root_certificates` is None.
83
+ retry_invoker: RetryInvoker
84
+ Unused argument present for compatibilty.
83
85
  max_message_length : int
84
86
  The maximum length of gRPC messages that can be exchanged with the Flower
85
87
  server. The default should be sufficient for most models. Users who train
@@ -141,22 +143,22 @@ def grpc_connection( # pylint: disable=R0915
141
143
  recordset = compat.getpropertiesins_to_recordset(
142
144
  serde.get_properties_ins_from_proto(proto.get_properties_ins)
143
145
  )
144
- message_type = MESSAGE_TYPE_GET_PROPERTIES
146
+ message_type = MessageTypeLegacy.GET_PROPERTIES
145
147
  elif field == "get_parameters_ins":
146
148
  recordset = compat.getparametersins_to_recordset(
147
149
  serde.get_parameters_ins_from_proto(proto.get_parameters_ins)
148
150
  )
149
- message_type = MESSAGE_TYPE_GET_PARAMETERS
151
+ message_type = MessageTypeLegacy.GET_PARAMETERS
150
152
  elif field == "fit_ins":
151
153
  recordset = compat.fitins_to_recordset(
152
154
  serde.fit_ins_from_proto(proto.fit_ins), False
153
155
  )
154
- message_type = MESSAGE_TYPE_FIT
156
+ message_type = MessageType.TRAIN
155
157
  elif field == "evaluate_ins":
156
158
  recordset = compat.evaluateins_to_recordset(
157
159
  serde.evaluate_ins_from_proto(proto.evaluate_ins), False
158
160
  )
159
- message_type = MESSAGE_TYPE_EVALUATE
161
+ message_type = MessageType.EVALUATE
160
162
  elif field == "reconnect_ins":
161
163
  recordset = RecordSet()
162
164
  recordset.configs_records["config"] = ConfigsRecord(
@@ -190,20 +192,20 @@ def grpc_connection( # pylint: disable=R0915
190
192
  message_type = message.metadata.message_type
191
193
 
192
194
  # RecordSet --> *Res --> *Res proto -> ClientMessage proto
193
- if message_type == MESSAGE_TYPE_GET_PROPERTIES:
195
+ if message_type == MessageTypeLegacy.GET_PROPERTIES:
194
196
  getpropres = compat.recordset_to_getpropertiesres(recordset)
195
197
  msg_proto = ClientMessage(
196
198
  get_properties_res=serde.get_properties_res_to_proto(getpropres)
197
199
  )
198
- elif message_type == MESSAGE_TYPE_GET_PARAMETERS:
200
+ elif message_type == MessageTypeLegacy.GET_PARAMETERS:
199
201
  getparamres = compat.recordset_to_getparametersres(recordset, False)
200
202
  msg_proto = ClientMessage(
201
203
  get_parameters_res=serde.get_parameters_res_to_proto(getparamres)
202
204
  )
203
- elif message_type == MESSAGE_TYPE_FIT:
205
+ elif message_type == MessageType.TRAIN:
204
206
  fitres = compat.recordset_to_fitres(recordset, False)
205
207
  msg_proto = ClientMessage(fit_res=serde.fit_res_to_proto(fitres))
206
- elif message_type == MESSAGE_TYPE_EVALUATE:
208
+ elif message_type == MessageType.EVALUATE:
207
209
  evalres = compat.recordset_to_evaluateres(recordset)
208
210
  msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres))
209
211
  elif message_type == "reconnect":