flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +151 -0
- flwr/cli/config_utils.py +18 -46
- flwr/cli/new/new.py +42 -18
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +15 -29
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +9 -1
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
- flwr/cli/run/run.py +1 -1
- flwr/cli/utils.py +18 -17
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +17 -93
- flwr/client/grpc_client/connection.py +6 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +17 -2
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/localdp_mod.py +9 -3
- flwr/client/rest_client/connection.py +5 -1
- flwr/client/supernode/__init__.py +2 -0
- flwr/client/supernode/app.py +181 -7
- flwr/common/grpc.py +5 -1
- flwr/common/logger.py +37 -4
- flwr/common/message.py +105 -86
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +17 -5
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
- flwr/server/app.py +111 -1
- flwr/server/compat/app.py +2 -2
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +27 -72
- flwr/server/driver/__init__.py +3 -0
- flwr/server/driver/driver.py +12 -242
- flwr/server/driver/grpc_driver.py +315 -0
- flwr/server/run_serverapp.py +18 -4
- flwr/server/strategy/dp_adaptive_clipping.py +5 -3
- flwr/server/strategy/dp_fixed_clipping.py +6 -3
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +76 -8
- flwr/server/superlink/state/sqlite_state.py +116 -11
- flwr/server/superlink/state/state.py +35 -3
- flwr/simulation/__init__.py +2 -2
- flwr/simulation/app.py +16 -1
- flwr/simulation/run_simulation.py +10 -7
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +63 -52
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +1 -1
- flwr/server/driver/abc_driver.py +0 -140
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
flwr/common/message.py
CHANGED
|
@@ -18,14 +18,13 @@ from __future__ import annotations
|
|
|
18
18
|
|
|
19
19
|
import time
|
|
20
20
|
import warnings
|
|
21
|
-
from
|
|
21
|
+
from typing import Optional, cast
|
|
22
22
|
|
|
23
23
|
from .record import RecordSet
|
|
24
24
|
|
|
25
25
|
DEFAULT_TTL = 3600
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
@dataclass
|
|
29
28
|
class Metadata: # pylint: disable=too-many-instance-attributes
|
|
30
29
|
"""A dataclass holding metadata associated with the current message.
|
|
31
30
|
|
|
@@ -55,17 +54,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
55
54
|
is more relevant when conducting simulations.
|
|
56
55
|
"""
|
|
57
56
|
|
|
58
|
-
_run_id: int
|
|
59
|
-
_message_id: str
|
|
60
|
-
_src_node_id: int
|
|
61
|
-
_dst_node_id: int
|
|
62
|
-
_reply_to_message: str
|
|
63
|
-
_group_id: str
|
|
64
|
-
_ttl: float
|
|
65
|
-
_message_type: str
|
|
66
|
-
_partition_id: int | None
|
|
67
|
-
_created_at: float # Unix timestamp (in seconds) to be set upon message creation
|
|
68
|
-
|
|
69
57
|
def __init__( # pylint: disable=too-many-arguments
|
|
70
58
|
self,
|
|
71
59
|
run_id: int,
|
|
@@ -78,98 +66,111 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
78
66
|
message_type: str,
|
|
79
67
|
partition_id: int | None = None,
|
|
80
68
|
) -> None:
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
69
|
+
var_dict = {
|
|
70
|
+
"_run_id": run_id,
|
|
71
|
+
"_message_id": message_id,
|
|
72
|
+
"_src_node_id": src_node_id,
|
|
73
|
+
"_dst_node_id": dst_node_id,
|
|
74
|
+
"_reply_to_message": reply_to_message,
|
|
75
|
+
"_group_id": group_id,
|
|
76
|
+
"_ttl": ttl,
|
|
77
|
+
"_message_type": message_type,
|
|
78
|
+
"_partition_id": partition_id,
|
|
79
|
+
}
|
|
80
|
+
self.__dict__.update(var_dict)
|
|
90
81
|
|
|
91
82
|
@property
|
|
92
83
|
def run_id(self) -> int:
|
|
93
84
|
"""An identifier for the current run."""
|
|
94
|
-
return self._run_id
|
|
85
|
+
return cast(int, self.__dict__["_run_id"])
|
|
95
86
|
|
|
96
87
|
@property
|
|
97
88
|
def message_id(self) -> str:
|
|
98
89
|
"""An identifier for the current message."""
|
|
99
|
-
return self._message_id
|
|
90
|
+
return cast(str, self.__dict__["_message_id"])
|
|
100
91
|
|
|
101
92
|
@property
|
|
102
93
|
def src_node_id(self) -> int:
|
|
103
94
|
"""An identifier for the node sending this message."""
|
|
104
|
-
return self._src_node_id
|
|
95
|
+
return cast(int, self.__dict__["_src_node_id"])
|
|
105
96
|
|
|
106
97
|
@property
|
|
107
98
|
def reply_to_message(self) -> str:
|
|
108
99
|
"""An identifier for the message this message replies to."""
|
|
109
|
-
return self._reply_to_message
|
|
100
|
+
return cast(str, self.__dict__["_reply_to_message"])
|
|
110
101
|
|
|
111
102
|
@property
|
|
112
103
|
def dst_node_id(self) -> int:
|
|
113
104
|
"""An identifier for the node receiving this message."""
|
|
114
|
-
return self._dst_node_id
|
|
105
|
+
return cast(int, self.__dict__["_dst_node_id"])
|
|
115
106
|
|
|
116
107
|
@dst_node_id.setter
|
|
117
108
|
def dst_node_id(self, value: int) -> None:
|
|
118
109
|
"""Set dst_node_id."""
|
|
119
|
-
self._dst_node_id = value
|
|
110
|
+
self.__dict__["_dst_node_id"] = value
|
|
120
111
|
|
|
121
112
|
@property
|
|
122
113
|
def group_id(self) -> str:
|
|
123
114
|
"""An identifier for grouping messages."""
|
|
124
|
-
return self._group_id
|
|
115
|
+
return cast(str, self.__dict__["_group_id"])
|
|
125
116
|
|
|
126
117
|
@group_id.setter
|
|
127
118
|
def group_id(self, value: str) -> None:
|
|
128
119
|
"""Set group_id."""
|
|
129
|
-
self._group_id = value
|
|
120
|
+
self.__dict__["_group_id"] = value
|
|
130
121
|
|
|
131
122
|
@property
|
|
132
123
|
def created_at(self) -> float:
|
|
133
124
|
"""Unix timestamp when the message was created."""
|
|
134
|
-
return self._created_at
|
|
125
|
+
return cast(float, self.__dict__["_created_at"])
|
|
135
126
|
|
|
136
127
|
@created_at.setter
|
|
137
128
|
def created_at(self, value: float) -> None:
|
|
138
|
-
"""Set creation timestamp for this
|
|
139
|
-
self._created_at = value
|
|
129
|
+
"""Set creation timestamp for this message."""
|
|
130
|
+
self.__dict__["_created_at"] = value
|
|
140
131
|
|
|
141
132
|
@property
|
|
142
133
|
def ttl(self) -> float:
|
|
143
134
|
"""Time-to-live for this message."""
|
|
144
|
-
return self._ttl
|
|
135
|
+
return cast(float, self.__dict__["_ttl"])
|
|
145
136
|
|
|
146
137
|
@ttl.setter
|
|
147
138
|
def ttl(self, value: float) -> None:
|
|
148
139
|
"""Set ttl."""
|
|
149
|
-
self._ttl = value
|
|
140
|
+
self.__dict__["_ttl"] = value
|
|
150
141
|
|
|
151
142
|
@property
|
|
152
143
|
def message_type(self) -> str:
|
|
153
144
|
"""A string that encodes the action to be executed on the receiving end."""
|
|
154
|
-
return self._message_type
|
|
145
|
+
return cast(str, self.__dict__["_message_type"])
|
|
155
146
|
|
|
156
147
|
@message_type.setter
|
|
157
148
|
def message_type(self, value: str) -> None:
|
|
158
149
|
"""Set message_type."""
|
|
159
|
-
self._message_type = value
|
|
150
|
+
self.__dict__["_message_type"] = value
|
|
160
151
|
|
|
161
152
|
@property
|
|
162
153
|
def partition_id(self) -> int | None:
|
|
163
154
|
"""An identifier telling which data partition a ClientApp should use."""
|
|
164
|
-
return self._partition_id
|
|
155
|
+
return cast(int, self.__dict__["_partition_id"])
|
|
165
156
|
|
|
166
157
|
@partition_id.setter
|
|
167
158
|
def partition_id(self, value: int) -> None:
|
|
168
|
-
"""Set
|
|
169
|
-
self._partition_id = value
|
|
159
|
+
"""Set partition_id."""
|
|
160
|
+
self.__dict__["_partition_id"] = value
|
|
161
|
+
|
|
162
|
+
def __repr__(self) -> str:
|
|
163
|
+
"""Return a string representation of this instance."""
|
|
164
|
+
view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
|
|
165
|
+
return f"{self.__class__.__qualname__}({view})"
|
|
166
|
+
|
|
167
|
+
def __eq__(self, other: object) -> bool:
|
|
168
|
+
"""Compare two instances of the class."""
|
|
169
|
+
if not isinstance(other, self.__class__):
|
|
170
|
+
raise NotImplementedError
|
|
171
|
+
return self.__dict__ == other.__dict__
|
|
170
172
|
|
|
171
173
|
|
|
172
|
-
@dataclass
|
|
173
174
|
class Error:
|
|
174
175
|
"""A dataclass that stores information about an error that occurred.
|
|
175
176
|
|
|
@@ -181,25 +182,35 @@ class Error:
|
|
|
181
182
|
A reason for why the error arose (e.g. an exception stack-trace)
|
|
182
183
|
"""
|
|
183
184
|
|
|
184
|
-
_code: int
|
|
185
|
-
_reason: str | None = None
|
|
186
|
-
|
|
187
185
|
def __init__(self, code: int, reason: str | None = None) -> None:
|
|
188
|
-
|
|
189
|
-
|
|
186
|
+
var_dict = {
|
|
187
|
+
"_code": code,
|
|
188
|
+
"_reason": reason,
|
|
189
|
+
}
|
|
190
|
+
self.__dict__.update(var_dict)
|
|
190
191
|
|
|
191
192
|
@property
|
|
192
193
|
def code(self) -> int:
|
|
193
194
|
"""Error code."""
|
|
194
|
-
return self._code
|
|
195
|
+
return cast(int, self.__dict__["_code"])
|
|
195
196
|
|
|
196
197
|
@property
|
|
197
198
|
def reason(self) -> str | None:
|
|
198
199
|
"""Reason reported about the error."""
|
|
199
|
-
return self._reason
|
|
200
|
+
return cast(Optional[str], self.__dict__["_reason"])
|
|
201
|
+
|
|
202
|
+
def __repr__(self) -> str:
|
|
203
|
+
"""Return a string representation of this instance."""
|
|
204
|
+
view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
|
|
205
|
+
return f"{self.__class__.__qualname__}({view})"
|
|
206
|
+
|
|
207
|
+
def __eq__(self, other: object) -> bool:
|
|
208
|
+
"""Compare two instances of the class."""
|
|
209
|
+
if not isinstance(other, self.__class__):
|
|
210
|
+
raise NotImplementedError
|
|
211
|
+
return self.__dict__ == other.__dict__
|
|
200
212
|
|
|
201
213
|
|
|
202
|
-
@dataclass
|
|
203
214
|
class Message:
|
|
204
215
|
"""State of your application from the viewpoint of the entity using it.
|
|
205
216
|
|
|
@@ -215,88 +226,70 @@ class Message:
|
|
|
215
226
|
when processing another message.
|
|
216
227
|
"""
|
|
217
228
|
|
|
218
|
-
_metadata: Metadata
|
|
219
|
-
_content: RecordSet | None = None
|
|
220
|
-
_error: Error | None = None
|
|
221
|
-
|
|
222
229
|
def __init__(
|
|
223
230
|
self,
|
|
224
231
|
metadata: Metadata,
|
|
225
232
|
content: RecordSet | None = None,
|
|
226
233
|
error: Error | None = None,
|
|
227
234
|
) -> None:
|
|
228
|
-
self._metadata = metadata
|
|
229
|
-
|
|
230
|
-
# Set message creation timestamp
|
|
231
|
-
self._metadata.created_at = time.time()
|
|
232
|
-
|
|
233
235
|
if not (content is None) ^ (error is None):
|
|
234
236
|
raise ValueError("Either `content` or `error` must be set, but not both.")
|
|
235
237
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
+
metadata.created_at = time.time() # Set the message creation timestamp
|
|
239
|
+
var_dict = {
|
|
240
|
+
"_metadata": metadata,
|
|
241
|
+
"_content": content,
|
|
242
|
+
"_error": error,
|
|
243
|
+
}
|
|
244
|
+
self.__dict__.update(var_dict)
|
|
238
245
|
|
|
239
246
|
@property
|
|
240
247
|
def metadata(self) -> Metadata:
|
|
241
248
|
"""A dataclass including information about the message to be executed."""
|
|
242
|
-
return self._metadata
|
|
249
|
+
return cast(Metadata, self.__dict__["_metadata"])
|
|
243
250
|
|
|
244
251
|
@property
|
|
245
252
|
def content(self) -> RecordSet:
|
|
246
253
|
"""The content of this message."""
|
|
247
|
-
if self._content is None:
|
|
254
|
+
if self.__dict__["_content"] is None:
|
|
248
255
|
raise ValueError(
|
|
249
256
|
"Message content is None. Use <message>.has_content() "
|
|
250
257
|
"to check if a message has content."
|
|
251
258
|
)
|
|
252
|
-
return self._content
|
|
259
|
+
return cast(RecordSet, self.__dict__["_content"])
|
|
253
260
|
|
|
254
261
|
@content.setter
|
|
255
262
|
def content(self, value: RecordSet) -> None:
|
|
256
263
|
"""Set content."""
|
|
257
|
-
if self._error is None:
|
|
258
|
-
self._content = value
|
|
264
|
+
if self.__dict__["_error"] is None:
|
|
265
|
+
self.__dict__["_content"] = value
|
|
259
266
|
else:
|
|
260
267
|
raise ValueError("A message with an error set cannot have content.")
|
|
261
268
|
|
|
262
269
|
@property
|
|
263
270
|
def error(self) -> Error:
|
|
264
271
|
"""Error captured by this message."""
|
|
265
|
-
if self._error is None:
|
|
272
|
+
if self.__dict__["_error"] is None:
|
|
266
273
|
raise ValueError(
|
|
267
274
|
"Message error is None. Use <message>.has_error() "
|
|
268
275
|
"to check first if a message carries an error."
|
|
269
276
|
)
|
|
270
|
-
return self._error
|
|
277
|
+
return cast(Error, self.__dict__["_error"])
|
|
271
278
|
|
|
272
279
|
@error.setter
|
|
273
280
|
def error(self, value: Error) -> None:
|
|
274
281
|
"""Set error."""
|
|
275
282
|
if self.has_content():
|
|
276
283
|
raise ValueError("A message with content set cannot carry an error.")
|
|
277
|
-
self._error = value
|
|
284
|
+
self.__dict__["_error"] = value
|
|
278
285
|
|
|
279
286
|
def has_content(self) -> bool:
|
|
280
287
|
"""Return True if message has content, else False."""
|
|
281
|
-
return self._content is not None
|
|
288
|
+
return self.__dict__["_content"] is not None
|
|
282
289
|
|
|
283
290
|
def has_error(self) -> bool:
|
|
284
291
|
"""Return True if message has an error, else False."""
|
|
285
|
-
return self._error is not None
|
|
286
|
-
|
|
287
|
-
def _create_reply_metadata(self, ttl: float) -> Metadata:
|
|
288
|
-
"""Construct metadata for a reply message."""
|
|
289
|
-
return Metadata(
|
|
290
|
-
run_id=self.metadata.run_id,
|
|
291
|
-
message_id="",
|
|
292
|
-
src_node_id=self.metadata.dst_node_id,
|
|
293
|
-
dst_node_id=self.metadata.src_node_id,
|
|
294
|
-
reply_to_message=self.metadata.message_id,
|
|
295
|
-
group_id=self.metadata.group_id,
|
|
296
|
-
ttl=ttl,
|
|
297
|
-
message_type=self.metadata.message_type,
|
|
298
|
-
partition_id=self.metadata.partition_id,
|
|
299
|
-
)
|
|
292
|
+
return self.__dict__["_error"] is not None
|
|
300
293
|
|
|
301
294
|
def create_error_reply(self, error: Error, ttl: float | None = None) -> Message:
|
|
302
295
|
"""Construct a reply message indicating an error happened.
|
|
@@ -323,7 +316,7 @@ class Message:
|
|
|
323
316
|
# message creation)
|
|
324
317
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
325
318
|
# Create reply with error
|
|
326
|
-
message = Message(metadata=
|
|
319
|
+
message = Message(metadata=_create_reply_metadata(self, ttl_), error=error)
|
|
327
320
|
|
|
328
321
|
if ttl is None:
|
|
329
322
|
# Set TTL equal to the remaining time for the received message to expire
|
|
@@ -369,7 +362,7 @@ class Message:
|
|
|
369
362
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
370
363
|
|
|
371
364
|
message = Message(
|
|
372
|
-
metadata=
|
|
365
|
+
metadata=_create_reply_metadata(self, ttl_),
|
|
373
366
|
content=content,
|
|
374
367
|
)
|
|
375
368
|
|
|
@@ -381,3 +374,29 @@ class Message:
|
|
|
381
374
|
message.metadata.ttl = ttl
|
|
382
375
|
|
|
383
376
|
return message
|
|
377
|
+
|
|
378
|
+
def __repr__(self) -> str:
|
|
379
|
+
"""Return a string representation of this instance."""
|
|
380
|
+
view = ", ".join(
|
|
381
|
+
[
|
|
382
|
+
f"{k.lstrip('_')}={v!r}"
|
|
383
|
+
for k, v in self.__dict__.items()
|
|
384
|
+
if v is not None
|
|
385
|
+
]
|
|
386
|
+
)
|
|
387
|
+
return f"{self.__class__.__qualname__}({view})"
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def _create_reply_metadata(msg: Message, ttl: float) -> Metadata:
|
|
391
|
+
"""Construct metadata for a reply message."""
|
|
392
|
+
return Metadata(
|
|
393
|
+
run_id=msg.metadata.run_id,
|
|
394
|
+
message_id="",
|
|
395
|
+
src_node_id=msg.metadata.dst_node_id,
|
|
396
|
+
dst_node_id=msg.metadata.src_node_id,
|
|
397
|
+
reply_to_message=msg.metadata.message_id,
|
|
398
|
+
group_id=msg.metadata.group_id,
|
|
399
|
+
ttl=ttl,
|
|
400
|
+
message_type=msg.metadata.message_type,
|
|
401
|
+
partition_id=msg.metadata.partition_id,
|
|
402
|
+
)
|
flwr/common/record/recordset.py
CHANGED
|
@@ -24,6 +24,7 @@ from .parametersrecord import ParametersRecord
|
|
|
24
24
|
from .typeddict import TypedDict
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
@dataclass
|
|
27
28
|
class RecordSetData:
|
|
28
29
|
"""Inner data container for the RecordSet class."""
|
|
29
30
|
|
|
@@ -82,7 +83,6 @@ class RecordSetData:
|
|
|
82
83
|
)
|
|
83
84
|
|
|
84
85
|
|
|
85
|
-
@dataclass
|
|
86
86
|
class RecordSet:
|
|
87
87
|
"""RecordSet stores groups of parameters, metrics and configs."""
|
|
88
88
|
|
|
@@ -97,22 +97,34 @@ class RecordSet:
|
|
|
97
97
|
metrics_records=metrics_records,
|
|
98
98
|
configs_records=configs_records,
|
|
99
99
|
)
|
|
100
|
-
|
|
100
|
+
self.__dict__["_data"] = data
|
|
101
101
|
|
|
102
102
|
@property
|
|
103
103
|
def parameters_records(self) -> TypedDict[str, ParametersRecord]:
|
|
104
104
|
"""Dictionary holding ParametersRecord instances."""
|
|
105
|
-
data = cast(RecordSetData,
|
|
105
|
+
data = cast(RecordSetData, self.__dict__["_data"])
|
|
106
106
|
return data.parameters_records
|
|
107
107
|
|
|
108
108
|
@property
|
|
109
109
|
def metrics_records(self) -> TypedDict[str, MetricsRecord]:
|
|
110
110
|
"""Dictionary holding MetricsRecord instances."""
|
|
111
|
-
data = cast(RecordSetData,
|
|
111
|
+
data = cast(RecordSetData, self.__dict__["_data"])
|
|
112
112
|
return data.metrics_records
|
|
113
113
|
|
|
114
114
|
@property
|
|
115
115
|
def configs_records(self) -> TypedDict[str, ConfigsRecord]:
|
|
116
116
|
"""Dictionary holding ConfigsRecord instances."""
|
|
117
|
-
data = cast(RecordSetData,
|
|
117
|
+
data = cast(RecordSetData, self.__dict__["_data"])
|
|
118
118
|
return data.configs_records
|
|
119
|
+
|
|
120
|
+
def __repr__(self) -> str:
|
|
121
|
+
"""Return a string representation of this instance."""
|
|
122
|
+
flds = ("parameters_records", "metrics_records", "configs_records")
|
|
123
|
+
view = ", ".join([f"{fld}={getattr(self, fld)!r}" for fld in flds])
|
|
124
|
+
return f"{self.__class__.__qualname__}({view})"
|
|
125
|
+
|
|
126
|
+
def __eq__(self, other: object) -> bool:
|
|
127
|
+
"""Compare two instances of the class."""
|
|
128
|
+
if not isinstance(other, self.__class__):
|
|
129
|
+
raise NotImplementedError
|
|
130
|
+
return self.__dict__ == other.__dict__
|
|
@@ -18,8 +18,9 @@
|
|
|
18
18
|
import base64
|
|
19
19
|
from typing import Tuple, cast
|
|
20
20
|
|
|
21
|
+
from cryptography.exceptions import InvalidSignature
|
|
21
22
|
from cryptography.fernet import Fernet
|
|
22
|
-
from cryptography.hazmat.primitives import hashes, serialization
|
|
23
|
+
from cryptography.hazmat.primitives import hashes, hmac, serialization
|
|
23
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
25
|
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
|
25
26
|
|
|
@@ -98,3 +99,36 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes:
|
|
|
98
99
|
# The input key must be url safe
|
|
99
100
|
fernet = Fernet(key)
|
|
100
101
|
return fernet.decrypt(ciphertext)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def compute_hmac(key: bytes, message: bytes) -> bytes:
|
|
105
|
+
"""Compute hmac of a message using key as hash."""
|
|
106
|
+
computed_hmac = hmac.HMAC(key, hashes.SHA256())
|
|
107
|
+
computed_hmac.update(message)
|
|
108
|
+
return computed_hmac.finalize()
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
|
|
112
|
+
"""Verify hmac of a message using key as hash."""
|
|
113
|
+
computed_hmac = hmac.HMAC(key, hashes.SHA256())
|
|
114
|
+
computed_hmac.update(message)
|
|
115
|
+
try:
|
|
116
|
+
computed_hmac.verify(hmac_value)
|
|
117
|
+
return True
|
|
118
|
+
except InvalidSignature:
|
|
119
|
+
return False
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def ssh_types_to_elliptic_curve(
|
|
123
|
+
private_key: serialization.SSHPrivateKeyTypes,
|
|
124
|
+
public_key: serialization.SSHPublicKeyTypes,
|
|
125
|
+
) -> Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]:
|
|
126
|
+
"""Cast SSH key types to elliptic curve."""
|
|
127
|
+
if isinstance(private_key, ec.EllipticCurvePrivateKey) and isinstance(
|
|
128
|
+
public_key, ec.EllipticCurvePublicKey
|
|
129
|
+
):
|
|
130
|
+
return (private_key, public_key)
|
|
131
|
+
|
|
132
|
+
raise TypeError(
|
|
133
|
+
"The provided key is not an EllipticCurvePrivateKey or EllipticCurvePublicKey"
|
|
134
|
+
)
|
flwr/server/app.py
CHANGED
|
@@ -16,15 +16,21 @@
|
|
|
16
16
|
|
|
17
17
|
import argparse
|
|
18
18
|
import asyncio
|
|
19
|
+
import csv
|
|
19
20
|
import importlib.util
|
|
20
21
|
import sys
|
|
21
22
|
import threading
|
|
22
23
|
from logging import ERROR, INFO, WARN
|
|
23
24
|
from os.path import isfile
|
|
24
25
|
from pathlib import Path
|
|
25
|
-
from typing import List, Optional, Tuple
|
|
26
|
+
from typing import List, Optional, Sequence, Set, Tuple
|
|
26
27
|
|
|
27
28
|
import grpc
|
|
29
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
30
|
+
from cryptography.hazmat.primitives.serialization import (
|
|
31
|
+
load_ssh_private_key,
|
|
32
|
+
load_ssh_public_key,
|
|
33
|
+
)
|
|
28
34
|
|
|
29
35
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
30
36
|
from flwr.common.address import parse_address
|
|
@@ -36,6 +42,11 @@ from flwr.common.constant import (
|
|
|
36
42
|
)
|
|
37
43
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
38
44
|
from flwr.common.logger import log
|
|
45
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
46
|
+
private_key_to_bytes,
|
|
47
|
+
public_key_to_bytes,
|
|
48
|
+
ssh_types_to_elliptic_curve,
|
|
49
|
+
)
|
|
39
50
|
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
|
40
51
|
add_FleetServicer_to_server,
|
|
41
52
|
)
|
|
@@ -51,6 +62,7 @@ from .superlink.fleet.grpc_bidi.grpc_server import (
|
|
|
51
62
|
start_grpc_server,
|
|
52
63
|
)
|
|
53
64
|
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
|
65
|
+
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
|
|
54
66
|
from .superlink.fleet.vce import start_vce
|
|
55
67
|
from .superlink.state import StateFactory
|
|
56
68
|
|
|
@@ -354,10 +366,33 @@ def run_superlink() -> None:
|
|
|
354
366
|
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
355
367
|
host, port, is_v6 = parsed_address
|
|
356
368
|
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
369
|
+
|
|
370
|
+
maybe_keys = _try_setup_client_authentication(args, certificates)
|
|
371
|
+
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
|
|
372
|
+
if maybe_keys is not None:
|
|
373
|
+
(
|
|
374
|
+
client_public_keys,
|
|
375
|
+
server_private_key,
|
|
376
|
+
server_public_key,
|
|
377
|
+
) = maybe_keys
|
|
378
|
+
state = state_factory.state()
|
|
379
|
+
state.store_client_public_keys(client_public_keys)
|
|
380
|
+
state.store_server_private_public_key(
|
|
381
|
+
private_key_to_bytes(server_private_key),
|
|
382
|
+
public_key_to_bytes(server_public_key),
|
|
383
|
+
)
|
|
384
|
+
log(
|
|
385
|
+
INFO,
|
|
386
|
+
"Client authentication enabled with %d known public keys",
|
|
387
|
+
len(client_public_keys),
|
|
388
|
+
)
|
|
389
|
+
interceptors = [AuthenticateServerInterceptor(state)]
|
|
390
|
+
|
|
357
391
|
fleet_server = _run_fleet_api_grpc_rere(
|
|
358
392
|
address=address,
|
|
359
393
|
state_factory=state_factory,
|
|
360
394
|
certificates=certificates,
|
|
395
|
+
interceptors=interceptors,
|
|
361
396
|
)
|
|
362
397
|
grpc_servers.append(fleet_server)
|
|
363
398
|
elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
|
|
@@ -390,6 +425,70 @@ def run_superlink() -> None:
|
|
|
390
425
|
driver_server.wait_for_termination(timeout=1)
|
|
391
426
|
|
|
392
427
|
|
|
428
|
+
def _try_setup_client_authentication(
|
|
429
|
+
args: argparse.Namespace,
|
|
430
|
+
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
431
|
+
) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
|
|
432
|
+
if not args.require_client_authentication:
|
|
433
|
+
return None
|
|
434
|
+
|
|
435
|
+
if certificates is None:
|
|
436
|
+
sys.exit(
|
|
437
|
+
"Client authentication only works over secure connections. "
|
|
438
|
+
"Please provide certificate paths using '--certificates' when "
|
|
439
|
+
"enabling '--require-client-authentication'."
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
client_keys_file_path = Path(args.require_client_authentication[0])
|
|
443
|
+
if not client_keys_file_path.exists():
|
|
444
|
+
sys.exit(
|
|
445
|
+
"The provided path to the client public keys CSV file does not exist: "
|
|
446
|
+
f"{client_keys_file_path}. "
|
|
447
|
+
"Please provide the CSV file path containing known client public keys "
|
|
448
|
+
"to '--require-client-authentication'."
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
client_public_keys: Set[bytes] = set()
|
|
452
|
+
ssh_private_key = load_ssh_private_key(
|
|
453
|
+
Path(args.require_client_authentication[1]).read_bytes(),
|
|
454
|
+
None,
|
|
455
|
+
)
|
|
456
|
+
ssh_public_key = load_ssh_public_key(
|
|
457
|
+
Path(args.require_client_authentication[2]).read_bytes()
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
try:
|
|
461
|
+
server_private_key, server_public_key = ssh_types_to_elliptic_curve(
|
|
462
|
+
ssh_private_key, ssh_public_key
|
|
463
|
+
)
|
|
464
|
+
except TypeError:
|
|
465
|
+
sys.exit(
|
|
466
|
+
"The file paths provided could not be read as a private and public "
|
|
467
|
+
"key pair. Client authentication requires an elliptic curve public and "
|
|
468
|
+
"private key pair. Please provide the file paths containing elliptic "
|
|
469
|
+
"curve private and public keys to '--require-client-authentication'."
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile:
|
|
473
|
+
reader = csv.reader(csvfile)
|
|
474
|
+
for row in reader:
|
|
475
|
+
for element in row:
|
|
476
|
+
public_key = load_ssh_public_key(element.encode())
|
|
477
|
+
if isinstance(public_key, ec.EllipticCurvePublicKey):
|
|
478
|
+
client_public_keys.add(public_key_to_bytes(public_key))
|
|
479
|
+
else:
|
|
480
|
+
sys.exit(
|
|
481
|
+
"Error: Unable to parse the public keys in the .csv "
|
|
482
|
+
"file. Please ensure that the .csv file contains valid "
|
|
483
|
+
"SSH public keys and try again."
|
|
484
|
+
)
|
|
485
|
+
return (
|
|
486
|
+
client_public_keys,
|
|
487
|
+
server_private_key,
|
|
488
|
+
server_public_key,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
|
|
393
492
|
def _try_obtain_certificates(
|
|
394
493
|
args: argparse.Namespace,
|
|
395
494
|
) -> Optional[Tuple[bytes, bytes, bytes]]:
|
|
@@ -417,6 +516,7 @@ def _run_fleet_api_grpc_rere(
|
|
|
417
516
|
address: str,
|
|
418
517
|
state_factory: StateFactory,
|
|
419
518
|
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
519
|
+
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
420
520
|
) -> grpc.Server:
|
|
421
521
|
"""Run Fleet API (gRPC, request-response)."""
|
|
422
522
|
# Create Fleet API gRPC server
|
|
@@ -429,6 +529,7 @@ def _run_fleet_api_grpc_rere(
|
|
|
429
529
|
server_address=address,
|
|
430
530
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
431
531
|
certificates=certificates,
|
|
532
|
+
interceptors=interceptors,
|
|
432
533
|
)
|
|
433
534
|
|
|
434
535
|
log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address)
|
|
@@ -606,6 +707,15 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
606
707
|
"Flower will just create a state in memory.",
|
|
607
708
|
default=DATABASE,
|
|
608
709
|
)
|
|
710
|
+
parser.add_argument(
|
|
711
|
+
"--require-client-authentication",
|
|
712
|
+
nargs=3,
|
|
713
|
+
metavar=("CLIENT_KEYS", "SERVER_PRIVATE_KEY", "SERVER_PUBLIC_KEY"),
|
|
714
|
+
type=str,
|
|
715
|
+
help="Provide three file paths: (1) a .csv file containing a list of "
|
|
716
|
+
"known client public keys for authentication, (2) the server's private "
|
|
717
|
+
"key file, and (3) the server's public key file.",
|
|
718
|
+
)
|
|
609
719
|
|
|
610
720
|
|
|
611
721
|
def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
|
flwr/server/compat/app.py
CHANGED
|
@@ -29,7 +29,7 @@ from flwr.server.server import Server, init_defaults, run_fl
|
|
|
29
29
|
from flwr.server.server_config import ServerConfig
|
|
30
30
|
from flwr.server.strategy import Strategy
|
|
31
31
|
|
|
32
|
-
from ..driver import Driver
|
|
32
|
+
from ..driver import Driver, GrpcDriver
|
|
33
33
|
from .app_utils import start_update_client_manager_thread
|
|
34
34
|
|
|
35
35
|
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
@@ -114,7 +114,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
|
114
114
|
# Create the Driver
|
|
115
115
|
if isinstance(root_certificates, str):
|
|
116
116
|
root_certificates = Path(root_certificates).read_bytes()
|
|
117
|
-
driver =
|
|
117
|
+
driver = GrpcDriver(
|
|
118
118
|
driver_service_address=address, root_certificates=root_certificates
|
|
119
119
|
)
|
|
120
120
|
|
flwr/server/compat/app_utils.py
CHANGED