flwr-nightly 1.9.0.dev20240423__py3-none-any.whl → 1.9.0.dev20240425__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/config_utils.py +18 -46
- flwr/cli/new/new.py +37 -17
- flwr/cli/new/templates/app/README.md.tpl +1 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +6 -3
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +6 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +6 -3
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +6 -3
- flwr/cli/run/run.py +1 -1
- flwr/cli/utils.py +18 -17
- flwr/client/grpc_client/connection.py +6 -1
- flwr/client/grpc_rere_client/client_interceptor.py +150 -0
- flwr/client/grpc_rere_client/connection.py +17 -2
- flwr/client/rest_client/connection.py +5 -1
- flwr/common/grpc.py +5 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +20 -1
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +27 -72
- flwr/server/driver/grpc_driver.py +17 -8
- flwr/server/run_serverapp.py +14 -0
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +37 -1
- flwr/server/superlink/state/sqlite_state.py +71 -4
- flwr/server/superlink/state/state.py +26 -0
- {flwr_nightly-1.9.0.dev20240423.dist-info → flwr_nightly-1.9.0.dev20240425.dist-info}/METADATA +1 -1
- {flwr_nightly-1.9.0.dev20240423.dist-info → flwr_nightly-1.9.0.dev20240425.dist-info}/RECORD +30 -29
- {flwr_nightly-1.9.0.dev20240423.dist-info → flwr_nightly-1.9.0.dev20240425.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240423.dist-info → flwr_nightly-1.9.0.dev20240425.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.9.0.dev20240423.dist-info → flwr_nightly-1.9.0.dev20240425.dist-info}/entry_points.txt +0 -0
flwr/common/grpc.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from logging import DEBUG
|
|
19
|
-
from typing import Optional
|
|
19
|
+
from typing import Optional, Sequence
|
|
20
20
|
|
|
21
21
|
import grpc
|
|
22
22
|
|
|
@@ -30,6 +30,7 @@ def create_channel(
|
|
|
30
30
|
insecure: bool,
|
|
31
31
|
root_certificates: Optional[bytes] = None,
|
|
32
32
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
33
|
+
interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None,
|
|
33
34
|
) -> grpc.Channel:
|
|
34
35
|
"""Create a gRPC channel, either secure or insecure."""
|
|
35
36
|
# Check for conflicting parameters
|
|
@@ -57,4 +58,7 @@ def create_channel(
|
|
|
57
58
|
)
|
|
58
59
|
log(DEBUG, "Opened secure gRPC connection using certificates")
|
|
59
60
|
|
|
61
|
+
if interceptors is not None:
|
|
62
|
+
channel = grpc.intercept_channel(channel, interceptors)
|
|
63
|
+
|
|
60
64
|
return channel
|
|
@@ -18,8 +18,9 @@
|
|
|
18
18
|
import base64
|
|
19
19
|
from typing import Tuple, cast
|
|
20
20
|
|
|
21
|
+
from cryptography.exceptions import InvalidSignature
|
|
21
22
|
from cryptography.fernet import Fernet
|
|
22
|
-
from cryptography.hazmat.primitives import hashes, serialization
|
|
23
|
+
from cryptography.hazmat.primitives import hashes, hmac, serialization
|
|
23
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
25
|
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
|
25
26
|
|
|
@@ -98,3 +99,21 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes:
|
|
|
98
99
|
# The input key must be url safe
|
|
99
100
|
fernet = Fernet(key)
|
|
100
101
|
return fernet.decrypt(ciphertext)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def compute_hmac(key: bytes, message: bytes) -> bytes:
|
|
105
|
+
"""Compute hmac of a message using key as hash."""
|
|
106
|
+
computed_hmac = hmac.HMAC(key, hashes.SHA256())
|
|
107
|
+
computed_hmac.update(message)
|
|
108
|
+
return computed_hmac.finalize()
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
|
|
112
|
+
"""Verify hmac of a message using key as hash."""
|
|
113
|
+
computed_hmac = hmac.HMAC(key, hashes.SHA256())
|
|
114
|
+
computed_hmac.update(message)
|
|
115
|
+
try:
|
|
116
|
+
computed_hmac.verify(hmac_value)
|
|
117
|
+
return True
|
|
118
|
+
except InvalidSignature:
|
|
119
|
+
return False
|
flwr/server/compat/app_utils.py
CHANGED
|
@@ -16,16 +16,14 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
|
-
from typing import
|
|
19
|
+
from typing import Optional
|
|
20
20
|
|
|
21
21
|
from flwr import common
|
|
22
|
-
from flwr.common import
|
|
22
|
+
from flwr.common import Message, MessageType, MessageTypeLegacy, RecordSet
|
|
23
23
|
from flwr.common import recordset_compat as compat
|
|
24
|
-
from flwr.common import serde
|
|
25
|
-
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
|
|
26
24
|
from flwr.server.client_proxy import ClientProxy
|
|
27
25
|
|
|
28
|
-
from ..driver.
|
|
26
|
+
from ..driver.driver import Driver
|
|
29
27
|
|
|
30
28
|
SLEEP_TIME = 1
|
|
31
29
|
|
|
@@ -33,9 +31,7 @@ SLEEP_TIME = 1
|
|
|
33
31
|
class DriverClientProxy(ClientProxy):
|
|
34
32
|
"""Flower client proxy which delegates work using the Driver API."""
|
|
35
33
|
|
|
36
|
-
def __init__(
|
|
37
|
-
self, node_id: int, driver: GrpcDriverHelper, anonymous: bool, run_id: int
|
|
38
|
-
):
|
|
34
|
+
def __init__(self, node_id: int, driver: Driver, anonymous: bool, run_id: int):
|
|
39
35
|
super().__init__(str(node_id))
|
|
40
36
|
self.node_id = node_id
|
|
41
37
|
self.driver = driver
|
|
@@ -116,80 +112,39 @@ class DriverClientProxy(ClientProxy):
|
|
|
116
112
|
timeout: Optional[float],
|
|
117
113
|
group_id: Optional[int],
|
|
118
114
|
) -> RecordSet:
|
|
119
|
-
task_ins = task_pb2.TaskIns( # pylint: disable=E1101
|
|
120
|
-
task_id="",
|
|
121
|
-
group_id=str(group_id) if group_id is not None else "",
|
|
122
|
-
run_id=self.run_id,
|
|
123
|
-
task=task_pb2.Task( # pylint: disable=E1101
|
|
124
|
-
producer=node_pb2.Node( # pylint: disable=E1101
|
|
125
|
-
node_id=0,
|
|
126
|
-
anonymous=True,
|
|
127
|
-
),
|
|
128
|
-
consumer=node_pb2.Node( # pylint: disable=E1101
|
|
129
|
-
node_id=self.node_id,
|
|
130
|
-
anonymous=self.anonymous,
|
|
131
|
-
),
|
|
132
|
-
task_type=task_type,
|
|
133
|
-
recordset=serde.recordset_to_proto(recordset),
|
|
134
|
-
ttl=DEFAULT_TTL,
|
|
135
|
-
),
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
# This would normally be recorded upon common.Message creation
|
|
139
|
-
# but this compatibility stack doesn't create Messages,
|
|
140
|
-
# so we need to inject `created_at` manually (needed for
|
|
141
|
-
# taskins validation by server.utils.validator)
|
|
142
|
-
task_ins.task.created_at = time.time()
|
|
143
115
|
|
|
144
|
-
|
|
145
|
-
|
|
116
|
+
# Create message
|
|
117
|
+
message = self.driver.create_message(
|
|
118
|
+
content=recordset,
|
|
119
|
+
message_type=task_type,
|
|
120
|
+
dst_node_id=self.node_id,
|
|
121
|
+
group_id=str(group_id) if group_id else "",
|
|
122
|
+
ttl=timeout,
|
|
146
123
|
)
|
|
147
124
|
|
|
148
|
-
#
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
raise ValueError("Unexpected number of task_ids")
|
|
125
|
+
# Push message
|
|
126
|
+
message_ids = list(self.driver.push_messages(messages=[message]))
|
|
127
|
+
if len(message_ids) != 1:
|
|
128
|
+
raise ValueError("Unexpected number of message_ids")
|
|
153
129
|
|
|
154
|
-
|
|
155
|
-
if
|
|
156
|
-
raise ValueError(f"Failed to
|
|
130
|
+
message_id = message_ids[0]
|
|
131
|
+
if message_id == "":
|
|
132
|
+
raise ValueError(f"Failed to send message to node {self.node_id}")
|
|
157
133
|
|
|
158
134
|
if timeout:
|
|
159
135
|
start_time = time.time()
|
|
160
136
|
|
|
161
137
|
while True:
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
pull_task_res_res.task_res_list
|
|
172
|
-
)
|
|
173
|
-
if len(task_res_list) == 1:
|
|
174
|
-
task_res = task_res_list[0]
|
|
175
|
-
|
|
176
|
-
# This will raise an Exception if task_res carries an `error`
|
|
177
|
-
validate_task_res(task_res=task_res)
|
|
178
|
-
|
|
179
|
-
return serde.recordset_from_proto(task_res.task.recordset)
|
|
138
|
+
messages = list(self.driver.pull_messages(message_ids))
|
|
139
|
+
if len(messages) == 1:
|
|
140
|
+
msg: Message = messages[0]
|
|
141
|
+
if msg.has_error():
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"Message contains an Error (reason: {msg.error.reason}). "
|
|
144
|
+
"It originated during client-side execution of a message."
|
|
145
|
+
)
|
|
146
|
+
return msg.content
|
|
180
147
|
|
|
181
148
|
if timeout is not None and time.time() > start_time + timeout:
|
|
182
149
|
raise RuntimeError("Timeout reached")
|
|
183
150
|
time.sleep(SLEEP_TIME)
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
def validate_task_res(
|
|
187
|
-
task_res: task_pb2.TaskRes, # pylint: disable=E1101
|
|
188
|
-
) -> None:
|
|
189
|
-
"""Validate if a TaskRes is empty or not."""
|
|
190
|
-
if not task_res.HasField("task"):
|
|
191
|
-
raise ValueError("Invalid TaskRes, field `task` missing")
|
|
192
|
-
if task_res.task.HasField("error"):
|
|
193
|
-
raise ValueError("Exception during client-side task execution")
|
|
194
|
-
if not task_res.task.HasField("recordset"):
|
|
195
|
-
raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing")
|
|
@@ -151,31 +151,40 @@ class GrpcDriver(Driver):
|
|
|
151
151
|
* CA certificate.
|
|
152
152
|
* server certificate.
|
|
153
153
|
* server private key.
|
|
154
|
+
fab_id : str (default: None)
|
|
155
|
+
The identifier of the FAB used in the run.
|
|
156
|
+
fab_version : str (default: None)
|
|
157
|
+
The version of the FAB used in the run.
|
|
154
158
|
"""
|
|
155
159
|
|
|
156
160
|
def __init__(
|
|
157
161
|
self,
|
|
158
162
|
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
|
|
159
163
|
root_certificates: Optional[bytes] = None,
|
|
164
|
+
fab_id: Optional[str] = None,
|
|
165
|
+
fab_version: Optional[str] = None,
|
|
160
166
|
) -> None:
|
|
161
167
|
self.addr = driver_service_address
|
|
162
168
|
self.root_certificates = root_certificates
|
|
163
|
-
self.
|
|
169
|
+
self.driver_helper: Optional[GrpcDriverHelper] = None
|
|
164
170
|
self.run_id: Optional[int] = None
|
|
171
|
+
self.fab_id = fab_id if fab_id is not None else ""
|
|
172
|
+
self.fab_version = fab_version if fab_version is not None else ""
|
|
165
173
|
self.node = Node(node_id=0, anonymous=True)
|
|
166
174
|
|
|
167
175
|
def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
|
|
168
176
|
# Check if the GrpcDriverHelper is initialized
|
|
169
|
-
if self.
|
|
177
|
+
if self.driver_helper is None or self.run_id is None:
|
|
170
178
|
# Connect and create run
|
|
171
|
-
self.
|
|
179
|
+
self.driver_helper = GrpcDriverHelper(
|
|
172
180
|
driver_service_address=self.addr,
|
|
173
181
|
root_certificates=self.root_certificates,
|
|
174
182
|
)
|
|
175
|
-
self.
|
|
176
|
-
|
|
183
|
+
self.driver_helper.connect()
|
|
184
|
+
req = CreateRunRequest(fab_id=self.fab_id, fab_version=self.fab_version)
|
|
185
|
+
res = self.driver_helper.create_run(req)
|
|
177
186
|
self.run_id = res.run_id
|
|
178
|
-
return self.
|
|
187
|
+
return self.driver_helper, self.run_id
|
|
179
188
|
|
|
180
189
|
def _check_message(self, message: Message) -> None:
|
|
181
190
|
# Check if the message is valid
|
|
@@ -300,7 +309,7 @@ class GrpcDriver(Driver):
|
|
|
300
309
|
def close(self) -> None:
|
|
301
310
|
"""Disconnect from the SuperLink if connected."""
|
|
302
311
|
# Check if GrpcDriverHelper is initialized
|
|
303
|
-
if self.
|
|
312
|
+
if self.driver_helper is None:
|
|
304
313
|
return
|
|
305
314
|
# Disconnect
|
|
306
|
-
self.
|
|
315
|
+
self.driver_helper.disconnect()
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -132,6 +132,8 @@ def run_server_app() -> None:
|
|
|
132
132
|
driver = GrpcDriver(
|
|
133
133
|
driver_service_address=args.server,
|
|
134
134
|
root_certificates=root_certificates,
|
|
135
|
+
fab_id=args.fab_id,
|
|
136
|
+
fab_version=args.fab_version,
|
|
135
137
|
)
|
|
136
138
|
|
|
137
139
|
# Run the ServerApp with the Driver
|
|
@@ -183,5 +185,17 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
|
183
185
|
"app from there."
|
|
184
186
|
" Default: current working directory.",
|
|
185
187
|
)
|
|
188
|
+
parser.add_argument(
|
|
189
|
+
"--fab-id",
|
|
190
|
+
default=None,
|
|
191
|
+
type=str,
|
|
192
|
+
help="The identifier of the FAB used in the run.",
|
|
193
|
+
)
|
|
194
|
+
parser.add_argument(
|
|
195
|
+
"--fab-version",
|
|
196
|
+
default=None,
|
|
197
|
+
type=str,
|
|
198
|
+
help="The version of the FAB used in the run.",
|
|
199
|
+
)
|
|
186
200
|
|
|
187
201
|
return parser
|
|
@@ -64,7 +64,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
64
64
|
"""Create run ID."""
|
|
65
65
|
log(INFO, "DriverServicer.CreateRun")
|
|
66
66
|
state: State = self.state_factory.state()
|
|
67
|
-
run_id = state.create_run(
|
|
67
|
+
run_id = state.create_run(request.fab_id, request.fab_version)
|
|
68
68
|
return CreateRunResponse(run_id=run_id)
|
|
69
69
|
|
|
70
70
|
def PushTaskIns(
|
|
@@ -30,7 +30,7 @@ from flwr.server.utils import validate_task_ins_or_res
|
|
|
30
30
|
from .utils import make_node_unavailable_taskres
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
class InMemoryState(State):
|
|
33
|
+
class InMemoryState(State): # pylint: disable=R0902
|
|
34
34
|
"""In-memory State implementation."""
|
|
35
35
|
|
|
36
36
|
def __init__(self) -> None:
|
|
@@ -40,6 +40,9 @@ class InMemoryState(State):
|
|
|
40
40
|
self.run_ids: Dict[int, Tuple[str, str]] = {}
|
|
41
41
|
self.task_ins_store: Dict[UUID, TaskIns] = {}
|
|
42
42
|
self.task_res_store: Dict[UUID, TaskRes] = {}
|
|
43
|
+
self.client_public_keys: Set[bytes] = set()
|
|
44
|
+
self.server_public_key: Optional[bytes] = None
|
|
45
|
+
self.server_private_key: Optional[bytes] = None
|
|
43
46
|
self.lock = threading.Lock()
|
|
44
47
|
|
|
45
48
|
def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
|
|
@@ -251,6 +254,39 @@ class InMemoryState(State):
|
|
|
251
254
|
log(ERROR, "Unexpected run creation failure.")
|
|
252
255
|
return 0
|
|
253
256
|
|
|
257
|
+
def store_server_public_private_key(
|
|
258
|
+
self, public_key: bytes, private_key: bytes
|
|
259
|
+
) -> None:
|
|
260
|
+
"""Store `server_public_key` and `server_private_key` in state."""
|
|
261
|
+
with self.lock:
|
|
262
|
+
if self.server_private_key is None and self.server_public_key is None:
|
|
263
|
+
self.server_private_key = private_key
|
|
264
|
+
self.server_public_key = public_key
|
|
265
|
+
else:
|
|
266
|
+
raise RuntimeError("Server public and private key already set")
|
|
267
|
+
|
|
268
|
+
def get_server_private_key(self) -> Optional[bytes]:
|
|
269
|
+
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
270
|
+
return self.server_private_key
|
|
271
|
+
|
|
272
|
+
def get_server_public_key(self) -> Optional[bytes]:
|
|
273
|
+
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
274
|
+
return self.server_public_key
|
|
275
|
+
|
|
276
|
+
def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
277
|
+
"""Store a set of `client_public_keys` in state."""
|
|
278
|
+
with self.lock:
|
|
279
|
+
self.client_public_keys = public_keys
|
|
280
|
+
|
|
281
|
+
def store_client_public_key(self, public_key: bytes) -> None:
|
|
282
|
+
"""Store a `client_public_key` in state."""
|
|
283
|
+
with self.lock:
|
|
284
|
+
self.client_public_keys.add(public_key)
|
|
285
|
+
|
|
286
|
+
def get_client_public_keys(self) -> Set[bytes]:
|
|
287
|
+
"""Retrieve all currently stored `client_public_keys` as a set."""
|
|
288
|
+
return self.client_public_keys
|
|
289
|
+
|
|
254
290
|
def get_run(self, run_id: int) -> Tuple[int, str, str]:
|
|
255
291
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
256
292
|
with self.lock:
|
|
@@ -20,7 +20,7 @@ import re
|
|
|
20
20
|
import sqlite3
|
|
21
21
|
import time
|
|
22
22
|
from logging import DEBUG, ERROR
|
|
23
|
-
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
|
|
23
|
+
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
|
|
24
24
|
from uuid import UUID, uuid4
|
|
25
25
|
|
|
26
26
|
from flwr.common import log, now
|
|
@@ -40,6 +40,19 @@ CREATE TABLE IF NOT EXISTS node(
|
|
|
40
40
|
);
|
|
41
41
|
"""
|
|
42
42
|
|
|
43
|
+
SQL_CREATE_TABLE_CREDENTIAL = """
|
|
44
|
+
CREATE TABLE IF NOT EXISTS credential(
|
|
45
|
+
public_key BLOB PRIMARY KEY,
|
|
46
|
+
private_key BLOB
|
|
47
|
+
);
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
SQL_CREATE_TABLE_PUBLIC_KEY = """
|
|
51
|
+
CREATE TABLE IF NOT EXISTS public_key(
|
|
52
|
+
public_key BLOB UNIQUE
|
|
53
|
+
);
|
|
54
|
+
"""
|
|
55
|
+
|
|
43
56
|
SQL_CREATE_INDEX_ONLINE_UNTIL = """
|
|
44
57
|
CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
|
45
58
|
"""
|
|
@@ -72,7 +85,6 @@ CREATE TABLE IF NOT EXISTS task_ins(
|
|
|
72
85
|
);
|
|
73
86
|
"""
|
|
74
87
|
|
|
75
|
-
|
|
76
88
|
SQL_CREATE_TABLE_TASK_RES = """
|
|
77
89
|
CREATE TABLE IF NOT EXISTS task_res(
|
|
78
90
|
task_id TEXT UNIQUE,
|
|
@@ -96,7 +108,7 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
|
96
108
|
DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]]
|
|
97
109
|
|
|
98
110
|
|
|
99
|
-
class SqliteState(State):
|
|
111
|
+
class SqliteState(State): # pylint: disable=R0904
|
|
100
112
|
"""SQLite-based state implementation."""
|
|
101
113
|
|
|
102
114
|
def __init__(
|
|
@@ -134,6 +146,8 @@ class SqliteState(State):
|
|
|
134
146
|
cur.execute(SQL_CREATE_TABLE_TASK_INS)
|
|
135
147
|
cur.execute(SQL_CREATE_TABLE_TASK_RES)
|
|
136
148
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
149
|
+
cur.execute(SQL_CREATE_TABLE_CREDENTIAL)
|
|
150
|
+
cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
|
|
137
151
|
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
|
138
152
|
res = cur.execute("SELECT name FROM sqlite_schema;")
|
|
139
153
|
|
|
@@ -142,7 +156,7 @@ class SqliteState(State):
|
|
|
142
156
|
def query(
|
|
143
157
|
self,
|
|
144
158
|
query: str,
|
|
145
|
-
data: Optional[Union[
|
|
159
|
+
data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
|
|
146
160
|
) -> List[Dict[str, Any]]:
|
|
147
161
|
"""Execute a SQL query."""
|
|
148
162
|
if self.conn is None:
|
|
@@ -575,6 +589,59 @@ class SqliteState(State):
|
|
|
575
589
|
log(ERROR, "Unexpected run creation failure.")
|
|
576
590
|
return 0
|
|
577
591
|
|
|
592
|
+
def store_server_public_private_key(
|
|
593
|
+
self, public_key: bytes, private_key: bytes
|
|
594
|
+
) -> None:
|
|
595
|
+
"""Store `server_public_key` and `server_private_key` in state."""
|
|
596
|
+
query = "SELECT COUNT(*) FROM credential"
|
|
597
|
+
count = self.query(query)[0]["COUNT(*)"]
|
|
598
|
+
if count < 1:
|
|
599
|
+
query = (
|
|
600
|
+
"INSERT OR REPLACE INTO credential (public_key, private_key) "
|
|
601
|
+
"VALUES (:public_key, :private_key)"
|
|
602
|
+
)
|
|
603
|
+
self.query(query, {"public_key": public_key, "private_key": private_key})
|
|
604
|
+
else:
|
|
605
|
+
raise RuntimeError("Server public and private key already set")
|
|
606
|
+
|
|
607
|
+
def get_server_private_key(self) -> Optional[bytes]:
|
|
608
|
+
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
609
|
+
query = "SELECT private_key FROM credential"
|
|
610
|
+
rows = self.query(query)
|
|
611
|
+
try:
|
|
612
|
+
private_key: Optional[bytes] = rows[0]["private_key"]
|
|
613
|
+
except IndexError:
|
|
614
|
+
private_key = None
|
|
615
|
+
return private_key
|
|
616
|
+
|
|
617
|
+
def get_server_public_key(self) -> Optional[bytes]:
|
|
618
|
+
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
619
|
+
query = "SELECT public_key FROM credential"
|
|
620
|
+
rows = self.query(query)
|
|
621
|
+
try:
|
|
622
|
+
public_key: Optional[bytes] = rows[0]["public_key"]
|
|
623
|
+
except IndexError:
|
|
624
|
+
public_key = None
|
|
625
|
+
return public_key
|
|
626
|
+
|
|
627
|
+
def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
628
|
+
"""Store a set of `client_public_keys` in state."""
|
|
629
|
+
query = "INSERT INTO public_key (public_key) VALUES (?)"
|
|
630
|
+
data = [(key,) for key in public_keys]
|
|
631
|
+
self.query(query, data)
|
|
632
|
+
|
|
633
|
+
def store_client_public_key(self, public_key: bytes) -> None:
|
|
634
|
+
"""Store a `client_public_key` in state."""
|
|
635
|
+
query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
|
|
636
|
+
self.query(query, {"public_key": public_key})
|
|
637
|
+
|
|
638
|
+
def get_client_public_keys(self) -> Set[bytes]:
|
|
639
|
+
"""Retrieve all currently stored `client_public_keys` as a set."""
|
|
640
|
+
query = "SELECT public_key FROM public_key"
|
|
641
|
+
rows = self.query(query)
|
|
642
|
+
result: Set[bytes] = {row["public_key"] for row in rows}
|
|
643
|
+
return result
|
|
644
|
+
|
|
578
645
|
def get_run(self, run_id: int) -> Tuple[int, str, str]:
|
|
579
646
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
580
647
|
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
@@ -171,6 +171,32 @@ class State(abc.ABC):
|
|
|
171
171
|
- `fab_version`: The version of the FAB used in the specified run.
|
|
172
172
|
"""
|
|
173
173
|
|
|
174
|
+
@abc.abstractmethod
|
|
175
|
+
def store_server_public_private_key(
|
|
176
|
+
self, public_key: bytes, private_key: bytes
|
|
177
|
+
) -> None:
|
|
178
|
+
"""Store `server_public_key` and `server_private_key` in state."""
|
|
179
|
+
|
|
180
|
+
@abc.abstractmethod
|
|
181
|
+
def get_server_private_key(self) -> Optional[bytes]:
|
|
182
|
+
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
183
|
+
|
|
184
|
+
@abc.abstractmethod
|
|
185
|
+
def get_server_public_key(self) -> Optional[bytes]:
|
|
186
|
+
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
187
|
+
|
|
188
|
+
@abc.abstractmethod
|
|
189
|
+
def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
190
|
+
"""Store a set of `client_public_keys` in state."""
|
|
191
|
+
|
|
192
|
+
@abc.abstractmethod
|
|
193
|
+
def store_client_public_key(self, public_key: bytes) -> None:
|
|
194
|
+
"""Store a `client_public_key` in state."""
|
|
195
|
+
|
|
196
|
+
@abc.abstractmethod
|
|
197
|
+
def get_client_public_keys(self) -> Set[bytes]:
|
|
198
|
+
"""Retrieve all currently stored `client_public_keys` as a set."""
|
|
199
|
+
|
|
174
200
|
@abc.abstractmethod
|
|
175
201
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
176
202
|
"""Acknowledge a ping received from a node, serving as a heartbeat.
|