flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. flwr/cli/build.py +2 -0
  2. flwr/cli/log.py +20 -21
  3. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  12. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  13. flwr/cli/run/run.py +5 -9
  14. flwr/client/app.py +6 -4
  15. flwr/client/client_app.py +260 -86
  16. flwr/client/clientapp/app.py +6 -2
  17. flwr/client/grpc_client/connection.py +24 -21
  18. flwr/client/message_handler/message_handler.py +28 -28
  19. flwr/client/mod/__init__.py +2 -2
  20. flwr/client/mod/centraldp_mods.py +7 -7
  21. flwr/client/mod/comms_mods.py +16 -22
  22. flwr/client/mod/localdp_mod.py +4 -4
  23. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  24. flwr/client/rest_client/connection.py +4 -6
  25. flwr/client/run_info_store.py +2 -2
  26. flwr/client/supernode/__init__.py +0 -2
  27. flwr/client/supernode/app.py +1 -11
  28. flwr/common/__init__.py +12 -4
  29. flwr/common/address.py +35 -0
  30. flwr/common/args.py +8 -2
  31. flwr/common/auth_plugin/auth_plugin.py +2 -1
  32. flwr/common/config.py +4 -4
  33. flwr/common/constant.py +16 -0
  34. flwr/common/context.py +4 -4
  35. flwr/common/event_log_plugin/__init__.py +22 -0
  36. flwr/common/event_log_plugin/event_log_plugin.py +60 -0
  37. flwr/common/grpc.py +1 -1
  38. flwr/common/logger.py +2 -2
  39. flwr/common/message.py +338 -102
  40. flwr/common/object_ref.py +0 -10
  41. flwr/common/record/__init__.py +8 -4
  42. flwr/common/record/arrayrecord.py +626 -0
  43. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  44. flwr/common/record/conversion_utils.py +9 -18
  45. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  46. flwr/common/record/recorddict.py +288 -0
  47. flwr/common/recorddict_compat.py +410 -0
  48. flwr/common/secure_aggregation/quantization.py +5 -1
  49. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  50. flwr/common/serde.py +67 -190
  51. flwr/common/telemetry.py +0 -10
  52. flwr/common/typing.py +44 -8
  53. flwr/proto/exec_pb2.py +3 -3
  54. flwr/proto/exec_pb2.pyi +3 -3
  55. flwr/proto/message_pb2.py +12 -12
  56. flwr/proto/message_pb2.pyi +9 -9
  57. flwr/proto/recorddict_pb2.py +70 -0
  58. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  59. flwr/proto/run_pb2.py +31 -31
  60. flwr/proto/run_pb2.pyi +3 -3
  61. flwr/server/__init__.py +3 -1
  62. flwr/server/app.py +74 -3
  63. flwr/server/compat/__init__.py +2 -2
  64. flwr/server/compat/app.py +15 -12
  65. flwr/server/compat/app_utils.py +26 -18
  66. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
  67. flwr/server/fleet_event_log_interceptor.py +94 -0
  68. flwr/server/{driver → grid}/__init__.py +8 -7
  69. flwr/server/{driver/driver.py → grid/grid.py} +48 -19
  70. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
  71. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
  72. flwr/server/run_serverapp.py +6 -17
  73. flwr/server/server_app.py +126 -33
  74. flwr/server/serverapp/app.py +10 -10
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
  76. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
  77. flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
  78. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  79. flwr/server/superlink/fleet/vce/vce_api.py +33 -38
  80. flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
  81. flwr/server/superlink/linkstate/linkstate.py +51 -64
  82. flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
  83. flwr/server/superlink/linkstate/utils.py +171 -133
  84. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  85. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  86. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
  87. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  88. flwr/server/typing.py +3 -3
  89. flwr/server/utils/__init__.py +2 -2
  90. flwr/server/utils/validator.py +53 -68
  91. flwr/server/workflow/default_workflows.py +52 -58
  92. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
  93. flwr/simulation/app.py +2 -2
  94. flwr/simulation/ray_transport/ray_actor.py +4 -2
  95. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  96. flwr/simulation/run_simulation.py +15 -15
  97. flwr/superexec/app.py +0 -14
  98. flwr/superexec/deployment.py +4 -4
  99. flwr/superexec/exec_event_log_interceptor.py +135 -0
  100. flwr/superexec/exec_grpc.py +10 -4
  101. flwr/superexec/exec_servicer.py +6 -6
  102. flwr/superexec/exec_user_auth_interceptor.py +22 -4
  103. flwr/superexec/executor.py +3 -3
  104. flwr/superexec/simulation.py +3 -3
  105. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
  106. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
  107. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
  108. flwr/client/message_handler/task_handler.py +0 -37
  109. flwr/common/record/parametersrecord.py +0 -204
  110. flwr/common/record/recordset.py +0 -202
  111. flwr/common/recordset_compat.py +0 -418
  112. flwr/proto/recordset_pb2.py +0 -70
  113. flwr/proto/task_pb2.py +0 -33
  114. flwr/proto/task_pb2.pyi +0 -100
  115. flwr/proto/task_pb2_grpc.py +0 -4
  116. flwr/proto/task_pb2_grpc.pyi +0 -4
  117. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  118. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  119. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
  120. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
@@ -26,29 +26,37 @@ from logging import DEBUG, ERROR, WARNING
26
26
  from typing import Any, Optional, Union, cast
27
27
  from uuid import UUID, uuid4
28
28
 
29
- from flwr.common import Context, log, now
29
+ from flwr.common import Context, Message, Metadata, log, now
30
30
  from flwr.common.constant import (
31
31
  MESSAGE_TTL_TOLERANCE,
32
32
  NODE_ID_NUM_BYTES,
33
+ PING_PATIENCE,
33
34
  RUN_ID_NUM_BYTES,
34
35
  SUPERLINK_NODE_ID,
35
36
  Status,
36
37
  )
37
- from flwr.common.record import ConfigsRecord
38
+ from flwr.common.message import make_message
39
+ from flwr.common.record import ConfigRecord
40
+ from flwr.common.serde import (
41
+ error_from_proto,
42
+ error_to_proto,
43
+ recorddict_from_proto,
44
+ recorddict_to_proto,
45
+ )
38
46
  from flwr.common.typing import Run, RunStatus, UserConfig
39
47
 
40
48
  # pylint: disable=E0611
41
- from flwr.proto.node_pb2 import Node
42
- from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
43
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
49
+ from flwr.proto.error_pb2 import Error as ProtoError
50
+ from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
44
51
 
45
52
  # pylint: enable=E0611
46
- from flwr.server.utils.validator import validate_task_ins_or_res
53
+ from flwr.server.utils.validator import validate_message
47
54
 
48
55
  from .linkstate import LinkState
49
56
  from .utils import (
50
- configsrecord_from_bytes,
51
- configsrecord_to_bytes,
57
+ check_node_availability_for_in_message,
58
+ configrecord_from_bytes,
59
+ configrecord_to_bytes,
52
60
  context_from_bytes,
53
61
  context_to_bytes,
54
62
  convert_sint64_to_uint64,
@@ -58,8 +66,8 @@ from .utils import (
58
66
  generate_rand_int_from_bytes,
59
67
  has_valid_sub_status,
60
68
  is_valid_transition,
61
- verify_found_taskres,
62
- verify_taskins_ids,
69
+ verify_found_message_replies,
70
+ verify_message_ids,
63
71
  )
64
72
 
65
73
  SQL_CREATE_TABLE_NODE = """
@@ -117,36 +125,39 @@ CREATE TABLE IF NOT EXISTS context(
117
125
  );
118
126
  """
119
127
 
120
- SQL_CREATE_TABLE_TASK_INS = """
121
- CREATE TABLE IF NOT EXISTS task_ins(
122
- task_id TEXT UNIQUE,
128
+ SQL_CREATE_TABLE_MESSAGE_INS = """
129
+ CREATE TABLE IF NOT EXISTS message_ins(
130
+ message_id TEXT UNIQUE,
123
131
  group_id TEXT,
124
132
  run_id INTEGER,
125
- producer_node_id INTEGER,
126
- consumer_node_id INTEGER,
133
+ src_node_id INTEGER,
134
+ dst_node_id INTEGER,
135
+ reply_to_message_id TEXT,
127
136
  created_at REAL,
128
137
  delivered_at TEXT,
129
138
  ttl REAL,
130
- ancestry TEXT,
131
- task_type TEXT,
132
- recordset BLOB,
139
+ message_type TEXT,
140
+ content BLOB NULL,
141
+ error BLOB NULL,
133
142
  FOREIGN KEY(run_id) REFERENCES run(run_id)
134
143
  );
135
144
  """
136
145
 
137
- SQL_CREATE_TABLE_TASK_RES = """
138
- CREATE TABLE IF NOT EXISTS task_res(
139
- task_id TEXT UNIQUE,
146
+
147
+ SQL_CREATE_TABLE_MESSAGE_RES = """
148
+ CREATE TABLE IF NOT EXISTS message_res(
149
+ message_id TEXT UNIQUE,
140
150
  group_id TEXT,
141
151
  run_id INTEGER,
142
- producer_node_id INTEGER,
143
- consumer_node_id INTEGER,
152
+ src_node_id INTEGER,
153
+ dst_node_id INTEGER,
154
+ reply_to_message_id TEXT,
144
155
  created_at REAL,
145
156
  delivered_at TEXT,
146
157
  ttl REAL,
147
- ancestry TEXT,
148
- task_type TEXT,
149
- recordset BLOB,
158
+ message_type TEXT,
159
+ content BLOB NULL,
160
+ error BLOB NULL,
150
161
  FOREIGN KEY(run_id) REFERENCES run(run_id)
151
162
  );
152
163
  """
@@ -196,8 +207,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
196
207
  cur.execute(SQL_CREATE_TABLE_RUN)
197
208
  cur.execute(SQL_CREATE_TABLE_LOGS)
198
209
  cur.execute(SQL_CREATE_TABLE_CONTEXT)
199
- cur.execute(SQL_CREATE_TABLE_TASK_INS)
200
- cur.execute(SQL_CREATE_TABLE_TASK_RES)
210
+ cur.execute(SQL_CREATE_TABLE_MESSAGE_INS)
211
+ cur.execute(SQL_CREATE_TABLE_MESSAGE_RES)
201
212
  cur.execute(SQL_CREATE_TABLE_NODE)
202
213
  cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
203
214
  cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
@@ -239,88 +250,62 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
239
250
 
240
251
  return result
241
252
 
242
- def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
243
- """Store one TaskIns.
244
-
245
- Usually, the ServerAppIo API calls this to schedule instructions.
246
-
247
- Stores the value of the task_ins in the link state and, if successful,
248
- returns the task_id (UUID) of the task_ins. If, for any reason, storing
249
- the task_ins fails, `None` is returned.
250
-
251
- Constraints
252
- -----------
253
-
254
- `task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
255
- """
256
- # Validate task
257
- errors = validate_task_ins_or_res(task_ins)
253
+ def store_message_ins(self, message: Message) -> Optional[UUID]:
254
+ """Store one Message."""
255
+ # Validate message
256
+ errors = validate_message(message=message, is_reply_message=False)
258
257
  if any(errors):
259
258
  log(ERROR, errors)
260
259
  return None
261
- # Create task_id
262
- task_id = uuid4()
260
+ # Create message_id
261
+ message_id = uuid4()
263
262
 
264
- # Store TaskIns
265
- task_ins.task_id = str(task_id)
266
- data = (task_ins_to_dict(task_ins),)
263
+ # Store Message
264
+ # pylint: disable-next=W0212
265
+ message.metadata._message_id = str(message_id) # type: ignore
266
+ data = (message_to_dict(message),)
267
267
 
268
268
  # Convert values from uint64 to sint64 for SQLite
269
269
  convert_uint64_values_in_dict_to_sint64(
270
- data[0], ["run_id", "producer_node_id", "consumer_node_id"]
270
+ data[0], ["run_id", "src_node_id", "dst_node_id"]
271
271
  )
272
272
 
273
273
  # Validate run_id
274
274
  query = "SELECT run_id FROM run WHERE run_id = ?;"
275
275
  if not self.query(query, (data[0]["run_id"],)):
276
- log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
276
+ log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
277
277
  return None
278
+
278
279
  # Validate source node ID
279
- if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
280
+ if message.metadata.src_node_id != SUPERLINK_NODE_ID:
280
281
  log(
281
282
  ERROR,
282
- "Invalid source node ID for TaskIns: %s",
283
- task_ins.task.producer.node_id,
283
+ "Invalid source node ID for Message: %s",
284
+ message.metadata.src_node_id,
284
285
  )
285
286
  return None
287
+
286
288
  # Validate destination node ID
287
289
  query = "SELECT node_id FROM node WHERE node_id = ?;"
288
- if not self.query(query, (data[0]["consumer_node_id"],)):
290
+ if not self.query(query, (data[0]["dst_node_id"],)):
289
291
  log(
290
292
  ERROR,
291
- "Invalid destination node ID for TaskIns: %s",
292
- task_ins.task.consumer.node_id,
293
+ "Invalid destination node ID for Message: %s",
294
+ message.metadata.dst_node_id,
293
295
  )
294
296
  return None
295
297
 
296
298
  columns = ", ".join([f":{key}" for key in data[0]])
297
- query = f"INSERT INTO task_ins VALUES({columns});"
299
+ query = f"INSERT INTO message_ins VALUES({columns});"
298
300
 
299
301
  # Only invalid run_id can trigger IntegrityError.
300
302
  # This may need to be changed in the future version with more integrity checks.
301
303
  self.query(query, data)
302
304
 
303
- return task_id
304
-
305
- def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
306
- """Get undelivered TaskIns for one node.
307
-
308
- Usually, the Fleet API calls this for Nodes planning to work on one or more
309
- TaskIns.
305
+ return message_id
310
306
 
311
- Constraints
312
- -----------
313
- Retrieve all TaskIns where
314
-
315
- 1. the `task_ins.task.consumer.node_id` equals `node_id` AND
316
- 2. the `task_ins.task.delivered_at` equals `""`.
317
-
318
- `delivered_at` MUST BE set (i.e., not `""`) otherwise the TaskIns MUST not be in
319
- the result.
320
-
321
- If `limit` is not `None`, return, at most, `limit` number of `task_ins`. If
322
- `limit` is set, it has to be greater than zero.
323
- """
307
+ def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
308
+ """Get all Messages that have not been delivered yet."""
324
309
  if limit is not None and limit < 1:
325
310
  raise AssertionError("`limit` must be >= 1")
326
311
 
@@ -333,11 +318,11 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
333
318
  # Convert the uint64 value to sint64 for SQLite
334
319
  data["node_id"] = convert_uint64_to_sint64(node_id)
335
320
 
336
- # Retrieve all TaskIns for node_id
321
+ # Retrieve all Messages for node_id
337
322
  query = """
338
- SELECT task_id
339
- FROM task_ins
340
- WHERE consumer_node_id == :node_id
323
+ SELECT message_id
324
+ FROM message_ins
325
+ WHERE dst_node_id == :node_id
341
326
  AND delivered_at = ""
342
327
  AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
343
328
  """
@@ -352,20 +337,20 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
352
337
 
353
338
  if rows:
354
339
  # Prepare query
355
- task_ids = [row["task_id"] for row in rows]
356
- placeholders: str = ",".join([f":id_{i}" for i in range(len(task_ids))])
340
+ message_ids = [row["message_id"] for row in rows]
341
+ placeholders: str = ",".join([f":id_{i}" for i in range(len(message_ids))])
357
342
  query = f"""
358
- UPDATE task_ins
343
+ UPDATE message_ins
359
344
  SET delivered_at = :delivered_at
360
- WHERE task_id IN ({placeholders})
345
+ WHERE message_id IN ({placeholders})
361
346
  RETURNING *;
362
347
  """
363
348
 
364
349
  # Prepare data for query
365
350
  delivered_at = now().isoformat()
366
351
  data = {"delivered_at": delivered_at}
367
- for index, task_id in enumerate(task_ids):
368
- data[f"id_{index}"] = str(task_id)
352
+ for index, msg_id in enumerate(message_ids):
353
+ data[f"id_{index}"] = str(msg_id)
369
354
 
370
355
  # Run query
371
356
  rows = self.query(query, data)
@@ -373,86 +358,80 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
373
358
  for row in rows:
374
359
  # Convert values from sint64 to uint64
375
360
  convert_sint64_values_in_dict_to_uint64(
376
- row, ["run_id", "producer_node_id", "consumer_node_id"]
361
+ row, ["run_id", "src_node_id", "dst_node_id"]
377
362
  )
378
363
 
379
- result = [dict_to_task_ins(row) for row in rows]
364
+ result = [dict_to_message(row) for row in rows]
380
365
 
381
366
  return result
382
367
 
383
- def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
384
- """Store one TaskRes.
385
-
386
- Usually, the Fleet API calls this when Nodes return their results.
387
-
388
- Stores the TaskRes and, if successful, returns the `task_id` (UUID) of
389
- the `task_res`. If storing the `task_res` fails, `None` is returned.
390
-
391
- Constraints
392
- -----------
393
- `task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
394
- """
395
- # Validate task
396
- errors = validate_task_ins_or_res(task_res)
368
+ def store_message_res(self, message: Message) -> Optional[UUID]:
369
+ """Store one Message."""
370
+ # Validate message
371
+ errors = validate_message(message=message, is_reply_message=True)
397
372
  if any(errors):
398
373
  log(ERROR, errors)
399
374
  return None
400
375
 
401
- # Create task_id
402
- task_id = uuid4()
403
-
404
- task_ins_id = task_res.task.ancestry[0]
405
- task_ins = self.get_valid_task_ins(task_ins_id)
406
- if task_ins is None:
376
+ res_metadata = message.metadata
377
+ msg_ins_id = res_metadata.reply_to_message_id
378
+ msg_ins = self.get_valid_message_ins(msg_ins_id)
379
+ if msg_ins is None:
407
380
  log(
408
381
  ERROR,
409
- "Failed to store TaskRes: "
410
- "TaskIns with task_id %s does not exist or has expired.",
411
- task_ins_id,
382
+ "Failed to store Message reply: "
383
+ "The message it replies to with message_id %s does not exist or "
384
+ "has expired.",
385
+ msg_ins_id,
412
386
  )
413
387
  return None
414
388
 
415
- # Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
389
+ # Ensure that the dst_node_id of the original message matches the src_node_id of
390
+ # reply being processed.
416
391
  if (
417
- task_ins
418
- and task_res
419
- and convert_sint64_to_uint64(task_ins["consumer_node_id"])
420
- != task_res.task.producer.node_id
392
+ msg_ins
393
+ and message
394
+ and convert_sint64_to_uint64(msg_ins["dst_node_id"])
395
+ != res_metadata.src_node_id
421
396
  ):
422
397
  return None
423
398
 
424
- # Fail if the TaskRes TTL exceeds the
425
- # expiration time of the TaskIns it replies to.
426
- # Condition: TaskIns.created_at + TaskIns.ttl ≥
427
- # TaskRes.created_at + TaskRes.ttl
399
+ # Fail if the Message TTL exceeds the
400
+ # expiration time of the Message it replies to.
401
+ # Condition: ins_metadata.created_at + ins_metadata.ttl ≥
402
+ # res_metadata.created_at + res_metadata.ttl
428
403
  # A small tolerance is introduced to account
429
404
  # for floating-point precision issues.
430
405
  max_allowed_ttl = (
431
- task_ins["created_at"] + task_ins["ttl"] - task_res.task.created_at
406
+ msg_ins["created_at"] + msg_ins["ttl"] - res_metadata.created_at
432
407
  )
433
- if task_res.task.ttl and (
434
- task_res.task.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
408
+ if res_metadata.ttl and (
409
+ res_metadata.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
435
410
  ):
436
411
  log(
437
412
  WARNING,
438
- "Received TaskRes with TTL %.2f "
439
- "exceeding the allowed maximum TTL %.2f.",
440
- task_res.task.ttl,
413
+ "Received Message with TTL %.2f exceeding the allowed maximum "
414
+ "TTL %.2f.",
415
+ res_metadata.ttl,
441
416
  max_allowed_ttl,
442
417
  )
443
418
  return None
444
419
 
445
- # Store TaskRes
446
- task_res.task_id = str(task_id)
447
- data = (task_res_to_dict(task_res),)
420
+ # Create message_id
421
+ message_id = uuid4()
422
+
423
+ # Store Message
424
+ # pylint: disable-next=W0212
425
+ message.metadata._message_id = str(message_id) # type: ignore
426
+ data = (message_to_dict(message),)
448
427
 
449
428
  # Convert values from uint64 to sint64 for SQLite
450
429
  convert_uint64_values_in_dict_to_sint64(
451
- data[0], ["run_id", "producer_node_id", "consumer_node_id"]
430
+ data[0], ["run_id", "src_node_id", "dst_node_id"]
452
431
  )
453
432
 
454
433
  columns = ", ".join([f":{key}" for key in data[0]])
455
- query = f"INSERT INTO task_res VALUES({columns});"
434
+ query = f"INSERT INTO message_res VALUES({columns});"
456
435
 
457
436
  # Only invalid run_id can trigger IntegrityError.
458
437
  # This may need to be changed in the future version with more integrity checks.
@@ -462,124 +441,149 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
462
441
  log(ERROR, "`run` is invalid")
463
442
  return None
464
443
 
465
- return task_id
444
+ return message_id
466
445
 
467
- # pylint: disable-next=R0912,R0915,R0914
468
- def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
469
- """Get TaskRes for the given TaskIns IDs."""
470
- ret: dict[UUID, TaskRes] = {}
446
+ def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
447
+ """Get reply Messages for the given Message IDs."""
448
+ # pylint: disable-msg=too-many-locals
449
+ ret: dict[UUID, Message] = {}
471
450
 
472
- # Verify TaskIns IDs
451
+ # Verify Message IDs
473
452
  current = time.time()
474
453
  query = f"""
475
454
  SELECT *
476
- FROM task_ins
477
- WHERE task_id IN ({",".join(["?"] * len(task_ids))});
455
+ FROM message_ins
456
+ WHERE message_id IN ({",".join(["?"] * len(message_ids))});
478
457
  """
479
- rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
480
- found_task_ins_dict: dict[UUID, TaskIns] = {}
458
+ rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
459
+ found_message_ins_dict: dict[UUID, Message] = {}
481
460
  for row in rows:
482
461
  convert_sint64_values_in_dict_to_uint64(
483
- row, ["run_id", "producer_node_id", "consumer_node_id"]
462
+ row, ["run_id", "src_node_id", "dst_node_id"]
484
463
  )
485
- found_task_ins_dict[UUID(row["task_id"])] = dict_to_task_ins(row)
464
+ found_message_ins_dict[UUID(row["message_id"])] = dict_to_message(row)
465
+
466
+ ret = verify_message_ids(
467
+ inquired_message_ids=message_ids,
468
+ found_message_ins_dict=found_message_ins_dict,
469
+ current_time=current,
470
+ )
486
471
 
487
- ret = verify_taskins_ids(
488
- inquired_taskins_ids=task_ids,
489
- found_taskins_dict=found_task_ins_dict,
472
+ # Check node availability
473
+ dst_node_ids: set[int] = set()
474
+ for message_id in message_ids:
475
+ in_message = found_message_ins_dict[message_id]
476
+ sint_node_id = convert_uint64_to_sint64(in_message.metadata.dst_node_id)
477
+ dst_node_ids.add(sint_node_id)
478
+ query = f"""
479
+ SELECT node_id, online_until
480
+ FROM node
481
+ WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))});
482
+ """
483
+ rows = self.query(query, tuple(dst_node_ids))
484
+ tmp_ret_dict = check_node_availability_for_in_message(
485
+ inquired_in_message_ids=message_ids,
486
+ found_in_message_dict=found_message_ins_dict,
487
+ node_id_to_online_until={
488
+ convert_sint64_to_uint64(row["node_id"]): row["online_until"]
489
+ for row in rows
490
+ },
490
491
  current_time=current,
491
492
  )
493
+ ret.update(tmp_ret_dict)
492
494
 
493
- # Find all TaskRes
495
+ # Find all reply Messages
494
496
  query = f"""
495
497
  SELECT *
496
- FROM task_res
497
- WHERE ancestry IN ({",".join(["?"] * len(task_ids))})
498
+ FROM message_res
499
+ WHERE reply_to_message_id IN ({",".join(["?"] * len(message_ids))})
498
500
  AND delivered_at = "";
499
501
  """
500
- rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
502
+ rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
501
503
  for row in rows:
502
504
  convert_sint64_values_in_dict_to_uint64(
503
- row, ["run_id", "producer_node_id", "consumer_node_id"]
505
+ row, ["run_id", "src_node_id", "dst_node_id"]
504
506
  )
505
- tmp_ret_dict = verify_found_taskres(
506
- inquired_taskins_ids=task_ids,
507
- found_taskins_dict=found_task_ins_dict,
508
- found_taskres_list=[dict_to_task_res(row) for row in rows],
507
+ tmp_ret_dict = verify_found_message_replies(
508
+ inquired_message_ids=message_ids,
509
+ found_message_ins_dict=found_message_ins_dict,
510
+ found_message_res_list=[dict_to_message(row) for row in rows],
509
511
  current_time=current,
510
512
  )
511
513
  ret.update(tmp_ret_dict)
512
514
 
513
- # Mark existing TaskRes to be returned as delivered
515
+ # Mark existing reply Messages to be returned as delivered
514
516
  delivered_at = now().isoformat()
515
- for task_res in ret.values():
516
- task_res.task.delivered_at = delivered_at
517
- task_res_ids = [task_res.task_id for task_res in ret.values()]
517
+ for message_res in ret.values():
518
+ message_res.metadata.delivered_at = delivered_at
519
+ message_res_ids = [
520
+ message_res.metadata.message_id for message_res in ret.values()
521
+ ]
518
522
  query = f"""
519
- UPDATE task_res
523
+ UPDATE message_res
520
524
  SET delivered_at = ?
521
- WHERE task_id IN ({",".join(["?"] * len(task_res_ids))});
525
+ WHERE message_id IN ({",".join(["?"] * len(message_res_ids))});
522
526
  """
523
- data: list[Any] = [delivered_at] + task_res_ids
527
+ data: list[Any] = [delivered_at] + message_res_ids
524
528
  self.query(query, data)
525
529
 
526
530
  return list(ret.values())
527
531
 
528
- def num_task_ins(self) -> int:
529
- """Calculate the number of task_ins in store.
532
+ def num_message_ins(self) -> int:
533
+ """Calculate the number of instruction Messages in store.
530
534
 
531
- This includes delivered but not yet deleted task_ins.
535
+ This includes delivered but not yet deleted.
532
536
  """
533
- query = "SELECT count(*) AS num FROM task_ins;"
537
+ query = "SELECT count(*) AS num FROM message_ins;"
534
538
  rows = self.query(query)
535
539
  result = rows[0]
536
540
  num = cast(int, result["num"])
537
541
  return num
538
542
 
539
- def num_task_res(self) -> int:
540
- """Calculate the number of task_res in store.
543
+ def num_message_res(self) -> int:
544
+ """Calculate the number of reply Messages in store.
541
545
 
542
- This includes delivered but not yet deleted task_res.
546
+ This includes delivered but not yet deleted.
543
547
  """
544
- query = "SELECT count(*) AS num FROM task_res;"
548
+ query = "SELECT count(*) AS num FROM message_res;"
545
549
  rows = self.query(query)
546
550
  result: dict[str, int] = rows[0]
547
551
  return result["num"]
548
552
 
549
- def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
550
- """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
551
- if not task_ins_ids:
553
+ def delete_messages(self, message_ins_ids: set[UUID]) -> None:
554
+ """Delete a Message and its reply based on provided Message IDs."""
555
+ if not message_ins_ids:
552
556
  return
553
557
  if self.conn is None:
554
558
  raise AttributeError("LinkState not initialized")
555
559
 
556
- placeholders = ",".join(["?"] * len(task_ins_ids))
557
- data = tuple(str(task_id) for task_id in task_ins_ids)
560
+ placeholders = ",".join(["?"] * len(message_ins_ids))
561
+ data = tuple(str(message_id) for message_id in message_ins_ids)
558
562
 
559
- # Delete task_ins
563
+ # Delete Message
560
564
  query_1 = f"""
561
- DELETE FROM task_ins
562
- WHERE task_id IN ({placeholders});
565
+ DELETE FROM message_ins
566
+ WHERE message_id IN ({placeholders});
563
567
  """
564
568
 
565
- # Delete task_res
569
+ # Delete reply Message
566
570
  query_2 = f"""
567
- DELETE FROM task_res
568
- WHERE ancestry IN ({placeholders});
571
+ DELETE FROM message_res
572
+ WHERE reply_to_message_id IN ({placeholders});
569
573
  """
570
574
 
571
575
  with self.conn:
572
576
  self.conn.execute(query_1, data)
573
577
  self.conn.execute(query_2, data)
574
578
 
575
- def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
576
- """Get all TaskIns IDs for the given run_id."""
579
+ def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
580
+ """Get all instruction Message IDs for the given run_id."""
577
581
  if self.conn is None:
578
582
  raise AttributeError("LinkState not initialized")
579
583
 
580
584
  query = """
581
- SELECT task_id
582
- FROM task_ins
585
+ SELECT message_id
586
+ FROM message_ins
583
587
  WHERE run_id = :run_id;
584
588
  """
585
589
 
@@ -589,7 +593,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
589
593
  with self.conn:
590
594
  rows = self.conn.execute(query, data).fetchall()
591
595
 
592
- return {UUID(row["task_id"]) for row in rows}
596
+ return {UUID(row["message_id"]) for row in rows}
593
597
 
594
598
  def create_node(self, ping_interval: float) -> int:
595
599
  """Create, store in the link state, and return `node_id`."""
@@ -607,6 +611,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
607
611
  "VALUES (?, ?, ?, ?)"
608
612
  )
609
613
 
614
+ # Mark the node online util time.time() + ping_interval
610
615
  try:
611
616
  self.query(
612
617
  query,
@@ -722,7 +727,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
722
727
  fab_version: Optional[str],
723
728
  fab_hash: Optional[str],
724
729
  override_config: UserConfig,
725
- federation_options: ConfigsRecord,
730
+ federation_options: ConfigRecord,
726
731
  ) -> int:
727
732
  """Create a new run for the specified `fab_id` and `fab_version`."""
728
733
  # Sample a random int64 as run_id
@@ -748,7 +753,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
748
753
  fab_version,
749
754
  fab_hash,
750
755
  override_config_json,
751
- configsrecord_to_bytes(federation_options),
756
+ configrecord_to_bytes(federation_options),
752
757
  ]
753
758
  data += [
754
759
  now().isoformat(),
@@ -906,7 +911,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
906
911
 
907
912
  return pending_run_id
908
913
 
909
- def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
914
+ def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
910
915
  """Retrieve the federation options for the specified `run_id`."""
911
916
  # Convert the uint64 value to sint64 for SQLite
912
917
  sint64_run_id = convert_uint64_to_sint64(run_id)
@@ -919,10 +924,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
919
924
  return None
920
925
 
921
926
  row = rows[0]
922
- return configsrecord_from_bytes(row["federation_options"])
927
+ return configrecord_from_bytes(row["federation_options"])
923
928
 
924
929
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
925
- """Acknowledge a ping received from a node, serving as a heartbeat."""
930
+ """Acknowledge a ping received from a node, serving as a heartbeat.
931
+
932
+ It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
933
+ marking the node as offline, where PING_PATIENCE = 2 in default.
934
+ """
926
935
  sint64_node_id = convert_uint64_to_sint64(node_id)
927
936
 
928
937
  # Check if the node exists in the `node` table
@@ -932,7 +941,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
932
941
 
933
942
  # Update `online_until` and `ping_interval` for the given `node_id`
934
943
  query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
935
- self.query(query, (time.time() + ping_interval, ping_interval, sint64_node_id))
944
+ self.query(
945
+ query,
946
+ (
947
+ time.time() + PING_PATIENCE * ping_interval,
948
+ ping_interval,
949
+ sint64_node_id,
950
+ ),
951
+ )
936
952
  return True
937
953
 
938
954
  def get_serverapp_context(self, run_id: int) -> Optional[Context]:
@@ -1001,32 +1017,32 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1001
1017
  latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
1002
1018
  return "".join(row["log"] for row in rows), latest_timestamp
1003
1019
 
1004
- def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
1005
- """Check if the TaskIns exists and is valid (not expired).
1020
+ def get_valid_message_ins(self, message_id: str) -> Optional[dict[str, Any]]:
1021
+ """Check if the Message exists and is valid (not expired).
1006
1022
 
1007
- Return TaskIns if valid.
1023
+ Return Message if valid.
1008
1024
  """
1009
1025
  query = """
1010
1026
  SELECT *
1011
- FROM task_ins
1012
- WHERE task_id = :task_id
1027
+ FROM message_ins
1028
+ WHERE message_id = :message_id
1013
1029
  """
1014
- data = {"task_id": task_id}
1030
+ data = {"message_id": message_id}
1015
1031
  rows = self.query(query, data)
1016
1032
  if not rows:
1017
- # TaskIns does not exist
1033
+ # Message does not exist
1018
1034
  return None
1019
1035
 
1020
- task_ins = rows[0]
1021
- created_at = task_ins["created_at"]
1022
- ttl = task_ins["ttl"]
1036
+ message_ins = rows[0]
1037
+ created_at = message_ins["created_at"]
1038
+ ttl = message_ins["ttl"]
1023
1039
  current_time = time.time()
1024
1040
 
1025
- # Check if TaskIns is expired
1041
+ # Check if Message is expired
1026
1042
  if ttl is not None and created_at + ttl <= current_time:
1027
1043
  return None
1028
1044
 
1029
- return task_ins
1045
+ return message_ins
1030
1046
 
1031
1047
 
1032
1048
  def dict_factory(
@@ -1041,94 +1057,46 @@ def dict_factory(
1041
1057
  return dict(zip(fields, row))
1042
1058
 
1043
1059
 
1044
- def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]:
1045
- """Transform TaskIns to dict."""
1046
- result = {
1047
- "task_id": task_msg.task_id,
1048
- "group_id": task_msg.group_id,
1049
- "run_id": task_msg.run_id,
1050
- "producer_node_id": task_msg.task.producer.node_id,
1051
- "consumer_node_id": task_msg.task.consumer.node_id,
1052
- "created_at": task_msg.task.created_at,
1053
- "delivered_at": task_msg.task.delivered_at,
1054
- "ttl": task_msg.task.ttl,
1055
- "ancestry": ",".join(task_msg.task.ancestry),
1056
- "task_type": task_msg.task.task_type,
1057
- "recordset": task_msg.task.recordset.SerializeToString(),
1058
- }
1059
- return result
1060
-
1061
-
1062
- def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
1063
- """Transform TaskRes to dict."""
1060
+ def message_to_dict(message: Message) -> dict[str, Any]:
1061
+ """Transform Message to dict."""
1064
1062
  result = {
1065
- "task_id": task_msg.task_id,
1066
- "group_id": task_msg.group_id,
1067
- "run_id": task_msg.run_id,
1068
- "producer_node_id": task_msg.task.producer.node_id,
1069
- "consumer_node_id": task_msg.task.consumer.node_id,
1070
- "created_at": task_msg.task.created_at,
1071
- "delivered_at": task_msg.task.delivered_at,
1072
- "ttl": task_msg.task.ttl,
1073
- "ancestry": ",".join(task_msg.task.ancestry),
1074
- "task_type": task_msg.task.task_type,
1075
- "recordset": task_msg.task.recordset.SerializeToString(),
1063
+ "message_id": message.metadata.message_id,
1064
+ "group_id": message.metadata.group_id,
1065
+ "run_id": message.metadata.run_id,
1066
+ "src_node_id": message.metadata.src_node_id,
1067
+ "dst_node_id": message.metadata.dst_node_id,
1068
+ "reply_to_message_id": message.metadata.reply_to_message_id,
1069
+ "created_at": message.metadata.created_at,
1070
+ "delivered_at": message.metadata.delivered_at,
1071
+ "ttl": message.metadata.ttl,
1072
+ "message_type": message.metadata.message_type,
1073
+ "content": None,
1074
+ "error": None,
1076
1075
  }
1077
- return result
1078
-
1079
1076
 
1080
- def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
1081
- """Turn task_dict into protobuf message."""
1082
- recordset = ProtoRecordSet()
1083
- recordset.ParseFromString(task_dict["recordset"])
1077
+ if message.has_content():
1078
+ result["content"] = recorddict_to_proto(message.content).SerializeToString()
1079
+ else:
1080
+ result["error"] = error_to_proto(message.error).SerializeToString()
1084
1081
 
1085
- result = TaskIns(
1086
- task_id=task_dict["task_id"],
1087
- group_id=task_dict["group_id"],
1088
- run_id=task_dict["run_id"],
1089
- task=Task(
1090
- producer=Node(
1091
- node_id=task_dict["producer_node_id"],
1092
- ),
1093
- consumer=Node(
1094
- node_id=task_dict["consumer_node_id"],
1095
- ),
1096
- created_at=task_dict["created_at"],
1097
- delivered_at=task_dict["delivered_at"],
1098
- ttl=task_dict["ttl"],
1099
- ancestry=task_dict["ancestry"].split(","),
1100
- task_type=task_dict["task_type"],
1101
- recordset=recordset,
1102
- ),
1103
- )
1104
1082
  return result
1105
1083
 
1106
1084
 
1107
- def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
1108
- """Turn task_dict into protobuf message."""
1109
- recordset = ProtoRecordSet()
1110
- recordset.ParseFromString(task_dict["recordset"])
1085
+ def dict_to_message(message_dict: dict[str, Any]) -> Message:
1086
+ """Transform dict to Message."""
1087
+ content, error = None, None
1088
+ if (b_content := message_dict.pop("content")) is not None:
1089
+ content = recorddict_from_proto(ProtoRecordDict.FromString(b_content))
1090
+ if (b_error := message_dict.pop("error")) is not None:
1091
+ error = error_from_proto(ProtoError.FromString(b_error))
1111
1092
 
1112
- result = TaskRes(
1113
- task_id=task_dict["task_id"],
1114
- group_id=task_dict["group_id"],
1115
- run_id=task_dict["run_id"],
1116
- task=Task(
1117
- producer=Node(
1118
- node_id=task_dict["producer_node_id"],
1119
- ),
1120
- consumer=Node(
1121
- node_id=task_dict["consumer_node_id"],
1122
- ),
1123
- created_at=task_dict["created_at"],
1124
- delivered_at=task_dict["delivered_at"],
1125
- ttl=task_dict["ttl"],
1126
- ancestry=task_dict["ancestry"].split(","),
1127
- task_type=task_dict["task_type"],
1128
- recordset=recordset,
1129
- ),
1093
+ # Metadata constructor doesn't allow passing created_at. We set it later
1094
+ metadata = Metadata(
1095
+ **{k: v for k, v in message_dict.items() if k not in ["delivered_at"]}
1130
1096
  )
1131
- return result
1097
+ msg = make_message(metadata=metadata, content=content, error=error)
1098
+ msg.metadata.delivered_at = message_dict["delivered_at"]
1099
+ return msg
1132
1100
 
1133
1101
 
1134
1102
  def determine_run_status(row: dict[str, Any]) -> str: