flwr-nightly 1.19.0.dev20250515__py3-none-any.whl → 1.19.0.dev20250520__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/client/app.py +1 -1
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/app.py +1 -1
- flwr/client/grpc_rere_client/connection.py +2 -1
- flwr/client/rest_client/connection.py +2 -1
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/constant.py +1 -1
- flwr/common/inflatable.py +41 -12
- flwr/common/inflatable_grpc_utils.py +97 -0
- flwr/common/message.py +4 -243
- flwr/common/record/array.py +2 -2
- flwr/common/record/arrayrecord.py +1 -1
- flwr/common/record/configrecord.py +2 -2
- flwr/common/record/metricrecord.py +1 -1
- flwr/common/record/recorddict.py +1 -1
- flwr/common/serde.py +4 -1
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/serverapp/__init__.py +15 -0
- flwr/supercore/__init__.py +15 -0
- flwr/superlink/__init__.py +15 -0
- flwr/supernode/__init__.py +15 -0
- {flwr_nightly-1.19.0.dev20250515.dist-info → flwr_nightly-1.19.0.dev20250520.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250515.dist-info → flwr_nightly-1.19.0.dev20250520.dist-info}/RECORD +34 -20
- {flwr_nightly-1.19.0.dev20250515.dist-info → flwr_nightly-1.19.0.dev20250520.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250515.dist-info → flwr_nightly-1.19.0.dev20250520.dist-info}/entry_points.txt +0 -0
flwr/app/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Public Flower App APIs."""
|
flwr/app/error.py
ADDED
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Error."""
|
16
|
+
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
from typing import Optional, cast
|
21
|
+
|
22
|
+
DEFAULT_TTL = 43200 # This is 12 hours
|
23
|
+
MESSAGE_INIT_ERROR_MESSAGE = (
|
24
|
+
"Invalid arguments for Message. Expected one of the documented "
|
25
|
+
"signatures: Message(content: RecordDict, dst_node_id: int, message_type: str,"
|
26
|
+
" *, [ttl: float, group_id: str]) or Message(content: RecordDict | error: Error,"
|
27
|
+
" *, reply_to: Message, [ttl: float])."
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
class Error:
|
32
|
+
"""The class storing information about an error that occurred.
|
33
|
+
|
34
|
+
Parameters
|
35
|
+
----------
|
36
|
+
code : int
|
37
|
+
An identifier for the error.
|
38
|
+
reason : Optional[str]
|
39
|
+
A reason for why the error arose (e.g. an exception stack-trace)
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(self, code: int, reason: str | None = None) -> None:
|
43
|
+
var_dict = {
|
44
|
+
"_code": code,
|
45
|
+
"_reason": reason,
|
46
|
+
}
|
47
|
+
self.__dict__.update(var_dict)
|
48
|
+
|
49
|
+
@property
|
50
|
+
def code(self) -> int:
|
51
|
+
"""Error code."""
|
52
|
+
return cast(int, self.__dict__["_code"])
|
53
|
+
|
54
|
+
@property
|
55
|
+
def reason(self) -> str | None:
|
56
|
+
"""Reason reported about the error."""
|
57
|
+
return cast(Optional[str], self.__dict__["_reason"])
|
58
|
+
|
59
|
+
def __repr__(self) -> str:
|
60
|
+
"""Return a string representation of this instance."""
|
61
|
+
view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
|
62
|
+
return f"{self.__class__.__qualname__}({view})"
|
63
|
+
|
64
|
+
def __eq__(self, other: object) -> bool:
|
65
|
+
"""Compare two instances of the class."""
|
66
|
+
if not isinstance(other, self.__class__):
|
67
|
+
raise NotImplementedError
|
68
|
+
return self.__dict__ == other.__dict__
|
flwr/app/metadata.py
ADDED
@@ -0,0 +1,223 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Metadata."""
|
16
|
+
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
from typing import cast
|
21
|
+
|
22
|
+
from ..common.constant import MessageType, MessageTypeLegacy
|
23
|
+
|
24
|
+
|
25
|
+
class Metadata: # pylint: disable=too-many-instance-attributes
|
26
|
+
"""The class representing metadata associated with the current message.
|
27
|
+
|
28
|
+
Parameters
|
29
|
+
----------
|
30
|
+
run_id : int
|
31
|
+
An identifier for the current run.
|
32
|
+
message_id : str
|
33
|
+
An identifier for the current message.
|
34
|
+
src_node_id : int
|
35
|
+
An identifier for the node sending this message.
|
36
|
+
dst_node_id : int
|
37
|
+
An identifier for the node receiving this message.
|
38
|
+
reply_to_message_id : str
|
39
|
+
An identifier for the message to which this message is a reply.
|
40
|
+
group_id : str
|
41
|
+
An identifier for grouping messages. In some settings,
|
42
|
+
this is used as the FL round.
|
43
|
+
created_at : float
|
44
|
+
Unix timestamp when the message was created.
|
45
|
+
ttl : float
|
46
|
+
Time-to-live for this message in seconds.
|
47
|
+
message_type : str
|
48
|
+
A string that encodes the action to be executed on
|
49
|
+
the receiving end.
|
50
|
+
"""
|
51
|
+
|
52
|
+
def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
53
|
+
self,
|
54
|
+
run_id: int,
|
55
|
+
message_id: str,
|
56
|
+
src_node_id: int,
|
57
|
+
dst_node_id: int,
|
58
|
+
reply_to_message_id: str,
|
59
|
+
group_id: str,
|
60
|
+
created_at: float,
|
61
|
+
ttl: float,
|
62
|
+
message_type: str,
|
63
|
+
) -> None:
|
64
|
+
var_dict = {
|
65
|
+
"_run_id": run_id,
|
66
|
+
"_message_id": message_id,
|
67
|
+
"_src_node_id": src_node_id,
|
68
|
+
"_dst_node_id": dst_node_id,
|
69
|
+
"_reply_to_message_id": reply_to_message_id,
|
70
|
+
"_group_id": group_id,
|
71
|
+
"_created_at": created_at,
|
72
|
+
"_ttl": ttl,
|
73
|
+
"_message_type": message_type,
|
74
|
+
}
|
75
|
+
self.__dict__.update(var_dict)
|
76
|
+
self.message_type = message_type # Trigger validation
|
77
|
+
|
78
|
+
@property
|
79
|
+
def run_id(self) -> int:
|
80
|
+
"""An identifier for the current run."""
|
81
|
+
return cast(int, self.__dict__["_run_id"])
|
82
|
+
|
83
|
+
@property
|
84
|
+
def message_id(self) -> str:
|
85
|
+
"""An identifier for the current message."""
|
86
|
+
return cast(str, self.__dict__["_message_id"])
|
87
|
+
|
88
|
+
@property
|
89
|
+
def src_node_id(self) -> int:
|
90
|
+
"""An identifier for the node sending this message."""
|
91
|
+
return cast(int, self.__dict__["_src_node_id"])
|
92
|
+
|
93
|
+
@property
|
94
|
+
def reply_to_message_id(self) -> str:
|
95
|
+
"""An identifier for the message to which this message is a reply."""
|
96
|
+
return cast(str, self.__dict__["_reply_to_message_id"])
|
97
|
+
|
98
|
+
@property
|
99
|
+
def dst_node_id(self) -> int:
|
100
|
+
"""An identifier for the node receiving this message."""
|
101
|
+
return cast(int, self.__dict__["_dst_node_id"])
|
102
|
+
|
103
|
+
@dst_node_id.setter
|
104
|
+
def dst_node_id(self, value: int) -> None:
|
105
|
+
"""Set dst_node_id."""
|
106
|
+
self.__dict__["_dst_node_id"] = value
|
107
|
+
|
108
|
+
@property
|
109
|
+
def group_id(self) -> str:
|
110
|
+
"""An identifier for grouping messages."""
|
111
|
+
return cast(str, self.__dict__["_group_id"])
|
112
|
+
|
113
|
+
@group_id.setter
|
114
|
+
def group_id(self, value: str) -> None:
|
115
|
+
"""Set group_id."""
|
116
|
+
self.__dict__["_group_id"] = value
|
117
|
+
|
118
|
+
@property
|
119
|
+
def created_at(self) -> float:
|
120
|
+
"""Unix timestamp when the message was created."""
|
121
|
+
return cast(float, self.__dict__["_created_at"])
|
122
|
+
|
123
|
+
@created_at.setter
|
124
|
+
def created_at(self, value: float) -> None:
|
125
|
+
"""Set creation timestamp of this message."""
|
126
|
+
self.__dict__["_created_at"] = value
|
127
|
+
|
128
|
+
@property
|
129
|
+
def delivered_at(self) -> str:
|
130
|
+
"""Unix timestamp when the message was delivered."""
|
131
|
+
return cast(str, self.__dict__["_delivered_at"])
|
132
|
+
|
133
|
+
@delivered_at.setter
|
134
|
+
def delivered_at(self, value: str) -> None:
|
135
|
+
"""Set delivery timestamp of this message."""
|
136
|
+
self.__dict__["_delivered_at"] = value
|
137
|
+
|
138
|
+
@property
|
139
|
+
def ttl(self) -> float:
|
140
|
+
"""Time-to-live for this message."""
|
141
|
+
return cast(float, self.__dict__["_ttl"])
|
142
|
+
|
143
|
+
@ttl.setter
|
144
|
+
def ttl(self, value: float) -> None:
|
145
|
+
"""Set ttl."""
|
146
|
+
self.__dict__["_ttl"] = value
|
147
|
+
|
148
|
+
@property
|
149
|
+
def message_type(self) -> str:
|
150
|
+
"""A string that encodes the action to be executed on the receiving end."""
|
151
|
+
return cast(str, self.__dict__["_message_type"])
|
152
|
+
|
153
|
+
@message_type.setter
|
154
|
+
def message_type(self, value: str) -> None:
|
155
|
+
"""Set message_type."""
|
156
|
+
# Validate message type
|
157
|
+
if validate_legacy_message_type(value):
|
158
|
+
pass # Backward compatibility for legacy message types
|
159
|
+
elif not validate_message_type(value):
|
160
|
+
raise ValueError(
|
161
|
+
f"Invalid message type: '{value}'. "
|
162
|
+
"Expected format: '<category>' or '<category>.<action>', "
|
163
|
+
"where <category> must be 'train', 'evaluate', or 'query', "
|
164
|
+
"and <action> must be a valid Python identifier."
|
165
|
+
)
|
166
|
+
|
167
|
+
self.__dict__["_message_type"] = value
|
168
|
+
|
169
|
+
def __repr__(self) -> str:
|
170
|
+
"""Return a string representation of this instance."""
|
171
|
+
view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
|
172
|
+
return f"{self.__class__.__qualname__}({view})"
|
173
|
+
|
174
|
+
def __eq__(self, other: object) -> bool:
|
175
|
+
"""Compare two instances of the class."""
|
176
|
+
if not isinstance(other, self.__class__):
|
177
|
+
raise NotImplementedError
|
178
|
+
return self.__dict__ == other.__dict__
|
179
|
+
|
180
|
+
|
181
|
+
def validate_message_type(message_type: str) -> bool:
|
182
|
+
"""Validate if the message type is valid.
|
183
|
+
|
184
|
+
A valid message type format must be one of the following:
|
185
|
+
|
186
|
+
- "<category>"
|
187
|
+
- "<category>.<action>"
|
188
|
+
|
189
|
+
where `category` must be one of "train", "evaluate", or "query",
|
190
|
+
and `action` must be a valid Python identifier.
|
191
|
+
"""
|
192
|
+
# Check if conforming to the format "<category>"
|
193
|
+
valid_types = {
|
194
|
+
MessageType.TRAIN,
|
195
|
+
MessageType.EVALUATE,
|
196
|
+
MessageType.QUERY,
|
197
|
+
MessageType.SYSTEM,
|
198
|
+
}
|
199
|
+
if message_type in valid_types:
|
200
|
+
return True
|
201
|
+
|
202
|
+
# Check if conforming to the format "<category>.<action>"
|
203
|
+
if message_type.count(".") != 1:
|
204
|
+
return False
|
205
|
+
|
206
|
+
category, action = message_type.split(".")
|
207
|
+
if category in valid_types and action.isidentifier():
|
208
|
+
return True
|
209
|
+
|
210
|
+
return False
|
211
|
+
|
212
|
+
|
213
|
+
def validate_legacy_message_type(message_type: str) -> bool:
|
214
|
+
"""Validate if the legacy message type is valid."""
|
215
|
+
# Backward compatibility for legacy message types
|
216
|
+
if message_type in (
|
217
|
+
MessageTypeLegacy.GET_PARAMETERS,
|
218
|
+
MessageTypeLegacy.GET_PROPERTIES,
|
219
|
+
"reconnect",
|
220
|
+
):
|
221
|
+
return True
|
222
|
+
|
223
|
+
return False
|
flwr/client/app.py
CHANGED
@@ -30,6 +30,7 @@ import grpc
|
|
30
30
|
from cryptography.hazmat.primitives.asymmetric import ec
|
31
31
|
from grpc import RpcError
|
32
32
|
|
33
|
+
from flwr.app.error import Error
|
33
34
|
from flwr.cli.config_utils import get_fab_metadata
|
34
35
|
from flwr.cli.install import install_from_fab
|
35
36
|
from flwr.client.client import Client
|
@@ -57,7 +58,6 @@ from flwr.common.constant import (
|
|
57
58
|
from flwr.common.exit import ExitCode, flwr_exit
|
58
59
|
from flwr.common.grpc import generic_create_grpc_server
|
59
60
|
from flwr.common.logger import log, warn_deprecated_feature
|
60
|
-
from flwr.common.message import Error
|
61
61
|
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
|
62
62
|
from flwr.common.typing import Fab, Run, RunNotRunningException, UserConfig
|
63
63
|
from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server
|
flwr/client/client_app.py
CHANGED
@@ -20,6 +20,7 @@ from collections.abc import Iterator
|
|
20
20
|
from contextlib import contextmanager
|
21
21
|
from typing import Callable, Optional
|
22
22
|
|
23
|
+
from flwr.app.metadata import validate_message_type
|
23
24
|
from flwr.client.client import Client
|
24
25
|
from flwr.client.message_handler.message_handler import (
|
25
26
|
handle_legacy_message_from_msgtype,
|
@@ -28,7 +29,6 @@ from flwr.client.mod.utils import make_ffn
|
|
28
29
|
from flwr.client.typing import ClientFnExt, Mod
|
29
30
|
from flwr.common import Context, Message, MessageType
|
30
31
|
from flwr.common.logger import warn_deprecated_feature
|
31
|
-
from flwr.common.message import validate_message_type
|
32
32
|
|
33
33
|
from .typing import ClientAppCallable
|
34
34
|
|
flwr/client/clientapp/app.py
CHANGED
@@ -23,6 +23,7 @@ from typing import Optional
|
|
23
23
|
|
24
24
|
import grpc
|
25
25
|
|
26
|
+
from flwr.app.error import Error
|
26
27
|
from flwr.cli.install import install_from_fab
|
27
28
|
from flwr.client.client_app import ClientApp, LoadClientAppError
|
28
29
|
from flwr.common import Context, Message
|
@@ -32,7 +33,6 @@ from flwr.common.constant import CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS, ErrorCo
|
|
32
33
|
from flwr.common.exit import ExitCode, flwr_exit
|
33
34
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
34
35
|
from flwr.common.logger import log
|
35
|
-
from flwr.common.message import Error
|
36
36
|
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
37
37
|
from flwr.common.serde import (
|
38
38
|
context_from_proto,
|
@@ -25,13 +25,14 @@ from typing import Callable, Optional, Union, cast
|
|
25
25
|
import grpc
|
26
26
|
from cryptography.hazmat.primitives.asymmetric import ec
|
27
27
|
|
28
|
+
from flwr.app.metadata import Metadata
|
28
29
|
from flwr.client.message_handler.message_handler import validate_out_message
|
29
30
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
30
31
|
from flwr.common.constant import HEARTBEAT_CALL_TIMEOUT, HEARTBEAT_DEFAULT_INTERVAL
|
31
32
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
32
33
|
from flwr.common.heartbeat import HeartbeatSender
|
33
34
|
from flwr.common.logger import log
|
34
|
-
from flwr.common.message import Message
|
35
|
+
from flwr.common.message import Message
|
35
36
|
from flwr.common.retry_invoker import RetryInvoker
|
36
37
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
37
38
|
generate_key_pairs,
|
@@ -25,13 +25,14 @@ from cryptography.hazmat.primitives.asymmetric import ec
|
|
25
25
|
from google.protobuf.message import Message as GrpcMessage
|
26
26
|
from requests.exceptions import ConnectionError as RequestsConnectionError
|
27
27
|
|
28
|
+
from flwr.app.metadata import Metadata
|
28
29
|
from flwr.client.message_handler.message_handler import validate_out_message
|
29
30
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
30
31
|
from flwr.common.constant import HEARTBEAT_DEFAULT_INTERVAL
|
31
32
|
from flwr.common.exit import ExitCode, flwr_exit
|
32
33
|
from flwr.common.heartbeat import HeartbeatSender
|
33
34
|
from flwr.common.logger import log
|
34
|
-
from flwr.common.message import Message
|
35
|
+
from flwr.common.message import Message
|
35
36
|
from flwr.common.retry_invoker import RetryInvoker
|
36
37
|
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
37
38
|
from flwr.common.typing import Fab, Run
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Public Flower ClientApp APIs."""
|
flwr/common/__init__.py
CHANGED
@@ -15,6 +15,8 @@
|
|
15
15
|
"""Common components shared between server and client."""
|
16
16
|
|
17
17
|
|
18
|
+
from ..app.error import Error as Error
|
19
|
+
from ..app.metadata import Metadata as Metadata
|
18
20
|
from .constant import MessageType as MessageType
|
19
21
|
from .constant import MessageTypeLegacy as MessageTypeLegacy
|
20
22
|
from .context import Context as Context
|
@@ -23,9 +25,7 @@ from .grpc import GRPC_MAX_MESSAGE_LENGTH
|
|
23
25
|
from .logger import configure as configure
|
24
26
|
from .logger import log as log
|
25
27
|
from .message import DEFAULT_TTL
|
26
|
-
from .message import Error as Error
|
27
28
|
from .message import Message as Message
|
28
|
-
from .message import Metadata as Metadata
|
29
29
|
from .parameter import bytes_to_ndarray as bytes_to_ndarray
|
30
30
|
from .parameter import ndarray_to_bytes as ndarray_to_bytes
|
31
31
|
from .parameter import ndarrays_to_parameters as ndarrays_to_parameters
|
flwr/common/constant.py
CHANGED
@@ -130,7 +130,7 @@ GC_THRESHOLD = 200_000_000 # 200 MB
|
|
130
130
|
|
131
131
|
# Constants for Inflatable
|
132
132
|
HEAD_BODY_DIVIDER = b"\x00"
|
133
|
-
|
133
|
+
HEAD_VALUE_DIVIDER = " "
|
134
134
|
|
135
135
|
# Constants for serialization
|
136
136
|
INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
|
flwr/common/inflatable.py
CHANGED
@@ -20,9 +20,7 @@ from __future__ import annotations
|
|
20
20
|
import hashlib
|
21
21
|
from typing import TypeVar
|
22
22
|
|
23
|
-
from .constant import HEAD_BODY_DIVIDER,
|
24
|
-
|
25
|
-
T = TypeVar("T", bound="InflatableObject")
|
23
|
+
from .constant import HEAD_BODY_DIVIDER, HEAD_VALUE_DIVIDER
|
26
24
|
|
27
25
|
|
28
26
|
class InflatableObject:
|
@@ -65,6 +63,9 @@ class InflatableObject:
|
|
65
63
|
return None
|
66
64
|
|
67
65
|
|
66
|
+
T = TypeVar("T", bound=InflatableObject)
|
67
|
+
|
68
|
+
|
68
69
|
def get_object_id(object_content: bytes) -> str:
|
69
70
|
"""Return a SHA-256 hash of the (deflated) object content."""
|
70
71
|
return hashlib.sha256(object_content).hexdigest()
|
@@ -84,13 +85,21 @@ def get_object_body(object_content: bytes, cls: type[T]) -> bytes:
|
|
84
85
|
return _get_object_body(object_content)
|
85
86
|
|
86
87
|
|
87
|
-
def add_header_to_object_body(object_body: bytes,
|
88
|
+
def add_header_to_object_body(object_body: bytes, obj: InflatableObject) -> bytes:
|
88
89
|
"""Add header to object content."""
|
89
90
|
# Construct header
|
90
|
-
header = f"{
|
91
|
-
|
91
|
+
header = f"%s{HEAD_VALUE_DIVIDER}%s{HEAD_VALUE_DIVIDER}%d" % (
|
92
|
+
obj.__class__.__qualname__, # Type of object
|
93
|
+
",".join((obj.children or {}).keys()), # IDs of child objects
|
94
|
+
len(object_body), # Length of object body
|
95
|
+
)
|
96
|
+
|
92
97
|
# Concatenate header and object body
|
93
|
-
|
98
|
+
ret = bytearray()
|
99
|
+
ret.extend(header.encode(encoding="utf-8"))
|
100
|
+
ret.extend(HEAD_BODY_DIVIDER)
|
101
|
+
ret.extend(object_body)
|
102
|
+
return bytes(ret)
|
94
103
|
|
95
104
|
|
96
105
|
def _get_object_head(object_content: bytes) -> bytes:
|
@@ -108,9 +117,14 @@ def get_object_type_from_object_content(object_content: bytes) -> str:
|
|
108
117
|
return get_object_head_values_from_object_content(object_content)[0]
|
109
118
|
|
110
119
|
|
120
|
+
def get_object_children_ids_from_object_content(object_content: bytes) -> list[str]:
|
121
|
+
"""Return object children IDs from bytes."""
|
122
|
+
return get_object_head_values_from_object_content(object_content)[1]
|
123
|
+
|
124
|
+
|
111
125
|
def get_object_body_len_from_object_content(object_content: bytes) -> int:
|
112
126
|
"""Return length of the object body."""
|
113
|
-
return get_object_head_values_from_object_content(object_content)[
|
127
|
+
return get_object_head_values_from_object_content(object_content)[2]
|
114
128
|
|
115
129
|
|
116
130
|
def check_body_len_consistency(object_content: bytes) -> bool:
|
@@ -121,8 +135,23 @@ def check_body_len_consistency(object_content: bytes) -> bool:
|
|
121
135
|
|
122
136
|
def get_object_head_values_from_object_content(
|
123
137
|
object_content: bytes,
|
124
|
-
) -> tuple[str, int]:
|
125
|
-
"""Return object type and body length from object content.
|
138
|
+
) -> tuple[str, list[str], int]:
|
139
|
+
"""Return object type and body length from object content.
|
140
|
+
|
141
|
+
Parameters
|
142
|
+
----------
|
143
|
+
object_content : bytes
|
144
|
+
The deflated object content.
|
145
|
+
|
146
|
+
Returns
|
147
|
+
-------
|
148
|
+
tuple[str, list[str], int]
|
149
|
+
A tuple containing:
|
150
|
+
- The object type as a string.
|
151
|
+
- A list of child object IDs as strings.
|
152
|
+
- The length of the object body as an integer.
|
153
|
+
"""
|
126
154
|
head = _get_object_head(object_content).decode(encoding="utf-8")
|
127
|
-
obj_type, body_len = head.split(
|
128
|
-
|
155
|
+
obj_type, children_str, body_len = head.split(HEAD_VALUE_DIVIDER)
|
156
|
+
children_ids = children_str.split(",") if children_str else []
|
157
|
+
return obj_type, children_ids, int(body_len)
|
@@ -0,0 +1,97 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""InflatableObject utils."""
|
16
|
+
|
17
|
+
|
18
|
+
from typing import Union
|
19
|
+
|
20
|
+
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
21
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
22
|
+
PullObjectRequest,
|
23
|
+
PullObjectResponse,
|
24
|
+
PushObjectRequest,
|
25
|
+
PushObjectResponse,
|
26
|
+
)
|
27
|
+
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
|
28
|
+
|
29
|
+
from .inflatable import (
|
30
|
+
InflatableObject,
|
31
|
+
get_object_head_values_from_object_content,
|
32
|
+
get_object_id,
|
33
|
+
)
|
34
|
+
from .record import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
|
35
|
+
|
36
|
+
# Helper registry that maps names of classes to their type
|
37
|
+
inflatable_class_registry: dict[str, type[InflatableObject]] = {
|
38
|
+
Array.__qualname__: Array,
|
39
|
+
ArrayRecord.__qualname__: ArrayRecord,
|
40
|
+
ConfigRecord.__qualname__: ConfigRecord,
|
41
|
+
MetricRecord.__qualname__: MetricRecord,
|
42
|
+
RecordDict.__qualname__: RecordDict,
|
43
|
+
}
|
44
|
+
|
45
|
+
|
46
|
+
def push_object_to_servicer(
|
47
|
+
obj: InflatableObject, stub: Union[FleetStub, ServerAppIoStub]
|
48
|
+
) -> set[str]:
|
49
|
+
"""Recursively deflate an object and push it to the servicer.
|
50
|
+
|
51
|
+
Objects with the same ID are not pushed twice. It returns the set of pushed object
|
52
|
+
IDs.
|
53
|
+
"""
|
54
|
+
pushed_object_ids: set[str] = set()
|
55
|
+
# Push children if it has any
|
56
|
+
if children := obj.children:
|
57
|
+
for child in children.values():
|
58
|
+
pushed_object_ids |= push_object_to_servicer(child, stub)
|
59
|
+
|
60
|
+
# Deflate object and push
|
61
|
+
object_content = obj.deflate()
|
62
|
+
object_id = get_object_id(object_content)
|
63
|
+
_: PushObjectResponse = stub.PushObject(
|
64
|
+
PushObjectRequest(
|
65
|
+
object_id=object_id,
|
66
|
+
object_content=object_content,
|
67
|
+
)
|
68
|
+
)
|
69
|
+
pushed_object_ids.add(object_id)
|
70
|
+
|
71
|
+
return pushed_object_ids
|
72
|
+
|
73
|
+
|
74
|
+
def pull_object_from_servicer(
|
75
|
+
object_id: str, stub: Union[FleetStub, ServerAppIoStub]
|
76
|
+
) -> InflatableObject:
|
77
|
+
"""Recursively inflate an object by pulling it from the servicer."""
|
78
|
+
# Pull object
|
79
|
+
object_proto: PullObjectResponse = stub.PullObject(
|
80
|
+
PullObjectRequest(object_id=object_id)
|
81
|
+
)
|
82
|
+
object_content = object_proto.object_content
|
83
|
+
|
84
|
+
# Extract object class and object_ids of children
|
85
|
+
obj_type, children_obj_ids, _ = get_object_head_values_from_object_content(
|
86
|
+
object_content=object_content
|
87
|
+
)
|
88
|
+
# Resolve object class
|
89
|
+
cls_type = inflatable_class_registry[obj_type]
|
90
|
+
|
91
|
+
# Pull all children objects
|
92
|
+
children: dict[str, InflatableObject] = {}
|
93
|
+
for child_object_id in children_obj_ids:
|
94
|
+
children[child_object_id] = pull_object_from_servicer(child_object_id, stub)
|
95
|
+
|
96
|
+
# Inflate object passing its children
|
97
|
+
return cls_type.inflate(object_content, children=children)
|