flwr 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +162 -99
- flwr/client/clientapp/app.py +2 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +27 -27
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/run_info_store.py +2 -2
- flwr/common/__init__.py +12 -4
- flwr/common/config.py +4 -4
- flwr/common/constant.py +6 -6
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/logger.py +2 -2
- flwr/common/message.py +327 -102
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +66 -71
- flwr/common/typing.py +8 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +56 -1
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +11 -11
- flwr/server/compat/app_utils.py +16 -16
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +47 -18
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
- flwr/server/run_serverapp.py +4 -4
- flwr/server/server_app.py +38 -18
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +1 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +33 -8
- flwr/server/superlink/linkstate/linkstate.py +4 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +61 -27
- flwr/server/superlink/linkstate/utils.py +93 -27
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/validator.py +4 -4
- flwr/server/workflow/default_workflows.py +48 -57
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +2 -2
- flwr/superexec/exec_user_auth_interceptor.py +18 -2
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/METADATA +2 -2
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/RECORD +94 -92
- flwr/common/record/parametersrecord.py +0 -339
- flwr/common/record/recordset.py +0 -209
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -0
|
@@ -30,30 +30,33 @@ from flwr.common import Context, Message, Metadata, log, now
|
|
|
30
30
|
from flwr.common.constant import (
|
|
31
31
|
MESSAGE_TTL_TOLERANCE,
|
|
32
32
|
NODE_ID_NUM_BYTES,
|
|
33
|
+
PING_PATIENCE,
|
|
33
34
|
RUN_ID_NUM_BYTES,
|
|
34
35
|
SUPERLINK_NODE_ID,
|
|
35
36
|
Status,
|
|
36
37
|
)
|
|
37
|
-
from flwr.common.
|
|
38
|
+
from flwr.common.message import make_message
|
|
39
|
+
from flwr.common.record import ConfigRecord
|
|
38
40
|
from flwr.common.serde import (
|
|
39
41
|
error_from_proto,
|
|
40
42
|
error_to_proto,
|
|
41
|
-
|
|
42
|
-
|
|
43
|
+
recorddict_from_proto,
|
|
44
|
+
recorddict_to_proto,
|
|
43
45
|
)
|
|
44
46
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
45
47
|
|
|
46
48
|
# pylint: disable=E0611
|
|
47
49
|
from flwr.proto.error_pb2 import Error as ProtoError
|
|
48
|
-
from flwr.proto.
|
|
50
|
+
from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
|
|
49
51
|
|
|
50
52
|
# pylint: enable=E0611
|
|
51
53
|
from flwr.server.utils.validator import validate_message
|
|
52
54
|
|
|
53
55
|
from .linkstate import LinkState
|
|
54
56
|
from .utils import (
|
|
55
|
-
|
|
56
|
-
|
|
57
|
+
check_node_availability_for_in_message,
|
|
58
|
+
configrecord_from_bytes,
|
|
59
|
+
configrecord_to_bytes,
|
|
57
60
|
context_from_bytes,
|
|
58
61
|
context_to_bytes,
|
|
59
62
|
convert_sint64_to_uint64,
|
|
@@ -129,7 +132,7 @@ CREATE TABLE IF NOT EXISTS message_ins(
|
|
|
129
132
|
run_id INTEGER,
|
|
130
133
|
src_node_id INTEGER,
|
|
131
134
|
dst_node_id INTEGER,
|
|
132
|
-
|
|
135
|
+
reply_to_message_id TEXT,
|
|
133
136
|
created_at REAL,
|
|
134
137
|
delivered_at TEXT,
|
|
135
138
|
ttl REAL,
|
|
@@ -148,7 +151,7 @@ CREATE TABLE IF NOT EXISTS message_res(
|
|
|
148
151
|
run_id INTEGER,
|
|
149
152
|
src_node_id INTEGER,
|
|
150
153
|
dst_node_id INTEGER,
|
|
151
|
-
|
|
154
|
+
reply_to_message_id TEXT,
|
|
152
155
|
created_at REAL,
|
|
153
156
|
delivered_at TEXT,
|
|
154
157
|
ttl REAL,
|
|
@@ -371,7 +374,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
371
374
|
return None
|
|
372
375
|
|
|
373
376
|
res_metadata = message.metadata
|
|
374
|
-
msg_ins_id = res_metadata.
|
|
377
|
+
msg_ins_id = res_metadata.reply_to_message_id
|
|
375
378
|
msg_ins = self.get_valid_message_ins(msg_ins_id)
|
|
376
379
|
if msg_ins is None:
|
|
377
380
|
log(
|
|
@@ -442,6 +445,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
442
445
|
|
|
443
446
|
def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
|
|
444
447
|
"""Get reply Messages for the given Message IDs."""
|
|
448
|
+
# pylint: disable-msg=too-many-locals
|
|
445
449
|
ret: dict[UUID, Message] = {}
|
|
446
450
|
|
|
447
451
|
# Verify Message IDs
|
|
@@ -465,11 +469,34 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
465
469
|
current_time=current,
|
|
466
470
|
)
|
|
467
471
|
|
|
472
|
+
# Check node availability
|
|
473
|
+
dst_node_ids: set[int] = set()
|
|
474
|
+
for message_id in message_ids:
|
|
475
|
+
in_message = found_message_ins_dict[message_id]
|
|
476
|
+
sint_node_id = convert_uint64_to_sint64(in_message.metadata.dst_node_id)
|
|
477
|
+
dst_node_ids.add(sint_node_id)
|
|
478
|
+
query = f"""
|
|
479
|
+
SELECT node_id, online_until
|
|
480
|
+
FROM node
|
|
481
|
+
WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))});
|
|
482
|
+
"""
|
|
483
|
+
rows = self.query(query, tuple(dst_node_ids))
|
|
484
|
+
tmp_ret_dict = check_node_availability_for_in_message(
|
|
485
|
+
inquired_in_message_ids=message_ids,
|
|
486
|
+
found_in_message_dict=found_message_ins_dict,
|
|
487
|
+
node_id_to_online_until={
|
|
488
|
+
convert_sint64_to_uint64(row["node_id"]): row["online_until"]
|
|
489
|
+
for row in rows
|
|
490
|
+
},
|
|
491
|
+
current_time=current,
|
|
492
|
+
)
|
|
493
|
+
ret.update(tmp_ret_dict)
|
|
494
|
+
|
|
468
495
|
# Find all reply Messages
|
|
469
496
|
query = f"""
|
|
470
497
|
SELECT *
|
|
471
498
|
FROM message_res
|
|
472
|
-
WHERE
|
|
499
|
+
WHERE reply_to_message_id IN ({",".join(["?"] * len(message_ids))})
|
|
473
500
|
AND delivered_at = "";
|
|
474
501
|
"""
|
|
475
502
|
rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
|
|
@@ -542,7 +569,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
542
569
|
# Delete reply Message
|
|
543
570
|
query_2 = f"""
|
|
544
571
|
DELETE FROM message_res
|
|
545
|
-
WHERE
|
|
572
|
+
WHERE reply_to_message_id IN ({placeholders});
|
|
546
573
|
"""
|
|
547
574
|
|
|
548
575
|
with self.conn:
|
|
@@ -584,6 +611,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
584
611
|
"VALUES (?, ?, ?, ?)"
|
|
585
612
|
)
|
|
586
613
|
|
|
614
|
+
# Mark the node online util time.time() + ping_interval
|
|
587
615
|
try:
|
|
588
616
|
self.query(
|
|
589
617
|
query,
|
|
@@ -699,7 +727,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
699
727
|
fab_version: Optional[str],
|
|
700
728
|
fab_hash: Optional[str],
|
|
701
729
|
override_config: UserConfig,
|
|
702
|
-
federation_options:
|
|
730
|
+
federation_options: ConfigRecord,
|
|
703
731
|
) -> int:
|
|
704
732
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
705
733
|
# Sample a random int64 as run_id
|
|
@@ -725,7 +753,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
725
753
|
fab_version,
|
|
726
754
|
fab_hash,
|
|
727
755
|
override_config_json,
|
|
728
|
-
|
|
756
|
+
configrecord_to_bytes(federation_options),
|
|
729
757
|
]
|
|
730
758
|
data += [
|
|
731
759
|
now().isoformat(),
|
|
@@ -883,7 +911,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
883
911
|
|
|
884
912
|
return pending_run_id
|
|
885
913
|
|
|
886
|
-
def get_federation_options(self, run_id: int) -> Optional[
|
|
914
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
|
|
887
915
|
"""Retrieve the federation options for the specified `run_id`."""
|
|
888
916
|
# Convert the uint64 value to sint64 for SQLite
|
|
889
917
|
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
@@ -896,10 +924,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
896
924
|
return None
|
|
897
925
|
|
|
898
926
|
row = rows[0]
|
|
899
|
-
return
|
|
927
|
+
return configrecord_from_bytes(row["federation_options"])
|
|
900
928
|
|
|
901
929
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
902
|
-
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
|
930
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
|
931
|
+
|
|
932
|
+
It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
|
|
933
|
+
marking the node as offline, where PING_PATIENCE = 2 in default.
|
|
934
|
+
"""
|
|
903
935
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
904
936
|
|
|
905
937
|
# Check if the node exists in the `node` table
|
|
@@ -909,7 +941,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
909
941
|
|
|
910
942
|
# Update `online_until` and `ping_interval` for the given `node_id`
|
|
911
943
|
query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
|
|
912
|
-
self.query(
|
|
944
|
+
self.query(
|
|
945
|
+
query,
|
|
946
|
+
(
|
|
947
|
+
time.time() + PING_PATIENCE * ping_interval,
|
|
948
|
+
ping_interval,
|
|
949
|
+
sint64_node_id,
|
|
950
|
+
),
|
|
951
|
+
)
|
|
913
952
|
return True
|
|
914
953
|
|
|
915
954
|
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
@@ -1026,7 +1065,7 @@ def message_to_dict(message: Message) -> dict[str, Any]:
|
|
|
1026
1065
|
"run_id": message.metadata.run_id,
|
|
1027
1066
|
"src_node_id": message.metadata.src_node_id,
|
|
1028
1067
|
"dst_node_id": message.metadata.dst_node_id,
|
|
1029
|
-
"
|
|
1068
|
+
"reply_to_message_id": message.metadata.reply_to_message_id,
|
|
1030
1069
|
"created_at": message.metadata.created_at,
|
|
1031
1070
|
"delivered_at": message.metadata.delivered_at,
|
|
1032
1071
|
"ttl": message.metadata.ttl,
|
|
@@ -1036,7 +1075,7 @@ def message_to_dict(message: Message) -> dict[str, Any]:
|
|
|
1036
1075
|
}
|
|
1037
1076
|
|
|
1038
1077
|
if message.has_content():
|
|
1039
|
-
result["content"] =
|
|
1078
|
+
result["content"] = recorddict_to_proto(message.content).SerializeToString()
|
|
1040
1079
|
else:
|
|
1041
1080
|
result["error"] = error_to_proto(message.error).SerializeToString()
|
|
1042
1081
|
|
|
@@ -1047,20 +1086,15 @@ def dict_to_message(message_dict: dict[str, Any]) -> Message:
|
|
|
1047
1086
|
"""Transform dict to Message."""
|
|
1048
1087
|
content, error = None, None
|
|
1049
1088
|
if (b_content := message_dict.pop("content")) is not None:
|
|
1050
|
-
content =
|
|
1089
|
+
content = recorddict_from_proto(ProtoRecordDict.FromString(b_content))
|
|
1051
1090
|
if (b_error := message_dict.pop("error")) is not None:
|
|
1052
1091
|
error = error_from_proto(ProtoError.FromString(b_error))
|
|
1053
1092
|
|
|
1054
1093
|
# Metadata constructor doesn't allow passing created_at. We set it later
|
|
1055
1094
|
metadata = Metadata(
|
|
1056
|
-
**{
|
|
1057
|
-
k: v
|
|
1058
|
-
for k, v in message_dict.items()
|
|
1059
|
-
if k not in ["created_at", "delivered_at"]
|
|
1060
|
-
}
|
|
1095
|
+
**{k: v for k, v in message_dict.items() if k not in ["delivered_at"]}
|
|
1061
1096
|
)
|
|
1062
|
-
msg =
|
|
1063
|
-
msg.metadata.__dict__["_created_at"] = message_dict["created_at"]
|
|
1097
|
+
msg = make_message(metadata=metadata, content=content, error=error)
|
|
1064
1098
|
msg.metadata.delivered_at = message_dict["delivered_at"]
|
|
1065
1099
|
return msg
|
|
1066
1100
|
|
|
@@ -19,21 +19,22 @@ from os import urandom
|
|
|
19
19
|
from typing import Optional
|
|
20
20
|
from uuid import UUID, uuid4
|
|
21
21
|
|
|
22
|
-
from flwr.common import
|
|
23
|
-
from flwr.common.constant import
|
|
22
|
+
from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
|
|
23
|
+
from flwr.common.constant import (
|
|
24
|
+
SUPERLINK_NODE_ID,
|
|
25
|
+
ErrorCode,
|
|
26
|
+
MessageType,
|
|
27
|
+
Status,
|
|
28
|
+
SubStatus,
|
|
29
|
+
)
|
|
30
|
+
from flwr.common.message import make_message
|
|
24
31
|
from flwr.common.typing import RunStatus
|
|
25
32
|
|
|
26
33
|
# pylint: disable=E0611
|
|
27
34
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
|
28
|
-
from flwr.proto.
|
|
35
|
+
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
|
29
36
|
|
|
30
37
|
# pylint: enable=E0611
|
|
31
|
-
|
|
32
|
-
NODE_UNAVAILABLE_ERROR_REASON = (
|
|
33
|
-
"Error: Node Unavailable - The destination node is currently unavailable. "
|
|
34
|
-
"It exceeds the time limit specified in its last ping."
|
|
35
|
-
)
|
|
36
|
-
|
|
37
38
|
VALID_RUN_STATUS_TRANSITIONS = {
|
|
38
39
|
(Status.PENDING, Status.STARTING),
|
|
39
40
|
(Status.STARTING, Status.RUNNING),
|
|
@@ -54,6 +55,10 @@ MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
|
54
55
|
REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
55
56
|
"Error: Reply Message Unavailable - The reply message has expired."
|
|
56
57
|
)
|
|
58
|
+
NODE_UNAVAILABLE_ERROR_REASON = (
|
|
59
|
+
"Error: Node Unavailable - The destination node is currently unavailable. "
|
|
60
|
+
"It exceeds twice the time limit specified in its last ping."
|
|
61
|
+
)
|
|
57
62
|
|
|
58
63
|
|
|
59
64
|
def generate_rand_int_from_bytes(
|
|
@@ -167,15 +172,15 @@ def context_from_bytes(context_bytes: bytes) -> Context:
|
|
|
167
172
|
return serde.context_from_proto(ProtoContext.FromString(context_bytes))
|
|
168
173
|
|
|
169
174
|
|
|
170
|
-
def
|
|
171
|
-
"""Serialize a `
|
|
172
|
-
return serde.
|
|
175
|
+
def configrecord_to_bytes(config_record: ConfigRecord) -> bytes:
|
|
176
|
+
"""Serialize a `ConfigRecord` to bytes."""
|
|
177
|
+
return serde.config_record_to_proto(config_record).SerializeToString()
|
|
173
178
|
|
|
174
179
|
|
|
175
|
-
def
|
|
176
|
-
"""Deserialize `
|
|
177
|
-
return serde.
|
|
178
|
-
|
|
180
|
+
def configrecord_from_bytes(configrecord_bytes: bytes) -> ConfigRecord:
|
|
181
|
+
"""Deserialize `ConfigRecord` from bytes."""
|
|
182
|
+
return serde.config_record_from_proto(
|
|
183
|
+
ProtoConfigRecord.FromString(configrecord_bytes)
|
|
179
184
|
)
|
|
180
185
|
|
|
181
186
|
|
|
@@ -231,7 +236,9 @@ def has_valid_sub_status(status: RunStatus) -> bool:
|
|
|
231
236
|
return status.sub_status == ""
|
|
232
237
|
|
|
233
238
|
|
|
234
|
-
def create_message_error_unavailable_res_message(
|
|
239
|
+
def create_message_error_unavailable_res_message(
|
|
240
|
+
ins_metadata: Metadata, error_type: str
|
|
241
|
+
) -> Message:
|
|
235
242
|
"""Generate an error Message that the SuperLink returns carrying the specified
|
|
236
243
|
error."""
|
|
237
244
|
current_time = now().timestamp()
|
|
@@ -241,22 +248,31 @@ def create_message_error_unavailable_res_message(ins_metadata: Metadata) -> Mess
|
|
|
241
248
|
message_id=str(uuid4()),
|
|
242
249
|
src_node_id=SUPERLINK_NODE_ID,
|
|
243
250
|
dst_node_id=SUPERLINK_NODE_ID,
|
|
244
|
-
|
|
251
|
+
reply_to_message_id=ins_metadata.message_id,
|
|
245
252
|
group_id=ins_metadata.group_id,
|
|
246
253
|
message_type=ins_metadata.message_type,
|
|
254
|
+
created_at=current_time,
|
|
247
255
|
ttl=ttl,
|
|
248
256
|
)
|
|
249
257
|
|
|
250
|
-
return
|
|
258
|
+
return make_message(
|
|
251
259
|
metadata=metadata,
|
|
252
260
|
error=Error(
|
|
253
|
-
code=
|
|
254
|
-
|
|
261
|
+
code=(
|
|
262
|
+
ErrorCode.REPLY_MESSAGE_UNAVAILABLE
|
|
263
|
+
if error_type == "msg_unavail"
|
|
264
|
+
else ErrorCode.NODE_UNAVAILABLE
|
|
265
|
+
),
|
|
266
|
+
reason=(
|
|
267
|
+
REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON
|
|
268
|
+
if error_type == "msg_unavail"
|
|
269
|
+
else NODE_UNAVAILABLE_ERROR_REASON
|
|
270
|
+
),
|
|
255
271
|
),
|
|
256
272
|
)
|
|
257
273
|
|
|
258
274
|
|
|
259
|
-
def create_message_error_unavailable_ins_message(
|
|
275
|
+
def create_message_error_unavailable_ins_message(reply_to_message_id: UUID) -> Message:
|
|
260
276
|
"""Error to indicate that the enquired Message had expired before reply arrived or
|
|
261
277
|
that it isn't found."""
|
|
262
278
|
metadata = Metadata(
|
|
@@ -264,13 +280,14 @@ def create_message_error_unavailable_ins_message(reply_to_message: UUID) -> Mess
|
|
|
264
280
|
message_id=str(uuid4()),
|
|
265
281
|
src_node_id=SUPERLINK_NODE_ID,
|
|
266
282
|
dst_node_id=SUPERLINK_NODE_ID,
|
|
267
|
-
|
|
283
|
+
reply_to_message_id=str(reply_to_message_id),
|
|
268
284
|
group_id="", # Unknown
|
|
269
|
-
message_type=
|
|
285
|
+
message_type=MessageType.SYSTEM,
|
|
286
|
+
created_at=now().timestamp(),
|
|
270
287
|
ttl=0,
|
|
271
288
|
)
|
|
272
289
|
|
|
273
|
-
return
|
|
290
|
+
return make_message(
|
|
274
291
|
metadata=metadata,
|
|
275
292
|
error=Error(
|
|
276
293
|
code=ErrorCode.MESSAGE_UNAVAILABLE,
|
|
@@ -358,14 +375,63 @@ def verify_found_message_replies(
|
|
|
358
375
|
ret_dict: dict[UUID, Message] = {}
|
|
359
376
|
current = current_time if current_time else now().timestamp()
|
|
360
377
|
for message_res in found_message_res_list:
|
|
361
|
-
message_ins_id = UUID(message_res.metadata.
|
|
378
|
+
message_ins_id = UUID(message_res.metadata.reply_to_message_id)
|
|
362
379
|
if update_set:
|
|
363
380
|
inquired_message_ids.remove(message_ins_id)
|
|
364
381
|
# Check if the reply Message has expired
|
|
365
382
|
if message_ttl_has_expired(message_res.metadata, current):
|
|
366
383
|
# No need to insert the error Message
|
|
367
384
|
message_res = create_message_error_unavailable_res_message(
|
|
368
|
-
found_message_ins_dict[message_ins_id].metadata
|
|
385
|
+
found_message_ins_dict[message_ins_id].metadata, "msg_unavail"
|
|
369
386
|
)
|
|
370
387
|
ret_dict[message_ins_id] = message_res
|
|
371
388
|
return ret_dict
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def check_node_availability_for_in_message(
|
|
392
|
+
inquired_in_message_ids: set[UUID],
|
|
393
|
+
found_in_message_dict: dict[UUID, Message],
|
|
394
|
+
node_id_to_online_until: dict[int, float],
|
|
395
|
+
current_time: Optional[float] = None,
|
|
396
|
+
update_set: bool = True,
|
|
397
|
+
) -> dict[UUID, Message]:
|
|
398
|
+
"""Check node availability for given Message and generate error reply Message if
|
|
399
|
+
unavailable. A Message error indicating node unavailability will be generated for
|
|
400
|
+
each given Message whose destination node is offline or non-existent.
|
|
401
|
+
|
|
402
|
+
Parameters
|
|
403
|
+
----------
|
|
404
|
+
inquired_in_message_ids : set[UUID]
|
|
405
|
+
Set of Message IDs for which to check destination node availability.
|
|
406
|
+
found_in_message_dict : dict[UUID, Message]
|
|
407
|
+
Dictionary containing all found Message indexed by their IDs.
|
|
408
|
+
node_id_to_online_until : dict[int, float]
|
|
409
|
+
Dictionary mapping node IDs to their online-until timestamps.
|
|
410
|
+
current_time : Optional[float] (default: None)
|
|
411
|
+
The current time to check for expiration. If set to `None`, the current time
|
|
412
|
+
will automatically be set to the current timestamp using `now().timestamp()`.
|
|
413
|
+
update_set : bool (default: True)
|
|
414
|
+
If True, the `inquired_in_message_ids` will be updated to remove invalid ones,
|
|
415
|
+
by default True.
|
|
416
|
+
|
|
417
|
+
Returns
|
|
418
|
+
-------
|
|
419
|
+
dict[UUID, Message]
|
|
420
|
+
A dictionary of error Message indexed by the corresponding Message ID.
|
|
421
|
+
"""
|
|
422
|
+
ret_dict = {}
|
|
423
|
+
current = current_time if current_time else now().timestamp()
|
|
424
|
+
for in_message_id in list(inquired_in_message_ids):
|
|
425
|
+
in_message = found_in_message_dict[in_message_id]
|
|
426
|
+
node_id = in_message.metadata.dst_node_id
|
|
427
|
+
online_until = node_id_to_online_until.get(node_id)
|
|
428
|
+
# Generate a reply message containing an error reply
|
|
429
|
+
# if the node is offline or doesn't exist.
|
|
430
|
+
if online_until is None or online_until < current:
|
|
431
|
+
if update_set:
|
|
432
|
+
inquired_in_message_ids.remove(in_message_id)
|
|
433
|
+
reply_message = create_message_error_unavailable_res_message(
|
|
434
|
+
in_message.metadata, "node_unavail"
|
|
435
|
+
)
|
|
436
|
+
ret_dict[in_message_id] = reply_message
|
|
437
|
+
return ret_dict
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -22,7 +22,7 @@ from uuid import UUID
|
|
|
22
22
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
25
|
-
from flwr.common import
|
|
25
|
+
from flwr.common import ConfigRecord, Message
|
|
26
26
|
from flwr.common.constant import SUPERLINK_NODE_ID, Status
|
|
27
27
|
from flwr.common.logger import log
|
|
28
28
|
from flwr.common.serde import (
|
|
@@ -127,7 +127,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
127
127
|
request.fab_version,
|
|
128
128
|
fab_hash,
|
|
129
129
|
user_config_from_proto(request.override_config),
|
|
130
|
-
|
|
130
|
+
ConfigRecord(),
|
|
131
131
|
)
|
|
132
132
|
return CreateRunResponse(run_id=run_id)
|
|
133
133
|
|
|
@@ -206,7 +206,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
206
206
|
|
|
207
207
|
# Delete the instruction Messages and their replies if found
|
|
208
208
|
message_ins_ids_to_delete = {
|
|
209
|
-
UUID(msg_res.metadata.
|
|
209
|
+
UUID(msg_res.metadata.reply_to_message_id) for msg_res in messages_res
|
|
210
210
|
}
|
|
211
211
|
|
|
212
212
|
state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
|
@@ -24,7 +24,7 @@ from grpc import ServicerContext
|
|
|
24
24
|
from flwr.common.constant import Status
|
|
25
25
|
from flwr.common.logger import log
|
|
26
26
|
from flwr.common.serde import (
|
|
27
|
-
|
|
27
|
+
config_record_to_proto,
|
|
28
28
|
context_from_proto,
|
|
29
29
|
context_to_proto,
|
|
30
30
|
fab_to_proto,
|
|
@@ -182,5 +182,5 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
182
182
|
)
|
|
183
183
|
return GetFederationOptionsResponse()
|
|
184
184
|
return GetFederationOptionsResponse(
|
|
185
|
-
federation_options=
|
|
185
|
+
federation_options=config_record_to_proto(federation_options)
|
|
186
186
|
)
|
flwr/server/typing.py
CHANGED
|
@@ -19,9 +19,9 @@ from typing import Callable
|
|
|
19
19
|
|
|
20
20
|
from flwr.common import Context
|
|
21
21
|
|
|
22
|
-
from .
|
|
22
|
+
from .grid import Grid
|
|
23
23
|
from .serverapp_components import ServerAppComponents
|
|
24
24
|
|
|
25
|
-
ServerAppCallable = Callable[[
|
|
26
|
-
Workflow = Callable[[
|
|
25
|
+
ServerAppCallable = Callable[[Grid, Context], None]
|
|
26
|
+
Workflow = Callable[[Grid, Context], None]
|
|
27
27
|
ServerFn = Callable[[Context], ServerAppComponents]
|
flwr/server/utils/validator.py
CHANGED
|
@@ -68,8 +68,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
|
68
68
|
|
|
69
69
|
# Link respose to original message
|
|
70
70
|
if not is_reply_message:
|
|
71
|
-
if metadata.
|
|
72
|
-
validation_errors.append("`metadata.
|
|
71
|
+
if metadata.reply_to_message_id != "":
|
|
72
|
+
validation_errors.append("`metadata.reply_to_message_id` MUST not be set.")
|
|
73
73
|
if metadata.src_node_id != SUPERLINK_NODE_ID:
|
|
74
74
|
validation_errors.append(
|
|
75
75
|
f"`metadata.src_node_id` is not {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
|
@@ -79,8 +79,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
|
79
79
|
f"`metadata.dst_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
|
80
80
|
)
|
|
81
81
|
else:
|
|
82
|
-
if metadata.
|
|
83
|
-
validation_errors.append("`metadata.
|
|
82
|
+
if metadata.reply_to_message_id == "":
|
|
83
|
+
validation_errors.append("`metadata.reply_to_message_id` MUST be set.")
|
|
84
84
|
if metadata.src_node_id == SUPERLINK_NODE_ID:
|
|
85
85
|
validation_errors.append(
|
|
86
86
|
f"`metadata.src_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
|