flwr-nightly 1.9.0.dev20240426__py3-none-any.whl → 1.9.0.dev20240430__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +17 -93
- flwr/client/grpc_rere_client/client_interceptor.py +9 -1
- flwr/client/supernode/__init__.py +2 -0
- flwr/client/supernode/app.py +166 -4
- flwr/common/logger.py +26 -0
- flwr/common/message.py +72 -82
- flwr/common/record/recordset.py +5 -4
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +15 -0
- flwr/server/app.py +105 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +174 -0
- flwr/simulation/app.py +16 -1
- flwr/simulation/run_simulation.py +3 -0
- {flwr_nightly-1.9.0.dev20240426.dist-info → flwr_nightly-1.9.0.dev20240430.dist-info}/METADATA +1 -1
- {flwr_nightly-1.9.0.dev20240426.dist-info → flwr_nightly-1.9.0.dev20240430.dist-info}/RECORD +19 -18
- {flwr_nightly-1.9.0.dev20240426.dist-info → flwr_nightly-1.9.0.dev20240430.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240426.dist-info → flwr_nightly-1.9.0.dev20240430.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.9.0.dev20240426.dist-info → flwr_nightly-1.9.0.dev20240430.dist-info}/entry_points.txt +0 -0
flwr/common/message.py
CHANGED
|
@@ -19,6 +19,7 @@ from __future__ import annotations
|
|
|
19
19
|
import time
|
|
20
20
|
import warnings
|
|
21
21
|
from dataclasses import dataclass
|
|
22
|
+
from typing import Optional, cast
|
|
22
23
|
|
|
23
24
|
from .record import RecordSet
|
|
24
25
|
|
|
@@ -55,17 +56,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
55
56
|
is more relevant when conducting simulations.
|
|
56
57
|
"""
|
|
57
58
|
|
|
58
|
-
_run_id: int
|
|
59
|
-
_message_id: str
|
|
60
|
-
_src_node_id: int
|
|
61
|
-
_dst_node_id: int
|
|
62
|
-
_reply_to_message: str
|
|
63
|
-
_group_id: str
|
|
64
|
-
_ttl: float
|
|
65
|
-
_message_type: str
|
|
66
|
-
_partition_id: int | None
|
|
67
|
-
_created_at: float # Unix timestamp (in seconds) to be set upon message creation
|
|
68
|
-
|
|
69
59
|
def __init__( # pylint: disable=too-many-arguments
|
|
70
60
|
self,
|
|
71
61
|
run_id: int,
|
|
@@ -78,95 +68,98 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
78
68
|
message_type: str,
|
|
79
69
|
partition_id: int | None = None,
|
|
80
70
|
) -> None:
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
71
|
+
var_dict = {
|
|
72
|
+
"_run_id": run_id,
|
|
73
|
+
"_message_id": message_id,
|
|
74
|
+
"_src_node_id": src_node_id,
|
|
75
|
+
"_dst_node_id": dst_node_id,
|
|
76
|
+
"_reply_to_message": reply_to_message,
|
|
77
|
+
"_group_id": group_id,
|
|
78
|
+
"_ttl": ttl,
|
|
79
|
+
"_message_type": message_type,
|
|
80
|
+
"_partition_id": partition_id,
|
|
81
|
+
}
|
|
82
|
+
self.__dict__.update(var_dict)
|
|
90
83
|
|
|
91
84
|
@property
|
|
92
85
|
def run_id(self) -> int:
|
|
93
86
|
"""An identifier for the current run."""
|
|
94
|
-
return self._run_id
|
|
87
|
+
return cast(int, self.__dict__["_run_id"])
|
|
95
88
|
|
|
96
89
|
@property
|
|
97
90
|
def message_id(self) -> str:
|
|
98
91
|
"""An identifier for the current message."""
|
|
99
|
-
return self._message_id
|
|
92
|
+
return cast(str, self.__dict__["_message_id"])
|
|
100
93
|
|
|
101
94
|
@property
|
|
102
95
|
def src_node_id(self) -> int:
|
|
103
96
|
"""An identifier for the node sending this message."""
|
|
104
|
-
return self._src_node_id
|
|
97
|
+
return cast(int, self.__dict__["_src_node_id"])
|
|
105
98
|
|
|
106
99
|
@property
|
|
107
100
|
def reply_to_message(self) -> str:
|
|
108
101
|
"""An identifier for the message this message replies to."""
|
|
109
|
-
return self._reply_to_message
|
|
102
|
+
return cast(str, self.__dict__["_reply_to_message"])
|
|
110
103
|
|
|
111
104
|
@property
|
|
112
105
|
def dst_node_id(self) -> int:
|
|
113
106
|
"""An identifier for the node receiving this message."""
|
|
114
|
-
return self._dst_node_id
|
|
107
|
+
return cast(int, self.__dict__["_dst_node_id"])
|
|
115
108
|
|
|
116
109
|
@dst_node_id.setter
|
|
117
110
|
def dst_node_id(self, value: int) -> None:
|
|
118
111
|
"""Set dst_node_id."""
|
|
119
|
-
self._dst_node_id = value
|
|
112
|
+
self.__dict__["_dst_node_id"] = value
|
|
120
113
|
|
|
121
114
|
@property
|
|
122
115
|
def group_id(self) -> str:
|
|
123
116
|
"""An identifier for grouping messages."""
|
|
124
|
-
return self._group_id
|
|
117
|
+
return cast(str, self.__dict__["_group_id"])
|
|
125
118
|
|
|
126
119
|
@group_id.setter
|
|
127
120
|
def group_id(self, value: str) -> None:
|
|
128
121
|
"""Set group_id."""
|
|
129
|
-
self._group_id = value
|
|
122
|
+
self.__dict__["_group_id"] = value
|
|
130
123
|
|
|
131
124
|
@property
|
|
132
125
|
def created_at(self) -> float:
|
|
133
126
|
"""Unix timestamp when the message was created."""
|
|
134
|
-
return self._created_at
|
|
127
|
+
return cast(float, self.__dict__["_created_at"])
|
|
135
128
|
|
|
136
129
|
@created_at.setter
|
|
137
130
|
def created_at(self, value: float) -> None:
|
|
138
|
-
"""Set creation timestamp for this
|
|
139
|
-
self._created_at = value
|
|
131
|
+
"""Set creation timestamp for this message."""
|
|
132
|
+
self.__dict__["_created_at"] = value
|
|
140
133
|
|
|
141
134
|
@property
|
|
142
135
|
def ttl(self) -> float:
|
|
143
136
|
"""Time-to-live for this message."""
|
|
144
|
-
return self._ttl
|
|
137
|
+
return cast(float, self.__dict__["_ttl"])
|
|
145
138
|
|
|
146
139
|
@ttl.setter
|
|
147
140
|
def ttl(self, value: float) -> None:
|
|
148
141
|
"""Set ttl."""
|
|
149
|
-
self._ttl = value
|
|
142
|
+
self.__dict__["_ttl"] = value
|
|
150
143
|
|
|
151
144
|
@property
|
|
152
145
|
def message_type(self) -> str:
|
|
153
146
|
"""A string that encodes the action to be executed on the receiving end."""
|
|
154
|
-
return self._message_type
|
|
147
|
+
return cast(str, self.__dict__["_message_type"])
|
|
155
148
|
|
|
156
149
|
@message_type.setter
|
|
157
150
|
def message_type(self, value: str) -> None:
|
|
158
151
|
"""Set message_type."""
|
|
159
|
-
self._message_type = value
|
|
152
|
+
self.__dict__["_message_type"] = value
|
|
160
153
|
|
|
161
154
|
@property
|
|
162
155
|
def partition_id(self) -> int | None:
|
|
163
156
|
"""An identifier telling which data partition a ClientApp should use."""
|
|
164
|
-
return self._partition_id
|
|
157
|
+
return cast(int, self.__dict__["_partition_id"])
|
|
165
158
|
|
|
166
159
|
@partition_id.setter
|
|
167
160
|
def partition_id(self, value: int) -> None:
|
|
168
|
-
"""Set
|
|
169
|
-
self._partition_id = value
|
|
161
|
+
"""Set partition_id."""
|
|
162
|
+
self.__dict__["_partition_id"] = value
|
|
170
163
|
|
|
171
164
|
|
|
172
165
|
@dataclass
|
|
@@ -181,22 +174,22 @@ class Error:
|
|
|
181
174
|
A reason for why the error arose (e.g. an exception stack-trace)
|
|
182
175
|
"""
|
|
183
176
|
|
|
184
|
-
_code: int
|
|
185
|
-
_reason: str | None = None
|
|
186
|
-
|
|
187
177
|
def __init__(self, code: int, reason: str | None = None) -> None:
|
|
188
|
-
|
|
189
|
-
|
|
178
|
+
var_dict = {
|
|
179
|
+
"_code": code,
|
|
180
|
+
"_reason": reason,
|
|
181
|
+
}
|
|
182
|
+
self.__dict__.update(var_dict)
|
|
190
183
|
|
|
191
184
|
@property
|
|
192
185
|
def code(self) -> int:
|
|
193
186
|
"""Error code."""
|
|
194
|
-
return self._code
|
|
187
|
+
return cast(int, self.__dict__["_code"])
|
|
195
188
|
|
|
196
189
|
@property
|
|
197
190
|
def reason(self) -> str | None:
|
|
198
191
|
"""Reason reported about the error."""
|
|
199
|
-
return self._reason
|
|
192
|
+
return cast(Optional[str], self.__dict__["_reason"])
|
|
200
193
|
|
|
201
194
|
|
|
202
195
|
@dataclass
|
|
@@ -215,88 +208,70 @@ class Message:
|
|
|
215
208
|
when processing another message.
|
|
216
209
|
"""
|
|
217
210
|
|
|
218
|
-
_metadata: Metadata
|
|
219
|
-
_content: RecordSet | None = None
|
|
220
|
-
_error: Error | None = None
|
|
221
|
-
|
|
222
211
|
def __init__(
|
|
223
212
|
self,
|
|
224
213
|
metadata: Metadata,
|
|
225
214
|
content: RecordSet | None = None,
|
|
226
215
|
error: Error | None = None,
|
|
227
216
|
) -> None:
|
|
228
|
-
self._metadata = metadata
|
|
229
|
-
|
|
230
|
-
# Set message creation timestamp
|
|
231
|
-
self._metadata.created_at = time.time()
|
|
232
|
-
|
|
233
217
|
if not (content is None) ^ (error is None):
|
|
234
218
|
raise ValueError("Either `content` or `error` must be set, but not both.")
|
|
235
219
|
|
|
236
|
-
|
|
237
|
-
|
|
220
|
+
metadata.created_at = time.time() # Set the message creation timestamp
|
|
221
|
+
var_dict = {
|
|
222
|
+
"_metadata": metadata,
|
|
223
|
+
"_content": content,
|
|
224
|
+
"_error": error,
|
|
225
|
+
}
|
|
226
|
+
self.__dict__.update(var_dict)
|
|
238
227
|
|
|
239
228
|
@property
|
|
240
229
|
def metadata(self) -> Metadata:
|
|
241
230
|
"""A dataclass including information about the message to be executed."""
|
|
242
|
-
return self._metadata
|
|
231
|
+
return cast(Metadata, self.__dict__["_metadata"])
|
|
243
232
|
|
|
244
233
|
@property
|
|
245
234
|
def content(self) -> RecordSet:
|
|
246
235
|
"""The content of this message."""
|
|
247
|
-
if self._content is None:
|
|
236
|
+
if self.__dict__["_content"] is None:
|
|
248
237
|
raise ValueError(
|
|
249
238
|
"Message content is None. Use <message>.has_content() "
|
|
250
239
|
"to check if a message has content."
|
|
251
240
|
)
|
|
252
|
-
return self._content
|
|
241
|
+
return cast(RecordSet, self.__dict__["_content"])
|
|
253
242
|
|
|
254
243
|
@content.setter
|
|
255
244
|
def content(self, value: RecordSet) -> None:
|
|
256
245
|
"""Set content."""
|
|
257
|
-
if self._error is None:
|
|
258
|
-
self._content = value
|
|
246
|
+
if self.__dict__["_error"] is None:
|
|
247
|
+
self.__dict__["_content"] = value
|
|
259
248
|
else:
|
|
260
249
|
raise ValueError("A message with an error set cannot have content.")
|
|
261
250
|
|
|
262
251
|
@property
|
|
263
252
|
def error(self) -> Error:
|
|
264
253
|
"""Error captured by this message."""
|
|
265
|
-
if self._error is None:
|
|
254
|
+
if self.__dict__["_error"] is None:
|
|
266
255
|
raise ValueError(
|
|
267
256
|
"Message error is None. Use <message>.has_error() "
|
|
268
257
|
"to check first if a message carries an error."
|
|
269
258
|
)
|
|
270
|
-
return self._error
|
|
259
|
+
return cast(Error, self.__dict__["_error"])
|
|
271
260
|
|
|
272
261
|
@error.setter
|
|
273
262
|
def error(self, value: Error) -> None:
|
|
274
263
|
"""Set error."""
|
|
275
264
|
if self.has_content():
|
|
276
265
|
raise ValueError("A message with content set cannot carry an error.")
|
|
277
|
-
self._error = value
|
|
266
|
+
self.__dict__["_error"] = value
|
|
278
267
|
|
|
279
268
|
def has_content(self) -> bool:
|
|
280
269
|
"""Return True if message has content, else False."""
|
|
281
|
-
return self._content is not None
|
|
270
|
+
return self.__dict__["_content"] is not None
|
|
282
271
|
|
|
283
272
|
def has_error(self) -> bool:
|
|
284
273
|
"""Return True if message has an error, else False."""
|
|
285
|
-
return self._error is not None
|
|
286
|
-
|
|
287
|
-
def _create_reply_metadata(self, ttl: float) -> Metadata:
|
|
288
|
-
"""Construct metadata for a reply message."""
|
|
289
|
-
return Metadata(
|
|
290
|
-
run_id=self.metadata.run_id,
|
|
291
|
-
message_id="",
|
|
292
|
-
src_node_id=self.metadata.dst_node_id,
|
|
293
|
-
dst_node_id=self.metadata.src_node_id,
|
|
294
|
-
reply_to_message=self.metadata.message_id,
|
|
295
|
-
group_id=self.metadata.group_id,
|
|
296
|
-
ttl=ttl,
|
|
297
|
-
message_type=self.metadata.message_type,
|
|
298
|
-
partition_id=self.metadata.partition_id,
|
|
299
|
-
)
|
|
274
|
+
return self.__dict__["_error"] is not None
|
|
300
275
|
|
|
301
276
|
def create_error_reply(self, error: Error, ttl: float | None = None) -> Message:
|
|
302
277
|
"""Construct a reply message indicating an error happened.
|
|
@@ -323,7 +298,7 @@ class Message:
|
|
|
323
298
|
# message creation)
|
|
324
299
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
325
300
|
# Create reply with error
|
|
326
|
-
message = Message(metadata=
|
|
301
|
+
message = Message(metadata=_create_reply_metadata(self, ttl_), error=error)
|
|
327
302
|
|
|
328
303
|
if ttl is None:
|
|
329
304
|
# Set TTL equal to the remaining time for the received message to expire
|
|
@@ -369,7 +344,7 @@ class Message:
|
|
|
369
344
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
370
345
|
|
|
371
346
|
message = Message(
|
|
372
|
-
metadata=
|
|
347
|
+
metadata=_create_reply_metadata(self, ttl_),
|
|
373
348
|
content=content,
|
|
374
349
|
)
|
|
375
350
|
|
|
@@ -381,3 +356,18 @@ class Message:
|
|
|
381
356
|
message.metadata.ttl = ttl
|
|
382
357
|
|
|
383
358
|
return message
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def _create_reply_metadata(msg: Message, ttl: float) -> Metadata:
|
|
362
|
+
"""Construct metadata for a reply message."""
|
|
363
|
+
return Metadata(
|
|
364
|
+
run_id=msg.metadata.run_id,
|
|
365
|
+
message_id="",
|
|
366
|
+
src_node_id=msg.metadata.dst_node_id,
|
|
367
|
+
dst_node_id=msg.metadata.src_node_id,
|
|
368
|
+
reply_to_message=msg.metadata.message_id,
|
|
369
|
+
group_id=msg.metadata.group_id,
|
|
370
|
+
ttl=ttl,
|
|
371
|
+
message_type=msg.metadata.message_type,
|
|
372
|
+
partition_id=msg.metadata.partition_id,
|
|
373
|
+
)
|
flwr/common/record/recordset.py
CHANGED
|
@@ -24,6 +24,7 @@ from .parametersrecord import ParametersRecord
|
|
|
24
24
|
from .typeddict import TypedDict
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
@dataclass
|
|
27
28
|
class RecordSetData:
|
|
28
29
|
"""Inner data container for the RecordSet class."""
|
|
29
30
|
|
|
@@ -97,22 +98,22 @@ class RecordSet:
|
|
|
97
98
|
metrics_records=metrics_records,
|
|
98
99
|
configs_records=configs_records,
|
|
99
100
|
)
|
|
100
|
-
|
|
101
|
+
self.__dict__["_data"] = data
|
|
101
102
|
|
|
102
103
|
@property
|
|
103
104
|
def parameters_records(self) -> TypedDict[str, ParametersRecord]:
|
|
104
105
|
"""Dictionary holding ParametersRecord instances."""
|
|
105
|
-
data = cast(RecordSetData,
|
|
106
|
+
data = cast(RecordSetData, self.__dict__["_data"])
|
|
106
107
|
return data.parameters_records
|
|
107
108
|
|
|
108
109
|
@property
|
|
109
110
|
def metrics_records(self) -> TypedDict[str, MetricsRecord]:
|
|
110
111
|
"""Dictionary holding MetricsRecord instances."""
|
|
111
|
-
data = cast(RecordSetData,
|
|
112
|
+
data = cast(RecordSetData, self.__dict__["_data"])
|
|
112
113
|
return data.metrics_records
|
|
113
114
|
|
|
114
115
|
@property
|
|
115
116
|
def configs_records(self) -> TypedDict[str, ConfigsRecord]:
|
|
116
117
|
"""Dictionary holding ConfigsRecord instances."""
|
|
117
|
-
data = cast(RecordSetData,
|
|
118
|
+
data = cast(RecordSetData, self.__dict__["_data"])
|
|
118
119
|
return data.configs_records
|
|
@@ -117,3 +117,18 @@ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
|
|
|
117
117
|
return True
|
|
118
118
|
except InvalidSignature:
|
|
119
119
|
return False
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def ssh_types_to_elliptic_curve(
|
|
123
|
+
private_key: serialization.SSHPrivateKeyTypes,
|
|
124
|
+
public_key: serialization.SSHPublicKeyTypes,
|
|
125
|
+
) -> Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]:
|
|
126
|
+
"""Cast SSH key types to elliptic curve."""
|
|
127
|
+
if isinstance(private_key, ec.EllipticCurvePrivateKey) and isinstance(
|
|
128
|
+
public_key, ec.EllipticCurvePublicKey
|
|
129
|
+
):
|
|
130
|
+
return (private_key, public_key)
|
|
131
|
+
|
|
132
|
+
raise TypeError(
|
|
133
|
+
"The provided key is not an EllipticCurvePrivateKey or EllipticCurvePublicKey"
|
|
134
|
+
)
|
flwr/server/app.py
CHANGED
|
@@ -16,15 +16,21 @@
|
|
|
16
16
|
|
|
17
17
|
import argparse
|
|
18
18
|
import asyncio
|
|
19
|
+
import csv
|
|
19
20
|
import importlib.util
|
|
20
21
|
import sys
|
|
21
22
|
import threading
|
|
22
23
|
from logging import ERROR, INFO, WARN
|
|
23
24
|
from os.path import isfile
|
|
24
25
|
from pathlib import Path
|
|
25
|
-
from typing import List, Optional, Tuple
|
|
26
|
+
from typing import List, Optional, Sequence, Set, Tuple
|
|
26
27
|
|
|
27
28
|
import grpc
|
|
29
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
30
|
+
from cryptography.hazmat.primitives.serialization import (
|
|
31
|
+
load_ssh_private_key,
|
|
32
|
+
load_ssh_public_key,
|
|
33
|
+
)
|
|
28
34
|
|
|
29
35
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
30
36
|
from flwr.common.address import parse_address
|
|
@@ -36,6 +42,10 @@ from flwr.common.constant import (
|
|
|
36
42
|
)
|
|
37
43
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
38
44
|
from flwr.common.logger import log
|
|
45
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
46
|
+
public_key_to_bytes,
|
|
47
|
+
ssh_types_to_elliptic_curve,
|
|
48
|
+
)
|
|
39
49
|
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
|
40
50
|
add_FleetServicer_to_server,
|
|
41
51
|
)
|
|
@@ -51,6 +61,7 @@ from .superlink.fleet.grpc_bidi.grpc_server import (
|
|
|
51
61
|
start_grpc_server,
|
|
52
62
|
)
|
|
53
63
|
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
|
64
|
+
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
|
|
54
65
|
from .superlink.fleet.vce import start_vce
|
|
55
66
|
from .superlink.state import StateFactory
|
|
56
67
|
|
|
@@ -354,10 +365,28 @@ def run_superlink() -> None:
|
|
|
354
365
|
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
355
366
|
host, port, is_v6 = parsed_address
|
|
356
367
|
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
368
|
+
|
|
369
|
+
maybe_keys = _try_setup_client_authentication(args, certificates)
|
|
370
|
+
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
|
|
371
|
+
if maybe_keys is not None:
|
|
372
|
+
(
|
|
373
|
+
client_public_keys,
|
|
374
|
+
server_private_key,
|
|
375
|
+
server_public_key,
|
|
376
|
+
) = maybe_keys
|
|
377
|
+
interceptors = [
|
|
378
|
+
AuthenticateServerInterceptor(
|
|
379
|
+
client_public_keys,
|
|
380
|
+
server_private_key,
|
|
381
|
+
server_public_key,
|
|
382
|
+
)
|
|
383
|
+
]
|
|
384
|
+
|
|
357
385
|
fleet_server = _run_fleet_api_grpc_rere(
|
|
358
386
|
address=address,
|
|
359
387
|
state_factory=state_factory,
|
|
360
388
|
certificates=certificates,
|
|
389
|
+
interceptors=interceptors,
|
|
361
390
|
)
|
|
362
391
|
grpc_servers.append(fleet_server)
|
|
363
392
|
elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
|
|
@@ -390,6 +419,70 @@ def run_superlink() -> None:
|
|
|
390
419
|
driver_server.wait_for_termination(timeout=1)
|
|
391
420
|
|
|
392
421
|
|
|
422
|
+
def _try_setup_client_authentication(
|
|
423
|
+
args: argparse.Namespace,
|
|
424
|
+
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
425
|
+
) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
|
|
426
|
+
if not args.require_client_authentication:
|
|
427
|
+
return None
|
|
428
|
+
|
|
429
|
+
if certificates is None:
|
|
430
|
+
sys.exit(
|
|
431
|
+
"Client authentication only works over secure connections. "
|
|
432
|
+
"Please provide certificate paths using '--certificates' when "
|
|
433
|
+
"enabling '--require-client-authentication'."
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
client_keys_file_path = Path(args.require_client_authentication[0])
|
|
437
|
+
if not client_keys_file_path.exists():
|
|
438
|
+
sys.exit(
|
|
439
|
+
"The provided path to the client public keys CSV file does not exist: "
|
|
440
|
+
f"{client_keys_file_path}. "
|
|
441
|
+
"Please provide the CSV file path containing known client public keys "
|
|
442
|
+
"to '--require-client-authentication'."
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
client_public_keys: Set[bytes] = set()
|
|
446
|
+
ssh_private_key = load_ssh_private_key(
|
|
447
|
+
Path(args.require_client_authentication[1]).read_bytes(),
|
|
448
|
+
None,
|
|
449
|
+
)
|
|
450
|
+
ssh_public_key = load_ssh_public_key(
|
|
451
|
+
Path(args.require_client_authentication[2]).read_bytes()
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
try:
|
|
455
|
+
server_private_key, server_public_key = ssh_types_to_elliptic_curve(
|
|
456
|
+
ssh_private_key, ssh_public_key
|
|
457
|
+
)
|
|
458
|
+
except TypeError:
|
|
459
|
+
sys.exit(
|
|
460
|
+
"The file paths provided could not be read as a private and public "
|
|
461
|
+
"key pair. Client authentication requires an elliptic curve public and "
|
|
462
|
+
"private key pair. Please provide the file paths containing elliptic "
|
|
463
|
+
"curve private and public keys to '--require-client-authentication'."
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile:
|
|
467
|
+
reader = csv.reader(csvfile)
|
|
468
|
+
for row in reader:
|
|
469
|
+
for element in row:
|
|
470
|
+
public_key = load_ssh_public_key(element.encode())
|
|
471
|
+
if isinstance(public_key, ec.EllipticCurvePublicKey):
|
|
472
|
+
client_public_keys.add(public_key_to_bytes(public_key))
|
|
473
|
+
else:
|
|
474
|
+
sys.exit(
|
|
475
|
+
"Error: Unable to parse the public keys in the .csv "
|
|
476
|
+
"file. Please ensure that the .csv file contains valid "
|
|
477
|
+
"SSH public keys and try again."
|
|
478
|
+
)
|
|
479
|
+
return (
|
|
480
|
+
client_public_keys,
|
|
481
|
+
server_private_key,
|
|
482
|
+
server_public_key,
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
|
|
393
486
|
def _try_obtain_certificates(
|
|
394
487
|
args: argparse.Namespace,
|
|
395
488
|
) -> Optional[Tuple[bytes, bytes, bytes]]:
|
|
@@ -417,6 +510,7 @@ def _run_fleet_api_grpc_rere(
|
|
|
417
510
|
address: str,
|
|
418
511
|
state_factory: StateFactory,
|
|
419
512
|
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
513
|
+
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
420
514
|
) -> grpc.Server:
|
|
421
515
|
"""Run Fleet API (gRPC, request-response)."""
|
|
422
516
|
# Create Fleet API gRPC server
|
|
@@ -429,6 +523,7 @@ def _run_fleet_api_grpc_rere(
|
|
|
429
523
|
server_address=address,
|
|
430
524
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
431
525
|
certificates=certificates,
|
|
526
|
+
interceptors=interceptors,
|
|
432
527
|
)
|
|
433
528
|
|
|
434
529
|
log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address)
|
|
@@ -606,6 +701,15 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
606
701
|
"Flower will just create a state in memory.",
|
|
607
702
|
default=DATABASE,
|
|
608
703
|
)
|
|
704
|
+
parser.add_argument(
|
|
705
|
+
"--require-client-authentication",
|
|
706
|
+
nargs=3,
|
|
707
|
+
metavar=("CLIENT_KEYS", "SERVER_PRIVATE_KEY", "SERVER_PUBLIC_KEY"),
|
|
708
|
+
type=str,
|
|
709
|
+
help="Provide three file paths: (1) a .csv file containing a list of "
|
|
710
|
+
"known client public keys for authentication, (2) the server's private "
|
|
711
|
+
"key file, and (3) the server's public key file.",
|
|
712
|
+
)
|
|
609
713
|
|
|
610
714
|
|
|
611
715
|
def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import concurrent.futures
|
|
19
19
|
import sys
|
|
20
20
|
from logging import ERROR
|
|
21
|
-
from typing import Any, Callable, Optional, Tuple, Union
|
|
21
|
+
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
22
22
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
@@ -162,6 +162,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
|
|
|
162
162
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
163
163
|
keepalive_time_ms: int = 210000,
|
|
164
164
|
certificates: Optional[Tuple[bytes, bytes, bytes]] = None,
|
|
165
|
+
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
165
166
|
) -> grpc.Server:
|
|
166
167
|
"""Create a gRPC server with a single servicer.
|
|
167
168
|
|
|
@@ -249,6 +250,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
|
|
|
249
250
|
# returning RESOURCE_EXHAUSTED status, or None to indicate no limit.
|
|
250
251
|
maximum_concurrent_rpcs=max_concurrent_workers,
|
|
251
252
|
options=options,
|
|
253
|
+
interceptors=interceptors,
|
|
252
254
|
)
|
|
253
255
|
add_servicer_to_server_fn(servicer, server)
|
|
254
256
|
|