flwr-nightly 1.8.0.dev20240315__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +7 -0
- flwr/cli/build.py +150 -0
- flwr/cli/config_utils.py +219 -0
- flwr/cli/example.py +3 -1
- flwr/cli/install.py +227 -0
- flwr/cli/new/new.py +179 -48
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/README.md.tpl +1 -5
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
- flwr/cli/run/run.py +168 -17
- flwr/cli/utils.py +75 -4
- flwr/client/__init__.py +6 -1
- flwr/client/app.py +239 -248
- flwr/client/client_app.py +70 -9
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +97 -0
- flwr/client/grpc_client/connection.py +18 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +127 -33
- flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
- flwr/client/heartbeat.py +74 -0
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +7 -7
- flwr/client/mod/__init__.py +5 -5
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/comms_mods.py +4 -4
- flwr/client/mod/localdp_mod.py +9 -4
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +177 -157
- flwr/client/supernode/__init__.py +26 -0
- flwr/client/supernode/app.py +464 -0
- flwr/client/typing.py +1 -0
- flwr/common/__init__.py +13 -11
- flwr/common/address.py +1 -1
- flwr/common/config.py +193 -0
- flwr/common/constant.py +42 -1
- flwr/common/context.py +26 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +6 -2
- flwr/common/logger.py +79 -8
- flwr/common/message.py +167 -105
- flwr/common/object_ref.py +126 -25
- flwr/common/record/__init__.py +1 -1
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/recordset_compat.py +8 -1
- flwr/common/retry_invoker.py +25 -13
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +209 -3
- flwr/common/telemetry.py +25 -0
- flwr/common/typing.py +38 -0
- flwr/common/version.py +14 -0
- flwr/proto/clientappio_pb2.py +41 -0
- flwr/proto/clientappio_pb2.pyi +110 -0
- flwr/proto/clientappio_pb2_grpc.py +101 -0
- flwr/proto/clientappio_pb2_grpc.pyi +40 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +26 -19
- flwr/proto/driver_pb2.pyi +34 -0
- flwr/proto/driver_pb2_grpc.py +70 -0
- flwr/proto/driver_pb2_grpc.pyi +28 -0
- flwr/proto/exec_pb2.py +43 -0
- flwr/proto/exec_pb2.pyi +95 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +29 -23
- flwr/proto/fleet_pb2.pyi +33 -0
- flwr/proto/fleet_pb2_grpc.py +102 -0
- flwr/proto/fleet_pb2_grpc.pyi +35 -0
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/message_pb2.py +41 -0
- flwr/proto/message_pb2.pyi +122 -0
- flwr/proto/message_pb2_grpc.py +4 -0
- flwr/proto/message_pb2_grpc.pyi +4 -0
- flwr/proto/run_pb2.py +35 -0
- flwr/proto/run_pb2.pyi +76 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +7 -8
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/__init__.py +4 -8
- flwr/server/app.py +298 -350
- flwr/server/compat/app.py +6 -57
- flwr/server/compat/app_utils.py +5 -4
- flwr/server/compat/driver_client_proxy.py +29 -48
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/__init__.py +2 -0
- flwr/server/driver/driver.py +22 -132
- flwr/server/driver/grpc_driver.py +224 -74
- flwr/server/driver/inmemory_driver.py +183 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +121 -34
- flwr/server/server.py +11 -7
- flwr/server/server_app.py +59 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dp_adaptive_clipping.py +3 -3
- flwr/server/strategy/dp_fixed_clipping.py +4 -3
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +51 -4
- flwr/server/superlink/ffs/__init__.py +24 -0
- flwr/server/superlink/ffs/disk_ffs.py +104 -0
- flwr/server/superlink/ffs/ffs.py +79 -0
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
- flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
- flwr/server/superlink/fleet/vce/vce_api.py +190 -127
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +159 -42
- flwr/server/superlink/state/sqlite_state.py +243 -39
- flwr/server/superlink/state/state.py +81 -6
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +62 -0
- flwr/server/typing.py +2 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +67 -25
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
- flwr/simulation/__init__.py +7 -4
- flwr/simulation/app.py +67 -36
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +20 -46
- flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
- flwr/simulation/run_simulation.py +308 -92
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +184 -0
- flwr/superexec/deployment.py +185 -0
- flwr/superexec/exec_grpc.py +55 -0
- flwr/superexec/exec_servicer.py +70 -0
- flwr/superexec/executor.py +75 -0
- flwr/superexec/simulation.py +193 -0
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
- flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
- flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
- flwr/cli/flower_toml.py +0 -140
- flwr/cli/new/templates/app/flower.toml.tpl +0 -13
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
- flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
- flwr_nightly-1.8.0.dev20240315.dist-info/RECORD +0 -211
- flwr_nightly-1.8.0.dev20240315.dist-info/entry_points.txt +0 -9
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,27 +15,40 @@
|
|
|
15
15
|
"""In-memory State implementation."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import os
|
|
19
18
|
import threading
|
|
20
|
-
|
|
19
|
+
import time
|
|
21
20
|
from logging import ERROR
|
|
22
|
-
from typing import Dict, List, Optional, Set
|
|
21
|
+
from typing import Dict, List, Optional, Set, Tuple
|
|
23
22
|
from uuid import UUID, uuid4
|
|
24
23
|
|
|
25
24
|
from flwr.common import log, now
|
|
25
|
+
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
|
|
26
|
+
from flwr.common.typing import Run, UserConfig
|
|
26
27
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
27
28
|
from flwr.server.superlink.state.state import State
|
|
28
29
|
from flwr.server.utils import validate_task_ins_or_res
|
|
29
30
|
|
|
31
|
+
from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
|
|
30
32
|
|
|
31
|
-
|
|
33
|
+
|
|
34
|
+
class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
32
35
|
"""In-memory State implementation."""
|
|
33
36
|
|
|
34
37
|
def __init__(self) -> None:
|
|
35
|
-
|
|
36
|
-
|
|
38
|
+
|
|
39
|
+
# Map node_id to (online_until, ping_interval)
|
|
40
|
+
self.node_ids: Dict[int, Tuple[float, float]] = {}
|
|
41
|
+
self.public_key_to_node_id: Dict[bytes, int] = {}
|
|
42
|
+
|
|
43
|
+
# Map run_id to (fab_id, fab_version)
|
|
44
|
+
self.run_ids: Dict[int, Run] = {}
|
|
37
45
|
self.task_ins_store: Dict[UUID, TaskIns] = {}
|
|
38
46
|
self.task_res_store: Dict[UUID, TaskRes] = {}
|
|
47
|
+
|
|
48
|
+
self.client_public_keys: Set[bytes] = set()
|
|
49
|
+
self.server_public_key: Optional[bytes] = None
|
|
50
|
+
self.server_private_key: Optional[bytes] = None
|
|
51
|
+
|
|
39
52
|
self.lock = threading.Lock()
|
|
40
53
|
|
|
41
54
|
def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
|
|
@@ -50,15 +63,11 @@ class InMemoryState(State):
|
|
|
50
63
|
log(ERROR, "`run_id` is invalid")
|
|
51
64
|
return None
|
|
52
65
|
|
|
53
|
-
# Create task_id
|
|
66
|
+
# Create task_id
|
|
54
67
|
task_id = uuid4()
|
|
55
|
-
created_at: datetime = now()
|
|
56
|
-
ttl: datetime = created_at + timedelta(hours=24)
|
|
57
68
|
|
|
58
69
|
# Store TaskIns
|
|
59
70
|
task_ins.task_id = str(task_id)
|
|
60
|
-
task_ins.task.created_at = created_at.isoformat()
|
|
61
|
-
task_ins.task.ttl = ttl.isoformat()
|
|
62
71
|
with self.lock:
|
|
63
72
|
self.task_ins_store[task_id] = task_ins
|
|
64
73
|
|
|
@@ -113,15 +122,11 @@ class InMemoryState(State):
|
|
|
113
122
|
log(ERROR, "`run_id` is invalid")
|
|
114
123
|
return None
|
|
115
124
|
|
|
116
|
-
# Create task_id
|
|
125
|
+
# Create task_id
|
|
117
126
|
task_id = uuid4()
|
|
118
|
-
created_at: datetime = now()
|
|
119
|
-
ttl: datetime = created_at + timedelta(hours=24)
|
|
120
127
|
|
|
121
128
|
# Store TaskRes
|
|
122
129
|
task_res.task_id = str(task_id)
|
|
123
|
-
task_res.task.created_at = created_at.isoformat()
|
|
124
|
-
task_res.task.ttl = ttl.isoformat()
|
|
125
130
|
with self.lock:
|
|
126
131
|
self.task_res_store[task_id] = task_res
|
|
127
132
|
|
|
@@ -136,14 +141,31 @@ class InMemoryState(State):
|
|
|
136
141
|
with self.lock:
|
|
137
142
|
# Find TaskRes that were not delivered yet
|
|
138
143
|
task_res_list: List[TaskRes] = []
|
|
144
|
+
replied_task_ids: Set[UUID] = set()
|
|
139
145
|
for _, task_res in self.task_res_store.items():
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
and task_res.task.delivered_at == ""
|
|
143
|
-
):
|
|
146
|
+
reply_to = UUID(task_res.task.ancestry[0])
|
|
147
|
+
if reply_to in task_ids and task_res.task.delivered_at == "":
|
|
144
148
|
task_res_list.append(task_res)
|
|
149
|
+
replied_task_ids.add(reply_to)
|
|
150
|
+
if limit and len(task_res_list) == limit:
|
|
151
|
+
break
|
|
152
|
+
|
|
153
|
+
# Check if the node is offline
|
|
154
|
+
for task_id in task_ids - replied_task_ids:
|
|
145
155
|
if limit and len(task_res_list) == limit:
|
|
146
156
|
break
|
|
157
|
+
task_ins = self.task_ins_store.get(task_id)
|
|
158
|
+
if task_ins is None:
|
|
159
|
+
continue
|
|
160
|
+
node_id = task_ins.task.consumer.node_id
|
|
161
|
+
online_until, _ = self.node_ids[node_id]
|
|
162
|
+
# Generate a TaskRes containing an error reply if the node is offline.
|
|
163
|
+
if online_until < time.time():
|
|
164
|
+
err_taskres = make_node_unavailable_taskres(
|
|
165
|
+
ref_taskins=task_ins,
|
|
166
|
+
)
|
|
167
|
+
self.task_res_store[UUID(err_taskres.task_id)] = err_taskres
|
|
168
|
+
task_res_list.append(err_taskres)
|
|
147
169
|
|
|
148
170
|
# Mark all of them as delivered
|
|
149
171
|
delivered_at = now().isoformat()
|
|
@@ -189,22 +211,47 @@ class InMemoryState(State):
|
|
|
189
211
|
"""
|
|
190
212
|
return len(self.task_res_store)
|
|
191
213
|
|
|
192
|
-
def create_node(
|
|
214
|
+
def create_node(
|
|
215
|
+
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
216
|
+
) -> int:
|
|
193
217
|
"""Create, store in state, and return `node_id`."""
|
|
194
218
|
# Sample a random int64 as node_id
|
|
195
|
-
node_id
|
|
219
|
+
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
|
196
220
|
|
|
197
|
-
|
|
198
|
-
self.node_ids
|
|
221
|
+
with self.lock:
|
|
222
|
+
if node_id in self.node_ids:
|
|
223
|
+
log(ERROR, "Unexpected node registration failure.")
|
|
224
|
+
return 0
|
|
225
|
+
|
|
226
|
+
if public_key is not None:
|
|
227
|
+
if (
|
|
228
|
+
public_key in self.public_key_to_node_id
|
|
229
|
+
or node_id in self.public_key_to_node_id.values()
|
|
230
|
+
):
|
|
231
|
+
log(ERROR, "Unexpected node registration failure.")
|
|
232
|
+
return 0
|
|
233
|
+
|
|
234
|
+
self.public_key_to_node_id[public_key] = node_id
|
|
235
|
+
|
|
236
|
+
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
199
237
|
return node_id
|
|
200
|
-
log(ERROR, "Unexpected node registration failure.")
|
|
201
|
-
return 0
|
|
202
238
|
|
|
203
|
-
def delete_node(self, node_id: int) -> None:
|
|
239
|
+
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
|
|
204
240
|
"""Delete a client node."""
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
241
|
+
with self.lock:
|
|
242
|
+
if node_id not in self.node_ids:
|
|
243
|
+
raise ValueError(f"Node {node_id} not found")
|
|
244
|
+
|
|
245
|
+
if public_key is not None:
|
|
246
|
+
if (
|
|
247
|
+
public_key not in self.public_key_to_node_id
|
|
248
|
+
or node_id not in self.public_key_to_node_id.values()
|
|
249
|
+
):
|
|
250
|
+
raise ValueError("Public key or node_id not found")
|
|
251
|
+
|
|
252
|
+
del self.public_key_to_node_id[public_key]
|
|
253
|
+
|
|
254
|
+
del self.node_ids[node_id]
|
|
208
255
|
|
|
209
256
|
def get_nodes(self, run_id: int) -> Set[int]:
|
|
210
257
|
"""Return all available client nodes.
|
|
@@ -214,17 +261,87 @@ class InMemoryState(State):
|
|
|
214
261
|
If the provided `run_id` does not exist or has no matching nodes,
|
|
215
262
|
an empty `Set` MUST be returned.
|
|
216
263
|
"""
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
264
|
+
with self.lock:
|
|
265
|
+
if run_id not in self.run_ids:
|
|
266
|
+
return set()
|
|
267
|
+
current_time = time.time()
|
|
268
|
+
return {
|
|
269
|
+
node_id
|
|
270
|
+
for node_id, (online_until, _) in self.node_ids.items()
|
|
271
|
+
if online_until > current_time
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
def get_node_id(self, client_public_key: bytes) -> Optional[int]:
|
|
275
|
+
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
276
|
+
return self.public_key_to_node_id.get(client_public_key)
|
|
277
|
+
|
|
278
|
+
def create_run(
|
|
279
|
+
self,
|
|
280
|
+
fab_id: str,
|
|
281
|
+
fab_version: str,
|
|
282
|
+
override_config: UserConfig,
|
|
283
|
+
) -> int:
|
|
284
|
+
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
223
285
|
# Sample a random int64 as run_id
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
self.run_ids
|
|
228
|
-
|
|
286
|
+
with self.lock:
|
|
287
|
+
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
288
|
+
|
|
289
|
+
if run_id not in self.run_ids:
|
|
290
|
+
self.run_ids[run_id] = Run(
|
|
291
|
+
run_id=run_id,
|
|
292
|
+
fab_id=fab_id,
|
|
293
|
+
fab_version=fab_version,
|
|
294
|
+
override_config=override_config,
|
|
295
|
+
)
|
|
296
|
+
return run_id
|
|
229
297
|
log(ERROR, "Unexpected run creation failure.")
|
|
230
298
|
return 0
|
|
299
|
+
|
|
300
|
+
def store_server_private_public_key(
|
|
301
|
+
self, private_key: bytes, public_key: bytes
|
|
302
|
+
) -> None:
|
|
303
|
+
"""Store `server_private_key` and `server_public_key` in state."""
|
|
304
|
+
with self.lock:
|
|
305
|
+
if self.server_private_key is None and self.server_public_key is None:
|
|
306
|
+
self.server_private_key = private_key
|
|
307
|
+
self.server_public_key = public_key
|
|
308
|
+
else:
|
|
309
|
+
raise RuntimeError("Server private and public key already set")
|
|
310
|
+
|
|
311
|
+
def get_server_private_key(self) -> Optional[bytes]:
|
|
312
|
+
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
313
|
+
return self.server_private_key
|
|
314
|
+
|
|
315
|
+
def get_server_public_key(self) -> Optional[bytes]:
|
|
316
|
+
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
317
|
+
return self.server_public_key
|
|
318
|
+
|
|
319
|
+
def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
320
|
+
"""Store a set of `client_public_keys` in state."""
|
|
321
|
+
with self.lock:
|
|
322
|
+
self.client_public_keys = public_keys
|
|
323
|
+
|
|
324
|
+
def store_client_public_key(self, public_key: bytes) -> None:
|
|
325
|
+
"""Store a `client_public_key` in state."""
|
|
326
|
+
with self.lock:
|
|
327
|
+
self.client_public_keys.add(public_key)
|
|
328
|
+
|
|
329
|
+
def get_client_public_keys(self) -> Set[bytes]:
|
|
330
|
+
"""Retrieve all currently stored `client_public_keys` as a set."""
|
|
331
|
+
return self.client_public_keys
|
|
332
|
+
|
|
333
|
+
def get_run(self, run_id: int) -> Optional[Run]:
|
|
334
|
+
"""Retrieve information about the run with the specified `run_id`."""
|
|
335
|
+
with self.lock:
|
|
336
|
+
if run_id not in self.run_ids:
|
|
337
|
+
log(ERROR, "`run_id` is invalid")
|
|
338
|
+
return None
|
|
339
|
+
return self.run_ids[run_id]
|
|
340
|
+
|
|
341
|
+
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
342
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
343
|
+
with self.lock:
|
|
344
|
+
if node_id in self.node_ids:
|
|
345
|
+
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
346
|
+
return True
|
|
347
|
+
return False
|