flwr-nightly 1.8.0.dev20240323__py3-none-any.whl → 1.8.0.dev20240328__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/client/app.py +35 -24
- flwr/client/client_app.py +4 -4
- flwr/client/grpc_client/connection.py +2 -1
- flwr/client/message_handler/message_handler.py +3 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/common/__init__.py +2 -0
- flwr/common/message.py +65 -20
- flwr/common/serde.py +8 -2
- flwr/proto/fleet_pb2.py +19 -15
- flwr/proto/fleet_pb2.pyi +28 -0
- flwr/proto/fleet_pb2_grpc.py +33 -0
- flwr/proto/fleet_pb2_grpc.pyi +10 -0
- flwr/proto/task_pb2.py +6 -6
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/compat/driver_client_proxy.py +25 -1
- flwr/server/driver/driver.py +6 -5
- flwr/server/superlink/driver/driver_servicer.py +6 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +11 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +14 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -4
- flwr/server/superlink/fleet/vce/vce_api.py +41 -25
- flwr/server/superlink/state/in_memory_state.py +38 -26
- flwr/server/superlink/state/sqlite_state.py +42 -21
- flwr/server/superlink/state/state.py +19 -0
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +4 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +5 -4
- flwr/simulation/ray_transport/ray_actor.py +6 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- {flwr_nightly-1.8.0.dev20240323.dist-info → flwr_nightly-1.8.0.dev20240328.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240323.dist-info → flwr_nightly-1.8.0.dev20240328.dist-info}/RECORD +34 -34
- {flwr_nightly-1.8.0.dev20240323.dist-info → flwr_nightly-1.8.0.dev20240328.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240323.dist-info → flwr_nightly-1.8.0.dev20240328.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240323.dist-info → flwr_nightly-1.8.0.dev20240328.dist-info}/entry_points.txt +0 -0
|
@@ -19,7 +19,7 @@ import time
|
|
|
19
19
|
from typing import List, Optional
|
|
20
20
|
|
|
21
21
|
from flwr import common
|
|
22
|
-
from flwr.common import MessageType, MessageTypeLegacy, RecordSet
|
|
22
|
+
from flwr.common import DEFAULT_TTL, MessageType, MessageTypeLegacy, RecordSet
|
|
23
23
|
from flwr.common import recordset_compat as compat
|
|
24
24
|
from flwr.common import serde
|
|
25
25
|
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
|
|
@@ -129,8 +129,16 @@ class DriverClientProxy(ClientProxy):
|
|
|
129
129
|
),
|
|
130
130
|
task_type=task_type,
|
|
131
131
|
recordset=serde.recordset_to_proto(recordset),
|
|
132
|
+
ttl=DEFAULT_TTL,
|
|
132
133
|
),
|
|
133
134
|
)
|
|
135
|
+
|
|
136
|
+
# This would normally be recorded upon common.Message creation
|
|
137
|
+
# but this compatibility stack doesn't create Messages,
|
|
138
|
+
# so we need to inject `created_at` manually (needed for
|
|
139
|
+
# taskins validation by server.utils.validator)
|
|
140
|
+
task_ins.task.created_at = time.time()
|
|
141
|
+
|
|
134
142
|
push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101
|
|
135
143
|
task_ins_list=[task_ins]
|
|
136
144
|
)
|
|
@@ -162,8 +170,24 @@ class DriverClientProxy(ClientProxy):
|
|
|
162
170
|
)
|
|
163
171
|
if len(task_res_list) == 1:
|
|
164
172
|
task_res = task_res_list[0]
|
|
173
|
+
|
|
174
|
+
# This will raise an Exception if task_res carries an `error`
|
|
175
|
+
validate_task_res(task_res=task_res)
|
|
176
|
+
|
|
165
177
|
return serde.recordset_from_proto(task_res.task.recordset)
|
|
166
178
|
|
|
167
179
|
if timeout is not None and time.time() > start_time + timeout:
|
|
168
180
|
raise RuntimeError("Timeout reached")
|
|
169
181
|
time.sleep(SLEEP_TIME)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def validate_task_res(
|
|
185
|
+
task_res: task_pb2.TaskRes, # pylint: disable=E1101
|
|
186
|
+
) -> None:
|
|
187
|
+
"""Validate if a TaskRes is empty or not."""
|
|
188
|
+
if not task_res.HasField("task"):
|
|
189
|
+
raise ValueError("Invalid TaskRes, field `task` missing")
|
|
190
|
+
if task_res.task.HasField("error"):
|
|
191
|
+
raise ValueError("Exception during client-side task execution")
|
|
192
|
+
if not task_res.task.HasField("recordset"):
|
|
193
|
+
raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing")
|
flwr/server/driver/driver.py
CHANGED
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import time
|
|
19
19
|
from typing import Iterable, List, Optional, Tuple
|
|
20
20
|
|
|
21
|
-
from flwr.common import Message, Metadata, RecordSet
|
|
21
|
+
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
22
22
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
23
23
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
24
24
|
CreateRunRequest,
|
|
@@ -81,6 +81,7 @@ class Driver:
|
|
|
81
81
|
and message.metadata.src_node_id == self.node.node_id
|
|
82
82
|
and message.metadata.message_id == ""
|
|
83
83
|
and message.metadata.reply_to_message == ""
|
|
84
|
+
and message.metadata.ttl > 0
|
|
84
85
|
):
|
|
85
86
|
raise ValueError(f"Invalid message: {message}")
|
|
86
87
|
|
|
@@ -90,7 +91,7 @@ class Driver:
|
|
|
90
91
|
message_type: str,
|
|
91
92
|
dst_node_id: int,
|
|
92
93
|
group_id: str,
|
|
93
|
-
ttl:
|
|
94
|
+
ttl: float = DEFAULT_TTL,
|
|
94
95
|
) -> Message:
|
|
95
96
|
"""Create a new message with specified parameters.
|
|
96
97
|
|
|
@@ -110,10 +111,10 @@ class Driver:
|
|
|
110
111
|
group_id : str
|
|
111
112
|
The ID of the group to which this message is associated. In some settings,
|
|
112
113
|
this is used as the FL round.
|
|
113
|
-
ttl :
|
|
114
|
+
ttl : float (default: common.DEFAULT_TTL)
|
|
114
115
|
Time-to-live for the round trip of this message, i.e., the time from sending
|
|
115
|
-
this message to receiving a reply. It specifies the duration for
|
|
116
|
-
message and its potential reply are considered valid.
|
|
116
|
+
this message to receiving a reply. It specifies in seconds the duration for
|
|
117
|
+
which the message and its potential reply are considered valid.
|
|
117
118
|
|
|
118
119
|
Returns
|
|
119
120
|
-------
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Driver API servicer."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import time
|
|
18
19
|
from logging import DEBUG, INFO
|
|
19
20
|
from typing import List, Optional, Set
|
|
20
21
|
from uuid import UUID
|
|
@@ -72,6 +73,11 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
72
73
|
"""Push a set of TaskIns."""
|
|
73
74
|
log(DEBUG, "DriverServicer.PushTaskIns")
|
|
74
75
|
|
|
76
|
+
# Set pushed_at (timestamp in seconds)
|
|
77
|
+
pushed_at = time.time()
|
|
78
|
+
for task_ins in request.task_ins_list:
|
|
79
|
+
task_ins.task.pushed_at = pushed_at
|
|
80
|
+
|
|
75
81
|
# Validate request
|
|
76
82
|
_raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty")
|
|
77
83
|
for task_ins in request.task_ins_list:
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Fleet API gRPC request-response servicer."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from logging import INFO
|
|
18
|
+
from logging import DEBUG, INFO
|
|
19
19
|
|
|
20
20
|
import grpc
|
|
21
21
|
|
|
@@ -26,6 +26,8 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
26
26
|
CreateNodeResponse,
|
|
27
27
|
DeleteNodeRequest,
|
|
28
28
|
DeleteNodeResponse,
|
|
29
|
+
PingRequest,
|
|
30
|
+
PingResponse,
|
|
29
31
|
PullTaskInsRequest,
|
|
30
32
|
PullTaskInsResponse,
|
|
31
33
|
PushTaskResRequest,
|
|
@@ -61,6 +63,14 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
61
63
|
state=self.state_factory.state(),
|
|
62
64
|
)
|
|
63
65
|
|
|
66
|
+
def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingResponse:
|
|
67
|
+
"""."""
|
|
68
|
+
log(DEBUG, "FleetServicer.Ping")
|
|
69
|
+
return message_handler.ping(
|
|
70
|
+
request=request,
|
|
71
|
+
state=self.state_factory.state(),
|
|
72
|
+
)
|
|
73
|
+
|
|
64
74
|
def PullTaskIns(
|
|
65
75
|
self, request: PullTaskInsRequest, context: grpc.ServicerContext
|
|
66
76
|
) -> PullTaskInsResponse:
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Fleet API message handlers."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import time
|
|
18
19
|
from typing import List, Optional
|
|
19
20
|
from uuid import UUID
|
|
20
21
|
|
|
@@ -23,6 +24,8 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
23
24
|
CreateNodeResponse,
|
|
24
25
|
DeleteNodeRequest,
|
|
25
26
|
DeleteNodeResponse,
|
|
27
|
+
PingRequest,
|
|
28
|
+
PingResponse,
|
|
26
29
|
PullTaskInsRequest,
|
|
27
30
|
PullTaskInsResponse,
|
|
28
31
|
PushTaskResRequest,
|
|
@@ -55,6 +58,14 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
|
|
|
55
58
|
return DeleteNodeResponse()
|
|
56
59
|
|
|
57
60
|
|
|
61
|
+
def ping(
|
|
62
|
+
request: PingRequest, # pylint: disable=unused-argument
|
|
63
|
+
state: State, # pylint: disable=unused-argument
|
|
64
|
+
) -> PingResponse:
|
|
65
|
+
"""."""
|
|
66
|
+
return PingResponse(success=True)
|
|
67
|
+
|
|
68
|
+
|
|
58
69
|
def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse:
|
|
59
70
|
"""Pull TaskIns handler."""
|
|
60
71
|
# Get node_id if client node is not anonymous
|
|
@@ -77,6 +88,9 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
|
|
|
77
88
|
task_res: TaskRes = request.task_res_list[0]
|
|
78
89
|
# pylint: enable=no-member
|
|
79
90
|
|
|
91
|
+
# Set pushed_at (timestamp in seconds)
|
|
92
|
+
task_res.task.pushed_at = time.time()
|
|
93
|
+
|
|
80
94
|
# Store TaskRes in State
|
|
81
95
|
task_id: Optional[UUID] = state.store_task_res(task_res=task_res)
|
|
82
96
|
|
|
@@ -20,7 +20,7 @@ from typing import Callable, Dict, List, Tuple, Union
|
|
|
20
20
|
|
|
21
21
|
import ray
|
|
22
22
|
|
|
23
|
-
from flwr.client.client_app import ClientApp
|
|
23
|
+
from flwr.client.client_app import ClientApp
|
|
24
24
|
from flwr.common.context import Context
|
|
25
25
|
from flwr.common.logger import log
|
|
26
26
|
from flwr.common.message import Message
|
|
@@ -151,7 +151,6 @@ class RayBackend(Backend):
|
|
|
151
151
|
)
|
|
152
152
|
|
|
153
153
|
await future
|
|
154
|
-
|
|
155
154
|
# Fetch result
|
|
156
155
|
(
|
|
157
156
|
out_mssg,
|
|
@@ -160,13 +159,15 @@ class RayBackend(Backend):
|
|
|
160
159
|
|
|
161
160
|
return out_mssg, updated_context
|
|
162
161
|
|
|
163
|
-
except
|
|
162
|
+
except Exception as ex:
|
|
164
163
|
log(
|
|
165
164
|
ERROR,
|
|
166
165
|
"An exception was raised when processing a message by %s",
|
|
167
166
|
self.__class__.__name__,
|
|
168
167
|
)
|
|
169
|
-
|
|
168
|
+
# add actor back into pool
|
|
169
|
+
await self.pool.add_actor_back_to_pool(future)
|
|
170
|
+
raise ex
|
|
170
171
|
|
|
171
172
|
async def terminate(self) -> None:
|
|
172
173
|
"""Terminate all actors in actor pool."""
|
|
@@ -14,9 +14,10 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Fleet Simulation Engine API."""
|
|
16
16
|
|
|
17
|
-
|
|
18
17
|
import asyncio
|
|
19
18
|
import json
|
|
19
|
+
import sys
|
|
20
|
+
import time
|
|
20
21
|
import traceback
|
|
21
22
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
22
23
|
from typing import Callable, Dict, List, Optional
|
|
@@ -24,6 +25,7 @@ from typing import Callable, Dict, List, Optional
|
|
|
24
25
|
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
25
26
|
from flwr.client.node_state import NodeState
|
|
26
27
|
from flwr.common.logger import log
|
|
28
|
+
from flwr.common.message import Error
|
|
27
29
|
from flwr.common.object_ref import load_app
|
|
28
30
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
29
31
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
@@ -59,6 +61,7 @@ async def worker(
|
|
|
59
61
|
"""Get TaskIns from queue and pass it to an actor in the pool to execute it."""
|
|
60
62
|
state = state_factory.state()
|
|
61
63
|
while True:
|
|
64
|
+
out_mssg = None
|
|
62
65
|
try:
|
|
63
66
|
task_ins: TaskIns = await queue.get()
|
|
64
67
|
node_id = task_ins.task.consumer.node_id
|
|
@@ -82,24 +85,25 @@ async def worker(
|
|
|
82
85
|
task_ins.run_id, context=updated_context
|
|
83
86
|
)
|
|
84
87
|
|
|
85
|
-
# Convert to TaskRes
|
|
86
|
-
task_res = message_to_taskres(out_mssg)
|
|
87
|
-
# Store TaskRes in state
|
|
88
|
-
state.store_task_res(task_res)
|
|
89
|
-
|
|
90
88
|
except asyncio.CancelledError as e:
|
|
91
|
-
log(DEBUG, "
|
|
89
|
+
log(DEBUG, "Terminating async worker: %s", e)
|
|
92
90
|
break
|
|
93
91
|
|
|
94
|
-
|
|
95
|
-
log(ERROR, "Async worker: %s", app_ex)
|
|
96
|
-
log(ERROR, traceback.format_exc())
|
|
97
|
-
raise
|
|
98
|
-
|
|
92
|
+
# Exceptions aren't raised but reported as an error message
|
|
99
93
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
100
94
|
log(ERROR, ex)
|
|
101
95
|
log(ERROR, traceback.format_exc())
|
|
102
|
-
|
|
96
|
+
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
|
|
97
|
+
error = Error(code=0, reason=reason)
|
|
98
|
+
out_mssg = message.create_error_reply(error=error)
|
|
99
|
+
|
|
100
|
+
finally:
|
|
101
|
+
if out_mssg:
|
|
102
|
+
# Convert to TaskRes
|
|
103
|
+
task_res = message_to_taskres(out_mssg)
|
|
104
|
+
# Store TaskRes in state
|
|
105
|
+
task_res.task.pushed_at = time.time()
|
|
106
|
+
state.store_task_res(task_res)
|
|
103
107
|
|
|
104
108
|
|
|
105
109
|
async def add_taskins_to_queue(
|
|
@@ -218,7 +222,7 @@ async def run(
|
|
|
218
222
|
await backend.terminate()
|
|
219
223
|
|
|
220
224
|
|
|
221
|
-
# pylint: disable=too-many-arguments,unused-argument,too-many-locals
|
|
225
|
+
# pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
|
|
222
226
|
def start_vce(
|
|
223
227
|
backend_name: str,
|
|
224
228
|
backend_config_json_stream: str,
|
|
@@ -300,12 +304,14 @@ def start_vce(
|
|
|
300
304
|
"""Instantiate a Backend."""
|
|
301
305
|
return backend_type(backend_config, work_dir=app_dir)
|
|
302
306
|
|
|
303
|
-
log(INFO, "client_app_attr = %s", client_app_attr)
|
|
304
|
-
|
|
305
307
|
# Load ClientApp if needed
|
|
306
308
|
def _load() -> ClientApp:
|
|
307
309
|
|
|
308
310
|
if client_app_attr:
|
|
311
|
+
|
|
312
|
+
if app_dir is not None:
|
|
313
|
+
sys.path.insert(0, app_dir)
|
|
314
|
+
|
|
309
315
|
app: ClientApp = load_app(client_app_attr, LoadClientAppError)
|
|
310
316
|
|
|
311
317
|
if not isinstance(app, ClientApp):
|
|
@@ -319,13 +325,23 @@ def start_vce(
|
|
|
319
325
|
|
|
320
326
|
app_fn = _load
|
|
321
327
|
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
328
|
+
try:
|
|
329
|
+
# Test if ClientApp can be loaded
|
|
330
|
+
_ = app_fn()
|
|
331
|
+
|
|
332
|
+
# Run main simulation loop
|
|
333
|
+
asyncio.run(
|
|
334
|
+
run(
|
|
335
|
+
app_fn,
|
|
336
|
+
backend_fn,
|
|
337
|
+
nodes_mapping,
|
|
338
|
+
state_factory,
|
|
339
|
+
node_states,
|
|
340
|
+
f_stop,
|
|
341
|
+
)
|
|
330
342
|
)
|
|
331
|
-
|
|
343
|
+
except LoadClientAppError as loadapp_ex:
|
|
344
|
+
f_stop.set() # set termination event
|
|
345
|
+
raise loadapp_ex
|
|
346
|
+
except Exception as ex:
|
|
347
|
+
raise ex
|
|
@@ -17,9 +17,9 @@
|
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
19
|
import threading
|
|
20
|
-
|
|
20
|
+
import time
|
|
21
21
|
from logging import ERROR
|
|
22
|
-
from typing import Dict, List, Optional, Set
|
|
22
|
+
from typing import Dict, List, Optional, Set, Tuple
|
|
23
23
|
from uuid import UUID, uuid4
|
|
24
24
|
|
|
25
25
|
from flwr.common import log, now
|
|
@@ -32,7 +32,8 @@ class InMemoryState(State):
|
|
|
32
32
|
"""In-memory State implementation."""
|
|
33
33
|
|
|
34
34
|
def __init__(self) -> None:
|
|
35
|
-
|
|
35
|
+
# Map node_id to (online_until, ping_interval)
|
|
36
|
+
self.node_ids: Dict[int, Tuple[float, float]] = {}
|
|
36
37
|
self.run_ids: Set[int] = set()
|
|
37
38
|
self.task_ins_store: Dict[UUID, TaskIns] = {}
|
|
38
39
|
self.task_res_store: Dict[UUID, TaskRes] = {}
|
|
@@ -50,15 +51,11 @@ class InMemoryState(State):
|
|
|
50
51
|
log(ERROR, "`run_id` is invalid")
|
|
51
52
|
return None
|
|
52
53
|
|
|
53
|
-
# Create task_id
|
|
54
|
+
# Create task_id
|
|
54
55
|
task_id = uuid4()
|
|
55
|
-
created_at: datetime = now()
|
|
56
|
-
ttl: datetime = created_at + timedelta(hours=24)
|
|
57
56
|
|
|
58
57
|
# Store TaskIns
|
|
59
58
|
task_ins.task_id = str(task_id)
|
|
60
|
-
task_ins.task.created_at = created_at.isoformat()
|
|
61
|
-
task_ins.task.ttl = ttl.isoformat()
|
|
62
59
|
with self.lock:
|
|
63
60
|
self.task_ins_store[task_id] = task_ins
|
|
64
61
|
|
|
@@ -113,15 +110,11 @@ class InMemoryState(State):
|
|
|
113
110
|
log(ERROR, "`run_id` is invalid")
|
|
114
111
|
return None
|
|
115
112
|
|
|
116
|
-
# Create task_id
|
|
113
|
+
# Create task_id
|
|
117
114
|
task_id = uuid4()
|
|
118
|
-
created_at: datetime = now()
|
|
119
|
-
ttl: datetime = created_at + timedelta(hours=24)
|
|
120
115
|
|
|
121
116
|
# Store TaskRes
|
|
122
117
|
task_res.task_id = str(task_id)
|
|
123
|
-
task_res.task.created_at = created_at.isoformat()
|
|
124
|
-
task_res.task.ttl = ttl.isoformat()
|
|
125
118
|
with self.lock:
|
|
126
119
|
self.task_res_store[task_id] = task_res
|
|
127
120
|
|
|
@@ -194,17 +187,21 @@ class InMemoryState(State):
|
|
|
194
187
|
# Sample a random int64 as node_id
|
|
195
188
|
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
196
189
|
|
|
197
|
-
|
|
198
|
-
self.node_ids
|
|
199
|
-
|
|
190
|
+
with self.lock:
|
|
191
|
+
if node_id not in self.node_ids:
|
|
192
|
+
# Default ping interval is 30s
|
|
193
|
+
# TODO: change 1e9 to 30s # pylint: disable=W0511
|
|
194
|
+
self.node_ids[node_id] = (time.time() + 1e9, 1e9)
|
|
195
|
+
return node_id
|
|
200
196
|
log(ERROR, "Unexpected node registration failure.")
|
|
201
197
|
return 0
|
|
202
198
|
|
|
203
199
|
def delete_node(self, node_id: int) -> None:
|
|
204
200
|
"""Delete a client node."""
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
201
|
+
with self.lock:
|
|
202
|
+
if node_id not in self.node_ids:
|
|
203
|
+
raise ValueError(f"Node {node_id} not found")
|
|
204
|
+
del self.node_ids[node_id]
|
|
208
205
|
|
|
209
206
|
def get_nodes(self, run_id: int) -> Set[int]:
|
|
210
207
|
"""Return all available client nodes.
|
|
@@ -214,17 +211,32 @@ class InMemoryState(State):
|
|
|
214
211
|
If the provided `run_id` does not exist or has no matching nodes,
|
|
215
212
|
an empty `Set` MUST be returned.
|
|
216
213
|
"""
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
214
|
+
with self.lock:
|
|
215
|
+
if run_id not in self.run_ids:
|
|
216
|
+
return set()
|
|
217
|
+
current_time = time.time()
|
|
218
|
+
return {
|
|
219
|
+
node_id
|
|
220
|
+
for node_id, (online_until, _) in self.node_ids.items()
|
|
221
|
+
if online_until > current_time
|
|
222
|
+
}
|
|
220
223
|
|
|
221
224
|
def create_run(self) -> int:
|
|
222
225
|
"""Create one run."""
|
|
223
226
|
# Sample a random int64 as run_id
|
|
224
|
-
|
|
227
|
+
with self.lock:
|
|
228
|
+
run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
225
229
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
230
|
+
if run_id not in self.run_ids:
|
|
231
|
+
self.run_ids.add(run_id)
|
|
232
|
+
return run_id
|
|
229
233
|
log(ERROR, "Unexpected run creation failure.")
|
|
230
234
|
return 0
|
|
235
|
+
|
|
236
|
+
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
237
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
238
|
+
with self.lock:
|
|
239
|
+
if node_id in self.node_ids:
|
|
240
|
+
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
241
|
+
return True
|
|
242
|
+
return False
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import os
|
|
19
19
|
import re
|
|
20
20
|
import sqlite3
|
|
21
|
-
|
|
21
|
+
import time
|
|
22
22
|
from logging import DEBUG, ERROR
|
|
23
23
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
|
|
24
24
|
from uuid import UUID, uuid4
|
|
@@ -33,10 +33,16 @@ from .state import State
|
|
|
33
33
|
|
|
34
34
|
SQL_CREATE_TABLE_NODE = """
|
|
35
35
|
CREATE TABLE IF NOT EXISTS node(
|
|
36
|
-
node_id
|
|
36
|
+
node_id INTEGER UNIQUE,
|
|
37
|
+
online_until REAL,
|
|
38
|
+
ping_interval REAL
|
|
37
39
|
);
|
|
38
40
|
"""
|
|
39
41
|
|
|
42
|
+
SQL_CREATE_INDEX_ONLINE_UNTIL = """
|
|
43
|
+
CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
|
44
|
+
"""
|
|
45
|
+
|
|
40
46
|
SQL_CREATE_TABLE_RUN = """
|
|
41
47
|
CREATE TABLE IF NOT EXISTS run(
|
|
42
48
|
run_id INTEGER UNIQUE
|
|
@@ -52,9 +58,10 @@ CREATE TABLE IF NOT EXISTS task_ins(
|
|
|
52
58
|
producer_node_id INTEGER,
|
|
53
59
|
consumer_anonymous BOOLEAN,
|
|
54
60
|
consumer_node_id INTEGER,
|
|
55
|
-
created_at
|
|
61
|
+
created_at REAL,
|
|
56
62
|
delivered_at TEXT,
|
|
57
|
-
|
|
63
|
+
pushed_at REAL,
|
|
64
|
+
ttl REAL,
|
|
58
65
|
ancestry TEXT,
|
|
59
66
|
task_type TEXT,
|
|
60
67
|
recordset BLOB,
|
|
@@ -72,9 +79,10 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
|
72
79
|
producer_node_id INTEGER,
|
|
73
80
|
consumer_anonymous BOOLEAN,
|
|
74
81
|
consumer_node_id INTEGER,
|
|
75
|
-
created_at
|
|
82
|
+
created_at REAL,
|
|
76
83
|
delivered_at TEXT,
|
|
77
|
-
|
|
84
|
+
pushed_at REAL,
|
|
85
|
+
ttl REAL,
|
|
78
86
|
ancestry TEXT,
|
|
79
87
|
task_type TEXT,
|
|
80
88
|
recordset BLOB,
|
|
@@ -82,7 +90,7 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
|
82
90
|
);
|
|
83
91
|
"""
|
|
84
92
|
|
|
85
|
-
DictOrTuple = Union[Tuple[Any], Dict[str, Any]]
|
|
93
|
+
DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]]
|
|
86
94
|
|
|
87
95
|
|
|
88
96
|
class SqliteState(State):
|
|
@@ -123,6 +131,7 @@ class SqliteState(State):
|
|
|
123
131
|
cur.execute(SQL_CREATE_TABLE_TASK_INS)
|
|
124
132
|
cur.execute(SQL_CREATE_TABLE_TASK_RES)
|
|
125
133
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
134
|
+
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
|
126
135
|
res = cur.execute("SELECT name FROM sqlite_schema;")
|
|
127
136
|
|
|
128
137
|
return res.fetchall()
|
|
@@ -185,15 +194,11 @@ class SqliteState(State):
|
|
|
185
194
|
log(ERROR, errors)
|
|
186
195
|
return None
|
|
187
196
|
|
|
188
|
-
# Create task_id
|
|
197
|
+
# Create task_id
|
|
189
198
|
task_id = uuid4()
|
|
190
|
-
created_at: datetime = now()
|
|
191
|
-
ttl: datetime = created_at + timedelta(hours=24)
|
|
192
199
|
|
|
193
200
|
# Store TaskIns
|
|
194
201
|
task_ins.task_id = str(task_id)
|
|
195
|
-
task_ins.task.created_at = created_at.isoformat()
|
|
196
|
-
task_ins.task.ttl = ttl.isoformat()
|
|
197
202
|
data = (task_ins_to_dict(task_ins),)
|
|
198
203
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
199
204
|
query = f"INSERT INTO task_ins VALUES({columns});"
|
|
@@ -320,15 +325,11 @@ class SqliteState(State):
|
|
|
320
325
|
log(ERROR, errors)
|
|
321
326
|
return None
|
|
322
327
|
|
|
323
|
-
# Create task_id
|
|
328
|
+
# Create task_id
|
|
324
329
|
task_id = uuid4()
|
|
325
|
-
created_at: datetime = now()
|
|
326
|
-
ttl: datetime = created_at + timedelta(hours=24)
|
|
327
330
|
|
|
328
331
|
# Store TaskIns
|
|
329
332
|
task_res.task_id = str(task_id)
|
|
330
|
-
task_res.task.created_at = created_at.isoformat()
|
|
331
|
-
task_res.task.ttl = ttl.isoformat()
|
|
332
333
|
data = (task_res_to_dict(task_res),)
|
|
333
334
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
334
335
|
query = f"INSERT INTO task_res VALUES({columns});"
|
|
@@ -472,9 +473,14 @@ class SqliteState(State):
|
|
|
472
473
|
# Sample a random int64 as node_id
|
|
473
474
|
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
474
475
|
|
|
475
|
-
query =
|
|
476
|
+
query = (
|
|
477
|
+
"INSERT INTO node (node_id, online_until, ping_interval) VALUES (?, ?, ?)"
|
|
478
|
+
)
|
|
479
|
+
|
|
476
480
|
try:
|
|
477
|
-
|
|
481
|
+
# Default ping interval is 30s
|
|
482
|
+
# TODO: change 1e9 to 30s # pylint: disable=W0511
|
|
483
|
+
self.query(query, (node_id, time.time() + 1e9, 1e9))
|
|
478
484
|
except sqlite3.IntegrityError:
|
|
479
485
|
log(ERROR, "Unexpected node registration failure.")
|
|
480
486
|
return 0
|
|
@@ -499,8 +505,8 @@ class SqliteState(State):
|
|
|
499
505
|
return set()
|
|
500
506
|
|
|
501
507
|
# Get nodes
|
|
502
|
-
query = "SELECT
|
|
503
|
-
rows = self.query(query)
|
|
508
|
+
query = "SELECT node_id FROM node WHERE online_until > ?;"
|
|
509
|
+
rows = self.query(query, (time.time(),))
|
|
504
510
|
result: Set[int] = {row["node_id"] for row in rows}
|
|
505
511
|
return result
|
|
506
512
|
|
|
@@ -519,6 +525,17 @@ class SqliteState(State):
|
|
|
519
525
|
log(ERROR, "Unexpected run creation failure.")
|
|
520
526
|
return 0
|
|
521
527
|
|
|
528
|
+
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
529
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
530
|
+
# Update `online_until` and `ping_interval` for the given `node_id`
|
|
531
|
+
query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?;"
|
|
532
|
+
try:
|
|
533
|
+
self.query(query, (time.time() + ping_interval, ping_interval, node_id))
|
|
534
|
+
return True
|
|
535
|
+
except sqlite3.IntegrityError:
|
|
536
|
+
log(ERROR, "`node_id` does not exist.")
|
|
537
|
+
return False
|
|
538
|
+
|
|
522
539
|
|
|
523
540
|
def dict_factory(
|
|
524
541
|
cursor: sqlite3.Cursor,
|
|
@@ -544,6 +561,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]:
|
|
|
544
561
|
"consumer_node_id": task_msg.task.consumer.node_id,
|
|
545
562
|
"created_at": task_msg.task.created_at,
|
|
546
563
|
"delivered_at": task_msg.task.delivered_at,
|
|
564
|
+
"pushed_at": task_msg.task.pushed_at,
|
|
547
565
|
"ttl": task_msg.task.ttl,
|
|
548
566
|
"ancestry": ",".join(task_msg.task.ancestry),
|
|
549
567
|
"task_type": task_msg.task.task_type,
|
|
@@ -564,6 +582,7 @@ def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]:
|
|
|
564
582
|
"consumer_node_id": task_msg.task.consumer.node_id,
|
|
565
583
|
"created_at": task_msg.task.created_at,
|
|
566
584
|
"delivered_at": task_msg.task.delivered_at,
|
|
585
|
+
"pushed_at": task_msg.task.pushed_at,
|
|
567
586
|
"ttl": task_msg.task.ttl,
|
|
568
587
|
"ancestry": ",".join(task_msg.task.ancestry),
|
|
569
588
|
"task_type": task_msg.task.task_type,
|
|
@@ -592,6 +611,7 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns:
|
|
|
592
611
|
),
|
|
593
612
|
created_at=task_dict["created_at"],
|
|
594
613
|
delivered_at=task_dict["delivered_at"],
|
|
614
|
+
pushed_at=task_dict["pushed_at"],
|
|
595
615
|
ttl=task_dict["ttl"],
|
|
596
616
|
ancestry=task_dict["ancestry"].split(","),
|
|
597
617
|
task_type=task_dict["task_type"],
|
|
@@ -621,6 +641,7 @@ def dict_to_task_res(task_dict: Dict[str, Any]) -> TaskRes:
|
|
|
621
641
|
),
|
|
622
642
|
created_at=task_dict["created_at"],
|
|
623
643
|
delivered_at=task_dict["delivered_at"],
|
|
644
|
+
pushed_at=task_dict["pushed_at"],
|
|
624
645
|
ttl=task_dict["ttl"],
|
|
625
646
|
ancestry=task_dict["ancestry"].split(","),
|
|
626
647
|
task_type=task_dict["task_type"],
|
|
@@ -152,3 +152,22 @@ class State(abc.ABC):
|
|
|
152
152
|
@abc.abstractmethod
|
|
153
153
|
def create_run(self) -> int:
|
|
154
154
|
"""Create one run."""
|
|
155
|
+
|
|
156
|
+
@abc.abstractmethod
|
|
157
|
+
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
158
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
|
159
|
+
|
|
160
|
+
Parameters
|
|
161
|
+
----------
|
|
162
|
+
node_id : int
|
|
163
|
+
The `node_id` from which the ping was received.
|
|
164
|
+
ping_interval : float
|
|
165
|
+
The interval (in seconds) from the current timestamp within which the next
|
|
166
|
+
ping from this node must be received. This acts as a hard deadline to ensure
|
|
167
|
+
an accurate assessment of the node's availability.
|
|
168
|
+
|
|
169
|
+
Returns
|
|
170
|
+
-------
|
|
171
|
+
is_acknowledged : bool
|
|
172
|
+
True if the ping is successfully acknowledged; otherwise, False.
|
|
173
|
+
"""
|