flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. flwr/cli/build.py +2 -0
  2. flwr/cli/log.py +20 -21
  3. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  12. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  13. flwr/cli/run/run.py +5 -9
  14. flwr/client/app.py +6 -4
  15. flwr/client/client_app.py +260 -86
  16. flwr/client/clientapp/app.py +6 -2
  17. flwr/client/grpc_client/connection.py +24 -21
  18. flwr/client/message_handler/message_handler.py +28 -28
  19. flwr/client/mod/__init__.py +2 -2
  20. flwr/client/mod/centraldp_mods.py +7 -7
  21. flwr/client/mod/comms_mods.py +16 -22
  22. flwr/client/mod/localdp_mod.py +4 -4
  23. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  24. flwr/client/rest_client/connection.py +4 -6
  25. flwr/client/run_info_store.py +2 -2
  26. flwr/client/supernode/__init__.py +0 -2
  27. flwr/client/supernode/app.py +1 -11
  28. flwr/common/__init__.py +12 -4
  29. flwr/common/address.py +35 -0
  30. flwr/common/args.py +8 -2
  31. flwr/common/auth_plugin/auth_plugin.py +2 -1
  32. flwr/common/config.py +4 -4
  33. flwr/common/constant.py +16 -0
  34. flwr/common/context.py +4 -4
  35. flwr/common/event_log_plugin/__init__.py +22 -0
  36. flwr/common/event_log_plugin/event_log_plugin.py +60 -0
  37. flwr/common/grpc.py +1 -1
  38. flwr/common/logger.py +2 -2
  39. flwr/common/message.py +338 -102
  40. flwr/common/object_ref.py +0 -10
  41. flwr/common/record/__init__.py +8 -4
  42. flwr/common/record/arrayrecord.py +626 -0
  43. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  44. flwr/common/record/conversion_utils.py +9 -18
  45. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  46. flwr/common/record/recorddict.py +288 -0
  47. flwr/common/recorddict_compat.py +410 -0
  48. flwr/common/secure_aggregation/quantization.py +5 -1
  49. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  50. flwr/common/serde.py +67 -190
  51. flwr/common/telemetry.py +0 -10
  52. flwr/common/typing.py +44 -8
  53. flwr/proto/exec_pb2.py +3 -3
  54. flwr/proto/exec_pb2.pyi +3 -3
  55. flwr/proto/message_pb2.py +12 -12
  56. flwr/proto/message_pb2.pyi +9 -9
  57. flwr/proto/recorddict_pb2.py +70 -0
  58. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  59. flwr/proto/run_pb2.py +31 -31
  60. flwr/proto/run_pb2.pyi +3 -3
  61. flwr/server/__init__.py +3 -1
  62. flwr/server/app.py +74 -3
  63. flwr/server/compat/__init__.py +2 -2
  64. flwr/server/compat/app.py +15 -12
  65. flwr/server/compat/app_utils.py +26 -18
  66. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
  67. flwr/server/fleet_event_log_interceptor.py +94 -0
  68. flwr/server/{driver → grid}/__init__.py +8 -7
  69. flwr/server/{driver/driver.py → grid/grid.py} +48 -19
  70. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
  71. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
  72. flwr/server/run_serverapp.py +6 -17
  73. flwr/server/server_app.py +126 -33
  74. flwr/server/serverapp/app.py +10 -10
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
  76. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
  77. flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
  78. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  79. flwr/server/superlink/fleet/vce/vce_api.py +33 -38
  80. flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
  81. flwr/server/superlink/linkstate/linkstate.py +51 -64
  82. flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
  83. flwr/server/superlink/linkstate/utils.py +171 -133
  84. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  85. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  86. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
  87. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  88. flwr/server/typing.py +3 -3
  89. flwr/server/utils/__init__.py +2 -2
  90. flwr/server/utils/validator.py +53 -68
  91. flwr/server/workflow/default_workflows.py +52 -58
  92. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
  93. flwr/simulation/app.py +2 -2
  94. flwr/simulation/ray_transport/ray_actor.py +4 -2
  95. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  96. flwr/simulation/run_simulation.py +15 -15
  97. flwr/superexec/app.py +0 -14
  98. flwr/superexec/deployment.py +4 -4
  99. flwr/superexec/exec_event_log_interceptor.py +135 -0
  100. flwr/superexec/exec_grpc.py +10 -4
  101. flwr/superexec/exec_servicer.py +6 -6
  102. flwr/superexec/exec_user_auth_interceptor.py +22 -4
  103. flwr/superexec/executor.py +3 -3
  104. flwr/superexec/simulation.py +3 -3
  105. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
  106. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
  107. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
  108. flwr/client/message_handler/task_handler.py +0 -37
  109. flwr/common/record/parametersrecord.py +0 -204
  110. flwr/common/record/recordset.py +0 -202
  111. flwr/common/recordset_compat.py +0 -418
  112. flwr/proto/recordset_pb2.py +0 -70
  113. flwr/proto/task_pb2.py +0 -33
  114. flwr/proto/task_pb2.pyi +0 -100
  115. flwr/proto/task_pb2_grpc.py +0 -4
  116. flwr/proto/task_pb2_grpc.pyi +0 -4
  117. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  118. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  119. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
  120. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
flwr/common/message.py CHANGED
@@ -17,19 +17,34 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import time
21
20
  from logging import WARNING
22
- from typing import Optional, cast
21
+ from typing import Any, Optional, cast, overload
23
22
 
24
- from .constant import MESSAGE_TTL_TOLERANCE
23
+ from flwr.common.date import now
24
+ from flwr.common.logger import warn_deprecated_feature
25
+
26
+ from .constant import MESSAGE_TTL_TOLERANCE, MessageType, MessageTypeLegacy
25
27
  from .logger import log
26
- from .record import RecordSet
28
+ from .record import RecordDict
29
+
30
+ DEFAULT_TTL = 43200 # This is 12 hours
31
+ MESSAGE_INIT_ERROR_MESSAGE = (
32
+ "Invalid arguments for Message. Expected one of the documented "
33
+ "signatures: Message(content: RecordDict, dst_node_id: int, message_type: str,"
34
+ " *, [ttl: float, group_id: str]) or Message(content: RecordDict | error: Error,"
35
+ " *, reply_to: Message, [ttl: float])."
36
+ )
37
+
27
38
 
28
- DEFAULT_TTL = 3600
39
+ class MessageInitializationError(TypeError):
40
+ """Error raised when initializing a message with invalid arguments."""
41
+
42
+ def __init__(self, message: str | None = None) -> None:
43
+ super().__init__(message or MESSAGE_INIT_ERROR_MESSAGE)
29
44
 
30
45
 
31
46
  class Metadata: # pylint: disable=too-many-instance-attributes
32
- """A dataclass holding metadata associated with the current message.
47
+ """The class representing metadata associated with the current message.
33
48
 
34
49
  Parameters
35
50
  ----------
@@ -41,11 +56,13 @@ class Metadata: # pylint: disable=too-many-instance-attributes
41
56
  An identifier for the node sending this message.
42
57
  dst_node_id : int
43
58
  An identifier for the node receiving this message.
44
- reply_to_message : str
45
- An identifier for the message this message replies to.
59
+ reply_to_message_id : str
60
+ An identifier for the message to which this message is a reply.
46
61
  group_id : str
47
62
  An identifier for grouping messages. In some settings,
48
63
  this is used as the FL round.
64
+ created_at : float
65
+ Unix timestamp when the message was created.
49
66
  ttl : float
50
67
  Time-to-live for this message in seconds.
51
68
  message_type : str
@@ -59,8 +76,9 @@ class Metadata: # pylint: disable=too-many-instance-attributes
59
76
  message_id: str,
60
77
  src_node_id: int,
61
78
  dst_node_id: int,
62
- reply_to_message: str,
79
+ reply_to_message_id: str,
63
80
  group_id: str,
81
+ created_at: float,
64
82
  ttl: float,
65
83
  message_type: str,
66
84
  ) -> None:
@@ -69,12 +87,14 @@ class Metadata: # pylint: disable=too-many-instance-attributes
69
87
  "_message_id": message_id,
70
88
  "_src_node_id": src_node_id,
71
89
  "_dst_node_id": dst_node_id,
72
- "_reply_to_message": reply_to_message,
90
+ "_reply_to_message_id": reply_to_message_id,
73
91
  "_group_id": group_id,
92
+ "_created_at": created_at,
74
93
  "_ttl": ttl,
75
94
  "_message_type": message_type,
76
95
  }
77
96
  self.__dict__.update(var_dict)
97
+ self.message_type = message_type # Trigger validation
78
98
 
79
99
  @property
80
100
  def run_id(self) -> int:
@@ -92,9 +112,9 @@ class Metadata: # pylint: disable=too-many-instance-attributes
92
112
  return cast(int, self.__dict__["_src_node_id"])
93
113
 
94
114
  @property
95
- def reply_to_message(self) -> str:
96
- """An identifier for the message this message replies to."""
97
- return cast(str, self.__dict__["_reply_to_message"])
115
+ def reply_to_message_id(self) -> str:
116
+ """An identifier for the message to which this message is a reply."""
117
+ return cast(str, self.__dict__["_reply_to_message_id"])
98
118
 
99
119
  @property
100
120
  def dst_node_id(self) -> int:
@@ -123,9 +143,19 @@ class Metadata: # pylint: disable=too-many-instance-attributes
123
143
 
124
144
  @created_at.setter
125
145
  def created_at(self, value: float) -> None:
126
- """Set creation timestamp for this message."""
146
+ """Set creation timestamp of this message."""
127
147
  self.__dict__["_created_at"] = value
128
148
 
149
+ @property
150
+ def delivered_at(self) -> str:
151
+ """Unix timestamp when the message was delivered."""
152
+ return cast(str, self.__dict__["_delivered_at"])
153
+
154
+ @delivered_at.setter
155
+ def delivered_at(self, value: str) -> None:
156
+ """Set delivery timestamp of this message."""
157
+ self.__dict__["_delivered_at"] = value
158
+
129
159
  @property
130
160
  def ttl(self) -> float:
131
161
  """Time-to-live for this message."""
@@ -144,6 +174,17 @@ class Metadata: # pylint: disable=too-many-instance-attributes
144
174
  @message_type.setter
145
175
  def message_type(self, value: str) -> None:
146
176
  """Set message_type."""
177
+ # Validate message type
178
+ if validate_legacy_message_type(value):
179
+ pass # Backward compatibility for legacy message types
180
+ elif not validate_message_type(value):
181
+ raise ValueError(
182
+ f"Invalid message type: '{value}'. "
183
+ "Expected format: '<category>' or '<category>.<action>', "
184
+ "where <category> must be 'train', 'evaluate', or 'query', "
185
+ "and <action> must be a valid Python identifier."
186
+ )
187
+
147
188
  self.__dict__["_message_type"] = value
148
189
 
149
190
  def __repr__(self) -> str:
@@ -159,7 +200,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes
159
200
 
160
201
 
161
202
  class Error:
162
- """A dataclass that stores information about an error that occurred.
203
+ """The class storing information about an error that occurred.
163
204
 
164
205
  Parameters
165
206
  ----------
@@ -199,30 +240,148 @@ class Error:
199
240
 
200
241
 
201
242
  class Message:
202
- """State of your application from the viewpoint of the entity using it.
243
+ """Represents a message exchanged between ClientApp and ServerApp.
244
+
245
+ This class encapsulates the payload and metadata necessary for communication
246
+ between a ClientApp and a ServerApp.
203
247
 
204
248
  Parameters
205
249
  ----------
206
- metadata : Metadata
207
- A dataclass including information about the message to be executed.
208
- content : Optional[RecordSet]
250
+ content : Optional[RecordDict] (default: None)
209
251
  Holds records either sent by another entity (e.g. sent by the server-side
210
252
  logic to a client, or vice-versa) or that will be sent to it.
211
- error : Optional[Error]
253
+ error : Optional[Error] (default: None)
212
254
  A dataclass that captures information about an error that took place
213
255
  when processing another message.
256
+ dst_node_id : Optional[int] (default: None)
257
+ An identifier for the node receiving this message.
258
+ message_type : Optional[str] (default: None)
259
+ A string that encodes the action to be executed on
260
+ the receiving end.
261
+ ttl : Optional[float] (default: None)
262
+ Time-to-live (TTL) for this message in seconds. If `None` (default),
263
+ the TTL is set to 43,200 seconds (12 hours).
264
+ group_id : Optional[str] (default: None)
265
+ An identifier for grouping messages. In some settings, this is used as
266
+ the FL round.
267
+ reply_to : Optional[Message] (default: None)
268
+ The instruction message to which this message is a reply. This message does
269
+ not retain the original message's content but derives its metadata from it.
214
270
  """
215
271
 
216
- def __init__(
272
+ @overload
273
+ def __init__( # pylint: disable=too-many-arguments # noqa: E704
274
+ self,
275
+ content: RecordDict,
276
+ dst_node_id: int,
277
+ message_type: str,
278
+ *,
279
+ ttl: float | None = None,
280
+ group_id: str | None = None,
281
+ ) -> None: ...
282
+
283
+ @overload
284
+ def __init__( # noqa: E704
285
+ self, content: RecordDict, *, reply_to: Message, ttl: float | None = None
286
+ ) -> None: ...
287
+
288
+ @overload
289
+ def __init__( # noqa: E704
290
+ self, error: Error, *, reply_to: Message, ttl: float | None = None
291
+ ) -> None: ...
292
+
293
+ def __init__( # pylint: disable=too-many-arguments
217
294
  self,
218
- metadata: Metadata,
219
- content: RecordSet | None = None,
295
+ *args: Any,
296
+ dst_node_id: int | None = None,
297
+ message_type: str | None = None,
298
+ content: RecordDict | None = None,
220
299
  error: Error | None = None,
300
+ ttl: float | None = None,
301
+ group_id: str | None = None,
302
+ reply_to: Message | None = None,
303
+ metadata: Metadata | None = None,
221
304
  ) -> None:
222
- if not (content is None) ^ (error is None):
223
- raise ValueError("Either `content` or `error` must be set, but not both.")
305
+ # Set positional arguments
306
+ content, error, dst_node_id, message_type = _extract_positional_args(
307
+ *args,
308
+ content=content,
309
+ error=error,
310
+ dst_node_id=dst_node_id,
311
+ message_type=message_type,
312
+ )
313
+ _check_arg_types(
314
+ dst_node_id=dst_node_id,
315
+ message_type=message_type,
316
+ content=content,
317
+ error=error,
318
+ ttl=ttl,
319
+ group_id=group_id,
320
+ reply_to=reply_to,
321
+ metadata=metadata,
322
+ )
323
+
324
+ # Set metadata directly (This is for internal use only)
325
+ if metadata is not None:
326
+ # When metadata is set, all other arguments must be None,
327
+ # except `content`, `error`, or `content_or_error`
328
+ if any(
329
+ x is not None
330
+ for x in [dst_node_id, message_type, ttl, group_id, reply_to]
331
+ ):
332
+ raise MessageInitializationError(
333
+ f"Invalid arguments for {Message.__qualname__}. "
334
+ "Expected only `metadata` to be set when creating a message "
335
+ "with provided metadata."
336
+ )
337
+
338
+ # Create metadata for an instruction message
339
+ elif reply_to is None:
340
+ # Check arguments
341
+ # `content`, `dst_node_id` and `message_type` must be set
342
+ if not (
343
+ isinstance(content, RecordDict)
344
+ and isinstance(dst_node_id, int)
345
+ and isinstance(message_type, str)
346
+ ):
347
+ raise MessageInitializationError()
348
+
349
+ # Set metadata
350
+ metadata = Metadata(
351
+ run_id=0, # Will be set before pushed
352
+ message_id="", # Will be set by the SuperLink
353
+ src_node_id=0, # Will be set before pushed
354
+ dst_node_id=dst_node_id,
355
+ # Instruction messages do not reply to any message
356
+ reply_to_message_id="",
357
+ group_id=group_id or "",
358
+ created_at=now().timestamp(),
359
+ ttl=ttl or DEFAULT_TTL,
360
+ message_type=message_type,
361
+ )
362
+
363
+ # Create metadata for a reply message
364
+ else:
365
+ # Check arguments
366
+ # `dst_node_id`, `message_type` and `group_id` must not be set
367
+ if any(x is not None for x in [dst_node_id, message_type, group_id]):
368
+ raise MessageInitializationError()
369
+
370
+ # Set metadata
371
+ current = now().timestamp()
372
+ metadata = Metadata(
373
+ run_id=reply_to.metadata.run_id,
374
+ message_id="", # Will be set by the SuperLink
375
+ src_node_id=reply_to.metadata.dst_node_id,
376
+ dst_node_id=reply_to.metadata.src_node_id,
377
+ reply_to_message_id=reply_to.metadata.message_id,
378
+ group_id=reply_to.metadata.group_id,
379
+ created_at=current,
380
+ ttl=_limit_reply_ttl(current, ttl, reply_to),
381
+ message_type=reply_to.metadata.message_type,
382
+ )
224
383
 
225
- metadata.created_at = time.time() # Set the message creation timestamp
384
+ metadata.delivered_at = "" # Backward compatibility
226
385
  var_dict = {
227
386
  "_metadata": metadata,
228
387
  "_content": content,
@@ -236,17 +395,17 @@ class Message:
236
395
  return cast(Metadata, self.__dict__["_metadata"])
237
396
 
238
397
  @property
239
- def content(self) -> RecordSet:
398
+ def content(self) -> RecordDict:
240
399
  """The content of this message."""
241
400
  if self.__dict__["_content"] is None:
242
401
  raise ValueError(
243
402
  "Message content is None. Use <message>.has_content() "
244
403
  "to check if a message has content."
245
404
  )
246
- return cast(RecordSet, self.__dict__["_content"])
405
+ return cast(RecordDict, self.__dict__["_content"])
247
406
 
248
407
  @content.setter
249
- def content(self, value: RecordSet) -> None:
408
+ def content(self, value: RecordDict) -> None:
250
409
  """Set content."""
251
410
  if self.__dict__["_error"] is None:
252
411
  self.__dict__["_content"] = value
@@ -297,33 +456,25 @@ class Message:
297
456
  message : Message
298
457
  A Message containing only the relevant error and metadata.
299
458
  """
300
- # If no TTL passed, use default for message creation (will update after
301
- # message creation)
302
- ttl_ = DEFAULT_TTL if ttl is None else ttl
303
- # Create reply with error
304
- message = Message(metadata=_create_reply_metadata(self, ttl_), error=error)
305
-
306
- if ttl is None:
307
- # Set TTL equal to the remaining time for the received message to expire
308
- ttl = self.metadata.ttl - (
309
- message.metadata.created_at - self.metadata.created_at
310
- )
311
- message.metadata.ttl = ttl
312
-
313
- self._limit_task_res_ttl(message)
314
-
315
- return message
459
+ warn_deprecated_feature(
460
+ "`Message.create_error_reply` is deprecated. "
461
+ "Instead of calling `some_message.create_error_reply(some_error, ttl=...)`"
462
+ ", use `Message(some_error, reply_to=some_message, ttl=...)`."
463
+ )
464
+ if ttl is not None:
465
+ return Message(error, reply_to=self, ttl=ttl)
466
+ return Message(error, reply_to=self)
316
467
 
317
- def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message:
468
+ def create_reply(self, content: RecordDict, ttl: float | None = None) -> Message:
318
469
  """Create a reply to this message with specified content and TTL.
319
470
 
320
471
  The method generates a new `Message` as a reply to this message.
321
472
  It inherits 'run_id', 'src_node_id', 'dst_node_id', and 'message_type' from
322
- this message and sets 'reply_to_message' to the ID of this message.
473
+ this message and sets 'reply_to_message_id' to the ID of this message.
323
474
 
324
475
  Parameters
325
476
  ----------
326
- content : RecordSet
477
+ content : RecordDict
327
478
  The content for the reply message.
328
479
  ttl : Optional[float] (default: None)
329
480
  Time-to-live for this message in seconds. If unset, it will be set based
@@ -337,25 +488,14 @@ class Message:
337
488
  Message
338
489
  A new `Message` instance representing the reply.
339
490
  """
340
- # If no TTL passed, use default for message creation (will update after
341
- # message creation)
342
- ttl_ = DEFAULT_TTL if ttl is None else ttl
343
-
344
- message = Message(
345
- metadata=_create_reply_metadata(self, ttl_),
346
- content=content,
491
+ warn_deprecated_feature(
492
+ "`Message.create_reply` is deprecated. "
493
+ "Instead of calling `some_message.create_reply(some_content, ttl=...)`"
494
+ ", use `Message(some_content, reply_to=some_message, ttl=...)`."
347
495
  )
348
-
349
- if ttl is None:
350
- # Set TTL equal to the remaining time for the received message to expire
351
- ttl = self.metadata.ttl - (
352
- message.metadata.created_at - self.metadata.created_at
353
- )
354
- message.metadata.ttl = ttl
355
-
356
- self._limit_task_res_ttl(message)
357
-
358
- return message
496
+ if ttl is not None:
497
+ return Message(content, reply_to=self, ttl=ttl)
498
+ return Message(content, reply_to=self)
359
499
 
360
500
  def __repr__(self) -> str:
361
501
  """Return a string representation of this instance."""
@@ -368,41 +508,137 @@ class Message:
368
508
  )
369
509
  return f"{self.__class__.__qualname__}({view})"
370
510
 
371
- def _limit_task_res_ttl(self, message: Message) -> None:
372
- """Limit the TaskRes TTL to not exceed the expiration time of the TaskIns it
373
- replies to.
374
511
 
375
- Parameters
376
- ----------
377
- message : Message
378
- The message to which the TaskRes is replying.
379
- """
380
- # Calculate the maximum allowed TTL
381
- max_allowed_ttl = (
382
- self.metadata.created_at + self.metadata.ttl - message.metadata.created_at
512
+ def make_message(
513
+ metadata: Metadata, content: RecordDict | None = None, error: Error | None = None
514
+ ) -> Message:
515
+ """Create a message with the provided metadata, content, and error."""
516
+ return Message(metadata=metadata, content=content, error=error) # type: ignore
517
+
518
+
519
+ def _limit_reply_ttl(
520
+ current: float, reply_ttl: float | None, reply_to: Message
521
+ ) -> float:
522
+ """Limit the TTL of a reply message such that it does exceed the expiration time of
523
+ the message it replies to."""
524
+ # Calculate the maximum allowed TTL
525
+ max_allowed_ttl = reply_to.metadata.created_at + reply_to.metadata.ttl - current
526
+
527
+ if reply_ttl is not None and reply_ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE:
528
+ log(
529
+ WARNING,
530
+ "The reply TTL of %.2f seconds exceeded the "
531
+ "allowed maximum of %.2f seconds. "
532
+ "The TTL has been updated to the allowed maximum.",
533
+ reply_ttl,
534
+ max_allowed_ttl,
383
535
  )
384
-
385
- if message.metadata.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE:
386
- log(
387
- WARNING,
388
- "The reply TTL of %.2f seconds exceeded the "
389
- "allowed maximum of %.2f seconds. "
390
- "The TTL has been updated to the allowed maximum.",
391
- message.metadata.ttl,
392
- max_allowed_ttl,
393
- )
394
- message.metadata.ttl = max_allowed_ttl
395
-
396
-
397
- def _create_reply_metadata(msg: Message, ttl: float) -> Metadata:
398
- """Construct metadata for a reply message."""
399
- return Metadata(
400
- run_id=msg.metadata.run_id,
401
- message_id="",
402
- src_node_id=msg.metadata.dst_node_id,
403
- dst_node_id=msg.metadata.src_node_id,
404
- reply_to_message=msg.metadata.message_id,
405
- group_id=msg.metadata.group_id,
406
- ttl=ttl,
407
- message_type=msg.metadata.message_type,
408
- )
536
+ return max_allowed_ttl
537
+
538
+ return reply_ttl or max_allowed_ttl
539
+
540
+
541
+ def _extract_positional_args(
542
+ *args: Any,
543
+ content: RecordDict | None,
544
+ error: Error | None,
545
+ dst_node_id: int | None,
546
+ message_type: str | None,
547
+ ) -> tuple[RecordDict | None, Error | None, int | None, str | None]:
548
+ """Extract positional arguments for the `Message` constructor."""
549
+ content_or_error = args[0] if args else None
550
+ if len(args) > 1:
551
+ if dst_node_id is not None:
552
+ raise MessageInitializationError()
553
+ dst_node_id = args[1]
554
+ if len(args) > 2:
555
+ if message_type is not None:
556
+ raise MessageInitializationError()
557
+ message_type = args[2]
558
+ if len(args) > 3:
559
+ raise MessageInitializationError()
560
+
561
+ # One and only one of `content_or_error`, `content` and `error` must be set
562
+ if sum(x is not None for x in [content_or_error, content, error]) != 1:
563
+ raise MessageInitializationError()
564
+
565
+ # Set `content` or `error` based on `content_or_error`
566
+ if content_or_error is not None: # This means `content` and `error` are None
567
+ if isinstance(content_or_error, RecordDict):
568
+ content = content_or_error
569
+ elif isinstance(content_or_error, Error):
570
+ error = content_or_error
571
+ else:
572
+ raise MessageInitializationError()
573
+ return content, error, dst_node_id, message_type
574
+
575
+
576
+ def _check_arg_types( # pylint: disable=too-many-arguments, R0917
577
+ dst_node_id: int | None = None,
578
+ message_type: str | None = None,
579
+ content: RecordDict | None = None,
580
+ error: Error | None = None,
581
+ ttl: float | None = None,
582
+ group_id: str | None = None,
583
+ reply_to: Message | None = None,
584
+ metadata: Metadata | None = None,
585
+ ) -> None:
586
+ """Check argument types for the `Message` constructor."""
587
+ # pylint: disable=too-many-boolean-expressions
588
+ if (
589
+ (dst_node_id is None or isinstance(dst_node_id, int))
590
+ and (message_type is None or isinstance(message_type, str))
591
+ and (content is None or isinstance(content, RecordDict))
592
+ and (error is None or isinstance(error, Error))
593
+ and (ttl is None or isinstance(ttl, (int, float)))
594
+ and (group_id is None or isinstance(group_id, str))
595
+ and (reply_to is None or isinstance(reply_to, Message))
596
+ and (metadata is None or isinstance(metadata, Metadata))
597
+ ):
598
+ return
599
+ raise MessageInitializationError()
600
+
601
+
602
+ def validate_message_type(message_type: str) -> bool:
603
+ """Validate if the message type is valid.
604
+
605
+ A valid message type format must be one of the following:
606
+
607
+ - "<category>"
608
+ - "<category>.<action>"
609
+
610
+ where `category` must be one of "train", "evaluate", or "query",
611
+ and `action` must be a valid Python identifier.
612
+ """
613
+ # Check if conforming to the format "<category>"
614
+ valid_types = {
615
+ MessageType.TRAIN,
616
+ MessageType.EVALUATE,
617
+ MessageType.QUERY,
618
+ MessageType.SYSTEM,
619
+ }
620
+ if message_type in valid_types:
621
+ return True
622
+
623
+ # Check if conforming to the format "<category>.<action>"
624
+ if message_type.count(".") != 1:
625
+ return False
626
+
627
+ category, action = message_type.split(".")
628
+ if category in valid_types and action.isidentifier():
629
+ return True
630
+
631
+ return False
632
+
633
+
634
+ def validate_legacy_message_type(message_type: str) -> bool:
635
+ """Validate if the legacy message type is valid."""
636
+ # Backward compatibility for legacy message types
637
+ if message_type in (
638
+ MessageTypeLegacy.GET_PARAMETERS,
639
+ MessageTypeLegacy.GET_PROPERTIES,
640
+ "reconnect",
641
+ ):
642
+ return True
643
+
644
+ return False
flwr/common/object_ref.py CHANGED
@@ -170,7 +170,6 @@ def load_app( # pylint: disable= too-many-branches
170
170
  module = importlib.import_module(module_str)
171
171
  else:
172
172
  module = sys.modules[module_str]
173
- _reload_modules(project_dir)
174
173
 
175
174
  except ModuleNotFoundError as err:
176
175
  raise error_type(
@@ -200,15 +199,6 @@ def _unload_modules(project_dir: Path) -> None:
200
199
  del sys.modules[name]
201
200
 
202
201
 
203
- def _reload_modules(project_dir: Path) -> None:
204
- """Reload modules from the project directory."""
205
- dir_str = str(project_dir.absolute())
206
- for m in list(sys.modules.values()):
207
- path: Optional[str] = getattr(m, "__file__", None)
208
- if path is not None and path.startswith(dir_str):
209
- importlib.reload(m)
210
-
211
-
212
202
  def _set_sys_path(directory: Optional[Union[str, Path]]) -> None:
213
203
  """Set the system path."""
214
204
  if directory is None:
@@ -15,17 +15,21 @@
15
15
  """Record APIs."""
16
16
 
17
17
 
18
- from .configsrecord import ConfigsRecord
18
+ from .arrayrecord import Array, ArrayRecord, ParametersRecord
19
+ from .configrecord import ConfigRecord, ConfigsRecord
19
20
  from .conversion_utils import array_from_numpy
20
- from .metricsrecord import MetricsRecord
21
- from .parametersrecord import Array, ParametersRecord
22
- from .recordset import RecordSet
21
+ from .metricrecord import MetricRecord, MetricsRecord
22
+ from .recorddict import RecordDict, RecordSet
23
23
 
24
24
  __all__ = [
25
25
  "Array",
26
+ "ArrayRecord",
27
+ "ConfigRecord",
26
28
  "ConfigsRecord",
29
+ "MetricRecord",
27
30
  "MetricsRecord",
28
31
  "ParametersRecord",
32
+ "RecordDict",
29
33
  "RecordSet",
30
34
  "array_from_numpy",
31
35
  ]