flwr-nightly 1.8.0.dev20240323__py3-none-any.whl → 1.8.0.dev20240328__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (34) hide show
  1. flwr/client/app.py +35 -24
  2. flwr/client/client_app.py +4 -4
  3. flwr/client/grpc_client/connection.py +2 -1
  4. flwr/client/message_handler/message_handler.py +3 -2
  5. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  6. flwr/common/__init__.py +2 -0
  7. flwr/common/message.py +65 -20
  8. flwr/common/serde.py +8 -2
  9. flwr/proto/fleet_pb2.py +19 -15
  10. flwr/proto/fleet_pb2.pyi +28 -0
  11. flwr/proto/fleet_pb2_grpc.py +33 -0
  12. flwr/proto/fleet_pb2_grpc.pyi +10 -0
  13. flwr/proto/task_pb2.py +6 -6
  14. flwr/proto/task_pb2.pyi +8 -5
  15. flwr/server/compat/driver_client_proxy.py +25 -1
  16. flwr/server/driver/driver.py +6 -5
  17. flwr/server/superlink/driver/driver_servicer.py +6 -0
  18. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +11 -1
  19. flwr/server/superlink/fleet/message_handler/message_handler.py +14 -0
  20. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -4
  21. flwr/server/superlink/fleet/vce/vce_api.py +41 -25
  22. flwr/server/superlink/state/in_memory_state.py +38 -26
  23. flwr/server/superlink/state/sqlite_state.py +42 -21
  24. flwr/server/superlink/state/state.py +19 -0
  25. flwr/server/utils/validator.py +23 -9
  26. flwr/server/workflow/default_workflows.py +4 -4
  27. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +5 -4
  28. flwr/simulation/ray_transport/ray_actor.py +6 -2
  29. flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
  30. {flwr_nightly-1.8.0.dev20240323.dist-info → flwr_nightly-1.8.0.dev20240328.dist-info}/METADATA +1 -1
  31. {flwr_nightly-1.8.0.dev20240323.dist-info → flwr_nightly-1.8.0.dev20240328.dist-info}/RECORD +34 -34
  32. {flwr_nightly-1.8.0.dev20240323.dist-info → flwr_nightly-1.8.0.dev20240328.dist-info}/LICENSE +0 -0
  33. {flwr_nightly-1.8.0.dev20240323.dist-info → flwr_nightly-1.8.0.dev20240328.dist-info}/WHEEL +0 -0
  34. {flwr_nightly-1.8.0.dev20240323.dist-info → flwr_nightly-1.8.0.dev20240328.dist-info}/entry_points.txt +0 -0
@@ -19,7 +19,7 @@ import time
19
19
  from typing import List, Optional
20
20
 
21
21
  from flwr import common
22
- from flwr.common import MessageType, MessageTypeLegacy, RecordSet
22
+ from flwr.common import DEFAULT_TTL, MessageType, MessageTypeLegacy, RecordSet
23
23
  from flwr.common import recordset_compat as compat
24
24
  from flwr.common import serde
25
25
  from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
@@ -129,8 +129,16 @@ class DriverClientProxy(ClientProxy):
129
129
  ),
130
130
  task_type=task_type,
131
131
  recordset=serde.recordset_to_proto(recordset),
132
+ ttl=DEFAULT_TTL,
132
133
  ),
133
134
  )
135
+
136
+ # This would normally be recorded upon common.Message creation
137
+ # but this compatibility stack doesn't create Messages,
138
+ # so we need to inject `created_at` manually (needed for
139
+ # taskins validation by server.utils.validator)
140
+ task_ins.task.created_at = time.time()
141
+
134
142
  push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101
135
143
  task_ins_list=[task_ins]
136
144
  )
@@ -162,8 +170,24 @@ class DriverClientProxy(ClientProxy):
162
170
  )
163
171
  if len(task_res_list) == 1:
164
172
  task_res = task_res_list[0]
173
+
174
+ # This will raise an Exception if task_res carries an `error`
175
+ validate_task_res(task_res=task_res)
176
+
165
177
  return serde.recordset_from_proto(task_res.task.recordset)
166
178
 
167
179
  if timeout is not None and time.time() > start_time + timeout:
168
180
  raise RuntimeError("Timeout reached")
169
181
  time.sleep(SLEEP_TIME)
182
+
183
+
184
+ def validate_task_res(
185
+ task_res: task_pb2.TaskRes, # pylint: disable=E1101
186
+ ) -> None:
187
+ """Validate if a TaskRes is empty or not."""
188
+ if not task_res.HasField("task"):
189
+ raise ValueError("Invalid TaskRes, field `task` missing")
190
+ if task_res.task.HasField("error"):
191
+ raise ValueError("Exception during client-side task execution")
192
+ if not task_res.task.HasField("recordset"):
193
+ raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing")
@@ -18,7 +18,7 @@
18
18
  import time
19
19
  from typing import Iterable, List, Optional, Tuple
20
20
 
21
- from flwr.common import Message, Metadata, RecordSet
21
+ from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
22
22
  from flwr.common.serde import message_from_taskres, message_to_taskins
23
23
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
24
24
  CreateRunRequest,
@@ -81,6 +81,7 @@ class Driver:
81
81
  and message.metadata.src_node_id == self.node.node_id
82
82
  and message.metadata.message_id == ""
83
83
  and message.metadata.reply_to_message == ""
84
+ and message.metadata.ttl > 0
84
85
  ):
85
86
  raise ValueError(f"Invalid message: {message}")
86
87
 
@@ -90,7 +91,7 @@ class Driver:
90
91
  message_type: str,
91
92
  dst_node_id: int,
92
93
  group_id: str,
93
- ttl: str,
94
+ ttl: float = DEFAULT_TTL,
94
95
  ) -> Message:
95
96
  """Create a new message with specified parameters.
96
97
 
@@ -110,10 +111,10 @@ class Driver:
110
111
  group_id : str
111
112
  The ID of the group to which this message is associated. In some settings,
112
113
  this is used as the FL round.
113
- ttl : str
114
+ ttl : float (default: common.DEFAULT_TTL)
114
115
  Time-to-live for the round trip of this message, i.e., the time from sending
115
- this message to receiving a reply. It specifies the duration for which the
116
- message and its potential reply are considered valid.
116
+ this message to receiving a reply. It specifies in seconds the duration for
117
+ which the message and its potential reply are considered valid.
117
118
 
118
119
  Returns
119
120
  -------
@@ -15,6 +15,7 @@
15
15
  """Driver API servicer."""
16
16
 
17
17
 
18
+ import time
18
19
  from logging import DEBUG, INFO
19
20
  from typing import List, Optional, Set
20
21
  from uuid import UUID
@@ -72,6 +73,11 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
72
73
  """Push a set of TaskIns."""
73
74
  log(DEBUG, "DriverServicer.PushTaskIns")
74
75
 
76
+ # Set pushed_at (timestamp in seconds)
77
+ pushed_at = time.time()
78
+ for task_ins in request.task_ins_list:
79
+ task_ins.task.pushed_at = pushed_at
80
+
75
81
  # Validate request
76
82
  _raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty")
77
83
  for task_ins in request.task_ins_list:
@@ -15,7 +15,7 @@
15
15
  """Fleet API gRPC request-response servicer."""
16
16
 
17
17
 
18
- from logging import INFO
18
+ from logging import DEBUG, INFO
19
19
 
20
20
  import grpc
21
21
 
@@ -26,6 +26,8 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
26
26
  CreateNodeResponse,
27
27
  DeleteNodeRequest,
28
28
  DeleteNodeResponse,
29
+ PingRequest,
30
+ PingResponse,
29
31
  PullTaskInsRequest,
30
32
  PullTaskInsResponse,
31
33
  PushTaskResRequest,
@@ -61,6 +63,14 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
61
63
  state=self.state_factory.state(),
62
64
  )
63
65
 
66
+ def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingResponse:
67
+ """."""
68
+ log(DEBUG, "FleetServicer.Ping")
69
+ return message_handler.ping(
70
+ request=request,
71
+ state=self.state_factory.state(),
72
+ )
73
+
64
74
  def PullTaskIns(
65
75
  self, request: PullTaskInsRequest, context: grpc.ServicerContext
66
76
  ) -> PullTaskInsResponse:
@@ -15,6 +15,7 @@
15
15
  """Fleet API message handlers."""
16
16
 
17
17
 
18
+ import time
18
19
  from typing import List, Optional
19
20
  from uuid import UUID
20
21
 
@@ -23,6 +24,8 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
23
24
  CreateNodeResponse,
24
25
  DeleteNodeRequest,
25
26
  DeleteNodeResponse,
27
+ PingRequest,
28
+ PingResponse,
26
29
  PullTaskInsRequest,
27
30
  PullTaskInsResponse,
28
31
  PushTaskResRequest,
@@ -55,6 +58,14 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
55
58
  return DeleteNodeResponse()
56
59
 
57
60
 
61
+ def ping(
62
+ request: PingRequest, # pylint: disable=unused-argument
63
+ state: State, # pylint: disable=unused-argument
64
+ ) -> PingResponse:
65
+ """."""
66
+ return PingResponse(success=True)
67
+
68
+
58
69
  def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse:
59
70
  """Pull TaskIns handler."""
60
71
  # Get node_id if client node is not anonymous
@@ -77,6 +88,9 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
77
88
  task_res: TaskRes = request.task_res_list[0]
78
89
  # pylint: enable=no-member
79
90
 
91
+ # Set pushed_at (timestamp in seconds)
92
+ task_res.task.pushed_at = time.time()
93
+
80
94
  # Store TaskRes in State
81
95
  task_id: Optional[UUID] = state.store_task_res(task_res=task_res)
82
96
 
@@ -20,7 +20,7 @@ from typing import Callable, Dict, List, Tuple, Union
20
20
 
21
21
  import ray
22
22
 
23
- from flwr.client.client_app import ClientApp, LoadClientAppError
23
+ from flwr.client.client_app import ClientApp
24
24
  from flwr.common.context import Context
25
25
  from flwr.common.logger import log
26
26
  from flwr.common.message import Message
@@ -151,7 +151,6 @@ class RayBackend(Backend):
151
151
  )
152
152
 
153
153
  await future
154
-
155
154
  # Fetch result
156
155
  (
157
156
  out_mssg,
@@ -160,13 +159,15 @@ class RayBackend(Backend):
160
159
 
161
160
  return out_mssg, updated_context
162
161
 
163
- except LoadClientAppError as load_ex:
162
+ except Exception as ex:
164
163
  log(
165
164
  ERROR,
166
165
  "An exception was raised when processing a message by %s",
167
166
  self.__class__.__name__,
168
167
  )
169
- raise load_ex
168
+ # add actor back into pool
169
+ await self.pool.add_actor_back_to_pool(future)
170
+ raise ex
170
171
 
171
172
  async def terminate(self) -> None:
172
173
  """Terminate all actors in actor pool."""
@@ -14,9 +14,10 @@
14
14
  # ==============================================================================
15
15
  """Fleet Simulation Engine API."""
16
16
 
17
-
18
17
  import asyncio
19
18
  import json
19
+ import sys
20
+ import time
20
21
  import traceback
21
22
  from logging import DEBUG, ERROR, INFO, WARN
22
23
  from typing import Callable, Dict, List, Optional
@@ -24,6 +25,7 @@ from typing import Callable, Dict, List, Optional
24
25
  from flwr.client.client_app import ClientApp, LoadClientAppError
25
26
  from flwr.client.node_state import NodeState
26
27
  from flwr.common.logger import log
28
+ from flwr.common.message import Error
27
29
  from flwr.common.object_ref import load_app
28
30
  from flwr.common.serde import message_from_taskins, message_to_taskres
29
31
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
@@ -59,6 +61,7 @@ async def worker(
59
61
  """Get TaskIns from queue and pass it to an actor in the pool to execute it."""
60
62
  state = state_factory.state()
61
63
  while True:
64
+ out_mssg = None
62
65
  try:
63
66
  task_ins: TaskIns = await queue.get()
64
67
  node_id = task_ins.task.consumer.node_id
@@ -82,24 +85,25 @@ async def worker(
82
85
  task_ins.run_id, context=updated_context
83
86
  )
84
87
 
85
- # Convert to TaskRes
86
- task_res = message_to_taskres(out_mssg)
87
- # Store TaskRes in state
88
- state.store_task_res(task_res)
89
-
90
88
  except asyncio.CancelledError as e:
91
- log(DEBUG, "Async worker: %s", e)
89
+ log(DEBUG, "Terminating async worker: %s", e)
92
90
  break
93
91
 
94
- except LoadClientAppError as app_ex:
95
- log(ERROR, "Async worker: %s", app_ex)
96
- log(ERROR, traceback.format_exc())
97
- raise
98
-
92
+ # Exceptions aren't raised but reported as an error message
99
93
  except Exception as ex: # pylint: disable=broad-exception-caught
100
94
  log(ERROR, ex)
101
95
  log(ERROR, traceback.format_exc())
102
- break
96
+ reason = str(type(ex)) + ":<'" + str(ex) + "'>"
97
+ error = Error(code=0, reason=reason)
98
+ out_mssg = message.create_error_reply(error=error)
99
+
100
+ finally:
101
+ if out_mssg:
102
+ # Convert to TaskRes
103
+ task_res = message_to_taskres(out_mssg)
104
+ # Store TaskRes in state
105
+ task_res.task.pushed_at = time.time()
106
+ state.store_task_res(task_res)
103
107
 
104
108
 
105
109
  async def add_taskins_to_queue(
@@ -218,7 +222,7 @@ async def run(
218
222
  await backend.terminate()
219
223
 
220
224
 
221
- # pylint: disable=too-many-arguments,unused-argument,too-many-locals
225
+ # pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
222
226
  def start_vce(
223
227
  backend_name: str,
224
228
  backend_config_json_stream: str,
@@ -300,12 +304,14 @@ def start_vce(
300
304
  """Instantiate a Backend."""
301
305
  return backend_type(backend_config, work_dir=app_dir)
302
306
 
303
- log(INFO, "client_app_attr = %s", client_app_attr)
304
-
305
307
  # Load ClientApp if needed
306
308
  def _load() -> ClientApp:
307
309
 
308
310
  if client_app_attr:
311
+
312
+ if app_dir is not None:
313
+ sys.path.insert(0, app_dir)
314
+
309
315
  app: ClientApp = load_app(client_app_attr, LoadClientAppError)
310
316
 
311
317
  if not isinstance(app, ClientApp):
@@ -319,13 +325,23 @@ def start_vce(
319
325
 
320
326
  app_fn = _load
321
327
 
322
- asyncio.run(
323
- run(
324
- app_fn,
325
- backend_fn,
326
- nodes_mapping,
327
- state_factory,
328
- node_states,
329
- f_stop,
328
+ try:
329
+ # Test if ClientApp can be loaded
330
+ _ = app_fn()
331
+
332
+ # Run main simulation loop
333
+ asyncio.run(
334
+ run(
335
+ app_fn,
336
+ backend_fn,
337
+ nodes_mapping,
338
+ state_factory,
339
+ node_states,
340
+ f_stop,
341
+ )
330
342
  )
331
- )
343
+ except LoadClientAppError as loadapp_ex:
344
+ f_stop.set() # set termination event
345
+ raise loadapp_ex
346
+ except Exception as ex:
347
+ raise ex
@@ -17,9 +17,9 @@
17
17
 
18
18
  import os
19
19
  import threading
20
- from datetime import datetime, timedelta
20
+ import time
21
21
  from logging import ERROR
22
- from typing import Dict, List, Optional, Set
22
+ from typing import Dict, List, Optional, Set, Tuple
23
23
  from uuid import UUID, uuid4
24
24
 
25
25
  from flwr.common import log, now
@@ -32,7 +32,8 @@ class InMemoryState(State):
32
32
  """In-memory State implementation."""
33
33
 
34
34
  def __init__(self) -> None:
35
- self.node_ids: Set[int] = set()
35
+ # Map node_id to (online_until, ping_interval)
36
+ self.node_ids: Dict[int, Tuple[float, float]] = {}
36
37
  self.run_ids: Set[int] = set()
37
38
  self.task_ins_store: Dict[UUID, TaskIns] = {}
38
39
  self.task_res_store: Dict[UUID, TaskRes] = {}
@@ -50,15 +51,11 @@ class InMemoryState(State):
50
51
  log(ERROR, "`run_id` is invalid")
51
52
  return None
52
53
 
53
- # Create task_id, created_at and ttl
54
+ # Create task_id
54
55
  task_id = uuid4()
55
- created_at: datetime = now()
56
- ttl: datetime = created_at + timedelta(hours=24)
57
56
 
58
57
  # Store TaskIns
59
58
  task_ins.task_id = str(task_id)
60
- task_ins.task.created_at = created_at.isoformat()
61
- task_ins.task.ttl = ttl.isoformat()
62
59
  with self.lock:
63
60
  self.task_ins_store[task_id] = task_ins
64
61
 
@@ -113,15 +110,11 @@ class InMemoryState(State):
113
110
  log(ERROR, "`run_id` is invalid")
114
111
  return None
115
112
 
116
- # Create task_id, created_at and ttl
113
+ # Create task_id
117
114
  task_id = uuid4()
118
- created_at: datetime = now()
119
- ttl: datetime = created_at + timedelta(hours=24)
120
115
 
121
116
  # Store TaskRes
122
117
  task_res.task_id = str(task_id)
123
- task_res.task.created_at = created_at.isoformat()
124
- task_res.task.ttl = ttl.isoformat()
125
118
  with self.lock:
126
119
  self.task_res_store[task_id] = task_res
127
120
 
@@ -194,17 +187,21 @@ class InMemoryState(State):
194
187
  # Sample a random int64 as node_id
195
188
  node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
196
189
 
197
- if node_id not in self.node_ids:
198
- self.node_ids.add(node_id)
199
- return node_id
190
+ with self.lock:
191
+ if node_id not in self.node_ids:
192
+ # Default ping interval is 30s
193
+ # TODO: change 1e9 to 30s # pylint: disable=W0511
194
+ self.node_ids[node_id] = (time.time() + 1e9, 1e9)
195
+ return node_id
200
196
  log(ERROR, "Unexpected node registration failure.")
201
197
  return 0
202
198
 
203
199
  def delete_node(self, node_id: int) -> None:
204
200
  """Delete a client node."""
205
- if node_id not in self.node_ids:
206
- raise ValueError(f"Node {node_id} not found")
207
- self.node_ids.remove(node_id)
201
+ with self.lock:
202
+ if node_id not in self.node_ids:
203
+ raise ValueError(f"Node {node_id} not found")
204
+ del self.node_ids[node_id]
208
205
 
209
206
  def get_nodes(self, run_id: int) -> Set[int]:
210
207
  """Return all available client nodes.
@@ -214,17 +211,32 @@ class InMemoryState(State):
214
211
  If the provided `run_id` does not exist or has no matching nodes,
215
212
  an empty `Set` MUST be returned.
216
213
  """
217
- if run_id not in self.run_ids:
218
- return set()
219
- return self.node_ids
214
+ with self.lock:
215
+ if run_id not in self.run_ids:
216
+ return set()
217
+ current_time = time.time()
218
+ return {
219
+ node_id
220
+ for node_id, (online_until, _) in self.node_ids.items()
221
+ if online_until > current_time
222
+ }
220
223
 
221
224
  def create_run(self) -> int:
222
225
  """Create one run."""
223
226
  # Sample a random int64 as run_id
224
- run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
227
+ with self.lock:
228
+ run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
225
229
 
226
- if run_id not in self.run_ids:
227
- self.run_ids.add(run_id)
228
- return run_id
230
+ if run_id not in self.run_ids:
231
+ self.run_ids.add(run_id)
232
+ return run_id
229
233
  log(ERROR, "Unexpected run creation failure.")
230
234
  return 0
235
+
236
+ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
237
+ """Acknowledge a ping received from a node, serving as a heartbeat."""
238
+ with self.lock:
239
+ if node_id in self.node_ids:
240
+ self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
241
+ return True
242
+ return False
@@ -18,7 +18,7 @@
18
18
  import os
19
19
  import re
20
20
  import sqlite3
21
- from datetime import datetime, timedelta
21
+ import time
22
22
  from logging import DEBUG, ERROR
23
23
  from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
24
24
  from uuid import UUID, uuid4
@@ -33,10 +33,16 @@ from .state import State
33
33
 
34
34
  SQL_CREATE_TABLE_NODE = """
35
35
  CREATE TABLE IF NOT EXISTS node(
36
- node_id INTEGER UNIQUE
36
+ node_id INTEGER UNIQUE,
37
+ online_until REAL,
38
+ ping_interval REAL
37
39
  );
38
40
  """
39
41
 
42
+ SQL_CREATE_INDEX_ONLINE_UNTIL = """
43
+ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
44
+ """
45
+
40
46
  SQL_CREATE_TABLE_RUN = """
41
47
  CREATE TABLE IF NOT EXISTS run(
42
48
  run_id INTEGER UNIQUE
@@ -52,9 +58,10 @@ CREATE TABLE IF NOT EXISTS task_ins(
52
58
  producer_node_id INTEGER,
53
59
  consumer_anonymous BOOLEAN,
54
60
  consumer_node_id INTEGER,
55
- created_at TEXT,
61
+ created_at REAL,
56
62
  delivered_at TEXT,
57
- ttl TEXT,
63
+ pushed_at REAL,
64
+ ttl REAL,
58
65
  ancestry TEXT,
59
66
  task_type TEXT,
60
67
  recordset BLOB,
@@ -72,9 +79,10 @@ CREATE TABLE IF NOT EXISTS task_res(
72
79
  producer_node_id INTEGER,
73
80
  consumer_anonymous BOOLEAN,
74
81
  consumer_node_id INTEGER,
75
- created_at TEXT,
82
+ created_at REAL,
76
83
  delivered_at TEXT,
77
- ttl TEXT,
84
+ pushed_at REAL,
85
+ ttl REAL,
78
86
  ancestry TEXT,
79
87
  task_type TEXT,
80
88
  recordset BLOB,
@@ -82,7 +90,7 @@ CREATE TABLE IF NOT EXISTS task_res(
82
90
  );
83
91
  """
84
92
 
85
- DictOrTuple = Union[Tuple[Any], Dict[str, Any]]
93
+ DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]]
86
94
 
87
95
 
88
96
  class SqliteState(State):
@@ -123,6 +131,7 @@ class SqliteState(State):
123
131
  cur.execute(SQL_CREATE_TABLE_TASK_INS)
124
132
  cur.execute(SQL_CREATE_TABLE_TASK_RES)
125
133
  cur.execute(SQL_CREATE_TABLE_NODE)
134
+ cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
126
135
  res = cur.execute("SELECT name FROM sqlite_schema;")
127
136
 
128
137
  return res.fetchall()
@@ -185,15 +194,11 @@ class SqliteState(State):
185
194
  log(ERROR, errors)
186
195
  return None
187
196
 
188
- # Create task_id, created_at and ttl
197
+ # Create task_id
189
198
  task_id = uuid4()
190
- created_at: datetime = now()
191
- ttl: datetime = created_at + timedelta(hours=24)
192
199
 
193
200
  # Store TaskIns
194
201
  task_ins.task_id = str(task_id)
195
- task_ins.task.created_at = created_at.isoformat()
196
- task_ins.task.ttl = ttl.isoformat()
197
202
  data = (task_ins_to_dict(task_ins),)
198
203
  columns = ", ".join([f":{key}" for key in data[0]])
199
204
  query = f"INSERT INTO task_ins VALUES({columns});"
@@ -320,15 +325,11 @@ class SqliteState(State):
320
325
  log(ERROR, errors)
321
326
  return None
322
327
 
323
- # Create task_id, created_at and ttl
328
+ # Create task_id
324
329
  task_id = uuid4()
325
- created_at: datetime = now()
326
- ttl: datetime = created_at + timedelta(hours=24)
327
330
 
328
331
  # Store TaskIns
329
332
  task_res.task_id = str(task_id)
330
- task_res.task.created_at = created_at.isoformat()
331
- task_res.task.ttl = ttl.isoformat()
332
333
  data = (task_res_to_dict(task_res),)
333
334
  columns = ", ".join([f":{key}" for key in data[0]])
334
335
  query = f"INSERT INTO task_res VALUES({columns});"
@@ -472,9 +473,14 @@ class SqliteState(State):
472
473
  # Sample a random int64 as node_id
473
474
  node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
474
475
 
475
- query = "INSERT INTO node VALUES(:node_id);"
476
+ query = (
477
+ "INSERT INTO node (node_id, online_until, ping_interval) VALUES (?, ?, ?)"
478
+ )
479
+
476
480
  try:
477
- self.query(query, {"node_id": node_id})
481
+ # Default ping interval is 30s
482
+ # TODO: change 1e9 to 30s # pylint: disable=W0511
483
+ self.query(query, (node_id, time.time() + 1e9, 1e9))
478
484
  except sqlite3.IntegrityError:
479
485
  log(ERROR, "Unexpected node registration failure.")
480
486
  return 0
@@ -499,8 +505,8 @@ class SqliteState(State):
499
505
  return set()
500
506
 
501
507
  # Get nodes
502
- query = "SELECT * FROM node;"
503
- rows = self.query(query)
508
+ query = "SELECT node_id FROM node WHERE online_until > ?;"
509
+ rows = self.query(query, (time.time(),))
504
510
  result: Set[int] = {row["node_id"] for row in rows}
505
511
  return result
506
512
 
@@ -519,6 +525,17 @@ class SqliteState(State):
519
525
  log(ERROR, "Unexpected run creation failure.")
520
526
  return 0
521
527
 
528
+ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
529
+ """Acknowledge a ping received from a node, serving as a heartbeat."""
530
+ # Update `online_until` and `ping_interval` for the given `node_id`
531
+ query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?;"
532
+ try:
533
+ self.query(query, (time.time() + ping_interval, ping_interval, node_id))
534
+ return True
535
+ except sqlite3.IntegrityError:
536
+ log(ERROR, "`node_id` does not exist.")
537
+ return False
538
+
522
539
 
523
540
  def dict_factory(
524
541
  cursor: sqlite3.Cursor,
@@ -544,6 +561,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]:
544
561
  "consumer_node_id": task_msg.task.consumer.node_id,
545
562
  "created_at": task_msg.task.created_at,
546
563
  "delivered_at": task_msg.task.delivered_at,
564
+ "pushed_at": task_msg.task.pushed_at,
547
565
  "ttl": task_msg.task.ttl,
548
566
  "ancestry": ",".join(task_msg.task.ancestry),
549
567
  "task_type": task_msg.task.task_type,
@@ -564,6 +582,7 @@ def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]:
564
582
  "consumer_node_id": task_msg.task.consumer.node_id,
565
583
  "created_at": task_msg.task.created_at,
566
584
  "delivered_at": task_msg.task.delivered_at,
585
+ "pushed_at": task_msg.task.pushed_at,
567
586
  "ttl": task_msg.task.ttl,
568
587
  "ancestry": ",".join(task_msg.task.ancestry),
569
588
  "task_type": task_msg.task.task_type,
@@ -592,6 +611,7 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns:
592
611
  ),
593
612
  created_at=task_dict["created_at"],
594
613
  delivered_at=task_dict["delivered_at"],
614
+ pushed_at=task_dict["pushed_at"],
595
615
  ttl=task_dict["ttl"],
596
616
  ancestry=task_dict["ancestry"].split(","),
597
617
  task_type=task_dict["task_type"],
@@ -621,6 +641,7 @@ def dict_to_task_res(task_dict: Dict[str, Any]) -> TaskRes:
621
641
  ),
622
642
  created_at=task_dict["created_at"],
623
643
  delivered_at=task_dict["delivered_at"],
644
+ pushed_at=task_dict["pushed_at"],
624
645
  ttl=task_dict["ttl"],
625
646
  ancestry=task_dict["ancestry"].split(","),
626
647
  task_type=task_dict["task_type"],
@@ -152,3 +152,22 @@ class State(abc.ABC):
152
152
  @abc.abstractmethod
153
153
  def create_run(self) -> int:
154
154
  """Create one run."""
155
+
156
+ @abc.abstractmethod
157
+ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
158
+ """Acknowledge a ping received from a node, serving as a heartbeat.
159
+
160
+ Parameters
161
+ ----------
162
+ node_id : int
163
+ The `node_id` from which the ping was received.
164
+ ping_interval : float
165
+ The interval (in seconds) from the current timestamp within which the next
166
+ ping from this node must be received. This acts as a hard deadline to ensure
167
+ an accurate assessment of the node's availability.
168
+
169
+ Returns
170
+ -------
171
+ is_acknowledged : bool
172
+ True if the ping is successfully acknowledged; otherwise, False.
173
+ """