flwr 1.13.1__py3-none-any.whl → 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +5 -0
- flwr/cli/auth_plugin/__init__.py +31 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
- flwr/cli/build.py +1 -0
- flwr/cli/cli_user_auth_interceptor.py +90 -0
- flwr/cli/config_utils.py +43 -149
- flwr/cli/constant.py +27 -0
- flwr/cli/example.py +1 -0
- flwr/cli/install.py +2 -1
- flwr/cli/log.py +34 -37
- flwr/cli/login/__init__.py +22 -0
- flwr/cli/login/login.py +116 -0
- flwr/cli/ls.py +214 -106
- flwr/cli/new/__init__.py +1 -0
- flwr/cli/new/new.py +2 -1
- flwr/cli/new/templates/app/.gitignore.tpl +3 -0
- flwr/cli/new/templates/app/README.md.tpl +3 -2
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +3 -4
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/run/__init__.py +1 -0
- flwr/cli/run/run.py +103 -43
- flwr/cli/stop.py +139 -0
- flwr/cli/utils.py +186 -8
- flwr/client/app.py +49 -50
- flwr/client/client.py +1 -32
- flwr/client/clientapp/app.py +23 -26
- flwr/client/clientapp/utils.py +2 -1
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/connection.py +2 -13
- flwr/client/grpc_rere_client/client_interceptor.py +19 -119
- flwr/client/grpc_rere_client/connection.py +59 -43
- flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
- flwr/client/message_handler/message_handler.py +1 -2
- flwr/client/message_handler/task_handler.py +0 -17
- flwr/client/mod/comms_mods.py +1 -0
- flwr/client/mod/localdp_mod.py +1 -1
- flwr/client/nodestate/__init__.py +1 -0
- flwr/client/nodestate/nodestate.py +1 -0
- flwr/client/nodestate/nodestate_factory.py +1 -0
- flwr/client/numpy_client.py +0 -44
- flwr/client/rest_client/connection.py +37 -29
- flwr/client/supernode/app.py +20 -74
- flwr/common/address.py +1 -0
- flwr/common/args.py +26 -47
- flwr/common/auth_plugin/__init__.py +24 -0
- flwr/common/auth_plugin/auth_plugin.py +122 -0
- flwr/common/config.py +169 -17
- flwr/common/constant.py +38 -9
- flwr/common/differential_privacy.py +2 -1
- flwr/common/exit/__init__.py +24 -0
- flwr/common/exit/exit.py +99 -0
- flwr/common/exit/exit_code.py +93 -0
- flwr/common/exit_handlers.py +24 -10
- flwr/common/grpc.py +167 -4
- flwr/common/logger.py +66 -7
- flwr/common/message.py +1 -0
- flwr/common/object_ref.py +57 -54
- flwr/common/pyproject.py +1 -0
- flwr/common/record/__init__.py +1 -0
- flwr/common/record/parametersrecord.py +1 -0
- flwr/common/record/recordset.py +1 -1
- flwr/common/retry_invoker.py +77 -0
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
- flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
- flwr/common/serde.py +6 -4
- flwr/common/telemetry.py +15 -4
- flwr/common/typing.py +32 -0
- flwr/common/version.py +1 -0
- flwr/proto/clientappio_pb2.py +1 -1
- flwr/proto/error_pb2.py +1 -1
- flwr/proto/exec_pb2.py +27 -15
- flwr/proto/exec_pb2.pyi +80 -2
- flwr/proto/exec_pb2_grpc.py +102 -0
- flwr/proto/exec_pb2_grpc.pyi +39 -0
- flwr/proto/fab_pb2.py +5 -5
- flwr/proto/fab_pb2.pyi +4 -1
- flwr/proto/fleet_pb2.py +31 -31
- flwr/proto/fleet_pb2.pyi +23 -23
- flwr/proto/fleet_pb2_grpc.py +30 -30
- flwr/proto/fleet_pb2_grpc.pyi +20 -20
- flwr/proto/grpcadapter_pb2.py +1 -1
- flwr/proto/log_pb2.py +1 -1
- flwr/proto/message_pb2.py +1 -1
- flwr/proto/node_pb2.py +3 -3
- flwr/proto/node_pb2.pyi +1 -4
- flwr/proto/recordset_pb2.py +1 -1
- flwr/proto/run_pb2.py +1 -1
- flwr/proto/serverappio_pb2.py +24 -25
- flwr/proto/serverappio_pb2.pyi +32 -32
- flwr/proto/serverappio_pb2_grpc.py +62 -28
- flwr/proto/serverappio_pb2_grpc.pyi +29 -16
- flwr/proto/simulationio_pb2.py +3 -3
- flwr/proto/simulationio_pb2_grpc.py +34 -0
- flwr/proto/simulationio_pb2_grpc.pyi +13 -0
- flwr/proto/task_pb2.py +1 -1
- flwr/proto/transport_pb2.py +1 -1
- flwr/server/app.py +152 -112
- flwr/server/compat/app_utils.py +7 -2
- flwr/server/compat/driver_client_proxy.py +1 -2
- flwr/server/driver/grpc_driver.py +38 -85
- flwr/server/driver/inmemory_driver.py +7 -2
- flwr/server/run_serverapp.py +8 -9
- flwr/server/serverapp/app.py +37 -13
- flwr/server/strategy/dpfedavg_fixed.py +1 -0
- flwr/server/superlink/driver/serverappio_grpc.py +2 -1
- flwr/server/superlink/driver/serverappio_servicer.py +148 -63
- flwr/server/superlink/ffs/disk_ffs.py +1 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -87
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +56 -35
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +99 -169
- flwr/server/superlink/fleet/message_handler/message_handler.py +69 -29
- flwr/server/superlink/fleet/rest_rere/rest_api.py +20 -19
- flwr/server/superlink/fleet/vce/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +60 -99
- flwr/server/superlink/linkstate/linkstate.py +30 -36
- flwr/server/superlink/linkstate/sqlite_linkstate.py +105 -188
- flwr/server/superlink/linkstate/utils.py +18 -8
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +33 -0
- flwr/server/superlink/utils.py +65 -0
- flwr/server/utils/validator.py +9 -34
- flwr/simulation/app.py +20 -10
- flwr/simulation/legacy_app.py +4 -2
- flwr/simulation/ray_transport/ray_actor.py +1 -0
- flwr/simulation/ray_transport/utils.py +1 -0
- flwr/simulation/run_simulation.py +36 -22
- flwr/simulation/simulationio_connection.py +5 -1
- flwr/superexec/app.py +1 -0
- flwr/superexec/deployment.py +1 -0
- flwr/superexec/exec_grpc.py +20 -2
- flwr/superexec/exec_servicer.py +97 -2
- flwr/superexec/exec_user_auth_interceptor.py +101 -0
- flwr/superexec/executor.py +1 -0
- {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/METADATA +14 -13
- {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/RECORD +150 -144
- flwr/proto/common_pb2.py +0 -36
- flwr/proto/common_pb2.pyi +0 -121
- flwr/proto/common_pb2_grpc.py +0 -4
- flwr/proto/common_pb2_grpc.pyi +0 -4
- flwr/proto/control_pb2.py +0 -27
- flwr/proto/control_pb2.pyi +0 -7
- flwr/proto/control_pb2_grpc.py +0 -135
- flwr/proto/control_pb2_grpc.pyi +0 -53
- {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/LICENSE +0 -0
- {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/WHEEL +0 -0
- {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/entry_points.txt +0 -0
|
@@ -17,13 +17,12 @@
|
|
|
17
17
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
|
-
import sys
|
|
21
20
|
from collections.abc import Awaitable
|
|
22
21
|
from typing import Callable, TypeVar, cast
|
|
23
22
|
|
|
24
23
|
from google.protobuf.message import Message as GrpcMessage
|
|
25
24
|
|
|
26
|
-
from flwr.common.
|
|
25
|
+
from flwr.common.exit import ExitCode, flwr_exit
|
|
27
26
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
28
27
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
29
28
|
CreateNodeRequest,
|
|
@@ -32,10 +31,10 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
32
31
|
DeleteNodeResponse,
|
|
33
32
|
PingRequest,
|
|
34
33
|
PingResponse,
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
34
|
+
PullMessagesRequest,
|
|
35
|
+
PullMessagesResponse,
|
|
36
|
+
PushMessagesRequest,
|
|
37
|
+
PushMessagesResponse,
|
|
39
38
|
)
|
|
40
39
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
41
40
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
@@ -51,7 +50,7 @@ try:
|
|
|
51
50
|
from starlette.responses import Response
|
|
52
51
|
from starlette.routing import Route
|
|
53
52
|
except ModuleNotFoundError:
|
|
54
|
-
|
|
53
|
+
flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
|
|
55
54
|
|
|
56
55
|
|
|
57
56
|
GrpcRequest = TypeVar("GrpcRequest", bound=GrpcMessage)
|
|
@@ -107,25 +106,24 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
|
107
106
|
return message_handler.delete_node(request=request, state=state)
|
|
108
107
|
|
|
109
108
|
|
|
110
|
-
@rest_request_response(
|
|
111
|
-
async def
|
|
112
|
-
"""Pull
|
|
109
|
+
@rest_request_response(PullMessagesRequest)
|
|
110
|
+
async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
|
|
111
|
+
"""Pull PullMessages."""
|
|
113
112
|
# Get state from app
|
|
114
113
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
115
114
|
|
|
116
115
|
# Handle message
|
|
117
|
-
return message_handler.
|
|
116
|
+
return message_handler.pull_messages(request=request, state=state)
|
|
118
117
|
|
|
119
118
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
"""Push TaskRes."""
|
|
119
|
+
@rest_request_response(PushMessagesRequest)
|
|
120
|
+
async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
|
|
121
|
+
"""Pull PushMessages."""
|
|
124
122
|
# Get state from app
|
|
125
123
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
126
124
|
|
|
127
125
|
# Handle message
|
|
128
|
-
return message_handler.
|
|
126
|
+
return message_handler.push_messages(request=request, state=state)
|
|
129
127
|
|
|
130
128
|
|
|
131
129
|
@rest_request_response(PingRequest)
|
|
@@ -154,15 +152,18 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
|
|
|
154
152
|
# Get ffs from app
|
|
155
153
|
ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()
|
|
156
154
|
|
|
155
|
+
# Get state from app
|
|
156
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
157
|
+
|
|
157
158
|
# Handle message
|
|
158
|
-
return message_handler.get_fab(request=request, ffs=ffs)
|
|
159
|
+
return message_handler.get_fab(request=request, ffs=ffs, state=state)
|
|
159
160
|
|
|
160
161
|
|
|
161
162
|
routes = [
|
|
162
163
|
Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
|
|
163
164
|
Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
|
|
164
|
-
Route("/api/v0/fleet/pull-
|
|
165
|
-
Route("/api/v0/fleet/push-
|
|
165
|
+
Route("/api/v0/fleet/pull-messages", pull_message, methods=["POST"]),
|
|
166
|
+
Route("/api/v0/fleet/push-messages", push_message, methods=["POST"]),
|
|
166
167
|
Route("/api/v0/fleet/ping", ping, methods=["POST"]),
|
|
167
168
|
Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
|
|
168
169
|
Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
|
|
@@ -182,8 +182,8 @@ def run_api(
|
|
|
182
182
|
f_stop: threading.Event,
|
|
183
183
|
) -> None:
|
|
184
184
|
"""Run the VCE."""
|
|
185
|
-
taskins_queue:
|
|
186
|
-
taskres_queue:
|
|
185
|
+
taskins_queue: Queue[TaskIns] = Queue()
|
|
186
|
+
taskres_queue: Queue[TaskRes] = Queue()
|
|
187
187
|
|
|
188
188
|
try:
|
|
189
189
|
|
|
@@ -28,6 +28,7 @@ from flwr.common.constant import (
|
|
|
28
28
|
MESSAGE_TTL_TOLERANCE,
|
|
29
29
|
NODE_ID_NUM_BYTES,
|
|
30
30
|
RUN_ID_NUM_BYTES,
|
|
31
|
+
SUPERLINK_NODE_ID,
|
|
31
32
|
Status,
|
|
32
33
|
)
|
|
33
34
|
from flwr.common.record import ConfigsRecord
|
|
@@ -62,6 +63,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
62
63
|
# Map node_id to (online_until, ping_interval)
|
|
63
64
|
self.node_ids: dict[int, tuple[float, float]] = {}
|
|
64
65
|
self.public_key_to_node_id: dict[bytes, int] = {}
|
|
66
|
+
self.node_id_to_public_key: dict[int, bytes] = {}
|
|
65
67
|
|
|
66
68
|
# Map run_id to RunRecord
|
|
67
69
|
self.run_ids: dict[int, RunRecord] = {}
|
|
@@ -72,8 +74,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
72
74
|
self.task_ins_id_to_task_res_id: dict[UUID, UUID] = {}
|
|
73
75
|
|
|
74
76
|
self.node_public_keys: set[bytes] = set()
|
|
75
|
-
self.server_public_key: Optional[bytes] = None
|
|
76
|
-
self.server_private_key: Optional[bytes] = None
|
|
77
77
|
|
|
78
78
|
self.lock = threading.RLock()
|
|
79
79
|
|
|
@@ -89,7 +89,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
89
89
|
log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
|
|
90
90
|
return None
|
|
91
91
|
# Validate source node ID
|
|
92
|
-
if task_ins.task.producer.node_id !=
|
|
92
|
+
if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
|
|
93
93
|
log(
|
|
94
94
|
ERROR,
|
|
95
95
|
"Invalid source node ID for TaskIns: %s",
|
|
@@ -97,14 +97,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
97
97
|
)
|
|
98
98
|
return None
|
|
99
99
|
# Validate destination node ID
|
|
100
|
-
if
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
return None
|
|
100
|
+
if task_ins.task.consumer.node_id not in self.node_ids:
|
|
101
|
+
log(
|
|
102
|
+
ERROR,
|
|
103
|
+
"Invalid destination node ID for TaskIns: %s",
|
|
104
|
+
task_ins.task.consumer.node_id,
|
|
105
|
+
)
|
|
106
|
+
return None
|
|
108
107
|
|
|
109
108
|
# Create task_id
|
|
110
109
|
task_id = uuid4()
|
|
@@ -117,9 +116,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
117
116
|
# Return the new task_id
|
|
118
117
|
return task_id
|
|
119
118
|
|
|
120
|
-
def get_task_ins(
|
|
121
|
-
self, node_id: Optional[int], limit: Optional[int]
|
|
122
|
-
) -> list[TaskIns]:
|
|
119
|
+
def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
|
|
123
120
|
"""Get all TaskIns that have not been delivered yet."""
|
|
124
121
|
if limit is not None and limit < 1:
|
|
125
122
|
raise AssertionError("`limit` must be >= 1")
|
|
@@ -129,17 +126,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
129
126
|
current_time = time.time()
|
|
130
127
|
with self.lock:
|
|
131
128
|
for _, task_ins in self.task_ins_store.items():
|
|
132
|
-
# pylint: disable=too-many-boolean-expressions
|
|
133
129
|
if (
|
|
134
|
-
node_id
|
|
135
|
-
and task_ins.task.consumer.anonymous is False
|
|
136
|
-
and task_ins.task.consumer.node_id == node_id
|
|
137
|
-
and task_ins.task.delivered_at == ""
|
|
138
|
-
and task_ins.task.created_at + task_ins.task.ttl > current_time
|
|
139
|
-
) or (
|
|
140
|
-
node_id is None # Anonymous
|
|
141
|
-
and task_ins.task.consumer.anonymous is True
|
|
142
|
-
and task_ins.task.consumer.node_id == 0
|
|
130
|
+
task_ins.task.consumer.node_id == node_id
|
|
143
131
|
and task_ins.task.delivered_at == ""
|
|
144
132
|
and task_ins.task.created_at + task_ins.task.ttl > current_time
|
|
145
133
|
):
|
|
@@ -173,9 +161,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
173
161
|
if (
|
|
174
162
|
task_ins
|
|
175
163
|
and task_res
|
|
176
|
-
and not (
|
|
177
|
-
task_ins.task.consumer.anonymous or task_res.task.producer.anonymous
|
|
178
|
-
)
|
|
179
164
|
and task_ins.task.consumer.node_id != task_res.task.producer.node_id
|
|
180
165
|
):
|
|
181
166
|
return None
|
|
@@ -265,41 +250,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
265
250
|
for task_res in task_res_found:
|
|
266
251
|
task_res.task.delivered_at = delivered_at
|
|
267
252
|
|
|
268
|
-
# Cleanup
|
|
269
|
-
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
270
|
-
|
|
271
253
|
return list(ret.values())
|
|
272
254
|
|
|
273
|
-
def delete_tasks(self,
|
|
274
|
-
"""Delete
|
|
275
|
-
|
|
276
|
-
task_res_to_be_deleted: set[UUID] = set()
|
|
277
|
-
|
|
278
|
-
with self.lock:
|
|
279
|
-
for task_ins_id in task_ids:
|
|
280
|
-
# Find the task_id of the matching task_res
|
|
281
|
-
for task_res_id, task_res in self.task_res_store.items():
|
|
282
|
-
if UUID(task_res.task.ancestry[0]) != task_ins_id:
|
|
283
|
-
continue
|
|
284
|
-
if task_res.task.delivered_at == "":
|
|
285
|
-
continue
|
|
286
|
-
|
|
287
|
-
task_ins_to_be_deleted.add(task_ins_id)
|
|
288
|
-
task_res_to_be_deleted.add(task_res_id)
|
|
289
|
-
|
|
290
|
-
for task_id in task_ins_to_be_deleted:
|
|
291
|
-
del self.task_ins_store[task_id]
|
|
292
|
-
del self.task_ins_id_to_task_res_id[task_id]
|
|
293
|
-
for task_id in task_res_to_be_deleted:
|
|
294
|
-
del self.task_res_store[task_id]
|
|
295
|
-
|
|
296
|
-
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
297
|
-
"""Delete tasks based on a set of TaskIns IDs."""
|
|
298
|
-
if not task_ids:
|
|
255
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
256
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
|
|
257
|
+
if not task_ins_ids:
|
|
299
258
|
return
|
|
300
259
|
|
|
301
260
|
with self.lock:
|
|
302
|
-
for task_id in
|
|
261
|
+
for task_id in task_ins_ids:
|
|
303
262
|
# Delete TaskIns
|
|
304
263
|
if task_id in self.task_ins_store:
|
|
305
264
|
del self.task_ins_store[task_id]
|
|
@@ -308,6 +267,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
308
267
|
task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
|
|
309
268
|
del self.task_res_store[task_res_id]
|
|
310
269
|
|
|
270
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
271
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
272
|
+
task_id_list: set[UUID] = set()
|
|
273
|
+
with self.lock:
|
|
274
|
+
for task_id, task_ins in self.task_ins_store.items():
|
|
275
|
+
if task_ins.run_id == run_id:
|
|
276
|
+
task_id_list.add(task_id)
|
|
277
|
+
|
|
278
|
+
return task_id_list
|
|
279
|
+
|
|
311
280
|
def num_task_ins(self) -> int:
|
|
312
281
|
"""Calculate the number of task_ins in store.
|
|
313
282
|
|
|
@@ -322,45 +291,30 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
322
291
|
"""
|
|
323
292
|
return len(self.task_res_store)
|
|
324
293
|
|
|
325
|
-
def create_node(
|
|
326
|
-
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
327
|
-
) -> int:
|
|
294
|
+
def create_node(self, ping_interval: float) -> int:
|
|
328
295
|
"""Create, store in the link state, and return `node_id`."""
|
|
329
296
|
# Sample a random int64 as node_id
|
|
330
|
-
node_id = generate_rand_int_from_bytes(
|
|
297
|
+
node_id = generate_rand_int_from_bytes(
|
|
298
|
+
NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
|
|
299
|
+
)
|
|
331
300
|
|
|
332
301
|
with self.lock:
|
|
333
302
|
if node_id in self.node_ids:
|
|
334
303
|
log(ERROR, "Unexpected node registration failure.")
|
|
335
304
|
return 0
|
|
336
305
|
|
|
337
|
-
if public_key is not None:
|
|
338
|
-
if (
|
|
339
|
-
public_key in self.public_key_to_node_id
|
|
340
|
-
or node_id in self.public_key_to_node_id.values()
|
|
341
|
-
):
|
|
342
|
-
log(ERROR, "Unexpected node registration failure.")
|
|
343
|
-
return 0
|
|
344
|
-
|
|
345
|
-
self.public_key_to_node_id[public_key] = node_id
|
|
346
|
-
|
|
347
306
|
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
348
307
|
return node_id
|
|
349
308
|
|
|
350
|
-
def delete_node(self, node_id: int
|
|
309
|
+
def delete_node(self, node_id: int) -> None:
|
|
351
310
|
"""Delete a node."""
|
|
352
311
|
with self.lock:
|
|
353
312
|
if node_id not in self.node_ids:
|
|
354
313
|
raise ValueError(f"Node {node_id} not found")
|
|
355
314
|
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
or node_id not in self.public_key_to_node_id.values()
|
|
360
|
-
):
|
|
361
|
-
raise ValueError("Public key or node_id not found")
|
|
362
|
-
|
|
363
|
-
del self.public_key_to_node_id[public_key]
|
|
315
|
+
# Remove node ID <> public key mappings
|
|
316
|
+
if pk := self.node_id_to_public_key.pop(node_id, None):
|
|
317
|
+
del self.public_key_to_node_id[pk]
|
|
364
318
|
|
|
365
319
|
del self.node_ids[node_id]
|
|
366
320
|
|
|
@@ -382,6 +336,26 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
382
336
|
if online_until > current_time
|
|
383
337
|
}
|
|
384
338
|
|
|
339
|
+
def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
|
|
340
|
+
"""Set `public_key` for the specified `node_id`."""
|
|
341
|
+
with self.lock:
|
|
342
|
+
if node_id not in self.node_ids:
|
|
343
|
+
raise ValueError(f"Node {node_id} not found")
|
|
344
|
+
|
|
345
|
+
if public_key in self.public_key_to_node_id:
|
|
346
|
+
raise ValueError("Public key already in use")
|
|
347
|
+
|
|
348
|
+
self.public_key_to_node_id[public_key] = node_id
|
|
349
|
+
self.node_id_to_public_key[node_id] = public_key
|
|
350
|
+
|
|
351
|
+
def get_node_public_key(self, node_id: int) -> Optional[bytes]:
|
|
352
|
+
"""Get `public_key` for the specified `node_id`."""
|
|
353
|
+
with self.lock:
|
|
354
|
+
if node_id not in self.node_ids:
|
|
355
|
+
raise ValueError(f"Node {node_id} not found")
|
|
356
|
+
|
|
357
|
+
return self.node_id_to_public_key.get(node_id)
|
|
358
|
+
|
|
385
359
|
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
|
|
386
360
|
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
|
387
361
|
return self.public_key_to_node_id.get(node_public_key)
|
|
@@ -427,29 +401,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
427
401
|
log(ERROR, "Unexpected run creation failure.")
|
|
428
402
|
return 0
|
|
429
403
|
|
|
430
|
-
def
|
|
431
|
-
|
|
432
|
-
) -> None:
|
|
433
|
-
"""Store `server_private_key` and `server_public_key` in the link state."""
|
|
404
|
+
def clear_supernode_auth_keys(self) -> None:
|
|
405
|
+
"""Clear stored `node_public_keys` in the link state if any."""
|
|
434
406
|
with self.lock:
|
|
435
|
-
|
|
436
|
-
self.server_private_key = private_key
|
|
437
|
-
self.server_public_key = public_key
|
|
438
|
-
else:
|
|
439
|
-
raise RuntimeError("Server private and public key already set")
|
|
440
|
-
|
|
441
|
-
def get_server_private_key(self) -> Optional[bytes]:
|
|
442
|
-
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
443
|
-
return self.server_private_key
|
|
444
|
-
|
|
445
|
-
def get_server_public_key(self) -> Optional[bytes]:
|
|
446
|
-
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
447
|
-
return self.server_public_key
|
|
407
|
+
self.node_public_keys.clear()
|
|
448
408
|
|
|
449
409
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
450
410
|
"""Store a set of `node_public_keys` in the link state."""
|
|
451
411
|
with self.lock:
|
|
452
|
-
self.node_public_keys
|
|
412
|
+
self.node_public_keys.update(public_keys)
|
|
453
413
|
|
|
454
414
|
def store_node_public_key(self, public_key: bytes) -> None:
|
|
455
415
|
"""Store a `node_public_key` in the link state."""
|
|
@@ -458,7 +418,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
458
418
|
|
|
459
419
|
def get_node_public_keys(self) -> set[bytes]:
|
|
460
420
|
"""Retrieve all currently stored `node_public_keys` as a set."""
|
|
461
|
-
|
|
421
|
+
with self.lock:
|
|
422
|
+
return self.node_public_keys.copy()
|
|
462
423
|
|
|
463
424
|
def get_run_ids(self) -> set[int]:
|
|
464
425
|
"""Retrieve all run IDs."""
|
|
@@ -40,20 +40,14 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
40
40
|
|
|
41
41
|
Constraints
|
|
42
42
|
-----------
|
|
43
|
-
|
|
44
|
-
`task_ins.task.consumer.node_id` MUST NOT be set (equal 0).
|
|
45
|
-
|
|
46
|
-
If `task_ins.task.consumer.anonymous` is `False`, then
|
|
47
|
-
`task_ins.task.consumer.node_id` MUST be set (not 0)
|
|
43
|
+
`task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
|
48
44
|
|
|
49
45
|
If `task_ins.run_id` is invalid, then
|
|
50
46
|
storing the `task_ins` MUST fail.
|
|
51
47
|
"""
|
|
52
48
|
|
|
53
49
|
@abc.abstractmethod
|
|
54
|
-
def get_task_ins(
|
|
55
|
-
self, node_id: Optional[int], limit: Optional[int]
|
|
56
|
-
) -> list[TaskIns]:
|
|
50
|
+
def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
|
|
57
51
|
"""Get TaskIns optionally filtered by node_id.
|
|
58
52
|
|
|
59
53
|
Usually, the Fleet API calls this for Nodes planning to work on one or more
|
|
@@ -61,15 +55,11 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
61
55
|
|
|
62
56
|
Constraints
|
|
63
57
|
-----------
|
|
64
|
-
|
|
58
|
+
Retrieve all TaskIns where
|
|
65
59
|
|
|
66
60
|
1. the `task_ins.task.consumer.node_id` equals `node_id` AND
|
|
67
|
-
2. the `task_ins.task.
|
|
68
|
-
3. the `task_ins.task.delivered_at` equals `""`.
|
|
61
|
+
2. the `task_ins.task.delivered_at` equals `""`.
|
|
69
62
|
|
|
70
|
-
If `node_id` is `None`, retrieve all TaskIns where the
|
|
71
|
-
`task_ins.task.consumer.node_id` equals `0` and
|
|
72
|
-
`task_ins.task.consumer.anonymous` is set to `True`.
|
|
73
63
|
|
|
74
64
|
If `delivered_at` MUST BE set (not `""`) otherwise the TaskIns MUST not be in
|
|
75
65
|
the result.
|
|
@@ -89,11 +79,8 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
89
79
|
|
|
90
80
|
Constraints
|
|
91
81
|
-----------
|
|
92
|
-
If `task_res.task.consumer.anonymous` is `True`, then
|
|
93
|
-
`task_res.task.consumer.node_id` MUST NOT be set (equal 0).
|
|
94
82
|
|
|
95
|
-
|
|
96
|
-
`task_res.task.consumer.node_id` MUST be set (not 0)
|
|
83
|
+
`task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
|
97
84
|
|
|
98
85
|
If `task_res.run_id` is invalid, then
|
|
99
86
|
storing the `task_res` MUST fail.
|
|
@@ -139,17 +126,26 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
139
126
|
"""
|
|
140
127
|
|
|
141
128
|
@abc.abstractmethod
|
|
142
|
-
def delete_tasks(self,
|
|
143
|
-
"""Delete
|
|
129
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
130
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
task_ins_ids : set[UUID]
|
|
135
|
+
A set of TaskIns IDs. For each ID in the set, the corresponding
|
|
136
|
+
TaskIns and its associated TaskRes will be deleted.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
@abc.abstractmethod
|
|
140
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
141
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
144
142
|
|
|
145
143
|
@abc.abstractmethod
|
|
146
|
-
def create_node(
|
|
147
|
-
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
148
|
-
) -> int:
|
|
144
|
+
def create_node(self, ping_interval: float) -> int:
|
|
149
145
|
"""Create, store in the link state, and return `node_id`."""
|
|
150
146
|
|
|
151
147
|
@abc.abstractmethod
|
|
152
|
-
def delete_node(self, node_id: int
|
|
148
|
+
def delete_node(self, node_id: int) -> None:
|
|
153
149
|
"""Remove `node_id` from the link state."""
|
|
154
150
|
|
|
155
151
|
@abc.abstractmethod
|
|
@@ -162,6 +158,14 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
162
158
|
an empty `Set` MUST be returned.
|
|
163
159
|
"""
|
|
164
160
|
|
|
161
|
+
@abc.abstractmethod
|
|
162
|
+
def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
|
|
163
|
+
"""Set `public_key` for the specified `node_id`."""
|
|
164
|
+
|
|
165
|
+
@abc.abstractmethod
|
|
166
|
+
def get_node_public_key(self, node_id: int) -> Optional[bytes]:
|
|
167
|
+
"""Get `public_key` for the specified `node_id`."""
|
|
168
|
+
|
|
165
169
|
@abc.abstractmethod
|
|
166
170
|
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
|
|
167
171
|
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
|
@@ -260,18 +264,8 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
260
264
|
"""
|
|
261
265
|
|
|
262
266
|
@abc.abstractmethod
|
|
263
|
-
def
|
|
264
|
-
|
|
265
|
-
) -> None:
|
|
266
|
-
"""Store `server_private_key` and `server_public_key` in the link state."""
|
|
267
|
-
|
|
268
|
-
@abc.abstractmethod
|
|
269
|
-
def get_server_private_key(self) -> Optional[bytes]:
|
|
270
|
-
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
271
|
-
|
|
272
|
-
@abc.abstractmethod
|
|
273
|
-
def get_server_public_key(self) -> Optional[bytes]:
|
|
274
|
-
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
267
|
+
def clear_supernode_auth_keys(self) -> None:
|
|
268
|
+
"""Clear stored `node_public_keys` in the link state if any."""
|
|
275
269
|
|
|
276
270
|
@abc.abstractmethod
|
|
277
271
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|