flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240509__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (71) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/build.py +151 -0
  3. flwr/cli/config_utils.py +18 -46
  4. flwr/cli/new/new.py +44 -18
  5. flwr/cli/new/templates/app/code/client.hf.py.tpl +55 -0
  6. flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
  7. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
  8. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
  9. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +15 -29
  10. flwr/cli/new/templates/app/code/server.hf.py.tpl +17 -0
  11. flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
  12. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
  13. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
  14. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +9 -1
  15. flwr/cli/new/templates/app/code/task.hf.py.tpl +87 -0
  16. flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
  17. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
  18. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +31 -0
  19. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
  20. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
  21. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
  22. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
  23. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
  24. flwr/cli/run/run.py +1 -1
  25. flwr/cli/utils.py +18 -17
  26. flwr/client/__init__.py +1 -1
  27. flwr/client/app.py +17 -93
  28. flwr/client/grpc_client/connection.py +6 -1
  29. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  30. flwr/client/grpc_rere_client/connection.py +17 -2
  31. flwr/client/mod/centraldp_mods.py +4 -2
  32. flwr/client/mod/localdp_mod.py +9 -3
  33. flwr/client/rest_client/connection.py +5 -1
  34. flwr/client/supernode/__init__.py +2 -0
  35. flwr/client/supernode/app.py +181 -7
  36. flwr/common/grpc.py +5 -1
  37. flwr/common/logger.py +37 -4
  38. flwr/common/message.py +105 -86
  39. flwr/common/record/parametersrecord.py +0 -1
  40. flwr/common/record/recordset.py +17 -5
  41. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
  42. flwr/server/__init__.py +0 -2
  43. flwr/server/app.py +118 -2
  44. flwr/server/compat/app.py +5 -56
  45. flwr/server/compat/app_utils.py +1 -1
  46. flwr/server/compat/driver_client_proxy.py +27 -72
  47. flwr/server/driver/__init__.py +3 -0
  48. flwr/server/driver/driver.py +12 -242
  49. flwr/server/driver/grpc_driver.py +315 -0
  50. flwr/server/history.py +20 -20
  51. flwr/server/run_serverapp.py +18 -4
  52. flwr/server/server.py +2 -5
  53. flwr/server/strategy/dp_adaptive_clipping.py +5 -3
  54. flwr/server/strategy/dp_fixed_clipping.py +6 -3
  55. flwr/server/superlink/driver/driver_servicer.py +1 -1
  56. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
  57. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
  58. flwr/server/superlink/fleet/vce/backend/raybackend.py +9 -6
  59. flwr/server/superlink/fleet/vce/vce_api.py +1 -1
  60. flwr/server/superlink/state/in_memory_state.py +76 -8
  61. flwr/server/superlink/state/sqlite_state.py +116 -11
  62. flwr/server/superlink/state/state.py +35 -3
  63. flwr/simulation/__init__.py +2 -2
  64. flwr/simulation/app.py +16 -1
  65. flwr/simulation/run_simulation.py +14 -9
  66. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/METADATA +3 -2
  67. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/RECORD +70 -55
  68. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/entry_points.txt +1 -1
  69. flwr/server/driver/abc_driver.py +0 -140
  70. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/LICENSE +0 -0
  71. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/WHEEL +0 -0
@@ -15,11 +15,27 @@
15
15
  """Flower SuperNode."""
16
16
 
17
17
  import argparse
18
- from logging import DEBUG, INFO
18
+ import sys
19
+ from logging import DEBUG, INFO, WARN
20
+ from pathlib import Path
21
+ from typing import Callable, Optional, Tuple
19
22
 
23
+ from cryptography.hazmat.primitives.asymmetric import ec
24
+ from cryptography.hazmat.primitives.serialization import (
25
+ load_ssh_private_key,
26
+ load_ssh_public_key,
27
+ )
28
+
29
+ from flwr.client.client_app import ClientApp, LoadClientAppError
20
30
  from flwr.common import EventType, event
21
31
  from flwr.common.exit_handlers import register_exit_handlers
22
32
  from flwr.common.logger import log
33
+ from flwr.common.object_ref import load_app, validate
34
+ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
35
+ ssh_types_to_elliptic_curve,
36
+ )
37
+
38
+ from ..app import _start_client_internal
23
39
 
24
40
 
25
41
  def run_supernode() -> None:
@@ -28,12 +44,11 @@ def run_supernode() -> None:
28
44
 
29
45
  event(EventType.RUN_SUPERNODE_ENTER)
30
46
 
31
- args = _parse_args_run_supernode().parse_args()
47
+ _ = _parse_args_run_supernode().parse_args()
32
48
 
33
49
  log(
34
50
  DEBUG,
35
- "Flower will load ClientApp `%s`",
36
- getattr(args, "client-app"),
51
+ "Flower SuperNode starting...",
37
52
  )
38
53
 
39
54
  # Graceful shutdown
@@ -42,23 +57,144 @@ def run_supernode() -> None:
42
57
  )
43
58
 
44
59
 
60
+ def run_client_app() -> None:
61
+ """Run Flower client app."""
62
+ log(INFO, "Long-running Flower client starting")
63
+
64
+ event(EventType.RUN_CLIENT_APP_ENTER)
65
+
66
+ args = _parse_args_run_client_app().parse_args()
67
+
68
+ root_certificates = _get_certificates(args)
69
+ log(
70
+ DEBUG,
71
+ "Flower will load ClientApp `%s`",
72
+ getattr(args, "client-app"),
73
+ )
74
+ load_fn = _get_load_client_app_fn(args)
75
+ authentication_keys = _try_setup_client_authentication(args)
76
+
77
+ _start_client_internal(
78
+ server_address=args.server,
79
+ load_client_app_fn=load_fn,
80
+ transport="rest" if args.rest else "grpc-rere",
81
+ root_certificates=root_certificates,
82
+ insecure=args.insecure,
83
+ authentication_keys=authentication_keys,
84
+ max_retries=args.max_retries,
85
+ max_wait_time=args.max_wait_time,
86
+ )
87
+ register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
88
+
89
+
90
+ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
91
+ """Load certificates if specified in args."""
92
+ # Obtain certificates
93
+ if args.insecure:
94
+ if args.root_certificates is not None:
95
+ sys.exit(
96
+ "Conflicting options: The '--insecure' flag disables HTTPS, "
97
+ "but '--root-certificates' was also specified. Please remove "
98
+ "the '--root-certificates' option when running in insecure mode, "
99
+ "or omit '--insecure' to use HTTPS."
100
+ )
101
+ log(
102
+ WARN,
103
+ "Option `--insecure` was set. "
104
+ "Starting insecure HTTP client connected to %s.",
105
+ args.server,
106
+ )
107
+ root_certificates = None
108
+ else:
109
+ # Load the certificates if provided, or load the system certificates
110
+ cert_path = args.root_certificates
111
+ if cert_path is None:
112
+ root_certificates = None
113
+ else:
114
+ root_certificates = Path(cert_path).read_bytes()
115
+ log(
116
+ DEBUG,
117
+ "Starting secure HTTPS client connected to %s "
118
+ "with the following certificates: %s.",
119
+ args.server,
120
+ cert_path,
121
+ )
122
+ return root_certificates
123
+
124
+
125
+ def _get_load_client_app_fn(
126
+ args: argparse.Namespace,
127
+ ) -> Callable[[], ClientApp]:
128
+ """Get the load_client_app_fn function."""
129
+ client_app_dir = args.dir
130
+ if client_app_dir is not None:
131
+ sys.path.insert(0, client_app_dir)
132
+
133
+ app_ref: str = getattr(args, "client-app")
134
+ valid, error_msg = validate(app_ref)
135
+ if not valid and error_msg:
136
+ raise LoadClientAppError(error_msg) from None
137
+
138
+ def _load() -> ClientApp:
139
+ client_app = load_app(app_ref, LoadClientAppError)
140
+
141
+ if not isinstance(client_app, ClientApp):
142
+ raise LoadClientAppError(
143
+ f"Attribute {app_ref} is not of type {ClientApp}",
144
+ ) from None
145
+
146
+ return client_app
147
+
148
+ return _load
149
+
150
+
45
151
  def _parse_args_run_supernode() -> argparse.ArgumentParser:
46
152
  """Parse flower-supernode command line arguments."""
47
153
  parser = argparse.ArgumentParser(
48
154
  description="Start a Flower SuperNode",
49
155
  )
50
156
 
51
- parse_args_run_client_app(parser=parser)
157
+ parser.add_argument(
158
+ "client-app",
159
+ nargs="?",
160
+ default="",
161
+ help="For example: `client:app` or `project.package.module:wrapper.app`. "
162
+ "This is optional and serves as the default ClientApp to be loaded when "
163
+ "the ServerApp does not specify `fab_id` and `fab_version`. "
164
+ "If not provided, defaults to an empty string.",
165
+ )
166
+ _parse_args_common(parser)
167
+ parser.add_argument(
168
+ "--flwr-dir",
169
+ default=None,
170
+ help="""The path containing installed Flower Apps.
171
+ By default, this value isequal to:
172
+
173
+ - `$FLWR_HOME/` if `$FLWR_HOME` is defined
174
+ - `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
175
+ - `$HOME/.flwr/` in all other cases
176
+ """,
177
+ )
52
178
 
53
179
  return parser
54
180
 
55
181
 
56
- def parse_args_run_client_app(parser: argparse.ArgumentParser) -> None:
57
- """Parse command line arguments."""
182
+ def _parse_args_run_client_app() -> argparse.ArgumentParser:
183
+ """Parse flower-client-app command line arguments."""
184
+ parser = argparse.ArgumentParser(
185
+ description="Start a Flower client app",
186
+ )
187
+
58
188
  parser.add_argument(
59
189
  "client-app",
60
190
  help="For example: `client:app` or `project.package.module:wrapper.app`",
61
191
  )
192
+ _parse_args_common(parser=parser)
193
+
194
+ return parser
195
+
196
+
197
+ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
62
198
  parser.add_argument(
63
199
  "--insecure",
64
200
  action="store_true",
@@ -105,3 +241,41 @@ def parse_args_run_client_app(parser: argparse.ArgumentParser) -> None:
105
241
  "app from there."
106
242
  " Default: current working directory.",
107
243
  )
244
+ parser.add_argument(
245
+ "--authentication-keys",
246
+ nargs=2,
247
+ metavar=("CLIENT_PRIVATE_KEY", "CLIENT_PUBLIC_KEY"),
248
+ type=str,
249
+ help="Provide two file paths: (1) the client's private "
250
+ "key file, and (2) the client's public key file.",
251
+ )
252
+
253
+
254
+ def _try_setup_client_authentication(
255
+ args: argparse.Namespace,
256
+ ) -> Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
257
+ if not args.authentication_keys:
258
+ return None
259
+
260
+ ssh_private_key = load_ssh_private_key(
261
+ Path(args.authentication_keys[0]).read_bytes(),
262
+ None,
263
+ )
264
+ ssh_public_key = load_ssh_public_key(Path(args.authentication_keys[1]).read_bytes())
265
+
266
+ try:
267
+ client_private_key, client_public_key = ssh_types_to_elliptic_curve(
268
+ ssh_private_key, ssh_public_key
269
+ )
270
+ except TypeError:
271
+ sys.exit(
272
+ "The file paths provided could not be read as a private and public "
273
+ "key pair. Client authentication requires an elliptic curve public and "
274
+ "private key pair. Please provide the file paths containing elliptic "
275
+ "curve private and public keys to '--authentication-keys'."
276
+ )
277
+
278
+ return (
279
+ client_private_key,
280
+ client_public_key,
281
+ )
flwr/common/grpc.py CHANGED
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  from logging import DEBUG
19
- from typing import Optional
19
+ from typing import Optional, Sequence
20
20
 
21
21
  import grpc
22
22
 
@@ -30,6 +30,7 @@ def create_channel(
30
30
  insecure: bool,
31
31
  root_certificates: Optional[bytes] = None,
32
32
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
33
+ interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None,
33
34
  ) -> grpc.Channel:
34
35
  """Create a gRPC channel, either secure or insecure."""
35
36
  # Check for conflicting parameters
@@ -57,4 +58,7 @@ def create_channel(
57
58
  )
58
59
  log(DEBUG, "Opened secure gRPC connection using certificates")
59
60
 
61
+ if interceptors is not None:
62
+ channel = grpc.intercept_channel(channel, interceptors)
63
+
60
64
  return channel
flwr/common/logger.py CHANGED
@@ -82,13 +82,20 @@ class ConsoleHandler(StreamHandler):
82
82
  return formatter.format(record)
83
83
 
84
84
 
85
- def update_console_handler(level: int, timestamps: bool, colored: bool) -> None:
85
+ def update_console_handler(
86
+ level: Optional[int] = None,
87
+ timestamps: Optional[bool] = None,
88
+ colored: Optional[bool] = None,
89
+ ) -> None:
86
90
  """Update the logging handler."""
87
91
  for handler in logging.getLogger(LOGGER_NAME).handlers:
88
92
  if isinstance(handler, ConsoleHandler):
89
- handler.setLevel(level)
90
- handler.timestamps = timestamps
91
- handler.colored = colored
93
+ if level is not None:
94
+ handler.setLevel(level)
95
+ if timestamps is not None:
96
+ handler.timestamps = timestamps
97
+ if colored is not None:
98
+ handler.colored = colored
92
99
 
93
100
 
94
101
  # Configure console logger
@@ -188,3 +195,29 @@ def warn_deprecated_feature(name: str) -> None:
188
195
  """,
189
196
  name,
190
197
  )
198
+
199
+
200
+ def set_logger_propagation(
201
+ child_logger: logging.Logger, value: bool = True
202
+ ) -> logging.Logger:
203
+ """Set the logger propagation attribute.
204
+
205
+ Parameters
206
+ ----------
207
+ child_logger : logging.Logger
208
+ Child logger object
209
+ value : bool
210
+ Boolean setting for propagation. If True, both parent and child logger
211
+ display messages. Otherwise, only the child logger displays a message.
212
+ This False setting prevents duplicate logs in Colab notebooks.
213
+ Reference: https://stackoverflow.com/a/19561320
214
+
215
+ Returns
216
+ -------
217
+ logging.Logger
218
+ Child logger object with updated propagation setting
219
+ """
220
+ child_logger.propagate = value
221
+ if not child_logger.propagate:
222
+ child_logger.log(logging.DEBUG, "Logger propagate set to False")
223
+ return child_logger
flwr/common/message.py CHANGED
@@ -18,14 +18,13 @@ from __future__ import annotations
18
18
 
19
19
  import time
20
20
  import warnings
21
- from dataclasses import dataclass
21
+ from typing import Optional, cast
22
22
 
23
23
  from .record import RecordSet
24
24
 
25
25
  DEFAULT_TTL = 3600
26
26
 
27
27
 
28
- @dataclass
29
28
  class Metadata: # pylint: disable=too-many-instance-attributes
30
29
  """A dataclass holding metadata associated with the current message.
31
30
 
@@ -55,17 +54,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
55
54
  is more relevant when conducting simulations.
56
55
  """
57
56
 
58
- _run_id: int
59
- _message_id: str
60
- _src_node_id: int
61
- _dst_node_id: int
62
- _reply_to_message: str
63
- _group_id: str
64
- _ttl: float
65
- _message_type: str
66
- _partition_id: int | None
67
- _created_at: float # Unix timestamp (in seconds) to be set upon message creation
68
-
69
57
  def __init__( # pylint: disable=too-many-arguments
70
58
  self,
71
59
  run_id: int,
@@ -78,98 +66,111 @@ class Metadata: # pylint: disable=too-many-instance-attributes
78
66
  message_type: str,
79
67
  partition_id: int | None = None,
80
68
  ) -> None:
81
- self._run_id = run_id
82
- self._message_id = message_id
83
- self._src_node_id = src_node_id
84
- self._dst_node_id = dst_node_id
85
- self._reply_to_message = reply_to_message
86
- self._group_id = group_id
87
- self._ttl = ttl
88
- self._message_type = message_type
89
- self._partition_id = partition_id
69
+ var_dict = {
70
+ "_run_id": run_id,
71
+ "_message_id": message_id,
72
+ "_src_node_id": src_node_id,
73
+ "_dst_node_id": dst_node_id,
74
+ "_reply_to_message": reply_to_message,
75
+ "_group_id": group_id,
76
+ "_ttl": ttl,
77
+ "_message_type": message_type,
78
+ "_partition_id": partition_id,
79
+ }
80
+ self.__dict__.update(var_dict)
90
81
 
91
82
  @property
92
83
  def run_id(self) -> int:
93
84
  """An identifier for the current run."""
94
- return self._run_id
85
+ return cast(int, self.__dict__["_run_id"])
95
86
 
96
87
  @property
97
88
  def message_id(self) -> str:
98
89
  """An identifier for the current message."""
99
- return self._message_id
90
+ return cast(str, self.__dict__["_message_id"])
100
91
 
101
92
  @property
102
93
  def src_node_id(self) -> int:
103
94
  """An identifier for the node sending this message."""
104
- return self._src_node_id
95
+ return cast(int, self.__dict__["_src_node_id"])
105
96
 
106
97
  @property
107
98
  def reply_to_message(self) -> str:
108
99
  """An identifier for the message this message replies to."""
109
- return self._reply_to_message
100
+ return cast(str, self.__dict__["_reply_to_message"])
110
101
 
111
102
  @property
112
103
  def dst_node_id(self) -> int:
113
104
  """An identifier for the node receiving this message."""
114
- return self._dst_node_id
105
+ return cast(int, self.__dict__["_dst_node_id"])
115
106
 
116
107
  @dst_node_id.setter
117
108
  def dst_node_id(self, value: int) -> None:
118
109
  """Set dst_node_id."""
119
- self._dst_node_id = value
110
+ self.__dict__["_dst_node_id"] = value
120
111
 
121
112
  @property
122
113
  def group_id(self) -> str:
123
114
  """An identifier for grouping messages."""
124
- return self._group_id
115
+ return cast(str, self.__dict__["_group_id"])
125
116
 
126
117
  @group_id.setter
127
118
  def group_id(self, value: str) -> None:
128
119
  """Set group_id."""
129
- self._group_id = value
120
+ self.__dict__["_group_id"] = value
130
121
 
131
122
  @property
132
123
  def created_at(self) -> float:
133
124
  """Unix timestamp when the message was created."""
134
- return self._created_at
125
+ return cast(float, self.__dict__["_created_at"])
135
126
 
136
127
  @created_at.setter
137
128
  def created_at(self, value: float) -> None:
138
- """Set creation timestamp for this messages."""
139
- self._created_at = value
129
+ """Set creation timestamp for this message."""
130
+ self.__dict__["_created_at"] = value
140
131
 
141
132
  @property
142
133
  def ttl(self) -> float:
143
134
  """Time-to-live for this message."""
144
- return self._ttl
135
+ return cast(float, self.__dict__["_ttl"])
145
136
 
146
137
  @ttl.setter
147
138
  def ttl(self, value: float) -> None:
148
139
  """Set ttl."""
149
- self._ttl = value
140
+ self.__dict__["_ttl"] = value
150
141
 
151
142
  @property
152
143
  def message_type(self) -> str:
153
144
  """A string that encodes the action to be executed on the receiving end."""
154
- return self._message_type
145
+ return cast(str, self.__dict__["_message_type"])
155
146
 
156
147
  @message_type.setter
157
148
  def message_type(self, value: str) -> None:
158
149
  """Set message_type."""
159
- self._message_type = value
150
+ self.__dict__["_message_type"] = value
160
151
 
161
152
  @property
162
153
  def partition_id(self) -> int | None:
163
154
  """An identifier telling which data partition a ClientApp should use."""
164
- return self._partition_id
155
+ return cast(int, self.__dict__["_partition_id"])
165
156
 
166
157
  @partition_id.setter
167
158
  def partition_id(self, value: int) -> None:
168
- """Set patition_id."""
169
- self._partition_id = value
159
+ """Set partition_id."""
160
+ self.__dict__["_partition_id"] = value
161
+
162
+ def __repr__(self) -> str:
163
+ """Return a string representation of this instance."""
164
+ view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
165
+ return f"{self.__class__.__qualname__}({view})"
166
+
167
+ def __eq__(self, other: object) -> bool:
168
+ """Compare two instances of the class."""
169
+ if not isinstance(other, self.__class__):
170
+ raise NotImplementedError
171
+ return self.__dict__ == other.__dict__
170
172
 
171
173
 
172
- @dataclass
173
174
  class Error:
174
175
  """A dataclass that stores information about an error that occurred.
175
176
 
@@ -181,25 +182,35 @@ class Error:
181
182
  A reason for why the error arose (e.g. an exception stack-trace)
182
183
  """
183
184
 
184
- _code: int
185
- _reason: str | None = None
186
-
187
185
  def __init__(self, code: int, reason: str | None = None) -> None:
188
- self._code = code
189
- self._reason = reason
186
+ var_dict = {
187
+ "_code": code,
188
+ "_reason": reason,
189
+ }
190
+ self.__dict__.update(var_dict)
190
191
 
191
192
  @property
192
193
  def code(self) -> int:
193
194
  """Error code."""
194
- return self._code
195
+ return cast(int, self.__dict__["_code"])
195
196
 
196
197
  @property
197
198
  def reason(self) -> str | None:
198
199
  """Reason reported about the error."""
199
- return self._reason
200
+ return cast(Optional[str], self.__dict__["_reason"])
201
+
202
+ def __repr__(self) -> str:
203
+ """Return a string representation of this instance."""
204
+ view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
205
+ return f"{self.__class__.__qualname__}({view})"
206
+
207
+ def __eq__(self, other: object) -> bool:
208
+ """Compare two instances of the class."""
209
+ if not isinstance(other, self.__class__):
210
+ raise NotImplementedError
211
+ return self.__dict__ == other.__dict__
200
212
 
201
213
 
202
- @dataclass
203
214
  class Message:
204
215
  """State of your application from the viewpoint of the entity using it.
205
216
 
@@ -215,88 +226,70 @@ class Message:
215
226
  when processing another message.
216
227
  """
217
228
 
218
- _metadata: Metadata
219
- _content: RecordSet | None = None
220
- _error: Error | None = None
221
-
222
229
  def __init__(
223
230
  self,
224
231
  metadata: Metadata,
225
232
  content: RecordSet | None = None,
226
233
  error: Error | None = None,
227
234
  ) -> None:
228
- self._metadata = metadata
229
-
230
- # Set message creation timestamp
231
- self._metadata.created_at = time.time()
232
-
233
235
  if not (content is None) ^ (error is None):
234
236
  raise ValueError("Either `content` or `error` must be set, but not both.")
235
237
 
236
- self._content = content
237
- self._error = error
238
+ metadata.created_at = time.time() # Set the message creation timestamp
239
+ var_dict = {
240
+ "_metadata": metadata,
241
+ "_content": content,
242
+ "_error": error,
243
+ }
244
+ self.__dict__.update(var_dict)
238
245
 
239
246
  @property
240
247
  def metadata(self) -> Metadata:
241
248
  """A dataclass including information about the message to be executed."""
242
- return self._metadata
249
+ return cast(Metadata, self.__dict__["_metadata"])
243
250
 
244
251
  @property
245
252
  def content(self) -> RecordSet:
246
253
  """The content of this message."""
247
- if self._content is None:
254
+ if self.__dict__["_content"] is None:
248
255
  raise ValueError(
249
256
  "Message content is None. Use <message>.has_content() "
250
257
  "to check if a message has content."
251
258
  )
252
- return self._content
259
+ return cast(RecordSet, self.__dict__["_content"])
253
260
 
254
261
  @content.setter
255
262
  def content(self, value: RecordSet) -> None:
256
263
  """Set content."""
257
- if self._error is None:
258
- self._content = value
264
+ if self.__dict__["_error"] is None:
265
+ self.__dict__["_content"] = value
259
266
  else:
260
267
  raise ValueError("A message with an error set cannot have content.")
261
268
 
262
269
  @property
263
270
  def error(self) -> Error:
264
271
  """Error captured by this message."""
265
- if self._error is None:
272
+ if self.__dict__["_error"] is None:
266
273
  raise ValueError(
267
274
  "Message error is None. Use <message>.has_error() "
268
275
  "to check first if a message carries an error."
269
276
  )
270
- return self._error
277
+ return cast(Error, self.__dict__["_error"])
271
278
 
272
279
  @error.setter
273
280
  def error(self, value: Error) -> None:
274
281
  """Set error."""
275
282
  if self.has_content():
276
283
  raise ValueError("A message with content set cannot carry an error.")
277
- self._error = value
284
+ self.__dict__["_error"] = value
278
285
 
279
286
  def has_content(self) -> bool:
280
287
  """Return True if message has content, else False."""
281
- return self._content is not None
288
+ return self.__dict__["_content"] is not None
282
289
 
283
290
  def has_error(self) -> bool:
284
291
  """Return True if message has an error, else False."""
285
- return self._error is not None
286
-
287
- def _create_reply_metadata(self, ttl: float) -> Metadata:
288
- """Construct metadata for a reply message."""
289
- return Metadata(
290
- run_id=self.metadata.run_id,
291
- message_id="",
292
- src_node_id=self.metadata.dst_node_id,
293
- dst_node_id=self.metadata.src_node_id,
294
- reply_to_message=self.metadata.message_id,
295
- group_id=self.metadata.group_id,
296
- ttl=ttl,
297
- message_type=self.metadata.message_type,
298
- partition_id=self.metadata.partition_id,
299
- )
292
+ return self.__dict__["_error"] is not None
300
293
 
301
294
  def create_error_reply(self, error: Error, ttl: float | None = None) -> Message:
302
295
  """Construct a reply message indicating an error happened.
@@ -323,7 +316,7 @@ class Message:
323
316
  # message creation)
324
317
  ttl_ = DEFAULT_TTL if ttl is None else ttl
325
318
  # Create reply with error
326
- message = Message(metadata=self._create_reply_metadata(ttl_), error=error)
319
+ message = Message(metadata=_create_reply_metadata(self, ttl_), error=error)
327
320
 
328
321
  if ttl is None:
329
322
  # Set TTL equal to the remaining time for the received message to expire
@@ -369,7 +362,7 @@ class Message:
369
362
  ttl_ = DEFAULT_TTL if ttl is None else ttl
370
363
 
371
364
  message = Message(
372
- metadata=self._create_reply_metadata(ttl_),
365
+ metadata=_create_reply_metadata(self, ttl_),
373
366
  content=content,
374
367
  )
375
368
 
@@ -381,3 +374,29 @@ class Message:
381
374
  message.metadata.ttl = ttl
382
375
 
383
376
  return message
377
+
378
+ def __repr__(self) -> str:
379
+ """Return a string representation of this instance."""
380
+ view = ", ".join(
381
+ [
382
+ f"{k.lstrip('_')}={v!r}"
383
+ for k, v in self.__dict__.items()
384
+ if v is not None
385
+ ]
386
+ )
387
+ return f"{self.__class__.__qualname__}({view})"
388
+
389
+
390
+ def _create_reply_metadata(msg: Message, ttl: float) -> Metadata:
391
+ """Construct metadata for a reply message."""
392
+ return Metadata(
393
+ run_id=msg.metadata.run_id,
394
+ message_id="",
395
+ src_node_id=msg.metadata.dst_node_id,
396
+ dst_node_id=msg.metadata.src_node_id,
397
+ reply_to_message=msg.metadata.message_id,
398
+ group_id=msg.metadata.group_id,
399
+ ttl=ttl,
400
+ message_type=msg.metadata.message_type,
401
+ partition_id=msg.metadata.partition_id,
402
+ )
@@ -82,7 +82,6 @@ def _check_value(value: Array) -> None:
82
82
  )
83
83
 
84
84
 
85
- @dataclass
86
85
  class ParametersRecord(TypedDict[str, Array]):
87
86
  """Parameters record.
88
87