flwr-nightly 1.16.0.dev20250306__py3-none-any.whl → 1.16.0.dev20250308__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. flwr/client/message_handler/message_handler.py +1 -1
  2. flwr/client/rest_client/connection.py +4 -6
  3. flwr/common/message.py +7 -7
  4. flwr/common/record/recordset.py +4 -12
  5. flwr/common/serde.py +8 -126
  6. flwr/server/compat/driver_client_proxy.py +2 -2
  7. flwr/server/driver/inmemory_driver.py +15 -18
  8. flwr/server/superlink/driver/serverappio_servicer.py +18 -23
  9. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
  10. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
  11. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  12. flwr/server/superlink/fleet/vce/vce_api.py +32 -35
  13. flwr/server/superlink/linkstate/in_memory_linkstate.py +1 -221
  14. flwr/server/superlink/linkstate/linkstate.py +0 -113
  15. flwr/server/superlink/linkstate/sqlite_linkstate.py +2 -511
  16. flwr/server/superlink/linkstate/utils.py +2 -179
  17. flwr/server/utils/__init__.py +0 -2
  18. flwr/server/utils/validator.py +0 -88
  19. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +3 -3
  20. flwr/superexec/exec_servicer.py +3 -3
  21. {flwr_nightly-1.16.0.dev20250306.dist-info → flwr_nightly-1.16.0.dev20250308.dist-info}/METADATA +1 -1
  22. {flwr_nightly-1.16.0.dev20250306.dist-info → flwr_nightly-1.16.0.dev20250308.dist-info}/RECORD +25 -30
  23. flwr/client/message_handler/task_handler.py +0 -37
  24. flwr/proto/task_pb2.py +0 -33
  25. flwr/proto/task_pb2.pyi +0 -100
  26. flwr/proto/task_pb2_grpc.py +0 -4
  27. flwr/proto/task_pb2_grpc.pyi +0 -4
  28. {flwr_nightly-1.16.0.dev20250306.dist-info → flwr_nightly-1.16.0.dev20250308.dist-info}/LICENSE +0 -0
  29. {flwr_nightly-1.16.0.dev20250306.dist-info → flwr_nightly-1.16.0.dev20250308.dist-info}/WHEEL +0 -0
  30. {flwr_nightly-1.16.0.dev20250306.dist-info → flwr_nightly-1.16.0.dev20250308.dist-info}/entry_points.txt +0 -0
@@ -82,7 +82,7 @@ def handle_control_message(message: Message) -> tuple[Optional[Message], int]:
82
82
  recordset = RecordSet()
83
83
  recordset.configs_records["config"] = ConfigsRecord({"reason": reason})
84
84
  out_message = message.create_reply(recordset)
85
- # Return TaskRes and sleep duration
85
+ # Return Message and sleep duration
86
86
  return out_message, sleep_duration
87
87
 
88
88
  # Any other message
@@ -66,9 +66,7 @@ except ModuleNotFoundError:
66
66
 
67
67
  PATH_CREATE_NODE: str = "api/v0/fleet/create-node"
68
68
  PATH_DELETE_NODE: str = "api/v0/fleet/delete-node"
69
- PATH_PULL_TASK_INS: str = "api/v0/fleet/pull-task-ins"
70
69
  PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages"
71
- PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
72
70
  PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages"
73
71
  PATH_PING: str = "api/v0/fleet/ping"
74
72
  PATH_GET_RUN: str = "/api/v0/fleet/get-run"
@@ -280,7 +278,7 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
280
278
  node = None
281
279
 
282
280
  def receive() -> Optional[Message]:
283
- """Receive next task from server."""
281
+ """Receive next Message from server."""
284
282
  # Get Node
285
283
  if node is None:
286
284
  log(ERROR, "Node instance missing")
@@ -309,11 +307,11 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
309
307
  if message_proto is not None:
310
308
  message = message_from_proto(message_proto)
311
309
  metadata = copy(message.metadata)
312
- log(INFO, "[Node] POST /%s: success", PATH_PULL_TASK_INS)
310
+ log(INFO, "[Node] POST /%s: success", PATH_PULL_MESSAGES)
313
311
  return message
314
312
 
315
313
  def send(message: Message) -> None:
316
- """Send task result back to server."""
314
+ """Send Message result back to server."""
317
315
  # Get Node
318
316
  if node is None:
319
317
  log(ERROR, "Node instance missing")
@@ -345,7 +343,7 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
345
343
  log(
346
344
  INFO,
347
345
  "[Node] POST /%s: success, created result %s",
348
- PATH_PUSH_TASK_RES,
346
+ PATH_PUSH_MESSAGES,
349
347
  res.results, # pylint: disable=no-member
350
348
  )
351
349
 
flwr/common/message.py CHANGED
@@ -25,7 +25,7 @@ from .constant import MESSAGE_TTL_TOLERANCE
25
25
  from .logger import log
26
26
  from .record import RecordSet
27
27
 
28
- DEFAULT_TTL = 3600
28
+ DEFAULT_TTL = 43200 # This is 12 hours
29
29
 
30
30
 
31
31
  class Metadata: # pylint: disable=too-many-instance-attributes
@@ -321,7 +321,7 @@ class Message:
321
321
  )
322
322
  message.metadata.ttl = ttl
323
323
 
324
- self._limit_task_res_ttl(message)
324
+ self._limit_message_res_ttl(message)
325
325
 
326
326
  return message
327
327
 
@@ -364,7 +364,7 @@ class Message:
364
364
  )
365
365
  message.metadata.ttl = ttl
366
366
 
367
- self._limit_task_res_ttl(message)
367
+ self._limit_message_res_ttl(message)
368
368
 
369
369
  return message
370
370
 
@@ -379,14 +379,14 @@ class Message:
379
379
  )
380
380
  return f"{self.__class__.__qualname__}({view})"
381
381
 
382
- def _limit_task_res_ttl(self, message: Message) -> None:
383
- """Limit the TaskRes TTL to not exceed the expiration time of the TaskIns it
384
- replies to.
382
+ def _limit_message_res_ttl(self, message: Message) -> None:
383
+ """Limit the TTL of the provided Message to not exceed the expiration time of
384
+ this Message it replies to.
385
385
 
386
386
  Parameters
387
387
  ----------
388
388
  message : Message
389
- The message to which the TaskRes is replying.
389
+ The reply Message to limit the TTL for.
390
390
  """
391
391
  # Calculate the maximum allowed TTL
392
392
  max_allowed_ttl = (
@@ -155,19 +155,11 @@ class RecordSet(TypedDict[str, RecordType]):
155
155
  :code:`MetricsRecord` and :code:`ParametersRecord`.
156
156
  """
157
157
 
158
- def __init__(
159
- self,
160
- parameters_records: dict[str, ParametersRecord] | None = None,
161
- metrics_records: dict[str, MetricsRecord] | None = None,
162
- configs_records: dict[str, ConfigsRecord] | None = None,
163
- ) -> None:
158
+ def __init__(self, records: dict[str, RecordType] | None = None) -> None:
164
159
  super().__init__(_check_key, _check_value)
165
- for key, p_record in (parameters_records or {}).items():
166
- self[key] = p_record
167
- for key, m_record in (metrics_records or {}).items():
168
- self[key] = m_record
169
- for key, c_record in (configs_records or {}).items():
170
- self[key] = c_record
160
+ if records is not None:
161
+ for key, record in records.items():
162
+ self[key] = record
171
163
 
172
164
  @property
173
165
  def parameters_records(self) -> TypedDict[str, ParametersRecord]:
flwr/common/serde.py CHANGED
@@ -21,8 +21,6 @@ from typing import Any, TypeVar, cast
21
21
 
22
22
  from google.protobuf.message import Message as GrpcMessage
23
23
 
24
- from flwr.common.constant import SUPERLINK_NODE_ID
25
-
26
24
  # pylint: disable=E0611
27
25
  from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
28
26
  from flwr.proto.error_pb2 import Error as ProtoError
@@ -30,7 +28,6 @@ from flwr.proto.fab_pb2 import Fab as ProtoFab
30
28
  from flwr.proto.message_pb2 import Context as ProtoContext
31
29
  from flwr.proto.message_pb2 import Message as ProtoMessage
32
30
  from flwr.proto.message_pb2 import Metadata as ProtoMetadata
33
- from flwr.proto.node_pb2 import Node
34
31
  from flwr.proto.recordset_pb2 import Array as ProtoArray
35
32
  from flwr.proto.recordset_pb2 import BoolList, BytesList
36
33
  from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
@@ -43,7 +40,6 @@ from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
43
40
  from flwr.proto.recordset_pb2 import SintList, StringList, UintList
44
41
  from flwr.proto.run_pb2 import Run as ProtoRun
45
42
  from flwr.proto.run_pb2 import RunStatus as ProtoRunStatus
46
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
47
43
  from flwr.proto.transport_pb2 import (
48
44
  ClientMessage,
49
45
  Code,
@@ -583,128 +579,14 @@ def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet:
583
579
 
584
580
  def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet:
585
581
  """Deserialize RecordSet from ProtoBuf."""
586
- return RecordSet(
587
- parameters_records={
588
- k: parameters_record_from_proto(v)
589
- for k, v in recordset_proto.parameters.items()
590
- },
591
- metrics_records={
592
- k: metrics_record_from_proto(v) for k, v in recordset_proto.metrics.items()
593
- },
594
- configs_records={
595
- k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items()
596
- },
597
- )
598
-
599
-
600
- # === Message ===
601
-
602
-
603
- def message_to_taskins(message: Message) -> TaskIns:
604
- """Create a TaskIns from the Message."""
605
- md = message.metadata
606
- return TaskIns(
607
- group_id=md.group_id,
608
- run_id=md.run_id,
609
- task=Task(
610
- producer=Node(node_id=SUPERLINK_NODE_ID), # Assume driver node
611
- consumer=Node(node_id=md.dst_node_id),
612
- created_at=md.created_at,
613
- ttl=md.ttl,
614
- ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
615
- task_type=md.message_type,
616
- recordset=(
617
- recordset_to_proto(message.content) if message.has_content() else None
618
- ),
619
- error=error_to_proto(message.error) if message.has_error() else None,
620
- ),
621
- )
622
-
623
-
624
- def message_from_taskins(taskins: TaskIns) -> Message:
625
- """Create a Message from the TaskIns."""
626
- # Retrieve the Metadata
627
- metadata = Metadata(
628
- run_id=taskins.run_id,
629
- message_id=taskins.task_id,
630
- src_node_id=taskins.task.producer.node_id,
631
- dst_node_id=taskins.task.consumer.node_id,
632
- reply_to_message=taskins.task.ancestry[0] if taskins.task.ancestry else "",
633
- group_id=taskins.group_id,
634
- ttl=taskins.task.ttl,
635
- message_type=taskins.task.task_type,
636
- )
637
-
638
- # Construct Message
639
- message = Message(
640
- metadata=metadata,
641
- content=(
642
- recordset_from_proto(taskins.task.recordset)
643
- if taskins.task.HasField("recordset")
644
- else None
645
- ),
646
- error=(
647
- error_from_proto(taskins.task.error)
648
- if taskins.task.HasField("error")
649
- else None
650
- ),
651
- )
652
- message.metadata.created_at = taskins.task.created_at
653
- return message
654
-
655
-
656
- def message_to_taskres(message: Message) -> TaskRes:
657
- """Create a TaskRes from the Message."""
658
- md = message.metadata
659
- return TaskRes(
660
- task_id="", # This will be generated by the server
661
- group_id=md.group_id,
662
- run_id=md.run_id,
663
- task=Task(
664
- producer=Node(node_id=md.src_node_id),
665
- consumer=Node(node_id=SUPERLINK_NODE_ID), # Assume driver node
666
- created_at=md.created_at,
667
- ttl=md.ttl,
668
- ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
669
- task_type=md.message_type,
670
- recordset=(
671
- recordset_to_proto(message.content) if message.has_content() else None
672
- ),
673
- error=error_to_proto(message.error) if message.has_error() else None,
674
- ),
675
- )
676
-
677
-
678
- def message_from_taskres(taskres: TaskRes) -> Message:
679
- """Create a Message from the TaskIns."""
680
- # Retrieve the MetaData
681
- metadata = Metadata(
682
- run_id=taskres.run_id,
683
- message_id=taskres.task_id,
684
- src_node_id=taskres.task.producer.node_id,
685
- dst_node_id=taskres.task.consumer.node_id,
686
- reply_to_message=taskres.task.ancestry[0] if taskres.task.ancestry else "",
687
- group_id=taskres.group_id,
688
- ttl=taskres.task.ttl,
689
- message_type=taskres.task.task_type,
690
- )
691
-
692
- # Construct the Message
693
- message = Message(
694
- metadata=metadata,
695
- content=(
696
- recordset_from_proto(taskres.task.recordset)
697
- if taskres.task.HasField("recordset")
698
- else None
699
- ),
700
- error=(
701
- error_from_proto(taskres.task.error)
702
- if taskres.task.HasField("error")
703
- else None
704
- ),
705
- )
706
- message.metadata.created_at = taskres.task.created_at
707
- return message
582
+ ret = RecordSet()
583
+ for k, p_record_proto in recordset_proto.parameters.items():
584
+ ret[k] = parameters_record_from_proto(p_record_proto)
585
+ for k, m_record_proto in recordset_proto.metrics.items():
586
+ ret[k] = metrics_record_from_proto(m_record_proto)
587
+ for k, c_record_proto in recordset_proto.configs.items():
588
+ ret[k] = configs_record_from_proto(c_record_proto)
589
+ return ret
708
590
 
709
591
 
710
592
  # === FAB ===
@@ -104,7 +104,7 @@ class DriverClientProxy(ClientProxy):
104
104
  def _send_receive_recordset(
105
105
  self,
106
106
  recordset: RecordSet,
107
- task_type: str,
107
+ message_type: str,
108
108
  timeout: Optional[float],
109
109
  group_id: Optional[int],
110
110
  ) -> RecordSet:
@@ -112,7 +112,7 @@ class DriverClientProxy(ClientProxy):
112
112
  # Create message
113
113
  message = self.driver.create_message(
114
114
  content=recordset,
115
- message_type=task_type,
115
+ message_type=message_type,
116
116
  dst_node_id=self.node_id,
117
117
  group_id=str(group_id) if group_id else "",
118
118
  ttl=timeout,
@@ -23,7 +23,6 @@ from uuid import UUID
23
23
 
24
24
  from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
25
25
  from flwr.common.constant import SUPERLINK_NODE_ID
26
- from flwr.common.serde import message_from_taskres, message_to_taskins
27
26
  from flwr.common.typing import Run
28
27
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
29
28
  from flwr.server.superlink.linkstate import LinkStateFactory
@@ -60,6 +59,7 @@ class InMemoryDriver(Driver):
60
59
  and message.metadata.message_id == ""
61
60
  and message.metadata.reply_to_message == ""
62
61
  and message.metadata.ttl > 0
62
+ and message.metadata.delivered_at == ""
63
63
  ):
64
64
  raise ValueError(f"Invalid message: {message}")
65
65
 
@@ -119,18 +119,16 @@ class InMemoryDriver(Driver):
119
119
  This method takes an iterable of messages and sends each message
120
120
  to the node specified in `dst_node_id`.
121
121
  """
122
- task_ids: list[str] = []
122
+ msg_ids: list[str] = []
123
123
  for msg in messages:
124
124
  # Check message
125
125
  self._check_message(msg)
126
- # Convert Message to TaskIns
127
- taskins = message_to_taskins(msg)
128
126
  # Store in state
129
- task_id = self.state.store_task_ins(taskins)
130
- if task_id:
131
- task_ids.append(str(task_id))
127
+ msg_id = self.state.store_message_ins(msg)
128
+ if msg_id:
129
+ msg_ids.append(str(msg_id))
132
130
 
133
- return task_ids
131
+ return msg_ids
134
132
 
135
133
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
136
134
  """Pull messages based on message IDs.
@@ -139,17 +137,16 @@ class InMemoryDriver(Driver):
139
137
  set of given message IDs.
140
138
  """
141
139
  msg_ids = {UUID(msg_id) for msg_id in message_ids}
142
- # Pull TaskRes
143
- task_res_list = self.state.get_task_res(task_ids=msg_ids)
144
- # Delete tasks in state
145
- # Delete the TaskIns/TaskRes pairs if TaskRes is found
146
- task_ins_ids_to_delete = {
147
- UUID(task_res.task.ancestry[0]) for task_res in task_res_list
140
+ # Pull Messages
141
+ message_res_list = self.state.get_message_res(message_ids=msg_ids)
142
+ # Get IDs of Messages these replies are for
143
+ message_ins_ids_to_delete = {
144
+ UUID(msg_res.metadata.reply_to_message) for msg_res in message_res_list
148
145
  }
149
- self.state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
150
- # Convert TaskRes to Message
151
- msgs = [message_from_taskres(taskres) for taskres in task_res_list]
152
- return msgs
146
+ # Delete
147
+ self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
148
+
149
+ return message_res_list
153
150
 
154
151
  def send_and_receive(
155
152
  self,
@@ -22,7 +22,7 @@ from uuid import UUID
22
22
 
23
23
  import grpc
24
24
 
25
- from flwr.common import ConfigsRecord
25
+ from flwr.common import ConfigsRecord, Message
26
26
  from flwr.common.constant import Status
27
27
  from flwr.common.logger import log
28
28
  from flwr.common.serde import (
@@ -31,9 +31,7 @@ from flwr.common.serde import (
31
31
  fab_from_proto,
32
32
  fab_to_proto,
33
33
  message_from_proto,
34
- message_from_taskres,
35
34
  message_to_proto,
36
- message_to_taskins,
37
35
  run_status_from_proto,
38
36
  run_status_to_proto,
39
37
  run_to_proto,
@@ -69,12 +67,11 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
69
67
  PushServerAppOutputsRequest,
70
68
  PushServerAppOutputsResponse,
71
69
  )
72
- from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
73
70
  from flwr.server.superlink.ffs.ffs import Ffs
74
71
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
75
72
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
76
73
  from flwr.server.superlink.utils import abort_if
77
- from flwr.server.utils.validator import validate_task_ins_or_res
74
+ from flwr.server.utils.validator import validate_message
78
75
 
79
76
 
80
77
  class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
@@ -161,20 +158,19 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
161
158
  while request.messages_list:
162
159
  message_proto = request.messages_list.pop(0)
163
160
  message = message_from_proto(message_proto=message_proto)
164
- task_ins = message_to_taskins(message=message)
165
- validation_errors = validate_task_ins_or_res(task_ins)
161
+ validation_errors = validate_message(message, is_reply_message=False)
166
162
  _raise_if(
167
163
  validation_error=bool(validation_errors),
168
164
  request_name="PushMessages",
169
165
  detail=", ".join(validation_errors),
170
166
  )
171
167
  _raise_if(
172
- validation_error=request.run_id != task_ins.run_id,
168
+ validation_error=request.run_id != message.metadata.run_id,
173
169
  request_name="PushMessages",
174
- detail="`task_ins` has mismatched `run_id`",
170
+ detail="`Message.metadata` has mismatched `run_id`",
175
171
  )
176
172
  # Store
177
- message_id: Optional[UUID] = state.store_task_ins(task_ins=task_ins)
173
+ message_id: Optional[UUID] = state.store_message_ins(message=message)
178
174
  message_ids.append(message_id)
179
175
 
180
176
  return PushInsMessagesResponse(
@@ -200,32 +196,31 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
200
196
  context,
201
197
  )
202
198
 
203
- # Convert each task_id str to UUID
199
+ # Convert each message_id str to UUID
204
200
  message_ids: set[UUID] = {
205
201
  UUID(message_id) for message_id in request.message_ids
206
202
  }
207
203
 
208
204
  # Read from state
209
- task_res_list: list[TaskRes] = state.get_task_res(task_ids=message_ids)
205
+ messages_res: list[Message] = state.get_message_res(message_ids=message_ids)
210
206
 
211
- # Delete the TaskIns/TaskRes pairs if TaskRes is found
212
- task_ins_ids_to_delete = {
213
- UUID(task_res.task.ancestry[0]) for task_res in task_res_list
207
+ # Delete the instruction Messages and their replies if found
208
+ message_ins_ids_to_delete = {
209
+ UUID(msg_res.metadata.reply_to_message) for msg_res in messages_res
214
210
  }
215
211
 
216
- state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
212
+ state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
217
213
 
218
- # Convert to Messages
214
+ # Convert Messages to proto
219
215
  messages_list = []
220
- while task_res_list:
221
- task_res = task_res_list.pop(0)
216
+ while messages_res:
217
+ msg = messages_res.pop(0)
222
218
  _raise_if(
223
- validation_error=request.run_id != task_res.run_id,
219
+ validation_error=request.run_id != msg.metadata.run_id,
224
220
  request_name="PullMessages",
225
- detail="`task_res` has mismatched `run_id`",
221
+ detail="`message.metadata` has mismatched `run_id`",
226
222
  )
227
- message = message_from_taskres(taskres=task_res)
228
- messages_list.append(message_to_proto(message))
223
+ messages_list.append(message_to_proto(msg))
229
224
 
230
225
  return PullResMessagesResponse(messages_list=messages_list)
231
226
 
@@ -103,11 +103,11 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
103
103
  if request.messages_list:
104
104
  log(
105
105
  INFO,
106
- "[Fleet.PushMessages] Push results from node_id=%s",
106
+ "[Fleet.PushMessages] Push replies from node_id=%s",
107
107
  request.messages_list[0].metadata.src_node_id,
108
108
  )
109
109
  else:
110
- log(INFO, "[Fleet.PushMessages] No task results to push")
110
+ log(INFO, "[Fleet.PushMessages] No replies to push")
111
111
 
112
112
  try:
113
113
  res = message_handler.push_messages(
@@ -18,13 +18,12 @@
18
18
  from typing import Optional
19
19
  from uuid import UUID
20
20
 
21
+ from flwr.common import Message
21
22
  from flwr.common.constant import Status
22
23
  from flwr.common.serde import (
23
24
  fab_to_proto,
24
25
  message_from_proto,
25
- message_from_taskins,
26
26
  message_to_proto,
27
- message_to_taskres,
28
27
  user_config_to_proto,
29
28
  )
30
29
  from flwr.common.typing import Fab, InvalidRunStatusException
@@ -48,7 +47,6 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
48
47
  GetRunResponse,
49
48
  Run,
50
49
  )
51
- from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
52
50
  from flwr.server.superlink.ffs.ffs import Ffs
53
51
  from flwr.server.superlink.linkstate import LinkState
54
52
  from flwr.server.superlink.utils import check_abort
@@ -92,13 +90,12 @@ def pull_messages(
92
90
  node = request.node # pylint: disable=no-member
93
91
  node_id: int = node.node_id
94
92
 
95
- # Retrieve TaskIns from State
96
- task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1)
93
+ # Retrieve Message from State
94
+ message_list: list[Message] = state.get_message_ins(node_id=node_id, limit=1)
97
95
 
98
96
  # Convert to Messages
99
97
  msg_proto = []
100
- for task_ins in task_ins_list:
101
- msg = message_from_taskins(task_ins)
98
+ for msg in message_list:
102
99
  msg_proto.append(message_to_proto(msg))
103
100
 
104
101
  return PullMessagesResponse(messages_list=msg_proto)
@@ -108,21 +105,20 @@ def push_messages(
108
105
  request: PushMessagesRequest, state: LinkState
109
106
  ) -> PushMessagesResponse:
110
107
  """Push Messages handler."""
111
- # Convert Message to TaskRes
108
+ # Convert Message from proto
112
109
  msg = message_from_proto(message_proto=request.messages_list[0])
113
- task_res = message_to_taskres(msg)
114
110
 
115
111
  # Abort if the run is not running
116
112
  abort_msg = check_abort(
117
- task_res.run_id,
113
+ msg.metadata.run_id,
118
114
  [Status.PENDING, Status.STARTING, Status.FINISHED],
119
115
  state,
120
116
  )
121
117
  if abort_msg:
122
118
  raise InvalidRunStatusException(abort_msg)
123
119
 
124
- # Store TaskRes in State
125
- message_id: Optional[UUID] = state.store_task_res(task_res=task_res)
120
+ # Store Message in State
121
+ message_id: Optional[UUID] = state.store_message_res(message=msg)
126
122
 
127
123
  # Build response
128
124
  response = PushMessagesResponse(
@@ -45,7 +45,7 @@ class Backend(ABC):
45
45
  def num_workers(self) -> int:
46
46
  """Return number of workers in the backend.
47
47
 
48
- This is the number of TaskIns that can be processed concurrently.
48
+ This is the number of Messages that can be processed concurrently.
49
49
  """
50
50
  return 0
51
51