flwr 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
  2. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  3. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  11. flwr/cli/run/run.py +5 -9
  12. flwr/client/app.py +6 -4
  13. flwr/client/client_app.py +162 -99
  14. flwr/client/clientapp/app.py +2 -2
  15. flwr/client/grpc_client/connection.py +24 -21
  16. flwr/client/message_handler/message_handler.py +27 -27
  17. flwr/client/mod/__init__.py +2 -2
  18. flwr/client/mod/centraldp_mods.py +7 -7
  19. flwr/client/mod/comms_mods.py +16 -22
  20. flwr/client/mod/localdp_mod.py +4 -4
  21. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  22. flwr/client/run_info_store.py +2 -2
  23. flwr/common/__init__.py +12 -4
  24. flwr/common/config.py +4 -4
  25. flwr/common/constant.py +6 -6
  26. flwr/common/context.py +4 -4
  27. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  28. flwr/common/logger.py +2 -2
  29. flwr/common/message.py +327 -102
  30. flwr/common/record/__init__.py +8 -4
  31. flwr/common/record/arrayrecord.py +626 -0
  32. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  33. flwr/common/record/conversion_utils.py +1 -1
  34. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  35. flwr/common/record/recorddict.py +288 -0
  36. flwr/common/recorddict_compat.py +410 -0
  37. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  38. flwr/common/serde.py +66 -71
  39. flwr/common/typing.py +8 -8
  40. flwr/proto/exec_pb2.py +3 -3
  41. flwr/proto/exec_pb2.pyi +3 -3
  42. flwr/proto/message_pb2.py +12 -12
  43. flwr/proto/message_pb2.pyi +9 -9
  44. flwr/proto/recorddict_pb2.py +70 -0
  45. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  46. flwr/proto/run_pb2.py +31 -31
  47. flwr/proto/run_pb2.pyi +3 -3
  48. flwr/server/__init__.py +3 -1
  49. flwr/server/app.py +56 -1
  50. flwr/server/compat/__init__.py +2 -2
  51. flwr/server/compat/app.py +11 -11
  52. flwr/server/compat/app_utils.py +16 -16
  53. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
  54. flwr/server/fleet_event_log_interceptor.py +94 -0
  55. flwr/server/{driver → grid}/__init__.py +8 -7
  56. flwr/server/{driver/driver.py → grid/grid.py} +47 -18
  57. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
  58. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
  59. flwr/server/run_serverapp.py +4 -4
  60. flwr/server/server_app.py +38 -18
  61. flwr/server/serverapp/app.py +10 -10
  62. flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
  63. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  64. flwr/server/superlink/fleet/vce/vce_api.py +1 -3
  65. flwr/server/superlink/linkstate/in_memory_linkstate.py +33 -8
  66. flwr/server/superlink/linkstate/linkstate.py +4 -4
  67. flwr/server/superlink/linkstate/sqlite_linkstate.py +61 -27
  68. flwr/server/superlink/linkstate/utils.py +93 -27
  69. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  70. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  71. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
  72. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  73. flwr/server/typing.py +3 -3
  74. flwr/server/utils/validator.py +4 -4
  75. flwr/server/workflow/default_workflows.py +48 -57
  76. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
  77. flwr/simulation/app.py +2 -2
  78. flwr/simulation/ray_transport/ray_actor.py +4 -2
  79. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  80. flwr/simulation/run_simulation.py +15 -15
  81. flwr/superexec/deployment.py +4 -4
  82. flwr/superexec/exec_event_log_interceptor.py +135 -0
  83. flwr/superexec/exec_grpc.py +10 -4
  84. flwr/superexec/exec_servicer.py +2 -2
  85. flwr/superexec/exec_user_auth_interceptor.py +18 -2
  86. flwr/superexec/executor.py +3 -3
  87. flwr/superexec/simulation.py +3 -3
  88. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/METADATA +2 -2
  89. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/RECORD +94 -92
  90. flwr/common/record/parametersrecord.py +0 -339
  91. flwr/common/record/recordset.py +0 -209
  92. flwr/common/recordset_compat.py +0 -418
  93. flwr/proto/recordset_pb2.py +0 -70
  94. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  95. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  96. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
  97. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
  98. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -0
flwr/common/message.py CHANGED
@@ -17,19 +17,34 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import time
21
20
  from logging import WARNING
22
- from typing import Optional, cast
21
+ from typing import Any, Optional, cast, overload
23
22
 
24
- from .constant import MESSAGE_TTL_TOLERANCE
23
+ from flwr.common.date import now
24
+ from flwr.common.logger import warn_deprecated_feature
25
+
26
+ from .constant import MESSAGE_TTL_TOLERANCE, MessageType, MessageTypeLegacy
25
27
  from .logger import log
26
- from .record import RecordSet
28
+ from .record import RecordDict
27
29
 
28
30
  DEFAULT_TTL = 43200 # This is 12 hours
31
+ MESSAGE_INIT_ERROR_MESSAGE = (
32
+ "Invalid arguments for Message. Expected one of the documented "
33
+ "signatures: Message(content: RecordDict, dst_node_id: int, message_type: str,"
34
+ " *, [ttl: float, group_id: str]) or Message(content: RecordDict | error: Error,"
35
+ " *, reply_to: Message, [ttl: float])."
36
+ )
37
+
38
+
39
+ class MessageInitializationError(TypeError):
40
+ """Error raised when initializing a message with invalid arguments."""
41
+
42
+ def __init__(self, message: str | None = None) -> None:
43
+ super().__init__(message or MESSAGE_INIT_ERROR_MESSAGE)
29
44
 
30
45
 
31
46
  class Metadata: # pylint: disable=too-many-instance-attributes
32
- """A dataclass holding metadata associated with the current message.
47
+ """The class representing metadata associated with the current message.
33
48
 
34
49
  Parameters
35
50
  ----------
@@ -41,11 +56,13 @@ class Metadata: # pylint: disable=too-many-instance-attributes
41
56
  An identifier for the node sending this message.
42
57
  dst_node_id : int
43
58
  An identifier for the node receiving this message.
44
- reply_to_message : str
45
- An identifier for the message this message replies to.
59
+ reply_to_message_id : str
60
+ An identifier for the message to which this message is a reply.
46
61
  group_id : str
47
62
  An identifier for grouping messages. In some settings,
48
63
  this is used as the FL round.
64
+ created_at : float
65
+ Unix timestamp when the message was created.
49
66
  ttl : float
50
67
  Time-to-live for this message in seconds.
51
68
  message_type : str
@@ -59,8 +76,9 @@ class Metadata: # pylint: disable=too-many-instance-attributes
59
76
  message_id: str,
60
77
  src_node_id: int,
61
78
  dst_node_id: int,
62
- reply_to_message: str,
79
+ reply_to_message_id: str,
63
80
  group_id: str,
81
+ created_at: float,
64
82
  ttl: float,
65
83
  message_type: str,
66
84
  ) -> None:
@@ -69,12 +87,14 @@ class Metadata: # pylint: disable=too-many-instance-attributes
69
87
  "_message_id": message_id,
70
88
  "_src_node_id": src_node_id,
71
89
  "_dst_node_id": dst_node_id,
72
- "_reply_to_message": reply_to_message,
90
+ "_reply_to_message_id": reply_to_message_id,
73
91
  "_group_id": group_id,
92
+ "_created_at": created_at,
74
93
  "_ttl": ttl,
75
94
  "_message_type": message_type,
76
95
  }
77
96
  self.__dict__.update(var_dict)
97
+ self.message_type = message_type # Trigger validation
78
98
 
79
99
  @property
80
100
  def run_id(self) -> int:
@@ -92,9 +112,9 @@ class Metadata: # pylint: disable=too-many-instance-attributes
92
112
  return cast(int, self.__dict__["_src_node_id"])
93
113
 
94
114
  @property
95
- def reply_to_message(self) -> str:
96
- """An identifier for the message this message replies to."""
97
- return cast(str, self.__dict__["_reply_to_message"])
115
+ def reply_to_message_id(self) -> str:
116
+ """An identifier for the message to which this message is a reply."""
117
+ return cast(str, self.__dict__["_reply_to_message_id"])
98
118
 
99
119
  @property
100
120
  def dst_node_id(self) -> int:
@@ -123,7 +143,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes
123
143
 
124
144
  @created_at.setter
125
145
  def created_at(self, value: float) -> None:
126
- """Set creation timestamp for this message."""
146
+ """Set creation timestamp of this message."""
127
147
  self.__dict__["_created_at"] = value
128
148
 
129
149
  @property
@@ -154,6 +174,17 @@ class Metadata: # pylint: disable=too-many-instance-attributes
154
174
  @message_type.setter
155
175
  def message_type(self, value: str) -> None:
156
176
  """Set message_type."""
177
+ # Validate message type
178
+ if validate_legacy_message_type(value):
179
+ pass # Backward compatibility for legacy message types
180
+ elif not validate_message_type(value):
181
+ raise ValueError(
182
+ f"Invalid message type: '{value}'. "
183
+ "Expected format: '<category>' or '<category>.<action>', "
184
+ "where <category> must be 'train', 'evaluate', or 'query', "
185
+ "and <action> must be a valid Python identifier."
186
+ )
187
+
157
188
  self.__dict__["_message_type"] = value
158
189
 
159
190
  def __repr__(self) -> str:
@@ -169,7 +200,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes
169
200
 
170
201
 
171
202
  class Error:
172
- """A dataclass that stores information about an error that occurred.
203
+ """The class storing information about an error that occurred.
173
204
 
174
205
  Parameters
175
206
  ----------
@@ -209,31 +240,148 @@ class Error:
209
240
 
210
241
 
211
242
  class Message:
212
- """State of your application from the viewpoint of the entity using it.
243
+ """Represents a message exchanged between ClientApp and ServerApp.
244
+
245
+ This class encapsulates the payload and metadata necessary for communication
246
+ between a ClientApp and a ServerApp.
213
247
 
214
248
  Parameters
215
249
  ----------
216
- metadata : Metadata
217
- A dataclass including information about the message to be executed.
218
- content : Optional[RecordSet]
250
+ content : Optional[RecordDict] (default: None)
219
251
  Holds records either sent by another entity (e.g. sent by the server-side
220
252
  logic to a client, or vice-versa) or that will be sent to it.
221
- error : Optional[Error]
253
+ error : Optional[Error] (default: None)
222
254
  A dataclass that captures information about an error that took place
223
255
  when processing another message.
256
+ dst_node_id : Optional[int] (default: None)
257
+ An identifier for the node receiving this message.
258
+ message_type : Optional[str] (default: None)
259
+ A string that encodes the action to be executed on
260
+ the receiving end.
261
+ ttl : Optional[float] (default: None)
262
+ Time-to-live (TTL) for this message in seconds. If `None` (default),
263
+ the TTL is set to 43,200 seconds (12 hours).
264
+ group_id : Optional[str] (default: None)
265
+ An identifier for grouping messages. In some settings, this is used as
266
+ the FL round.
267
+ reply_to : Optional[Message] (default: None)
268
+ The instruction message to which this message is a reply. This message does
269
+ not retain the original message's content but derives its metadata from it.
224
270
  """
225
271
 
226
- def __init__(
272
+ @overload
273
+ def __init__( # pylint: disable=too-many-arguments # noqa: E704
227
274
  self,
228
- metadata: Metadata,
229
- content: RecordSet | None = None,
275
+ content: RecordDict,
276
+ dst_node_id: int,
277
+ message_type: str,
278
+ *,
279
+ ttl: float | None = None,
280
+ group_id: str | None = None,
281
+ ) -> None: ...
282
+
283
+ @overload
284
+ def __init__( # noqa: E704
285
+ self, content: RecordDict, *, reply_to: Message, ttl: float | None = None
286
+ ) -> None: ...
287
+
288
+ @overload
289
+ def __init__( # noqa: E704
290
+ self, error: Error, *, reply_to: Message, ttl: float | None = None
291
+ ) -> None: ...
292
+
293
+ def __init__( # pylint: disable=too-many-arguments
294
+ self,
295
+ *args: Any,
296
+ dst_node_id: int | None = None,
297
+ message_type: str | None = None,
298
+ content: RecordDict | None = None,
230
299
  error: Error | None = None,
300
+ ttl: float | None = None,
301
+ group_id: str | None = None,
302
+ reply_to: Message | None = None,
303
+ metadata: Metadata | None = None,
231
304
  ) -> None:
232
- if not (content is None) ^ (error is None):
233
- raise ValueError("Either `content` or `error` must be set, but not both.")
305
+ # Set positional arguments
306
+ content, error, dst_node_id, message_type = _extract_positional_args(
307
+ *args,
308
+ content=content,
309
+ error=error,
310
+ dst_node_id=dst_node_id,
311
+ message_type=message_type,
312
+ )
313
+ _check_arg_types(
314
+ dst_node_id=dst_node_id,
315
+ message_type=message_type,
316
+ content=content,
317
+ error=error,
318
+ ttl=ttl,
319
+ group_id=group_id,
320
+ reply_to=reply_to,
321
+ metadata=metadata,
322
+ )
234
323
 
235
- metadata.created_at = time.time() # Set the message creation timestamp
236
- metadata.delivered_at = ""
324
+ # Set metadata directly (This is for internal use only)
325
+ if metadata is not None:
326
+ # When metadata is set, all other arguments must be None,
327
+ # except `content`, `error`, or `content_or_error`
328
+ if any(
329
+ x is not None
330
+ for x in [dst_node_id, message_type, ttl, group_id, reply_to]
331
+ ):
332
+ raise MessageInitializationError(
333
+ f"Invalid arguments for {Message.__qualname__}. "
334
+ "Expected only `metadata` to be set when creating a message "
335
+ "with provided metadata."
336
+ )
337
+
338
+ # Create metadata for an instruction message
339
+ elif reply_to is None:
340
+ # Check arguments
341
+ # `content`, `dst_node_id` and `message_type` must be set
342
+ if not (
343
+ isinstance(content, RecordDict)
344
+ and isinstance(dst_node_id, int)
345
+ and isinstance(message_type, str)
346
+ ):
347
+ raise MessageInitializationError()
348
+
349
+ # Set metadata
350
+ metadata = Metadata(
351
+ run_id=0, # Will be set before pushed
352
+ message_id="", # Will be set by the SuperLink
353
+ src_node_id=0, # Will be set before pushed
354
+ dst_node_id=dst_node_id,
355
+ # Instruction messages do not reply to any message
356
+ reply_to_message_id="",
357
+ group_id=group_id or "",
358
+ created_at=now().timestamp(),
359
+ ttl=ttl or DEFAULT_TTL,
360
+ message_type=message_type,
361
+ )
362
+
363
+ # Create metadata for a reply message
364
+ else:
365
+ # Check arguments
366
+ # `dst_node_id`, `message_type` and `group_id` must not be set
367
+ if any(x is not None for x in [dst_node_id, message_type, group_id]):
368
+ raise MessageInitializationError()
369
+
370
+ # Set metadata
371
+ current = now().timestamp()
372
+ metadata = Metadata(
373
+ run_id=reply_to.metadata.run_id,
374
+ message_id="", # Will be set by the SuperLink
375
+ src_node_id=reply_to.metadata.dst_node_id,
376
+ dst_node_id=reply_to.metadata.src_node_id,
377
+ reply_to_message_id=reply_to.metadata.message_id,
378
+ group_id=reply_to.metadata.group_id,
379
+ created_at=current,
380
+ ttl=_limit_reply_ttl(current, ttl, reply_to),
381
+ message_type=reply_to.metadata.message_type,
382
+ )
383
+
384
+ metadata.delivered_at = "" # Backward compatibility
237
385
  var_dict = {
238
386
  "_metadata": metadata,
239
387
  "_content": content,
@@ -247,17 +395,17 @@ class Message:
247
395
  return cast(Metadata, self.__dict__["_metadata"])
248
396
 
249
397
  @property
250
- def content(self) -> RecordSet:
398
+ def content(self) -> RecordDict:
251
399
  """The content of this message."""
252
400
  if self.__dict__["_content"] is None:
253
401
  raise ValueError(
254
402
  "Message content is None. Use <message>.has_content() "
255
403
  "to check if a message has content."
256
404
  )
257
- return cast(RecordSet, self.__dict__["_content"])
405
+ return cast(RecordDict, self.__dict__["_content"])
258
406
 
259
407
  @content.setter
260
- def content(self, value: RecordSet) -> None:
408
+ def content(self, value: RecordDict) -> None:
261
409
  """Set content."""
262
410
  if self.__dict__["_error"] is None:
263
411
  self.__dict__["_content"] = value
@@ -308,33 +456,25 @@ class Message:
308
456
  message : Message
309
457
  A Message containing only the relevant error and metadata.
310
458
  """
311
- # If no TTL passed, use default for message creation (will update after
312
- # message creation)
313
- ttl_ = DEFAULT_TTL if ttl is None else ttl
314
- # Create reply with error
315
- message = Message(metadata=_create_reply_metadata(self, ttl_), error=error)
316
-
317
- if ttl is None:
318
- # Set TTL equal to the remaining time for the received message to expire
319
- ttl = self.metadata.ttl - (
320
- message.metadata.created_at - self.metadata.created_at
321
- )
322
- message.metadata.ttl = ttl
323
-
324
- self._limit_message_res_ttl(message)
325
-
326
- return message
459
+ warn_deprecated_feature(
460
+ "`Message.create_error_reply` is deprecated. "
461
+ "Instead of calling `some_message.create_error_reply(some_error, ttl=...)`"
462
+ ", use `Message(some_error, reply_to=some_message, ttl=...)`."
463
+ )
464
+ if ttl is not None:
465
+ return Message(error, reply_to=self, ttl=ttl)
466
+ return Message(error, reply_to=self)
327
467
 
328
- def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message:
468
+ def create_reply(self, content: RecordDict, ttl: float | None = None) -> Message:
329
469
  """Create a reply to this message with specified content and TTL.
330
470
 
331
471
  The method generates a new `Message` as a reply to this message.
332
472
  It inherits 'run_id', 'src_node_id', 'dst_node_id', and 'message_type' from
333
- this message and sets 'reply_to_message' to the ID of this message.
473
+ this message and sets 'reply_to_message_id' to the ID of this message.
334
474
 
335
475
  Parameters
336
476
  ----------
337
- content : RecordSet
477
+ content : RecordDict
338
478
  The content for the reply message.
339
479
  ttl : Optional[float] (default: None)
340
480
  Time-to-live for this message in seconds. If unset, it will be set based
@@ -348,25 +488,14 @@ class Message:
348
488
  Message
349
489
  A new `Message` instance representing the reply.
350
490
  """
351
- # If no TTL passed, use default for message creation (will update after
352
- # message creation)
353
- ttl_ = DEFAULT_TTL if ttl is None else ttl
354
-
355
- message = Message(
356
- metadata=_create_reply_metadata(self, ttl_),
357
- content=content,
491
+ warn_deprecated_feature(
492
+ "`Message.create_reply` is deprecated. "
493
+ "Instead of calling `some_message.create_reply(some_content, ttl=...)`"
494
+ ", use `Message(some_content, reply_to=some_message, ttl=...)`."
358
495
  )
359
-
360
- if ttl is None:
361
- # Set TTL equal to the remaining time for the received message to expire
362
- ttl = self.metadata.ttl - (
363
- message.metadata.created_at - self.metadata.created_at
364
- )
365
- message.metadata.ttl = ttl
366
-
367
- self._limit_message_res_ttl(message)
368
-
369
- return message
496
+ if ttl is not None:
497
+ return Message(content, reply_to=self, ttl=ttl)
498
+ return Message(content, reply_to=self)
370
499
 
371
500
  def __repr__(self) -> str:
372
501
  """Return a string representation of this instance."""
@@ -379,41 +508,137 @@ class Message:
379
508
  )
380
509
  return f"{self.__class__.__qualname__}({view})"
381
510
 
382
- def _limit_message_res_ttl(self, message: Message) -> None:
383
- """Limit the TTL of the provided Message to not exceed the expiration time of
384
- this Message it replies to.
385
511
 
386
- Parameters
387
- ----------
388
- message : Message
389
- The reply Message to limit the TTL for.
390
- """
391
- # Calculate the maximum allowed TTL
392
- max_allowed_ttl = (
393
- self.metadata.created_at + self.metadata.ttl - message.metadata.created_at
512
+ def make_message(
513
+ metadata: Metadata, content: RecordDict | None = None, error: Error | None = None
514
+ ) -> Message:
515
+ """Create a message with the provided metadata, content, and error."""
516
+ return Message(metadata=metadata, content=content, error=error) # type: ignore
517
+
518
+
519
+ def _limit_reply_ttl(
520
+ current: float, reply_ttl: float | None, reply_to: Message
521
+ ) -> float:
522
+ """Limit the TTL of a reply message such that it does exceed the expiration time of
523
+ the message it replies to."""
524
+ # Calculate the maximum allowed TTL
525
+ max_allowed_ttl = reply_to.metadata.created_at + reply_to.metadata.ttl - current
526
+
527
+ if reply_ttl is not None and reply_ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE:
528
+ log(
529
+ WARNING,
530
+ "The reply TTL of %.2f seconds exceeded the "
531
+ "allowed maximum of %.2f seconds. "
532
+ "The TTL has been updated to the allowed maximum.",
533
+ reply_ttl,
534
+ max_allowed_ttl,
394
535
  )
395
-
396
- if message.metadata.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE:
397
- log(
398
- WARNING,
399
- "The reply TTL of %.2f seconds exceeded the "
400
- "allowed maximum of %.2f seconds. "
401
- "The TTL has been updated to the allowed maximum.",
402
- message.metadata.ttl,
403
- max_allowed_ttl,
404
- )
405
- message.metadata.ttl = max_allowed_ttl
406
-
407
-
408
- def _create_reply_metadata(msg: Message, ttl: float) -> Metadata:
409
- """Construct metadata for a reply message."""
410
- return Metadata(
411
- run_id=msg.metadata.run_id,
412
- message_id="",
413
- src_node_id=msg.metadata.dst_node_id,
414
- dst_node_id=msg.metadata.src_node_id,
415
- reply_to_message=msg.metadata.message_id,
416
- group_id=msg.metadata.group_id,
417
- ttl=ttl,
418
- message_type=msg.metadata.message_type,
419
- )
536
+ return max_allowed_ttl
537
+
538
+ return reply_ttl or max_allowed_ttl
539
+
540
+
541
+ def _extract_positional_args(
542
+ *args: Any,
543
+ content: RecordDict | None,
544
+ error: Error | None,
545
+ dst_node_id: int | None,
546
+ message_type: str | None,
547
+ ) -> tuple[RecordDict | None, Error | None, int | None, str | None]:
548
+ """Extract positional arguments for the `Message` constructor."""
549
+ content_or_error = args[0] if args else None
550
+ if len(args) > 1:
551
+ if dst_node_id is not None:
552
+ raise MessageInitializationError()
553
+ dst_node_id = args[1]
554
+ if len(args) > 2:
555
+ if message_type is not None:
556
+ raise MessageInitializationError()
557
+ message_type = args[2]
558
+ if len(args) > 3:
559
+ raise MessageInitializationError()
560
+
561
+ # One and only one of `content_or_error`, `content` and `error` must be set
562
+ if sum(x is not None for x in [content_or_error, content, error]) != 1:
563
+ raise MessageInitializationError()
564
+
565
+ # Set `content` or `error` based on `content_or_error`
566
+ if content_or_error is not None: # This means `content` and `error` are None
567
+ if isinstance(content_or_error, RecordDict):
568
+ content = content_or_error
569
+ elif isinstance(content_or_error, Error):
570
+ error = content_or_error
571
+ else:
572
+ raise MessageInitializationError()
573
+ return content, error, dst_node_id, message_type
574
+
575
+
576
+ def _check_arg_types( # pylint: disable=too-many-arguments, R0917
577
+ dst_node_id: int | None = None,
578
+ message_type: str | None = None,
579
+ content: RecordDict | None = None,
580
+ error: Error | None = None,
581
+ ttl: float | None = None,
582
+ group_id: str | None = None,
583
+ reply_to: Message | None = None,
584
+ metadata: Metadata | None = None,
585
+ ) -> None:
586
+ """Check argument types for the `Message` constructor."""
587
+ # pylint: disable=too-many-boolean-expressions
588
+ if (
589
+ (dst_node_id is None or isinstance(dst_node_id, int))
590
+ and (message_type is None or isinstance(message_type, str))
591
+ and (content is None or isinstance(content, RecordDict))
592
+ and (error is None or isinstance(error, Error))
593
+ and (ttl is None or isinstance(ttl, (int, float)))
594
+ and (group_id is None or isinstance(group_id, str))
595
+ and (reply_to is None or isinstance(reply_to, Message))
596
+ and (metadata is None or isinstance(metadata, Metadata))
597
+ ):
598
+ return
599
+ raise MessageInitializationError()
600
+
601
+
602
+ def validate_message_type(message_type: str) -> bool:
603
+ """Validate if the message type is valid.
604
+
605
+ A valid message type format must be one of the following:
606
+
607
+ - "<category>"
608
+ - "<category>.<action>"
609
+
610
+ where `category` must be one of "train", "evaluate", or "query",
611
+ and `action` must be a valid Python identifier.
612
+ """
613
+ # Check if conforming to the format "<category>"
614
+ valid_types = {
615
+ MessageType.TRAIN,
616
+ MessageType.EVALUATE,
617
+ MessageType.QUERY,
618
+ MessageType.SYSTEM,
619
+ }
620
+ if message_type in valid_types:
621
+ return True
622
+
623
+ # Check if conforming to the format "<category>.<action>"
624
+ if message_type.count(".") != 1:
625
+ return False
626
+
627
+ category, action = message_type.split(".")
628
+ if category in valid_types and action.isidentifier():
629
+ return True
630
+
631
+ return False
632
+
633
+
634
+ def validate_legacy_message_type(message_type: str) -> bool:
635
+ """Validate if the legacy message type is valid."""
636
+ # Backward compatibility for legacy message types
637
+ if message_type in (
638
+ MessageTypeLegacy.GET_PARAMETERS,
639
+ MessageTypeLegacy.GET_PROPERTIES,
640
+ "reconnect",
641
+ ):
642
+ return True
643
+
644
+ return False
@@ -15,17 +15,21 @@
15
15
  """Record APIs."""
16
16
 
17
17
 
18
- from .configsrecord import ConfigsRecord
18
+ from .arrayrecord import Array, ArrayRecord, ParametersRecord
19
+ from .configrecord import ConfigRecord, ConfigsRecord
19
20
  from .conversion_utils import array_from_numpy
20
- from .metricsrecord import MetricsRecord
21
- from .parametersrecord import Array, ParametersRecord
22
- from .recordset import RecordSet
21
+ from .metricrecord import MetricRecord, MetricsRecord
22
+ from .recorddict import RecordDict, RecordSet
23
23
 
24
24
  __all__ = [
25
25
  "Array",
26
+ "ArrayRecord",
27
+ "ConfigRecord",
26
28
  "ConfigsRecord",
29
+ "MetricRecord",
27
30
  "MetricsRecord",
28
31
  "ParametersRecord",
32
+ "RecordDict",
29
33
  "RecordSet",
30
34
  "array_from_numpy",
31
35
  ]