flwr-nightly 1.17.0.dev20250319__py3-none-any.whl → 1.17.0.dev20250320__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/app.py +6 -4
- flwr/client/clientapp/app.py +2 -2
- flwr/client/grpc_client/connection.py +23 -20
- flwr/client/message_handler/message_handler.py +27 -27
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +5 -5
- flwr/client/run_info_store.py +2 -2
- flwr/common/__init__.py +2 -0
- flwr/common/context.py +4 -4
- flwr/common/message.py +269 -101
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/configsrecord.py +2 -2
- flwr/common/record/metricsrecord.py +1 -1
- flwr/common/record/parametersrecord.py +1 -1
- flwr/common/record/{recordset.py → recorddict.py} +57 -17
- flwr/common/{recordset_compat.py → recorddict_compat.py} +105 -105
- flwr/common/serde.py +33 -37
- flwr/proto/exec_pb2.py +32 -32
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +2 -2
- flwr/proto/run_pb2.py +32 -32
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/compat/grid_client_proxy.py +30 -30
- flwr/server/grid/grid.py +3 -3
- flwr/server/grid/grpc_grid.py +15 -23
- flwr/server/grid/inmemory_grid.py +14 -20
- flwr/server/superlink/fleet/vce/vce_api.py +1 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +1 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +14 -18
- flwr/server/superlink/linkstate/utils.py +10 -7
- flwr/server/superlink/serverappio/serverappio_servicer.py +1 -1
- flwr/server/utils/validator.py +4 -4
- flwr/server/workflow/default_workflows.py +7 -7
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +3 -3
- flwr/superexec/deployment.py +2 -2
- flwr/superexec/simulation.py +2 -2
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/METADATA +1 -1
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/RECORD +49 -49
- flwr/proto/recordset_pb2.py +0 -70
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/entry_points.txt +0 -0
@@ -18,8 +18,8 @@
|
|
18
18
|
from typing import Optional
|
19
19
|
|
20
20
|
from flwr import common
|
21
|
-
from flwr.common import Message, MessageType, MessageTypeLegacy,
|
22
|
-
from flwr.common import
|
21
|
+
from flwr.common import Message, MessageType, MessageTypeLegacy, RecordDict
|
22
|
+
from flwr.common import recorddict_compat as compat
|
23
23
|
from flwr.server.client_proxy import ClientProxy
|
24
24
|
|
25
25
|
from ..grid.grid import Grid
|
@@ -41,14 +41,14 @@ class GridClientProxy(ClientProxy):
|
|
41
41
|
group_id: Optional[int],
|
42
42
|
) -> common.GetPropertiesRes:
|
43
43
|
"""Return client's properties."""
|
44
|
-
# Ins to
|
45
|
-
|
44
|
+
# Ins to RecordDict
|
45
|
+
out_recorddict = compat.getpropertiesins_to_recorddict(ins)
|
46
46
|
# Fetch response
|
47
|
-
|
48
|
-
|
47
|
+
in_recorddict = self._send_receive_recorddict(
|
48
|
+
out_recorddict, MessageTypeLegacy.GET_PROPERTIES, timeout, group_id
|
49
49
|
)
|
50
|
-
#
|
51
|
-
return compat.
|
50
|
+
# RecordDict to Res
|
51
|
+
return compat.recorddict_to_getpropertiesres(in_recorddict)
|
52
52
|
|
53
53
|
def get_parameters(
|
54
54
|
self,
|
@@ -57,40 +57,40 @@ class GridClientProxy(ClientProxy):
|
|
57
57
|
group_id: Optional[int],
|
58
58
|
) -> common.GetParametersRes:
|
59
59
|
"""Return the current local model parameters."""
|
60
|
-
# Ins to
|
61
|
-
|
60
|
+
# Ins to RecordDict
|
61
|
+
out_recorddict = compat.getparametersins_to_recorddict(ins)
|
62
62
|
# Fetch response
|
63
|
-
|
64
|
-
|
63
|
+
in_recorddict = self._send_receive_recorddict(
|
64
|
+
out_recorddict, MessageTypeLegacy.GET_PARAMETERS, timeout, group_id
|
65
65
|
)
|
66
|
-
#
|
67
|
-
return compat.
|
66
|
+
# RecordDict to Res
|
67
|
+
return compat.recorddict_to_getparametersres(in_recorddict, False)
|
68
68
|
|
69
69
|
def fit(
|
70
70
|
self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
|
71
71
|
) -> common.FitRes:
|
72
72
|
"""Train model parameters on the locally held dataset."""
|
73
|
-
# Ins to
|
74
|
-
|
73
|
+
# Ins to RecordDict
|
74
|
+
out_recorddict = compat.fitins_to_recorddict(ins, keep_input=True)
|
75
75
|
# Fetch response
|
76
|
-
|
77
|
-
|
76
|
+
in_recorddict = self._send_receive_recorddict(
|
77
|
+
out_recorddict, MessageType.TRAIN, timeout, group_id
|
78
78
|
)
|
79
|
-
#
|
80
|
-
return compat.
|
79
|
+
# RecordDict to Res
|
80
|
+
return compat.recorddict_to_fitres(in_recorddict, keep_input=False)
|
81
81
|
|
82
82
|
def evaluate(
|
83
83
|
self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
|
84
84
|
) -> common.EvaluateRes:
|
85
85
|
"""Evaluate model parameters on the locally held dataset."""
|
86
|
-
# Ins to
|
87
|
-
|
86
|
+
# Ins to RecordDict
|
87
|
+
out_recorddict = compat.evaluateins_to_recorddict(ins, keep_input=True)
|
88
88
|
# Fetch response
|
89
|
-
|
90
|
-
|
89
|
+
in_recorddict = self._send_receive_recorddict(
|
90
|
+
out_recorddict, MessageType.EVALUATE, timeout, group_id
|
91
91
|
)
|
92
|
-
#
|
93
|
-
return compat.
|
92
|
+
# RecordDict to Res
|
93
|
+
return compat.recorddict_to_evaluateres(in_recorddict)
|
94
94
|
|
95
95
|
def reconnect(
|
96
96
|
self,
|
@@ -101,17 +101,17 @@ class GridClientProxy(ClientProxy):
|
|
101
101
|
"""Disconnect and (optionally) reconnect later."""
|
102
102
|
return common.DisconnectRes(reason="") # Nothing to do here (yet)
|
103
103
|
|
104
|
-
def
|
104
|
+
def _send_receive_recorddict(
|
105
105
|
self,
|
106
|
-
|
106
|
+
recorddict: RecordDict,
|
107
107
|
message_type: str,
|
108
108
|
timeout: Optional[float],
|
109
109
|
group_id: Optional[int],
|
110
|
-
) ->
|
110
|
+
) -> RecordDict:
|
111
111
|
|
112
112
|
# Create message
|
113
113
|
message = self.grid.create_message(
|
114
|
-
content=
|
114
|
+
content=recorddict,
|
115
115
|
message_type=message_type,
|
116
116
|
dst_node_id=self.node_id,
|
117
117
|
group_id=str(group_id) if group_id else "",
|
flwr/server/grid/grid.py
CHANGED
@@ -19,7 +19,7 @@ from abc import ABC, abstractmethod
|
|
19
19
|
from collections.abc import Iterable
|
20
20
|
from typing import Optional
|
21
21
|
|
22
|
-
from flwr.common import Message,
|
22
|
+
from flwr.common import Message, RecordDict
|
23
23
|
from flwr.common.typing import Run
|
24
24
|
|
25
25
|
|
@@ -48,7 +48,7 @@ class Grid(ABC):
|
|
48
48
|
@abstractmethod
|
49
49
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
50
50
|
self,
|
51
|
-
content:
|
51
|
+
content: RecordDict,
|
52
52
|
message_type: str,
|
53
53
|
dst_node_id: int,
|
54
54
|
group_id: str,
|
@@ -61,7 +61,7 @@ class Grid(ABC):
|
|
61
61
|
|
62
62
|
Parameters
|
63
63
|
----------
|
64
|
-
content :
|
64
|
+
content : RecordDict
|
65
65
|
The content for the new message. This holds records that are to be sent
|
66
66
|
to the destination node.
|
67
67
|
message_type : str
|
flwr/server/grid/grpc_grid.py
CHANGED
@@ -22,13 +22,13 @@ from typing import Optional, cast
|
|
22
22
|
|
23
23
|
import grpc
|
24
24
|
|
25
|
-
from flwr.common import
|
25
|
+
from flwr.common import Message, RecordDict
|
26
26
|
from flwr.common.constant import (
|
27
27
|
SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
28
28
|
SUPERLINK_NODE_ID,
|
29
29
|
)
|
30
30
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
31
|
-
from flwr.common.logger import log
|
31
|
+
from flwr.common.logger import log, warn_deprecated_feature
|
32
32
|
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
33
33
|
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
34
34
|
from flwr.common.typing import Run
|
@@ -161,18 +161,15 @@ class GrpcGrid(Grid):
|
|
161
161
|
def _check_message(self, message: Message) -> None:
|
162
162
|
# Check if the message is valid
|
163
163
|
if not (
|
164
|
-
|
165
|
-
message.metadata.
|
166
|
-
and message.metadata.src_node_id == self.node.node_id
|
167
|
-
and message.metadata.message_id == ""
|
168
|
-
and message.metadata.reply_to_message == ""
|
164
|
+
message.metadata.message_id == ""
|
165
|
+
and message.metadata.reply_to_message_id == ""
|
169
166
|
and message.metadata.ttl > 0
|
170
167
|
):
|
171
168
|
raise ValueError(f"Invalid message: {message}")
|
172
169
|
|
173
170
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
174
171
|
self,
|
175
|
-
content:
|
172
|
+
content: RecordDict,
|
176
173
|
message_type: str,
|
177
174
|
dst_node_id: int,
|
178
175
|
group_id: str,
|
@@ -183,18 +180,11 @@ class GrpcGrid(Grid):
|
|
183
180
|
This method constructs a new `Message` with given content and metadata.
|
184
181
|
The `run_id` and `src_node_id` will be set automatically.
|
185
182
|
"""
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
message_id="", # Will be set by the server
|
190
|
-
src_node_id=self.node.node_id,
|
191
|
-
dst_node_id=dst_node_id,
|
192
|
-
reply_to_message="",
|
193
|
-
group_id=group_id,
|
194
|
-
ttl=ttl_,
|
195
|
-
message_type=message_type,
|
183
|
+
warn_deprecated_feature(
|
184
|
+
"`Driver.create_message` / `Grid.create_message` is deprecated."
|
185
|
+
"Use `Message` constructor instead."
|
196
186
|
)
|
197
|
-
return Message(
|
187
|
+
return Message(content, dst_node_id, message_type, ttl=ttl, group_id=group_id)
|
198
188
|
|
199
189
|
def get_node_ids(self) -> Iterable[int]:
|
200
190
|
"""Get node IDs."""
|
@@ -211,8 +201,12 @@ class GrpcGrid(Grid):
|
|
211
201
|
to the node specified in `dst_node_id`.
|
212
202
|
"""
|
213
203
|
# Construct Messages
|
204
|
+
run_id = cast(Run, self._run).run_id
|
214
205
|
message_proto_list: list[ProtoMessage] = []
|
215
206
|
for msg in messages:
|
207
|
+
# Populate metadata
|
208
|
+
msg.metadata.__dict__["_run_id"] = run_id
|
209
|
+
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
216
210
|
# Check message
|
217
211
|
self._check_message(msg)
|
218
212
|
# Convert to proto
|
@@ -223,9 +217,7 @@ class GrpcGrid(Grid):
|
|
223
217
|
try:
|
224
218
|
# Call GrpcServerAppIoStub method
|
225
219
|
res: PushInsMessagesResponse = self._stub.PushMessages(
|
226
|
-
PushInsMessagesRequest(
|
227
|
-
messages_list=message_proto_list, run_id=cast(Run, self._run).run_id
|
228
|
-
)
|
220
|
+
PushInsMessagesRequest(messages_list=message_proto_list, run_id=run_id)
|
229
221
|
)
|
230
222
|
if len([msg_id for msg_id in res.message_ids if msg_id]) != len(
|
231
223
|
message_proto_list
|
@@ -289,7 +281,7 @@ class GrpcGrid(Grid):
|
|
289
281
|
res_msgs = self.pull_messages(msg_ids)
|
290
282
|
ret.extend(res_msgs)
|
291
283
|
msg_ids.difference_update(
|
292
|
-
{msg.metadata.
|
284
|
+
{msg.metadata.reply_to_message_id for msg in res_msgs}
|
293
285
|
)
|
294
286
|
if len(msg_ids) == 0:
|
295
287
|
break
|
@@ -20,8 +20,9 @@ from collections.abc import Iterable
|
|
20
20
|
from typing import Optional, cast
|
21
21
|
from uuid import UUID
|
22
22
|
|
23
|
-
from flwr.common import
|
23
|
+
from flwr.common import Message, RecordDict
|
24
24
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
25
|
+
from flwr.common.logger import warn_deprecated_feature
|
25
26
|
from flwr.common.typing import Run
|
26
27
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
27
28
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
@@ -53,10 +54,8 @@ class InMemoryGrid(Grid):
|
|
53
54
|
def _check_message(self, message: Message) -> None:
|
54
55
|
# Check if the message is valid
|
55
56
|
if not (
|
56
|
-
message.metadata.
|
57
|
-
and message.metadata.
|
58
|
-
and message.metadata.message_id == ""
|
59
|
-
and message.metadata.reply_to_message == ""
|
57
|
+
message.metadata.message_id == ""
|
58
|
+
and message.metadata.reply_to_message_id == ""
|
60
59
|
and message.metadata.ttl > 0
|
61
60
|
and message.metadata.delivered_at == ""
|
62
61
|
):
|
@@ -76,7 +75,7 @@ class InMemoryGrid(Grid):
|
|
76
75
|
|
77
76
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
78
77
|
self,
|
79
|
-
content:
|
78
|
+
content: RecordDict,
|
80
79
|
message_type: str,
|
81
80
|
dst_node_id: int,
|
82
81
|
group_id: str,
|
@@ -87,19 +86,11 @@ class InMemoryGrid(Grid):
|
|
87
86
|
This method constructs a new `Message` with given content and metadata.
|
88
87
|
The `run_id` and `src_node_id` will be set automatically.
|
89
88
|
"""
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
run_id=cast(Run, self._run).run_id,
|
94
|
-
message_id="", # Will be set by the server
|
95
|
-
src_node_id=self.node.node_id,
|
96
|
-
dst_node_id=dst_node_id,
|
97
|
-
reply_to_message="",
|
98
|
-
group_id=group_id,
|
99
|
-
ttl=ttl_,
|
100
|
-
message_type=message_type,
|
89
|
+
warn_deprecated_feature(
|
90
|
+
"`Driver.create_message` / `Grid.create_message` is deprecated."
|
91
|
+
"Use `Message` constructor instead."
|
101
92
|
)
|
102
|
-
return Message(
|
93
|
+
return Message(content, dst_node_id, message_type, ttl=ttl, group_id=group_id)
|
103
94
|
|
104
95
|
def get_node_ids(self) -> Iterable[int]:
|
105
96
|
"""Get node IDs."""
|
@@ -113,6 +104,9 @@ class InMemoryGrid(Grid):
|
|
113
104
|
"""
|
114
105
|
msg_ids: list[str] = []
|
115
106
|
for msg in messages:
|
107
|
+
# Populate metadata
|
108
|
+
msg.metadata.__dict__["_run_id"] = cast(Run, self._run).run_id
|
109
|
+
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
116
110
|
# Check message
|
117
111
|
self._check_message(msg)
|
118
112
|
# Store in state
|
@@ -133,7 +127,7 @@ class InMemoryGrid(Grid):
|
|
133
127
|
message_res_list = self.state.get_message_res(message_ids=msg_ids)
|
134
128
|
# Get IDs of Messages these replies are for
|
135
129
|
message_ins_ids_to_delete = {
|
136
|
-
UUID(msg_res.metadata.
|
130
|
+
UUID(msg_res.metadata.reply_to_message_id) for msg_res in message_res_list
|
137
131
|
}
|
138
132
|
# Delete
|
139
133
|
self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
@@ -162,7 +156,7 @@ class InMemoryGrid(Grid):
|
|
162
156
|
res_msgs = self.pull_messages(msg_ids)
|
163
157
|
ret.extend(res_msgs)
|
164
158
|
msg_ids.difference_update(
|
165
|
-
{msg.metadata.
|
159
|
+
{msg.metadata.reply_to_message_id for msg in res_msgs}
|
166
160
|
)
|
167
161
|
if len(msg_ids) == 0:
|
168
162
|
break
|
@@ -130,9 +130,7 @@ def worker(
|
|
130
130
|
e_code = ErrorCode.UNKNOWN
|
131
131
|
|
132
132
|
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
|
133
|
-
out_mssg = message
|
134
|
-
error=Error(code=e_code, reason=reason)
|
135
|
-
)
|
133
|
+
out_mssg = Message(Error(code=e_code, reason=reason), reply_to=message)
|
136
134
|
|
137
135
|
finally:
|
138
136
|
if out_mssg:
|
@@ -158,7 +158,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
158
158
|
res_metadata = message.metadata
|
159
159
|
with self.lock:
|
160
160
|
# Check if the Message it is replying to exists and is valid
|
161
|
-
msg_ins_id = res_metadata.
|
161
|
+
msg_ins_id = res_metadata.reply_to_message_id
|
162
162
|
msg_ins = self.message_ins_store.get(UUID(msg_ins_id))
|
163
163
|
|
164
164
|
# Ensure that dst_node_id of original Message matches the src_node_id of
|
@@ -35,18 +35,19 @@ from flwr.common.constant import (
|
|
35
35
|
SUPERLINK_NODE_ID,
|
36
36
|
Status,
|
37
37
|
)
|
38
|
+
from flwr.common.message import make_message
|
38
39
|
from flwr.common.record import ConfigsRecord
|
39
40
|
from flwr.common.serde import (
|
40
41
|
error_from_proto,
|
41
42
|
error_to_proto,
|
42
|
-
|
43
|
-
|
43
|
+
recorddict_from_proto,
|
44
|
+
recorddict_to_proto,
|
44
45
|
)
|
45
46
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
46
47
|
|
47
48
|
# pylint: disable=E0611
|
48
49
|
from flwr.proto.error_pb2 import Error as ProtoError
|
49
|
-
from flwr.proto.
|
50
|
+
from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
|
50
51
|
|
51
52
|
# pylint: enable=E0611
|
52
53
|
from flwr.server.utils.validator import validate_message
|
@@ -131,7 +132,7 @@ CREATE TABLE IF NOT EXISTS message_ins(
|
|
131
132
|
run_id INTEGER,
|
132
133
|
src_node_id INTEGER,
|
133
134
|
dst_node_id INTEGER,
|
134
|
-
|
135
|
+
reply_to_message_id TEXT,
|
135
136
|
created_at REAL,
|
136
137
|
delivered_at TEXT,
|
137
138
|
ttl REAL,
|
@@ -150,7 +151,7 @@ CREATE TABLE IF NOT EXISTS message_res(
|
|
150
151
|
run_id INTEGER,
|
151
152
|
src_node_id INTEGER,
|
152
153
|
dst_node_id INTEGER,
|
153
|
-
|
154
|
+
reply_to_message_id TEXT,
|
154
155
|
created_at REAL,
|
155
156
|
delivered_at TEXT,
|
156
157
|
ttl REAL,
|
@@ -373,7 +374,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
373
374
|
return None
|
374
375
|
|
375
376
|
res_metadata = message.metadata
|
376
|
-
msg_ins_id = res_metadata.
|
377
|
+
msg_ins_id = res_metadata.reply_to_message_id
|
377
378
|
msg_ins = self.get_valid_message_ins(msg_ins_id)
|
378
379
|
if msg_ins is None:
|
379
380
|
log(
|
@@ -495,7 +496,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
495
496
|
query = f"""
|
496
497
|
SELECT *
|
497
498
|
FROM message_res
|
498
|
-
WHERE
|
499
|
+
WHERE reply_to_message_id IN ({",".join(["?"] * len(message_ids))})
|
499
500
|
AND delivered_at = "";
|
500
501
|
"""
|
501
502
|
rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
|
@@ -568,7 +569,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
568
569
|
# Delete reply Message
|
569
570
|
query_2 = f"""
|
570
571
|
DELETE FROM message_res
|
571
|
-
WHERE
|
572
|
+
WHERE reply_to_message_id IN ({placeholders});
|
572
573
|
"""
|
573
574
|
|
574
575
|
with self.conn:
|
@@ -1064,7 +1065,7 @@ def message_to_dict(message: Message) -> dict[str, Any]:
|
|
1064
1065
|
"run_id": message.metadata.run_id,
|
1065
1066
|
"src_node_id": message.metadata.src_node_id,
|
1066
1067
|
"dst_node_id": message.metadata.dst_node_id,
|
1067
|
-
"
|
1068
|
+
"reply_to_message_id": message.metadata.reply_to_message_id,
|
1068
1069
|
"created_at": message.metadata.created_at,
|
1069
1070
|
"delivered_at": message.metadata.delivered_at,
|
1070
1071
|
"ttl": message.metadata.ttl,
|
@@ -1074,7 +1075,7 @@ def message_to_dict(message: Message) -> dict[str, Any]:
|
|
1074
1075
|
}
|
1075
1076
|
|
1076
1077
|
if message.has_content():
|
1077
|
-
result["content"] =
|
1078
|
+
result["content"] = recorddict_to_proto(message.content).SerializeToString()
|
1078
1079
|
else:
|
1079
1080
|
result["error"] = error_to_proto(message.error).SerializeToString()
|
1080
1081
|
|
@@ -1085,20 +1086,15 @@ def dict_to_message(message_dict: dict[str, Any]) -> Message:
|
|
1085
1086
|
"""Transform dict to Message."""
|
1086
1087
|
content, error = None, None
|
1087
1088
|
if (b_content := message_dict.pop("content")) is not None:
|
1088
|
-
content =
|
1089
|
+
content = recorddict_from_proto(ProtoRecordDict.FromString(b_content))
|
1089
1090
|
if (b_error := message_dict.pop("error")) is not None:
|
1090
1091
|
error = error_from_proto(ProtoError.FromString(b_error))
|
1091
1092
|
|
1092
1093
|
# Metadata constructor doesn't allow passing created_at. We set it later
|
1093
1094
|
metadata = Metadata(
|
1094
|
-
**{
|
1095
|
-
k: v
|
1096
|
-
for k, v in message_dict.items()
|
1097
|
-
if k not in ["created_at", "delivered_at"]
|
1098
|
-
}
|
1095
|
+
**{k: v for k, v in message_dict.items() if k not in ["delivered_at"]}
|
1099
1096
|
)
|
1100
|
-
msg =
|
1101
|
-
msg.metadata.__dict__["_created_at"] = message_dict["created_at"]
|
1097
|
+
msg = make_message(metadata=metadata, content=content, error=error)
|
1102
1098
|
msg.metadata.delivered_at = message_dict["delivered_at"]
|
1103
1099
|
return msg
|
1104
1100
|
|
@@ -27,11 +27,12 @@ from flwr.common.constant import (
|
|
27
27
|
Status,
|
28
28
|
SubStatus,
|
29
29
|
)
|
30
|
+
from flwr.common.message import make_message
|
30
31
|
from flwr.common.typing import RunStatus
|
31
32
|
|
32
33
|
# pylint: disable=E0611
|
33
34
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
34
|
-
from flwr.proto.
|
35
|
+
from flwr.proto.recorddict_pb2 import ConfigsRecord as ProtoConfigsRecord
|
35
36
|
|
36
37
|
# pylint: enable=E0611
|
37
38
|
VALID_RUN_STATUS_TRANSITIONS = {
|
@@ -247,13 +248,14 @@ def create_message_error_unavailable_res_message(
|
|
247
248
|
message_id=str(uuid4()),
|
248
249
|
src_node_id=SUPERLINK_NODE_ID,
|
249
250
|
dst_node_id=SUPERLINK_NODE_ID,
|
250
|
-
|
251
|
+
reply_to_message_id=ins_metadata.message_id,
|
251
252
|
group_id=ins_metadata.group_id,
|
252
253
|
message_type=ins_metadata.message_type,
|
254
|
+
created_at=current_time,
|
253
255
|
ttl=ttl,
|
254
256
|
)
|
255
257
|
|
256
|
-
return
|
258
|
+
return make_message(
|
257
259
|
metadata=metadata,
|
258
260
|
error=Error(
|
259
261
|
code=(
|
@@ -270,7 +272,7 @@ def create_message_error_unavailable_res_message(
|
|
270
272
|
)
|
271
273
|
|
272
274
|
|
273
|
-
def create_message_error_unavailable_ins_message(
|
275
|
+
def create_message_error_unavailable_ins_message(reply_to_message_id: UUID) -> Message:
|
274
276
|
"""Error to indicate that the enquired Message had expired before reply arrived or
|
275
277
|
that it isn't found."""
|
276
278
|
metadata = Metadata(
|
@@ -278,13 +280,14 @@ def create_message_error_unavailable_ins_message(reply_to_message: UUID) -> Mess
|
|
278
280
|
message_id=str(uuid4()),
|
279
281
|
src_node_id=SUPERLINK_NODE_ID,
|
280
282
|
dst_node_id=SUPERLINK_NODE_ID,
|
281
|
-
|
283
|
+
reply_to_message_id=str(reply_to_message_id),
|
282
284
|
group_id="", # Unknown
|
283
285
|
message_type=MessageType.SYSTEM,
|
286
|
+
created_at=now().timestamp(),
|
284
287
|
ttl=0,
|
285
288
|
)
|
286
289
|
|
287
|
-
return
|
290
|
+
return make_message(
|
288
291
|
metadata=metadata,
|
289
292
|
error=Error(
|
290
293
|
code=ErrorCode.MESSAGE_UNAVAILABLE,
|
@@ -372,7 +375,7 @@ def verify_found_message_replies(
|
|
372
375
|
ret_dict: dict[UUID, Message] = {}
|
373
376
|
current = current_time if current_time else now().timestamp()
|
374
377
|
for message_res in found_message_res_list:
|
375
|
-
message_ins_id = UUID(message_res.metadata.
|
378
|
+
message_ins_id = UUID(message_res.metadata.reply_to_message_id)
|
376
379
|
if update_set:
|
377
380
|
inquired_message_ids.remove(message_ins_id)
|
378
381
|
# Check if the reply Message has expired
|
@@ -206,7 +206,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
206
206
|
|
207
207
|
# Delete the instruction Messages and their replies if found
|
208
208
|
message_ins_ids_to_delete = {
|
209
|
-
UUID(msg_res.metadata.
|
209
|
+
UUID(msg_res.metadata.reply_to_message_id) for msg_res in messages_res
|
210
210
|
}
|
211
211
|
|
212
212
|
state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
flwr/server/utils/validator.py
CHANGED
@@ -68,8 +68,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
68
68
|
|
69
69
|
# Link respose to original message
|
70
70
|
if not is_reply_message:
|
71
|
-
if metadata.
|
72
|
-
validation_errors.append("`metadata.
|
71
|
+
if metadata.reply_to_message_id != "":
|
72
|
+
validation_errors.append("`metadata.reply_to_message_id` MUST not be set.")
|
73
73
|
if metadata.src_node_id != SUPERLINK_NODE_ID:
|
74
74
|
validation_errors.append(
|
75
75
|
f"`metadata.src_node_id` is not {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
@@ -79,8 +79,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
79
79
|
f"`metadata.dst_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
80
80
|
)
|
81
81
|
else:
|
82
|
-
if metadata.
|
83
|
-
validation_errors.append("`metadata.
|
82
|
+
if metadata.reply_to_message_id == "":
|
83
|
+
validation_errors.append("`metadata.reply_to_message_id` MUST be set.")
|
84
84
|
if metadata.src_node_id == SUPERLINK_NODE_ID:
|
85
85
|
validation_errors.append(
|
86
86
|
f"`metadata.src_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
@@ -20,7 +20,7 @@ import timeit
|
|
20
20
|
from logging import INFO, WARN
|
21
21
|
from typing import Optional, Union, cast
|
22
22
|
|
23
|
-
import flwr.common.
|
23
|
+
import flwr.common.recorddict_compat as compat
|
24
24
|
from flwr.common import (
|
25
25
|
Code,
|
26
26
|
ConfigsRecord,
|
@@ -137,7 +137,7 @@ def default_init_params_workflow(grid: Grid, context: Context) -> None:
|
|
137
137
|
log(INFO, "Requesting initial parameters from one random client")
|
138
138
|
random_client = context.client_manager.sample(1)[0]
|
139
139
|
# Send GetParametersIns and get the response
|
140
|
-
content = compat.
|
140
|
+
content = compat.getparametersins_to_recorddict(GetParametersIns({}))
|
141
141
|
messages = grid.send_and_receive(
|
142
142
|
[
|
143
143
|
grid.create_message(
|
@@ -152,7 +152,7 @@ def default_init_params_workflow(grid: Grid, context: Context) -> None:
|
|
152
152
|
|
153
153
|
if (
|
154
154
|
msg.has_content()
|
155
|
-
and compat.
|
155
|
+
and compat._extract_status_from_recorddict( # pylint: disable=W0212
|
156
156
|
"getparametersres", msg.content
|
157
157
|
).code
|
158
158
|
== Code.OK
|
@@ -254,7 +254,7 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
|
|
254
254
|
# Build out messages
|
255
255
|
out_messages = [
|
256
256
|
grid.create_message(
|
257
|
-
content=compat.
|
257
|
+
content=compat.fitins_to_recorddict(fitins, True),
|
258
258
|
message_type=MessageType.TRAIN,
|
259
259
|
dst_node_id=proxy.node_id,
|
260
260
|
group_id=str(current_round),
|
@@ -282,7 +282,7 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
|
|
282
282
|
for msg in messages:
|
283
283
|
if msg.has_content():
|
284
284
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
285
|
-
fitres = compat.
|
285
|
+
fitres = compat.recorddict_to_fitres(msg.content, False)
|
286
286
|
if fitres.status.code == Code.OK:
|
287
287
|
results.append((proxy, fitres))
|
288
288
|
else:
|
@@ -340,7 +340,7 @@ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
|
|
340
340
|
# Build out messages
|
341
341
|
out_messages = [
|
342
342
|
grid.create_message(
|
343
|
-
content=compat.
|
343
|
+
content=compat.evaluateins_to_recorddict(evalins, True),
|
344
344
|
message_type=MessageType.EVALUATE,
|
345
345
|
dst_node_id=proxy.node_id,
|
346
346
|
group_id=str(current_round),
|
@@ -368,7 +368,7 @@ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
|
|
368
368
|
for msg in messages:
|
369
369
|
if msg.has_content():
|
370
370
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
371
|
-
evalres = compat.
|
371
|
+
evalres = compat.recorddict_to_evaluateres(msg.content)
|
372
372
|
if evalres.status.code == Code.OK:
|
373
373
|
results.append((proxy, evalres))
|
374
374
|
else:
|