flwr-nightly 1.19.0.dev20250516__py3-none-any.whl → 1.19.0.dev20250521__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/client/__init__.py +2 -2
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/app.py +1 -1
- flwr/client/grpc_rere_client/connection.py +2 -1
- flwr/client/rest_client/connection.py +2 -1
- flwr/client/start_client_internal.py +608 -0
- flwr/client/supernode/app.py +1 -1
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/inflatable_grpc_utils.py +97 -0
- flwr/common/message.py +87 -245
- flwr/common/record/array.py +1 -1
- flwr/common/record/configrecord.py +1 -1
- flwr/common/serde.py +9 -54
- flwr/common/serde_utils.py +50 -0
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +13 -11
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +2 -6
- flwr/serverapp/__init__.py +15 -0
- flwr/supercore/__init__.py +15 -0
- flwr/superlink/__init__.py +15 -0
- flwr/supernode/__init__.py +15 -0
- flwr/{client → supernode}/nodestate/in_memory_nodestate.py +1 -1
- {flwr_nightly-1.19.0.dev20250516.dist-info → flwr_nightly-1.19.0.dev20250521.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250516.dist-info → flwr_nightly-1.19.0.dev20250521.dist-info}/RECORD +38 -23
- /flwr/{client → supernode}/nodestate/__init__.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
- {flwr_nightly-1.19.0.dev20250516.dist-info → flwr_nightly-1.19.0.dev20250521.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250516.dist-info → flwr_nightly-1.19.0.dev20250521.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,97 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""InflatableObject utils."""
|
16
|
+
|
17
|
+
|
18
|
+
from typing import Union
|
19
|
+
|
20
|
+
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
21
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
22
|
+
PullObjectRequest,
|
23
|
+
PullObjectResponse,
|
24
|
+
PushObjectRequest,
|
25
|
+
PushObjectResponse,
|
26
|
+
)
|
27
|
+
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
|
28
|
+
|
29
|
+
from .inflatable import (
|
30
|
+
InflatableObject,
|
31
|
+
get_object_head_values_from_object_content,
|
32
|
+
get_object_id,
|
33
|
+
)
|
34
|
+
from .record import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
|
35
|
+
|
36
|
+
# Helper registry that maps names of classes to their type
|
37
|
+
inflatable_class_registry: dict[str, type[InflatableObject]] = {
|
38
|
+
Array.__qualname__: Array,
|
39
|
+
ArrayRecord.__qualname__: ArrayRecord,
|
40
|
+
ConfigRecord.__qualname__: ConfigRecord,
|
41
|
+
MetricRecord.__qualname__: MetricRecord,
|
42
|
+
RecordDict.__qualname__: RecordDict,
|
43
|
+
}
|
44
|
+
|
45
|
+
|
46
|
+
def push_object_to_servicer(
|
47
|
+
obj: InflatableObject, stub: Union[FleetStub, ServerAppIoStub]
|
48
|
+
) -> set[str]:
|
49
|
+
"""Recursively deflate an object and push it to the servicer.
|
50
|
+
|
51
|
+
Objects with the same ID are not pushed twice. It returns the set of pushed object
|
52
|
+
IDs.
|
53
|
+
"""
|
54
|
+
pushed_object_ids: set[str] = set()
|
55
|
+
# Push children if it has any
|
56
|
+
if children := obj.children:
|
57
|
+
for child in children.values():
|
58
|
+
pushed_object_ids |= push_object_to_servicer(child, stub)
|
59
|
+
|
60
|
+
# Deflate object and push
|
61
|
+
object_content = obj.deflate()
|
62
|
+
object_id = get_object_id(object_content)
|
63
|
+
_: PushObjectResponse = stub.PushObject(
|
64
|
+
PushObjectRequest(
|
65
|
+
object_id=object_id,
|
66
|
+
object_content=object_content,
|
67
|
+
)
|
68
|
+
)
|
69
|
+
pushed_object_ids.add(object_id)
|
70
|
+
|
71
|
+
return pushed_object_ids
|
72
|
+
|
73
|
+
|
74
|
+
def pull_object_from_servicer(
|
75
|
+
object_id: str, stub: Union[FleetStub, ServerAppIoStub]
|
76
|
+
) -> InflatableObject:
|
77
|
+
"""Recursively inflate an object by pulling it from the servicer."""
|
78
|
+
# Pull object
|
79
|
+
object_proto: PullObjectResponse = stub.PullObject(
|
80
|
+
PullObjectRequest(object_id=object_id)
|
81
|
+
)
|
82
|
+
object_content = object_proto.object_content
|
83
|
+
|
84
|
+
# Extract object class and object_ids of children
|
85
|
+
obj_type, children_obj_ids, _ = get_object_head_values_from_object_content(
|
86
|
+
object_content=object_content
|
87
|
+
)
|
88
|
+
# Resolve object class
|
89
|
+
cls_type = inflatable_class_registry[obj_type]
|
90
|
+
|
91
|
+
# Pull all children objects
|
92
|
+
children: dict[str, InflatableObject] = {}
|
93
|
+
for child_object_id in children_obj_ids:
|
94
|
+
children[child_object_id] = pull_object_from_servicer(child_object_id, stub)
|
95
|
+
|
96
|
+
# Inflate object passing its children
|
97
|
+
return cls_type.inflate(object_content, children=children)
|
flwr/common/message.py
CHANGED
@@ -18,14 +18,29 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
from logging import WARNING
|
21
|
-
from typing import Any,
|
21
|
+
from typing import Any, cast, overload
|
22
22
|
|
23
23
|
from flwr.common.date import now
|
24
24
|
from flwr.common.logger import warn_deprecated_feature
|
25
|
-
|
26
|
-
|
25
|
+
from flwr.proto.message_pb2 import Message as ProtoMessage # pylint: disable=E0611
|
26
|
+
|
27
|
+
from ..app.error import Error
|
28
|
+
from ..app.metadata import Metadata
|
29
|
+
from .constant import MESSAGE_TTL_TOLERANCE
|
30
|
+
from .inflatable import (
|
31
|
+
InflatableObject,
|
32
|
+
add_header_to_object_body,
|
33
|
+
get_object_body,
|
34
|
+
get_object_children_ids_from_object_content,
|
35
|
+
)
|
27
36
|
from .logger import log
|
28
37
|
from .record import RecordDict
|
38
|
+
from .serde_utils import (
|
39
|
+
error_from_proto,
|
40
|
+
error_to_proto,
|
41
|
+
metadata_from_proto,
|
42
|
+
metadata_to_proto,
|
43
|
+
)
|
29
44
|
|
30
45
|
DEFAULT_TTL = 43200 # This is 12 hours
|
31
46
|
MESSAGE_INIT_ERROR_MESSAGE = (
|
@@ -56,203 +71,7 @@ class MessageInitializationError(TypeError):
|
|
56
71
|
super().__init__(message or MESSAGE_INIT_ERROR_MESSAGE)
|
57
72
|
|
58
73
|
|
59
|
-
class
|
60
|
-
"""The class representing metadata associated with the current message.
|
61
|
-
|
62
|
-
Parameters
|
63
|
-
----------
|
64
|
-
run_id : int
|
65
|
-
An identifier for the current run.
|
66
|
-
message_id : str
|
67
|
-
An identifier for the current message.
|
68
|
-
src_node_id : int
|
69
|
-
An identifier for the node sending this message.
|
70
|
-
dst_node_id : int
|
71
|
-
An identifier for the node receiving this message.
|
72
|
-
reply_to_message_id : str
|
73
|
-
An identifier for the message to which this message is a reply.
|
74
|
-
group_id : str
|
75
|
-
An identifier for grouping messages. In some settings,
|
76
|
-
this is used as the FL round.
|
77
|
-
created_at : float
|
78
|
-
Unix timestamp when the message was created.
|
79
|
-
ttl : float
|
80
|
-
Time-to-live for this message in seconds.
|
81
|
-
message_type : str
|
82
|
-
A string that encodes the action to be executed on
|
83
|
-
the receiving end.
|
84
|
-
"""
|
85
|
-
|
86
|
-
def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
87
|
-
self,
|
88
|
-
run_id: int,
|
89
|
-
message_id: str,
|
90
|
-
src_node_id: int,
|
91
|
-
dst_node_id: int,
|
92
|
-
reply_to_message_id: str,
|
93
|
-
group_id: str,
|
94
|
-
created_at: float,
|
95
|
-
ttl: float,
|
96
|
-
message_type: str,
|
97
|
-
) -> None:
|
98
|
-
var_dict = {
|
99
|
-
"_run_id": run_id,
|
100
|
-
"_message_id": message_id,
|
101
|
-
"_src_node_id": src_node_id,
|
102
|
-
"_dst_node_id": dst_node_id,
|
103
|
-
"_reply_to_message_id": reply_to_message_id,
|
104
|
-
"_group_id": group_id,
|
105
|
-
"_created_at": created_at,
|
106
|
-
"_ttl": ttl,
|
107
|
-
"_message_type": message_type,
|
108
|
-
}
|
109
|
-
self.__dict__.update(var_dict)
|
110
|
-
self.message_type = message_type # Trigger validation
|
111
|
-
|
112
|
-
@property
|
113
|
-
def run_id(self) -> int:
|
114
|
-
"""An identifier for the current run."""
|
115
|
-
return cast(int, self.__dict__["_run_id"])
|
116
|
-
|
117
|
-
@property
|
118
|
-
def message_id(self) -> str:
|
119
|
-
"""An identifier for the current message."""
|
120
|
-
return cast(str, self.__dict__["_message_id"])
|
121
|
-
|
122
|
-
@property
|
123
|
-
def src_node_id(self) -> int:
|
124
|
-
"""An identifier for the node sending this message."""
|
125
|
-
return cast(int, self.__dict__["_src_node_id"])
|
126
|
-
|
127
|
-
@property
|
128
|
-
def reply_to_message_id(self) -> str:
|
129
|
-
"""An identifier for the message to which this message is a reply."""
|
130
|
-
return cast(str, self.__dict__["_reply_to_message_id"])
|
131
|
-
|
132
|
-
@property
|
133
|
-
def dst_node_id(self) -> int:
|
134
|
-
"""An identifier for the node receiving this message."""
|
135
|
-
return cast(int, self.__dict__["_dst_node_id"])
|
136
|
-
|
137
|
-
@dst_node_id.setter
|
138
|
-
def dst_node_id(self, value: int) -> None:
|
139
|
-
"""Set dst_node_id."""
|
140
|
-
self.__dict__["_dst_node_id"] = value
|
141
|
-
|
142
|
-
@property
|
143
|
-
def group_id(self) -> str:
|
144
|
-
"""An identifier for grouping messages."""
|
145
|
-
return cast(str, self.__dict__["_group_id"])
|
146
|
-
|
147
|
-
@group_id.setter
|
148
|
-
def group_id(self, value: str) -> None:
|
149
|
-
"""Set group_id."""
|
150
|
-
self.__dict__["_group_id"] = value
|
151
|
-
|
152
|
-
@property
|
153
|
-
def created_at(self) -> float:
|
154
|
-
"""Unix timestamp when the message was created."""
|
155
|
-
return cast(float, self.__dict__["_created_at"])
|
156
|
-
|
157
|
-
@created_at.setter
|
158
|
-
def created_at(self, value: float) -> None:
|
159
|
-
"""Set creation timestamp of this message."""
|
160
|
-
self.__dict__["_created_at"] = value
|
161
|
-
|
162
|
-
@property
|
163
|
-
def delivered_at(self) -> str:
|
164
|
-
"""Unix timestamp when the message was delivered."""
|
165
|
-
return cast(str, self.__dict__["_delivered_at"])
|
166
|
-
|
167
|
-
@delivered_at.setter
|
168
|
-
def delivered_at(self, value: str) -> None:
|
169
|
-
"""Set delivery timestamp of this message."""
|
170
|
-
self.__dict__["_delivered_at"] = value
|
171
|
-
|
172
|
-
@property
|
173
|
-
def ttl(self) -> float:
|
174
|
-
"""Time-to-live for this message."""
|
175
|
-
return cast(float, self.__dict__["_ttl"])
|
176
|
-
|
177
|
-
@ttl.setter
|
178
|
-
def ttl(self, value: float) -> None:
|
179
|
-
"""Set ttl."""
|
180
|
-
self.__dict__["_ttl"] = value
|
181
|
-
|
182
|
-
@property
|
183
|
-
def message_type(self) -> str:
|
184
|
-
"""A string that encodes the action to be executed on the receiving end."""
|
185
|
-
return cast(str, self.__dict__["_message_type"])
|
186
|
-
|
187
|
-
@message_type.setter
|
188
|
-
def message_type(self, value: str) -> None:
|
189
|
-
"""Set message_type."""
|
190
|
-
# Validate message type
|
191
|
-
if validate_legacy_message_type(value):
|
192
|
-
pass # Backward compatibility for legacy message types
|
193
|
-
elif not validate_message_type(value):
|
194
|
-
raise ValueError(
|
195
|
-
f"Invalid message type: '{value}'. "
|
196
|
-
"Expected format: '<category>' or '<category>.<action>', "
|
197
|
-
"where <category> must be 'train', 'evaluate', or 'query', "
|
198
|
-
"and <action> must be a valid Python identifier."
|
199
|
-
)
|
200
|
-
|
201
|
-
self.__dict__["_message_type"] = value
|
202
|
-
|
203
|
-
def __repr__(self) -> str:
|
204
|
-
"""Return a string representation of this instance."""
|
205
|
-
view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
|
206
|
-
return f"{self.__class__.__qualname__}({view})"
|
207
|
-
|
208
|
-
def __eq__(self, other: object) -> bool:
|
209
|
-
"""Compare two instances of the class."""
|
210
|
-
if not isinstance(other, self.__class__):
|
211
|
-
raise NotImplementedError
|
212
|
-
return self.__dict__ == other.__dict__
|
213
|
-
|
214
|
-
|
215
|
-
class Error:
|
216
|
-
"""The class storing information about an error that occurred.
|
217
|
-
|
218
|
-
Parameters
|
219
|
-
----------
|
220
|
-
code : int
|
221
|
-
An identifier for the error.
|
222
|
-
reason : Optional[str]
|
223
|
-
A reason for why the error arose (e.g. an exception stack-trace)
|
224
|
-
"""
|
225
|
-
|
226
|
-
def __init__(self, code: int, reason: str | None = None) -> None:
|
227
|
-
var_dict = {
|
228
|
-
"_code": code,
|
229
|
-
"_reason": reason,
|
230
|
-
}
|
231
|
-
self.__dict__.update(var_dict)
|
232
|
-
|
233
|
-
@property
|
234
|
-
def code(self) -> int:
|
235
|
-
"""Error code."""
|
236
|
-
return cast(int, self.__dict__["_code"])
|
237
|
-
|
238
|
-
@property
|
239
|
-
def reason(self) -> str | None:
|
240
|
-
"""Reason reported about the error."""
|
241
|
-
return cast(Optional[str], self.__dict__["_reason"])
|
242
|
-
|
243
|
-
def __repr__(self) -> str:
|
244
|
-
"""Return a string representation of this instance."""
|
245
|
-
view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
|
246
|
-
return f"{self.__class__.__qualname__}({view})"
|
247
|
-
|
248
|
-
def __eq__(self, other: object) -> bool:
|
249
|
-
"""Compare two instances of the class."""
|
250
|
-
if not isinstance(other, self.__class__):
|
251
|
-
raise NotImplementedError
|
252
|
-
return self.__dict__ == other.__dict__
|
253
|
-
|
254
|
-
|
255
|
-
class Message:
|
74
|
+
class Message(InflatableObject):
|
256
75
|
"""Represents a message exchanged between ClientApp and ServerApp.
|
257
76
|
|
258
77
|
This class encapsulates the payload and metadata necessary for communication
|
@@ -525,6 +344,74 @@ class Message:
|
|
525
344
|
)
|
526
345
|
return f"{self.__class__.__qualname__}({view})"
|
527
346
|
|
347
|
+
@property
|
348
|
+
def children(self) -> dict[str, InflatableObject] | None:
|
349
|
+
"""Return a dictionary of a single RecordDict with its Object IDs as key."""
|
350
|
+
return {self.content.object_id: self.content} if self.has_content() else None
|
351
|
+
|
352
|
+
def deflate(self) -> bytes:
|
353
|
+
"""Deflate message."""
|
354
|
+
# Store message metadata and error in object body
|
355
|
+
obj_body = ProtoMessage(
|
356
|
+
metadata=metadata_to_proto(self.metadata),
|
357
|
+
content=None,
|
358
|
+
error=error_to_proto(self.error) if self.has_error() else None,
|
359
|
+
).SerializeToString(deterministic=True)
|
360
|
+
|
361
|
+
return add_header_to_object_body(object_body=obj_body, obj=self)
|
362
|
+
|
363
|
+
@classmethod
|
364
|
+
def inflate(
|
365
|
+
cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
|
366
|
+
) -> Message:
|
367
|
+
"""Inflate an Message from bytes.
|
368
|
+
|
369
|
+
Parameters
|
370
|
+
----------
|
371
|
+
object_content : bytes
|
372
|
+
The deflated object content of the Message.
|
373
|
+
children : Optional[dict[str, InflatableObject]] (default: None)
|
374
|
+
Dictionary of children InflatableObjects mapped to their Object IDs.
|
375
|
+
These children enable the full inflation of the Message.
|
376
|
+
|
377
|
+
Returns
|
378
|
+
-------
|
379
|
+
Message
|
380
|
+
The inflated Message.
|
381
|
+
"""
|
382
|
+
if children is None:
|
383
|
+
children = {}
|
384
|
+
|
385
|
+
# Get the children id from the deflated message
|
386
|
+
children_ids = get_object_children_ids_from_object_content(object_content)
|
387
|
+
|
388
|
+
# If the message had content, only one children is possible
|
389
|
+
# If the message carried an error, the returned listed should be empty
|
390
|
+
if children_ids != list(children.keys()):
|
391
|
+
raise ValueError(
|
392
|
+
f"Mismatch in children object IDs: expected {children_ids}, but "
|
393
|
+
f"received {list(children.keys())}. The provided children must exactly "
|
394
|
+
"match the IDs specified in the object head."
|
395
|
+
)
|
396
|
+
|
397
|
+
# Inflate content
|
398
|
+
obj_body = get_object_body(object_content, cls)
|
399
|
+
proto_message = ProtoMessage.FromString(obj_body)
|
400
|
+
|
401
|
+
# Prepare content if error wasn't set in protobuf message
|
402
|
+
if proto_message.HasField("error"):
|
403
|
+
content = None
|
404
|
+
error = error_from_proto(proto_message.error)
|
405
|
+
else:
|
406
|
+
content = cast(RecordDict, children[children_ids[0]])
|
407
|
+
error = None
|
408
|
+
# Return message
|
409
|
+
return make_message(
|
410
|
+
metadata=metadata_from_proto(proto_message.metadata),
|
411
|
+
content=content,
|
412
|
+
error=error,
|
413
|
+
)
|
414
|
+
|
528
415
|
|
529
416
|
def make_message(
|
530
417
|
metadata: Metadata, content: RecordDict | None = None, error: Error | None = None
|
@@ -614,48 +501,3 @@ def _check_arg_types( # pylint: disable=too-many-arguments, R0917
|
|
614
501
|
):
|
615
502
|
return
|
616
503
|
raise MessageInitializationError()
|
617
|
-
|
618
|
-
|
619
|
-
def validate_message_type(message_type: str) -> bool:
|
620
|
-
"""Validate if the message type is valid.
|
621
|
-
|
622
|
-
A valid message type format must be one of the following:
|
623
|
-
|
624
|
-
- "<category>"
|
625
|
-
- "<category>.<action>"
|
626
|
-
|
627
|
-
where `category` must be one of "train", "evaluate", or "query",
|
628
|
-
and `action` must be a valid Python identifier.
|
629
|
-
"""
|
630
|
-
# Check if conforming to the format "<category>"
|
631
|
-
valid_types = {
|
632
|
-
MessageType.TRAIN,
|
633
|
-
MessageType.EVALUATE,
|
634
|
-
MessageType.QUERY,
|
635
|
-
MessageType.SYSTEM,
|
636
|
-
}
|
637
|
-
if message_type in valid_types:
|
638
|
-
return True
|
639
|
-
|
640
|
-
# Check if conforming to the format "<category>.<action>"
|
641
|
-
if message_type.count(".") != 1:
|
642
|
-
return False
|
643
|
-
|
644
|
-
category, action = message_type.split(".")
|
645
|
-
if category in valid_types and action.isidentifier():
|
646
|
-
return True
|
647
|
-
|
648
|
-
return False
|
649
|
-
|
650
|
-
|
651
|
-
def validate_legacy_message_type(message_type: str) -> bool:
|
652
|
-
"""Validate if the legacy message type is valid."""
|
653
|
-
# Backward compatibility for legacy message types
|
654
|
-
if message_type in (
|
655
|
-
MessageTypeLegacy.GET_PARAMETERS,
|
656
|
-
MessageTypeLegacy.GET_PROPERTIES,
|
657
|
-
"reconnect",
|
658
|
-
):
|
659
|
-
return True
|
660
|
-
|
661
|
-
return False
|
flwr/common/record/array.py
CHANGED
@@ -204,7 +204,7 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues], InflatableObject):
|
|
204
204
|
ConfigRecord
|
205
205
|
The inflated ConfigRecord.
|
206
206
|
"""
|
207
|
-
if children
|
207
|
+
if children:
|
208
208
|
raise ValueError("`ConfigRecord` objects do not have children.")
|
209
209
|
|
210
210
|
obj_body = get_object_body(object_content, cls)
|
flwr/common/serde.py
CHANGED
@@ -20,11 +20,9 @@ from typing import Any, cast
|
|
20
20
|
|
21
21
|
# pylint: disable=E0611
|
22
22
|
from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
|
23
|
-
from flwr.proto.error_pb2 import Error as ProtoError
|
24
23
|
from flwr.proto.fab_pb2 import Fab as ProtoFab
|
25
24
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
26
25
|
from flwr.proto.message_pb2 import Message as ProtoMessage
|
27
|
-
from flwr.proto.message_pb2 import Metadata as ProtoMetadata
|
28
26
|
from flwr.proto.recorddict_pb2 import Array as ProtoArray
|
29
27
|
from flwr.proto.recorddict_pb2 import ArrayRecord as ProtoArrayRecord
|
30
28
|
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
@@ -55,8 +53,15 @@ from . import (
|
|
55
53
|
typing,
|
56
54
|
)
|
57
55
|
from .constant import INT64_MAX_VALUE
|
58
|
-
from .message import
|
59
|
-
from .serde_utils import
|
56
|
+
from .message import Message, make_message
|
57
|
+
from .serde_utils import (
|
58
|
+
error_from_proto,
|
59
|
+
error_to_proto,
|
60
|
+
metadata_from_proto,
|
61
|
+
metadata_to_proto,
|
62
|
+
record_value_dict_from_proto,
|
63
|
+
record_value_dict_to_proto,
|
64
|
+
)
|
60
65
|
|
61
66
|
# === Parameters message ===
|
62
67
|
|
@@ -446,21 +451,6 @@ def config_record_from_proto(record_proto: ProtoConfigRecord) -> ConfigRecord:
|
|
446
451
|
)
|
447
452
|
|
448
453
|
|
449
|
-
# === Error message ===
|
450
|
-
|
451
|
-
|
452
|
-
def error_to_proto(error: Error) -> ProtoError:
|
453
|
-
"""Serialize Error to ProtoBuf."""
|
454
|
-
reason = error.reason if error.reason else ""
|
455
|
-
return ProtoError(code=error.code, reason=reason)
|
456
|
-
|
457
|
-
|
458
|
-
def error_from_proto(error_proto: ProtoError) -> Error:
|
459
|
-
"""Deserialize Error from ProtoBuf."""
|
460
|
-
reason = error_proto.reason if len(error_proto.reason) > 0 else None
|
461
|
-
return Error(code=error_proto.code, reason=reason)
|
462
|
-
|
463
|
-
|
464
454
|
# === RecordDict message ===
|
465
455
|
|
466
456
|
|
@@ -549,41 +539,6 @@ def user_config_value_from_proto(scalar_msg: Scalar) -> typing.UserConfigValue:
|
|
549
539
|
return cast(typing.UserConfigValue, scalar)
|
550
540
|
|
551
541
|
|
552
|
-
# === Metadata messages ===
|
553
|
-
|
554
|
-
|
555
|
-
def metadata_to_proto(metadata: Metadata) -> ProtoMetadata:
|
556
|
-
"""Serialize `Metadata` to ProtoBuf."""
|
557
|
-
proto = ProtoMetadata( # pylint: disable=E1101
|
558
|
-
run_id=metadata.run_id,
|
559
|
-
message_id=metadata.message_id,
|
560
|
-
src_node_id=metadata.src_node_id,
|
561
|
-
dst_node_id=metadata.dst_node_id,
|
562
|
-
reply_to_message_id=metadata.reply_to_message_id,
|
563
|
-
group_id=metadata.group_id,
|
564
|
-
ttl=metadata.ttl,
|
565
|
-
message_type=metadata.message_type,
|
566
|
-
created_at=metadata.created_at,
|
567
|
-
)
|
568
|
-
return proto
|
569
|
-
|
570
|
-
|
571
|
-
def metadata_from_proto(metadata_proto: ProtoMetadata) -> Metadata:
|
572
|
-
"""Deserialize `Metadata` from ProtoBuf."""
|
573
|
-
metadata = Metadata(
|
574
|
-
run_id=metadata_proto.run_id,
|
575
|
-
message_id=metadata_proto.message_id,
|
576
|
-
src_node_id=metadata_proto.src_node_id,
|
577
|
-
dst_node_id=metadata_proto.dst_node_id,
|
578
|
-
reply_to_message_id=metadata_proto.reply_to_message_id,
|
579
|
-
group_id=metadata_proto.group_id,
|
580
|
-
created_at=metadata_proto.created_at,
|
581
|
-
ttl=metadata_proto.ttl,
|
582
|
-
message_type=metadata_proto.message_type,
|
583
|
-
)
|
584
|
-
return metadata
|
585
|
-
|
586
|
-
|
587
542
|
# === Message messages ===
|
588
543
|
|
589
544
|
|
flwr/common/serde_utils.py
CHANGED
@@ -20,6 +20,8 @@ from typing import Any, TypeVar, cast
|
|
20
20
|
from google.protobuf.message import Message as GrpcMessage
|
21
21
|
|
22
22
|
# pylint: disable=E0611
|
23
|
+
from flwr.proto.error_pb2 import Error as ProtoError
|
24
|
+
from flwr.proto.message_pb2 import Metadata as ProtoMetadata
|
23
25
|
from flwr.proto.recorddict_pb2 import (
|
24
26
|
BoolList,
|
25
27
|
BytesList,
|
@@ -29,9 +31,13 @@ from flwr.proto.recorddict_pb2 import (
|
|
29
31
|
UintList,
|
30
32
|
)
|
31
33
|
|
34
|
+
from ..app.error import Error
|
35
|
+
from ..app.metadata import Metadata
|
32
36
|
from .constant import INT64_MAX_VALUE
|
33
37
|
from .record.typeddict import TypedDict
|
34
38
|
|
39
|
+
# pylint: enable=E0611
|
40
|
+
|
35
41
|
_type_to_field: dict[type, str] = {
|
36
42
|
float: "double",
|
37
43
|
int: "sint64",
|
@@ -121,3 +127,47 @@ def record_value_dict_from_proto(
|
|
121
127
|
) -> dict[str, Any]:
|
122
128
|
"""Deserialize the record value dict from ProtoBuf."""
|
123
129
|
return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}
|
130
|
+
|
131
|
+
|
132
|
+
def error_to_proto(error: Error) -> ProtoError:
|
133
|
+
"""Serialize Error to ProtoBuf."""
|
134
|
+
reason = error.reason if error.reason else ""
|
135
|
+
return ProtoError(code=error.code, reason=reason)
|
136
|
+
|
137
|
+
|
138
|
+
def error_from_proto(error_proto: ProtoError) -> Error:
|
139
|
+
"""Deserialize Error from ProtoBuf."""
|
140
|
+
reason = error_proto.reason if len(error_proto.reason) > 0 else None
|
141
|
+
return Error(code=error_proto.code, reason=reason)
|
142
|
+
|
143
|
+
|
144
|
+
def metadata_to_proto(metadata: Metadata) -> ProtoMetadata:
|
145
|
+
"""Serialize `Metadata` to ProtoBuf."""
|
146
|
+
proto = ProtoMetadata( # pylint: disable=E1101
|
147
|
+
run_id=metadata.run_id,
|
148
|
+
message_id=metadata.message_id,
|
149
|
+
src_node_id=metadata.src_node_id,
|
150
|
+
dst_node_id=metadata.dst_node_id,
|
151
|
+
reply_to_message_id=metadata.reply_to_message_id,
|
152
|
+
group_id=metadata.group_id,
|
153
|
+
ttl=metadata.ttl,
|
154
|
+
message_type=metadata.message_type,
|
155
|
+
created_at=metadata.created_at,
|
156
|
+
)
|
157
|
+
return proto
|
158
|
+
|
159
|
+
|
160
|
+
def metadata_from_proto(metadata_proto: ProtoMetadata) -> Metadata:
|
161
|
+
"""Deserialize `Metadata` from ProtoBuf."""
|
162
|
+
metadata = Metadata(
|
163
|
+
run_id=metadata_proto.run_id,
|
164
|
+
message_id=metadata_proto.message_id,
|
165
|
+
src_node_id=metadata_proto.src_node_id,
|
166
|
+
dst_node_id=metadata_proto.dst_node_id,
|
167
|
+
reply_to_message_id=metadata_proto.reply_to_message_id,
|
168
|
+
group_id=metadata_proto.group_id,
|
169
|
+
created_at=metadata_proto.created_at,
|
170
|
+
ttl=metadata_proto.ttl,
|
171
|
+
message_type=metadata_proto.message_type,
|
172
|
+
)
|
173
|
+
return metadata
|
flwr/compat/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Compatibility package containing deprecated legacy components."""
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Legacy components previously located in ``flwr.client``."""
|