flwr-nightly 1.16.0.dev20250306__py3-none-any.whl → 1.16.0.dev20250308__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/rest_client/connection.py +4 -6
- flwr/common/message.py +7 -7
- flwr/common/record/recordset.py +4 -12
- flwr/common/serde.py +8 -126
- flwr/server/compat/driver_client_proxy.py +2 -2
- flwr/server/driver/inmemory_driver.py +15 -18
- flwr/server/superlink/driver/serverappio_servicer.py +18 -23
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +32 -35
- flwr/server/superlink/linkstate/in_memory_linkstate.py +1 -221
- flwr/server/superlink/linkstate/linkstate.py +0 -113
- flwr/server/superlink/linkstate/sqlite_linkstate.py +2 -511
- flwr/server/superlink/linkstate/utils.py +2 -179
- flwr/server/utils/__init__.py +0 -2
- flwr/server/utils/validator.py +0 -88
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +3 -3
- flwr/superexec/exec_servicer.py +3 -3
- {flwr_nightly-1.16.0.dev20250306.dist-info → flwr_nightly-1.16.0.dev20250308.dist-info}/METADATA +1 -1
- {flwr_nightly-1.16.0.dev20250306.dist-info → flwr_nightly-1.16.0.dev20250308.dist-info}/RECORD +25 -30
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- {flwr_nightly-1.16.0.dev20250306.dist-info → flwr_nightly-1.16.0.dev20250308.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.16.0.dev20250306.dist-info → flwr_nightly-1.16.0.dev20250308.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.16.0.dev20250306.dist-info → flwr_nightly-1.16.0.dev20250308.dist-info}/entry_points.txt +0 -0
@@ -82,7 +82,7 @@ def handle_control_message(message: Message) -> tuple[Optional[Message], int]:
|
|
82
82
|
recordset = RecordSet()
|
83
83
|
recordset.configs_records["config"] = ConfigsRecord({"reason": reason})
|
84
84
|
out_message = message.create_reply(recordset)
|
85
|
-
# Return
|
85
|
+
# Return Message and sleep duration
|
86
86
|
return out_message, sleep_duration
|
87
87
|
|
88
88
|
# Any other message
|
@@ -66,9 +66,7 @@ except ModuleNotFoundError:
|
|
66
66
|
|
67
67
|
PATH_CREATE_NODE: str = "api/v0/fleet/create-node"
|
68
68
|
PATH_DELETE_NODE: str = "api/v0/fleet/delete-node"
|
69
|
-
PATH_PULL_TASK_INS: str = "api/v0/fleet/pull-task-ins"
|
70
69
|
PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages"
|
71
|
-
PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
|
72
70
|
PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages"
|
73
71
|
PATH_PING: str = "api/v0/fleet/ping"
|
74
72
|
PATH_GET_RUN: str = "/api/v0/fleet/get-run"
|
@@ -280,7 +278,7 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
280
278
|
node = None
|
281
279
|
|
282
280
|
def receive() -> Optional[Message]:
|
283
|
-
"""Receive next
|
281
|
+
"""Receive next Message from server."""
|
284
282
|
# Get Node
|
285
283
|
if node is None:
|
286
284
|
log(ERROR, "Node instance missing")
|
@@ -309,11 +307,11 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
309
307
|
if message_proto is not None:
|
310
308
|
message = message_from_proto(message_proto)
|
311
309
|
metadata = copy(message.metadata)
|
312
|
-
log(INFO, "[Node] POST /%s: success",
|
310
|
+
log(INFO, "[Node] POST /%s: success", PATH_PULL_MESSAGES)
|
313
311
|
return message
|
314
312
|
|
315
313
|
def send(message: Message) -> None:
|
316
|
-
"""Send
|
314
|
+
"""Send Message result back to server."""
|
317
315
|
# Get Node
|
318
316
|
if node is None:
|
319
317
|
log(ERROR, "Node instance missing")
|
@@ -345,7 +343,7 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
345
343
|
log(
|
346
344
|
INFO,
|
347
345
|
"[Node] POST /%s: success, created result %s",
|
348
|
-
|
346
|
+
PATH_PUSH_MESSAGES,
|
349
347
|
res.results, # pylint: disable=no-member
|
350
348
|
)
|
351
349
|
|
flwr/common/message.py
CHANGED
@@ -25,7 +25,7 @@ from .constant import MESSAGE_TTL_TOLERANCE
|
|
25
25
|
from .logger import log
|
26
26
|
from .record import RecordSet
|
27
27
|
|
28
|
-
DEFAULT_TTL =
|
28
|
+
DEFAULT_TTL = 43200 # This is 12 hours
|
29
29
|
|
30
30
|
|
31
31
|
class Metadata: # pylint: disable=too-many-instance-attributes
|
@@ -321,7 +321,7 @@ class Message:
|
|
321
321
|
)
|
322
322
|
message.metadata.ttl = ttl
|
323
323
|
|
324
|
-
self.
|
324
|
+
self._limit_message_res_ttl(message)
|
325
325
|
|
326
326
|
return message
|
327
327
|
|
@@ -364,7 +364,7 @@ class Message:
|
|
364
364
|
)
|
365
365
|
message.metadata.ttl = ttl
|
366
366
|
|
367
|
-
self.
|
367
|
+
self._limit_message_res_ttl(message)
|
368
368
|
|
369
369
|
return message
|
370
370
|
|
@@ -379,14 +379,14 @@ class Message:
|
|
379
379
|
)
|
380
380
|
return f"{self.__class__.__qualname__}({view})"
|
381
381
|
|
382
|
-
def
|
383
|
-
"""Limit the
|
384
|
-
replies to.
|
382
|
+
def _limit_message_res_ttl(self, message: Message) -> None:
|
383
|
+
"""Limit the TTL of the provided Message to not exceed the expiration time of
|
384
|
+
this Message it replies to.
|
385
385
|
|
386
386
|
Parameters
|
387
387
|
----------
|
388
388
|
message : Message
|
389
|
-
The
|
389
|
+
The reply Message to limit the TTL for.
|
390
390
|
"""
|
391
391
|
# Calculate the maximum allowed TTL
|
392
392
|
max_allowed_ttl = (
|
flwr/common/record/recordset.py
CHANGED
@@ -155,19 +155,11 @@ class RecordSet(TypedDict[str, RecordType]):
|
|
155
155
|
:code:`MetricsRecord` and :code:`ParametersRecord`.
|
156
156
|
"""
|
157
157
|
|
158
|
-
def __init__(
|
159
|
-
self,
|
160
|
-
parameters_records: dict[str, ParametersRecord] | None = None,
|
161
|
-
metrics_records: dict[str, MetricsRecord] | None = None,
|
162
|
-
configs_records: dict[str, ConfigsRecord] | None = None,
|
163
|
-
) -> None:
|
158
|
+
def __init__(self, records: dict[str, RecordType] | None = None) -> None:
|
164
159
|
super().__init__(_check_key, _check_value)
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
self[key] = m_record
|
169
|
-
for key, c_record in (configs_records or {}).items():
|
170
|
-
self[key] = c_record
|
160
|
+
if records is not None:
|
161
|
+
for key, record in records.items():
|
162
|
+
self[key] = record
|
171
163
|
|
172
164
|
@property
|
173
165
|
def parameters_records(self) -> TypedDict[str, ParametersRecord]:
|
flwr/common/serde.py
CHANGED
@@ -21,8 +21,6 @@ from typing import Any, TypeVar, cast
|
|
21
21
|
|
22
22
|
from google.protobuf.message import Message as GrpcMessage
|
23
23
|
|
24
|
-
from flwr.common.constant import SUPERLINK_NODE_ID
|
25
|
-
|
26
24
|
# pylint: disable=E0611
|
27
25
|
from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
|
28
26
|
from flwr.proto.error_pb2 import Error as ProtoError
|
@@ -30,7 +28,6 @@ from flwr.proto.fab_pb2 import Fab as ProtoFab
|
|
30
28
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
31
29
|
from flwr.proto.message_pb2 import Message as ProtoMessage
|
32
30
|
from flwr.proto.message_pb2 import Metadata as ProtoMetadata
|
33
|
-
from flwr.proto.node_pb2 import Node
|
34
31
|
from flwr.proto.recordset_pb2 import Array as ProtoArray
|
35
32
|
from flwr.proto.recordset_pb2 import BoolList, BytesList
|
36
33
|
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
|
@@ -43,7 +40,6 @@ from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
|
|
43
40
|
from flwr.proto.recordset_pb2 import SintList, StringList, UintList
|
44
41
|
from flwr.proto.run_pb2 import Run as ProtoRun
|
45
42
|
from flwr.proto.run_pb2 import RunStatus as ProtoRunStatus
|
46
|
-
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
47
43
|
from flwr.proto.transport_pb2 import (
|
48
44
|
ClientMessage,
|
49
45
|
Code,
|
@@ -583,128 +579,14 @@ def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet:
|
|
583
579
|
|
584
580
|
def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet:
|
585
581
|
"""Deserialize RecordSet from ProtoBuf."""
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
configs_records={
|
595
|
-
k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items()
|
596
|
-
},
|
597
|
-
)
|
598
|
-
|
599
|
-
|
600
|
-
# === Message ===
|
601
|
-
|
602
|
-
|
603
|
-
def message_to_taskins(message: Message) -> TaskIns:
|
604
|
-
"""Create a TaskIns from the Message."""
|
605
|
-
md = message.metadata
|
606
|
-
return TaskIns(
|
607
|
-
group_id=md.group_id,
|
608
|
-
run_id=md.run_id,
|
609
|
-
task=Task(
|
610
|
-
producer=Node(node_id=SUPERLINK_NODE_ID), # Assume driver node
|
611
|
-
consumer=Node(node_id=md.dst_node_id),
|
612
|
-
created_at=md.created_at,
|
613
|
-
ttl=md.ttl,
|
614
|
-
ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
|
615
|
-
task_type=md.message_type,
|
616
|
-
recordset=(
|
617
|
-
recordset_to_proto(message.content) if message.has_content() else None
|
618
|
-
),
|
619
|
-
error=error_to_proto(message.error) if message.has_error() else None,
|
620
|
-
),
|
621
|
-
)
|
622
|
-
|
623
|
-
|
624
|
-
def message_from_taskins(taskins: TaskIns) -> Message:
|
625
|
-
"""Create a Message from the TaskIns."""
|
626
|
-
# Retrieve the Metadata
|
627
|
-
metadata = Metadata(
|
628
|
-
run_id=taskins.run_id,
|
629
|
-
message_id=taskins.task_id,
|
630
|
-
src_node_id=taskins.task.producer.node_id,
|
631
|
-
dst_node_id=taskins.task.consumer.node_id,
|
632
|
-
reply_to_message=taskins.task.ancestry[0] if taskins.task.ancestry else "",
|
633
|
-
group_id=taskins.group_id,
|
634
|
-
ttl=taskins.task.ttl,
|
635
|
-
message_type=taskins.task.task_type,
|
636
|
-
)
|
637
|
-
|
638
|
-
# Construct Message
|
639
|
-
message = Message(
|
640
|
-
metadata=metadata,
|
641
|
-
content=(
|
642
|
-
recordset_from_proto(taskins.task.recordset)
|
643
|
-
if taskins.task.HasField("recordset")
|
644
|
-
else None
|
645
|
-
),
|
646
|
-
error=(
|
647
|
-
error_from_proto(taskins.task.error)
|
648
|
-
if taskins.task.HasField("error")
|
649
|
-
else None
|
650
|
-
),
|
651
|
-
)
|
652
|
-
message.metadata.created_at = taskins.task.created_at
|
653
|
-
return message
|
654
|
-
|
655
|
-
|
656
|
-
def message_to_taskres(message: Message) -> TaskRes:
|
657
|
-
"""Create a TaskRes from the Message."""
|
658
|
-
md = message.metadata
|
659
|
-
return TaskRes(
|
660
|
-
task_id="", # This will be generated by the server
|
661
|
-
group_id=md.group_id,
|
662
|
-
run_id=md.run_id,
|
663
|
-
task=Task(
|
664
|
-
producer=Node(node_id=md.src_node_id),
|
665
|
-
consumer=Node(node_id=SUPERLINK_NODE_ID), # Assume driver node
|
666
|
-
created_at=md.created_at,
|
667
|
-
ttl=md.ttl,
|
668
|
-
ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
|
669
|
-
task_type=md.message_type,
|
670
|
-
recordset=(
|
671
|
-
recordset_to_proto(message.content) if message.has_content() else None
|
672
|
-
),
|
673
|
-
error=error_to_proto(message.error) if message.has_error() else None,
|
674
|
-
),
|
675
|
-
)
|
676
|
-
|
677
|
-
|
678
|
-
def message_from_taskres(taskres: TaskRes) -> Message:
|
679
|
-
"""Create a Message from the TaskIns."""
|
680
|
-
# Retrieve the MetaData
|
681
|
-
metadata = Metadata(
|
682
|
-
run_id=taskres.run_id,
|
683
|
-
message_id=taskres.task_id,
|
684
|
-
src_node_id=taskres.task.producer.node_id,
|
685
|
-
dst_node_id=taskres.task.consumer.node_id,
|
686
|
-
reply_to_message=taskres.task.ancestry[0] if taskres.task.ancestry else "",
|
687
|
-
group_id=taskres.group_id,
|
688
|
-
ttl=taskres.task.ttl,
|
689
|
-
message_type=taskres.task.task_type,
|
690
|
-
)
|
691
|
-
|
692
|
-
# Construct the Message
|
693
|
-
message = Message(
|
694
|
-
metadata=metadata,
|
695
|
-
content=(
|
696
|
-
recordset_from_proto(taskres.task.recordset)
|
697
|
-
if taskres.task.HasField("recordset")
|
698
|
-
else None
|
699
|
-
),
|
700
|
-
error=(
|
701
|
-
error_from_proto(taskres.task.error)
|
702
|
-
if taskres.task.HasField("error")
|
703
|
-
else None
|
704
|
-
),
|
705
|
-
)
|
706
|
-
message.metadata.created_at = taskres.task.created_at
|
707
|
-
return message
|
582
|
+
ret = RecordSet()
|
583
|
+
for k, p_record_proto in recordset_proto.parameters.items():
|
584
|
+
ret[k] = parameters_record_from_proto(p_record_proto)
|
585
|
+
for k, m_record_proto in recordset_proto.metrics.items():
|
586
|
+
ret[k] = metrics_record_from_proto(m_record_proto)
|
587
|
+
for k, c_record_proto in recordset_proto.configs.items():
|
588
|
+
ret[k] = configs_record_from_proto(c_record_proto)
|
589
|
+
return ret
|
708
590
|
|
709
591
|
|
710
592
|
# === FAB ===
|
@@ -104,7 +104,7 @@ class DriverClientProxy(ClientProxy):
|
|
104
104
|
def _send_receive_recordset(
|
105
105
|
self,
|
106
106
|
recordset: RecordSet,
|
107
|
-
|
107
|
+
message_type: str,
|
108
108
|
timeout: Optional[float],
|
109
109
|
group_id: Optional[int],
|
110
110
|
) -> RecordSet:
|
@@ -112,7 +112,7 @@ class DriverClientProxy(ClientProxy):
|
|
112
112
|
# Create message
|
113
113
|
message = self.driver.create_message(
|
114
114
|
content=recordset,
|
115
|
-
message_type=
|
115
|
+
message_type=message_type,
|
116
116
|
dst_node_id=self.node_id,
|
117
117
|
group_id=str(group_id) if group_id else "",
|
118
118
|
ttl=timeout,
|
@@ -23,7 +23,6 @@ from uuid import UUID
|
|
23
23
|
|
24
24
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
25
25
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
26
|
-
from flwr.common.serde import message_from_taskres, message_to_taskins
|
27
26
|
from flwr.common.typing import Run
|
28
27
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
29
28
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
@@ -60,6 +59,7 @@ class InMemoryDriver(Driver):
|
|
60
59
|
and message.metadata.message_id == ""
|
61
60
|
and message.metadata.reply_to_message == ""
|
62
61
|
and message.metadata.ttl > 0
|
62
|
+
and message.metadata.delivered_at == ""
|
63
63
|
):
|
64
64
|
raise ValueError(f"Invalid message: {message}")
|
65
65
|
|
@@ -119,18 +119,16 @@ class InMemoryDriver(Driver):
|
|
119
119
|
This method takes an iterable of messages and sends each message
|
120
120
|
to the node specified in `dst_node_id`.
|
121
121
|
"""
|
122
|
-
|
122
|
+
msg_ids: list[str] = []
|
123
123
|
for msg in messages:
|
124
124
|
# Check message
|
125
125
|
self._check_message(msg)
|
126
|
-
# Convert Message to TaskIns
|
127
|
-
taskins = message_to_taskins(msg)
|
128
126
|
# Store in state
|
129
|
-
|
130
|
-
if
|
131
|
-
|
127
|
+
msg_id = self.state.store_message_ins(msg)
|
128
|
+
if msg_id:
|
129
|
+
msg_ids.append(str(msg_id))
|
132
130
|
|
133
|
-
return
|
131
|
+
return msg_ids
|
134
132
|
|
135
133
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
136
134
|
"""Pull messages based on message IDs.
|
@@ -139,17 +137,16 @@ class InMemoryDriver(Driver):
|
|
139
137
|
set of given message IDs.
|
140
138
|
"""
|
141
139
|
msg_ids = {UUID(msg_id) for msg_id in message_ids}
|
142
|
-
# Pull
|
143
|
-
|
144
|
-
#
|
145
|
-
|
146
|
-
|
147
|
-
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
140
|
+
# Pull Messages
|
141
|
+
message_res_list = self.state.get_message_res(message_ids=msg_ids)
|
142
|
+
# Get IDs of Messages these replies are for
|
143
|
+
message_ins_ids_to_delete = {
|
144
|
+
UUID(msg_res.metadata.reply_to_message) for msg_res in message_res_list
|
148
145
|
}
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
return
|
146
|
+
# Delete
|
147
|
+
self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
148
|
+
|
149
|
+
return message_res_list
|
153
150
|
|
154
151
|
def send_and_receive(
|
155
152
|
self,
|
@@ -22,7 +22,7 @@ from uuid import UUID
|
|
22
22
|
|
23
23
|
import grpc
|
24
24
|
|
25
|
-
from flwr.common import ConfigsRecord
|
25
|
+
from flwr.common import ConfigsRecord, Message
|
26
26
|
from flwr.common.constant import Status
|
27
27
|
from flwr.common.logger import log
|
28
28
|
from flwr.common.serde import (
|
@@ -31,9 +31,7 @@ from flwr.common.serde import (
|
|
31
31
|
fab_from_proto,
|
32
32
|
fab_to_proto,
|
33
33
|
message_from_proto,
|
34
|
-
message_from_taskres,
|
35
34
|
message_to_proto,
|
36
|
-
message_to_taskins,
|
37
35
|
run_status_from_proto,
|
38
36
|
run_status_to_proto,
|
39
37
|
run_to_proto,
|
@@ -69,12 +67,11 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
69
67
|
PushServerAppOutputsRequest,
|
70
68
|
PushServerAppOutputsResponse,
|
71
69
|
)
|
72
|
-
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
73
70
|
from flwr.server.superlink.ffs.ffs import Ffs
|
74
71
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
75
72
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
76
73
|
from flwr.server.superlink.utils import abort_if
|
77
|
-
from flwr.server.utils.validator import
|
74
|
+
from flwr.server.utils.validator import validate_message
|
78
75
|
|
79
76
|
|
80
77
|
class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
@@ -161,20 +158,19 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
161
158
|
while request.messages_list:
|
162
159
|
message_proto = request.messages_list.pop(0)
|
163
160
|
message = message_from_proto(message_proto=message_proto)
|
164
|
-
|
165
|
-
validation_errors = validate_task_ins_or_res(task_ins)
|
161
|
+
validation_errors = validate_message(message, is_reply_message=False)
|
166
162
|
_raise_if(
|
167
163
|
validation_error=bool(validation_errors),
|
168
164
|
request_name="PushMessages",
|
169
165
|
detail=", ".join(validation_errors),
|
170
166
|
)
|
171
167
|
_raise_if(
|
172
|
-
validation_error=request.run_id !=
|
168
|
+
validation_error=request.run_id != message.metadata.run_id,
|
173
169
|
request_name="PushMessages",
|
174
|
-
detail="`
|
170
|
+
detail="`Message.metadata` has mismatched `run_id`",
|
175
171
|
)
|
176
172
|
# Store
|
177
|
-
message_id: Optional[UUID] = state.
|
173
|
+
message_id: Optional[UUID] = state.store_message_ins(message=message)
|
178
174
|
message_ids.append(message_id)
|
179
175
|
|
180
176
|
return PushInsMessagesResponse(
|
@@ -200,32 +196,31 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
200
196
|
context,
|
201
197
|
)
|
202
198
|
|
203
|
-
# Convert each
|
199
|
+
# Convert each message_id str to UUID
|
204
200
|
message_ids: set[UUID] = {
|
205
201
|
UUID(message_id) for message_id in request.message_ids
|
206
202
|
}
|
207
203
|
|
208
204
|
# Read from state
|
209
|
-
|
205
|
+
messages_res: list[Message] = state.get_message_res(message_ids=message_ids)
|
210
206
|
|
211
|
-
# Delete the
|
212
|
-
|
213
|
-
UUID(
|
207
|
+
# Delete the instruction Messages and their replies if found
|
208
|
+
message_ins_ids_to_delete = {
|
209
|
+
UUID(msg_res.metadata.reply_to_message) for msg_res in messages_res
|
214
210
|
}
|
215
211
|
|
216
|
-
state.
|
212
|
+
state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
217
213
|
|
218
|
-
# Convert to
|
214
|
+
# Convert Messages to proto
|
219
215
|
messages_list = []
|
220
|
-
while
|
221
|
-
|
216
|
+
while messages_res:
|
217
|
+
msg = messages_res.pop(0)
|
222
218
|
_raise_if(
|
223
|
-
validation_error=request.run_id !=
|
219
|
+
validation_error=request.run_id != msg.metadata.run_id,
|
224
220
|
request_name="PullMessages",
|
225
|
-
detail="`
|
221
|
+
detail="`message.metadata` has mismatched `run_id`",
|
226
222
|
)
|
227
|
-
|
228
|
-
messages_list.append(message_to_proto(message))
|
223
|
+
messages_list.append(message_to_proto(msg))
|
229
224
|
|
230
225
|
return PullResMessagesResponse(messages_list=messages_list)
|
231
226
|
|
@@ -103,11 +103,11 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
103
103
|
if request.messages_list:
|
104
104
|
log(
|
105
105
|
INFO,
|
106
|
-
"[Fleet.PushMessages] Push
|
106
|
+
"[Fleet.PushMessages] Push replies from node_id=%s",
|
107
107
|
request.messages_list[0].metadata.src_node_id,
|
108
108
|
)
|
109
109
|
else:
|
110
|
-
log(INFO, "[Fleet.PushMessages] No
|
110
|
+
log(INFO, "[Fleet.PushMessages] No replies to push")
|
111
111
|
|
112
112
|
try:
|
113
113
|
res = message_handler.push_messages(
|
@@ -18,13 +18,12 @@
|
|
18
18
|
from typing import Optional
|
19
19
|
from uuid import UUID
|
20
20
|
|
21
|
+
from flwr.common import Message
|
21
22
|
from flwr.common.constant import Status
|
22
23
|
from flwr.common.serde import (
|
23
24
|
fab_to_proto,
|
24
25
|
message_from_proto,
|
25
|
-
message_from_taskins,
|
26
26
|
message_to_proto,
|
27
|
-
message_to_taskres,
|
28
27
|
user_config_to_proto,
|
29
28
|
)
|
30
29
|
from flwr.common.typing import Fab, InvalidRunStatusException
|
@@ -48,7 +47,6 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
48
47
|
GetRunResponse,
|
49
48
|
Run,
|
50
49
|
)
|
51
|
-
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
52
50
|
from flwr.server.superlink.ffs.ffs import Ffs
|
53
51
|
from flwr.server.superlink.linkstate import LinkState
|
54
52
|
from flwr.server.superlink.utils import check_abort
|
@@ -92,13 +90,12 @@ def pull_messages(
|
|
92
90
|
node = request.node # pylint: disable=no-member
|
93
91
|
node_id: int = node.node_id
|
94
92
|
|
95
|
-
# Retrieve
|
96
|
-
|
93
|
+
# Retrieve Message from State
|
94
|
+
message_list: list[Message] = state.get_message_ins(node_id=node_id, limit=1)
|
97
95
|
|
98
96
|
# Convert to Messages
|
99
97
|
msg_proto = []
|
100
|
-
for
|
101
|
-
msg = message_from_taskins(task_ins)
|
98
|
+
for msg in message_list:
|
102
99
|
msg_proto.append(message_to_proto(msg))
|
103
100
|
|
104
101
|
return PullMessagesResponse(messages_list=msg_proto)
|
@@ -108,21 +105,20 @@ def push_messages(
|
|
108
105
|
request: PushMessagesRequest, state: LinkState
|
109
106
|
) -> PushMessagesResponse:
|
110
107
|
"""Push Messages handler."""
|
111
|
-
# Convert Message
|
108
|
+
# Convert Message from proto
|
112
109
|
msg = message_from_proto(message_proto=request.messages_list[0])
|
113
|
-
task_res = message_to_taskres(msg)
|
114
110
|
|
115
111
|
# Abort if the run is not running
|
116
112
|
abort_msg = check_abort(
|
117
|
-
|
113
|
+
msg.metadata.run_id,
|
118
114
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
119
115
|
state,
|
120
116
|
)
|
121
117
|
if abort_msg:
|
122
118
|
raise InvalidRunStatusException(abort_msg)
|
123
119
|
|
124
|
-
# Store
|
125
|
-
message_id: Optional[UUID] = state.
|
120
|
+
# Store Message in State
|
121
|
+
message_id: Optional[UUID] = state.store_message_res(message=msg)
|
126
122
|
|
127
123
|
# Build response
|
128
124
|
response = PushMessagesResponse(
|
@@ -45,7 +45,7 @@ class Backend(ABC):
|
|
45
45
|
def num_workers(self) -> int:
|
46
46
|
"""Return number of workers in the backend.
|
47
47
|
|
48
|
-
This is the number of
|
48
|
+
This is the number of Messages that can be processed concurrently.
|
49
49
|
"""
|
50
50
|
return 0
|
51
51
|
|