flwr 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +162 -99
- flwr/client/clientapp/app.py +2 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +27 -27
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/run_info_store.py +2 -2
- flwr/common/__init__.py +12 -4
- flwr/common/config.py +4 -4
- flwr/common/constant.py +6 -6
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/logger.py +2 -2
- flwr/common/message.py +327 -102
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +66 -71
- flwr/common/typing.py +8 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +56 -1
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +11 -11
- flwr/server/compat/app_utils.py +16 -16
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +47 -18
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
- flwr/server/run_serverapp.py +4 -4
- flwr/server/server_app.py +38 -18
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +1 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +33 -8
- flwr/server/superlink/linkstate/linkstate.py +4 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +61 -27
- flwr/server/superlink/linkstate/utils.py +93 -27
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/validator.py +4 -4
- flwr/server/workflow/default_workflows.py +48 -57
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +2 -2
- flwr/superexec/exec_user_auth_interceptor.py +18 -2
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/METADATA +2 -2
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/RECORD +94 -92
- flwr/common/record/parametersrecord.py +0 -339
- flwr/common/record/recordset.py +0 -209
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -0
flwr/common/message.py
CHANGED
|
@@ -17,19 +17,34 @@
|
|
|
17
17
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
|
-
import time
|
|
21
20
|
from logging import WARNING
|
|
22
|
-
from typing import Optional, cast
|
|
21
|
+
from typing import Any, Optional, cast, overload
|
|
23
22
|
|
|
24
|
-
from .
|
|
23
|
+
from flwr.common.date import now
|
|
24
|
+
from flwr.common.logger import warn_deprecated_feature
|
|
25
|
+
|
|
26
|
+
from .constant import MESSAGE_TTL_TOLERANCE, MessageType, MessageTypeLegacy
|
|
25
27
|
from .logger import log
|
|
26
|
-
from .record import
|
|
28
|
+
from .record import RecordDict
|
|
27
29
|
|
|
28
30
|
DEFAULT_TTL = 43200 # This is 12 hours
|
|
31
|
+
MESSAGE_INIT_ERROR_MESSAGE = (
|
|
32
|
+
"Invalid arguments for Message. Expected one of the documented "
|
|
33
|
+
"signatures: Message(content: RecordDict, dst_node_id: int, message_type: str,"
|
|
34
|
+
" *, [ttl: float, group_id: str]) or Message(content: RecordDict | error: Error,"
|
|
35
|
+
" *, reply_to: Message, [ttl: float])."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class MessageInitializationError(TypeError):
|
|
40
|
+
"""Error raised when initializing a message with invalid arguments."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, message: str | None = None) -> None:
|
|
43
|
+
super().__init__(message or MESSAGE_INIT_ERROR_MESSAGE)
|
|
29
44
|
|
|
30
45
|
|
|
31
46
|
class Metadata: # pylint: disable=too-many-instance-attributes
|
|
32
|
-
"""
|
|
47
|
+
"""The class representing metadata associated with the current message.
|
|
33
48
|
|
|
34
49
|
Parameters
|
|
35
50
|
----------
|
|
@@ -41,11 +56,13 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
41
56
|
An identifier for the node sending this message.
|
|
42
57
|
dst_node_id : int
|
|
43
58
|
An identifier for the node receiving this message.
|
|
44
|
-
|
|
45
|
-
An identifier for the message this message
|
|
59
|
+
reply_to_message_id : str
|
|
60
|
+
An identifier for the message to which this message is a reply.
|
|
46
61
|
group_id : str
|
|
47
62
|
An identifier for grouping messages. In some settings,
|
|
48
63
|
this is used as the FL round.
|
|
64
|
+
created_at : float
|
|
65
|
+
Unix timestamp when the message was created.
|
|
49
66
|
ttl : float
|
|
50
67
|
Time-to-live for this message in seconds.
|
|
51
68
|
message_type : str
|
|
@@ -59,8 +76,9 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
59
76
|
message_id: str,
|
|
60
77
|
src_node_id: int,
|
|
61
78
|
dst_node_id: int,
|
|
62
|
-
|
|
79
|
+
reply_to_message_id: str,
|
|
63
80
|
group_id: str,
|
|
81
|
+
created_at: float,
|
|
64
82
|
ttl: float,
|
|
65
83
|
message_type: str,
|
|
66
84
|
) -> None:
|
|
@@ -69,12 +87,14 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
69
87
|
"_message_id": message_id,
|
|
70
88
|
"_src_node_id": src_node_id,
|
|
71
89
|
"_dst_node_id": dst_node_id,
|
|
72
|
-
"
|
|
90
|
+
"_reply_to_message_id": reply_to_message_id,
|
|
73
91
|
"_group_id": group_id,
|
|
92
|
+
"_created_at": created_at,
|
|
74
93
|
"_ttl": ttl,
|
|
75
94
|
"_message_type": message_type,
|
|
76
95
|
}
|
|
77
96
|
self.__dict__.update(var_dict)
|
|
97
|
+
self.message_type = message_type # Trigger validation
|
|
78
98
|
|
|
79
99
|
@property
|
|
80
100
|
def run_id(self) -> int:
|
|
@@ -92,9 +112,9 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
92
112
|
return cast(int, self.__dict__["_src_node_id"])
|
|
93
113
|
|
|
94
114
|
@property
|
|
95
|
-
def
|
|
96
|
-
"""An identifier for the message this message
|
|
97
|
-
return cast(str, self.__dict__["
|
|
115
|
+
def reply_to_message_id(self) -> str:
|
|
116
|
+
"""An identifier for the message to which this message is a reply."""
|
|
117
|
+
return cast(str, self.__dict__["_reply_to_message_id"])
|
|
98
118
|
|
|
99
119
|
@property
|
|
100
120
|
def dst_node_id(self) -> int:
|
|
@@ -123,7 +143,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
123
143
|
|
|
124
144
|
@created_at.setter
|
|
125
145
|
def created_at(self, value: float) -> None:
|
|
126
|
-
"""Set creation timestamp
|
|
146
|
+
"""Set creation timestamp of this message."""
|
|
127
147
|
self.__dict__["_created_at"] = value
|
|
128
148
|
|
|
129
149
|
@property
|
|
@@ -154,6 +174,17 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
154
174
|
@message_type.setter
|
|
155
175
|
def message_type(self, value: str) -> None:
|
|
156
176
|
"""Set message_type."""
|
|
177
|
+
# Validate message type
|
|
178
|
+
if validate_legacy_message_type(value):
|
|
179
|
+
pass # Backward compatibility for legacy message types
|
|
180
|
+
elif not validate_message_type(value):
|
|
181
|
+
raise ValueError(
|
|
182
|
+
f"Invalid message type: '{value}'. "
|
|
183
|
+
"Expected format: '<category>' or '<category>.<action>', "
|
|
184
|
+
"where <category> must be 'train', 'evaluate', or 'query', "
|
|
185
|
+
"and <action> must be a valid Python identifier."
|
|
186
|
+
)
|
|
187
|
+
|
|
157
188
|
self.__dict__["_message_type"] = value
|
|
158
189
|
|
|
159
190
|
def __repr__(self) -> str:
|
|
@@ -169,7 +200,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
169
200
|
|
|
170
201
|
|
|
171
202
|
class Error:
|
|
172
|
-
"""
|
|
203
|
+
"""The class storing information about an error that occurred.
|
|
173
204
|
|
|
174
205
|
Parameters
|
|
175
206
|
----------
|
|
@@ -209,31 +240,148 @@ class Error:
|
|
|
209
240
|
|
|
210
241
|
|
|
211
242
|
class Message:
|
|
212
|
-
"""
|
|
243
|
+
"""Represents a message exchanged between ClientApp and ServerApp.
|
|
244
|
+
|
|
245
|
+
This class encapsulates the payload and metadata necessary for communication
|
|
246
|
+
between a ClientApp and a ServerApp.
|
|
213
247
|
|
|
214
248
|
Parameters
|
|
215
249
|
----------
|
|
216
|
-
|
|
217
|
-
A dataclass including information about the message to be executed.
|
|
218
|
-
content : Optional[RecordSet]
|
|
250
|
+
content : Optional[RecordDict] (default: None)
|
|
219
251
|
Holds records either sent by another entity (e.g. sent by the server-side
|
|
220
252
|
logic to a client, or vice-versa) or that will be sent to it.
|
|
221
|
-
error : Optional[Error]
|
|
253
|
+
error : Optional[Error] (default: None)
|
|
222
254
|
A dataclass that captures information about an error that took place
|
|
223
255
|
when processing another message.
|
|
256
|
+
dst_node_id : Optional[int] (default: None)
|
|
257
|
+
An identifier for the node receiving this message.
|
|
258
|
+
message_type : Optional[str] (default: None)
|
|
259
|
+
A string that encodes the action to be executed on
|
|
260
|
+
the receiving end.
|
|
261
|
+
ttl : Optional[float] (default: None)
|
|
262
|
+
Time-to-live (TTL) for this message in seconds. If `None` (default),
|
|
263
|
+
the TTL is set to 43,200 seconds (12 hours).
|
|
264
|
+
group_id : Optional[str] (default: None)
|
|
265
|
+
An identifier for grouping messages. In some settings, this is used as
|
|
266
|
+
the FL round.
|
|
267
|
+
reply_to : Optional[Message] (default: None)
|
|
268
|
+
The instruction message to which this message is a reply. This message does
|
|
269
|
+
not retain the original message's content but derives its metadata from it.
|
|
224
270
|
"""
|
|
225
271
|
|
|
226
|
-
|
|
272
|
+
@overload
|
|
273
|
+
def __init__( # pylint: disable=too-many-arguments # noqa: E704
|
|
227
274
|
self,
|
|
228
|
-
|
|
229
|
-
|
|
275
|
+
content: RecordDict,
|
|
276
|
+
dst_node_id: int,
|
|
277
|
+
message_type: str,
|
|
278
|
+
*,
|
|
279
|
+
ttl: float | None = None,
|
|
280
|
+
group_id: str | None = None,
|
|
281
|
+
) -> None: ...
|
|
282
|
+
|
|
283
|
+
@overload
|
|
284
|
+
def __init__( # noqa: E704
|
|
285
|
+
self, content: RecordDict, *, reply_to: Message, ttl: float | None = None
|
|
286
|
+
) -> None: ...
|
|
287
|
+
|
|
288
|
+
@overload
|
|
289
|
+
def __init__( # noqa: E704
|
|
290
|
+
self, error: Error, *, reply_to: Message, ttl: float | None = None
|
|
291
|
+
) -> None: ...
|
|
292
|
+
|
|
293
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
294
|
+
self,
|
|
295
|
+
*args: Any,
|
|
296
|
+
dst_node_id: int | None = None,
|
|
297
|
+
message_type: str | None = None,
|
|
298
|
+
content: RecordDict | None = None,
|
|
230
299
|
error: Error | None = None,
|
|
300
|
+
ttl: float | None = None,
|
|
301
|
+
group_id: str | None = None,
|
|
302
|
+
reply_to: Message | None = None,
|
|
303
|
+
metadata: Metadata | None = None,
|
|
231
304
|
) -> None:
|
|
232
|
-
|
|
233
|
-
|
|
305
|
+
# Set positional arguments
|
|
306
|
+
content, error, dst_node_id, message_type = _extract_positional_args(
|
|
307
|
+
*args,
|
|
308
|
+
content=content,
|
|
309
|
+
error=error,
|
|
310
|
+
dst_node_id=dst_node_id,
|
|
311
|
+
message_type=message_type,
|
|
312
|
+
)
|
|
313
|
+
_check_arg_types(
|
|
314
|
+
dst_node_id=dst_node_id,
|
|
315
|
+
message_type=message_type,
|
|
316
|
+
content=content,
|
|
317
|
+
error=error,
|
|
318
|
+
ttl=ttl,
|
|
319
|
+
group_id=group_id,
|
|
320
|
+
reply_to=reply_to,
|
|
321
|
+
metadata=metadata,
|
|
322
|
+
)
|
|
234
323
|
|
|
235
|
-
metadata
|
|
236
|
-
metadata
|
|
324
|
+
# Set metadata directly (This is for internal use only)
|
|
325
|
+
if metadata is not None:
|
|
326
|
+
# When metadata is set, all other arguments must be None,
|
|
327
|
+
# except `content`, `error`, or `content_or_error`
|
|
328
|
+
if any(
|
|
329
|
+
x is not None
|
|
330
|
+
for x in [dst_node_id, message_type, ttl, group_id, reply_to]
|
|
331
|
+
):
|
|
332
|
+
raise MessageInitializationError(
|
|
333
|
+
f"Invalid arguments for {Message.__qualname__}. "
|
|
334
|
+
"Expected only `metadata` to be set when creating a message "
|
|
335
|
+
"with provided metadata."
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# Create metadata for an instruction message
|
|
339
|
+
elif reply_to is None:
|
|
340
|
+
# Check arguments
|
|
341
|
+
# `content`, `dst_node_id` and `message_type` must be set
|
|
342
|
+
if not (
|
|
343
|
+
isinstance(content, RecordDict)
|
|
344
|
+
and isinstance(dst_node_id, int)
|
|
345
|
+
and isinstance(message_type, str)
|
|
346
|
+
):
|
|
347
|
+
raise MessageInitializationError()
|
|
348
|
+
|
|
349
|
+
# Set metadata
|
|
350
|
+
metadata = Metadata(
|
|
351
|
+
run_id=0, # Will be set before pushed
|
|
352
|
+
message_id="", # Will be set by the SuperLink
|
|
353
|
+
src_node_id=0, # Will be set before pushed
|
|
354
|
+
dst_node_id=dst_node_id,
|
|
355
|
+
# Instruction messages do not reply to any message
|
|
356
|
+
reply_to_message_id="",
|
|
357
|
+
group_id=group_id or "",
|
|
358
|
+
created_at=now().timestamp(),
|
|
359
|
+
ttl=ttl or DEFAULT_TTL,
|
|
360
|
+
message_type=message_type,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Create metadata for a reply message
|
|
364
|
+
else:
|
|
365
|
+
# Check arguments
|
|
366
|
+
# `dst_node_id`, `message_type` and `group_id` must not be set
|
|
367
|
+
if any(x is not None for x in [dst_node_id, message_type, group_id]):
|
|
368
|
+
raise MessageInitializationError()
|
|
369
|
+
|
|
370
|
+
# Set metadata
|
|
371
|
+
current = now().timestamp()
|
|
372
|
+
metadata = Metadata(
|
|
373
|
+
run_id=reply_to.metadata.run_id,
|
|
374
|
+
message_id="", # Will be set by the SuperLink
|
|
375
|
+
src_node_id=reply_to.metadata.dst_node_id,
|
|
376
|
+
dst_node_id=reply_to.metadata.src_node_id,
|
|
377
|
+
reply_to_message_id=reply_to.metadata.message_id,
|
|
378
|
+
group_id=reply_to.metadata.group_id,
|
|
379
|
+
created_at=current,
|
|
380
|
+
ttl=_limit_reply_ttl(current, ttl, reply_to),
|
|
381
|
+
message_type=reply_to.metadata.message_type,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
metadata.delivered_at = "" # Backward compatibility
|
|
237
385
|
var_dict = {
|
|
238
386
|
"_metadata": metadata,
|
|
239
387
|
"_content": content,
|
|
@@ -247,17 +395,17 @@ class Message:
|
|
|
247
395
|
return cast(Metadata, self.__dict__["_metadata"])
|
|
248
396
|
|
|
249
397
|
@property
|
|
250
|
-
def content(self) ->
|
|
398
|
+
def content(self) -> RecordDict:
|
|
251
399
|
"""The content of this message."""
|
|
252
400
|
if self.__dict__["_content"] is None:
|
|
253
401
|
raise ValueError(
|
|
254
402
|
"Message content is None. Use <message>.has_content() "
|
|
255
403
|
"to check if a message has content."
|
|
256
404
|
)
|
|
257
|
-
return cast(
|
|
405
|
+
return cast(RecordDict, self.__dict__["_content"])
|
|
258
406
|
|
|
259
407
|
@content.setter
|
|
260
|
-
def content(self, value:
|
|
408
|
+
def content(self, value: RecordDict) -> None:
|
|
261
409
|
"""Set content."""
|
|
262
410
|
if self.__dict__["_error"] is None:
|
|
263
411
|
self.__dict__["_content"] = value
|
|
@@ -308,33 +456,25 @@ class Message:
|
|
|
308
456
|
message : Message
|
|
309
457
|
A Message containing only the relevant error and metadata.
|
|
310
458
|
"""
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
ttl = self.metadata.ttl - (
|
|
320
|
-
message.metadata.created_at - self.metadata.created_at
|
|
321
|
-
)
|
|
322
|
-
message.metadata.ttl = ttl
|
|
323
|
-
|
|
324
|
-
self._limit_message_res_ttl(message)
|
|
325
|
-
|
|
326
|
-
return message
|
|
459
|
+
warn_deprecated_feature(
|
|
460
|
+
"`Message.create_error_reply` is deprecated. "
|
|
461
|
+
"Instead of calling `some_message.create_error_reply(some_error, ttl=...)`"
|
|
462
|
+
", use `Message(some_error, reply_to=some_message, ttl=...)`."
|
|
463
|
+
)
|
|
464
|
+
if ttl is not None:
|
|
465
|
+
return Message(error, reply_to=self, ttl=ttl)
|
|
466
|
+
return Message(error, reply_to=self)
|
|
327
467
|
|
|
328
|
-
def create_reply(self, content:
|
|
468
|
+
def create_reply(self, content: RecordDict, ttl: float | None = None) -> Message:
|
|
329
469
|
"""Create a reply to this message with specified content and TTL.
|
|
330
470
|
|
|
331
471
|
The method generates a new `Message` as a reply to this message.
|
|
332
472
|
It inherits 'run_id', 'src_node_id', 'dst_node_id', and 'message_type' from
|
|
333
|
-
this message and sets '
|
|
473
|
+
this message and sets 'reply_to_message_id' to the ID of this message.
|
|
334
474
|
|
|
335
475
|
Parameters
|
|
336
476
|
----------
|
|
337
|
-
content :
|
|
477
|
+
content : RecordDict
|
|
338
478
|
The content for the reply message.
|
|
339
479
|
ttl : Optional[float] (default: None)
|
|
340
480
|
Time-to-live for this message in seconds. If unset, it will be set based
|
|
@@ -348,25 +488,14 @@ class Message:
|
|
|
348
488
|
Message
|
|
349
489
|
A new `Message` instance representing the reply.
|
|
350
490
|
"""
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
message = Message(
|
|
356
|
-
metadata=_create_reply_metadata(self, ttl_),
|
|
357
|
-
content=content,
|
|
491
|
+
warn_deprecated_feature(
|
|
492
|
+
"`Message.create_reply` is deprecated. "
|
|
493
|
+
"Instead of calling `some_message.create_reply(some_content, ttl=...)`"
|
|
494
|
+
", use `Message(some_content, reply_to=some_message, ttl=...)`."
|
|
358
495
|
)
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
ttl = self.metadata.ttl - (
|
|
363
|
-
message.metadata.created_at - self.metadata.created_at
|
|
364
|
-
)
|
|
365
|
-
message.metadata.ttl = ttl
|
|
366
|
-
|
|
367
|
-
self._limit_message_res_ttl(message)
|
|
368
|
-
|
|
369
|
-
return message
|
|
496
|
+
if ttl is not None:
|
|
497
|
+
return Message(content, reply_to=self, ttl=ttl)
|
|
498
|
+
return Message(content, reply_to=self)
|
|
370
499
|
|
|
371
500
|
def __repr__(self) -> str:
|
|
372
501
|
"""Return a string representation of this instance."""
|
|
@@ -379,41 +508,137 @@ class Message:
|
|
|
379
508
|
)
|
|
380
509
|
return f"{self.__class__.__qualname__}({view})"
|
|
381
510
|
|
|
382
|
-
def _limit_message_res_ttl(self, message: Message) -> None:
|
|
383
|
-
"""Limit the TTL of the provided Message to not exceed the expiration time of
|
|
384
|
-
this Message it replies to.
|
|
385
511
|
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
512
|
+
def make_message(
|
|
513
|
+
metadata: Metadata, content: RecordDict | None = None, error: Error | None = None
|
|
514
|
+
) -> Message:
|
|
515
|
+
"""Create a message with the provided metadata, content, and error."""
|
|
516
|
+
return Message(metadata=metadata, content=content, error=error) # type: ignore
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def _limit_reply_ttl(
|
|
520
|
+
current: float, reply_ttl: float | None, reply_to: Message
|
|
521
|
+
) -> float:
|
|
522
|
+
"""Limit the TTL of a reply message such that it does exceed the expiration time of
|
|
523
|
+
the message it replies to."""
|
|
524
|
+
# Calculate the maximum allowed TTL
|
|
525
|
+
max_allowed_ttl = reply_to.metadata.created_at + reply_to.metadata.ttl - current
|
|
526
|
+
|
|
527
|
+
if reply_ttl is not None and reply_ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE:
|
|
528
|
+
log(
|
|
529
|
+
WARNING,
|
|
530
|
+
"The reply TTL of %.2f seconds exceeded the "
|
|
531
|
+
"allowed maximum of %.2f seconds. "
|
|
532
|
+
"The TTL has been updated to the allowed maximum.",
|
|
533
|
+
reply_ttl,
|
|
534
|
+
max_allowed_ttl,
|
|
394
535
|
)
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
536
|
+
return max_allowed_ttl
|
|
537
|
+
|
|
538
|
+
return reply_ttl or max_allowed_ttl
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def _extract_positional_args(
|
|
542
|
+
*args: Any,
|
|
543
|
+
content: RecordDict | None,
|
|
544
|
+
error: Error | None,
|
|
545
|
+
dst_node_id: int | None,
|
|
546
|
+
message_type: str | None,
|
|
547
|
+
) -> tuple[RecordDict | None, Error | None, int | None, str | None]:
|
|
548
|
+
"""Extract positional arguments for the `Message` constructor."""
|
|
549
|
+
content_or_error = args[0] if args else None
|
|
550
|
+
if len(args) > 1:
|
|
551
|
+
if dst_node_id is not None:
|
|
552
|
+
raise MessageInitializationError()
|
|
553
|
+
dst_node_id = args[1]
|
|
554
|
+
if len(args) > 2:
|
|
555
|
+
if message_type is not None:
|
|
556
|
+
raise MessageInitializationError()
|
|
557
|
+
message_type = args[2]
|
|
558
|
+
if len(args) > 3:
|
|
559
|
+
raise MessageInitializationError()
|
|
560
|
+
|
|
561
|
+
# One and only one of `content_or_error`, `content` and `error` must be set
|
|
562
|
+
if sum(x is not None for x in [content_or_error, content, error]) != 1:
|
|
563
|
+
raise MessageInitializationError()
|
|
564
|
+
|
|
565
|
+
# Set `content` or `error` based on `content_or_error`
|
|
566
|
+
if content_or_error is not None: # This means `content` and `error` are None
|
|
567
|
+
if isinstance(content_or_error, RecordDict):
|
|
568
|
+
content = content_or_error
|
|
569
|
+
elif isinstance(content_or_error, Error):
|
|
570
|
+
error = content_or_error
|
|
571
|
+
else:
|
|
572
|
+
raise MessageInitializationError()
|
|
573
|
+
return content, error, dst_node_id, message_type
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def _check_arg_types( # pylint: disable=too-many-arguments, R0917
|
|
577
|
+
dst_node_id: int | None = None,
|
|
578
|
+
message_type: str | None = None,
|
|
579
|
+
content: RecordDict | None = None,
|
|
580
|
+
error: Error | None = None,
|
|
581
|
+
ttl: float | None = None,
|
|
582
|
+
group_id: str | None = None,
|
|
583
|
+
reply_to: Message | None = None,
|
|
584
|
+
metadata: Metadata | None = None,
|
|
585
|
+
) -> None:
|
|
586
|
+
"""Check argument types for the `Message` constructor."""
|
|
587
|
+
# pylint: disable=too-many-boolean-expressions
|
|
588
|
+
if (
|
|
589
|
+
(dst_node_id is None or isinstance(dst_node_id, int))
|
|
590
|
+
and (message_type is None or isinstance(message_type, str))
|
|
591
|
+
and (content is None or isinstance(content, RecordDict))
|
|
592
|
+
and (error is None or isinstance(error, Error))
|
|
593
|
+
and (ttl is None or isinstance(ttl, (int, float)))
|
|
594
|
+
and (group_id is None or isinstance(group_id, str))
|
|
595
|
+
and (reply_to is None or isinstance(reply_to, Message))
|
|
596
|
+
and (metadata is None or isinstance(metadata, Metadata))
|
|
597
|
+
):
|
|
598
|
+
return
|
|
599
|
+
raise MessageInitializationError()
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def validate_message_type(message_type: str) -> bool:
|
|
603
|
+
"""Validate if the message type is valid.
|
|
604
|
+
|
|
605
|
+
A valid message type format must be one of the following:
|
|
606
|
+
|
|
607
|
+
- "<category>"
|
|
608
|
+
- "<category>.<action>"
|
|
609
|
+
|
|
610
|
+
where `category` must be one of "train", "evaluate", or "query",
|
|
611
|
+
and `action` must be a valid Python identifier.
|
|
612
|
+
"""
|
|
613
|
+
# Check if conforming to the format "<category>"
|
|
614
|
+
valid_types = {
|
|
615
|
+
MessageType.TRAIN,
|
|
616
|
+
MessageType.EVALUATE,
|
|
617
|
+
MessageType.QUERY,
|
|
618
|
+
MessageType.SYSTEM,
|
|
619
|
+
}
|
|
620
|
+
if message_type in valid_types:
|
|
621
|
+
return True
|
|
622
|
+
|
|
623
|
+
# Check if conforming to the format "<category>.<action>"
|
|
624
|
+
if message_type.count(".") != 1:
|
|
625
|
+
return False
|
|
626
|
+
|
|
627
|
+
category, action = message_type.split(".")
|
|
628
|
+
if category in valid_types and action.isidentifier():
|
|
629
|
+
return True
|
|
630
|
+
|
|
631
|
+
return False
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def validate_legacy_message_type(message_type: str) -> bool:
|
|
635
|
+
"""Validate if the legacy message type is valid."""
|
|
636
|
+
# Backward compatibility for legacy message types
|
|
637
|
+
if message_type in (
|
|
638
|
+
MessageTypeLegacy.GET_PARAMETERS,
|
|
639
|
+
MessageTypeLegacy.GET_PROPERTIES,
|
|
640
|
+
"reconnect",
|
|
641
|
+
):
|
|
642
|
+
return True
|
|
643
|
+
|
|
644
|
+
return False
|
flwr/common/record/__init__.py
CHANGED
|
@@ -15,17 +15,21 @@
|
|
|
15
15
|
"""Record APIs."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from .
|
|
18
|
+
from .arrayrecord import Array, ArrayRecord, ParametersRecord
|
|
19
|
+
from .configrecord import ConfigRecord, ConfigsRecord
|
|
19
20
|
from .conversion_utils import array_from_numpy
|
|
20
|
-
from .
|
|
21
|
-
from .
|
|
22
|
-
from .recordset import RecordSet
|
|
21
|
+
from .metricrecord import MetricRecord, MetricsRecord
|
|
22
|
+
from .recorddict import RecordDict, RecordSet
|
|
23
23
|
|
|
24
24
|
__all__ = [
|
|
25
25
|
"Array",
|
|
26
|
+
"ArrayRecord",
|
|
27
|
+
"ConfigRecord",
|
|
26
28
|
"ConfigsRecord",
|
|
29
|
+
"MetricRecord",
|
|
27
30
|
"MetricsRecord",
|
|
28
31
|
"ParametersRecord",
|
|
32
|
+
"RecordDict",
|
|
29
33
|
"RecordSet",
|
|
30
34
|
"array_from_numpy",
|
|
31
35
|
]
|