flwr 1.15.2__py3-none-any.whl → 1.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. flwr/cli/build.py +2 -0
  2. flwr/cli/log.py +20 -21
  3. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  12. flwr/client/client_app.py +147 -36
  13. flwr/client/clientapp/app.py +4 -0
  14. flwr/client/message_handler/message_handler.py +1 -1
  15. flwr/client/rest_client/connection.py +4 -6
  16. flwr/client/supernode/__init__.py +0 -2
  17. flwr/client/supernode/app.py +1 -11
  18. flwr/common/address.py +35 -0
  19. flwr/common/args.py +8 -2
  20. flwr/common/auth_plugin/auth_plugin.py +2 -1
  21. flwr/common/constant.py +16 -0
  22. flwr/common/event_log_plugin/__init__.py +22 -0
  23. flwr/common/event_log_plugin/event_log_plugin.py +60 -0
  24. flwr/common/grpc.py +1 -1
  25. flwr/common/message.py +18 -7
  26. flwr/common/object_ref.py +0 -10
  27. flwr/common/record/conversion_utils.py +8 -17
  28. flwr/common/record/parametersrecord.py +151 -16
  29. flwr/common/record/recordset.py +95 -88
  30. flwr/common/secure_aggregation/quantization.py +5 -1
  31. flwr/common/serde.py +8 -126
  32. flwr/common/telemetry.py +0 -10
  33. flwr/common/typing.py +36 -0
  34. flwr/server/app.py +18 -2
  35. flwr/server/compat/app.py +4 -1
  36. flwr/server/compat/app_utils.py +10 -2
  37. flwr/server/compat/driver_client_proxy.py +2 -2
  38. flwr/server/driver/driver.py +1 -1
  39. flwr/server/driver/grpc_driver.py +10 -1
  40. flwr/server/driver/inmemory_driver.py +17 -20
  41. flwr/server/run_serverapp.py +2 -13
  42. flwr/server/server_app.py +93 -20
  43. flwr/server/superlink/driver/serverappio_servicer.py +25 -27
  44. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
  45. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
  46. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  47. flwr/server/superlink/fleet/vce/vce_api.py +32 -35
  48. flwr/server/superlink/linkstate/in_memory_linkstate.py +140 -126
  49. flwr/server/superlink/linkstate/linkstate.py +47 -60
  50. flwr/server/superlink/linkstate/sqlite_linkstate.py +210 -276
  51. flwr/server/superlink/linkstate/utils.py +91 -119
  52. flwr/server/utils/__init__.py +2 -2
  53. flwr/server/utils/validator.py +53 -68
  54. flwr/server/workflow/default_workflows.py +4 -1
  55. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +3 -3
  56. flwr/superexec/app.py +0 -14
  57. flwr/superexec/exec_servicer.py +4 -4
  58. flwr/superexec/exec_user_auth_interceptor.py +5 -3
  59. {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/METADATA +4 -4
  60. {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/RECORD +63 -66
  61. {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/entry_points.txt +0 -3
  62. flwr/client/message_handler/task_handler.py +0 -37
  63. flwr/proto/task_pb2.py +0 -33
  64. flwr/proto/task_pb2.pyi +0 -100
  65. flwr/proto/task_pb2_grpc.py +0 -4
  66. flwr/proto/task_pb2_grpc.pyi +0 -4
  67. {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/LICENSE +0 -0
  68. {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/WHEEL +0 -0
@@ -26,7 +26,7 @@ from logging import DEBUG, ERROR, WARNING
26
26
  from typing import Any, Optional, Union, cast
27
27
  from uuid import UUID, uuid4
28
28
 
29
- from flwr.common import Context, log, now
29
+ from flwr.common import Context, Message, Metadata, log, now
30
30
  from flwr.common.constant import (
31
31
  MESSAGE_TTL_TOLERANCE,
32
32
  NODE_ID_NUM_BYTES,
@@ -35,15 +35,20 @@ from flwr.common.constant import (
35
35
  Status,
36
36
  )
37
37
  from flwr.common.record import ConfigsRecord
38
+ from flwr.common.serde import (
39
+ error_from_proto,
40
+ error_to_proto,
41
+ recordset_from_proto,
42
+ recordset_to_proto,
43
+ )
38
44
  from flwr.common.typing import Run, RunStatus, UserConfig
39
45
 
40
46
  # pylint: disable=E0611
41
- from flwr.proto.node_pb2 import Node
47
+ from flwr.proto.error_pb2 import Error as ProtoError
42
48
  from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
43
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
44
49
 
45
50
  # pylint: enable=E0611
46
- from flwr.server.utils.validator import validate_task_ins_or_res
51
+ from flwr.server.utils.validator import validate_message
47
52
 
48
53
  from .linkstate import LinkState
49
54
  from .utils import (
@@ -58,8 +63,8 @@ from .utils import (
58
63
  generate_rand_int_from_bytes,
59
64
  has_valid_sub_status,
60
65
  is_valid_transition,
61
- verify_found_taskres,
62
- verify_taskins_ids,
66
+ verify_found_message_replies,
67
+ verify_message_ids,
63
68
  )
64
69
 
65
70
  SQL_CREATE_TABLE_NODE = """
@@ -117,36 +122,39 @@ CREATE TABLE IF NOT EXISTS context(
117
122
  );
118
123
  """
119
124
 
120
- SQL_CREATE_TABLE_TASK_INS = """
121
- CREATE TABLE IF NOT EXISTS task_ins(
122
- task_id TEXT UNIQUE,
125
+ SQL_CREATE_TABLE_MESSAGE_INS = """
126
+ CREATE TABLE IF NOT EXISTS message_ins(
127
+ message_id TEXT UNIQUE,
123
128
  group_id TEXT,
124
129
  run_id INTEGER,
125
- producer_node_id INTEGER,
126
- consumer_node_id INTEGER,
130
+ src_node_id INTEGER,
131
+ dst_node_id INTEGER,
132
+ reply_to_message TEXT,
127
133
  created_at REAL,
128
134
  delivered_at TEXT,
129
135
  ttl REAL,
130
- ancestry TEXT,
131
- task_type TEXT,
132
- recordset BLOB,
136
+ message_type TEXT,
137
+ content BLOB NULL,
138
+ error BLOB NULL,
133
139
  FOREIGN KEY(run_id) REFERENCES run(run_id)
134
140
  );
135
141
  """
136
142
 
137
- SQL_CREATE_TABLE_TASK_RES = """
138
- CREATE TABLE IF NOT EXISTS task_res(
139
- task_id TEXT UNIQUE,
143
+
144
+ SQL_CREATE_TABLE_MESSAGE_RES = """
145
+ CREATE TABLE IF NOT EXISTS message_res(
146
+ message_id TEXT UNIQUE,
140
147
  group_id TEXT,
141
148
  run_id INTEGER,
142
- producer_node_id INTEGER,
143
- consumer_node_id INTEGER,
149
+ src_node_id INTEGER,
150
+ dst_node_id INTEGER,
151
+ reply_to_message TEXT,
144
152
  created_at REAL,
145
153
  delivered_at TEXT,
146
154
  ttl REAL,
147
- ancestry TEXT,
148
- task_type TEXT,
149
- recordset BLOB,
155
+ message_type TEXT,
156
+ content BLOB NULL,
157
+ error BLOB NULL,
150
158
  FOREIGN KEY(run_id) REFERENCES run(run_id)
151
159
  );
152
160
  """
@@ -196,8 +204,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
196
204
  cur.execute(SQL_CREATE_TABLE_RUN)
197
205
  cur.execute(SQL_CREATE_TABLE_LOGS)
198
206
  cur.execute(SQL_CREATE_TABLE_CONTEXT)
199
- cur.execute(SQL_CREATE_TABLE_TASK_INS)
200
- cur.execute(SQL_CREATE_TABLE_TASK_RES)
207
+ cur.execute(SQL_CREATE_TABLE_MESSAGE_INS)
208
+ cur.execute(SQL_CREATE_TABLE_MESSAGE_RES)
201
209
  cur.execute(SQL_CREATE_TABLE_NODE)
202
210
  cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
203
211
  cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
@@ -239,88 +247,62 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
239
247
 
240
248
  return result
241
249
 
242
- def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
243
- """Store one TaskIns.
244
-
245
- Usually, the ServerAppIo API calls this to schedule instructions.
246
-
247
- Stores the value of the task_ins in the link state and, if successful,
248
- returns the task_id (UUID) of the task_ins. If, for any reason, storing
249
- the task_ins fails, `None` is returned.
250
-
251
- Constraints
252
- -----------
253
-
254
- `task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
255
- """
256
- # Validate task
257
- errors = validate_task_ins_or_res(task_ins)
250
+ def store_message_ins(self, message: Message) -> Optional[UUID]:
251
+ """Store one Message."""
252
+ # Validate message
253
+ errors = validate_message(message=message, is_reply_message=False)
258
254
  if any(errors):
259
255
  log(ERROR, errors)
260
256
  return None
261
- # Create task_id
262
- task_id = uuid4()
257
+ # Create message_id
258
+ message_id = uuid4()
263
259
 
264
- # Store TaskIns
265
- task_ins.task_id = str(task_id)
266
- data = (task_ins_to_dict(task_ins),)
260
+ # Store Message
261
+ # pylint: disable-next=W0212
262
+ message.metadata._message_id = str(message_id) # type: ignore
263
+ data = (message_to_dict(message),)
267
264
 
268
265
  # Convert values from uint64 to sint64 for SQLite
269
266
  convert_uint64_values_in_dict_to_sint64(
270
- data[0], ["run_id", "producer_node_id", "consumer_node_id"]
267
+ data[0], ["run_id", "src_node_id", "dst_node_id"]
271
268
  )
272
269
 
273
270
  # Validate run_id
274
271
  query = "SELECT run_id FROM run WHERE run_id = ?;"
275
272
  if not self.query(query, (data[0]["run_id"],)):
276
- log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
273
+ log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
277
274
  return None
275
+
278
276
  # Validate source node ID
279
- if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
277
+ if message.metadata.src_node_id != SUPERLINK_NODE_ID:
280
278
  log(
281
279
  ERROR,
282
- "Invalid source node ID for TaskIns: %s",
283
- task_ins.task.producer.node_id,
280
+ "Invalid source node ID for Message: %s",
281
+ message.metadata.src_node_id,
284
282
  )
285
283
  return None
284
+
286
285
  # Validate destination node ID
287
286
  query = "SELECT node_id FROM node WHERE node_id = ?;"
288
- if not self.query(query, (data[0]["consumer_node_id"],)):
287
+ if not self.query(query, (data[0]["dst_node_id"],)):
289
288
  log(
290
289
  ERROR,
291
- "Invalid destination node ID for TaskIns: %s",
292
- task_ins.task.consumer.node_id,
290
+ "Invalid destination node ID for Message: %s",
291
+ message.metadata.dst_node_id,
293
292
  )
294
293
  return None
295
294
 
296
295
  columns = ", ".join([f":{key}" for key in data[0]])
297
- query = f"INSERT INTO task_ins VALUES({columns});"
296
+ query = f"INSERT INTO message_ins VALUES({columns});"
298
297
 
299
298
  # Only invalid run_id can trigger IntegrityError.
300
299
  # This may need to be changed in the future version with more integrity checks.
301
300
  self.query(query, data)
302
301
 
303
- return task_id
304
-
305
- def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
306
- """Get undelivered TaskIns for one node.
307
-
308
- Usually, the Fleet API calls this for Nodes planning to work on one or more
309
- TaskIns.
310
-
311
- Constraints
312
- -----------
313
- Retrieve all TaskIns where
314
-
315
- 1. the `task_ins.task.consumer.node_id` equals `node_id` AND
316
- 2. the `task_ins.task.delivered_at` equals `""`.
317
-
318
- `delivered_at` MUST BE set (i.e., not `""`) otherwise the TaskIns MUST not be in
319
- the result.
302
+ return message_id
320
303
 
321
- If `limit` is not `None`, return, at most, `limit` number of `task_ins`. If
322
- `limit` is set, it has to be greater than zero.
323
- """
304
+ def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
305
+ """Get all Messages that have not been delivered yet."""
324
306
  if limit is not None and limit < 1:
325
307
  raise AssertionError("`limit` must be >= 1")
326
308
 
@@ -333,11 +315,11 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
333
315
  # Convert the uint64 value to sint64 for SQLite
334
316
  data["node_id"] = convert_uint64_to_sint64(node_id)
335
317
 
336
- # Retrieve all TaskIns for node_id
318
+ # Retrieve all Messages for node_id
337
319
  query = """
338
- SELECT task_id
339
- FROM task_ins
340
- WHERE consumer_node_id == :node_id
320
+ SELECT message_id
321
+ FROM message_ins
322
+ WHERE dst_node_id == :node_id
341
323
  AND delivered_at = ""
342
324
  AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
343
325
  """
@@ -352,20 +334,20 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
352
334
 
353
335
  if rows:
354
336
  # Prepare query
355
- task_ids = [row["task_id"] for row in rows]
356
- placeholders: str = ",".join([f":id_{i}" for i in range(len(task_ids))])
337
+ message_ids = [row["message_id"] for row in rows]
338
+ placeholders: str = ",".join([f":id_{i}" for i in range(len(message_ids))])
357
339
  query = f"""
358
- UPDATE task_ins
340
+ UPDATE message_ins
359
341
  SET delivered_at = :delivered_at
360
- WHERE task_id IN ({placeholders})
342
+ WHERE message_id IN ({placeholders})
361
343
  RETURNING *;
362
344
  """
363
345
 
364
346
  # Prepare data for query
365
347
  delivered_at = now().isoformat()
366
348
  data = {"delivered_at": delivered_at}
367
- for index, task_id in enumerate(task_ids):
368
- data[f"id_{index}"] = str(task_id)
349
+ for index, msg_id in enumerate(message_ids):
350
+ data[f"id_{index}"] = str(msg_id)
369
351
 
370
352
  # Run query
371
353
  rows = self.query(query, data)
@@ -373,86 +355,80 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
373
355
  for row in rows:
374
356
  # Convert values from sint64 to uint64
375
357
  convert_sint64_values_in_dict_to_uint64(
376
- row, ["run_id", "producer_node_id", "consumer_node_id"]
358
+ row, ["run_id", "src_node_id", "dst_node_id"]
377
359
  )
378
360
 
379
- result = [dict_to_task_ins(row) for row in rows]
361
+ result = [dict_to_message(row) for row in rows]
380
362
 
381
363
  return result
382
364
 
383
- def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
384
- """Store one TaskRes.
385
-
386
- Usually, the Fleet API calls this when Nodes return their results.
387
-
388
- Stores the TaskRes and, if successful, returns the `task_id` (UUID) of
389
- the `task_res`. If storing the `task_res` fails, `None` is returned.
390
-
391
- Constraints
392
- -----------
393
- `task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
394
- """
395
- # Validate task
396
- errors = validate_task_ins_or_res(task_res)
365
+ def store_message_res(self, message: Message) -> Optional[UUID]:
366
+ """Store one Message."""
367
+ # Validate message
368
+ errors = validate_message(message=message, is_reply_message=True)
397
369
  if any(errors):
398
370
  log(ERROR, errors)
399
371
  return None
400
372
 
401
- # Create task_id
402
- task_id = uuid4()
403
-
404
- task_ins_id = task_res.task.ancestry[0]
405
- task_ins = self.get_valid_task_ins(task_ins_id)
406
- if task_ins is None:
373
+ res_metadata = message.metadata
374
+ msg_ins_id = res_metadata.reply_to_message
375
+ msg_ins = self.get_valid_message_ins(msg_ins_id)
376
+ if msg_ins is None:
407
377
  log(
408
378
  ERROR,
409
- "Failed to store TaskRes: "
410
- "TaskIns with task_id %s does not exist or has expired.",
411
- task_ins_id,
379
+ "Failed to store Message reply: "
380
+ "The message it replies to with message_id %s does not exist or "
381
+ "has expired.",
382
+ msg_ins_id,
412
383
  )
413
384
  return None
414
385
 
415
- # Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
386
+ # Ensure that the dst_node_id of the original message matches the src_node_id of
387
+ # reply being processed.
416
388
  if (
417
- task_ins
418
- and task_res
419
- and convert_sint64_to_uint64(task_ins["consumer_node_id"])
420
- != task_res.task.producer.node_id
389
+ msg_ins
390
+ and message
391
+ and convert_sint64_to_uint64(msg_ins["dst_node_id"])
392
+ != res_metadata.src_node_id
421
393
  ):
422
394
  return None
423
395
 
424
- # Fail if the TaskRes TTL exceeds the
425
- # expiration time of the TaskIns it replies to.
426
- # Condition: TaskIns.created_at + TaskIns.ttl ≥
427
- # TaskRes.created_at + TaskRes.ttl
396
+ # Fail if the Message TTL exceeds the
397
+ # expiration time of the Message it replies to.
398
+ # Condition: ins_metadata.created_at + ins_metadata.ttl ≥
399
+ # res_metadata.created_at + res_metadata.ttl
428
400
  # A small tolerance is introduced to account
429
401
  # for floating-point precision issues.
430
402
  max_allowed_ttl = (
431
- task_ins["created_at"] + task_ins["ttl"] - task_res.task.created_at
403
+ msg_ins["created_at"] + msg_ins["ttl"] - res_metadata.created_at
432
404
  )
433
- if task_res.task.ttl and (
434
- task_res.task.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
405
+ if res_metadata.ttl and (
406
+ res_metadata.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
435
407
  ):
436
408
  log(
437
409
  WARNING,
438
- "Received TaskRes with TTL %.2f "
439
- "exceeding the allowed maximum TTL %.2f.",
440
- task_res.task.ttl,
410
+ "Received Message with TTL %.2f exceeding the allowed maximum "
411
+ "TTL %.2f.",
412
+ res_metadata.ttl,
441
413
  max_allowed_ttl,
442
414
  )
443
415
  return None
444
416
 
445
- # Store TaskRes
446
- task_res.task_id = str(task_id)
447
- data = (task_res_to_dict(task_res),)
417
+ # Create message_id
418
+ message_id = uuid4()
419
+
420
+ # Store Message
421
+ # pylint: disable-next=W0212
422
+ message.metadata._message_id = str(message_id) # type: ignore
423
+ data = (message_to_dict(message),)
448
424
 
449
425
  # Convert values from uint64 to sint64 for SQLite
450
426
  convert_uint64_values_in_dict_to_sint64(
451
- data[0], ["run_id", "producer_node_id", "consumer_node_id"]
427
+ data[0], ["run_id", "src_node_id", "dst_node_id"]
452
428
  )
453
429
 
454
430
  columns = ", ".join([f":{key}" for key in data[0]])
455
- query = f"INSERT INTO task_res VALUES({columns});"
431
+ query = f"INSERT INTO message_res VALUES({columns});"
456
432
 
457
433
  # Only invalid run_id can trigger IntegrityError.
458
434
  # This may need to be changed in the future version with more integrity checks.
@@ -462,124 +438,125 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
462
438
  log(ERROR, "`run` is invalid")
463
439
  return None
464
440
 
465
- return task_id
441
+ return message_id
466
442
 
467
- # pylint: disable-next=R0912,R0915,R0914
468
- def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
469
- """Get TaskRes for the given TaskIns IDs."""
470
- ret: dict[UUID, TaskRes] = {}
443
+ def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
444
+ """Get reply Messages for the given Message IDs."""
445
+ ret: dict[UUID, Message] = {}
471
446
 
472
- # Verify TaskIns IDs
447
+ # Verify Message IDs
473
448
  current = time.time()
474
449
  query = f"""
475
450
  SELECT *
476
- FROM task_ins
477
- WHERE task_id IN ({",".join(["?"] * len(task_ids))});
451
+ FROM message_ins
452
+ WHERE message_id IN ({",".join(["?"] * len(message_ids))});
478
453
  """
479
- rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
480
- found_task_ins_dict: dict[UUID, TaskIns] = {}
454
+ rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
455
+ found_message_ins_dict: dict[UUID, Message] = {}
481
456
  for row in rows:
482
457
  convert_sint64_values_in_dict_to_uint64(
483
- row, ["run_id", "producer_node_id", "consumer_node_id"]
458
+ row, ["run_id", "src_node_id", "dst_node_id"]
484
459
  )
485
- found_task_ins_dict[UUID(row["task_id"])] = dict_to_task_ins(row)
460
+ found_message_ins_dict[UUID(row["message_id"])] = dict_to_message(row)
486
461
 
487
- ret = verify_taskins_ids(
488
- inquired_taskins_ids=task_ids,
489
- found_taskins_dict=found_task_ins_dict,
462
+ ret = verify_message_ids(
463
+ inquired_message_ids=message_ids,
464
+ found_message_ins_dict=found_message_ins_dict,
490
465
  current_time=current,
491
466
  )
492
467
 
493
- # Find all TaskRes
468
+ # Find all reply Messages
494
469
  query = f"""
495
470
  SELECT *
496
- FROM task_res
497
- WHERE ancestry IN ({",".join(["?"] * len(task_ids))})
471
+ FROM message_res
472
+ WHERE reply_to_message IN ({",".join(["?"] * len(message_ids))})
498
473
  AND delivered_at = "";
499
474
  """
500
- rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
475
+ rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
501
476
  for row in rows:
502
477
  convert_sint64_values_in_dict_to_uint64(
503
- row, ["run_id", "producer_node_id", "consumer_node_id"]
478
+ row, ["run_id", "src_node_id", "dst_node_id"]
504
479
  )
505
- tmp_ret_dict = verify_found_taskres(
506
- inquired_taskins_ids=task_ids,
507
- found_taskins_dict=found_task_ins_dict,
508
- found_taskres_list=[dict_to_task_res(row) for row in rows],
480
+ tmp_ret_dict = verify_found_message_replies(
481
+ inquired_message_ids=message_ids,
482
+ found_message_ins_dict=found_message_ins_dict,
483
+ found_message_res_list=[dict_to_message(row) for row in rows],
509
484
  current_time=current,
510
485
  )
511
486
  ret.update(tmp_ret_dict)
512
487
 
513
- # Mark existing TaskRes to be returned as delivered
488
+ # Mark existing reply Messages to be returned as delivered
514
489
  delivered_at = now().isoformat()
515
- for task_res in ret.values():
516
- task_res.task.delivered_at = delivered_at
517
- task_res_ids = [task_res.task_id for task_res in ret.values()]
490
+ for message_res in ret.values():
491
+ message_res.metadata.delivered_at = delivered_at
492
+ message_res_ids = [
493
+ message_res.metadata.message_id for message_res in ret.values()
494
+ ]
518
495
  query = f"""
519
- UPDATE task_res
496
+ UPDATE message_res
520
497
  SET delivered_at = ?
521
- WHERE task_id IN ({",".join(["?"] * len(task_res_ids))});
498
+ WHERE message_id IN ({",".join(["?"] * len(message_res_ids))});
522
499
  """
523
- data: list[Any] = [delivered_at] + task_res_ids
500
+ data: list[Any] = [delivered_at] + message_res_ids
524
501
  self.query(query, data)
525
502
 
526
503
  return list(ret.values())
527
504
 
528
- def num_task_ins(self) -> int:
529
- """Calculate the number of task_ins in store.
505
+ def num_message_ins(self) -> int:
506
+ """Calculate the number of instruction Messages in store.
530
507
 
531
- This includes delivered but not yet deleted task_ins.
508
+ This includes delivered but not yet deleted.
532
509
  """
533
- query = "SELECT count(*) AS num FROM task_ins;"
510
+ query = "SELECT count(*) AS num FROM message_ins;"
534
511
  rows = self.query(query)
535
512
  result = rows[0]
536
513
  num = cast(int, result["num"])
537
514
  return num
538
515
 
539
- def num_task_res(self) -> int:
540
- """Calculate the number of task_res in store.
516
+ def num_message_res(self) -> int:
517
+ """Calculate the number of reply Messages in store.
541
518
 
542
- This includes delivered but not yet deleted task_res.
519
+ This includes delivered but not yet deleted.
543
520
  """
544
- query = "SELECT count(*) AS num FROM task_res;"
521
+ query = "SELECT count(*) AS num FROM message_res;"
545
522
  rows = self.query(query)
546
523
  result: dict[str, int] = rows[0]
547
524
  return result["num"]
548
525
 
549
- def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
550
- """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
551
- if not task_ins_ids:
526
+ def delete_messages(self, message_ins_ids: set[UUID]) -> None:
527
+ """Delete a Message and its reply based on provided Message IDs."""
528
+ if not message_ins_ids:
552
529
  return
553
530
  if self.conn is None:
554
531
  raise AttributeError("LinkState not initialized")
555
532
 
556
- placeholders = ",".join(["?"] * len(task_ins_ids))
557
- data = tuple(str(task_id) for task_id in task_ins_ids)
533
+ placeholders = ",".join(["?"] * len(message_ins_ids))
534
+ data = tuple(str(message_id) for message_id in message_ins_ids)
558
535
 
559
- # Delete task_ins
536
+ # Delete Message
560
537
  query_1 = f"""
561
- DELETE FROM task_ins
562
- WHERE task_id IN ({placeholders});
538
+ DELETE FROM message_ins
539
+ WHERE message_id IN ({placeholders});
563
540
  """
564
541
 
565
- # Delete task_res
542
+ # Delete reply Message
566
543
  query_2 = f"""
567
- DELETE FROM task_res
568
- WHERE ancestry IN ({placeholders});
544
+ DELETE FROM message_res
545
+ WHERE reply_to_message IN ({placeholders});
569
546
  """
570
547
 
571
548
  with self.conn:
572
549
  self.conn.execute(query_1, data)
573
550
  self.conn.execute(query_2, data)
574
551
 
575
- def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
576
- """Get all TaskIns IDs for the given run_id."""
552
+ def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
553
+ """Get all instruction Message IDs for the given run_id."""
577
554
  if self.conn is None:
578
555
  raise AttributeError("LinkState not initialized")
579
556
 
580
557
  query = """
581
- SELECT task_id
582
- FROM task_ins
558
+ SELECT message_id
559
+ FROM message_ins
583
560
  WHERE run_id = :run_id;
584
561
  """
585
562
 
@@ -589,7 +566,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
589
566
  with self.conn:
590
567
  rows = self.conn.execute(query, data).fetchall()
591
568
 
592
- return {UUID(row["task_id"]) for row in rows}
569
+ return {UUID(row["message_id"]) for row in rows}
593
570
 
594
571
  def create_node(self, ping_interval: float) -> int:
595
572
  """Create, store in the link state, and return `node_id`."""
@@ -1001,32 +978,32 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1001
978
  latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
1002
979
  return "".join(row["log"] for row in rows), latest_timestamp
1003
980
 
1004
- def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
1005
- """Check if the TaskIns exists and is valid (not expired).
981
+ def get_valid_message_ins(self, message_id: str) -> Optional[dict[str, Any]]:
982
+ """Check if the Message exists and is valid (not expired).
1006
983
 
1007
- Return TaskIns if valid.
984
+ Return Message if valid.
1008
985
  """
1009
986
  query = """
1010
987
  SELECT *
1011
- FROM task_ins
1012
- WHERE task_id = :task_id
988
+ FROM message_ins
989
+ WHERE message_id = :message_id
1013
990
  """
1014
- data = {"task_id": task_id}
991
+ data = {"message_id": message_id}
1015
992
  rows = self.query(query, data)
1016
993
  if not rows:
1017
- # TaskIns does not exist
994
+ # Message does not exist
1018
995
  return None
1019
996
 
1020
- task_ins = rows[0]
1021
- created_at = task_ins["created_at"]
1022
- ttl = task_ins["ttl"]
997
+ message_ins = rows[0]
998
+ created_at = message_ins["created_at"]
999
+ ttl = message_ins["ttl"]
1023
1000
  current_time = time.time()
1024
1001
 
1025
- # Check if TaskIns is expired
1002
+ # Check if Message is expired
1026
1003
  if ttl is not None and created_at + ttl <= current_time:
1027
1004
  return None
1028
1005
 
1029
- return task_ins
1006
+ return message_ins
1030
1007
 
1031
1008
 
1032
1009
  def dict_factory(
@@ -1041,94 +1018,51 @@ def dict_factory(
1041
1018
  return dict(zip(fields, row))
1042
1019
 
1043
1020
 
1044
- def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]:
1045
- """Transform TaskIns to dict."""
1021
+ def message_to_dict(message: Message) -> dict[str, Any]:
1022
+ """Transform Message to dict."""
1046
1023
  result = {
1047
- "task_id": task_msg.task_id,
1048
- "group_id": task_msg.group_id,
1049
- "run_id": task_msg.run_id,
1050
- "producer_node_id": task_msg.task.producer.node_id,
1051
- "consumer_node_id": task_msg.task.consumer.node_id,
1052
- "created_at": task_msg.task.created_at,
1053
- "delivered_at": task_msg.task.delivered_at,
1054
- "ttl": task_msg.task.ttl,
1055
- "ancestry": ",".join(task_msg.task.ancestry),
1056
- "task_type": task_msg.task.task_type,
1057
- "recordset": task_msg.task.recordset.SerializeToString(),
1024
+ "message_id": message.metadata.message_id,
1025
+ "group_id": message.metadata.group_id,
1026
+ "run_id": message.metadata.run_id,
1027
+ "src_node_id": message.metadata.src_node_id,
1028
+ "dst_node_id": message.metadata.dst_node_id,
1029
+ "reply_to_message": message.metadata.reply_to_message,
1030
+ "created_at": message.metadata.created_at,
1031
+ "delivered_at": message.metadata.delivered_at,
1032
+ "ttl": message.metadata.ttl,
1033
+ "message_type": message.metadata.message_type,
1034
+ "content": None,
1035
+ "error": None,
1058
1036
  }
1059
- return result
1060
1037
 
1038
+ if message.has_content():
1039
+ result["content"] = recordset_to_proto(message.content).SerializeToString()
1040
+ else:
1041
+ result["error"] = error_to_proto(message.error).SerializeToString()
1061
1042
 
1062
- def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
1063
- """Transform TaskRes to dict."""
1064
- result = {
1065
- "task_id": task_msg.task_id,
1066
- "group_id": task_msg.group_id,
1067
- "run_id": task_msg.run_id,
1068
- "producer_node_id": task_msg.task.producer.node_id,
1069
- "consumer_node_id": task_msg.task.consumer.node_id,
1070
- "created_at": task_msg.task.created_at,
1071
- "delivered_at": task_msg.task.delivered_at,
1072
- "ttl": task_msg.task.ttl,
1073
- "ancestry": ",".join(task_msg.task.ancestry),
1074
- "task_type": task_msg.task.task_type,
1075
- "recordset": task_msg.task.recordset.SerializeToString(),
1076
- }
1077
1043
  return result
1078
1044
 
1079
1045
 
1080
- def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
1081
- """Turn task_dict into protobuf message."""
1082
- recordset = ProtoRecordSet()
1083
- recordset.ParseFromString(task_dict["recordset"])
1084
-
1085
- result = TaskIns(
1086
- task_id=task_dict["task_id"],
1087
- group_id=task_dict["group_id"],
1088
- run_id=task_dict["run_id"],
1089
- task=Task(
1090
- producer=Node(
1091
- node_id=task_dict["producer_node_id"],
1092
- ),
1093
- consumer=Node(
1094
- node_id=task_dict["consumer_node_id"],
1095
- ),
1096
- created_at=task_dict["created_at"],
1097
- delivered_at=task_dict["delivered_at"],
1098
- ttl=task_dict["ttl"],
1099
- ancestry=task_dict["ancestry"].split(","),
1100
- task_type=task_dict["task_type"],
1101
- recordset=recordset,
1102
- ),
1103
- )
1104
- return result
1105
-
1106
-
1107
- def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
1108
- """Turn task_dict into protobuf message."""
1109
- recordset = ProtoRecordSet()
1110
- recordset.ParseFromString(task_dict["recordset"])
1111
-
1112
- result = TaskRes(
1113
- task_id=task_dict["task_id"],
1114
- group_id=task_dict["group_id"],
1115
- run_id=task_dict["run_id"],
1116
- task=Task(
1117
- producer=Node(
1118
- node_id=task_dict["producer_node_id"],
1119
- ),
1120
- consumer=Node(
1121
- node_id=task_dict["consumer_node_id"],
1122
- ),
1123
- created_at=task_dict["created_at"],
1124
- delivered_at=task_dict["delivered_at"],
1125
- ttl=task_dict["ttl"],
1126
- ancestry=task_dict["ancestry"].split(","),
1127
- task_type=task_dict["task_type"],
1128
- recordset=recordset,
1129
- ),
1046
+ def dict_to_message(message_dict: dict[str, Any]) -> Message:
1047
+ """Transform dict to Message."""
1048
+ content, error = None, None
1049
+ if (b_content := message_dict.pop("content")) is not None:
1050
+ content = recordset_from_proto(ProtoRecordSet.FromString(b_content))
1051
+ if (b_error := message_dict.pop("error")) is not None:
1052
+ error = error_from_proto(ProtoError.FromString(b_error))
1053
+
1054
+ # Metadata constructor doesn't allow passing created_at. We set it later
1055
+ metadata = Metadata(
1056
+ **{
1057
+ k: v
1058
+ for k, v in message_dict.items()
1059
+ if k not in ["created_at", "delivered_at"]
1060
+ }
1130
1061
  )
1131
- return result
1062
+ msg = Message(metadata=metadata, content=content, error=error)
1063
+ msg.metadata.__dict__["_created_at"] = message_dict["created_at"]
1064
+ msg.metadata.delivered_at = message_dict["delivered_at"]
1065
+ return msg
1132
1066
 
1133
1067
 
1134
1068
  def determine_run_status(row: dict[str, Any]) -> str: