flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (64) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/build.py +151 -0
  3. flwr/cli/config_utils.py +18 -46
  4. flwr/cli/new/new.py +42 -18
  5. flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
  6. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
  7. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
  8. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +15 -29
  9. flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
  10. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
  12. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +9 -1
  13. flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
  14. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
  15. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
  16. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
  17. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
  18. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
  19. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
  20. flwr/cli/run/run.py +1 -1
  21. flwr/cli/utils.py +18 -17
  22. flwr/client/__init__.py +1 -1
  23. flwr/client/app.py +17 -93
  24. flwr/client/grpc_client/connection.py +6 -1
  25. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  26. flwr/client/grpc_rere_client/connection.py +17 -2
  27. flwr/client/mod/centraldp_mods.py +4 -2
  28. flwr/client/mod/localdp_mod.py +9 -3
  29. flwr/client/rest_client/connection.py +5 -1
  30. flwr/client/supernode/__init__.py +2 -0
  31. flwr/client/supernode/app.py +181 -7
  32. flwr/common/grpc.py +5 -1
  33. flwr/common/logger.py +37 -4
  34. flwr/common/message.py +105 -86
  35. flwr/common/record/parametersrecord.py +0 -1
  36. flwr/common/record/recordset.py +17 -5
  37. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
  38. flwr/server/app.py +111 -1
  39. flwr/server/compat/app.py +2 -2
  40. flwr/server/compat/app_utils.py +1 -1
  41. flwr/server/compat/driver_client_proxy.py +27 -72
  42. flwr/server/driver/__init__.py +3 -0
  43. flwr/server/driver/driver.py +12 -242
  44. flwr/server/driver/grpc_driver.py +315 -0
  45. flwr/server/run_serverapp.py +18 -4
  46. flwr/server/strategy/dp_adaptive_clipping.py +5 -3
  47. flwr/server/strategy/dp_fixed_clipping.py +6 -3
  48. flwr/server/superlink/driver/driver_servicer.py +1 -1
  49. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
  50. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
  51. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  52. flwr/server/superlink/fleet/vce/vce_api.py +1 -1
  53. flwr/server/superlink/state/in_memory_state.py +76 -8
  54. flwr/server/superlink/state/sqlite_state.py +116 -11
  55. flwr/server/superlink/state/state.py +35 -3
  56. flwr/simulation/__init__.py +2 -2
  57. flwr/simulation/app.py +16 -1
  58. flwr/simulation/run_simulation.py +10 -7
  59. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
  60. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +63 -52
  61. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +1 -1
  62. flwr/server/driver/abc_driver.py +0 -140
  63. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
  64. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
@@ -0,0 +1,158 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower client interceptor."""
16
+
17
+
18
+ import base64
19
+ import collections
20
+ from typing import Any, Callable, Optional, Sequence, Tuple, Union
21
+
22
+ import grpc
23
+ from cryptography.hazmat.primitives.asymmetric import ec
24
+
25
+ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
26
+ bytes_to_public_key,
27
+ compute_hmac,
28
+ generate_shared_key,
29
+ public_key_to_bytes,
30
+ )
31
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
+ CreateNodeRequest,
33
+ DeleteNodeRequest,
34
+ GetRunRequest,
35
+ PingRequest,
36
+ PullTaskInsRequest,
37
+ PushTaskResRequest,
38
+ )
39
+
40
+ _PUBLIC_KEY_HEADER = "public-key"
41
+ _AUTH_TOKEN_HEADER = "auth-token"
42
+
43
+ Request = Union[
44
+ CreateNodeRequest,
45
+ DeleteNodeRequest,
46
+ PullTaskInsRequest,
47
+ PushTaskResRequest,
48
+ GetRunRequest,
49
+ PingRequest,
50
+ ]
51
+
52
+
53
+ def _get_value_from_tuples(
54
+ key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]
55
+ ) -> bytes:
56
+ value = next((value for key, value in tuples if key == key_string), "")
57
+ if isinstance(value, str):
58
+ return value.encode()
59
+
60
+ return value
61
+
62
+
63
+ class _ClientCallDetails(
64
+ collections.namedtuple(
65
+ "_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
66
+ ),
67
+ grpc.ClientCallDetails, # type: ignore
68
+ ):
69
+ """Details for each client call.
70
+
71
+ The class will be passed on as the first argument in continuation function.
72
+ In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
73
+ """
74
+
75
+
76
+ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
77
+ """Client interceptor for client authentication."""
78
+
79
+ def __init__(
80
+ self,
81
+ private_key: ec.EllipticCurvePrivateKey,
82
+ public_key: ec.EllipticCurvePublicKey,
83
+ ):
84
+ self.private_key = private_key
85
+ self.public_key = public_key
86
+ self.shared_secret: Optional[bytes] = None
87
+ self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
88
+ self.encoded_public_key = base64.urlsafe_b64encode(
89
+ public_key_to_bytes(self.public_key)
90
+ )
91
+
92
+ def intercept_unary_unary(
93
+ self,
94
+ continuation: Callable[[Any, Any], Any],
95
+ client_call_details: grpc.ClientCallDetails,
96
+ request: Request,
97
+ ) -> grpc.Call:
98
+ """Flower client interceptor.
99
+
100
+ Intercept unary call from client and add necessary authentication header in the
101
+ RPC metadata.
102
+ """
103
+ metadata = []
104
+ postprocess = False
105
+ if client_call_details.metadata is not None:
106
+ metadata = list(client_call_details.metadata)
107
+
108
+ # Always add the public key header
109
+ metadata.append(
110
+ (
111
+ _PUBLIC_KEY_HEADER,
112
+ self.encoded_public_key,
113
+ )
114
+ )
115
+
116
+ if isinstance(request, CreateNodeRequest):
117
+ postprocess = True
118
+ elif isinstance(
119
+ request,
120
+ (
121
+ DeleteNodeRequest,
122
+ PullTaskInsRequest,
123
+ PushTaskResRequest,
124
+ GetRunRequest,
125
+ PingRequest,
126
+ ),
127
+ ):
128
+ if self.shared_secret is None:
129
+ raise RuntimeError("Failure to compute hmac")
130
+
131
+ metadata.append(
132
+ (
133
+ _AUTH_TOKEN_HEADER,
134
+ base64.urlsafe_b64encode(
135
+ compute_hmac(
136
+ self.shared_secret, request.SerializeToString(True)
137
+ )
138
+ ),
139
+ )
140
+ )
141
+
142
+ client_call_details = _ClientCallDetails(
143
+ client_call_details.method,
144
+ client_call_details.timeout,
145
+ metadata,
146
+ client_call_details.credentials,
147
+ )
148
+
149
+ response = continuation(client_call_details, request)
150
+ if postprocess:
151
+ server_public_key_bytes = base64.urlsafe_b64decode(
152
+ _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
153
+ )
154
+ self.server_public_key = bytes_to_public_key(server_public_key_bytes)
155
+ self.shared_secret = generate_shared_key(
156
+ self.private_key, self.server_public_key
157
+ )
158
+ return response
@@ -21,7 +21,10 @@ from contextlib import contextmanager
21
21
  from copy import copy
22
22
  from logging import DEBUG, ERROR
23
23
  from pathlib import Path
24
- from typing import Callable, Iterator, Optional, Tuple, Union, cast
24
+ from typing import Callable, Iterator, Optional, Sequence, Tuple, Union, cast
25
+
26
+ import grpc
27
+ from cryptography.hazmat.primitives.asymmetric import ec
25
28
 
26
29
  from flwr.client.heartbeat import start_ping_loop
27
30
  from flwr.client.message_handler.message_handler import validate_out_message
@@ -52,6 +55,8 @@ from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
52
55
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
53
56
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
54
57
 
58
+ from .client_interceptor import AuthenticateClientInterceptor
59
+
55
60
 
56
61
  def on_channel_state_change(channel_connectivity: str) -> None:
57
62
  """Log channel connectivity."""
@@ -59,12 +64,15 @@ def on_channel_state_change(channel_connectivity: str) -> None:
59
64
 
60
65
 
61
66
  @contextmanager
62
- def grpc_request_response( # pylint: disable=R0914, R0915
67
+ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
63
68
  server_address: str,
64
69
  insecure: bool,
65
70
  retry_invoker: RetryInvoker,
66
71
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
67
72
  root_certificates: Optional[Union[bytes, str]] = None,
73
+ authentication_keys: Optional[
74
+ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
75
+ ] = None,
68
76
  ) -> Iterator[
69
77
  Tuple[
70
78
  Callable[[], Optional[Message]],
@@ -109,11 +117,18 @@ def grpc_request_response( # pylint: disable=R0914, R0915
109
117
  if isinstance(root_certificates, str):
110
118
  root_certificates = Path(root_certificates).read_bytes()
111
119
 
120
+ interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None
121
+ if authentication_keys is not None:
122
+ interceptors = AuthenticateClientInterceptor(
123
+ authentication_keys[0], authentication_keys[1]
124
+ )
125
+
112
126
  channel = create_channel(
113
127
  server_address=server_address,
114
128
  insecure=insecure,
115
129
  root_certificates=root_certificates,
116
130
  max_message_length=max_message_length,
131
+ interceptors=interceptors,
117
132
  )
118
133
  channel.subscribe(on_channel_state_change)
119
134
 
@@ -82,7 +82,9 @@ def fixedclipping_mod(
82
82
  clipping_norm,
83
83
  )
84
84
 
85
- log(INFO, "fixedclipping_mod: parameters are clipped by value: %s.", clipping_norm)
85
+ log(
86
+ INFO, "fixedclipping_mod: parameters are clipped by value: %.4f.", clipping_norm
87
+ )
86
88
 
87
89
  fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
88
90
  out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
@@ -146,7 +148,7 @@ def adaptiveclipping_mod(
146
148
  )
147
149
  log(
148
150
  INFO,
149
- "adaptiveclipping_mod: parameters are clipped by value: %s.",
151
+ "adaptiveclipping_mod: parameters are clipped by value: %.4f.",
150
152
  clipping_norm,
151
153
  )
152
154
 
@@ -128,7 +128,9 @@ class LocalDpMod:
128
128
  self.clipping_norm,
129
129
  )
130
130
  log(
131
- INFO, "LocalDpMod: parameters are clipped by value: %s.", self.clipping_norm
131
+ INFO,
132
+ "LocalDpMod: parameters are clipped by value: %.4f.",
133
+ self.clipping_norm,
132
134
  )
133
135
 
134
136
  fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
@@ -137,11 +139,15 @@ class LocalDpMod:
137
139
  add_localdp_gaussian_noise_to_params(
138
140
  fit_res.parameters, self.sensitivity, self.epsilon, self.delta
139
141
  )
142
+
143
+ noise_value_sd = (
144
+ self.sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon
145
+ )
140
146
  log(
141
147
  INFO,
142
148
  "LocalDpMod: local DP noise with "
143
- "standard deviation: %s added to parameters.",
144
- self.sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon,
149
+ "standard deviation: %.4f added to parameters.",
150
+ noise_value_sd,
145
151
  )
146
152
 
147
153
  out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
@@ -23,6 +23,7 @@ from copy import copy
23
23
  from logging import ERROR, INFO, WARN
24
24
  from typing import Callable, Iterator, Optional, Tuple, Type, TypeVar, Union
25
25
 
26
+ from cryptography.hazmat.primitives.asymmetric import ec
26
27
  from google.protobuf.message import Message as GrpcMessage
27
28
 
28
29
  from flwr.client.heartbeat import start_ping_loop
@@ -74,7 +75,7 @@ T = TypeVar("T", bound=GrpcMessage)
74
75
 
75
76
 
76
77
  @contextmanager
77
- def http_request_response( # pylint: disable=R0914, R0915
78
+ def http_request_response( # pylint: disable=,R0913, R0914, R0915
78
79
  server_address: str,
79
80
  insecure: bool, # pylint: disable=unused-argument
80
81
  retry_invoker: RetryInvoker,
@@ -82,6 +83,9 @@ def http_request_response( # pylint: disable=R0914, R0915
82
83
  root_certificates: Optional[
83
84
  Union[bytes, str]
84
85
  ] = None, # pylint: disable=unused-argument
86
+ authentication_keys: Optional[ # pylint: disable=unused-argument
87
+ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
88
+ ] = None,
85
89
  ) -> Iterator[
86
90
  Tuple[
87
91
  Callable[[], Optional[Message]],
@@ -15,8 +15,10 @@
15
15
  """Flower SuperNode."""
16
16
 
17
17
 
18
+ from .app import run_client_app as run_client_app
18
19
  from .app import run_supernode as run_supernode
19
20
 
20
21
  __all__ = [
22
+ "run_client_app",
21
23
  "run_supernode",
22
24
  ]
@@ -15,11 +15,27 @@
15
15
  """Flower SuperNode."""
16
16
 
17
17
  import argparse
18
- from logging import DEBUG, INFO
18
+ import sys
19
+ from logging import DEBUG, INFO, WARN
20
+ from pathlib import Path
21
+ from typing import Callable, Optional, Tuple
19
22
 
23
+ from cryptography.hazmat.primitives.asymmetric import ec
24
+ from cryptography.hazmat.primitives.serialization import (
25
+ load_ssh_private_key,
26
+ load_ssh_public_key,
27
+ )
28
+
29
+ from flwr.client.client_app import ClientApp, LoadClientAppError
20
30
  from flwr.common import EventType, event
21
31
  from flwr.common.exit_handlers import register_exit_handlers
22
32
  from flwr.common.logger import log
33
+ from flwr.common.object_ref import load_app, validate
34
+ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
35
+ ssh_types_to_elliptic_curve,
36
+ )
37
+
38
+ from ..app import _start_client_internal
23
39
 
24
40
 
25
41
  def run_supernode() -> None:
@@ -28,12 +44,11 @@ def run_supernode() -> None:
28
44
 
29
45
  event(EventType.RUN_SUPERNODE_ENTER)
30
46
 
31
- args = _parse_args_run_supernode().parse_args()
47
+ _ = _parse_args_run_supernode().parse_args()
32
48
 
33
49
  log(
34
50
  DEBUG,
35
- "Flower will load ClientApp `%s`",
36
- getattr(args, "client-app"),
51
+ "Flower SuperNode starting...",
37
52
  )
38
53
 
39
54
  # Graceful shutdown
@@ -42,23 +57,144 @@ def run_supernode() -> None:
42
57
  )
43
58
 
44
59
 
60
+ def run_client_app() -> None:
61
+ """Run Flower client app."""
62
+ log(INFO, "Long-running Flower client starting")
63
+
64
+ event(EventType.RUN_CLIENT_APP_ENTER)
65
+
66
+ args = _parse_args_run_client_app().parse_args()
67
+
68
+ root_certificates = _get_certificates(args)
69
+ log(
70
+ DEBUG,
71
+ "Flower will load ClientApp `%s`",
72
+ getattr(args, "client-app"),
73
+ )
74
+ load_fn = _get_load_client_app_fn(args)
75
+ authentication_keys = _try_setup_client_authentication(args)
76
+
77
+ _start_client_internal(
78
+ server_address=args.server,
79
+ load_client_app_fn=load_fn,
80
+ transport="rest" if args.rest else "grpc-rere",
81
+ root_certificates=root_certificates,
82
+ insecure=args.insecure,
83
+ authentication_keys=authentication_keys,
84
+ max_retries=args.max_retries,
85
+ max_wait_time=args.max_wait_time,
86
+ )
87
+ register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
88
+
89
+
90
+ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
91
+ """Load certificates if specified in args."""
92
+ # Obtain certificates
93
+ if args.insecure:
94
+ if args.root_certificates is not None:
95
+ sys.exit(
96
+ "Conflicting options: The '--insecure' flag disables HTTPS, "
97
+ "but '--root-certificates' was also specified. Please remove "
98
+ "the '--root-certificates' option when running in insecure mode, "
99
+ "or omit '--insecure' to use HTTPS."
100
+ )
101
+ log(
102
+ WARN,
103
+ "Option `--insecure` was set. "
104
+ "Starting insecure HTTP client connected to %s.",
105
+ args.server,
106
+ )
107
+ root_certificates = None
108
+ else:
109
+ # Load the certificates if provided, or load the system certificates
110
+ cert_path = args.root_certificates
111
+ if cert_path is None:
112
+ root_certificates = None
113
+ else:
114
+ root_certificates = Path(cert_path).read_bytes()
115
+ log(
116
+ DEBUG,
117
+ "Starting secure HTTPS client connected to %s "
118
+ "with the following certificates: %s.",
119
+ args.server,
120
+ cert_path,
121
+ )
122
+ return root_certificates
123
+
124
+
125
+ def _get_load_client_app_fn(
126
+ args: argparse.Namespace,
127
+ ) -> Callable[[], ClientApp]:
128
+ """Get the load_client_app_fn function."""
129
+ client_app_dir = args.dir
130
+ if client_app_dir is not None:
131
+ sys.path.insert(0, client_app_dir)
132
+
133
+ app_ref: str = getattr(args, "client-app")
134
+ valid, error_msg = validate(app_ref)
135
+ if not valid and error_msg:
136
+ raise LoadClientAppError(error_msg) from None
137
+
138
+ def _load() -> ClientApp:
139
+ client_app = load_app(app_ref, LoadClientAppError)
140
+
141
+ if not isinstance(client_app, ClientApp):
142
+ raise LoadClientAppError(
143
+ f"Attribute {app_ref} is not of type {ClientApp}",
144
+ ) from None
145
+
146
+ return client_app
147
+
148
+ return _load
149
+
150
+
45
151
  def _parse_args_run_supernode() -> argparse.ArgumentParser:
46
152
  """Parse flower-supernode command line arguments."""
47
153
  parser = argparse.ArgumentParser(
48
154
  description="Start a Flower SuperNode",
49
155
  )
50
156
 
51
- parse_args_run_client_app(parser=parser)
157
+ parser.add_argument(
158
+ "client-app",
159
+ nargs="?",
160
+ default="",
161
+ help="For example: `client:app` or `project.package.module:wrapper.app`. "
162
+ "This is optional and serves as the default ClientApp to be loaded when "
163
+ "the ServerApp does not specify `fab_id` and `fab_version`. "
164
+ "If not provided, defaults to an empty string.",
165
+ )
166
+ _parse_args_common(parser)
167
+ parser.add_argument(
168
+ "--flwr-dir",
169
+ default=None,
170
+ help="""The path containing installed Flower Apps.
171
+ By default, this value isequal to:
172
+
173
+ - `$FLWR_HOME/` if `$FLWR_HOME` is defined
174
+ - `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
175
+ - `$HOME/.flwr/` in all other cases
176
+ """,
177
+ )
52
178
 
53
179
  return parser
54
180
 
55
181
 
56
- def parse_args_run_client_app(parser: argparse.ArgumentParser) -> None:
57
- """Parse command line arguments."""
182
+ def _parse_args_run_client_app() -> argparse.ArgumentParser:
183
+ """Parse flower-client-app command line arguments."""
184
+ parser = argparse.ArgumentParser(
185
+ description="Start a Flower client app",
186
+ )
187
+
58
188
  parser.add_argument(
59
189
  "client-app",
60
190
  help="For example: `client:app` or `project.package.module:wrapper.app`",
61
191
  )
192
+ _parse_args_common(parser=parser)
193
+
194
+ return parser
195
+
196
+
197
+ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
62
198
  parser.add_argument(
63
199
  "--insecure",
64
200
  action="store_true",
@@ -105,3 +241,41 @@ def parse_args_run_client_app(parser: argparse.ArgumentParser) -> None:
105
241
  "app from there."
106
242
  " Default: current working directory.",
107
243
  )
244
+ parser.add_argument(
245
+ "--authentication-keys",
246
+ nargs=2,
247
+ metavar=("CLIENT_PRIVATE_KEY", "CLIENT_PUBLIC_KEY"),
248
+ type=str,
249
+ help="Provide two file paths: (1) the client's private "
250
+ "key file, and (2) the client's public key file.",
251
+ )
252
+
253
+
254
+ def _try_setup_client_authentication(
255
+ args: argparse.Namespace,
256
+ ) -> Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
257
+ if not args.authentication_keys:
258
+ return None
259
+
260
+ ssh_private_key = load_ssh_private_key(
261
+ Path(args.authentication_keys[0]).read_bytes(),
262
+ None,
263
+ )
264
+ ssh_public_key = load_ssh_public_key(Path(args.authentication_keys[1]).read_bytes())
265
+
266
+ try:
267
+ client_private_key, client_public_key = ssh_types_to_elliptic_curve(
268
+ ssh_private_key, ssh_public_key
269
+ )
270
+ except TypeError:
271
+ sys.exit(
272
+ "The file paths provided could not be read as a private and public "
273
+ "key pair. Client authentication requires an elliptic curve public and "
274
+ "private key pair. Please provide the file paths containing elliptic "
275
+ "curve private and public keys to '--authentication-keys'."
276
+ )
277
+
278
+ return (
279
+ client_private_key,
280
+ client_public_key,
281
+ )
flwr/common/grpc.py CHANGED
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  from logging import DEBUG
19
- from typing import Optional
19
+ from typing import Optional, Sequence
20
20
 
21
21
  import grpc
22
22
 
@@ -30,6 +30,7 @@ def create_channel(
30
30
  insecure: bool,
31
31
  root_certificates: Optional[bytes] = None,
32
32
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
33
+ interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None,
33
34
  ) -> grpc.Channel:
34
35
  """Create a gRPC channel, either secure or insecure."""
35
36
  # Check for conflicting parameters
@@ -57,4 +58,7 @@ def create_channel(
57
58
  )
58
59
  log(DEBUG, "Opened secure gRPC connection using certificates")
59
60
 
61
+ if interceptors is not None:
62
+ channel = grpc.intercept_channel(channel, interceptors)
63
+
60
64
  return channel
flwr/common/logger.py CHANGED
@@ -82,13 +82,20 @@ class ConsoleHandler(StreamHandler):
82
82
  return formatter.format(record)
83
83
 
84
84
 
85
- def update_console_handler(level: int, timestamps: bool, colored: bool) -> None:
85
+ def update_console_handler(
86
+ level: Optional[int] = None,
87
+ timestamps: Optional[bool] = None,
88
+ colored: Optional[bool] = None,
89
+ ) -> None:
86
90
  """Update the logging handler."""
87
91
  for handler in logging.getLogger(LOGGER_NAME).handlers:
88
92
  if isinstance(handler, ConsoleHandler):
89
- handler.setLevel(level)
90
- handler.timestamps = timestamps
91
- handler.colored = colored
93
+ if level is not None:
94
+ handler.setLevel(level)
95
+ if timestamps is not None:
96
+ handler.timestamps = timestamps
97
+ if colored is not None:
98
+ handler.colored = colored
92
99
 
93
100
 
94
101
  # Configure console logger
@@ -188,3 +195,29 @@ def warn_deprecated_feature(name: str) -> None:
188
195
  """,
189
196
  name,
190
197
  )
198
+
199
+
200
+ def set_logger_propagation(
201
+ child_logger: logging.Logger, value: bool = True
202
+ ) -> logging.Logger:
203
+ """Set the logger propagation attribute.
204
+
205
+ Parameters
206
+ ----------
207
+ child_logger : logging.Logger
208
+ Child logger object
209
+ value : bool
210
+ Boolean setting for propagation. If True, both parent and child logger
211
+ display messages. Otherwise, only the child logger displays a message.
212
+ This False setting prevents duplicate logs in Colab notebooks.
213
+ Reference: https://stackoverflow.com/a/19561320
214
+
215
+ Returns
216
+ -------
217
+ logging.Logger
218
+ Child logger object with updated propagation setting
219
+ """
220
+ child_logger.propagate = value
221
+ if not child_logger.propagate:
222
+ child_logger.log(logging.DEBUG, "Logger propagate set to False")
223
+ return child_logger