flwr 1.18.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/build.py +82 -57
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +15 -36
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +10 -18
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +31 -5
- flwr/client/__init__.py +2 -2
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/__init__.py +0 -7
- flwr/client/grpc_adapter_client/connection.py +4 -4
- flwr/client/grpc_rere_client/connection.py +130 -60
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/comms_mods.py +36 -17
- flwr/client/rest_client/connection.py +173 -67
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +29 -3
- flwr/common/constant.py +36 -7
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit_handlers.py +30 -0
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_grpc_utils.py +99 -0
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/inflatable_utils.py +341 -0
- flwr/common/message.py +110 -242
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/array.py +323 -0
- flwr/common/record/arrayrecord.py +103 -225
- flwr/common/record/configrecord.py +59 -4
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/metricrecord.py +55 -4
- flwr/common/record/recorddict.py +69 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +5 -1
- flwr/common/serde.py +59 -183
- flwr/common/serde_utils.py +175 -0
- flwr/common/typing.py +5 -3
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +19 -159
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/fleet_pb2.py +32 -27
- flwr/proto/fleet_pb2.pyi +49 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +32 -23
- flwr/proto/serverappio_pb2.pyi +45 -3
- flwr/proto/serverappio_pb2_grpc.py +138 -34
- flwr/proto/serverappio_pb2_grpc.pyi +54 -13
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +1 -1
- flwr/server/app.py +68 -186
- flwr/server/compat/app_utils.py +50 -28
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grpc_grid.py +104 -34
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/serverapp/app.py +18 -0
- flwr/server/superlink/ffs/__init__.py +2 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +13 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +101 -7
- flwr/server/superlink/fleet/message_handler/message_handler.py +135 -18
- flwr/server/superlink/fleet/rest_rere/rest_api.py +72 -11
- flwr/server/superlink/fleet/vce/vce_api.py +6 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
- flwr/server/superlink/linkstate/linkstate.py +53 -20
- flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
- flwr/server/superlink/linkstate/utils.py +33 -29
- flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
- flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
- flwr/server/superlink/simulation/simulationio_servicer.py +25 -1
- flwr/server/superlink/utils.py +44 -2
- flwr/server/utils/validator.py +2 -2
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/app.py +17 -0
- flwr/supercore/__init__.py +15 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +192 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/superexec/deployment.py +6 -2
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_grpc.py +7 -3
- flwr/superexec/exec_servicer.py +125 -23
- flwr/superexec/exec_user_auth_interceptor.py +37 -8
- flwr/superexec/executor.py +4 -0
- flwr/superexec/simulation.py +7 -1
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +0 -7
- flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +7 -14
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -12
- flwr/supernode/cli/flwr_clientapp.py +81 -0
- flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
- flwr/supernode/nodestate/nodestate.py +212 -0
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +25 -56
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +24 -0
- flwr/supernode/start_client_internal.py +491 -0
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/METADATA +5 -4
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/RECORD +141 -108
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
- /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
- /flwr/{client → supernode}/nodestate/__init__.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
- /flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +0 -0
flwr/common/message.py
CHANGED
|
@@ -18,14 +18,32 @@
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
20
|
from logging import WARNING
|
|
21
|
-
from typing import Any,
|
|
21
|
+
from typing import Any, cast, overload
|
|
22
22
|
|
|
23
23
|
from flwr.common.date import now
|
|
24
24
|
from flwr.common.logger import warn_deprecated_feature
|
|
25
|
-
|
|
26
|
-
from .
|
|
25
|
+
from flwr.proto.message_pb2 import Message as ProtoMessage # pylint: disable=E0611
|
|
26
|
+
from flwr.proto.message_pb2 import Metadata as ProtoMetadata # pylint: disable=E0611
|
|
27
|
+
from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
|
|
28
|
+
|
|
29
|
+
from ..app.error import Error
|
|
30
|
+
from ..app.metadata import Metadata
|
|
31
|
+
from .constant import MESSAGE_TTL_TOLERANCE
|
|
32
|
+
from .inflatable import (
|
|
33
|
+
InflatableObject,
|
|
34
|
+
add_header_to_object_body,
|
|
35
|
+
get_descendant_object_ids,
|
|
36
|
+
get_object_body,
|
|
37
|
+
get_object_children_ids_from_object_content,
|
|
38
|
+
)
|
|
27
39
|
from .logger import log
|
|
28
40
|
from .record import RecordDict
|
|
41
|
+
from .serde_utils import (
|
|
42
|
+
error_from_proto,
|
|
43
|
+
error_to_proto,
|
|
44
|
+
metadata_from_proto,
|
|
45
|
+
metadata_to_proto,
|
|
46
|
+
)
|
|
29
47
|
|
|
30
48
|
DEFAULT_TTL = 43200 # This is 12 hours
|
|
31
49
|
MESSAGE_INIT_ERROR_MESSAGE = (
|
|
@@ -56,203 +74,7 @@ class MessageInitializationError(TypeError):
|
|
|
56
74
|
super().__init__(message or MESSAGE_INIT_ERROR_MESSAGE)
|
|
57
75
|
|
|
58
76
|
|
|
59
|
-
class
|
|
60
|
-
"""The class representing metadata associated with the current message.
|
|
61
|
-
|
|
62
|
-
Parameters
|
|
63
|
-
----------
|
|
64
|
-
run_id : int
|
|
65
|
-
An identifier for the current run.
|
|
66
|
-
message_id : str
|
|
67
|
-
An identifier for the current message.
|
|
68
|
-
src_node_id : int
|
|
69
|
-
An identifier for the node sending this message.
|
|
70
|
-
dst_node_id : int
|
|
71
|
-
An identifier for the node receiving this message.
|
|
72
|
-
reply_to_message_id : str
|
|
73
|
-
An identifier for the message to which this message is a reply.
|
|
74
|
-
group_id : str
|
|
75
|
-
An identifier for grouping messages. In some settings,
|
|
76
|
-
this is used as the FL round.
|
|
77
|
-
created_at : float
|
|
78
|
-
Unix timestamp when the message was created.
|
|
79
|
-
ttl : float
|
|
80
|
-
Time-to-live for this message in seconds.
|
|
81
|
-
message_type : str
|
|
82
|
-
A string that encodes the action to be executed on
|
|
83
|
-
the receiving end.
|
|
84
|
-
"""
|
|
85
|
-
|
|
86
|
-
def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
87
|
-
self,
|
|
88
|
-
run_id: int,
|
|
89
|
-
message_id: str,
|
|
90
|
-
src_node_id: int,
|
|
91
|
-
dst_node_id: int,
|
|
92
|
-
reply_to_message_id: str,
|
|
93
|
-
group_id: str,
|
|
94
|
-
created_at: float,
|
|
95
|
-
ttl: float,
|
|
96
|
-
message_type: str,
|
|
97
|
-
) -> None:
|
|
98
|
-
var_dict = {
|
|
99
|
-
"_run_id": run_id,
|
|
100
|
-
"_message_id": message_id,
|
|
101
|
-
"_src_node_id": src_node_id,
|
|
102
|
-
"_dst_node_id": dst_node_id,
|
|
103
|
-
"_reply_to_message_id": reply_to_message_id,
|
|
104
|
-
"_group_id": group_id,
|
|
105
|
-
"_created_at": created_at,
|
|
106
|
-
"_ttl": ttl,
|
|
107
|
-
"_message_type": message_type,
|
|
108
|
-
}
|
|
109
|
-
self.__dict__.update(var_dict)
|
|
110
|
-
self.message_type = message_type # Trigger validation
|
|
111
|
-
|
|
112
|
-
@property
|
|
113
|
-
def run_id(self) -> int:
|
|
114
|
-
"""An identifier for the current run."""
|
|
115
|
-
return cast(int, self.__dict__["_run_id"])
|
|
116
|
-
|
|
117
|
-
@property
|
|
118
|
-
def message_id(self) -> str:
|
|
119
|
-
"""An identifier for the current message."""
|
|
120
|
-
return cast(str, self.__dict__["_message_id"])
|
|
121
|
-
|
|
122
|
-
@property
|
|
123
|
-
def src_node_id(self) -> int:
|
|
124
|
-
"""An identifier for the node sending this message."""
|
|
125
|
-
return cast(int, self.__dict__["_src_node_id"])
|
|
126
|
-
|
|
127
|
-
@property
|
|
128
|
-
def reply_to_message_id(self) -> str:
|
|
129
|
-
"""An identifier for the message to which this message is a reply."""
|
|
130
|
-
return cast(str, self.__dict__["_reply_to_message_id"])
|
|
131
|
-
|
|
132
|
-
@property
|
|
133
|
-
def dst_node_id(self) -> int:
|
|
134
|
-
"""An identifier for the node receiving this message."""
|
|
135
|
-
return cast(int, self.__dict__["_dst_node_id"])
|
|
136
|
-
|
|
137
|
-
@dst_node_id.setter
|
|
138
|
-
def dst_node_id(self, value: int) -> None:
|
|
139
|
-
"""Set dst_node_id."""
|
|
140
|
-
self.__dict__["_dst_node_id"] = value
|
|
141
|
-
|
|
142
|
-
@property
|
|
143
|
-
def group_id(self) -> str:
|
|
144
|
-
"""An identifier for grouping messages."""
|
|
145
|
-
return cast(str, self.__dict__["_group_id"])
|
|
146
|
-
|
|
147
|
-
@group_id.setter
|
|
148
|
-
def group_id(self, value: str) -> None:
|
|
149
|
-
"""Set group_id."""
|
|
150
|
-
self.__dict__["_group_id"] = value
|
|
151
|
-
|
|
152
|
-
@property
|
|
153
|
-
def created_at(self) -> float:
|
|
154
|
-
"""Unix timestamp when the message was created."""
|
|
155
|
-
return cast(float, self.__dict__["_created_at"])
|
|
156
|
-
|
|
157
|
-
@created_at.setter
|
|
158
|
-
def created_at(self, value: float) -> None:
|
|
159
|
-
"""Set creation timestamp of this message."""
|
|
160
|
-
self.__dict__["_created_at"] = value
|
|
161
|
-
|
|
162
|
-
@property
|
|
163
|
-
def delivered_at(self) -> str:
|
|
164
|
-
"""Unix timestamp when the message was delivered."""
|
|
165
|
-
return cast(str, self.__dict__["_delivered_at"])
|
|
166
|
-
|
|
167
|
-
@delivered_at.setter
|
|
168
|
-
def delivered_at(self, value: str) -> None:
|
|
169
|
-
"""Set delivery timestamp of this message."""
|
|
170
|
-
self.__dict__["_delivered_at"] = value
|
|
171
|
-
|
|
172
|
-
@property
|
|
173
|
-
def ttl(self) -> float:
|
|
174
|
-
"""Time-to-live for this message."""
|
|
175
|
-
return cast(float, self.__dict__["_ttl"])
|
|
176
|
-
|
|
177
|
-
@ttl.setter
|
|
178
|
-
def ttl(self, value: float) -> None:
|
|
179
|
-
"""Set ttl."""
|
|
180
|
-
self.__dict__["_ttl"] = value
|
|
181
|
-
|
|
182
|
-
@property
|
|
183
|
-
def message_type(self) -> str:
|
|
184
|
-
"""A string that encodes the action to be executed on the receiving end."""
|
|
185
|
-
return cast(str, self.__dict__["_message_type"])
|
|
186
|
-
|
|
187
|
-
@message_type.setter
|
|
188
|
-
def message_type(self, value: str) -> None:
|
|
189
|
-
"""Set message_type."""
|
|
190
|
-
# Validate message type
|
|
191
|
-
if validate_legacy_message_type(value):
|
|
192
|
-
pass # Backward compatibility for legacy message types
|
|
193
|
-
elif not validate_message_type(value):
|
|
194
|
-
raise ValueError(
|
|
195
|
-
f"Invalid message type: '{value}'. "
|
|
196
|
-
"Expected format: '<category>' or '<category>.<action>', "
|
|
197
|
-
"where <category> must be 'train', 'evaluate', or 'query', "
|
|
198
|
-
"and <action> must be a valid Python identifier."
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
self.__dict__["_message_type"] = value
|
|
202
|
-
|
|
203
|
-
def __repr__(self) -> str:
|
|
204
|
-
"""Return a string representation of this instance."""
|
|
205
|
-
view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
|
|
206
|
-
return f"{self.__class__.__qualname__}({view})"
|
|
207
|
-
|
|
208
|
-
def __eq__(self, other: object) -> bool:
|
|
209
|
-
"""Compare two instances of the class."""
|
|
210
|
-
if not isinstance(other, self.__class__):
|
|
211
|
-
raise NotImplementedError
|
|
212
|
-
return self.__dict__ == other.__dict__
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
class Error:
|
|
216
|
-
"""The class storing information about an error that occurred.
|
|
217
|
-
|
|
218
|
-
Parameters
|
|
219
|
-
----------
|
|
220
|
-
code : int
|
|
221
|
-
An identifier for the error.
|
|
222
|
-
reason : Optional[str]
|
|
223
|
-
A reason for why the error arose (e.g. an exception stack-trace)
|
|
224
|
-
"""
|
|
225
|
-
|
|
226
|
-
def __init__(self, code: int, reason: str | None = None) -> None:
|
|
227
|
-
var_dict = {
|
|
228
|
-
"_code": code,
|
|
229
|
-
"_reason": reason,
|
|
230
|
-
}
|
|
231
|
-
self.__dict__.update(var_dict)
|
|
232
|
-
|
|
233
|
-
@property
|
|
234
|
-
def code(self) -> int:
|
|
235
|
-
"""Error code."""
|
|
236
|
-
return cast(int, self.__dict__["_code"])
|
|
237
|
-
|
|
238
|
-
@property
|
|
239
|
-
def reason(self) -> str | None:
|
|
240
|
-
"""Reason reported about the error."""
|
|
241
|
-
return cast(Optional[str], self.__dict__["_reason"])
|
|
242
|
-
|
|
243
|
-
def __repr__(self) -> str:
|
|
244
|
-
"""Return a string representation of this instance."""
|
|
245
|
-
view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
|
|
246
|
-
return f"{self.__class__.__qualname__}({view})"
|
|
247
|
-
|
|
248
|
-
def __eq__(self, other: object) -> bool:
|
|
249
|
-
"""Compare two instances of the class."""
|
|
250
|
-
if not isinstance(other, self.__class__):
|
|
251
|
-
raise NotImplementedError
|
|
252
|
-
return self.__dict__ == other.__dict__
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
class Message:
|
|
77
|
+
class Message(InflatableObject):
|
|
256
78
|
"""Represents a message exchanged between ClientApp and ServerApp.
|
|
257
79
|
|
|
258
80
|
This class encapsulates the payload and metadata necessary for communication
|
|
@@ -525,6 +347,77 @@ class Message:
|
|
|
525
347
|
)
|
|
526
348
|
return f"{self.__class__.__qualname__}({view})"
|
|
527
349
|
|
|
350
|
+
@property
|
|
351
|
+
def children(self) -> dict[str, InflatableObject] | None:
|
|
352
|
+
"""Return a dictionary of a single RecordDict with its Object IDs as key."""
|
|
353
|
+
return {self.content.object_id: self.content} if self.has_content() else None
|
|
354
|
+
|
|
355
|
+
def deflate(self) -> bytes:
|
|
356
|
+
"""Deflate message."""
|
|
357
|
+
# Exclude message_id from serialization
|
|
358
|
+
proto_metadata: ProtoMetadata = metadata_to_proto(self.metadata)
|
|
359
|
+
proto_metadata.message_id = ""
|
|
360
|
+
# Store message metadata and error in object body
|
|
361
|
+
obj_body = ProtoMessage(
|
|
362
|
+
metadata=proto_metadata,
|
|
363
|
+
content=None,
|
|
364
|
+
error=error_to_proto(self.error) if self.has_error() else None,
|
|
365
|
+
).SerializeToString(deterministic=True)
|
|
366
|
+
|
|
367
|
+
return add_header_to_object_body(object_body=obj_body, obj=self)
|
|
368
|
+
|
|
369
|
+
@classmethod
|
|
370
|
+
def inflate(
|
|
371
|
+
cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
|
|
372
|
+
) -> Message:
|
|
373
|
+
"""Inflate an Message from bytes.
|
|
374
|
+
|
|
375
|
+
Parameters
|
|
376
|
+
----------
|
|
377
|
+
object_content : bytes
|
|
378
|
+
The deflated object content of the Message.
|
|
379
|
+
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
380
|
+
Dictionary of children InflatableObjects mapped to their Object IDs.
|
|
381
|
+
These children enable the full inflation of the Message.
|
|
382
|
+
|
|
383
|
+
Returns
|
|
384
|
+
-------
|
|
385
|
+
Message
|
|
386
|
+
The inflated Message.
|
|
387
|
+
"""
|
|
388
|
+
if children is None:
|
|
389
|
+
children = {}
|
|
390
|
+
|
|
391
|
+
# Get the children id from the deflated message
|
|
392
|
+
children_ids = get_object_children_ids_from_object_content(object_content)
|
|
393
|
+
|
|
394
|
+
# If the message had content, only one children is possible
|
|
395
|
+
# If the message carried an error, the returned listed should be empty
|
|
396
|
+
if children_ids != list(children.keys()):
|
|
397
|
+
raise ValueError(
|
|
398
|
+
f"Mismatch in children object IDs: expected {children_ids}, but "
|
|
399
|
+
f"received {list(children.keys())}. The provided children must exactly "
|
|
400
|
+
"match the IDs specified in the object head."
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
# Inflate content
|
|
404
|
+
obj_body = get_object_body(object_content, cls)
|
|
405
|
+
proto_message = ProtoMessage.FromString(obj_body)
|
|
406
|
+
|
|
407
|
+
# Prepare content if error wasn't set in protobuf message
|
|
408
|
+
if proto_message.HasField("error"):
|
|
409
|
+
content = None
|
|
410
|
+
error = error_from_proto(proto_message.error)
|
|
411
|
+
else:
|
|
412
|
+
content = cast(RecordDict, children[children_ids[0]])
|
|
413
|
+
error = None
|
|
414
|
+
# Return message
|
|
415
|
+
return make_message(
|
|
416
|
+
metadata=metadata_from_proto(proto_message.metadata),
|
|
417
|
+
content=content,
|
|
418
|
+
error=error,
|
|
419
|
+
)
|
|
420
|
+
|
|
528
421
|
|
|
529
422
|
def make_message(
|
|
530
423
|
metadata: Metadata, content: RecordDict | None = None, error: Error | None = None
|
|
@@ -533,6 +426,17 @@ def make_message(
|
|
|
533
426
|
return Message(metadata=metadata, content=content, error=error) # type: ignore
|
|
534
427
|
|
|
535
428
|
|
|
429
|
+
def remove_content_from_message(message: Message) -> Message:
|
|
430
|
+
"""Return a copy of the Message but with an empty RecordDict as content.
|
|
431
|
+
|
|
432
|
+
If message has no content, it returns itself.
|
|
433
|
+
"""
|
|
434
|
+
if message.has_error():
|
|
435
|
+
return message
|
|
436
|
+
|
|
437
|
+
return make_message(metadata=message.metadata, content=RecordDict())
|
|
438
|
+
|
|
439
|
+
|
|
536
440
|
def _limit_reply_ttl(
|
|
537
441
|
current: float, reply_ttl: float | None, reply_to: Message
|
|
538
442
|
) -> float:
|
|
@@ -616,46 +520,10 @@ def _check_arg_types( # pylint: disable=too-many-arguments, R0917
|
|
|
616
520
|
raise MessageInitializationError()
|
|
617
521
|
|
|
618
522
|
|
|
619
|
-
def
|
|
620
|
-
"""
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
- "<category>.<action>"
|
|
626
|
-
|
|
627
|
-
where `category` must be one of "train", "evaluate", or "query",
|
|
628
|
-
and `action` must be a valid Python identifier.
|
|
629
|
-
"""
|
|
630
|
-
# Check if conforming to the format "<category>"
|
|
631
|
-
valid_types = {
|
|
632
|
-
MessageType.TRAIN,
|
|
633
|
-
MessageType.EVALUATE,
|
|
634
|
-
MessageType.QUERY,
|
|
635
|
-
MessageType.SYSTEM,
|
|
523
|
+
def get_message_to_descendant_id_mapping(message: Message) -> dict[str, ObjectIDs]:
|
|
524
|
+
"""Construct a mapping between message object_id and that of its descendants."""
|
|
525
|
+
return {
|
|
526
|
+
message.object_id: ObjectIDs(
|
|
527
|
+
object_ids=list(get_descendant_object_ids(message))
|
|
528
|
+
)
|
|
636
529
|
}
|
|
637
|
-
if message_type in valid_types:
|
|
638
|
-
return True
|
|
639
|
-
|
|
640
|
-
# Check if conforming to the format "<category>.<action>"
|
|
641
|
-
if message_type.count(".") != 1:
|
|
642
|
-
return False
|
|
643
|
-
|
|
644
|
-
category, action = message_type.split(".")
|
|
645
|
-
if category in valid_types and action.isidentifier():
|
|
646
|
-
return True
|
|
647
|
-
|
|
648
|
-
return False
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
def validate_legacy_message_type(message_type: str) -> bool:
|
|
652
|
-
"""Validate if the legacy message type is valid."""
|
|
653
|
-
# Backward compatibility for legacy message types
|
|
654
|
-
if message_type in (
|
|
655
|
-
MessageTypeLegacy.GET_PARAMETERS,
|
|
656
|
-
MessageTypeLegacy.GET_PROPERTIES,
|
|
657
|
-
"reconnect",
|
|
658
|
-
):
|
|
659
|
-
return True
|
|
660
|
-
|
|
661
|
-
return False
|
flwr/common/record/__init__.py
CHANGED
|
@@ -15,7 +15,8 @@
|
|
|
15
15
|
"""Record APIs."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from .
|
|
18
|
+
from .array import Array
|
|
19
|
+
from .arrayrecord import ArrayRecord, ParametersRecord
|
|
19
20
|
from .configrecord import ConfigRecord, ConfigsRecord
|
|
20
21
|
from .conversion_utils import array_from_numpy
|
|
21
22
|
from .metricrecord import MetricRecord, MetricsRecord
|