flwr-nightly 1.9.0.dev20240417__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +151 -0
- flwr/cli/config_utils.py +19 -14
- flwr/cli/new/new.py +51 -22
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +42 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +26 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
- flwr/cli/run/run.py +1 -1
- flwr/cli/utils.py +18 -17
- flwr/client/__init__.py +3 -1
- flwr/client/app.py +20 -142
- flwr/client/grpc_client/connection.py +8 -2
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +33 -4
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/localdp_mod.py +9 -3
- flwr/client/rest_client/connection.py +92 -169
- flwr/client/supernode/__init__.py +24 -0
- flwr/client/supernode/app.py +281 -0
- flwr/common/grpc.py +5 -1
- flwr/common/logger.py +37 -4
- flwr/common/message.py +105 -86
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
- flwr/common/telemetry.py +4 -0
- flwr/server/app.py +116 -6
- flwr/server/compat/app.py +2 -2
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +27 -70
- flwr/server/driver/__init__.py +2 -1
- flwr/server/driver/driver.py +12 -139
- flwr/server/driver/grpc_driver.py +199 -13
- flwr/server/run_serverapp.py +18 -4
- flwr/server/strategy/dp_adaptive_clipping.py +5 -3
- flwr/server/strategy/dp_fixed_clipping.py +6 -3
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +4 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +89 -12
- flwr/server/superlink/state/sqlite_state.py +133 -16
- flwr/server/superlink/state/state.py +56 -6
- flwr/simulation/__init__.py +2 -2
- flwr/simulation/app.py +16 -1
- flwr/simulation/run_simulation.py +10 -7
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +66 -52
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +2 -1
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
flwr/common/grpc.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from logging import DEBUG
|
|
19
|
-
from typing import Optional
|
|
19
|
+
from typing import Optional, Sequence
|
|
20
20
|
|
|
21
21
|
import grpc
|
|
22
22
|
|
|
@@ -30,6 +30,7 @@ def create_channel(
|
|
|
30
30
|
insecure: bool,
|
|
31
31
|
root_certificates: Optional[bytes] = None,
|
|
32
32
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
33
|
+
interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None,
|
|
33
34
|
) -> grpc.Channel:
|
|
34
35
|
"""Create a gRPC channel, either secure or insecure."""
|
|
35
36
|
# Check for conflicting parameters
|
|
@@ -57,4 +58,7 @@ def create_channel(
|
|
|
57
58
|
)
|
|
58
59
|
log(DEBUG, "Opened secure gRPC connection using certificates")
|
|
59
60
|
|
|
61
|
+
if interceptors is not None:
|
|
62
|
+
channel = grpc.intercept_channel(channel, interceptors)
|
|
63
|
+
|
|
60
64
|
return channel
|
flwr/common/logger.py
CHANGED
|
@@ -82,13 +82,20 @@ class ConsoleHandler(StreamHandler):
|
|
|
82
82
|
return formatter.format(record)
|
|
83
83
|
|
|
84
84
|
|
|
85
|
-
def update_console_handler(
|
|
85
|
+
def update_console_handler(
|
|
86
|
+
level: Optional[int] = None,
|
|
87
|
+
timestamps: Optional[bool] = None,
|
|
88
|
+
colored: Optional[bool] = None,
|
|
89
|
+
) -> None:
|
|
86
90
|
"""Update the logging handler."""
|
|
87
91
|
for handler in logging.getLogger(LOGGER_NAME).handlers:
|
|
88
92
|
if isinstance(handler, ConsoleHandler):
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
93
|
+
if level is not None:
|
|
94
|
+
handler.setLevel(level)
|
|
95
|
+
if timestamps is not None:
|
|
96
|
+
handler.timestamps = timestamps
|
|
97
|
+
if colored is not None:
|
|
98
|
+
handler.colored = colored
|
|
92
99
|
|
|
93
100
|
|
|
94
101
|
# Configure console logger
|
|
@@ -188,3 +195,29 @@ def warn_deprecated_feature(name: str) -> None:
|
|
|
188
195
|
""",
|
|
189
196
|
name,
|
|
190
197
|
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def set_logger_propagation(
|
|
201
|
+
child_logger: logging.Logger, value: bool = True
|
|
202
|
+
) -> logging.Logger:
|
|
203
|
+
"""Set the logger propagation attribute.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
child_logger : logging.Logger
|
|
208
|
+
Child logger object
|
|
209
|
+
value : bool
|
|
210
|
+
Boolean setting for propagation. If True, both parent and child logger
|
|
211
|
+
display messages. Otherwise, only the child logger displays a message.
|
|
212
|
+
This False setting prevents duplicate logs in Colab notebooks.
|
|
213
|
+
Reference: https://stackoverflow.com/a/19561320
|
|
214
|
+
|
|
215
|
+
Returns
|
|
216
|
+
-------
|
|
217
|
+
logging.Logger
|
|
218
|
+
Child logger object with updated propagation setting
|
|
219
|
+
"""
|
|
220
|
+
child_logger.propagate = value
|
|
221
|
+
if not child_logger.propagate:
|
|
222
|
+
child_logger.log(logging.DEBUG, "Logger propagate set to False")
|
|
223
|
+
return child_logger
|
flwr/common/message.py
CHANGED
|
@@ -18,14 +18,13 @@ from __future__ import annotations
|
|
|
18
18
|
|
|
19
19
|
import time
|
|
20
20
|
import warnings
|
|
21
|
-
from
|
|
21
|
+
from typing import Optional, cast
|
|
22
22
|
|
|
23
23
|
from .record import RecordSet
|
|
24
24
|
|
|
25
25
|
DEFAULT_TTL = 3600
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
@dataclass
|
|
29
28
|
class Metadata: # pylint: disable=too-many-instance-attributes
|
|
30
29
|
"""A dataclass holding metadata associated with the current message.
|
|
31
30
|
|
|
@@ -55,17 +54,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
55
54
|
is more relevant when conducting simulations.
|
|
56
55
|
"""
|
|
57
56
|
|
|
58
|
-
_run_id: int
|
|
59
|
-
_message_id: str
|
|
60
|
-
_src_node_id: int
|
|
61
|
-
_dst_node_id: int
|
|
62
|
-
_reply_to_message: str
|
|
63
|
-
_group_id: str
|
|
64
|
-
_ttl: float
|
|
65
|
-
_message_type: str
|
|
66
|
-
_partition_id: int | None
|
|
67
|
-
_created_at: float # Unix timestamp (in seconds) to be set upon message creation
|
|
68
|
-
|
|
69
57
|
def __init__( # pylint: disable=too-many-arguments
|
|
70
58
|
self,
|
|
71
59
|
run_id: int,
|
|
@@ -78,98 +66,111 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
78
66
|
message_type: str,
|
|
79
67
|
partition_id: int | None = None,
|
|
80
68
|
) -> None:
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
69
|
+
var_dict = {
|
|
70
|
+
"_run_id": run_id,
|
|
71
|
+
"_message_id": message_id,
|
|
72
|
+
"_src_node_id": src_node_id,
|
|
73
|
+
"_dst_node_id": dst_node_id,
|
|
74
|
+
"_reply_to_message": reply_to_message,
|
|
75
|
+
"_group_id": group_id,
|
|
76
|
+
"_ttl": ttl,
|
|
77
|
+
"_message_type": message_type,
|
|
78
|
+
"_partition_id": partition_id,
|
|
79
|
+
}
|
|
80
|
+
self.__dict__.update(var_dict)
|
|
90
81
|
|
|
91
82
|
@property
|
|
92
83
|
def run_id(self) -> int:
|
|
93
84
|
"""An identifier for the current run."""
|
|
94
|
-
return self._run_id
|
|
85
|
+
return cast(int, self.__dict__["_run_id"])
|
|
95
86
|
|
|
96
87
|
@property
|
|
97
88
|
def message_id(self) -> str:
|
|
98
89
|
"""An identifier for the current message."""
|
|
99
|
-
return self._message_id
|
|
90
|
+
return cast(str, self.__dict__["_message_id"])
|
|
100
91
|
|
|
101
92
|
@property
|
|
102
93
|
def src_node_id(self) -> int:
|
|
103
94
|
"""An identifier for the node sending this message."""
|
|
104
|
-
return self._src_node_id
|
|
95
|
+
return cast(int, self.__dict__["_src_node_id"])
|
|
105
96
|
|
|
106
97
|
@property
|
|
107
98
|
def reply_to_message(self) -> str:
|
|
108
99
|
"""An identifier for the message this message replies to."""
|
|
109
|
-
return self._reply_to_message
|
|
100
|
+
return cast(str, self.__dict__["_reply_to_message"])
|
|
110
101
|
|
|
111
102
|
@property
|
|
112
103
|
def dst_node_id(self) -> int:
|
|
113
104
|
"""An identifier for the node receiving this message."""
|
|
114
|
-
return self._dst_node_id
|
|
105
|
+
return cast(int, self.__dict__["_dst_node_id"])
|
|
115
106
|
|
|
116
107
|
@dst_node_id.setter
|
|
117
108
|
def dst_node_id(self, value: int) -> None:
|
|
118
109
|
"""Set dst_node_id."""
|
|
119
|
-
self._dst_node_id = value
|
|
110
|
+
self.__dict__["_dst_node_id"] = value
|
|
120
111
|
|
|
121
112
|
@property
|
|
122
113
|
def group_id(self) -> str:
|
|
123
114
|
"""An identifier for grouping messages."""
|
|
124
|
-
return self._group_id
|
|
115
|
+
return cast(str, self.__dict__["_group_id"])
|
|
125
116
|
|
|
126
117
|
@group_id.setter
|
|
127
118
|
def group_id(self, value: str) -> None:
|
|
128
119
|
"""Set group_id."""
|
|
129
|
-
self._group_id = value
|
|
120
|
+
self.__dict__["_group_id"] = value
|
|
130
121
|
|
|
131
122
|
@property
|
|
132
123
|
def created_at(self) -> float:
|
|
133
124
|
"""Unix timestamp when the message was created."""
|
|
134
|
-
return self._created_at
|
|
125
|
+
return cast(float, self.__dict__["_created_at"])
|
|
135
126
|
|
|
136
127
|
@created_at.setter
|
|
137
128
|
def created_at(self, value: float) -> None:
|
|
138
|
-
"""Set creation timestamp for this
|
|
139
|
-
self._created_at = value
|
|
129
|
+
"""Set creation timestamp for this message."""
|
|
130
|
+
self.__dict__["_created_at"] = value
|
|
140
131
|
|
|
141
132
|
@property
|
|
142
133
|
def ttl(self) -> float:
|
|
143
134
|
"""Time-to-live for this message."""
|
|
144
|
-
return self._ttl
|
|
135
|
+
return cast(float, self.__dict__["_ttl"])
|
|
145
136
|
|
|
146
137
|
@ttl.setter
|
|
147
138
|
def ttl(self, value: float) -> None:
|
|
148
139
|
"""Set ttl."""
|
|
149
|
-
self._ttl = value
|
|
140
|
+
self.__dict__["_ttl"] = value
|
|
150
141
|
|
|
151
142
|
@property
|
|
152
143
|
def message_type(self) -> str:
|
|
153
144
|
"""A string that encodes the action to be executed on the receiving end."""
|
|
154
|
-
return self._message_type
|
|
145
|
+
return cast(str, self.__dict__["_message_type"])
|
|
155
146
|
|
|
156
147
|
@message_type.setter
|
|
157
148
|
def message_type(self, value: str) -> None:
|
|
158
149
|
"""Set message_type."""
|
|
159
|
-
self._message_type = value
|
|
150
|
+
self.__dict__["_message_type"] = value
|
|
160
151
|
|
|
161
152
|
@property
|
|
162
153
|
def partition_id(self) -> int | None:
|
|
163
154
|
"""An identifier telling which data partition a ClientApp should use."""
|
|
164
|
-
return self._partition_id
|
|
155
|
+
return cast(int, self.__dict__["_partition_id"])
|
|
165
156
|
|
|
166
157
|
@partition_id.setter
|
|
167
158
|
def partition_id(self, value: int) -> None:
|
|
168
|
-
"""Set
|
|
169
|
-
self._partition_id = value
|
|
159
|
+
"""Set partition_id."""
|
|
160
|
+
self.__dict__["_partition_id"] = value
|
|
161
|
+
|
|
162
|
+
def __repr__(self) -> str:
|
|
163
|
+
"""Return a string representation of this instance."""
|
|
164
|
+
view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
|
|
165
|
+
return f"{self.__class__.__qualname__}({view})"
|
|
166
|
+
|
|
167
|
+
def __eq__(self, other: object) -> bool:
|
|
168
|
+
"""Compare two instances of the class."""
|
|
169
|
+
if not isinstance(other, self.__class__):
|
|
170
|
+
raise NotImplementedError
|
|
171
|
+
return self.__dict__ == other.__dict__
|
|
170
172
|
|
|
171
173
|
|
|
172
|
-
@dataclass
|
|
173
174
|
class Error:
|
|
174
175
|
"""A dataclass that stores information about an error that occurred.
|
|
175
176
|
|
|
@@ -181,25 +182,35 @@ class Error:
|
|
|
181
182
|
A reason for why the error arose (e.g. an exception stack-trace)
|
|
182
183
|
"""
|
|
183
184
|
|
|
184
|
-
_code: int
|
|
185
|
-
_reason: str | None = None
|
|
186
|
-
|
|
187
185
|
def __init__(self, code: int, reason: str | None = None) -> None:
|
|
188
|
-
|
|
189
|
-
|
|
186
|
+
var_dict = {
|
|
187
|
+
"_code": code,
|
|
188
|
+
"_reason": reason,
|
|
189
|
+
}
|
|
190
|
+
self.__dict__.update(var_dict)
|
|
190
191
|
|
|
191
192
|
@property
|
|
192
193
|
def code(self) -> int:
|
|
193
194
|
"""Error code."""
|
|
194
|
-
return self._code
|
|
195
|
+
return cast(int, self.__dict__["_code"])
|
|
195
196
|
|
|
196
197
|
@property
|
|
197
198
|
def reason(self) -> str | None:
|
|
198
199
|
"""Reason reported about the error."""
|
|
199
|
-
return self._reason
|
|
200
|
+
return cast(Optional[str], self.__dict__["_reason"])
|
|
201
|
+
|
|
202
|
+
def __repr__(self) -> str:
|
|
203
|
+
"""Return a string representation of this instance."""
|
|
204
|
+
view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
|
|
205
|
+
return f"{self.__class__.__qualname__}({view})"
|
|
206
|
+
|
|
207
|
+
def __eq__(self, other: object) -> bool:
|
|
208
|
+
"""Compare two instances of the class."""
|
|
209
|
+
if not isinstance(other, self.__class__):
|
|
210
|
+
raise NotImplementedError
|
|
211
|
+
return self.__dict__ == other.__dict__
|
|
200
212
|
|
|
201
213
|
|
|
202
|
-
@dataclass
|
|
203
214
|
class Message:
|
|
204
215
|
"""State of your application from the viewpoint of the entity using it.
|
|
205
216
|
|
|
@@ -215,88 +226,70 @@ class Message:
|
|
|
215
226
|
when processing another message.
|
|
216
227
|
"""
|
|
217
228
|
|
|
218
|
-
_metadata: Metadata
|
|
219
|
-
_content: RecordSet | None = None
|
|
220
|
-
_error: Error | None = None
|
|
221
|
-
|
|
222
229
|
def __init__(
|
|
223
230
|
self,
|
|
224
231
|
metadata: Metadata,
|
|
225
232
|
content: RecordSet | None = None,
|
|
226
233
|
error: Error | None = None,
|
|
227
234
|
) -> None:
|
|
228
|
-
self._metadata = metadata
|
|
229
|
-
|
|
230
|
-
# Set message creation timestamp
|
|
231
|
-
self._metadata.created_at = time.time()
|
|
232
|
-
|
|
233
235
|
if not (content is None) ^ (error is None):
|
|
234
236
|
raise ValueError("Either `content` or `error` must be set, but not both.")
|
|
235
237
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
+
metadata.created_at = time.time() # Set the message creation timestamp
|
|
239
|
+
var_dict = {
|
|
240
|
+
"_metadata": metadata,
|
|
241
|
+
"_content": content,
|
|
242
|
+
"_error": error,
|
|
243
|
+
}
|
|
244
|
+
self.__dict__.update(var_dict)
|
|
238
245
|
|
|
239
246
|
@property
|
|
240
247
|
def metadata(self) -> Metadata:
|
|
241
248
|
"""A dataclass including information about the message to be executed."""
|
|
242
|
-
return self._metadata
|
|
249
|
+
return cast(Metadata, self.__dict__["_metadata"])
|
|
243
250
|
|
|
244
251
|
@property
|
|
245
252
|
def content(self) -> RecordSet:
|
|
246
253
|
"""The content of this message."""
|
|
247
|
-
if self._content is None:
|
|
254
|
+
if self.__dict__["_content"] is None:
|
|
248
255
|
raise ValueError(
|
|
249
256
|
"Message content is None. Use <message>.has_content() "
|
|
250
257
|
"to check if a message has content."
|
|
251
258
|
)
|
|
252
|
-
return self._content
|
|
259
|
+
return cast(RecordSet, self.__dict__["_content"])
|
|
253
260
|
|
|
254
261
|
@content.setter
|
|
255
262
|
def content(self, value: RecordSet) -> None:
|
|
256
263
|
"""Set content."""
|
|
257
|
-
if self._error is None:
|
|
258
|
-
self._content = value
|
|
264
|
+
if self.__dict__["_error"] is None:
|
|
265
|
+
self.__dict__["_content"] = value
|
|
259
266
|
else:
|
|
260
267
|
raise ValueError("A message with an error set cannot have content.")
|
|
261
268
|
|
|
262
269
|
@property
|
|
263
270
|
def error(self) -> Error:
|
|
264
271
|
"""Error captured by this message."""
|
|
265
|
-
if self._error is None:
|
|
272
|
+
if self.__dict__["_error"] is None:
|
|
266
273
|
raise ValueError(
|
|
267
274
|
"Message error is None. Use <message>.has_error() "
|
|
268
275
|
"to check first if a message carries an error."
|
|
269
276
|
)
|
|
270
|
-
return self._error
|
|
277
|
+
return cast(Error, self.__dict__["_error"])
|
|
271
278
|
|
|
272
279
|
@error.setter
|
|
273
280
|
def error(self, value: Error) -> None:
|
|
274
281
|
"""Set error."""
|
|
275
282
|
if self.has_content():
|
|
276
283
|
raise ValueError("A message with content set cannot carry an error.")
|
|
277
|
-
self._error = value
|
|
284
|
+
self.__dict__["_error"] = value
|
|
278
285
|
|
|
279
286
|
def has_content(self) -> bool:
|
|
280
287
|
"""Return True if message has content, else False."""
|
|
281
|
-
return self._content is not None
|
|
288
|
+
return self.__dict__["_content"] is not None
|
|
282
289
|
|
|
283
290
|
def has_error(self) -> bool:
|
|
284
291
|
"""Return True if message has an error, else False."""
|
|
285
|
-
return self._error is not None
|
|
286
|
-
|
|
287
|
-
def _create_reply_metadata(self, ttl: float) -> Metadata:
|
|
288
|
-
"""Construct metadata for a reply message."""
|
|
289
|
-
return Metadata(
|
|
290
|
-
run_id=self.metadata.run_id,
|
|
291
|
-
message_id="",
|
|
292
|
-
src_node_id=self.metadata.dst_node_id,
|
|
293
|
-
dst_node_id=self.metadata.src_node_id,
|
|
294
|
-
reply_to_message=self.metadata.message_id,
|
|
295
|
-
group_id=self.metadata.group_id,
|
|
296
|
-
ttl=ttl,
|
|
297
|
-
message_type=self.metadata.message_type,
|
|
298
|
-
partition_id=self.metadata.partition_id,
|
|
299
|
-
)
|
|
292
|
+
return self.__dict__["_error"] is not None
|
|
300
293
|
|
|
301
294
|
def create_error_reply(self, error: Error, ttl: float | None = None) -> Message:
|
|
302
295
|
"""Construct a reply message indicating an error happened.
|
|
@@ -323,7 +316,7 @@ class Message:
|
|
|
323
316
|
# message creation)
|
|
324
317
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
325
318
|
# Create reply with error
|
|
326
|
-
message = Message(metadata=
|
|
319
|
+
message = Message(metadata=_create_reply_metadata(self, ttl_), error=error)
|
|
327
320
|
|
|
328
321
|
if ttl is None:
|
|
329
322
|
# Set TTL equal to the remaining time for the received message to expire
|
|
@@ -369,7 +362,7 @@ class Message:
|
|
|
369
362
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
370
363
|
|
|
371
364
|
message = Message(
|
|
372
|
-
metadata=
|
|
365
|
+
metadata=_create_reply_metadata(self, ttl_),
|
|
373
366
|
content=content,
|
|
374
367
|
)
|
|
375
368
|
|
|
@@ -381,3 +374,29 @@ class Message:
|
|
|
381
374
|
message.metadata.ttl = ttl
|
|
382
375
|
|
|
383
376
|
return message
|
|
377
|
+
|
|
378
|
+
def __repr__(self) -> str:
|
|
379
|
+
"""Return a string representation of this instance."""
|
|
380
|
+
view = ", ".join(
|
|
381
|
+
[
|
|
382
|
+
f"{k.lstrip('_')}={v!r}"
|
|
383
|
+
for k, v in self.__dict__.items()
|
|
384
|
+
if v is not None
|
|
385
|
+
]
|
|
386
|
+
)
|
|
387
|
+
return f"{self.__class__.__qualname__}({view})"
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def _create_reply_metadata(msg: Message, ttl: float) -> Metadata:
|
|
391
|
+
"""Construct metadata for a reply message."""
|
|
392
|
+
return Metadata(
|
|
393
|
+
run_id=msg.metadata.run_id,
|
|
394
|
+
message_id="",
|
|
395
|
+
src_node_id=msg.metadata.dst_node_id,
|
|
396
|
+
dst_node_id=msg.metadata.src_node_id,
|
|
397
|
+
reply_to_message=msg.metadata.message_id,
|
|
398
|
+
group_id=msg.metadata.group_id,
|
|
399
|
+
ttl=ttl,
|
|
400
|
+
message_type=msg.metadata.message_type,
|
|
401
|
+
partition_id=msg.metadata.partition_id,
|
|
402
|
+
)
|
flwr/common/record/recordset.py
CHANGED
|
@@ -16,23 +16,21 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from dataclasses import dataclass
|
|
19
|
-
from typing import
|
|
19
|
+
from typing import Dict, Optional, cast
|
|
20
20
|
|
|
21
21
|
from .configsrecord import ConfigsRecord
|
|
22
22
|
from .metricsrecord import MetricsRecord
|
|
23
23
|
from .parametersrecord import ParametersRecord
|
|
24
24
|
from .typeddict import TypedDict
|
|
25
25
|
|
|
26
|
-
T = TypeVar("T")
|
|
27
|
-
|
|
28
26
|
|
|
29
27
|
@dataclass
|
|
30
|
-
class
|
|
31
|
-
"""
|
|
28
|
+
class RecordSetData:
|
|
29
|
+
"""Inner data container for the RecordSet class."""
|
|
32
30
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
31
|
+
parameters_records: TypedDict[str, ParametersRecord]
|
|
32
|
+
metrics_records: TypedDict[str, MetricsRecord]
|
|
33
|
+
configs_records: TypedDict[str, ConfigsRecord]
|
|
36
34
|
|
|
37
35
|
def __init__(
|
|
38
36
|
self,
|
|
@@ -40,40 +38,93 @@ class RecordSet:
|
|
|
40
38
|
metrics_records: Optional[Dict[str, MetricsRecord]] = None,
|
|
41
39
|
configs_records: Optional[Dict[str, ConfigsRecord]] = None,
|
|
42
40
|
) -> None:
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
if not isinstance(__v, __t):
|
|
46
|
-
raise TypeError(f"Expected `{__t}`, but `{type(__v)}` was passed.")
|
|
47
|
-
|
|
48
|
-
return _check_fn
|
|
49
|
-
|
|
50
|
-
self._parameters_records = TypedDict[str, ParametersRecord](
|
|
51
|
-
_get_check_fn(str), _get_check_fn(ParametersRecord)
|
|
41
|
+
self.parameters_records = TypedDict[str, ParametersRecord](
|
|
42
|
+
self._check_fn_str, self._check_fn_params
|
|
52
43
|
)
|
|
53
|
-
self.
|
|
54
|
-
|
|
44
|
+
self.metrics_records = TypedDict[str, MetricsRecord](
|
|
45
|
+
self._check_fn_str, self._check_fn_metrics
|
|
55
46
|
)
|
|
56
|
-
self.
|
|
57
|
-
|
|
47
|
+
self.configs_records = TypedDict[str, ConfigsRecord](
|
|
48
|
+
self._check_fn_str, self._check_fn_configs
|
|
58
49
|
)
|
|
59
50
|
if parameters_records is not None:
|
|
60
|
-
self.
|
|
51
|
+
self.parameters_records.update(parameters_records)
|
|
61
52
|
if metrics_records is not None:
|
|
62
|
-
self.
|
|
53
|
+
self.metrics_records.update(metrics_records)
|
|
63
54
|
if configs_records is not None:
|
|
64
|
-
self.
|
|
55
|
+
self.configs_records.update(configs_records)
|
|
56
|
+
|
|
57
|
+
def _check_fn_str(self, key: str) -> None:
|
|
58
|
+
if not isinstance(key, str):
|
|
59
|
+
raise TypeError(
|
|
60
|
+
f"Expected `{str.__name__}`, but "
|
|
61
|
+
f"received `{type(key).__name__}` for the key."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def _check_fn_params(self, record: ParametersRecord) -> None:
|
|
65
|
+
if not isinstance(record, ParametersRecord):
|
|
66
|
+
raise TypeError(
|
|
67
|
+
f"Expected `{ParametersRecord.__name__}`, but "
|
|
68
|
+
f"received `{type(record).__name__}` for the value."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def _check_fn_metrics(self, record: MetricsRecord) -> None:
|
|
72
|
+
if not isinstance(record, MetricsRecord):
|
|
73
|
+
raise TypeError(
|
|
74
|
+
f"Expected `{MetricsRecord.__name__}`, but "
|
|
75
|
+
f"received `{type(record).__name__}` for the value."
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def _check_fn_configs(self, record: ConfigsRecord) -> None:
|
|
79
|
+
if not isinstance(record, ConfigsRecord):
|
|
80
|
+
raise TypeError(
|
|
81
|
+
f"Expected `{ConfigsRecord.__name__}`, but "
|
|
82
|
+
f"received `{type(record).__name__}` for the value."
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class RecordSet:
|
|
87
|
+
"""RecordSet stores groups of parameters, metrics and configs."""
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
parameters_records: Optional[Dict[str, ParametersRecord]] = None,
|
|
92
|
+
metrics_records: Optional[Dict[str, MetricsRecord]] = None,
|
|
93
|
+
configs_records: Optional[Dict[str, ConfigsRecord]] = None,
|
|
94
|
+
) -> None:
|
|
95
|
+
data = RecordSetData(
|
|
96
|
+
parameters_records=parameters_records,
|
|
97
|
+
metrics_records=metrics_records,
|
|
98
|
+
configs_records=configs_records,
|
|
99
|
+
)
|
|
100
|
+
self.__dict__["_data"] = data
|
|
65
101
|
|
|
66
102
|
@property
|
|
67
103
|
def parameters_records(self) -> TypedDict[str, ParametersRecord]:
|
|
68
104
|
"""Dictionary holding ParametersRecord instances."""
|
|
69
|
-
|
|
105
|
+
data = cast(RecordSetData, self.__dict__["_data"])
|
|
106
|
+
return data.parameters_records
|
|
70
107
|
|
|
71
108
|
@property
|
|
72
109
|
def metrics_records(self) -> TypedDict[str, MetricsRecord]:
|
|
73
110
|
"""Dictionary holding MetricsRecord instances."""
|
|
74
|
-
|
|
111
|
+
data = cast(RecordSetData, self.__dict__["_data"])
|
|
112
|
+
return data.metrics_records
|
|
75
113
|
|
|
76
114
|
@property
|
|
77
115
|
def configs_records(self) -> TypedDict[str, ConfigsRecord]:
|
|
78
116
|
"""Dictionary holding ConfigsRecord instances."""
|
|
79
|
-
|
|
117
|
+
data = cast(RecordSetData, self.__dict__["_data"])
|
|
118
|
+
return data.configs_records
|
|
119
|
+
|
|
120
|
+
def __repr__(self) -> str:
|
|
121
|
+
"""Return a string representation of this instance."""
|
|
122
|
+
flds = ("parameters_records", "metrics_records", "configs_records")
|
|
123
|
+
view = ", ".join([f"{fld}={getattr(self, fld)!r}" for fld in flds])
|
|
124
|
+
return f"{self.__class__.__qualname__}({view})"
|
|
125
|
+
|
|
126
|
+
def __eq__(self, other: object) -> bool:
|
|
127
|
+
"""Compare two instances of the class."""
|
|
128
|
+
if not isinstance(other, self.__class__):
|
|
129
|
+
raise NotImplementedError
|
|
130
|
+
return self.__dict__ == other.__dict__
|
|
@@ -18,8 +18,9 @@
|
|
|
18
18
|
import base64
|
|
19
19
|
from typing import Tuple, cast
|
|
20
20
|
|
|
21
|
+
from cryptography.exceptions import InvalidSignature
|
|
21
22
|
from cryptography.fernet import Fernet
|
|
22
|
-
from cryptography.hazmat.primitives import hashes, serialization
|
|
23
|
+
from cryptography.hazmat.primitives import hashes, hmac, serialization
|
|
23
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
25
|
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
|
25
26
|
|
|
@@ -98,3 +99,36 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes:
|
|
|
98
99
|
# The input key must be url safe
|
|
99
100
|
fernet = Fernet(key)
|
|
100
101
|
return fernet.decrypt(ciphertext)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def compute_hmac(key: bytes, message: bytes) -> bytes:
|
|
105
|
+
"""Compute hmac of a message using key as hash."""
|
|
106
|
+
computed_hmac = hmac.HMAC(key, hashes.SHA256())
|
|
107
|
+
computed_hmac.update(message)
|
|
108
|
+
return computed_hmac.finalize()
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
|
|
112
|
+
"""Verify hmac of a message using key as hash."""
|
|
113
|
+
computed_hmac = hmac.HMAC(key, hashes.SHA256())
|
|
114
|
+
computed_hmac.update(message)
|
|
115
|
+
try:
|
|
116
|
+
computed_hmac.verify(hmac_value)
|
|
117
|
+
return True
|
|
118
|
+
except InvalidSignature:
|
|
119
|
+
return False
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def ssh_types_to_elliptic_curve(
|
|
123
|
+
private_key: serialization.SSHPrivateKeyTypes,
|
|
124
|
+
public_key: serialization.SSHPublicKeyTypes,
|
|
125
|
+
) -> Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]:
|
|
126
|
+
"""Cast SSH key types to elliptic curve."""
|
|
127
|
+
if isinstance(private_key, ec.EllipticCurvePrivateKey) and isinstance(
|
|
128
|
+
public_key, ec.EllipticCurvePublicKey
|
|
129
|
+
):
|
|
130
|
+
return (private_key, public_key)
|
|
131
|
+
|
|
132
|
+
raise TypeError(
|
|
133
|
+
"The provided key is not an EllipticCurvePrivateKey or EllipticCurvePublicKey"
|
|
134
|
+
)
|
flwr/common/telemetry.py
CHANGED
|
@@ -160,6 +160,10 @@ class EventType(str, Enum):
|
|
|
160
160
|
RUN_SERVER_APP_ENTER = auto()
|
|
161
161
|
RUN_SERVER_APP_LEAVE = auto()
|
|
162
162
|
|
|
163
|
+
# SuperNode
|
|
164
|
+
RUN_SUPERNODE_ENTER = auto()
|
|
165
|
+
RUN_SUPERNODE_LEAVE = auto()
|
|
166
|
+
|
|
163
167
|
|
|
164
168
|
# Use the ThreadPoolExecutor with max_workers=1 to have a queue
|
|
165
169
|
# and also ensure that telemetry calls are not blocking.
|