flwr-nightly 1.9.0.dev20240426__py3-none-any.whl → 1.9.0.dev20240430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

flwr/common/message.py CHANGED
@@ -19,6 +19,7 @@ from __future__ import annotations
19
19
  import time
20
20
  import warnings
21
21
  from dataclasses import dataclass
22
+ from typing import Optional, cast
22
23
 
23
24
  from .record import RecordSet
24
25
 
@@ -55,17 +56,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
55
56
  is more relevant when conducting simulations.
56
57
  """
57
58
 
58
- _run_id: int
59
- _message_id: str
60
- _src_node_id: int
61
- _dst_node_id: int
62
- _reply_to_message: str
63
- _group_id: str
64
- _ttl: float
65
- _message_type: str
66
- _partition_id: int | None
67
- _created_at: float # Unix timestamp (in seconds) to be set upon message creation
68
-
69
59
  def __init__( # pylint: disable=too-many-arguments
70
60
  self,
71
61
  run_id: int,
@@ -78,95 +68,98 @@ class Metadata: # pylint: disable=too-many-instance-attributes
78
68
  message_type: str,
79
69
  partition_id: int | None = None,
80
70
  ) -> None:
81
- self._run_id = run_id
82
- self._message_id = message_id
83
- self._src_node_id = src_node_id
84
- self._dst_node_id = dst_node_id
85
- self._reply_to_message = reply_to_message
86
- self._group_id = group_id
87
- self._ttl = ttl
88
- self._message_type = message_type
89
- self._partition_id = partition_id
71
+ var_dict = {
72
+ "_run_id": run_id,
73
+ "_message_id": message_id,
74
+ "_src_node_id": src_node_id,
75
+ "_dst_node_id": dst_node_id,
76
+ "_reply_to_message": reply_to_message,
77
+ "_group_id": group_id,
78
+ "_ttl": ttl,
79
+ "_message_type": message_type,
80
+ "_partition_id": partition_id,
81
+ }
82
+ self.__dict__.update(var_dict)
90
83
 
91
84
  @property
92
85
  def run_id(self) -> int:
93
86
  """An identifier for the current run."""
94
- return self._run_id
87
+ return cast(int, self.__dict__["_run_id"])
95
88
 
96
89
  @property
97
90
  def message_id(self) -> str:
98
91
  """An identifier for the current message."""
99
- return self._message_id
92
+ return cast(str, self.__dict__["_message_id"])
100
93
 
101
94
  @property
102
95
  def src_node_id(self) -> int:
103
96
  """An identifier for the node sending this message."""
104
- return self._src_node_id
97
+ return cast(int, self.__dict__["_src_node_id"])
105
98
 
106
99
  @property
107
100
  def reply_to_message(self) -> str:
108
101
  """An identifier for the message this message replies to."""
109
- return self._reply_to_message
102
+ return cast(str, self.__dict__["_reply_to_message"])
110
103
 
111
104
  @property
112
105
  def dst_node_id(self) -> int:
113
106
  """An identifier for the node receiving this message."""
114
- return self._dst_node_id
107
+ return cast(int, self.__dict__["_dst_node_id"])
115
108
 
116
109
  @dst_node_id.setter
117
110
  def dst_node_id(self, value: int) -> None:
118
111
  """Set dst_node_id."""
119
- self._dst_node_id = value
112
+ self.__dict__["_dst_node_id"] = value
120
113
 
121
114
  @property
122
115
  def group_id(self) -> str:
123
116
  """An identifier for grouping messages."""
124
- return self._group_id
117
+ return cast(str, self.__dict__["_group_id"])
125
118
 
126
119
  @group_id.setter
127
120
  def group_id(self, value: str) -> None:
128
121
  """Set group_id."""
129
- self._group_id = value
122
+ self.__dict__["_group_id"] = value
130
123
 
131
124
  @property
132
125
  def created_at(self) -> float:
133
126
  """Unix timestamp when the message was created."""
134
- return self._created_at
127
+ return cast(float, self.__dict__["_created_at"])
135
128
 
136
129
  @created_at.setter
137
130
  def created_at(self, value: float) -> None:
138
- """Set creation timestamp for this messages."""
139
- self._created_at = value
131
+ """Set creation timestamp for this message."""
132
+ self.__dict__["_created_at"] = value
140
133
 
141
134
  @property
142
135
  def ttl(self) -> float:
143
136
  """Time-to-live for this message."""
144
- return self._ttl
137
+ return cast(float, self.__dict__["_ttl"])
145
138
 
146
139
  @ttl.setter
147
140
  def ttl(self, value: float) -> None:
148
141
  """Set ttl."""
149
- self._ttl = value
142
+ self.__dict__["_ttl"] = value
150
143
 
151
144
  @property
152
145
  def message_type(self) -> str:
153
146
  """A string that encodes the action to be executed on the receiving end."""
154
- return self._message_type
147
+ return cast(str, self.__dict__["_message_type"])
155
148
 
156
149
  @message_type.setter
157
150
  def message_type(self, value: str) -> None:
158
151
  """Set message_type."""
159
- self._message_type = value
152
+ self.__dict__["_message_type"] = value
160
153
 
161
154
  @property
162
155
  def partition_id(self) -> int | None:
163
156
  """An identifier telling which data partition a ClientApp should use."""
164
- return self._partition_id
157
+ return cast(int, self.__dict__["_partition_id"])
165
158
 
166
159
  @partition_id.setter
167
160
  def partition_id(self, value: int) -> None:
168
- """Set patition_id."""
169
- self._partition_id = value
161
+ """Set partition_id."""
162
+ self.__dict__["_partition_id"] = value
170
163
 
171
164
 
172
165
  @dataclass
@@ -181,22 +174,22 @@ class Error:
181
174
  A reason for why the error arose (e.g. an exception stack-trace)
182
175
  """
183
176
 
184
- _code: int
185
- _reason: str | None = None
186
-
187
177
  def __init__(self, code: int, reason: str | None = None) -> None:
188
- self._code = code
189
- self._reason = reason
178
+ var_dict = {
179
+ "_code": code,
180
+ "_reason": reason,
181
+ }
182
+ self.__dict__.update(var_dict)
190
183
 
191
184
  @property
192
185
  def code(self) -> int:
193
186
  """Error code."""
194
- return self._code
187
+ return cast(int, self.__dict__["_code"])
195
188
 
196
189
  @property
197
190
  def reason(self) -> str | None:
198
191
  """Reason reported about the error."""
199
- return self._reason
192
+ return cast(Optional[str], self.__dict__["_reason"])
200
193
 
201
194
 
202
195
  @dataclass
@@ -215,88 +208,70 @@ class Message:
215
208
  when processing another message.
216
209
  """
217
210
 
218
- _metadata: Metadata
219
- _content: RecordSet | None = None
220
- _error: Error | None = None
221
-
222
211
  def __init__(
223
212
  self,
224
213
  metadata: Metadata,
225
214
  content: RecordSet | None = None,
226
215
  error: Error | None = None,
227
216
  ) -> None:
228
- self._metadata = metadata
229
-
230
- # Set message creation timestamp
231
- self._metadata.created_at = time.time()
232
-
233
217
  if not (content is None) ^ (error is None):
234
218
  raise ValueError("Either `content` or `error` must be set, but not both.")
235
219
 
236
- self._content = content
237
- self._error = error
220
+ metadata.created_at = time.time() # Set the message creation timestamp
221
+ var_dict = {
222
+ "_metadata": metadata,
223
+ "_content": content,
224
+ "_error": error,
225
+ }
226
+ self.__dict__.update(var_dict)
238
227
 
239
228
  @property
240
229
  def metadata(self) -> Metadata:
241
230
  """A dataclass including information about the message to be executed."""
242
- return self._metadata
231
+ return cast(Metadata, self.__dict__["_metadata"])
243
232
 
244
233
  @property
245
234
  def content(self) -> RecordSet:
246
235
  """The content of this message."""
247
- if self._content is None:
236
+ if self.__dict__["_content"] is None:
248
237
  raise ValueError(
249
238
  "Message content is None. Use <message>.has_content() "
250
239
  "to check if a message has content."
251
240
  )
252
- return self._content
241
+ return cast(RecordSet, self.__dict__["_content"])
253
242
 
254
243
  @content.setter
255
244
  def content(self, value: RecordSet) -> None:
256
245
  """Set content."""
257
- if self._error is None:
258
- self._content = value
246
+ if self.__dict__["_error"] is None:
247
+ self.__dict__["_content"] = value
259
248
  else:
260
249
  raise ValueError("A message with an error set cannot have content.")
261
250
 
262
251
  @property
263
252
  def error(self) -> Error:
264
253
  """Error captured by this message."""
265
- if self._error is None:
254
+ if self.__dict__["_error"] is None:
266
255
  raise ValueError(
267
256
  "Message error is None. Use <message>.has_error() "
268
257
  "to check first if a message carries an error."
269
258
  )
270
- return self._error
259
+ return cast(Error, self.__dict__["_error"])
271
260
 
272
261
  @error.setter
273
262
  def error(self, value: Error) -> None:
274
263
  """Set error."""
275
264
  if self.has_content():
276
265
  raise ValueError("A message with content set cannot carry an error.")
277
- self._error = value
266
+ self.__dict__["_error"] = value
278
267
 
279
268
  def has_content(self) -> bool:
280
269
  """Return True if message has content, else False."""
281
- return self._content is not None
270
+ return self.__dict__["_content"] is not None
282
271
 
283
272
  def has_error(self) -> bool:
284
273
  """Return True if message has an error, else False."""
285
- return self._error is not None
286
-
287
- def _create_reply_metadata(self, ttl: float) -> Metadata:
288
- """Construct metadata for a reply message."""
289
- return Metadata(
290
- run_id=self.metadata.run_id,
291
- message_id="",
292
- src_node_id=self.metadata.dst_node_id,
293
- dst_node_id=self.metadata.src_node_id,
294
- reply_to_message=self.metadata.message_id,
295
- group_id=self.metadata.group_id,
296
- ttl=ttl,
297
- message_type=self.metadata.message_type,
298
- partition_id=self.metadata.partition_id,
299
- )
274
+ return self.__dict__["_error"] is not None
300
275
 
301
276
  def create_error_reply(self, error: Error, ttl: float | None = None) -> Message:
302
277
  """Construct a reply message indicating an error happened.
@@ -323,7 +298,7 @@ class Message:
323
298
  # message creation)
324
299
  ttl_ = DEFAULT_TTL if ttl is None else ttl
325
300
  # Create reply with error
326
- message = Message(metadata=self._create_reply_metadata(ttl_), error=error)
301
+ message = Message(metadata=_create_reply_metadata(self, ttl_), error=error)
327
302
 
328
303
  if ttl is None:
329
304
  # Set TTL equal to the remaining time for the received message to expire
@@ -369,7 +344,7 @@ class Message:
369
344
  ttl_ = DEFAULT_TTL if ttl is None else ttl
370
345
 
371
346
  message = Message(
372
- metadata=self._create_reply_metadata(ttl_),
347
+ metadata=_create_reply_metadata(self, ttl_),
373
348
  content=content,
374
349
  )
375
350
 
@@ -381,3 +356,18 @@ class Message:
381
356
  message.metadata.ttl = ttl
382
357
 
383
358
  return message
359
+
360
+
361
+ def _create_reply_metadata(msg: Message, ttl: float) -> Metadata:
362
+ """Construct metadata for a reply message."""
363
+ return Metadata(
364
+ run_id=msg.metadata.run_id,
365
+ message_id="",
366
+ src_node_id=msg.metadata.dst_node_id,
367
+ dst_node_id=msg.metadata.src_node_id,
368
+ reply_to_message=msg.metadata.message_id,
369
+ group_id=msg.metadata.group_id,
370
+ ttl=ttl,
371
+ message_type=msg.metadata.message_type,
372
+ partition_id=msg.metadata.partition_id,
373
+ )
@@ -24,6 +24,7 @@ from .parametersrecord import ParametersRecord
24
24
  from .typeddict import TypedDict
25
25
 
26
26
 
27
+ @dataclass
27
28
  class RecordSetData:
28
29
  """Inner data container for the RecordSet class."""
29
30
 
@@ -97,22 +98,22 @@ class RecordSet:
97
98
  metrics_records=metrics_records,
98
99
  configs_records=configs_records,
99
100
  )
100
- setattr(self, "_data", data) # noqa
101
+ self.__dict__["_data"] = data
101
102
 
102
103
  @property
103
104
  def parameters_records(self) -> TypedDict[str, ParametersRecord]:
104
105
  """Dictionary holding ParametersRecord instances."""
105
- data = cast(RecordSetData, getattr(self, "_data")) # noqa
106
+ data = cast(RecordSetData, self.__dict__["_data"])
106
107
  return data.parameters_records
107
108
 
108
109
  @property
109
110
  def metrics_records(self) -> TypedDict[str, MetricsRecord]:
110
111
  """Dictionary holding MetricsRecord instances."""
111
- data = cast(RecordSetData, getattr(self, "_data")) # noqa
112
+ data = cast(RecordSetData, self.__dict__["_data"])
112
113
  return data.metrics_records
113
114
 
114
115
  @property
115
116
  def configs_records(self) -> TypedDict[str, ConfigsRecord]:
116
117
  """Dictionary holding ConfigsRecord instances."""
117
- data = cast(RecordSetData, getattr(self, "_data")) # noqa
118
+ data = cast(RecordSetData, self.__dict__["_data"])
118
119
  return data.configs_records
@@ -117,3 +117,18 @@ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
117
117
  return True
118
118
  except InvalidSignature:
119
119
  return False
120
+
121
+
122
+ def ssh_types_to_elliptic_curve(
123
+ private_key: serialization.SSHPrivateKeyTypes,
124
+ public_key: serialization.SSHPublicKeyTypes,
125
+ ) -> Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]:
126
+ """Cast SSH key types to elliptic curve."""
127
+ if isinstance(private_key, ec.EllipticCurvePrivateKey) and isinstance(
128
+ public_key, ec.EllipticCurvePublicKey
129
+ ):
130
+ return (private_key, public_key)
131
+
132
+ raise TypeError(
133
+ "The provided key is not an EllipticCurvePrivateKey or EllipticCurvePublicKey"
134
+ )
flwr/server/app.py CHANGED
@@ -16,15 +16,21 @@
16
16
 
17
17
  import argparse
18
18
  import asyncio
19
+ import csv
19
20
  import importlib.util
20
21
  import sys
21
22
  import threading
22
23
  from logging import ERROR, INFO, WARN
23
24
  from os.path import isfile
24
25
  from pathlib import Path
25
- from typing import List, Optional, Tuple
26
+ from typing import List, Optional, Sequence, Set, Tuple
26
27
 
27
28
  import grpc
29
+ from cryptography.hazmat.primitives.asymmetric import ec
30
+ from cryptography.hazmat.primitives.serialization import (
31
+ load_ssh_private_key,
32
+ load_ssh_public_key,
33
+ )
28
34
 
29
35
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
30
36
  from flwr.common.address import parse_address
@@ -36,6 +42,10 @@ from flwr.common.constant import (
36
42
  )
37
43
  from flwr.common.exit_handlers import register_exit_handlers
38
44
  from flwr.common.logger import log
45
+ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
46
+ public_key_to_bytes,
47
+ ssh_types_to_elliptic_curve,
48
+ )
39
49
  from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
40
50
  add_FleetServicer_to_server,
41
51
  )
@@ -51,6 +61,7 @@ from .superlink.fleet.grpc_bidi.grpc_server import (
51
61
  start_grpc_server,
52
62
  )
53
63
  from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
64
+ from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
54
65
  from .superlink.fleet.vce import start_vce
55
66
  from .superlink.state import StateFactory
56
67
 
@@ -354,10 +365,28 @@ def run_superlink() -> None:
354
365
  sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
355
366
  host, port, is_v6 = parsed_address
356
367
  address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
368
+
369
+ maybe_keys = _try_setup_client_authentication(args, certificates)
370
+ interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
371
+ if maybe_keys is not None:
372
+ (
373
+ client_public_keys,
374
+ server_private_key,
375
+ server_public_key,
376
+ ) = maybe_keys
377
+ interceptors = [
378
+ AuthenticateServerInterceptor(
379
+ client_public_keys,
380
+ server_private_key,
381
+ server_public_key,
382
+ )
383
+ ]
384
+
357
385
  fleet_server = _run_fleet_api_grpc_rere(
358
386
  address=address,
359
387
  state_factory=state_factory,
360
388
  certificates=certificates,
389
+ interceptors=interceptors,
361
390
  )
362
391
  grpc_servers.append(fleet_server)
363
392
  elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
@@ -390,6 +419,70 @@ def run_superlink() -> None:
390
419
  driver_server.wait_for_termination(timeout=1)
391
420
 
392
421
 
422
+ def _try_setup_client_authentication(
423
+ args: argparse.Namespace,
424
+ certificates: Optional[Tuple[bytes, bytes, bytes]],
425
+ ) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
426
+ if not args.require_client_authentication:
427
+ return None
428
+
429
+ if certificates is None:
430
+ sys.exit(
431
+ "Client authentication only works over secure connections. "
432
+ "Please provide certificate paths using '--certificates' when "
433
+ "enabling '--require-client-authentication'."
434
+ )
435
+
436
+ client_keys_file_path = Path(args.require_client_authentication[0])
437
+ if not client_keys_file_path.exists():
438
+ sys.exit(
439
+ "The provided path to the client public keys CSV file does not exist: "
440
+ f"{client_keys_file_path}. "
441
+ "Please provide the CSV file path containing known client public keys "
442
+ "to '--require-client-authentication'."
443
+ )
444
+
445
+ client_public_keys: Set[bytes] = set()
446
+ ssh_private_key = load_ssh_private_key(
447
+ Path(args.require_client_authentication[1]).read_bytes(),
448
+ None,
449
+ )
450
+ ssh_public_key = load_ssh_public_key(
451
+ Path(args.require_client_authentication[2]).read_bytes()
452
+ )
453
+
454
+ try:
455
+ server_private_key, server_public_key = ssh_types_to_elliptic_curve(
456
+ ssh_private_key, ssh_public_key
457
+ )
458
+ except TypeError:
459
+ sys.exit(
460
+ "The file paths provided could not be read as a private and public "
461
+ "key pair. Client authentication requires an elliptic curve public and "
462
+ "private key pair. Please provide the file paths containing elliptic "
463
+ "curve private and public keys to '--require-client-authentication'."
464
+ )
465
+
466
+ with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile:
467
+ reader = csv.reader(csvfile)
468
+ for row in reader:
469
+ for element in row:
470
+ public_key = load_ssh_public_key(element.encode())
471
+ if isinstance(public_key, ec.EllipticCurvePublicKey):
472
+ client_public_keys.add(public_key_to_bytes(public_key))
473
+ else:
474
+ sys.exit(
475
+ "Error: Unable to parse the public keys in the .csv "
476
+ "file. Please ensure that the .csv file contains valid "
477
+ "SSH public keys and try again."
478
+ )
479
+ return (
480
+ client_public_keys,
481
+ server_private_key,
482
+ server_public_key,
483
+ )
484
+
485
+
393
486
  def _try_obtain_certificates(
394
487
  args: argparse.Namespace,
395
488
  ) -> Optional[Tuple[bytes, bytes, bytes]]:
@@ -417,6 +510,7 @@ def _run_fleet_api_grpc_rere(
417
510
  address: str,
418
511
  state_factory: StateFactory,
419
512
  certificates: Optional[Tuple[bytes, bytes, bytes]],
513
+ interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
420
514
  ) -> grpc.Server:
421
515
  """Run Fleet API (gRPC, request-response)."""
422
516
  # Create Fleet API gRPC server
@@ -429,6 +523,7 @@ def _run_fleet_api_grpc_rere(
429
523
  server_address=address,
430
524
  max_message_length=GRPC_MAX_MESSAGE_LENGTH,
431
525
  certificates=certificates,
526
+ interceptors=interceptors,
432
527
  )
433
528
 
434
529
  log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address)
@@ -606,6 +701,15 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
606
701
  "Flower will just create a state in memory.",
607
702
  default=DATABASE,
608
703
  )
704
+ parser.add_argument(
705
+ "--require-client-authentication",
706
+ nargs=3,
707
+ metavar=("CLIENT_KEYS", "SERVER_PRIVATE_KEY", "SERVER_PUBLIC_KEY"),
708
+ type=str,
709
+ help="Provide three file paths: (1) a .csv file containing a list of "
710
+ "known client public keys for authentication, (2) the server's private "
711
+ "key file, and (3) the server's public key file.",
712
+ )
609
713
 
610
714
 
611
715
  def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
@@ -18,7 +18,7 @@
18
18
  import concurrent.futures
19
19
  import sys
20
20
  from logging import ERROR
21
- from typing import Any, Callable, Optional, Tuple, Union
21
+ from typing import Any, Callable, Optional, Sequence, Tuple, Union
22
22
 
23
23
  import grpc
24
24
 
@@ -162,6 +162,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
162
162
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
163
163
  keepalive_time_ms: int = 210000,
164
164
  certificates: Optional[Tuple[bytes, bytes, bytes]] = None,
165
+ interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
165
166
  ) -> grpc.Server:
166
167
  """Create a gRPC server with a single servicer.
167
168
 
@@ -249,6 +250,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
249
250
  # returning RESOURCE_EXHAUSTED status, or None to indicate no limit.
250
251
  maximum_concurrent_rpcs=max_concurrent_workers,
251
252
  options=options,
253
+ interceptors=interceptors,
252
254
  )
253
255
  add_servicer_to_server_fn(servicer, server)
254
256