flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (64) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/build.py +151 -0
  3. flwr/cli/config_utils.py +18 -46
  4. flwr/cli/new/new.py +42 -18
  5. flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
  6. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
  7. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
  8. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +15 -29
  9. flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
  10. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
  12. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +9 -1
  13. flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
  14. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
  15. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
  16. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
  17. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
  18. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
  19. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
  20. flwr/cli/run/run.py +1 -1
  21. flwr/cli/utils.py +18 -17
  22. flwr/client/__init__.py +1 -1
  23. flwr/client/app.py +17 -93
  24. flwr/client/grpc_client/connection.py +6 -1
  25. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  26. flwr/client/grpc_rere_client/connection.py +17 -2
  27. flwr/client/mod/centraldp_mods.py +4 -2
  28. flwr/client/mod/localdp_mod.py +9 -3
  29. flwr/client/rest_client/connection.py +5 -1
  30. flwr/client/supernode/__init__.py +2 -0
  31. flwr/client/supernode/app.py +181 -7
  32. flwr/common/grpc.py +5 -1
  33. flwr/common/logger.py +37 -4
  34. flwr/common/message.py +105 -86
  35. flwr/common/record/parametersrecord.py +0 -1
  36. flwr/common/record/recordset.py +17 -5
  37. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
  38. flwr/server/app.py +111 -1
  39. flwr/server/compat/app.py +2 -2
  40. flwr/server/compat/app_utils.py +1 -1
  41. flwr/server/compat/driver_client_proxy.py +27 -72
  42. flwr/server/driver/__init__.py +3 -0
  43. flwr/server/driver/driver.py +12 -242
  44. flwr/server/driver/grpc_driver.py +315 -0
  45. flwr/server/run_serverapp.py +18 -4
  46. flwr/server/strategy/dp_adaptive_clipping.py +5 -3
  47. flwr/server/strategy/dp_fixed_clipping.py +6 -3
  48. flwr/server/superlink/driver/driver_servicer.py +1 -1
  49. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
  50. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
  51. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  52. flwr/server/superlink/fleet/vce/vce_api.py +1 -1
  53. flwr/server/superlink/state/in_memory_state.py +76 -8
  54. flwr/server/superlink/state/sqlite_state.py +116 -11
  55. flwr/server/superlink/state/state.py +35 -3
  56. flwr/simulation/__init__.py +2 -2
  57. flwr/simulation/app.py +16 -1
  58. flwr/simulation/run_simulation.py +10 -7
  59. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
  60. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +63 -52
  61. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +1 -1
  62. flwr/server/driver/abc_driver.py +0 -140
  63. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
  64. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
flwr/common/message.py CHANGED
@@ -18,14 +18,13 @@ from __future__ import annotations
18
18
 
19
19
  import time
20
20
  import warnings
21
- from dataclasses import dataclass
21
+ from typing import Optional, cast
22
22
 
23
23
  from .record import RecordSet
24
24
 
25
25
  DEFAULT_TTL = 3600
26
26
 
27
27
 
28
- @dataclass
29
28
  class Metadata: # pylint: disable=too-many-instance-attributes
30
29
  """A dataclass holding metadata associated with the current message.
31
30
 
@@ -55,17 +54,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
55
54
  is more relevant when conducting simulations.
56
55
  """
57
56
 
58
- _run_id: int
59
- _message_id: str
60
- _src_node_id: int
61
- _dst_node_id: int
62
- _reply_to_message: str
63
- _group_id: str
64
- _ttl: float
65
- _message_type: str
66
- _partition_id: int | None
67
- _created_at: float # Unix timestamp (in seconds) to be set upon message creation
68
-
69
57
  def __init__( # pylint: disable=too-many-arguments
70
58
  self,
71
59
  run_id: int,
@@ -78,98 +66,111 @@ class Metadata: # pylint: disable=too-many-instance-attributes
78
66
  message_type: str,
79
67
  partition_id: int | None = None,
80
68
  ) -> None:
81
- self._run_id = run_id
82
- self._message_id = message_id
83
- self._src_node_id = src_node_id
84
- self._dst_node_id = dst_node_id
85
- self._reply_to_message = reply_to_message
86
- self._group_id = group_id
87
- self._ttl = ttl
88
- self._message_type = message_type
89
- self._partition_id = partition_id
69
+ var_dict = {
70
+ "_run_id": run_id,
71
+ "_message_id": message_id,
72
+ "_src_node_id": src_node_id,
73
+ "_dst_node_id": dst_node_id,
74
+ "_reply_to_message": reply_to_message,
75
+ "_group_id": group_id,
76
+ "_ttl": ttl,
77
+ "_message_type": message_type,
78
+ "_partition_id": partition_id,
79
+ }
80
+ self.__dict__.update(var_dict)
90
81
 
91
82
  @property
92
83
  def run_id(self) -> int:
93
84
  """An identifier for the current run."""
94
- return self._run_id
85
+ return cast(int, self.__dict__["_run_id"])
95
86
 
96
87
  @property
97
88
  def message_id(self) -> str:
98
89
  """An identifier for the current message."""
99
- return self._message_id
90
+ return cast(str, self.__dict__["_message_id"])
100
91
 
101
92
  @property
102
93
  def src_node_id(self) -> int:
103
94
  """An identifier for the node sending this message."""
104
- return self._src_node_id
95
+ return cast(int, self.__dict__["_src_node_id"])
105
96
 
106
97
  @property
107
98
  def reply_to_message(self) -> str:
108
99
  """An identifier for the message this message replies to."""
109
- return self._reply_to_message
100
+ return cast(str, self.__dict__["_reply_to_message"])
110
101
 
111
102
  @property
112
103
  def dst_node_id(self) -> int:
113
104
  """An identifier for the node receiving this message."""
114
- return self._dst_node_id
105
+ return cast(int, self.__dict__["_dst_node_id"])
115
106
 
116
107
  @dst_node_id.setter
117
108
  def dst_node_id(self, value: int) -> None:
118
109
  """Set dst_node_id."""
119
- self._dst_node_id = value
110
+ self.__dict__["_dst_node_id"] = value
120
111
 
121
112
  @property
122
113
  def group_id(self) -> str:
123
114
  """An identifier for grouping messages."""
124
- return self._group_id
115
+ return cast(str, self.__dict__["_group_id"])
125
116
 
126
117
  @group_id.setter
127
118
  def group_id(self, value: str) -> None:
128
119
  """Set group_id."""
129
- self._group_id = value
120
+ self.__dict__["_group_id"] = value
130
121
 
131
122
  @property
132
123
  def created_at(self) -> float:
133
124
  """Unix timestamp when the message was created."""
134
- return self._created_at
125
+ return cast(float, self.__dict__["_created_at"])
135
126
 
136
127
  @created_at.setter
137
128
  def created_at(self, value: float) -> None:
138
- """Set creation timestamp for this messages."""
139
- self._created_at = value
129
+ """Set creation timestamp for this message."""
130
+ self.__dict__["_created_at"] = value
140
131
 
141
132
  @property
142
133
  def ttl(self) -> float:
143
134
  """Time-to-live for this message."""
144
- return self._ttl
135
+ return cast(float, self.__dict__["_ttl"])
145
136
 
146
137
  @ttl.setter
147
138
  def ttl(self, value: float) -> None:
148
139
  """Set ttl."""
149
- self._ttl = value
140
+ self.__dict__["_ttl"] = value
150
141
 
151
142
  @property
152
143
  def message_type(self) -> str:
153
144
  """A string that encodes the action to be executed on the receiving end."""
154
- return self._message_type
145
+ return cast(str, self.__dict__["_message_type"])
155
146
 
156
147
  @message_type.setter
157
148
  def message_type(self, value: str) -> None:
158
149
  """Set message_type."""
159
- self._message_type = value
150
+ self.__dict__["_message_type"] = value
160
151
 
161
152
  @property
162
153
  def partition_id(self) -> int | None:
163
154
  """An identifier telling which data partition a ClientApp should use."""
164
- return self._partition_id
155
+ return cast(int, self.__dict__["_partition_id"])
165
156
 
166
157
  @partition_id.setter
167
158
  def partition_id(self, value: int) -> None:
168
- """Set patition_id."""
169
- self._partition_id = value
159
+ """Set partition_id."""
160
+ self.__dict__["_partition_id"] = value
161
+
162
+ def __repr__(self) -> str:
163
+ """Return a string representation of this instance."""
164
+ view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
165
+ return f"{self.__class__.__qualname__}({view})"
166
+
167
+ def __eq__(self, other: object) -> bool:
168
+ """Compare two instances of the class."""
169
+ if not isinstance(other, self.__class__):
170
+ raise NotImplementedError
171
+ return self.__dict__ == other.__dict__
170
172
 
171
173
 
172
- @dataclass
173
174
  class Error:
174
175
  """A dataclass that stores information about an error that occurred.
175
176
 
@@ -181,25 +182,35 @@ class Error:
181
182
  A reason for why the error arose (e.g. an exception stack-trace)
182
183
  """
183
184
 
184
- _code: int
185
- _reason: str | None = None
186
-
187
185
  def __init__(self, code: int, reason: str | None = None) -> None:
188
- self._code = code
189
- self._reason = reason
186
+ var_dict = {
187
+ "_code": code,
188
+ "_reason": reason,
189
+ }
190
+ self.__dict__.update(var_dict)
190
191
 
191
192
  @property
192
193
  def code(self) -> int:
193
194
  """Error code."""
194
- return self._code
195
+ return cast(int, self.__dict__["_code"])
195
196
 
196
197
  @property
197
198
  def reason(self) -> str | None:
198
199
  """Reason reported about the error."""
199
- return self._reason
200
+ return cast(Optional[str], self.__dict__["_reason"])
201
+
202
+ def __repr__(self) -> str:
203
+ """Return a string representation of this instance."""
204
+ view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
205
+ return f"{self.__class__.__qualname__}({view})"
206
+
207
+ def __eq__(self, other: object) -> bool:
208
+ """Compare two instances of the class."""
209
+ if not isinstance(other, self.__class__):
210
+ raise NotImplementedError
211
+ return self.__dict__ == other.__dict__
200
212
 
201
213
 
202
- @dataclass
203
214
  class Message:
204
215
  """State of your application from the viewpoint of the entity using it.
205
216
 
@@ -215,88 +226,70 @@ class Message:
215
226
  when processing another message.
216
227
  """
217
228
 
218
- _metadata: Metadata
219
- _content: RecordSet | None = None
220
- _error: Error | None = None
221
-
222
229
  def __init__(
223
230
  self,
224
231
  metadata: Metadata,
225
232
  content: RecordSet | None = None,
226
233
  error: Error | None = None,
227
234
  ) -> None:
228
- self._metadata = metadata
229
-
230
- # Set message creation timestamp
231
- self._metadata.created_at = time.time()
232
-
233
235
  if not (content is None) ^ (error is None):
234
236
  raise ValueError("Either `content` or `error` must be set, but not both.")
235
237
 
236
- self._content = content
237
- self._error = error
238
+ metadata.created_at = time.time() # Set the message creation timestamp
239
+ var_dict = {
240
+ "_metadata": metadata,
241
+ "_content": content,
242
+ "_error": error,
243
+ }
244
+ self.__dict__.update(var_dict)
238
245
 
239
246
  @property
240
247
  def metadata(self) -> Metadata:
241
248
  """A dataclass including information about the message to be executed."""
242
- return self._metadata
249
+ return cast(Metadata, self.__dict__["_metadata"])
243
250
 
244
251
  @property
245
252
  def content(self) -> RecordSet:
246
253
  """The content of this message."""
247
- if self._content is None:
254
+ if self.__dict__["_content"] is None:
248
255
  raise ValueError(
249
256
  "Message content is None. Use <message>.has_content() "
250
257
  "to check if a message has content."
251
258
  )
252
- return self._content
259
+ return cast(RecordSet, self.__dict__["_content"])
253
260
 
254
261
  @content.setter
255
262
  def content(self, value: RecordSet) -> None:
256
263
  """Set content."""
257
- if self._error is None:
258
- self._content = value
264
+ if self.__dict__["_error"] is None:
265
+ self.__dict__["_content"] = value
259
266
  else:
260
267
  raise ValueError("A message with an error set cannot have content.")
261
268
 
262
269
  @property
263
270
  def error(self) -> Error:
264
271
  """Error captured by this message."""
265
- if self._error is None:
272
+ if self.__dict__["_error"] is None:
266
273
  raise ValueError(
267
274
  "Message error is None. Use <message>.has_error() "
268
275
  "to check first if a message carries an error."
269
276
  )
270
- return self._error
277
+ return cast(Error, self.__dict__["_error"])
271
278
 
272
279
  @error.setter
273
280
  def error(self, value: Error) -> None:
274
281
  """Set error."""
275
282
  if self.has_content():
276
283
  raise ValueError("A message with content set cannot carry an error.")
277
- self._error = value
284
+ self.__dict__["_error"] = value
278
285
 
279
286
  def has_content(self) -> bool:
280
287
  """Return True if message has content, else False."""
281
- return self._content is not None
288
+ return self.__dict__["_content"] is not None
282
289
 
283
290
  def has_error(self) -> bool:
284
291
  """Return True if message has an error, else False."""
285
- return self._error is not None
286
-
287
- def _create_reply_metadata(self, ttl: float) -> Metadata:
288
- """Construct metadata for a reply message."""
289
- return Metadata(
290
- run_id=self.metadata.run_id,
291
- message_id="",
292
- src_node_id=self.metadata.dst_node_id,
293
- dst_node_id=self.metadata.src_node_id,
294
- reply_to_message=self.metadata.message_id,
295
- group_id=self.metadata.group_id,
296
- ttl=ttl,
297
- message_type=self.metadata.message_type,
298
- partition_id=self.metadata.partition_id,
299
- )
292
+ return self.__dict__["_error"] is not None
300
293
 
301
294
  def create_error_reply(self, error: Error, ttl: float | None = None) -> Message:
302
295
  """Construct a reply message indicating an error happened.
@@ -323,7 +316,7 @@ class Message:
323
316
  # message creation)
324
317
  ttl_ = DEFAULT_TTL if ttl is None else ttl
325
318
  # Create reply with error
326
- message = Message(metadata=self._create_reply_metadata(ttl_), error=error)
319
+ message = Message(metadata=_create_reply_metadata(self, ttl_), error=error)
327
320
 
328
321
  if ttl is None:
329
322
  # Set TTL equal to the remaining time for the received message to expire
@@ -369,7 +362,7 @@ class Message:
369
362
  ttl_ = DEFAULT_TTL if ttl is None else ttl
370
363
 
371
364
  message = Message(
372
- metadata=self._create_reply_metadata(ttl_),
365
+ metadata=_create_reply_metadata(self, ttl_),
373
366
  content=content,
374
367
  )
375
368
 
@@ -381,3 +374,29 @@ class Message:
381
374
  message.metadata.ttl = ttl
382
375
 
383
376
  return message
377
+
378
+ def __repr__(self) -> str:
379
+ """Return a string representation of this instance."""
380
+ view = ", ".join(
381
+ [
382
+ f"{k.lstrip('_')}={v!r}"
383
+ for k, v in self.__dict__.items()
384
+ if v is not None
385
+ ]
386
+ )
387
+ return f"{self.__class__.__qualname__}({view})"
388
+
389
+
390
+ def _create_reply_metadata(msg: Message, ttl: float) -> Metadata:
391
+ """Construct metadata for a reply message."""
392
+ return Metadata(
393
+ run_id=msg.metadata.run_id,
394
+ message_id="",
395
+ src_node_id=msg.metadata.dst_node_id,
396
+ dst_node_id=msg.metadata.src_node_id,
397
+ reply_to_message=msg.metadata.message_id,
398
+ group_id=msg.metadata.group_id,
399
+ ttl=ttl,
400
+ message_type=msg.metadata.message_type,
401
+ partition_id=msg.metadata.partition_id,
402
+ )
@@ -82,7 +82,6 @@ def _check_value(value: Array) -> None:
82
82
  )
83
83
 
84
84
 
85
- @dataclass
86
85
  class ParametersRecord(TypedDict[str, Array]):
87
86
  """Parameters record.
88
87
 
@@ -24,6 +24,7 @@ from .parametersrecord import ParametersRecord
24
24
  from .typeddict import TypedDict
25
25
 
26
26
 
27
+ @dataclass
27
28
  class RecordSetData:
28
29
  """Inner data container for the RecordSet class."""
29
30
 
@@ -82,7 +83,6 @@ class RecordSetData:
82
83
  )
83
84
 
84
85
 
85
- @dataclass
86
86
  class RecordSet:
87
87
  """RecordSet stores groups of parameters, metrics and configs."""
88
88
 
@@ -97,22 +97,34 @@ class RecordSet:
97
97
  metrics_records=metrics_records,
98
98
  configs_records=configs_records,
99
99
  )
100
- setattr(self, "_data", data) # noqa
100
+ self.__dict__["_data"] = data
101
101
 
102
102
  @property
103
103
  def parameters_records(self) -> TypedDict[str, ParametersRecord]:
104
104
  """Dictionary holding ParametersRecord instances."""
105
- data = cast(RecordSetData, getattr(self, "_data")) # noqa
105
+ data = cast(RecordSetData, self.__dict__["_data"])
106
106
  return data.parameters_records
107
107
 
108
108
  @property
109
109
  def metrics_records(self) -> TypedDict[str, MetricsRecord]:
110
110
  """Dictionary holding MetricsRecord instances."""
111
- data = cast(RecordSetData, getattr(self, "_data")) # noqa
111
+ data = cast(RecordSetData, self.__dict__["_data"])
112
112
  return data.metrics_records
113
113
 
114
114
  @property
115
115
  def configs_records(self) -> TypedDict[str, ConfigsRecord]:
116
116
  """Dictionary holding ConfigsRecord instances."""
117
- data = cast(RecordSetData, getattr(self, "_data")) # noqa
117
+ data = cast(RecordSetData, self.__dict__["_data"])
118
118
  return data.configs_records
119
+
120
+ def __repr__(self) -> str:
121
+ """Return a string representation of this instance."""
122
+ flds = ("parameters_records", "metrics_records", "configs_records")
123
+ view = ", ".join([f"{fld}={getattr(self, fld)!r}" for fld in flds])
124
+ return f"{self.__class__.__qualname__}({view})"
125
+
126
+ def __eq__(self, other: object) -> bool:
127
+ """Compare two instances of the class."""
128
+ if not isinstance(other, self.__class__):
129
+ raise NotImplementedError
130
+ return self.__dict__ == other.__dict__
@@ -18,8 +18,9 @@
18
18
  import base64
19
19
  from typing import Tuple, cast
20
20
 
21
+ from cryptography.exceptions import InvalidSignature
21
22
  from cryptography.fernet import Fernet
22
- from cryptography.hazmat.primitives import hashes, serialization
23
+ from cryptography.hazmat.primitives import hashes, hmac, serialization
23
24
  from cryptography.hazmat.primitives.asymmetric import ec
24
25
  from cryptography.hazmat.primitives.kdf.hkdf import HKDF
25
26
 
@@ -98,3 +99,36 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes:
98
99
  # The input key must be url safe
99
100
  fernet = Fernet(key)
100
101
  return fernet.decrypt(ciphertext)
102
+
103
+
104
+ def compute_hmac(key: bytes, message: bytes) -> bytes:
105
+ """Compute hmac of a message using key as hash."""
106
+ computed_hmac = hmac.HMAC(key, hashes.SHA256())
107
+ computed_hmac.update(message)
108
+ return computed_hmac.finalize()
109
+
110
+
111
+ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
112
+ """Verify hmac of a message using key as hash."""
113
+ computed_hmac = hmac.HMAC(key, hashes.SHA256())
114
+ computed_hmac.update(message)
115
+ try:
116
+ computed_hmac.verify(hmac_value)
117
+ return True
118
+ except InvalidSignature:
119
+ return False
120
+
121
+
122
+ def ssh_types_to_elliptic_curve(
123
+ private_key: serialization.SSHPrivateKeyTypes,
124
+ public_key: serialization.SSHPublicKeyTypes,
125
+ ) -> Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]:
126
+ """Cast SSH key types to elliptic curve."""
127
+ if isinstance(private_key, ec.EllipticCurvePrivateKey) and isinstance(
128
+ public_key, ec.EllipticCurvePublicKey
129
+ ):
130
+ return (private_key, public_key)
131
+
132
+ raise TypeError(
133
+ "The provided key is not an EllipticCurvePrivateKey or EllipticCurvePublicKey"
134
+ )
flwr/server/app.py CHANGED
@@ -16,15 +16,21 @@
16
16
 
17
17
  import argparse
18
18
  import asyncio
19
+ import csv
19
20
  import importlib.util
20
21
  import sys
21
22
  import threading
22
23
  from logging import ERROR, INFO, WARN
23
24
  from os.path import isfile
24
25
  from pathlib import Path
25
- from typing import List, Optional, Tuple
26
+ from typing import List, Optional, Sequence, Set, Tuple
26
27
 
27
28
  import grpc
29
+ from cryptography.hazmat.primitives.asymmetric import ec
30
+ from cryptography.hazmat.primitives.serialization import (
31
+ load_ssh_private_key,
32
+ load_ssh_public_key,
33
+ )
28
34
 
29
35
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
30
36
  from flwr.common.address import parse_address
@@ -36,6 +42,11 @@ from flwr.common.constant import (
36
42
  )
37
43
  from flwr.common.exit_handlers import register_exit_handlers
38
44
  from flwr.common.logger import log
45
+ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
46
+ private_key_to_bytes,
47
+ public_key_to_bytes,
48
+ ssh_types_to_elliptic_curve,
49
+ )
39
50
  from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
40
51
  add_FleetServicer_to_server,
41
52
  )
@@ -51,6 +62,7 @@ from .superlink.fleet.grpc_bidi.grpc_server import (
51
62
  start_grpc_server,
52
63
  )
53
64
  from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
65
+ from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
54
66
  from .superlink.fleet.vce import start_vce
55
67
  from .superlink.state import StateFactory
56
68
 
@@ -354,10 +366,33 @@ def run_superlink() -> None:
354
366
  sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
355
367
  host, port, is_v6 = parsed_address
356
368
  address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
369
+
370
+ maybe_keys = _try_setup_client_authentication(args, certificates)
371
+ interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
372
+ if maybe_keys is not None:
373
+ (
374
+ client_public_keys,
375
+ server_private_key,
376
+ server_public_key,
377
+ ) = maybe_keys
378
+ state = state_factory.state()
379
+ state.store_client_public_keys(client_public_keys)
380
+ state.store_server_private_public_key(
381
+ private_key_to_bytes(server_private_key),
382
+ public_key_to_bytes(server_public_key),
383
+ )
384
+ log(
385
+ INFO,
386
+ "Client authentication enabled with %d known public keys",
387
+ len(client_public_keys),
388
+ )
389
+ interceptors = [AuthenticateServerInterceptor(state)]
390
+
357
391
  fleet_server = _run_fleet_api_grpc_rere(
358
392
  address=address,
359
393
  state_factory=state_factory,
360
394
  certificates=certificates,
395
+ interceptors=interceptors,
361
396
  )
362
397
  grpc_servers.append(fleet_server)
363
398
  elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
@@ -390,6 +425,70 @@ def run_superlink() -> None:
390
425
  driver_server.wait_for_termination(timeout=1)
391
426
 
392
427
 
428
+ def _try_setup_client_authentication(
429
+ args: argparse.Namespace,
430
+ certificates: Optional[Tuple[bytes, bytes, bytes]],
431
+ ) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
432
+ if not args.require_client_authentication:
433
+ return None
434
+
435
+ if certificates is None:
436
+ sys.exit(
437
+ "Client authentication only works over secure connections. "
438
+ "Please provide certificate paths using '--certificates' when "
439
+ "enabling '--require-client-authentication'."
440
+ )
441
+
442
+ client_keys_file_path = Path(args.require_client_authentication[0])
443
+ if not client_keys_file_path.exists():
444
+ sys.exit(
445
+ "The provided path to the client public keys CSV file does not exist: "
446
+ f"{client_keys_file_path}. "
447
+ "Please provide the CSV file path containing known client public keys "
448
+ "to '--require-client-authentication'."
449
+ )
450
+
451
+ client_public_keys: Set[bytes] = set()
452
+ ssh_private_key = load_ssh_private_key(
453
+ Path(args.require_client_authentication[1]).read_bytes(),
454
+ None,
455
+ )
456
+ ssh_public_key = load_ssh_public_key(
457
+ Path(args.require_client_authentication[2]).read_bytes()
458
+ )
459
+
460
+ try:
461
+ server_private_key, server_public_key = ssh_types_to_elliptic_curve(
462
+ ssh_private_key, ssh_public_key
463
+ )
464
+ except TypeError:
465
+ sys.exit(
466
+ "The file paths provided could not be read as a private and public "
467
+ "key pair. Client authentication requires an elliptic curve public and "
468
+ "private key pair. Please provide the file paths containing elliptic "
469
+ "curve private and public keys to '--require-client-authentication'."
470
+ )
471
+
472
+ with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile:
473
+ reader = csv.reader(csvfile)
474
+ for row in reader:
475
+ for element in row:
476
+ public_key = load_ssh_public_key(element.encode())
477
+ if isinstance(public_key, ec.EllipticCurvePublicKey):
478
+ client_public_keys.add(public_key_to_bytes(public_key))
479
+ else:
480
+ sys.exit(
481
+ "Error: Unable to parse the public keys in the .csv "
482
+ "file. Please ensure that the .csv file contains valid "
483
+ "SSH public keys and try again."
484
+ )
485
+ return (
486
+ client_public_keys,
487
+ server_private_key,
488
+ server_public_key,
489
+ )
490
+
491
+
393
492
  def _try_obtain_certificates(
394
493
  args: argparse.Namespace,
395
494
  ) -> Optional[Tuple[bytes, bytes, bytes]]:
@@ -417,6 +516,7 @@ def _run_fleet_api_grpc_rere(
417
516
  address: str,
418
517
  state_factory: StateFactory,
419
518
  certificates: Optional[Tuple[bytes, bytes, bytes]],
519
+ interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
420
520
  ) -> grpc.Server:
421
521
  """Run Fleet API (gRPC, request-response)."""
422
522
  # Create Fleet API gRPC server
@@ -429,6 +529,7 @@ def _run_fleet_api_grpc_rere(
429
529
  server_address=address,
430
530
  max_message_length=GRPC_MAX_MESSAGE_LENGTH,
431
531
  certificates=certificates,
532
+ interceptors=interceptors,
432
533
  )
433
534
 
434
535
  log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address)
@@ -606,6 +707,15 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
606
707
  "Flower will just create a state in memory.",
607
708
  default=DATABASE,
608
709
  )
710
+ parser.add_argument(
711
+ "--require-client-authentication",
712
+ nargs=3,
713
+ metavar=("CLIENT_KEYS", "SERVER_PRIVATE_KEY", "SERVER_PUBLIC_KEY"),
714
+ type=str,
715
+ help="Provide three file paths: (1) a .csv file containing a list of "
716
+ "known client public keys for authentication, (2) the server's private "
717
+ "key file, and (3) the server's public key file.",
718
+ )
609
719
 
610
720
 
611
721
  def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
flwr/server/compat/app.py CHANGED
@@ -29,7 +29,7 @@ from flwr.server.server import Server, init_defaults, run_fl
29
29
  from flwr.server.server_config import ServerConfig
30
30
  from flwr.server.strategy import Strategy
31
31
 
32
- from ..driver import Driver
32
+ from ..driver import Driver, GrpcDriver
33
33
  from .app_utils import start_update_client_manager_thread
34
34
 
35
35
  DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
@@ -114,7 +114,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
114
114
  # Create the Driver
115
115
  if isinstance(root_certificates, str):
116
116
  root_certificates = Path(root_certificates).read_bytes()
117
- driver = Driver(
117
+ driver = GrpcDriver(
118
118
  driver_service_address=address, root_certificates=root_certificates
119
119
  )
120
120
 
@@ -89,7 +89,7 @@ def _update_client_manager(
89
89
  for node_id in new_nodes:
90
90
  client_proxy = DriverClientProxy(
91
91
  node_id=node_id,
92
- driver=driver.grpc_driver_helper, # type: ignore
92
+ driver=driver,
93
93
  anonymous=False,
94
94
  run_id=driver.run_id, # type: ignore
95
95
  )