flwr-nightly 1.9.0.dev20240423__py3-none-any.whl → 1.9.0.dev20240425__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (30) hide show
  1. flwr/cli/config_utils.py +18 -46
  2. flwr/cli/new/new.py +37 -17
  3. flwr/cli/new/templates/app/README.md.tpl +1 -1
  4. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
  5. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +6 -3
  7. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +6 -3
  8. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +6 -3
  9. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +6 -3
  10. flwr/cli/run/run.py +1 -1
  11. flwr/cli/utils.py +18 -17
  12. flwr/client/grpc_client/connection.py +6 -1
  13. flwr/client/grpc_rere_client/client_interceptor.py +150 -0
  14. flwr/client/grpc_rere_client/connection.py +17 -2
  15. flwr/client/rest_client/connection.py +5 -1
  16. flwr/common/grpc.py +5 -1
  17. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +20 -1
  18. flwr/server/compat/app_utils.py +1 -1
  19. flwr/server/compat/driver_client_proxy.py +27 -72
  20. flwr/server/driver/grpc_driver.py +17 -8
  21. flwr/server/run_serverapp.py +14 -0
  22. flwr/server/superlink/driver/driver_servicer.py +1 -1
  23. flwr/server/superlink/state/in_memory_state.py +37 -1
  24. flwr/server/superlink/state/sqlite_state.py +71 -4
  25. flwr/server/superlink/state/state.py +26 -0
  26. {flwr_nightly-1.9.0.dev20240423.dist-info → flwr_nightly-1.9.0.dev20240425.dist-info}/METADATA +1 -1
  27. {flwr_nightly-1.9.0.dev20240423.dist-info → flwr_nightly-1.9.0.dev20240425.dist-info}/RECORD +30 -29
  28. {flwr_nightly-1.9.0.dev20240423.dist-info → flwr_nightly-1.9.0.dev20240425.dist-info}/LICENSE +0 -0
  29. {flwr_nightly-1.9.0.dev20240423.dist-info → flwr_nightly-1.9.0.dev20240425.dist-info}/WHEEL +0 -0
  30. {flwr_nightly-1.9.0.dev20240423.dist-info → flwr_nightly-1.9.0.dev20240425.dist-info}/entry_points.txt +0 -0
flwr/cli/config_utils.py CHANGED
@@ -14,49 +14,16 @@
14
14
  # ==============================================================================
15
15
  """Utility to validate the `pyproject.toml` file."""
16
16
 
17
- import os
18
17
  from pathlib import Path
19
18
  from typing import Any, Dict, List, Optional, Tuple
20
19
 
21
20
  import tomli
22
- import typer
23
21
 
24
22
  from flwr.common import object_ref
25
23
 
26
24
 
27
- def validate_project_dir(project_dir: Path) -> Optional[Dict[str, Any]]:
28
- """Check if a Flower App directory is valid."""
29
- config = load(str(project_dir / "pyproject.toml"))
30
- if config is None:
31
- typer.secho(
32
- "❌ Project configuration could not be loaded. "
33
- "`pyproject.toml` does not exist.",
34
- fg=typer.colors.RED,
35
- bold=True,
36
- )
37
- return None
38
-
39
- if not validate(config):
40
- typer.secho(
41
- "❌ Project configuration is invalid.",
42
- fg=typer.colors.RED,
43
- bold=True,
44
- )
45
- return None
46
-
47
- if "publisher" not in config["flower"]:
48
- typer.secho(
49
- "❌ Project configuration is missing required `publisher` field.",
50
- fg=typer.colors.RED,
51
- bold=True,
52
- )
53
- return None
54
-
55
- return config
56
-
57
-
58
25
  def load_and_validate_with_defaults(
59
- path: Optional[str] = None,
26
+ path: Optional[Path] = None,
60
27
  ) -> Tuple[Optional[Dict[str, Any]], List[str], List[str]]:
61
28
  """Load and validate pyproject.toml as dict.
62
29
 
@@ -70,7 +37,8 @@ def load_and_validate_with_defaults(
70
37
 
71
38
  if config is None:
72
39
  errors = [
73
- "Project configuration could not be loaded. pyproject.toml does not exist."
40
+ "Project configuration could not be loaded. "
41
+ "`pyproject.toml` does not exist."
74
42
  ]
75
43
  return (None, errors, [])
76
44
 
@@ -90,22 +58,23 @@ def load_and_validate_with_defaults(
90
58
  return (config, errors, warnings)
91
59
 
92
60
 
93
- def load(path: Optional[str] = None) -> Optional[Dict[str, Any]]:
61
+ def load(path: Optional[Path] = None) -> Optional[Dict[str, Any]]:
94
62
  """Load pyproject.toml and return as dict."""
95
63
  if path is None:
96
- cur_dir = os.getcwd()
97
- toml_path = os.path.join(cur_dir, "pyproject.toml")
64
+ cur_dir = Path.cwd()
65
+ toml_path = cur_dir / "pyproject.toml"
98
66
  else:
99
67
  toml_path = path
100
68
 
101
- if not os.path.isfile(toml_path):
69
+ if not toml_path.is_file():
102
70
  return None
103
71
 
104
- with open(toml_path, encoding="utf-8") as toml_file:
72
+ with toml_path.open(encoding="utf-8") as toml_file:
105
73
  data = tomli.loads(toml_file.read())
106
74
  return data
107
75
 
108
76
 
77
+ # pylint: disable=too-many-branches
109
78
  def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
110
79
  """Validate pyproject.toml fields."""
111
80
  errors = []
@@ -127,13 +96,16 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]
127
96
 
128
97
  if "flower" not in config:
129
98
  errors.append("Missing [flower] section")
130
- elif "components" not in config["flower"]:
131
- errors.append("Missing [flower.components] section")
132
99
  else:
133
- if "serverapp" not in config["flower"]["components"]:
134
- errors.append('Property "serverapp" missing in [flower.components]')
135
- if "clientapp" not in config["flower"]["components"]:
136
- errors.append('Property "clientapp" missing in [flower.components]')
100
+ if "publisher" not in config["flower"]:
101
+ errors.append('Property "publisher" missing in [flower]')
102
+ if "components" not in config["flower"]:
103
+ errors.append("Missing [flower.components] section")
104
+ else:
105
+ if "serverapp" not in config["flower"]["components"]:
106
+ errors.append('Property "serverapp" missing in [flower.components]')
107
+ if "clientapp" not in config["flower"]["components"]:
108
+ errors.append('Property "clientapp" missing in [flower.components]')
137
109
 
138
110
  return len(errors) == 0, errors, warnings
139
111
 
flwr/cli/new/new.py CHANGED
@@ -15,6 +15,7 @@
15
15
  """Flower command line interface `new` command."""
16
16
 
17
17
  import os
18
+ import re
18
19
  from enum import Enum
19
20
  from string import Template
20
21
  from typing import Dict, Optional
@@ -86,25 +87,24 @@ def new(
86
87
  Optional[MlFramework],
87
88
  typer.Option(case_sensitive=False, help="The ML framework to use"),
88
89
  ] = None,
90
+ username: Annotated[
91
+ Optional[str],
92
+ typer.Option(case_sensitive=False, help="The Flower username of the author"),
93
+ ] = None,
89
94
  ) -> None:
90
95
  """Create new Flower project."""
91
96
  if project_name is None:
92
- project_name = prompt_text("Please provide project name")
97
+ project_name = prompt_text("Please provide the project name")
93
98
  if not is_valid_project_name(project_name):
94
99
  project_name = prompt_text(
95
100
  "Please provide a name that only contains "
96
- "characters in {'_', 'a-zA-Z', '0-9'}",
101
+ "characters in {'-', a-zA-Z', '0-9'}",
97
102
  predicate=is_valid_project_name,
98
103
  default=sanitize_project_name(project_name),
99
104
  )
100
105
 
101
- print(
102
- typer.style(
103
- f"🔨 Creating Flower project {project_name}...",
104
- fg=typer.colors.GREEN,
105
- bold=True,
106
- )
107
- )
106
+ if username is None:
107
+ username = prompt_text("Please provide your Flower username")
108
108
 
109
109
  if framework is not None:
110
110
  framework_str = str(framework.value)
@@ -122,19 +122,32 @@ def new(
122
122
 
123
123
  framework_str = framework_str.lower()
124
124
 
125
+ print(
126
+ typer.style(
127
+ f"\n🔨 Creating Flower project {project_name}...",
128
+ fg=typer.colors.GREEN,
129
+ bold=True,
130
+ )
131
+ )
132
+
125
133
  # Set project directory path
126
134
  cwd = os.getcwd()
127
- pnl = project_name.lower()
128
- project_dir = os.path.join(cwd, pnl)
135
+ package_name = re.sub(r"[-_.]+", "-", project_name).lower()
136
+ import_name = package_name.replace("-", "_")
137
+ project_dir = os.path.join(cwd, package_name)
129
138
 
130
139
  # List of files to render
131
140
  files = {
132
141
  ".gitignore": {"template": "app/.gitignore.tpl"},
133
142
  "README.md": {"template": "app/README.md.tpl"},
134
143
  "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
135
- f"{pnl}/__init__.py": {"template": "app/code/__init__.py.tpl"},
136
- f"{pnl}/server.py": {"template": f"app/code/server.{framework_str}.py.tpl"},
137
- f"{pnl}/client.py": {"template": f"app/code/client.{framework_str}.py.tpl"},
144
+ f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
145
+ f"{import_name}/server.py": {
146
+ "template": f"app/code/server.{framework_str}.py.tpl"
147
+ },
148
+ f"{import_name}/client.py": {
149
+ "template": f"app/code/client.{framework_str}.py.tpl"
150
+ },
138
151
  }
139
152
 
140
153
  # Depending on the framework, generate task.py file
@@ -142,9 +155,16 @@ def new(
142
155
  MlFramework.PYTORCH.value.lower(),
143
156
  ]
144
157
  if framework_str in frameworks_with_tasks:
145
- files[f"{pnl}/task.py"] = {"template": f"app/code/task.{framework_str}.py.tpl"}
146
-
147
- context = {"project_name": project_name}
158
+ files[f"{import_name}/task.py"] = {
159
+ "template": f"app/code/task.{framework_str}.py.tpl"
160
+ }
161
+
162
+ context = {
163
+ "project_name": project_name,
164
+ "package_name": package_name,
165
+ "import_name": import_name.replace("-", "_"),
166
+ "username": username,
167
+ }
148
168
 
149
169
  for file_path, value in files.items():
150
170
  render_and_create(
@@ -8,7 +8,7 @@ pip install .
8
8
 
9
9
  ## Run (Simulation Engine)
10
10
 
11
- In the `$project_name` directory, use `flwr run` to run a local simulation:
11
+ In the `$import_name` directory, use `flwr run` to run a local simulation:
12
12
 
13
13
  ```bash
14
14
  flwr run
@@ -2,7 +2,7 @@
2
2
 
3
3
  from flwr.client import NumPyClient, ClientApp
4
4
 
5
- from $project_name.task import (
5
+ from $import_name.task import (
6
6
  Net,
7
7
  DEVICE,
8
8
  load_data,
@@ -4,7 +4,7 @@ from flwr.common import ndarrays_to_parameters
4
4
  from flwr.server import ServerApp, ServerConfig
5
5
  from flwr.server.strategy import FedAvg
6
6
 
7
- from $project_name.task import Net, get_weights
7
+ from $import_name.task import Net, get_weights
8
8
 
9
9
 
10
10
  # Initialize model parameters
@@ -3,7 +3,7 @@ requires = ["hatchling"]
3
3
  build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
- name = "$project_name"
6
+ name = "$package_name"
7
7
  version = "1.0.0"
8
8
  description = ""
9
9
  authors = [
@@ -18,6 +18,9 @@ dependencies = [
18
18
  [tool.hatch.build.targets.wheel]
19
19
  packages = ["."]
20
20
 
21
+ [flower]
22
+ publisher = "$username"
23
+
21
24
  [flower.components]
22
- serverapp = "$project_name.server:app"
23
- clientapp = "$project_name.client:app"
25
+ serverapp = "$import_name.server:app"
26
+ clientapp = "$import_name.client:app"
@@ -3,7 +3,7 @@ requires = ["hatchling"]
3
3
  build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
- name = "$project_name"
6
+ name = "$package_name"
7
7
  version = "1.0.0"
8
8
  description = ""
9
9
  authors = [
@@ -20,6 +20,9 @@ dependencies = [
20
20
  [tool.hatch.build.targets.wheel]
21
21
  packages = ["."]
22
22
 
23
+ [flower]
24
+ publisher = "$username"
25
+
23
26
  [flower.components]
24
- serverapp = "$project_name.server:app"
25
- clientapp = "$project_name.client:app"
27
+ serverapp = "$import_name.server:app"
28
+ clientapp = "$import_name.client:app"
@@ -3,7 +3,7 @@ requires = ["hatchling"]
3
3
  build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
- name = "$project_name"
6
+ name = "$package_name"
7
7
  version = "1.0.0"
8
8
  description = ""
9
9
  authors = [
@@ -19,6 +19,9 @@ dependencies = [
19
19
  [tool.hatch.build.targets.wheel]
20
20
  packages = ["."]
21
21
 
22
+ [flower]
23
+ publisher = "$username"
24
+
22
25
  [flower.components]
23
- serverapp = "$project_name.server:app"
24
- clientapp = "$project_name.client:app"
26
+ serverapp = "$import_name.server:app"
27
+ clientapp = "$import_name.client:app"
@@ -3,7 +3,7 @@ requires = ["hatchling"]
3
3
  build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
- name = "$project_name"
6
+ name = "$package_name"
7
7
  version = "1.0.0"
8
8
  description = ""
9
9
  authors = [
@@ -19,6 +19,9 @@ dependencies = [
19
19
  [tool.hatch.build.targets.wheel]
20
20
  packages = ["."]
21
21
 
22
+ [flower]
23
+ publisher = "$username"
24
+
22
25
  [flower.components]
23
- serverapp = "$project_name.server:app"
24
- clientapp = "$project_name.client:app"
26
+ serverapp = "$import_name.server:app"
27
+ clientapp = "$import_name.client:app"
flwr/cli/run/run.py CHANGED
@@ -30,7 +30,7 @@ def run() -> None:
30
30
 
31
31
  if config is None:
32
32
  typer.secho(
33
- "Project configuration could not be loaded.\nflower.toml is invalid:\n"
33
+ "Project configuration could not be loaded.\npyproject.toml is invalid:\n"
34
34
  + "\n".join([f"- {line}" for line in errors]),
35
35
  fg=typer.colors.RED,
36
36
  bold=True,
flwr/cli/utils.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface utils."""
16
16
 
17
+ import re
17
18
  from typing import Callable, List, Optional, cast
18
19
 
19
20
  import typer
@@ -73,51 +74,51 @@ def prompt_options(text: str, options: List[str]) -> str:
73
74
 
74
75
 
75
76
  def is_valid_project_name(name: str) -> bool:
76
- """Check if the given string is a valid Python module name.
77
+ """Check if the given string is a valid Python project name.
77
78
 
78
- A valid module name must start with a letter or an underscore, and can only contain
79
- letters, digits, and underscores.
79
+ A valid project name must start with a letter and can only contain letters, digits,
80
+ and hyphens.
80
81
  """
81
82
  if not name:
82
83
  return False
83
84
 
84
- # Check if the first character is a letter or underscore
85
- if not (name[0].isalpha() or name[0] == "_"):
85
+ # Check if the first character is a letter
86
+ if not name[0].isalpha():
86
87
  return False
87
88
 
88
- # Check if the rest of the characters are valid (letter, digit, or underscore)
89
+ # Check if the rest of the characters are valid (letter, digit, or dash)
89
90
  for char in name[1:]:
90
- if not (char.isalnum() or char == "_"):
91
+ if not (char.isalnum() or char in "-"):
91
92
  return False
92
93
 
93
94
  return True
94
95
 
95
96
 
96
97
  def sanitize_project_name(name: str) -> str:
97
- """Sanitize the given string to make it a valid Python module name.
98
+ """Sanitize the given string to make it a valid Python project name.
98
99
 
99
- This version replaces hyphens with underscores, removes any characters not allowed
100
- in Python module names, makes the string lowercase, and ensures it starts with a
101
- valid character.
100
+ This version replaces spaces, dots, slashes, and underscores with dashes, removes
101
+ any characters not allowed in Python project names, makes the string lowercase, and
102
+ ensures it starts with a valid character.
102
103
  """
103
- # Replace '-' with '_'
104
- name_with_underscores = name.replace("-", "_").replace(" ", "_")
104
+ # Replace whitespace with '_'
105
+ name_with_hyphens = re.sub(r"[ ./_]", "-", name)
105
106
 
106
107
  # Allowed characters in a module name: letters, digits, underscore
107
108
  allowed_chars = set(
108
- "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"
109
+ "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-"
109
110
  )
110
111
 
111
112
  # Make the string lowercase
112
- sanitized_name = name_with_underscores.lower()
113
+ sanitized_name = name_with_hyphens.lower()
113
114
 
114
115
  # Remove any characters not allowed in Python module names
115
116
  sanitized_name = "".join(c for c in sanitized_name if c in allowed_chars)
116
117
 
117
118
  # Ensure the first character is a letter or underscore
118
- if sanitized_name and (
119
+ while sanitized_name and (
119
120
  sanitized_name[0].isdigit() or sanitized_name[0] not in allowed_chars
120
121
  ):
121
- sanitized_name = "_" + sanitized_name
122
+ sanitized_name = sanitized_name[1:]
122
123
 
123
124
  return sanitized_name
@@ -22,6 +22,8 @@ from pathlib import Path
22
22
  from queue import Queue
23
23
  from typing import Callable, Iterator, Optional, Tuple, Union, cast
24
24
 
25
+ from cryptography.hazmat.primitives.asymmetric import ec
26
+
25
27
  from flwr.common import (
26
28
  DEFAULT_TTL,
27
29
  GRPC_MAX_MESSAGE_LENGTH,
@@ -56,12 +58,15 @@ def on_channel_state_change(channel_connectivity: str) -> None:
56
58
 
57
59
 
58
60
  @contextmanager
59
- def grpc_connection( # pylint: disable=R0915
61
+ def grpc_connection( # pylint: disable=R0913, R0915
60
62
  server_address: str,
61
63
  insecure: bool,
62
64
  retry_invoker: RetryInvoker, # pylint: disable=unused-argument
63
65
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
64
66
  root_certificates: Optional[Union[bytes, str]] = None,
67
+ authentication_keys: Optional[ # pylint: disable=unused-argument
68
+ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
69
+ ] = None,
65
70
  ) -> Iterator[
66
71
  Tuple[
67
72
  Callable[[], Optional[Message]],
@@ -0,0 +1,150 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower client interceptor."""
16
+
17
+
18
+ import base64
19
+ import collections
20
+ from typing import Any, Callable, Optional, Sequence, Tuple, Union
21
+
22
+ import grpc
23
+ from cryptography.hazmat.primitives.asymmetric import ec
24
+
25
+ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
26
+ bytes_to_public_key,
27
+ compute_hmac,
28
+ generate_shared_key,
29
+ public_key_to_bytes,
30
+ )
31
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
+ CreateNodeRequest,
33
+ DeleteNodeRequest,
34
+ GetRunRequest,
35
+ PullTaskInsRequest,
36
+ PushTaskResRequest,
37
+ )
38
+
39
+ _PUBLIC_KEY_HEADER = "public-key"
40
+ _AUTH_TOKEN_HEADER = "auth-token"
41
+
42
+ Request = Union[
43
+ CreateNodeRequest,
44
+ DeleteNodeRequest,
45
+ PullTaskInsRequest,
46
+ PushTaskResRequest,
47
+ GetRunRequest,
48
+ ]
49
+
50
+
51
+ def _get_value_from_tuples(
52
+ key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]
53
+ ) -> bytes:
54
+ value = next((value for key, value in tuples if key == key_string), "")
55
+ if isinstance(value, str):
56
+ return value.encode()
57
+
58
+ return value
59
+
60
+
61
+ class _ClientCallDetails(
62
+ collections.namedtuple(
63
+ "_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
64
+ ),
65
+ grpc.ClientCallDetails, # type: ignore
66
+ ):
67
+ """Details for each client call.
68
+
69
+ The class will be passed on as the first argument in continuation function.
70
+ In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
71
+ """
72
+
73
+
74
+ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
75
+ """Client interceptor for client authentication."""
76
+
77
+ def __init__(
78
+ self,
79
+ private_key: ec.EllipticCurvePrivateKey,
80
+ public_key: ec.EllipticCurvePublicKey,
81
+ ):
82
+ self.private_key = private_key
83
+ self.public_key = public_key
84
+ self.shared_secret: Optional[bytes] = None
85
+ self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
86
+ self.encoded_public_key = base64.urlsafe_b64encode(
87
+ public_key_to_bytes(self.public_key)
88
+ )
89
+
90
+ def intercept_unary_unary(
91
+ self,
92
+ continuation: Callable[[Any, Any], Any],
93
+ client_call_details: grpc.ClientCallDetails,
94
+ request: Request,
95
+ ) -> grpc.Call:
96
+ """Flower client interceptor.
97
+
98
+ Intercept unary call from client and add necessary authentication header in the
99
+ RPC metadata.
100
+ """
101
+ metadata = []
102
+ postprocess = False
103
+ if client_call_details.metadata is not None:
104
+ metadata = list(client_call_details.metadata)
105
+
106
+ # Always add the public key header
107
+ metadata.append(
108
+ (
109
+ _PUBLIC_KEY_HEADER,
110
+ self.encoded_public_key,
111
+ )
112
+ )
113
+
114
+ if isinstance(request, CreateNodeRequest):
115
+ postprocess = True
116
+ elif isinstance(
117
+ request,
118
+ (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, GetRunRequest),
119
+ ):
120
+ if self.shared_secret is None:
121
+ raise RuntimeError("Failure to compute hmac")
122
+
123
+ metadata.append(
124
+ (
125
+ _AUTH_TOKEN_HEADER,
126
+ base64.urlsafe_b64encode(
127
+ compute_hmac(
128
+ self.shared_secret, request.SerializeToString(True)
129
+ )
130
+ ),
131
+ )
132
+ )
133
+
134
+ client_call_details = _ClientCallDetails(
135
+ client_call_details.method,
136
+ client_call_details.timeout,
137
+ metadata,
138
+ client_call_details.credentials,
139
+ )
140
+
141
+ response = continuation(client_call_details, request)
142
+ if postprocess:
143
+ server_public_key_bytes = base64.urlsafe_b64decode(
144
+ _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
145
+ )
146
+ self.server_public_key = bytes_to_public_key(server_public_key_bytes)
147
+ self.shared_secret = generate_shared_key(
148
+ self.private_key, self.server_public_key
149
+ )
150
+ return response
@@ -21,7 +21,10 @@ from contextlib import contextmanager
21
21
  from copy import copy
22
22
  from logging import DEBUG, ERROR
23
23
  from pathlib import Path
24
- from typing import Callable, Iterator, Optional, Tuple, Union, cast
24
+ from typing import Callable, Iterator, Optional, Sequence, Tuple, Union, cast
25
+
26
+ import grpc
27
+ from cryptography.hazmat.primitives.asymmetric import ec
25
28
 
26
29
  from flwr.client.heartbeat import start_ping_loop
27
30
  from flwr.client.message_handler.message_handler import validate_out_message
@@ -52,6 +55,8 @@ from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
52
55
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
53
56
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
54
57
 
58
+ from .client_interceptor import AuthenticateClientInterceptor
59
+
55
60
 
56
61
  def on_channel_state_change(channel_connectivity: str) -> None:
57
62
  """Log channel connectivity."""
@@ -59,12 +64,15 @@ def on_channel_state_change(channel_connectivity: str) -> None:
59
64
 
60
65
 
61
66
  @contextmanager
62
- def grpc_request_response( # pylint: disable=R0914, R0915
67
+ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
63
68
  server_address: str,
64
69
  insecure: bool,
65
70
  retry_invoker: RetryInvoker,
66
71
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
67
72
  root_certificates: Optional[Union[bytes, str]] = None,
73
+ authentication_keys: Optional[
74
+ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
75
+ ] = None,
68
76
  ) -> Iterator[
69
77
  Tuple[
70
78
  Callable[[], Optional[Message]],
@@ -109,11 +117,18 @@ def grpc_request_response( # pylint: disable=R0914, R0915
109
117
  if isinstance(root_certificates, str):
110
118
  root_certificates = Path(root_certificates).read_bytes()
111
119
 
120
+ interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None
121
+ if authentication_keys is not None:
122
+ interceptors = AuthenticateClientInterceptor(
123
+ authentication_keys[0], authentication_keys[1]
124
+ )
125
+
112
126
  channel = create_channel(
113
127
  server_address=server_address,
114
128
  insecure=insecure,
115
129
  root_certificates=root_certificates,
116
130
  max_message_length=max_message_length,
131
+ interceptors=interceptors,
117
132
  )
118
133
  channel.subscribe(on_channel_state_change)
119
134
 
@@ -23,6 +23,7 @@ from copy import copy
23
23
  from logging import ERROR, INFO, WARN
24
24
  from typing import Callable, Iterator, Optional, Tuple, Type, TypeVar, Union
25
25
 
26
+ from cryptography.hazmat.primitives.asymmetric import ec
26
27
  from google.protobuf.message import Message as GrpcMessage
27
28
 
28
29
  from flwr.client.heartbeat import start_ping_loop
@@ -74,7 +75,7 @@ T = TypeVar("T", bound=GrpcMessage)
74
75
 
75
76
 
76
77
  @contextmanager
77
- def http_request_response( # pylint: disable=R0914, R0915
78
+ def http_request_response( # pylint: disable=,R0913, R0914, R0915
78
79
  server_address: str,
79
80
  insecure: bool, # pylint: disable=unused-argument
80
81
  retry_invoker: RetryInvoker,
@@ -82,6 +83,9 @@ def http_request_response( # pylint: disable=R0914, R0915
82
83
  root_certificates: Optional[
83
84
  Union[bytes, str]
84
85
  ] = None, # pylint: disable=unused-argument
86
+ authentication_keys: Optional[ # pylint: disable=unused-argument
87
+ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
88
+ ] = None,
85
89
  ) -> Iterator[
86
90
  Tuple[
87
91
  Callable[[], Optional[Message]],