flwr-nightly 1.17.0.dev20250319__py3-none-any.whl → 1.17.0.dev20250321__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +10 -12
- flwr/client/clientapp/app.py +2 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +27 -27
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/run_info_store.py +2 -2
- flwr/common/__init__.py +12 -4
- flwr/common/config.py +4 -4
- flwr/common/constant.py +1 -1
- flwr/common/context.py +4 -4
- flwr/common/message.py +269 -101
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/{parametersrecord.py → arrayrecord.py} +75 -32
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +66 -71
- flwr/common/typing.py +8 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/compat/grid_client_proxy.py +31 -31
- flwr/server/grid/grid.py +3 -3
- flwr/server/grid/grpc_grid.py +15 -23
- flwr/server/grid/inmemory_grid.py +14 -20
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +1 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +5 -5
- flwr/server/superlink/linkstate/linkstate.py +4 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +21 -25
- flwr/server/superlink/linkstate/utils.py +18 -15
- flwr/server/superlink/serverappio/serverappio_servicer.py +3 -3
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/utils/validator.py +4 -4
- flwr/server/workflow/default_workflows.py +34 -41
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +37 -39
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +5 -5
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_servicer.py +2 -2
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/METADATA +1 -1
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/RECORD +66 -66
- flwr/common/record/recordset.py +0 -209
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/entry_points.txt +0 -0
@@ -130,9 +130,7 @@ def worker(
|
|
130
130
|
e_code = ErrorCode.UNKNOWN
|
131
131
|
|
132
132
|
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
|
133
|
-
out_mssg = message
|
134
|
-
error=Error(code=e_code, reason=reason)
|
135
|
-
)
|
133
|
+
out_mssg = Message(Error(code=e_code, reason=reason), reply_to=message)
|
136
134
|
|
137
135
|
finally:
|
138
136
|
if out_mssg:
|
@@ -32,7 +32,7 @@ from flwr.common.constant import (
|
|
32
32
|
SUPERLINK_NODE_ID,
|
33
33
|
Status,
|
34
34
|
)
|
35
|
-
from flwr.common.record import
|
35
|
+
from flwr.common.record import ConfigRecord
|
36
36
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
37
37
|
from flwr.server.superlink.linkstate.linkstate import LinkState
|
38
38
|
from flwr.server.utils import validate_message
|
@@ -69,7 +69,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
69
69
|
# Map run_id to RunRecord
|
70
70
|
self.run_ids: dict[int, RunRecord] = {}
|
71
71
|
self.contexts: dict[int, Context] = {}
|
72
|
-
self.federation_options: dict[int,
|
72
|
+
self.federation_options: dict[int, ConfigRecord] = {}
|
73
73
|
self.message_ins_store: dict[UUID, Message] = {}
|
74
74
|
self.message_res_store: dict[UUID, Message] = {}
|
75
75
|
self.message_ins_id_to_message_res_id: dict[UUID, UUID] = {}
|
@@ -158,7 +158,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
158
158
|
res_metadata = message.metadata
|
159
159
|
with self.lock:
|
160
160
|
# Check if the Message it is replying to exists and is valid
|
161
|
-
msg_ins_id = res_metadata.
|
161
|
+
msg_ins_id = res_metadata.reply_to_message_id
|
162
162
|
msg_ins = self.message_ins_store.get(UUID(msg_ins_id))
|
163
163
|
|
164
164
|
# Ensure that dst_node_id of original Message matches the src_node_id of
|
@@ -399,7 +399,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
399
399
|
fab_version: Optional[str],
|
400
400
|
fab_hash: Optional[str],
|
401
401
|
override_config: UserConfig,
|
402
|
-
federation_options:
|
402
|
+
federation_options: ConfigRecord,
|
403
403
|
) -> int:
|
404
404
|
"""Create a new run for the specified `fab_hash`."""
|
405
405
|
# Sample a random int64 as run_id
|
@@ -528,7 +528,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
528
528
|
|
529
529
|
return pending_run_id
|
530
530
|
|
531
|
-
def get_federation_options(self, run_id: int) -> Optional[
|
531
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
|
532
532
|
"""Retrieve the federation options for the specified `run_id`."""
|
533
533
|
with self.lock:
|
534
534
|
if run_id not in self.run_ids:
|
@@ -20,7 +20,7 @@ from typing import Optional
|
|
20
20
|
from uuid import UUID
|
21
21
|
|
22
22
|
from flwr.common import Context, Message
|
23
|
-
from flwr.common.record import
|
23
|
+
from flwr.common.record import ConfigRecord
|
24
24
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
25
25
|
|
26
26
|
|
@@ -164,7 +164,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
164
164
|
fab_version: Optional[str],
|
165
165
|
fab_hash: Optional[str],
|
166
166
|
override_config: UserConfig,
|
167
|
-
federation_options:
|
167
|
+
federation_options: ConfigRecord,
|
168
168
|
) -> int:
|
169
169
|
"""Create a new run for the specified `fab_hash`."""
|
170
170
|
|
@@ -236,7 +236,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
236
236
|
"""
|
237
237
|
|
238
238
|
@abc.abstractmethod
|
239
|
-
def get_federation_options(self, run_id: int) -> Optional[
|
239
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
|
240
240
|
"""Retrieve the federation options for the specified `run_id`.
|
241
241
|
|
242
242
|
Parameters
|
@@ -246,7 +246,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
246
246
|
|
247
247
|
Returns
|
248
248
|
-------
|
249
|
-
Optional[
|
249
|
+
Optional[ConfigRecord]
|
250
250
|
The federation options for the run if it exists; None otherwise.
|
251
251
|
"""
|
252
252
|
|
@@ -35,18 +35,19 @@ from flwr.common.constant import (
|
|
35
35
|
SUPERLINK_NODE_ID,
|
36
36
|
Status,
|
37
37
|
)
|
38
|
-
from flwr.common.
|
38
|
+
from flwr.common.message import make_message
|
39
|
+
from flwr.common.record import ConfigRecord
|
39
40
|
from flwr.common.serde import (
|
40
41
|
error_from_proto,
|
41
42
|
error_to_proto,
|
42
|
-
|
43
|
-
|
43
|
+
recorddict_from_proto,
|
44
|
+
recorddict_to_proto,
|
44
45
|
)
|
45
46
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
46
47
|
|
47
48
|
# pylint: disable=E0611
|
48
49
|
from flwr.proto.error_pb2 import Error as ProtoError
|
49
|
-
from flwr.proto.
|
50
|
+
from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
|
50
51
|
|
51
52
|
# pylint: enable=E0611
|
52
53
|
from flwr.server.utils.validator import validate_message
|
@@ -54,8 +55,8 @@ from flwr.server.utils.validator import validate_message
|
|
54
55
|
from .linkstate import LinkState
|
55
56
|
from .utils import (
|
56
57
|
check_node_availability_for_in_message,
|
57
|
-
|
58
|
-
|
58
|
+
configrecord_from_bytes,
|
59
|
+
configrecord_to_bytes,
|
59
60
|
context_from_bytes,
|
60
61
|
context_to_bytes,
|
61
62
|
convert_sint64_to_uint64,
|
@@ -131,7 +132,7 @@ CREATE TABLE IF NOT EXISTS message_ins(
|
|
131
132
|
run_id INTEGER,
|
132
133
|
src_node_id INTEGER,
|
133
134
|
dst_node_id INTEGER,
|
134
|
-
|
135
|
+
reply_to_message_id TEXT,
|
135
136
|
created_at REAL,
|
136
137
|
delivered_at TEXT,
|
137
138
|
ttl REAL,
|
@@ -150,7 +151,7 @@ CREATE TABLE IF NOT EXISTS message_res(
|
|
150
151
|
run_id INTEGER,
|
151
152
|
src_node_id INTEGER,
|
152
153
|
dst_node_id INTEGER,
|
153
|
-
|
154
|
+
reply_to_message_id TEXT,
|
154
155
|
created_at REAL,
|
155
156
|
delivered_at TEXT,
|
156
157
|
ttl REAL,
|
@@ -373,7 +374,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
373
374
|
return None
|
374
375
|
|
375
376
|
res_metadata = message.metadata
|
376
|
-
msg_ins_id = res_metadata.
|
377
|
+
msg_ins_id = res_metadata.reply_to_message_id
|
377
378
|
msg_ins = self.get_valid_message_ins(msg_ins_id)
|
378
379
|
if msg_ins is None:
|
379
380
|
log(
|
@@ -495,7 +496,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
495
496
|
query = f"""
|
496
497
|
SELECT *
|
497
498
|
FROM message_res
|
498
|
-
WHERE
|
499
|
+
WHERE reply_to_message_id IN ({",".join(["?"] * len(message_ids))})
|
499
500
|
AND delivered_at = "";
|
500
501
|
"""
|
501
502
|
rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
|
@@ -568,7 +569,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
568
569
|
# Delete reply Message
|
569
570
|
query_2 = f"""
|
570
571
|
DELETE FROM message_res
|
571
|
-
WHERE
|
572
|
+
WHERE reply_to_message_id IN ({placeholders});
|
572
573
|
"""
|
573
574
|
|
574
575
|
with self.conn:
|
@@ -726,7 +727,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
726
727
|
fab_version: Optional[str],
|
727
728
|
fab_hash: Optional[str],
|
728
729
|
override_config: UserConfig,
|
729
|
-
federation_options:
|
730
|
+
federation_options: ConfigRecord,
|
730
731
|
) -> int:
|
731
732
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
732
733
|
# Sample a random int64 as run_id
|
@@ -752,7 +753,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
752
753
|
fab_version,
|
753
754
|
fab_hash,
|
754
755
|
override_config_json,
|
755
|
-
|
756
|
+
configrecord_to_bytes(federation_options),
|
756
757
|
]
|
757
758
|
data += [
|
758
759
|
now().isoformat(),
|
@@ -910,7 +911,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
910
911
|
|
911
912
|
return pending_run_id
|
912
913
|
|
913
|
-
def get_federation_options(self, run_id: int) -> Optional[
|
914
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
|
914
915
|
"""Retrieve the federation options for the specified `run_id`."""
|
915
916
|
# Convert the uint64 value to sint64 for SQLite
|
916
917
|
sint64_run_id = convert_uint64_to_sint64(run_id)
|
@@ -923,7 +924,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
923
924
|
return None
|
924
925
|
|
925
926
|
row = rows[0]
|
926
|
-
return
|
927
|
+
return configrecord_from_bytes(row["federation_options"])
|
927
928
|
|
928
929
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
929
930
|
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
@@ -1064,7 +1065,7 @@ def message_to_dict(message: Message) -> dict[str, Any]:
|
|
1064
1065
|
"run_id": message.metadata.run_id,
|
1065
1066
|
"src_node_id": message.metadata.src_node_id,
|
1066
1067
|
"dst_node_id": message.metadata.dst_node_id,
|
1067
|
-
"
|
1068
|
+
"reply_to_message_id": message.metadata.reply_to_message_id,
|
1068
1069
|
"created_at": message.metadata.created_at,
|
1069
1070
|
"delivered_at": message.metadata.delivered_at,
|
1070
1071
|
"ttl": message.metadata.ttl,
|
@@ -1074,7 +1075,7 @@ def message_to_dict(message: Message) -> dict[str, Any]:
|
|
1074
1075
|
}
|
1075
1076
|
|
1076
1077
|
if message.has_content():
|
1077
|
-
result["content"] =
|
1078
|
+
result["content"] = recorddict_to_proto(message.content).SerializeToString()
|
1078
1079
|
else:
|
1079
1080
|
result["error"] = error_to_proto(message.error).SerializeToString()
|
1080
1081
|
|
@@ -1085,20 +1086,15 @@ def dict_to_message(message_dict: dict[str, Any]) -> Message:
|
|
1085
1086
|
"""Transform dict to Message."""
|
1086
1087
|
content, error = None, None
|
1087
1088
|
if (b_content := message_dict.pop("content")) is not None:
|
1088
|
-
content =
|
1089
|
+
content = recorddict_from_proto(ProtoRecordDict.FromString(b_content))
|
1089
1090
|
if (b_error := message_dict.pop("error")) is not None:
|
1090
1091
|
error = error_from_proto(ProtoError.FromString(b_error))
|
1091
1092
|
|
1092
1093
|
# Metadata constructor doesn't allow passing created_at. We set it later
|
1093
1094
|
metadata = Metadata(
|
1094
|
-
**{
|
1095
|
-
k: v
|
1096
|
-
for k, v in message_dict.items()
|
1097
|
-
if k not in ["created_at", "delivered_at"]
|
1098
|
-
}
|
1095
|
+
**{k: v for k, v in message_dict.items() if k not in ["delivered_at"]}
|
1099
1096
|
)
|
1100
|
-
msg =
|
1101
|
-
msg.metadata.__dict__["_created_at"] = message_dict["created_at"]
|
1097
|
+
msg = make_message(metadata=metadata, content=content, error=error)
|
1102
1098
|
msg.metadata.delivered_at = message_dict["delivered_at"]
|
1103
1099
|
return msg
|
1104
1100
|
|
@@ -19,7 +19,7 @@ from os import urandom
|
|
19
19
|
from typing import Optional
|
20
20
|
from uuid import UUID, uuid4
|
21
21
|
|
22
|
-
from flwr.common import
|
22
|
+
from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
|
23
23
|
from flwr.common.constant import (
|
24
24
|
SUPERLINK_NODE_ID,
|
25
25
|
ErrorCode,
|
@@ -27,11 +27,12 @@ from flwr.common.constant import (
|
|
27
27
|
Status,
|
28
28
|
SubStatus,
|
29
29
|
)
|
30
|
+
from flwr.common.message import make_message
|
30
31
|
from flwr.common.typing import RunStatus
|
31
32
|
|
32
33
|
# pylint: disable=E0611
|
33
34
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
34
|
-
from flwr.proto.
|
35
|
+
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
35
36
|
|
36
37
|
# pylint: enable=E0611
|
37
38
|
VALID_RUN_STATUS_TRANSITIONS = {
|
@@ -171,15 +172,15 @@ def context_from_bytes(context_bytes: bytes) -> Context:
|
|
171
172
|
return serde.context_from_proto(ProtoContext.FromString(context_bytes))
|
172
173
|
|
173
174
|
|
174
|
-
def
|
175
|
-
"""Serialize a `
|
176
|
-
return serde.
|
175
|
+
def configrecord_to_bytes(config_record: ConfigRecord) -> bytes:
|
176
|
+
"""Serialize a `ConfigRecord` to bytes."""
|
177
|
+
return serde.config_record_to_proto(config_record).SerializeToString()
|
177
178
|
|
178
179
|
|
179
|
-
def
|
180
|
-
"""Deserialize `
|
181
|
-
return serde.
|
182
|
-
|
180
|
+
def configrecord_from_bytes(configrecord_bytes: bytes) -> ConfigRecord:
|
181
|
+
"""Deserialize `ConfigRecord` from bytes."""
|
182
|
+
return serde.config_record_from_proto(
|
183
|
+
ProtoConfigRecord.FromString(configrecord_bytes)
|
183
184
|
)
|
184
185
|
|
185
186
|
|
@@ -247,13 +248,14 @@ def create_message_error_unavailable_res_message(
|
|
247
248
|
message_id=str(uuid4()),
|
248
249
|
src_node_id=SUPERLINK_NODE_ID,
|
249
250
|
dst_node_id=SUPERLINK_NODE_ID,
|
250
|
-
|
251
|
+
reply_to_message_id=ins_metadata.message_id,
|
251
252
|
group_id=ins_metadata.group_id,
|
252
253
|
message_type=ins_metadata.message_type,
|
254
|
+
created_at=current_time,
|
253
255
|
ttl=ttl,
|
254
256
|
)
|
255
257
|
|
256
|
-
return
|
258
|
+
return make_message(
|
257
259
|
metadata=metadata,
|
258
260
|
error=Error(
|
259
261
|
code=(
|
@@ -270,7 +272,7 @@ def create_message_error_unavailable_res_message(
|
|
270
272
|
)
|
271
273
|
|
272
274
|
|
273
|
-
def create_message_error_unavailable_ins_message(
|
275
|
+
def create_message_error_unavailable_ins_message(reply_to_message_id: UUID) -> Message:
|
274
276
|
"""Error to indicate that the enquired Message had expired before reply arrived or
|
275
277
|
that it isn't found."""
|
276
278
|
metadata = Metadata(
|
@@ -278,13 +280,14 @@ def create_message_error_unavailable_ins_message(reply_to_message: UUID) -> Mess
|
|
278
280
|
message_id=str(uuid4()),
|
279
281
|
src_node_id=SUPERLINK_NODE_ID,
|
280
282
|
dst_node_id=SUPERLINK_NODE_ID,
|
281
|
-
|
283
|
+
reply_to_message_id=str(reply_to_message_id),
|
282
284
|
group_id="", # Unknown
|
283
285
|
message_type=MessageType.SYSTEM,
|
286
|
+
created_at=now().timestamp(),
|
284
287
|
ttl=0,
|
285
288
|
)
|
286
289
|
|
287
|
-
return
|
290
|
+
return make_message(
|
288
291
|
metadata=metadata,
|
289
292
|
error=Error(
|
290
293
|
code=ErrorCode.MESSAGE_UNAVAILABLE,
|
@@ -372,7 +375,7 @@ def verify_found_message_replies(
|
|
372
375
|
ret_dict: dict[UUID, Message] = {}
|
373
376
|
current = current_time if current_time else now().timestamp()
|
374
377
|
for message_res in found_message_res_list:
|
375
|
-
message_ins_id = UUID(message_res.metadata.
|
378
|
+
message_ins_id = UUID(message_res.metadata.reply_to_message_id)
|
376
379
|
if update_set:
|
377
380
|
inquired_message_ids.remove(message_ins_id)
|
378
381
|
# Check if the reply Message has expired
|
@@ -22,7 +22,7 @@ from uuid import UUID
|
|
22
22
|
|
23
23
|
import grpc
|
24
24
|
|
25
|
-
from flwr.common import
|
25
|
+
from flwr.common import ConfigRecord, Message
|
26
26
|
from flwr.common.constant import SUPERLINK_NODE_ID, Status
|
27
27
|
from flwr.common.logger import log
|
28
28
|
from flwr.common.serde import (
|
@@ -127,7 +127,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
127
127
|
request.fab_version,
|
128
128
|
fab_hash,
|
129
129
|
user_config_from_proto(request.override_config),
|
130
|
-
|
130
|
+
ConfigRecord(),
|
131
131
|
)
|
132
132
|
return CreateRunResponse(run_id=run_id)
|
133
133
|
|
@@ -206,7 +206,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
206
206
|
|
207
207
|
# Delete the instruction Messages and their replies if found
|
208
208
|
message_ins_ids_to_delete = {
|
209
|
-
UUID(msg_res.metadata.
|
209
|
+
UUID(msg_res.metadata.reply_to_message_id) for msg_res in messages_res
|
210
210
|
}
|
211
211
|
|
212
212
|
state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
@@ -24,7 +24,7 @@ from grpc import ServicerContext
|
|
24
24
|
from flwr.common.constant import Status
|
25
25
|
from flwr.common.logger import log
|
26
26
|
from flwr.common.serde import (
|
27
|
-
|
27
|
+
config_record_to_proto,
|
28
28
|
context_from_proto,
|
29
29
|
context_to_proto,
|
30
30
|
fab_to_proto,
|
@@ -182,5 +182,5 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
182
182
|
)
|
183
183
|
return GetFederationOptionsResponse()
|
184
184
|
return GetFederationOptionsResponse(
|
185
|
-
federation_options=
|
185
|
+
federation_options=config_record_to_proto(federation_options)
|
186
186
|
)
|
flwr/server/utils/validator.py
CHANGED
@@ -68,8 +68,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
68
68
|
|
69
69
|
# Link respose to original message
|
70
70
|
if not is_reply_message:
|
71
|
-
if metadata.
|
72
|
-
validation_errors.append("`metadata.
|
71
|
+
if metadata.reply_to_message_id != "":
|
72
|
+
validation_errors.append("`metadata.reply_to_message_id` MUST not be set.")
|
73
73
|
if metadata.src_node_id != SUPERLINK_NODE_ID:
|
74
74
|
validation_errors.append(
|
75
75
|
f"`metadata.src_node_id` is not {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
@@ -79,8 +79,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
79
79
|
f"`metadata.dst_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
80
80
|
)
|
81
81
|
else:
|
82
|
-
if metadata.
|
83
|
-
validation_errors.append("`metadata.
|
82
|
+
if metadata.reply_to_message_id == "":
|
83
|
+
validation_errors.append("`metadata.reply_to_message_id` MUST be set.")
|
84
84
|
if metadata.src_node_id == SUPERLINK_NODE_ID:
|
85
85
|
validation_errors.append(
|
86
86
|
f"`metadata.src_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
@@ -20,15 +20,16 @@ import timeit
|
|
20
20
|
from logging import INFO, WARN
|
21
21
|
from typing import Optional, Union, cast
|
22
22
|
|
23
|
-
import flwr.common.
|
23
|
+
import flwr.common.recorddict_compat as compat
|
24
24
|
from flwr.common import (
|
25
|
+
ArrayRecord,
|
25
26
|
Code,
|
26
|
-
|
27
|
+
ConfigRecord,
|
27
28
|
Context,
|
28
29
|
EvaluateRes,
|
29
30
|
FitRes,
|
30
31
|
GetParametersIns,
|
31
|
-
|
32
|
+
Message,
|
32
33
|
log,
|
33
34
|
)
|
34
35
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
@@ -77,9 +78,9 @@ class DefaultWorkflow:
|
|
77
78
|
|
78
79
|
# Run federated learning for num_rounds
|
79
80
|
start_time = timeit.default_timer()
|
80
|
-
cfg =
|
81
|
+
cfg = ConfigRecord()
|
81
82
|
cfg[Key.START_TIME] = start_time
|
82
|
-
context.state.
|
83
|
+
context.state.config_records[MAIN_CONFIGS_RECORD] = cfg
|
83
84
|
|
84
85
|
for current_round in range(1, context.config.num_rounds + 1):
|
85
86
|
log(INFO, "")
|
@@ -129,21 +130,19 @@ def default_init_params_workflow(grid: Grid, context: Context) -> None:
|
|
129
130
|
)
|
130
131
|
if parameters is not None:
|
131
132
|
log(INFO, "Using initial global parameters provided by strategy")
|
132
|
-
|
133
|
-
parameters, keep_input=True
|
134
|
-
)
|
133
|
+
arr_record = compat.parameters_to_arrayrecord(parameters, keep_input=True)
|
135
134
|
else:
|
136
135
|
# Get initial parameters from one of the clients
|
137
136
|
log(INFO, "Requesting initial parameters from one random client")
|
138
137
|
random_client = context.client_manager.sample(1)[0]
|
139
138
|
# Send GetParametersIns and get the response
|
140
|
-
content = compat.
|
139
|
+
content = compat.getparametersins_to_recorddict(GetParametersIns({}))
|
141
140
|
messages = grid.send_and_receive(
|
142
141
|
[
|
143
|
-
|
142
|
+
Message(
|
144
143
|
content=content,
|
145
|
-
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
146
144
|
dst_node_id=random_client.node_id,
|
145
|
+
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
147
146
|
group_id="0",
|
148
147
|
)
|
149
148
|
]
|
@@ -152,26 +151,26 @@ def default_init_params_workflow(grid: Grid, context: Context) -> None:
|
|
152
151
|
|
153
152
|
if (
|
154
153
|
msg.has_content()
|
155
|
-
and compat.
|
154
|
+
and compat._extract_status_from_recorddict( # pylint: disable=W0212
|
156
155
|
"getparametersres", msg.content
|
157
156
|
).code
|
158
157
|
== Code.OK
|
159
158
|
):
|
160
159
|
log(INFO, "Received initial parameters from one random client")
|
161
|
-
|
160
|
+
arr_record = next(iter(msg.content.array_records.values()))
|
162
161
|
else:
|
163
162
|
log(
|
164
163
|
WARN,
|
165
164
|
"Failed to receive initial parameters from the client."
|
166
165
|
" Empty initial parameters will be used.",
|
167
166
|
)
|
168
|
-
|
167
|
+
arr_record = ArrayRecord()
|
169
168
|
|
170
|
-
context.state.
|
169
|
+
context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
|
171
170
|
|
172
171
|
# Evaluate initial parameters
|
173
172
|
log(INFO, "Starting evaluation of initial global parameters")
|
174
|
-
parameters = compat.
|
173
|
+
parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
|
175
174
|
res = context.strategy.evaluate(0, parameters=parameters)
|
176
175
|
if res is not None:
|
177
176
|
log(
|
@@ -192,13 +191,13 @@ def default_centralized_evaluation_workflow(_: Grid, context: Context) -> None:
|
|
192
191
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
193
192
|
|
194
193
|
# Retrieve current_round and start_time from the context
|
195
|
-
cfg = context.state.
|
194
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
196
195
|
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
197
196
|
start_time = cast(float, cfg[Key.START_TIME])
|
198
197
|
|
199
198
|
# Centralized evaluation
|
200
|
-
parameters = compat.
|
201
|
-
record=context.state.
|
199
|
+
parameters = compat.arrayrecord_to_parameters(
|
200
|
+
record=context.state.array_records[MAIN_PARAMS_RECORD],
|
202
201
|
keep_input=True,
|
203
202
|
)
|
204
203
|
res_cen = context.strategy.evaluate(current_round, parameters=parameters)
|
@@ -224,12 +223,10 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
|
|
224
223
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
225
224
|
|
226
225
|
# Get current_round and parameters
|
227
|
-
cfg = context.state.
|
226
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
228
227
|
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
229
|
-
|
230
|
-
parameters = compat.
|
231
|
-
parametersrecord, keep_input=True
|
232
|
-
)
|
228
|
+
arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
|
229
|
+
parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
|
233
230
|
|
234
231
|
# Get clients and their respective instructions from strategy
|
235
232
|
client_instructions = context.strategy.configure_fit(
|
@@ -253,10 +250,10 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
|
|
253
250
|
|
254
251
|
# Build out messages
|
255
252
|
out_messages = [
|
256
|
-
|
257
|
-
content=compat.
|
258
|
-
message_type=MessageType.TRAIN,
|
253
|
+
Message(
|
254
|
+
content=compat.fitins_to_recorddict(fitins, True),
|
259
255
|
dst_node_id=proxy.node_id,
|
256
|
+
message_type=MessageType.TRAIN,
|
260
257
|
group_id=str(current_round),
|
261
258
|
)
|
262
259
|
for proxy, fitins in client_instructions
|
@@ -282,7 +279,7 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
|
|
282
279
|
for msg in messages:
|
283
280
|
if msg.has_content():
|
284
281
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
285
|
-
fitres = compat.
|
282
|
+
fitres = compat.recorddict_to_fitres(msg.content, False)
|
286
283
|
if fitres.status.code == Code.OK:
|
287
284
|
results.append((proxy, fitres))
|
288
285
|
else:
|
@@ -295,10 +292,8 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
|
|
295
292
|
|
296
293
|
# Update the parameters and write history
|
297
294
|
if parameters_aggregated:
|
298
|
-
|
299
|
-
|
300
|
-
)
|
301
|
-
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
295
|
+
arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
|
296
|
+
context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
|
302
297
|
context.history.add_metrics_distributed_fit(
|
303
298
|
server_round=current_round, metrics=metrics_aggregated
|
304
299
|
)
|
@@ -311,12 +306,10 @@ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
|
|
311
306
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
312
307
|
|
313
308
|
# Get current_round and parameters
|
314
|
-
cfg = context.state.
|
309
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
315
310
|
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
316
|
-
|
317
|
-
parameters = compat.
|
318
|
-
parametersrecord, keep_input=True
|
319
|
-
)
|
311
|
+
arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
|
312
|
+
parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
|
320
313
|
|
321
314
|
# Get clients and their respective instructions from strategy
|
322
315
|
client_instructions = context.strategy.configure_evaluate(
|
@@ -339,10 +332,10 @@ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
|
|
339
332
|
|
340
333
|
# Build out messages
|
341
334
|
out_messages = [
|
342
|
-
|
343
|
-
content=compat.
|
344
|
-
message_type=MessageType.EVALUATE,
|
335
|
+
Message(
|
336
|
+
content=compat.evaluateins_to_recorddict(evalins, True),
|
345
337
|
dst_node_id=proxy.node_id,
|
338
|
+
message_type=MessageType.EVALUATE,
|
346
339
|
group_id=str(current_round),
|
347
340
|
)
|
348
341
|
for proxy, evalins in client_instructions
|
@@ -368,7 +361,7 @@ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
|
|
368
361
|
for msg in messages:
|
369
362
|
if msg.has_content():
|
370
363
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
371
|
-
evalres = compat.
|
364
|
+
evalres = compat.recorddict_to_evaluateres(msg.content)
|
372
365
|
if evalres.status.code == Code.OK:
|
373
366
|
results.append((proxy, evalres))
|
374
367
|
else:
|