flwr-nightly 1.19.0.dev20250515__py3-none-any.whl → 1.19.0.dev20250520__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. flwr/app/__init__.py +15 -0
  2. flwr/app/error.py +68 -0
  3. flwr/app/metadata.py +223 -0
  4. flwr/client/app.py +1 -1
  5. flwr/client/client_app.py +1 -1
  6. flwr/client/clientapp/app.py +1 -1
  7. flwr/client/grpc_rere_client/connection.py +2 -1
  8. flwr/client/rest_client/connection.py +2 -1
  9. flwr/clientapp/__init__.py +15 -0
  10. flwr/common/__init__.py +2 -2
  11. flwr/common/constant.py +1 -1
  12. flwr/common/inflatable.py +41 -12
  13. flwr/common/inflatable_grpc_utils.py +97 -0
  14. flwr/common/message.py +4 -243
  15. flwr/common/record/array.py +2 -2
  16. flwr/common/record/arrayrecord.py +1 -1
  17. flwr/common/record/configrecord.py +2 -2
  18. flwr/common/record/metricrecord.py +1 -1
  19. flwr/common/record/recorddict.py +1 -1
  20. flwr/common/serde.py +4 -1
  21. flwr/compat/__init__.py +15 -0
  22. flwr/compat/client/__init__.py +15 -0
  23. flwr/compat/common/__init__.py +15 -0
  24. flwr/compat/server/__init__.py +15 -0
  25. flwr/compat/simulation/__init__.py +15 -0
  26. flwr/server/superlink/fleet/vce/vce_api.py +1 -1
  27. flwr/serverapp/__init__.py +15 -0
  28. flwr/supercore/__init__.py +15 -0
  29. flwr/superlink/__init__.py +15 -0
  30. flwr/supernode/__init__.py +15 -0
  31. {flwr_nightly-1.19.0.dev20250515.dist-info → flwr_nightly-1.19.0.dev20250520.dist-info}/METADATA +1 -1
  32. {flwr_nightly-1.19.0.dev20250515.dist-info → flwr_nightly-1.19.0.dev20250520.dist-info}/RECORD +34 -20
  33. {flwr_nightly-1.19.0.dev20250515.dist-info → flwr_nightly-1.19.0.dev20250520.dist-info}/WHEEL +0 -0
  34. {flwr_nightly-1.19.0.dev20250515.dist-info → flwr_nightly-1.19.0.dev20250520.dist-info}/entry_points.txt +0 -0
flwr/app/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Public Flower App APIs."""
flwr/app/error.py ADDED
@@ -0,0 +1,68 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Error."""
16
+
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Optional, cast
21
+
22
+ DEFAULT_TTL = 43200 # This is 12 hours
23
+ MESSAGE_INIT_ERROR_MESSAGE = (
24
+ "Invalid arguments for Message. Expected one of the documented "
25
+ "signatures: Message(content: RecordDict, dst_node_id: int, message_type: str,"
26
+ " *, [ttl: float, group_id: str]) or Message(content: RecordDict | error: Error,"
27
+ " *, reply_to: Message, [ttl: float])."
28
+ )
29
+
30
+
31
+ class Error:
32
+ """The class storing information about an error that occurred.
33
+
34
+ Parameters
35
+ ----------
36
+ code : int
37
+ An identifier for the error.
38
+ reason : Optional[str]
39
+ A reason for why the error arose (e.g. an exception stack-trace)
40
+ """
41
+
42
+ def __init__(self, code: int, reason: str | None = None) -> None:
43
+ var_dict = {
44
+ "_code": code,
45
+ "_reason": reason,
46
+ }
47
+ self.__dict__.update(var_dict)
48
+
49
+ @property
50
+ def code(self) -> int:
51
+ """Error code."""
52
+ return cast(int, self.__dict__["_code"])
53
+
54
+ @property
55
+ def reason(self) -> str | None:
56
+ """Reason reported about the error."""
57
+ return cast(Optional[str], self.__dict__["_reason"])
58
+
59
+ def __repr__(self) -> str:
60
+ """Return a string representation of this instance."""
61
+ view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
62
+ return f"{self.__class__.__qualname__}({view})"
63
+
64
+ def __eq__(self, other: object) -> bool:
65
+ """Compare two instances of the class."""
66
+ if not isinstance(other, self.__class__):
67
+ raise NotImplementedError
68
+ return self.__dict__ == other.__dict__
flwr/app/metadata.py ADDED
@@ -0,0 +1,223 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Metadata."""
16
+
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import cast
21
+
22
+ from ..common.constant import MessageType, MessageTypeLegacy
23
+
24
+
25
+ class Metadata: # pylint: disable=too-many-instance-attributes
26
+ """The class representing metadata associated with the current message.
27
+
28
+ Parameters
29
+ ----------
30
+ run_id : int
31
+ An identifier for the current run.
32
+ message_id : str
33
+ An identifier for the current message.
34
+ src_node_id : int
35
+ An identifier for the node sending this message.
36
+ dst_node_id : int
37
+ An identifier for the node receiving this message.
38
+ reply_to_message_id : str
39
+ An identifier for the message to which this message is a reply.
40
+ group_id : str
41
+ An identifier for grouping messages. In some settings,
42
+ this is used as the FL round.
43
+ created_at : float
44
+ Unix timestamp when the message was created.
45
+ ttl : float
46
+ Time-to-live for this message in seconds.
47
+ message_type : str
48
+ A string that encodes the action to be executed on
49
+ the receiving end.
50
+ """
51
+
52
+ def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
53
+ self,
54
+ run_id: int,
55
+ message_id: str,
56
+ src_node_id: int,
57
+ dst_node_id: int,
58
+ reply_to_message_id: str,
59
+ group_id: str,
60
+ created_at: float,
61
+ ttl: float,
62
+ message_type: str,
63
+ ) -> None:
64
+ var_dict = {
65
+ "_run_id": run_id,
66
+ "_message_id": message_id,
67
+ "_src_node_id": src_node_id,
68
+ "_dst_node_id": dst_node_id,
69
+ "_reply_to_message_id": reply_to_message_id,
70
+ "_group_id": group_id,
71
+ "_created_at": created_at,
72
+ "_ttl": ttl,
73
+ "_message_type": message_type,
74
+ }
75
+ self.__dict__.update(var_dict)
76
+ self.message_type = message_type # Trigger validation
77
+
78
+ @property
79
+ def run_id(self) -> int:
80
+ """An identifier for the current run."""
81
+ return cast(int, self.__dict__["_run_id"])
82
+
83
+ @property
84
+ def message_id(self) -> str:
85
+ """An identifier for the current message."""
86
+ return cast(str, self.__dict__["_message_id"])
87
+
88
+ @property
89
+ def src_node_id(self) -> int:
90
+ """An identifier for the node sending this message."""
91
+ return cast(int, self.__dict__["_src_node_id"])
92
+
93
+ @property
94
+ def reply_to_message_id(self) -> str:
95
+ """An identifier for the message to which this message is a reply."""
96
+ return cast(str, self.__dict__["_reply_to_message_id"])
97
+
98
+ @property
99
+ def dst_node_id(self) -> int:
100
+ """An identifier for the node receiving this message."""
101
+ return cast(int, self.__dict__["_dst_node_id"])
102
+
103
+ @dst_node_id.setter
104
+ def dst_node_id(self, value: int) -> None:
105
+ """Set dst_node_id."""
106
+ self.__dict__["_dst_node_id"] = value
107
+
108
+ @property
109
+ def group_id(self) -> str:
110
+ """An identifier for grouping messages."""
111
+ return cast(str, self.__dict__["_group_id"])
112
+
113
+ @group_id.setter
114
+ def group_id(self, value: str) -> None:
115
+ """Set group_id."""
116
+ self.__dict__["_group_id"] = value
117
+
118
+ @property
119
+ def created_at(self) -> float:
120
+ """Unix timestamp when the message was created."""
121
+ return cast(float, self.__dict__["_created_at"])
122
+
123
+ @created_at.setter
124
+ def created_at(self, value: float) -> None:
125
+ """Set creation timestamp of this message."""
126
+ self.__dict__["_created_at"] = value
127
+
128
+ @property
129
+ def delivered_at(self) -> str:
130
+ """Unix timestamp when the message was delivered."""
131
+ return cast(str, self.__dict__["_delivered_at"])
132
+
133
+ @delivered_at.setter
134
+ def delivered_at(self, value: str) -> None:
135
+ """Set delivery timestamp of this message."""
136
+ self.__dict__["_delivered_at"] = value
137
+
138
+ @property
139
+ def ttl(self) -> float:
140
+ """Time-to-live for this message."""
141
+ return cast(float, self.__dict__["_ttl"])
142
+
143
+ @ttl.setter
144
+ def ttl(self, value: float) -> None:
145
+ """Set ttl."""
146
+ self.__dict__["_ttl"] = value
147
+
148
+ @property
149
+ def message_type(self) -> str:
150
+ """A string that encodes the action to be executed on the receiving end."""
151
+ return cast(str, self.__dict__["_message_type"])
152
+
153
+ @message_type.setter
154
+ def message_type(self, value: str) -> None:
155
+ """Set message_type."""
156
+ # Validate message type
157
+ if validate_legacy_message_type(value):
158
+ pass # Backward compatibility for legacy message types
159
+ elif not validate_message_type(value):
160
+ raise ValueError(
161
+ f"Invalid message type: '{value}'. "
162
+ "Expected format: '<category>' or '<category>.<action>', "
163
+ "where <category> must be 'train', 'evaluate', or 'query', "
164
+ "and <action> must be a valid Python identifier."
165
+ )
166
+
167
+ self.__dict__["_message_type"] = value
168
+
169
+ def __repr__(self) -> str:
170
+ """Return a string representation of this instance."""
171
+ view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
172
+ return f"{self.__class__.__qualname__}({view})"
173
+
174
+ def __eq__(self, other: object) -> bool:
175
+ """Compare two instances of the class."""
176
+ if not isinstance(other, self.__class__):
177
+ raise NotImplementedError
178
+ return self.__dict__ == other.__dict__
179
+
180
+
181
+ def validate_message_type(message_type: str) -> bool:
182
+ """Validate if the message type is valid.
183
+
184
+ A valid message type format must be one of the following:
185
+
186
+ - "<category>"
187
+ - "<category>.<action>"
188
+
189
+ where `category` must be one of "train", "evaluate", or "query",
190
+ and `action` must be a valid Python identifier.
191
+ """
192
+ # Check if conforming to the format "<category>"
193
+ valid_types = {
194
+ MessageType.TRAIN,
195
+ MessageType.EVALUATE,
196
+ MessageType.QUERY,
197
+ MessageType.SYSTEM,
198
+ }
199
+ if message_type in valid_types:
200
+ return True
201
+
202
+ # Check if conforming to the format "<category>.<action>"
203
+ if message_type.count(".") != 1:
204
+ return False
205
+
206
+ category, action = message_type.split(".")
207
+ if category in valid_types and action.isidentifier():
208
+ return True
209
+
210
+ return False
211
+
212
+
213
+ def validate_legacy_message_type(message_type: str) -> bool:
214
+ """Validate if the legacy message type is valid."""
215
+ # Backward compatibility for legacy message types
216
+ if message_type in (
217
+ MessageTypeLegacy.GET_PARAMETERS,
218
+ MessageTypeLegacy.GET_PROPERTIES,
219
+ "reconnect",
220
+ ):
221
+ return True
222
+
223
+ return False
flwr/client/app.py CHANGED
@@ -30,6 +30,7 @@ import grpc
30
30
  from cryptography.hazmat.primitives.asymmetric import ec
31
31
  from grpc import RpcError
32
32
 
33
+ from flwr.app.error import Error
33
34
  from flwr.cli.config_utils import get_fab_metadata
34
35
  from flwr.cli.install import install_from_fab
35
36
  from flwr.client.client import Client
@@ -57,7 +58,6 @@ from flwr.common.constant import (
57
58
  from flwr.common.exit import ExitCode, flwr_exit
58
59
  from flwr.common.grpc import generic_create_grpc_server
59
60
  from flwr.common.logger import log, warn_deprecated_feature
60
- from flwr.common.message import Error
61
61
  from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
62
62
  from flwr.common.typing import Fab, Run, RunNotRunningException, UserConfig
63
63
  from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server
flwr/client/client_app.py CHANGED
@@ -20,6 +20,7 @@ from collections.abc import Iterator
20
20
  from contextlib import contextmanager
21
21
  from typing import Callable, Optional
22
22
 
23
+ from flwr.app.metadata import validate_message_type
23
24
  from flwr.client.client import Client
24
25
  from flwr.client.message_handler.message_handler import (
25
26
  handle_legacy_message_from_msgtype,
@@ -28,7 +29,6 @@ from flwr.client.mod.utils import make_ffn
28
29
  from flwr.client.typing import ClientFnExt, Mod
29
30
  from flwr.common import Context, Message, MessageType
30
31
  from flwr.common.logger import warn_deprecated_feature
31
- from flwr.common.message import validate_message_type
32
32
 
33
33
  from .typing import ClientAppCallable
34
34
 
@@ -23,6 +23,7 @@ from typing import Optional
23
23
 
24
24
  import grpc
25
25
 
26
+ from flwr.app.error import Error
26
27
  from flwr.cli.install import install_from_fab
27
28
  from flwr.client.client_app import ClientApp, LoadClientAppError
28
29
  from flwr.common import Context, Message
@@ -32,7 +33,6 @@ from flwr.common.constant import CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS, ErrorCo
32
33
  from flwr.common.exit import ExitCode, flwr_exit
33
34
  from flwr.common.grpc import create_channel, on_channel_state_change
34
35
  from flwr.common.logger import log
35
- from flwr.common.message import Error
36
36
  from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
37
37
  from flwr.common.serde import (
38
38
  context_from_proto,
@@ -25,13 +25,14 @@ from typing import Callable, Optional, Union, cast
25
25
  import grpc
26
26
  from cryptography.hazmat.primitives.asymmetric import ec
27
27
 
28
+ from flwr.app.metadata import Metadata
28
29
  from flwr.client.message_handler.message_handler import validate_out_message
29
30
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
30
31
  from flwr.common.constant import HEARTBEAT_CALL_TIMEOUT, HEARTBEAT_DEFAULT_INTERVAL
31
32
  from flwr.common.grpc import create_channel, on_channel_state_change
32
33
  from flwr.common.heartbeat import HeartbeatSender
33
34
  from flwr.common.logger import log
34
- from flwr.common.message import Message, Metadata
35
+ from flwr.common.message import Message
35
36
  from flwr.common.retry_invoker import RetryInvoker
36
37
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
37
38
  generate_key_pairs,
@@ -25,13 +25,14 @@ from cryptography.hazmat.primitives.asymmetric import ec
25
25
  from google.protobuf.message import Message as GrpcMessage
26
26
  from requests.exceptions import ConnectionError as RequestsConnectionError
27
27
 
28
+ from flwr.app.metadata import Metadata
28
29
  from flwr.client.message_handler.message_handler import validate_out_message
29
30
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
30
31
  from flwr.common.constant import HEARTBEAT_DEFAULT_INTERVAL
31
32
  from flwr.common.exit import ExitCode, flwr_exit
32
33
  from flwr.common.heartbeat import HeartbeatSender
33
34
  from flwr.common.logger import log
34
- from flwr.common.message import Message, Metadata
35
+ from flwr.common.message import Message
35
36
  from flwr.common.retry_invoker import RetryInvoker
36
37
  from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
37
38
  from flwr.common.typing import Fab, Run
@@ -0,0 +1,15 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Public Flower ClientApp APIs."""
flwr/common/__init__.py CHANGED
@@ -15,6 +15,8 @@
15
15
  """Common components shared between server and client."""
16
16
 
17
17
 
18
+ from ..app.error import Error as Error
19
+ from ..app.metadata import Metadata as Metadata
18
20
  from .constant import MessageType as MessageType
19
21
  from .constant import MessageTypeLegacy as MessageTypeLegacy
20
22
  from .context import Context as Context
@@ -23,9 +25,7 @@ from .grpc import GRPC_MAX_MESSAGE_LENGTH
23
25
  from .logger import configure as configure
24
26
  from .logger import log as log
25
27
  from .message import DEFAULT_TTL
26
- from .message import Error as Error
27
28
  from .message import Message as Message
28
- from .message import Metadata as Metadata
29
29
  from .parameter import bytes_to_ndarray as bytes_to_ndarray
30
30
  from .parameter import ndarray_to_bytes as ndarray_to_bytes
31
31
  from .parameter import ndarrays_to_parameters as ndarrays_to_parameters
flwr/common/constant.py CHANGED
@@ -130,7 +130,7 @@ GC_THRESHOLD = 200_000_000 # 200 MB
130
130
 
131
131
  # Constants for Inflatable
132
132
  HEAD_BODY_DIVIDER = b"\x00"
133
- TYPE_BODY_LEN_DIVIDER = " "
133
+ HEAD_VALUE_DIVIDER = " "
134
134
 
135
135
  # Constants for serialization
136
136
  INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
flwr/common/inflatable.py CHANGED
@@ -20,9 +20,7 @@ from __future__ import annotations
20
20
  import hashlib
21
21
  from typing import TypeVar
22
22
 
23
- from .constant import HEAD_BODY_DIVIDER, TYPE_BODY_LEN_DIVIDER
24
-
25
- T = TypeVar("T", bound="InflatableObject")
23
+ from .constant import HEAD_BODY_DIVIDER, HEAD_VALUE_DIVIDER
26
24
 
27
25
 
28
26
  class InflatableObject:
@@ -65,6 +63,9 @@ class InflatableObject:
65
63
  return None
66
64
 
67
65
 
66
+ T = TypeVar("T", bound=InflatableObject)
67
+
68
+
68
69
  def get_object_id(object_content: bytes) -> str:
69
70
  """Return a SHA-256 hash of the (deflated) object content."""
70
71
  return hashlib.sha256(object_content).hexdigest()
@@ -84,13 +85,21 @@ def get_object_body(object_content: bytes, cls: type[T]) -> bytes:
84
85
  return _get_object_body(object_content)
85
86
 
86
87
 
87
- def add_header_to_object_body(object_body: bytes, cls: T) -> bytes:
88
+ def add_header_to_object_body(object_body: bytes, obj: InflatableObject) -> bytes:
88
89
  """Add header to object content."""
89
90
  # Construct header
90
- header = f"{cls.__class__.__qualname__}{TYPE_BODY_LEN_DIVIDER}{len(object_body)}"
91
- enc_header = header.encode(encoding="utf-8")
91
+ header = f"%s{HEAD_VALUE_DIVIDER}%s{HEAD_VALUE_DIVIDER}%d" % (
92
+ obj.__class__.__qualname__, # Type of object
93
+ ",".join((obj.children or {}).keys()), # IDs of child objects
94
+ len(object_body), # Length of object body
95
+ )
96
+
92
97
  # Concatenate header and object body
93
- return enc_header + HEAD_BODY_DIVIDER + object_body
98
+ ret = bytearray()
99
+ ret.extend(header.encode(encoding="utf-8"))
100
+ ret.extend(HEAD_BODY_DIVIDER)
101
+ ret.extend(object_body)
102
+ return bytes(ret)
94
103
 
95
104
 
96
105
  def _get_object_head(object_content: bytes) -> bytes:
@@ -108,9 +117,14 @@ def get_object_type_from_object_content(object_content: bytes) -> str:
108
117
  return get_object_head_values_from_object_content(object_content)[0]
109
118
 
110
119
 
120
+ def get_object_children_ids_from_object_content(object_content: bytes) -> list[str]:
121
+ """Return object children IDs from bytes."""
122
+ return get_object_head_values_from_object_content(object_content)[1]
123
+
124
+
111
125
  def get_object_body_len_from_object_content(object_content: bytes) -> int:
112
126
  """Return length of the object body."""
113
- return get_object_head_values_from_object_content(object_content)[1]
127
+ return get_object_head_values_from_object_content(object_content)[2]
114
128
 
115
129
 
116
130
  def check_body_len_consistency(object_content: bytes) -> bool:
@@ -121,8 +135,23 @@ def check_body_len_consistency(object_content: bytes) -> bool:
121
135
 
122
136
  def get_object_head_values_from_object_content(
123
137
  object_content: bytes,
124
- ) -> tuple[str, int]:
125
- """Return object type and body length from object content."""
138
+ ) -> tuple[str, list[str], int]:
139
+ """Return object type and body length from object content.
140
+
141
+ Parameters
142
+ ----------
143
+ object_content : bytes
144
+ The deflated object content.
145
+
146
+ Returns
147
+ -------
148
+ tuple[str, list[str], int]
149
+ A tuple containing:
150
+ - The object type as a string.
151
+ - A list of child object IDs as strings.
152
+ - The length of the object body as an integer.
153
+ """
126
154
  head = _get_object_head(object_content).decode(encoding="utf-8")
127
- obj_type, body_len = head.split(TYPE_BODY_LEN_DIVIDER, 1)
128
- return obj_type, int(body_len)
155
+ obj_type, children_str, body_len = head.split(HEAD_VALUE_DIVIDER)
156
+ children_ids = children_str.split(",") if children_str else []
157
+ return obj_type, children_ids, int(body_len)
@@ -0,0 +1,97 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """InflatableObject utils."""
16
+
17
+
18
+ from typing import Union
19
+
20
+ from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
21
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
22
+ PullObjectRequest,
23
+ PullObjectResponse,
24
+ PushObjectRequest,
25
+ PushObjectResponse,
26
+ )
27
+ from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
28
+
29
+ from .inflatable import (
30
+ InflatableObject,
31
+ get_object_head_values_from_object_content,
32
+ get_object_id,
33
+ )
34
+ from .record import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
35
+
36
+ # Helper registry that maps names of classes to their type
37
+ inflatable_class_registry: dict[str, type[InflatableObject]] = {
38
+ Array.__qualname__: Array,
39
+ ArrayRecord.__qualname__: ArrayRecord,
40
+ ConfigRecord.__qualname__: ConfigRecord,
41
+ MetricRecord.__qualname__: MetricRecord,
42
+ RecordDict.__qualname__: RecordDict,
43
+ }
44
+
45
+
46
+ def push_object_to_servicer(
47
+ obj: InflatableObject, stub: Union[FleetStub, ServerAppIoStub]
48
+ ) -> set[str]:
49
+ """Recursively deflate an object and push it to the servicer.
50
+
51
+ Objects with the same ID are not pushed twice. It returns the set of pushed object
52
+ IDs.
53
+ """
54
+ pushed_object_ids: set[str] = set()
55
+ # Push children if it has any
56
+ if children := obj.children:
57
+ for child in children.values():
58
+ pushed_object_ids |= push_object_to_servicer(child, stub)
59
+
60
+ # Deflate object and push
61
+ object_content = obj.deflate()
62
+ object_id = get_object_id(object_content)
63
+ _: PushObjectResponse = stub.PushObject(
64
+ PushObjectRequest(
65
+ object_id=object_id,
66
+ object_content=object_content,
67
+ )
68
+ )
69
+ pushed_object_ids.add(object_id)
70
+
71
+ return pushed_object_ids
72
+
73
+
74
+ def pull_object_from_servicer(
75
+ object_id: str, stub: Union[FleetStub, ServerAppIoStub]
76
+ ) -> InflatableObject:
77
+ """Recursively inflate an object by pulling it from the servicer."""
78
+ # Pull object
79
+ object_proto: PullObjectResponse = stub.PullObject(
80
+ PullObjectRequest(object_id=object_id)
81
+ )
82
+ object_content = object_proto.object_content
83
+
84
+ # Extract object class and object_ids of children
85
+ obj_type, children_obj_ids, _ = get_object_head_values_from_object_content(
86
+ object_content=object_content
87
+ )
88
+ # Resolve object class
89
+ cls_type = inflatable_class_registry[obj_type]
90
+
91
+ # Pull all children objects
92
+ children: dict[str, InflatableObject] = {}
93
+ for child_object_id in children_obj_ids:
94
+ children[child_object_id] = pull_object_from_servicer(child_object_id, stub)
95
+
96
+ # Inflate object passing its children
97
+ return cls_type.inflate(object_content, children=children)