flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (64) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/build.py +151 -0
  3. flwr/cli/config_utils.py +18 -46
  4. flwr/cli/new/new.py +42 -18
  5. flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
  6. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
  7. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
  8. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +15 -29
  9. flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
  10. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
  12. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +9 -1
  13. flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
  14. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
  15. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
  16. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
  17. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
  18. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
  19. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
  20. flwr/cli/run/run.py +1 -1
  21. flwr/cli/utils.py +18 -17
  22. flwr/client/__init__.py +1 -1
  23. flwr/client/app.py +17 -93
  24. flwr/client/grpc_client/connection.py +6 -1
  25. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  26. flwr/client/grpc_rere_client/connection.py +17 -2
  27. flwr/client/mod/centraldp_mods.py +4 -2
  28. flwr/client/mod/localdp_mod.py +9 -3
  29. flwr/client/rest_client/connection.py +5 -1
  30. flwr/client/supernode/__init__.py +2 -0
  31. flwr/client/supernode/app.py +181 -7
  32. flwr/common/grpc.py +5 -1
  33. flwr/common/logger.py +37 -4
  34. flwr/common/message.py +105 -86
  35. flwr/common/record/parametersrecord.py +0 -1
  36. flwr/common/record/recordset.py +17 -5
  37. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
  38. flwr/server/app.py +111 -1
  39. flwr/server/compat/app.py +2 -2
  40. flwr/server/compat/app_utils.py +1 -1
  41. flwr/server/compat/driver_client_proxy.py +27 -72
  42. flwr/server/driver/__init__.py +3 -0
  43. flwr/server/driver/driver.py +12 -242
  44. flwr/server/driver/grpc_driver.py +315 -0
  45. flwr/server/run_serverapp.py +18 -4
  46. flwr/server/strategy/dp_adaptive_clipping.py +5 -3
  47. flwr/server/strategy/dp_fixed_clipping.py +6 -3
  48. flwr/server/superlink/driver/driver_servicer.py +1 -1
  49. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
  50. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
  51. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  52. flwr/server/superlink/fleet/vce/vce_api.py +1 -1
  53. flwr/server/superlink/state/in_memory_state.py +76 -8
  54. flwr/server/superlink/state/sqlite_state.py +116 -11
  55. flwr/server/superlink/state/state.py +35 -3
  56. flwr/simulation/__init__.py +2 -2
  57. flwr/simulation/app.py +16 -1
  58. flwr/simulation/run_simulation.py +10 -7
  59. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
  60. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +63 -52
  61. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +1 -1
  62. flwr/server/driver/abc_driver.py +0 -140
  63. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
  64. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
@@ -1,18 +1,26 @@
1
1
  """$project_name: A Flower / TensorFlow app."""
2
2
 
3
+ from flwr.common import ndarrays_to_parameters
3
4
  from flwr.server import ServerApp, ServerConfig
4
5
  from flwr.server.strategy import FedAvg
5
6
 
7
+ from $import_name.task import load_model
8
+
6
9
  # Define config
7
10
  config = ServerConfig(num_rounds=3)
8
11
 
12
+ parameters = ndarrays_to_parameters(load_model().get_weights())
13
+
14
+ # Define strategy
9
15
  strategy = FedAvg(
10
16
  fraction_fit=1.0,
11
17
  fraction_evaluate=1.0,
12
18
  min_available_clients=2,
19
+ initial_parameters=parameters,
13
20
  )
14
21
 
15
- # Flower ServerApp
22
+
23
+ # Create ServerApp
16
24
  app = ServerApp(
17
25
  config=config,
18
26
  strategy=strategy,
@@ -0,0 +1,89 @@
1
+ """$project_name: A Flower / MLX app."""
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+ from datasets.utils.logging import disable_progress_bar
7
+ from flwr_datasets import FederatedDataset
8
+
9
+
10
+ disable_progress_bar()
11
+
12
+ class MLP(nn.Module):
13
+ """A simple MLP."""
14
+
15
+ def __init__(
16
+ self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
17
+ ):
18
+ super().__init__()
19
+ layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
20
+ self.layers = [
21
+ nn.Linear(idim, odim)
22
+ for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
23
+ ]
24
+
25
+ def __call__(self, x):
26
+ for l in self.layers[:-1]:
27
+ x = mx.maximum(l(x), 0.0)
28
+ return self.layers[-1](x)
29
+
30
+
31
+ def loss_fn(model, X, y):
32
+ return mx.mean(nn.losses.cross_entropy(model(X), y))
33
+
34
+
35
+ def eval_fn(model, X, y):
36
+ return mx.mean(mx.argmax(model(X), axis=1) == y)
37
+
38
+
39
+ def batch_iterate(batch_size, X, y):
40
+ perm = mx.array(np.random.permutation(y.size))
41
+ for s in range(0, y.size, batch_size):
42
+ ids = perm[s : s + batch_size]
43
+ yield X[ids], y[ids]
44
+
45
+
46
+ def load_data(partition_id, num_clients):
47
+ fds = FederatedDataset(dataset="mnist", partitioners={"train": num_clients})
48
+ partition = fds.load_partition(partition_id)
49
+ partition_splits = partition.train_test_split(test_size=0.2, seed=42)
50
+
51
+ partition_splits["train"].set_format("numpy")
52
+ partition_splits["test"].set_format("numpy")
53
+
54
+ train_partition = partition_splits["train"].map(
55
+ lambda img: {
56
+ "img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
57
+ },
58
+ input_columns="image",
59
+ )
60
+ test_partition = partition_splits["test"].map(
61
+ lambda img: {
62
+ "img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
63
+ },
64
+ input_columns="image",
65
+ )
66
+
67
+ data = (
68
+ train_partition["img"],
69
+ train_partition["label"].astype(np.uint32),
70
+ test_partition["img"],
71
+ test_partition["label"].astype(np.uint32),
72
+ )
73
+
74
+ train_images, train_labels, test_images, test_labels = map(mx.array, data)
75
+ return train_images, train_labels, test_images, test_labels
76
+
77
+
78
+ def get_params(model):
79
+ layers = model.parameters()["layers"]
80
+ return [np.array(val) for layer in layers for _, val in layer.items()]
81
+
82
+
83
+ def set_params(model, parameters):
84
+ new_params = {}
85
+ new_params["layers"] = [
86
+ {"weight": mx.array(parameters[i]), "bias": mx.array(parameters[i + 1])}
87
+ for i in range(0, len(parameters), 2)
88
+ ]
89
+ model.update(new_params)
@@ -0,0 +1,29 @@
1
+ """$project_name: A Flower / TensorFlow app."""
2
+
3
+ import os
4
+
5
+ import tensorflow as tf
6
+ from flwr_datasets import FederatedDataset
7
+
8
+
9
+ # Make TensorFlow log less verbose
10
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
11
+
12
+ def load_model():
13
+ # Load model and data (MobileNetV2, CIFAR-10)
14
+ model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
15
+ model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
16
+ return model
17
+
18
+
19
+ def load_data(partition_id, num_partitions):
20
+ # Download and partition dataset
21
+ fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
22
+ partition = fds.load_partition(partition_id, "train")
23
+ partition.set_format("numpy")
24
+
25
+ # Divide data on each node: 80% train, 20% test
26
+ partition = partition.train_test_split(test_size=0.2)
27
+ x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"]
28
+ x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"]
29
+ return x_train, y_train, x_test, y_test
@@ -0,0 +1,28 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "$package_name"
7
+ version = "1.0.0"
8
+ description = ""
9
+ authors = [
10
+ { name = "The Flower Authors", email = "hello@flower.ai" },
11
+ ]
12
+ license = { text = "Apache License (2.0)" }
13
+ dependencies = [
14
+ "flwr[simulation]>=1.8.0,<2.0",
15
+ "flwr-datasets[vision]>=0.0.2,<1.0.0",
16
+ "mlx==0.10.0",
17
+ "numpy==1.24.4",
18
+ ]
19
+
20
+ [tool.hatch.build.targets.wheel]
21
+ packages = ["."]
22
+
23
+ [flower]
24
+ publisher = "$username"
25
+
26
+ [flower.components]
27
+ serverapp = "$import_name.server:app"
28
+ clientapp = "$import_name.client:app"
@@ -3,13 +3,13 @@ requires = ["hatchling"]
3
3
  build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
- name = "$project_name"
6
+ name = "$package_name"
7
7
  version = "1.0.0"
8
8
  description = ""
9
9
  authors = [
10
10
  { name = "The Flower Authors", email = "hello@flower.ai" },
11
11
  ]
12
- license = {text = "Apache License (2.0)"}
12
+ license = { text = "Apache License (2.0)" }
13
13
  dependencies = [
14
14
  "flwr[simulation]>=1.8.0,<2.0",
15
15
  "numpy>=1.21.0",
@@ -18,6 +18,9 @@ dependencies = [
18
18
  [tool.hatch.build.targets.wheel]
19
19
  packages = ["."]
20
20
 
21
+ [flower]
22
+ publisher = "$username"
23
+
21
24
  [flower.components]
22
- serverapp = "$project_name.server:app"
23
- clientapp = "$project_name.client:app"
25
+ serverapp = "$import_name.server:app"
26
+ clientapp = "$import_name.client:app"
@@ -3,13 +3,13 @@ requires = ["hatchling"]
3
3
  build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
- name = "$project_name"
6
+ name = "$package_name"
7
7
  version = "1.0.0"
8
8
  description = ""
9
9
  authors = [
10
10
  { name = "The Flower Authors", email = "hello@flower.ai" },
11
11
  ]
12
- license = {text = "Apache License (2.0)"}
12
+ license = { text = "Apache License (2.0)" }
13
13
  dependencies = [
14
14
  "flwr[simulation]>=1.8.0,<2.0",
15
15
  "flwr-datasets[vision]>=0.0.2,<1.0.0",
@@ -20,6 +20,9 @@ dependencies = [
20
20
  [tool.hatch.build.targets.wheel]
21
21
  packages = ["."]
22
22
 
23
+ [flower]
24
+ publisher = "$username"
25
+
23
26
  [flower.components]
24
- serverapp = "$project_name.server:app"
25
- clientapp = "$project_name.client:app"
27
+ serverapp = "$import_name.server:app"
28
+ clientapp = "$import_name.client:app"
@@ -0,0 +1,27 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "$package_name"
7
+ version = "1.0.0"
8
+ description = ""
9
+ authors = [
10
+ { name = "The Flower Authors", email = "hello@flower.ai" },
11
+ ]
12
+ license = { text = "Apache License (2.0)" }
13
+ dependencies = [
14
+ "flwr[simulation]>=1.8.0,<2.0",
15
+ "flwr-datasets[vision]>=0.0.2,<1.0.0",
16
+ "scikit-learn>=1.1.1",
17
+ ]
18
+
19
+ [tool.hatch.build.targets.wheel]
20
+ packages = ["."]
21
+
22
+ [flower]
23
+ publisher = "$username"
24
+
25
+ [flower.components]
26
+ serverapp = "$import_name.server:app"
27
+ clientapp = "$import_name.client:app"
@@ -3,13 +3,13 @@ requires = ["hatchling"]
3
3
  build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
- name = "$project_name"
6
+ name = "$package_name"
7
7
  version = "1.0.0"
8
8
  description = ""
9
9
  authors = [
10
10
  { name = "The Flower Authors", email = "hello@flower.ai" },
11
11
  ]
12
- license = {text = "Apache License (2.0)"}
12
+ license = { text = "Apache License (2.0)" }
13
13
  dependencies = [
14
14
  "flwr[simulation]>=1.8.0,<2.0",
15
15
  "flwr-datasets[vision]>=0.0.2,<1.0.0",
@@ -19,6 +19,9 @@ dependencies = [
19
19
  [tool.hatch.build.targets.wheel]
20
20
  packages = ["."]
21
21
 
22
+ [flower]
23
+ publisher = "$username"
24
+
22
25
  [flower.components]
23
- serverapp = "$project_name.server:app"
24
- clientapp = "$project_name.client:app"
26
+ serverapp = "$import_name.server:app"
27
+ clientapp = "$import_name.client:app"
flwr/cli/run/run.py CHANGED
@@ -30,7 +30,7 @@ def run() -> None:
30
30
 
31
31
  if config is None:
32
32
  typer.secho(
33
- "Project configuration could not be loaded.\nflower.toml is invalid:\n"
33
+ "Project configuration could not be loaded.\npyproject.toml is invalid:\n"
34
34
  + "\n".join([f"- {line}" for line in errors]),
35
35
  fg=typer.colors.RED,
36
36
  bold=True,
flwr/cli/utils.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface utils."""
16
16
 
17
+ import re
17
18
  from typing import Callable, List, Optional, cast
18
19
 
19
20
  import typer
@@ -73,51 +74,51 @@ def prompt_options(text: str, options: List[str]) -> str:
73
74
 
74
75
 
75
76
  def is_valid_project_name(name: str) -> bool:
76
- """Check if the given string is a valid Python module name.
77
+ """Check if the given string is a valid Python project name.
77
78
 
78
- A valid module name must start with a letter or an underscore, and can only contain
79
- letters, digits, and underscores.
79
+ A valid project name must start with a letter and can only contain letters, digits,
80
+ and hyphens.
80
81
  """
81
82
  if not name:
82
83
  return False
83
84
 
84
- # Check if the first character is a letter or underscore
85
- if not (name[0].isalpha() or name[0] == "_"):
85
+ # Check if the first character is a letter
86
+ if not name[0].isalpha():
86
87
  return False
87
88
 
88
- # Check if the rest of the characters are valid (letter, digit, or underscore)
89
+ # Check if the rest of the characters are valid (letter, digit, or dash)
89
90
  for char in name[1:]:
90
- if not (char.isalnum() or char == "_"):
91
+ if not (char.isalnum() or char in "-"):
91
92
  return False
92
93
 
93
94
  return True
94
95
 
95
96
 
96
97
  def sanitize_project_name(name: str) -> str:
97
- """Sanitize the given string to make it a valid Python module name.
98
+ """Sanitize the given string to make it a valid Python project name.
98
99
 
99
- This version replaces hyphens with underscores, removes any characters not allowed
100
- in Python module names, makes the string lowercase, and ensures it starts with a
101
- valid character.
100
+ This version replaces spaces, dots, slashes, and underscores with dashes, removes
101
+ any characters not allowed in Python project names, makes the string lowercase, and
102
+ ensures it starts with a valid character.
102
103
  """
103
- # Replace '-' with '_'
104
- name_with_underscores = name.replace("-", "_").replace(" ", "_")
104
+ # Replace whitespace with '_'
105
+ name_with_hyphens = re.sub(r"[ ./_]", "-", name)
105
106
 
106
107
  # Allowed characters in a module name: letters, digits, underscore
107
108
  allowed_chars = set(
108
- "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"
109
+ "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-"
109
110
  )
110
111
 
111
112
  # Make the string lowercase
112
- sanitized_name = name_with_underscores.lower()
113
+ sanitized_name = name_with_hyphens.lower()
113
114
 
114
115
  # Remove any characters not allowed in Python module names
115
116
  sanitized_name = "".join(c for c in sanitized_name if c in allowed_chars)
116
117
 
117
118
  # Ensure the first character is a letter or underscore
118
- if sanitized_name and (
119
+ while sanitized_name and (
119
120
  sanitized_name[0].isdigit() or sanitized_name[0] not in allowed_chars
120
121
  ):
121
- sanitized_name = "_" + sanitized_name
122
+ sanitized_name = sanitized_name[1:]
122
123
 
123
124
  return sanitized_name
flwr/client/__init__.py CHANGED
@@ -15,12 +15,12 @@
15
15
  """Flower client."""
16
16
 
17
17
 
18
- from .app import run_client_app as run_client_app
19
18
  from .app import start_client as start_client
20
19
  from .app import start_numpy_client as start_numpy_client
21
20
  from .client import Client as Client
22
21
  from .client_app import ClientApp as ClientApp
23
22
  from .numpy_client import NumPyClient as NumPyClient
23
+ from .supernode import run_client_app as run_client_app
24
24
  from .supernode import run_supernode as run_supernode
25
25
  from .typing import ClientFn as ClientFn
26
26
 
flwr/client/app.py CHANGED
@@ -14,13 +14,12 @@
14
14
  # ==============================================================================
15
15
  """Flower client app."""
16
16
 
17
- import argparse
18
17
  import sys
19
18
  import time
20
19
  from logging import DEBUG, ERROR, INFO, WARN
21
- from pathlib import Path
22
20
  from typing import Callable, ContextManager, Optional, Tuple, Type, Union
23
21
 
22
+ from cryptography.hazmat.primitives.asymmetric import ec
24
23
  from grpc import RpcError
25
24
 
26
25
  from flwr.client.client import Client
@@ -36,10 +35,8 @@ from flwr.common.constant import (
36
35
  TRANSPORT_TYPES,
37
36
  ErrorCode,
38
37
  )
39
- from flwr.common.exit_handlers import register_exit_handlers
40
38
  from flwr.common.logger import log, warn_deprecated_feature
41
39
  from flwr.common.message import Error
42
- from flwr.common.object_ref import load_app, validate
43
40
  from flwr.common.retry_invoker import RetryInvoker, exponential
44
41
 
45
42
  from .grpc_client.connection import grpc_connection
@@ -47,94 +44,6 @@ from .grpc_rere_client.connection import grpc_request_response
47
44
  from .message_handler.message_handler import handle_control_message
48
45
  from .node_state import NodeState
49
46
  from .numpy_client import NumPyClient
50
- from .supernode.app import parse_args_run_client_app
51
-
52
-
53
- def run_client_app() -> None:
54
- """Run Flower client app."""
55
- log(INFO, "Long-running Flower client starting")
56
-
57
- event(EventType.RUN_CLIENT_APP_ENTER)
58
-
59
- args = _parse_args_run_client_app().parse_args()
60
-
61
- # Obtain certificates
62
- if args.insecure:
63
- if args.root_certificates is not None:
64
- sys.exit(
65
- "Conflicting options: The '--insecure' flag disables HTTPS, "
66
- "but '--root-certificates' was also specified. Please remove "
67
- "the '--root-certificates' option when running in insecure mode, "
68
- "or omit '--insecure' to use HTTPS."
69
- )
70
- log(
71
- WARN,
72
- "Option `--insecure` was set. "
73
- "Starting insecure HTTP client connected to %s.",
74
- args.server,
75
- )
76
- root_certificates = None
77
- else:
78
- # Load the certificates if provided, or load the system certificates
79
- cert_path = args.root_certificates
80
- if cert_path is None:
81
- root_certificates = None
82
- else:
83
- root_certificates = Path(cert_path).read_bytes()
84
- log(
85
- DEBUG,
86
- "Starting secure HTTPS client connected to %s "
87
- "with the following certificates: %s.",
88
- args.server,
89
- cert_path,
90
- )
91
-
92
- log(
93
- DEBUG,
94
- "Flower will load ClientApp `%s`",
95
- getattr(args, "client-app"),
96
- )
97
-
98
- client_app_dir = args.dir
99
- if client_app_dir is not None:
100
- sys.path.insert(0, client_app_dir)
101
-
102
- app_ref: str = getattr(args, "client-app")
103
- valid, error_msg = validate(app_ref)
104
- if not valid and error_msg:
105
- raise LoadClientAppError(error_msg) from None
106
-
107
- def _load() -> ClientApp:
108
- client_app = load_app(app_ref, LoadClientAppError)
109
-
110
- if not isinstance(client_app, ClientApp):
111
- raise LoadClientAppError(
112
- f"Attribute {app_ref} is not of type {ClientApp}",
113
- ) from None
114
-
115
- return client_app
116
-
117
- _start_client_internal(
118
- server_address=args.server,
119
- load_client_app_fn=_load,
120
- transport="rest" if args.rest else "grpc-rere",
121
- root_certificates=root_certificates,
122
- insecure=args.insecure,
123
- max_retries=args.max_retries,
124
- max_wait_time=args.max_wait_time,
125
- )
126
- register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
127
-
128
-
129
- def _parse_args_run_client_app() -> argparse.ArgumentParser:
130
- """Parse flower-client-app command line arguments."""
131
- parser = argparse.ArgumentParser(
132
- description="Start a Flower client app",
133
- )
134
-
135
- parse_args_run_client_app(parser=parser)
136
-
137
- return parser
138
47
 
139
48
 
140
49
  def _check_actionable_client(
@@ -165,6 +74,9 @@ def start_client(
165
74
  root_certificates: Optional[Union[bytes, str]] = None,
166
75
  insecure: Optional[bool] = None,
167
76
  transport: Optional[str] = None,
77
+ authentication_keys: Optional[
78
+ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
79
+ ] = None,
168
80
  max_retries: Optional[int] = None,
169
81
  max_wait_time: Optional[float] = None,
170
82
  ) -> None:
@@ -249,6 +161,7 @@ def start_client(
249
161
  root_certificates=root_certificates,
250
162
  insecure=insecure,
251
163
  transport=transport,
164
+ authentication_keys=authentication_keys,
252
165
  max_retries=max_retries,
253
166
  max_wait_time=max_wait_time,
254
167
  )
@@ -269,6 +182,9 @@ def _start_client_internal(
269
182
  root_certificates: Optional[Union[bytes, str]] = None,
270
183
  insecure: Optional[bool] = None,
271
184
  transport: Optional[str] = None,
185
+ authentication_keys: Optional[
186
+ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
187
+ ] = None,
272
188
  max_retries: Optional[int] = None,
273
189
  max_wait_time: Optional[float] = None,
274
190
  ) -> None:
@@ -393,6 +309,7 @@ def _start_client_internal(
393
309
  retry_invoker,
394
310
  grpc_max_message_length,
395
311
  root_certificates,
312
+ authentication_keys,
396
313
  ) as conn:
397
314
  # pylint: disable-next=W0612
398
315
  receive, send, create_node, delete_node, get_run = conn
@@ -606,7 +523,14 @@ def start_numpy_client(
606
523
 
607
524
  def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
608
525
  Callable[
609
- [str, bool, RetryInvoker, int, Union[bytes, str, None]],
526
+ [
527
+ str,
528
+ bool,
529
+ RetryInvoker,
530
+ int,
531
+ Union[bytes, str, None],
532
+ Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]],
533
+ ],
610
534
  ContextManager[
611
535
  Tuple[
612
536
  Callable[[], Optional[Message]],
@@ -22,6 +22,8 @@ from pathlib import Path
22
22
  from queue import Queue
23
23
  from typing import Callable, Iterator, Optional, Tuple, Union, cast
24
24
 
25
+ from cryptography.hazmat.primitives.asymmetric import ec
26
+
25
27
  from flwr.common import (
26
28
  DEFAULT_TTL,
27
29
  GRPC_MAX_MESSAGE_LENGTH,
@@ -56,12 +58,15 @@ def on_channel_state_change(channel_connectivity: str) -> None:
56
58
 
57
59
 
58
60
  @contextmanager
59
- def grpc_connection( # pylint: disable=R0915
61
+ def grpc_connection( # pylint: disable=R0913, R0915
60
62
  server_address: str,
61
63
  insecure: bool,
62
64
  retry_invoker: RetryInvoker, # pylint: disable=unused-argument
63
65
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
64
66
  root_certificates: Optional[Union[bytes, str]] = None,
67
+ authentication_keys: Optional[ # pylint: disable=unused-argument
68
+ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
69
+ ] = None,
65
70
  ) -> Iterator[
66
71
  Tuple[
67
72
  Callable[[], Optional[Message]],