flwr-nightly 1.16.0.dev20250305__py3-none-any.whl → 1.16.0.dev20250307__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. flwr/client/message_handler/message_handler.py +1 -1
  2. flwr/client/rest_client/connection.py +4 -6
  3. flwr/common/message.py +7 -7
  4. flwr/common/record/recordset.py +4 -12
  5. flwr/common/serde.py +8 -126
  6. flwr/server/compat/driver_client_proxy.py +2 -2
  7. flwr/server/driver/driver.py +1 -1
  8. flwr/server/driver/grpc_driver.py +1 -1
  9. flwr/server/driver/inmemory_driver.py +17 -20
  10. flwr/server/superlink/driver/serverappio_servicer.py +18 -23
  11. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
  12. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
  13. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  14. flwr/server/superlink/fleet/vce/vce_api.py +32 -35
  15. flwr/server/superlink/linkstate/in_memory_linkstate.py +1 -221
  16. flwr/server/superlink/linkstate/linkstate.py +0 -113
  17. flwr/server/superlink/linkstate/sqlite_linkstate.py +2 -511
  18. flwr/server/superlink/linkstate/utils.py +2 -179
  19. flwr/server/utils/__init__.py +0 -2
  20. flwr/server/utils/validator.py +0 -88
  21. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +3 -3
  22. flwr/superexec/exec_servicer.py +3 -3
  23. {flwr_nightly-1.16.0.dev20250305.dist-info → flwr_nightly-1.16.0.dev20250307.dist-info}/METADATA +1 -1
  24. {flwr_nightly-1.16.0.dev20250305.dist-info → flwr_nightly-1.16.0.dev20250307.dist-info}/RECORD +27 -32
  25. flwr/client/message_handler/task_handler.py +0 -37
  26. flwr/proto/task_pb2.py +0 -33
  27. flwr/proto/task_pb2.pyi +0 -100
  28. flwr/proto/task_pb2_grpc.py +0 -4
  29. flwr/proto/task_pb2_grpc.pyi +0 -4
  30. {flwr_nightly-1.16.0.dev20250305.dist-info → flwr_nightly-1.16.0.dev20250307.dist-info}/LICENSE +0 -0
  31. {flwr_nightly-1.16.0.dev20250305.dist-info → flwr_nightly-1.16.0.dev20250307.dist-info}/WHEEL +0 -0
  32. {flwr_nightly-1.16.0.dev20250305.dist-info → flwr_nightly-1.16.0.dev20250307.dist-info}/entry_points.txt +0 -0
@@ -45,12 +45,10 @@ from flwr.common.typing import Run, RunStatus, UserConfig
45
45
 
46
46
  # pylint: disable=E0611
47
47
  from flwr.proto.error_pb2 import Error as ProtoError
48
- from flwr.proto.node_pb2 import Node
49
48
  from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
50
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
51
49
 
52
50
  # pylint: enable=E0611
53
- from flwr.server.utils.validator import validate_message, validate_task_ins_or_res
51
+ from flwr.server.utils.validator import validate_message
54
52
 
55
53
  from .linkstate import LinkState
56
54
  from .utils import (
@@ -66,9 +64,7 @@ from .utils import (
66
64
  has_valid_sub_status,
67
65
  is_valid_transition,
68
66
  verify_found_message_replies,
69
- verify_found_taskres,
70
67
  verify_message_ids,
71
- verify_taskins_ids,
72
68
  )
73
69
 
74
70
  SQL_CREATE_TABLE_NODE = """
@@ -126,23 +122,6 @@ CREATE TABLE IF NOT EXISTS context(
126
122
  );
127
123
  """
128
124
 
129
- SQL_CREATE_TABLE_TASK_INS = """
130
- CREATE TABLE IF NOT EXISTS task_ins(
131
- task_id TEXT UNIQUE,
132
- group_id TEXT,
133
- run_id INTEGER,
134
- producer_node_id INTEGER,
135
- consumer_node_id INTEGER,
136
- created_at REAL,
137
- delivered_at TEXT,
138
- ttl REAL,
139
- ancestry TEXT,
140
- task_type TEXT,
141
- recordset BLOB,
142
- FOREIGN KEY(run_id) REFERENCES run(run_id)
143
- );
144
- """
145
-
146
125
  SQL_CREATE_TABLE_MESSAGE_INS = """
147
126
  CREATE TABLE IF NOT EXISTS message_ins(
148
127
  message_id TEXT UNIQUE,
@@ -161,23 +140,6 @@ CREATE TABLE IF NOT EXISTS message_ins(
161
140
  );
162
141
  """
163
142
 
164
- SQL_CREATE_TABLE_TASK_RES = """
165
- CREATE TABLE IF NOT EXISTS task_res(
166
- task_id TEXT UNIQUE,
167
- group_id TEXT,
168
- run_id INTEGER,
169
- producer_node_id INTEGER,
170
- consumer_node_id INTEGER,
171
- created_at REAL,
172
- delivered_at TEXT,
173
- ttl REAL,
174
- ancestry TEXT,
175
- task_type TEXT,
176
- recordset BLOB,
177
- FOREIGN KEY(run_id) REFERENCES run(run_id)
178
- );
179
- """
180
-
181
143
 
182
144
  SQL_CREATE_TABLE_MESSAGE_RES = """
183
145
  CREATE TABLE IF NOT EXISTS message_res(
@@ -242,8 +204,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
242
204
  cur.execute(SQL_CREATE_TABLE_RUN)
243
205
  cur.execute(SQL_CREATE_TABLE_LOGS)
244
206
  cur.execute(SQL_CREATE_TABLE_CONTEXT)
245
- cur.execute(SQL_CREATE_TABLE_TASK_INS)
246
- cur.execute(SQL_CREATE_TABLE_TASK_RES)
247
207
  cur.execute(SQL_CREATE_TABLE_MESSAGE_INS)
248
208
  cur.execute(SQL_CREATE_TABLE_MESSAGE_RES)
249
209
  cur.execute(SQL_CREATE_TABLE_NODE)
@@ -287,69 +247,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
287
247
 
288
248
  return result
289
249
 
290
- def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
291
- """Store one TaskIns.
292
-
293
- Usually, the ServerAppIo API calls this to schedule instructions.
294
-
295
- Stores the value of the task_ins in the link state and, if successful,
296
- returns the task_id (UUID) of the task_ins. If, for any reason, storing
297
- the task_ins fails, `None` is returned.
298
-
299
- Constraints
300
- -----------
301
-
302
- `task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
303
- """
304
- # Validate task
305
- errors = validate_task_ins_or_res(task_ins)
306
- if any(errors):
307
- log(ERROR, errors)
308
- return None
309
- # Create task_id
310
- task_id = uuid4()
311
-
312
- # Store TaskIns
313
- task_ins.task_id = str(task_id)
314
- data = (task_ins_to_dict(task_ins),)
315
-
316
- # Convert values from uint64 to sint64 for SQLite
317
- convert_uint64_values_in_dict_to_sint64(
318
- data[0], ["run_id", "producer_node_id", "consumer_node_id"]
319
- )
320
-
321
- # Validate run_id
322
- query = "SELECT run_id FROM run WHERE run_id = ?;"
323
- if not self.query(query, (data[0]["run_id"],)):
324
- log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
325
- return None
326
- # Validate source node ID
327
- if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
328
- log(
329
- ERROR,
330
- "Invalid source node ID for TaskIns: %s",
331
- task_ins.task.producer.node_id,
332
- )
333
- return None
334
- # Validate destination node ID
335
- query = "SELECT node_id FROM node WHERE node_id = ?;"
336
- if not self.query(query, (data[0]["consumer_node_id"],)):
337
- log(
338
- ERROR,
339
- "Invalid destination node ID for TaskIns: %s",
340
- task_ins.task.consumer.node_id,
341
- )
342
- return None
343
-
344
- columns = ", ".join([f":{key}" for key in data[0]])
345
- query = f"INSERT INTO task_ins VALUES({columns});"
346
-
347
- # Only invalid run_id can trigger IntegrityError.
348
- # This may need to be changed in the future version with more integrity checks.
349
- self.query(query, data)
350
-
351
- return task_id
352
-
353
250
  def store_message_ins(self, message: Message) -> Optional[UUID]:
354
251
  """Store one Message."""
355
252
  # Validate message
@@ -404,84 +301,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
404
301
 
405
302
  return message_id
406
303
 
407
- def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
408
- """Get undelivered TaskIns for one node.
409
-
410
- Usually, the Fleet API calls this for Nodes planning to work on one or more
411
- TaskIns.
412
-
413
- Constraints
414
- -----------
415
- Retrieve all TaskIns where
416
-
417
- 1. the `task_ins.task.consumer.node_id` equals `node_id` AND
418
- 2. the `task_ins.task.delivered_at` equals `""`.
419
-
420
- `delivered_at` MUST BE set (i.e., not `""`) otherwise the TaskIns MUST not be in
421
- the result.
422
-
423
- If `limit` is not `None`, return, at most, `limit` number of `task_ins`. If
424
- `limit` is set, it has to be greater than zero.
425
- """
426
- if limit is not None and limit < 1:
427
- raise AssertionError("`limit` must be >= 1")
428
-
429
- if node_id == SUPERLINK_NODE_ID:
430
- msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
431
- raise AssertionError(msg)
432
-
433
- data: dict[str, Union[str, int]] = {}
434
-
435
- # Convert the uint64 value to sint64 for SQLite
436
- data["node_id"] = convert_uint64_to_sint64(node_id)
437
-
438
- # Retrieve all TaskIns for node_id
439
- query = """
440
- SELECT task_id
441
- FROM task_ins
442
- WHERE consumer_node_id == :node_id
443
- AND delivered_at = ""
444
- AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
445
- """
446
-
447
- if limit is not None:
448
- query += " LIMIT :limit"
449
- data["limit"] = limit
450
-
451
- query += ";"
452
-
453
- rows = self.query(query, data)
454
-
455
- if rows:
456
- # Prepare query
457
- task_ids = [row["task_id"] for row in rows]
458
- placeholders: str = ",".join([f":id_{i}" for i in range(len(task_ids))])
459
- query = f"""
460
- UPDATE task_ins
461
- SET delivered_at = :delivered_at
462
- WHERE task_id IN ({placeholders})
463
- RETURNING *;
464
- """
465
-
466
- # Prepare data for query
467
- delivered_at = now().isoformat()
468
- data = {"delivered_at": delivered_at}
469
- for index, task_id in enumerate(task_ids):
470
- data[f"id_{index}"] = str(task_id)
471
-
472
- # Run query
473
- rows = self.query(query, data)
474
-
475
- for row in rows:
476
- # Convert values from sint64 to uint64
477
- convert_sint64_values_in_dict_to_uint64(
478
- row, ["run_id", "producer_node_id", "consumer_node_id"]
479
- )
480
-
481
- result = [dict_to_task_ins(row) for row in rows]
482
-
483
- return result
484
-
485
304
  def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
486
305
  """Get all Messages that have not been delivered yet."""
487
306
  if limit is not None and limit < 1:
@@ -543,90 +362,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
543
362
 
544
363
  return result
545
364
 
546
- def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
547
- """Store one TaskRes.
548
-
549
- Usually, the Fleet API calls this when Nodes return their results.
550
-
551
- Stores the TaskRes and, if successful, returns the `task_id` (UUID) of
552
- the `task_res`. If storing the `task_res` fails, `None` is returned.
553
-
554
- Constraints
555
- -----------
556
- `task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
557
- """
558
- # Validate task
559
- errors = validate_task_ins_or_res(task_res)
560
- if any(errors):
561
- log(ERROR, errors)
562
- return None
563
-
564
- # Create task_id
565
- task_id = uuid4()
566
-
567
- task_ins_id = task_res.task.ancestry[0]
568
- task_ins = self.get_valid_task_ins(task_ins_id)
569
- if task_ins is None:
570
- log(
571
- ERROR,
572
- "Failed to store TaskRes: "
573
- "TaskIns with task_id %s does not exist or has expired.",
574
- task_ins_id,
575
- )
576
- return None
577
-
578
- # Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
579
- if (
580
- task_ins
581
- and task_res
582
- and convert_sint64_to_uint64(task_ins["consumer_node_id"])
583
- != task_res.task.producer.node_id
584
- ):
585
- return None
586
-
587
- # Fail if the TaskRes TTL exceeds the
588
- # expiration time of the TaskIns it replies to.
589
- # Condition: TaskIns.created_at + TaskIns.ttl ≥
590
- # TaskRes.created_at + TaskRes.ttl
591
- # A small tolerance is introduced to account
592
- # for floating-point precision issues.
593
- max_allowed_ttl = (
594
- task_ins["created_at"] + task_ins["ttl"] - task_res.task.created_at
595
- )
596
- if task_res.task.ttl and (
597
- task_res.task.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
598
- ):
599
- log(
600
- WARNING,
601
- "Received TaskRes with TTL %.2f "
602
- "exceeding the allowed maximum TTL %.2f.",
603
- task_res.task.ttl,
604
- max_allowed_ttl,
605
- )
606
- return None
607
-
608
- # Store TaskRes
609
- task_res.task_id = str(task_id)
610
- data = (task_res_to_dict(task_res),)
611
-
612
- # Convert values from uint64 to sint64 for SQLite
613
- convert_uint64_values_in_dict_to_sint64(
614
- data[0], ["run_id", "producer_node_id", "consumer_node_id"]
615
- )
616
-
617
- columns = ", ".join([f":{key}" for key in data[0]])
618
- query = f"INSERT INTO task_res VALUES({columns});"
619
-
620
- # Only invalid run_id can trigger IntegrityError.
621
- # This may need to be changed in the future version with more integrity checks.
622
- try:
623
- self.query(query, data)
624
- except sqlite3.IntegrityError:
625
- log(ERROR, "`run` is invalid")
626
- return None
627
-
628
- return task_id
629
-
630
365
  def store_message_res(self, message: Message) -> Optional[UUID]:
631
366
  """Store one Message."""
632
367
  # Validate message
@@ -705,67 +440,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
705
440
 
706
441
  return message_id
707
442
 
708
- # pylint: disable-next=R0912,R0915,R0914
709
- def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
710
- """Get TaskRes for the given TaskIns IDs."""
711
- ret: dict[UUID, TaskRes] = {}
712
-
713
- # Verify TaskIns IDs
714
- current = time.time()
715
- query = f"""
716
- SELECT *
717
- FROM task_ins
718
- WHERE task_id IN ({",".join(["?"] * len(task_ids))});
719
- """
720
- rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
721
- found_task_ins_dict: dict[UUID, TaskIns] = {}
722
- for row in rows:
723
- convert_sint64_values_in_dict_to_uint64(
724
- row, ["run_id", "producer_node_id", "consumer_node_id"]
725
- )
726
- found_task_ins_dict[UUID(row["task_id"])] = dict_to_task_ins(row)
727
-
728
- ret = verify_taskins_ids(
729
- inquired_taskins_ids=task_ids,
730
- found_taskins_dict=found_task_ins_dict,
731
- current_time=current,
732
- )
733
-
734
- # Find all TaskRes
735
- query = f"""
736
- SELECT *
737
- FROM task_res
738
- WHERE ancestry IN ({",".join(["?"] * len(task_ids))})
739
- AND delivered_at = "";
740
- """
741
- rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
742
- for row in rows:
743
- convert_sint64_values_in_dict_to_uint64(
744
- row, ["run_id", "producer_node_id", "consumer_node_id"]
745
- )
746
- tmp_ret_dict = verify_found_taskres(
747
- inquired_taskins_ids=task_ids,
748
- found_taskins_dict=found_task_ins_dict,
749
- found_taskres_list=[dict_to_task_res(row) for row in rows],
750
- current_time=current,
751
- )
752
- ret.update(tmp_ret_dict)
753
-
754
- # Mark existing TaskRes to be returned as delivered
755
- delivered_at = now().isoformat()
756
- for task_res in ret.values():
757
- task_res.task.delivered_at = delivered_at
758
- task_res_ids = [task_res.task_id for task_res in ret.values()]
759
- query = f"""
760
- UPDATE task_res
761
- SET delivered_at = ?
762
- WHERE task_id IN ({",".join(["?"] * len(task_res_ids))});
763
- """
764
- data: list[Any] = [delivered_at] + task_res_ids
765
- self.query(query, data)
766
-
767
- return list(ret.values())
768
-
769
443
  def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
770
444
  """Get reply Messages for the given Message IDs."""
771
445
  ret: dict[UUID, Message] = {}
@@ -828,17 +502,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
828
502
 
829
503
  return list(ret.values())
830
504
 
831
- def num_task_ins(self) -> int:
832
- """Calculate the number of task_ins in store.
833
-
834
- This includes delivered but not yet deleted task_ins.
835
- """
836
- query = "SELECT count(*) AS num FROM task_ins;"
837
- rows = self.query(query)
838
- result = rows[0]
839
- num = cast(int, result["num"])
840
- return num
841
-
842
505
  def num_message_ins(self) -> int:
843
506
  """Calculate the number of instruction Messages in store.
844
507
 
@@ -850,16 +513,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
850
513
  num = cast(int, result["num"])
851
514
  return num
852
515
 
853
- def num_task_res(self) -> int:
854
- """Calculate the number of task_res in store.
855
-
856
- This includes delivered but not yet deleted task_res.
857
- """
858
- query = "SELECT count(*) AS num FROM task_res;"
859
- rows = self.query(query)
860
- result: dict[str, int] = rows[0]
861
- return result["num"]
862
-
863
516
  def num_message_res(self) -> int:
864
517
  """Calculate the number of reply Messages in store.
865
518
 
@@ -870,32 +523,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
870
523
  result: dict[str, int] = rows[0]
871
524
  return result["num"]
872
525
 
873
- def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
874
- """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
875
- if not task_ins_ids:
876
- return
877
- if self.conn is None:
878
- raise AttributeError("LinkState not initialized")
879
-
880
- placeholders = ",".join(["?"] * len(task_ins_ids))
881
- data = tuple(str(task_id) for task_id in task_ins_ids)
882
-
883
- # Delete task_ins
884
- query_1 = f"""
885
- DELETE FROM task_ins
886
- WHERE task_id IN ({placeholders});
887
- """
888
-
889
- # Delete task_res
890
- query_2 = f"""
891
- DELETE FROM task_res
892
- WHERE ancestry IN ({placeholders});
893
- """
894
-
895
- with self.conn:
896
- self.conn.execute(query_1, data)
897
- self.conn.execute(query_2, data)
898
-
899
526
  def delete_messages(self, message_ins_ids: set[UUID]) -> None:
900
527
  """Delete a Message and its reply based on provided Message IDs."""
901
528
  if not message_ins_ids:
@@ -922,25 +549,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
922
549
  self.conn.execute(query_1, data)
923
550
  self.conn.execute(query_2, data)
924
551
 
925
- def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
926
- """Get all TaskIns IDs for the given run_id."""
927
- if self.conn is None:
928
- raise AttributeError("LinkState not initialized")
929
-
930
- query = """
931
- SELECT task_id
932
- FROM task_ins
933
- WHERE run_id = :run_id;
934
- """
935
-
936
- sint64_run_id = convert_uint64_to_sint64(run_id)
937
- data = {"run_id": sint64_run_id}
938
-
939
- with self.conn:
940
- rows = self.conn.execute(query, data).fetchall()
941
-
942
- return {UUID(row["task_id"]) for row in rows}
943
-
944
552
  def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
945
553
  """Get all instruction Message IDs for the given run_id."""
946
554
  if self.conn is None:
@@ -1370,33 +978,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1370
978
  latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
1371
979
  return "".join(row["log"] for row in rows), latest_timestamp
1372
980
 
1373
- def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
1374
- """Check if the TaskIns exists and is valid (not expired).
1375
-
1376
- Return TaskIns if valid.
1377
- """
1378
- query = """
1379
- SELECT *
1380
- FROM task_ins
1381
- WHERE task_id = :task_id
1382
- """
1383
- data = {"task_id": task_id}
1384
- rows = self.query(query, data)
1385
- if not rows:
1386
- # TaskIns does not exist
1387
- return None
1388
-
1389
- task_ins = rows[0]
1390
- created_at = task_ins["created_at"]
1391
- ttl = task_ins["ttl"]
1392
- current_time = time.time()
1393
-
1394
- # Check if TaskIns is expired
1395
- if ttl is not None and created_at + ttl <= current_time:
1396
- return None
1397
-
1398
- return task_ins
1399
-
1400
981
  def get_valid_message_ins(self, message_id: str) -> Optional[dict[str, Any]]:
1401
982
  """Check if the Message exists and is valid (not expired).
1402
983
 
@@ -1418,7 +999,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1418
999
  ttl = message_ins["ttl"]
1419
1000
  current_time = time.time()
1420
1001
 
1421
- # Check if TaskIns is expired
1002
+ # Check if Message is expired
1422
1003
  if ttl is not None and created_at + ttl <= current_time:
1423
1004
  return None
1424
1005
 
@@ -1437,42 +1018,6 @@ def dict_factory(
1437
1018
  return dict(zip(fields, row))
1438
1019
 
1439
1020
 
1440
- def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]:
1441
- """Transform TaskIns to dict."""
1442
- result = {
1443
- "task_id": task_msg.task_id,
1444
- "group_id": task_msg.group_id,
1445
- "run_id": task_msg.run_id,
1446
- "producer_node_id": task_msg.task.producer.node_id,
1447
- "consumer_node_id": task_msg.task.consumer.node_id,
1448
- "created_at": task_msg.task.created_at,
1449
- "delivered_at": task_msg.task.delivered_at,
1450
- "ttl": task_msg.task.ttl,
1451
- "ancestry": ",".join(task_msg.task.ancestry),
1452
- "task_type": task_msg.task.task_type,
1453
- "recordset": task_msg.task.recordset.SerializeToString(),
1454
- }
1455
- return result
1456
-
1457
-
1458
- def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
1459
- """Transform TaskRes to dict."""
1460
- result = {
1461
- "task_id": task_msg.task_id,
1462
- "group_id": task_msg.group_id,
1463
- "run_id": task_msg.run_id,
1464
- "producer_node_id": task_msg.task.producer.node_id,
1465
- "consumer_node_id": task_msg.task.consumer.node_id,
1466
- "created_at": task_msg.task.created_at,
1467
- "delivered_at": task_msg.task.delivered_at,
1468
- "ttl": task_msg.task.ttl,
1469
- "ancestry": ",".join(task_msg.task.ancestry),
1470
- "task_type": task_msg.task.task_type,
1471
- "recordset": task_msg.task.recordset.SerializeToString(),
1472
- }
1473
- return result
1474
-
1475
-
1476
1021
  def message_to_dict(message: Message) -> dict[str, Any]:
1477
1022
  """Transform Message to dict."""
1478
1023
  result = {
@@ -1498,60 +1043,6 @@ def message_to_dict(message: Message) -> dict[str, Any]:
1498
1043
  return result
1499
1044
 
1500
1045
 
1501
- def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
1502
- """Turn task_dict into protobuf message."""
1503
- recordset = ProtoRecordSet()
1504
- recordset.ParseFromString(task_dict["recordset"])
1505
-
1506
- result = TaskIns(
1507
- task_id=task_dict["task_id"],
1508
- group_id=task_dict["group_id"],
1509
- run_id=task_dict["run_id"],
1510
- task=Task(
1511
- producer=Node(
1512
- node_id=task_dict["producer_node_id"],
1513
- ),
1514
- consumer=Node(
1515
- node_id=task_dict["consumer_node_id"],
1516
- ),
1517
- created_at=task_dict["created_at"],
1518
- delivered_at=task_dict["delivered_at"],
1519
- ttl=task_dict["ttl"],
1520
- ancestry=task_dict["ancestry"].split(","),
1521
- task_type=task_dict["task_type"],
1522
- recordset=recordset,
1523
- ),
1524
- )
1525
- return result
1526
-
1527
-
1528
- def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
1529
- """Turn task_dict into protobuf message."""
1530
- recordset = ProtoRecordSet()
1531
- recordset.ParseFromString(task_dict["recordset"])
1532
-
1533
- result = TaskRes(
1534
- task_id=task_dict["task_id"],
1535
- group_id=task_dict["group_id"],
1536
- run_id=task_dict["run_id"],
1537
- task=Task(
1538
- producer=Node(
1539
- node_id=task_dict["producer_node_id"],
1540
- ),
1541
- consumer=Node(
1542
- node_id=task_dict["consumer_node_id"],
1543
- ),
1544
- created_at=task_dict["created_at"],
1545
- delivered_at=task_dict["delivered_at"],
1546
- ttl=task_dict["ttl"],
1547
- ancestry=task_dict["ancestry"].split(","),
1548
- task_type=task_dict["task_type"],
1549
- recordset=recordset,
1550
- ),
1551
- )
1552
- return result
1553
-
1554
-
1555
1046
  def dict_to_message(message_dict: dict[str, Any]) -> Message:
1556
1047
  """Transform dict to Message."""
1557
1048
  content, error = None, None