flwr-nightly 1.17.0.dev20250318__py3-none-any.whl → 1.17.0.dev20250320__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/app.py +6 -4
- flwr/client/clientapp/app.py +2 -2
- flwr/client/grpc_client/connection.py +23 -20
- flwr/client/message_handler/message_handler.py +27 -27
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +5 -5
- flwr/client/run_info_store.py +2 -2
- flwr/common/__init__.py +2 -0
- flwr/common/constant.py +2 -0
- flwr/common/context.py +4 -4
- flwr/common/logger.py +2 -2
- flwr/common/message.py +269 -101
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/configsrecord.py +2 -2
- flwr/common/record/metricsrecord.py +1 -1
- flwr/common/record/parametersrecord.py +1 -1
- flwr/common/record/{recordset.py → recorddict.py} +57 -17
- flwr/common/{recordset_compat.py → recorddict_compat.py} +105 -105
- flwr/common/serde.py +33 -37
- flwr/proto/exec_pb2.py +32 -32
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +2 -2
- flwr/proto/run_pb2.py +32 -32
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +2 -0
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +11 -11
- flwr/server/compat/app_utils.py +16 -16
- flwr/server/compat/grid_client_proxy.py +38 -38
- flwr/server/grid/__init__.py +7 -6
- flwr/server/grid/grid.py +46 -17
- flwr/server/grid/grpc_grid.py +26 -33
- flwr/server/grid/inmemory_grid.py +19 -25
- flwr/server/run_serverapp.py +4 -4
- flwr/server/server_app.py +37 -11
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +1 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +29 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +54 -20
- flwr/server/superlink/linkstate/utils.py +77 -17
- flwr/server/superlink/serverappio/serverappio_servicer.py +1 -1
- flwr/server/typing.py +3 -3
- flwr/server/utils/validator.py +4 -4
- flwr/server/workflow/default_workflows.py +24 -26
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +23 -23
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +13 -13
- flwr/superexec/deployment.py +2 -2
- flwr/superexec/simulation.py +2 -2
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/METADATA +1 -1
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/RECORD +60 -60
- flwr/proto/recordset_pb2.py +0 -70
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/entry_points.txt +0 -0
@@ -27,19 +27,14 @@ from flwr.common.constant import (
|
|
27
27
|
Status,
|
28
28
|
SubStatus,
|
29
29
|
)
|
30
|
+
from flwr.common.message import make_message
|
30
31
|
from flwr.common.typing import RunStatus
|
31
32
|
|
32
33
|
# pylint: disable=E0611
|
33
34
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
34
|
-
from flwr.proto.
|
35
|
+
from flwr.proto.recorddict_pb2 import ConfigsRecord as ProtoConfigsRecord
|
35
36
|
|
36
37
|
# pylint: enable=E0611
|
37
|
-
|
38
|
-
NODE_UNAVAILABLE_ERROR_REASON = (
|
39
|
-
"Error: Node Unavailable - The destination node is currently unavailable. "
|
40
|
-
"It exceeds the time limit specified in its last ping."
|
41
|
-
)
|
42
|
-
|
43
38
|
VALID_RUN_STATUS_TRANSITIONS = {
|
44
39
|
(Status.PENDING, Status.STARTING),
|
45
40
|
(Status.STARTING, Status.RUNNING),
|
@@ -60,6 +55,10 @@ MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
60
55
|
REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
61
56
|
"Error: Reply Message Unavailable - The reply message has expired."
|
62
57
|
)
|
58
|
+
NODE_UNAVAILABLE_ERROR_REASON = (
|
59
|
+
"Error: Node Unavailable - The destination node is currently unavailable. "
|
60
|
+
"It exceeds twice the time limit specified in its last ping."
|
61
|
+
)
|
63
62
|
|
64
63
|
|
65
64
|
def generate_rand_int_from_bytes(
|
@@ -237,7 +236,9 @@ def has_valid_sub_status(status: RunStatus) -> bool:
|
|
237
236
|
return status.sub_status == ""
|
238
237
|
|
239
238
|
|
240
|
-
def create_message_error_unavailable_res_message(
|
239
|
+
def create_message_error_unavailable_res_message(
|
240
|
+
ins_metadata: Metadata, error_type: str
|
241
|
+
) -> Message:
|
241
242
|
"""Generate an error Message that the SuperLink returns carrying the specified
|
242
243
|
error."""
|
243
244
|
current_time = now().timestamp()
|
@@ -247,22 +248,31 @@ def create_message_error_unavailable_res_message(ins_metadata: Metadata) -> Mess
|
|
247
248
|
message_id=str(uuid4()),
|
248
249
|
src_node_id=SUPERLINK_NODE_ID,
|
249
250
|
dst_node_id=SUPERLINK_NODE_ID,
|
250
|
-
|
251
|
+
reply_to_message_id=ins_metadata.message_id,
|
251
252
|
group_id=ins_metadata.group_id,
|
252
253
|
message_type=ins_metadata.message_type,
|
254
|
+
created_at=current_time,
|
253
255
|
ttl=ttl,
|
254
256
|
)
|
255
257
|
|
256
|
-
return
|
258
|
+
return make_message(
|
257
259
|
metadata=metadata,
|
258
260
|
error=Error(
|
259
|
-
code=
|
260
|
-
|
261
|
+
code=(
|
262
|
+
ErrorCode.REPLY_MESSAGE_UNAVAILABLE
|
263
|
+
if error_type == "msg_unavail"
|
264
|
+
else ErrorCode.NODE_UNAVAILABLE
|
265
|
+
),
|
266
|
+
reason=(
|
267
|
+
REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON
|
268
|
+
if error_type == "msg_unavail"
|
269
|
+
else NODE_UNAVAILABLE_ERROR_REASON
|
270
|
+
),
|
261
271
|
),
|
262
272
|
)
|
263
273
|
|
264
274
|
|
265
|
-
def create_message_error_unavailable_ins_message(
|
275
|
+
def create_message_error_unavailable_ins_message(reply_to_message_id: UUID) -> Message:
|
266
276
|
"""Error to indicate that the enquired Message had expired before reply arrived or
|
267
277
|
that it isn't found."""
|
268
278
|
metadata = Metadata(
|
@@ -270,13 +280,14 @@ def create_message_error_unavailable_ins_message(reply_to_message: UUID) -> Mess
|
|
270
280
|
message_id=str(uuid4()),
|
271
281
|
src_node_id=SUPERLINK_NODE_ID,
|
272
282
|
dst_node_id=SUPERLINK_NODE_ID,
|
273
|
-
|
283
|
+
reply_to_message_id=str(reply_to_message_id),
|
274
284
|
group_id="", # Unknown
|
275
285
|
message_type=MessageType.SYSTEM,
|
286
|
+
created_at=now().timestamp(),
|
276
287
|
ttl=0,
|
277
288
|
)
|
278
289
|
|
279
|
-
return
|
290
|
+
return make_message(
|
280
291
|
metadata=metadata,
|
281
292
|
error=Error(
|
282
293
|
code=ErrorCode.MESSAGE_UNAVAILABLE,
|
@@ -364,14 +375,63 @@ def verify_found_message_replies(
|
|
364
375
|
ret_dict: dict[UUID, Message] = {}
|
365
376
|
current = current_time if current_time else now().timestamp()
|
366
377
|
for message_res in found_message_res_list:
|
367
|
-
message_ins_id = UUID(message_res.metadata.
|
378
|
+
message_ins_id = UUID(message_res.metadata.reply_to_message_id)
|
368
379
|
if update_set:
|
369
380
|
inquired_message_ids.remove(message_ins_id)
|
370
381
|
# Check if the reply Message has expired
|
371
382
|
if message_ttl_has_expired(message_res.metadata, current):
|
372
383
|
# No need to insert the error Message
|
373
384
|
message_res = create_message_error_unavailable_res_message(
|
374
|
-
found_message_ins_dict[message_ins_id].metadata
|
385
|
+
found_message_ins_dict[message_ins_id].metadata, "msg_unavail"
|
375
386
|
)
|
376
387
|
ret_dict[message_ins_id] = message_res
|
377
388
|
return ret_dict
|
389
|
+
|
390
|
+
|
391
|
+
def check_node_availability_for_in_message(
|
392
|
+
inquired_in_message_ids: set[UUID],
|
393
|
+
found_in_message_dict: dict[UUID, Message],
|
394
|
+
node_id_to_online_until: dict[int, float],
|
395
|
+
current_time: Optional[float] = None,
|
396
|
+
update_set: bool = True,
|
397
|
+
) -> dict[UUID, Message]:
|
398
|
+
"""Check node availability for given Message and generate error reply Message if
|
399
|
+
unavailable. A Message error indicating node unavailability will be generated for
|
400
|
+
each given Message whose destination node is offline or non-existent.
|
401
|
+
|
402
|
+
Parameters
|
403
|
+
----------
|
404
|
+
inquired_in_message_ids : set[UUID]
|
405
|
+
Set of Message IDs for which to check destination node availability.
|
406
|
+
found_in_message_dict : dict[UUID, Message]
|
407
|
+
Dictionary containing all found Message indexed by their IDs.
|
408
|
+
node_id_to_online_until : dict[int, float]
|
409
|
+
Dictionary mapping node IDs to their online-until timestamps.
|
410
|
+
current_time : Optional[float] (default: None)
|
411
|
+
The current time to check for expiration. If set to `None`, the current time
|
412
|
+
will automatically be set to the current timestamp using `now().timestamp()`.
|
413
|
+
update_set : bool (default: True)
|
414
|
+
If True, the `inquired_in_message_ids` will be updated to remove invalid ones,
|
415
|
+
by default True.
|
416
|
+
|
417
|
+
Returns
|
418
|
+
-------
|
419
|
+
dict[UUID, Message]
|
420
|
+
A dictionary of error Message indexed by the corresponding Message ID.
|
421
|
+
"""
|
422
|
+
ret_dict = {}
|
423
|
+
current = current_time if current_time else now().timestamp()
|
424
|
+
for in_message_id in list(inquired_in_message_ids):
|
425
|
+
in_message = found_in_message_dict[in_message_id]
|
426
|
+
node_id = in_message.metadata.dst_node_id
|
427
|
+
online_until = node_id_to_online_until.get(node_id)
|
428
|
+
# Generate a reply message containing an error reply
|
429
|
+
# if the node is offline or doesn't exist.
|
430
|
+
if online_until is None or online_until < current:
|
431
|
+
if update_set:
|
432
|
+
inquired_in_message_ids.remove(in_message_id)
|
433
|
+
reply_message = create_message_error_unavailable_res_message(
|
434
|
+
in_message.metadata, "node_unavail"
|
435
|
+
)
|
436
|
+
ret_dict[in_message_id] = reply_message
|
437
|
+
return ret_dict
|
@@ -206,7 +206,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
206
206
|
|
207
207
|
# Delete the instruction Messages and their replies if found
|
208
208
|
message_ins_ids_to_delete = {
|
209
|
-
UUID(msg_res.metadata.
|
209
|
+
UUID(msg_res.metadata.reply_to_message_id) for msg_res in messages_res
|
210
210
|
}
|
211
211
|
|
212
212
|
state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
flwr/server/typing.py
CHANGED
@@ -19,9 +19,9 @@ from typing import Callable
|
|
19
19
|
|
20
20
|
from flwr.common import Context
|
21
21
|
|
22
|
-
from .grid import
|
22
|
+
from .grid import Grid
|
23
23
|
from .serverapp_components import ServerAppComponents
|
24
24
|
|
25
|
-
ServerAppCallable = Callable[[
|
26
|
-
Workflow = Callable[[
|
25
|
+
ServerAppCallable = Callable[[Grid, Context], None]
|
26
|
+
Workflow = Callable[[Grid, Context], None]
|
27
27
|
ServerFn = Callable[[Context], ServerAppComponents]
|
flwr/server/utils/validator.py
CHANGED
@@ -68,8 +68,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
68
68
|
|
69
69
|
# Link respose to original message
|
70
70
|
if not is_reply_message:
|
71
|
-
if metadata.
|
72
|
-
validation_errors.append("`metadata.
|
71
|
+
if metadata.reply_to_message_id != "":
|
72
|
+
validation_errors.append("`metadata.reply_to_message_id` MUST not be set.")
|
73
73
|
if metadata.src_node_id != SUPERLINK_NODE_ID:
|
74
74
|
validation_errors.append(
|
75
75
|
f"`metadata.src_node_id` is not {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
@@ -79,8 +79,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
79
79
|
f"`metadata.dst_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
80
80
|
)
|
81
81
|
else:
|
82
|
-
if metadata.
|
83
|
-
validation_errors.append("`metadata.
|
82
|
+
if metadata.reply_to_message_id == "":
|
83
|
+
validation_errors.append("`metadata.reply_to_message_id` MUST be set.")
|
84
84
|
if metadata.src_node_id == SUPERLINK_NODE_ID:
|
85
85
|
validation_errors.append(
|
86
86
|
f"`metadata.src_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
@@ -20,7 +20,7 @@ import timeit
|
|
20
20
|
from logging import INFO, WARN
|
21
21
|
from typing import Optional, Union, cast
|
22
22
|
|
23
|
-
import flwr.common.
|
23
|
+
import flwr.common.recorddict_compat as compat
|
24
24
|
from flwr.common import (
|
25
25
|
Code,
|
26
26
|
ConfigsRecord,
|
@@ -36,7 +36,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
36
36
|
from ..client_proxy import ClientProxy
|
37
37
|
from ..compat.app_utils import start_update_client_manager_thread
|
38
38
|
from ..compat.legacy_context import LegacyContext
|
39
|
-
from ..grid import
|
39
|
+
from ..grid import Grid
|
40
40
|
from ..typing import Workflow
|
41
41
|
from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key
|
42
42
|
|
@@ -56,7 +56,7 @@ class DefaultWorkflow:
|
|
56
56
|
self.fit_workflow: Workflow = fit_workflow
|
57
57
|
self.evaluate_workflow: Workflow = evaluate_workflow
|
58
58
|
|
59
|
-
def __call__(self,
|
59
|
+
def __call__(self, grid: Grid, context: Context) -> None:
|
60
60
|
"""Execute the workflow."""
|
61
61
|
if not isinstance(context, LegacyContext):
|
62
62
|
raise TypeError(
|
@@ -65,7 +65,7 @@ class DefaultWorkflow:
|
|
65
65
|
|
66
66
|
# Start the thread updating nodes
|
67
67
|
thread, f_stop, c_done = start_update_client_manager_thread(
|
68
|
-
|
68
|
+
grid, context.client_manager
|
69
69
|
)
|
70
70
|
|
71
71
|
# Wait until the node registration done
|
@@ -73,7 +73,7 @@ class DefaultWorkflow:
|
|
73
73
|
|
74
74
|
# Initialize parameters
|
75
75
|
log(INFO, "[INIT]")
|
76
|
-
default_init_params_workflow(
|
76
|
+
default_init_params_workflow(grid, context)
|
77
77
|
|
78
78
|
# Run federated learning for num_rounds
|
79
79
|
start_time = timeit.default_timer()
|
@@ -87,13 +87,13 @@ class DefaultWorkflow:
|
|
87
87
|
cfg[Key.CURRENT_ROUND] = current_round
|
88
88
|
|
89
89
|
# Fit round
|
90
|
-
self.fit_workflow(
|
90
|
+
self.fit_workflow(grid, context)
|
91
91
|
|
92
92
|
# Centralized evaluation
|
93
|
-
default_centralized_evaluation_workflow(
|
93
|
+
default_centralized_evaluation_workflow(grid, context)
|
94
94
|
|
95
95
|
# Evaluate round
|
96
|
-
self.evaluate_workflow(
|
96
|
+
self.evaluate_workflow(grid, context)
|
97
97
|
|
98
98
|
# Bookkeeping and log results
|
99
99
|
end_time = timeit.default_timer()
|
@@ -119,7 +119,7 @@ class DefaultWorkflow:
|
|
119
119
|
thread.join()
|
120
120
|
|
121
121
|
|
122
|
-
def default_init_params_workflow(
|
122
|
+
def default_init_params_workflow(grid: Grid, context: Context) -> None:
|
123
123
|
"""Execute the default workflow for parameters initialization."""
|
124
124
|
if not isinstance(context, LegacyContext):
|
125
125
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
@@ -137,10 +137,10 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
137
137
|
log(INFO, "Requesting initial parameters from one random client")
|
138
138
|
random_client = context.client_manager.sample(1)[0]
|
139
139
|
# Send GetParametersIns and get the response
|
140
|
-
content = compat.
|
141
|
-
messages =
|
140
|
+
content = compat.getparametersins_to_recorddict(GetParametersIns({}))
|
141
|
+
messages = grid.send_and_receive(
|
142
142
|
[
|
143
|
-
|
143
|
+
grid.create_message(
|
144
144
|
content=content,
|
145
145
|
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
146
146
|
dst_node_id=random_client.node_id,
|
@@ -152,7 +152,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
152
152
|
|
153
153
|
if (
|
154
154
|
msg.has_content()
|
155
|
-
and compat.
|
155
|
+
and compat._extract_status_from_recorddict( # pylint: disable=W0212
|
156
156
|
"getparametersres", msg.content
|
157
157
|
).code
|
158
158
|
== Code.OK
|
@@ -186,7 +186,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
186
186
|
log(INFO, "Evaluation returned no results (`None`)")
|
187
187
|
|
188
188
|
|
189
|
-
def default_centralized_evaluation_workflow(_:
|
189
|
+
def default_centralized_evaluation_workflow(_: Grid, context: Context) -> None:
|
190
190
|
"""Execute the default workflow for centralized evaluation."""
|
191
191
|
if not isinstance(context, LegacyContext):
|
192
192
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
@@ -218,9 +218,7 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
|
|
218
218
|
)
|
219
219
|
|
220
220
|
|
221
|
-
def default_fit_workflow( # pylint: disable=R0914
|
222
|
-
driver: Driver, context: Context
|
223
|
-
) -> None:
|
221
|
+
def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disable=R0914
|
224
222
|
"""Execute the default workflow for a single fit round."""
|
225
223
|
if not isinstance(context, LegacyContext):
|
226
224
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
@@ -255,8 +253,8 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
255
253
|
|
256
254
|
# Build out messages
|
257
255
|
out_messages = [
|
258
|
-
|
259
|
-
content=compat.
|
256
|
+
grid.create_message(
|
257
|
+
content=compat.fitins_to_recorddict(fitins, True),
|
260
258
|
message_type=MessageType.TRAIN,
|
261
259
|
dst_node_id=proxy.node_id,
|
262
260
|
group_id=str(current_round),
|
@@ -266,7 +264,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
266
264
|
|
267
265
|
# Send instructions to clients and
|
268
266
|
# collect `fit` results from all clients participating in this round
|
269
|
-
messages = list(
|
267
|
+
messages = list(grid.send_and_receive(out_messages))
|
270
268
|
del out_messages
|
271
269
|
num_failures = len([msg for msg in messages if msg.has_error()])
|
272
270
|
|
@@ -284,7 +282,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
284
282
|
for msg in messages:
|
285
283
|
if msg.has_content():
|
286
284
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
287
|
-
fitres = compat.
|
285
|
+
fitres = compat.recorddict_to_fitres(msg.content, False)
|
288
286
|
if fitres.status.code == Code.OK:
|
289
287
|
results.append((proxy, fitres))
|
290
288
|
else:
|
@@ -307,7 +305,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
307
305
|
|
308
306
|
|
309
307
|
# pylint: disable-next=R0914
|
310
|
-
def default_evaluate_workflow(
|
308
|
+
def default_evaluate_workflow(grid: Grid, context: Context) -> None:
|
311
309
|
"""Execute the default workflow for a single evaluate round."""
|
312
310
|
if not isinstance(context, LegacyContext):
|
313
311
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
@@ -341,8 +339,8 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
341
339
|
|
342
340
|
# Build out messages
|
343
341
|
out_messages = [
|
344
|
-
|
345
|
-
content=compat.
|
342
|
+
grid.create_message(
|
343
|
+
content=compat.evaluateins_to_recorddict(evalins, True),
|
346
344
|
message_type=MessageType.EVALUATE,
|
347
345
|
dst_node_id=proxy.node_id,
|
348
346
|
group_id=str(current_round),
|
@@ -352,7 +350,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
352
350
|
|
353
351
|
# Send instructions to clients and
|
354
352
|
# collect `evaluate` results from all clients participating in this round
|
355
|
-
messages = list(
|
353
|
+
messages = list(grid.send_and_receive(out_messages))
|
356
354
|
del out_messages
|
357
355
|
num_failures = len([msg for msg in messages if msg.has_error()])
|
358
356
|
|
@@ -370,7 +368,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
370
368
|
for msg in messages:
|
371
369
|
if msg.has_content():
|
372
370
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
373
|
-
evalres = compat.
|
371
|
+
evalres = compat.recorddict_to_evaluateres(msg.content)
|
374
372
|
if evalres.status.code == Code.OK:
|
375
373
|
results.append((proxy, evalres))
|
376
374
|
else:
|
@@ -20,7 +20,7 @@ from dataclasses import dataclass, field
|
|
20
20
|
from logging import DEBUG, ERROR, INFO, WARN
|
21
21
|
from typing import Optional, Union, cast
|
22
22
|
|
23
|
-
import flwr.common.
|
23
|
+
import flwr.common.recorddict_compat as compat
|
24
24
|
from flwr.common import (
|
25
25
|
ConfigsRecord,
|
26
26
|
Context,
|
@@ -28,7 +28,7 @@ from flwr.common import (
|
|
28
28
|
Message,
|
29
29
|
MessageType,
|
30
30
|
NDArrays,
|
31
|
-
|
31
|
+
RecordDict,
|
32
32
|
bytes_to_ndarray,
|
33
33
|
log,
|
34
34
|
ndarrays_to_parameters,
|
@@ -55,7 +55,7 @@ from flwr.common.secure_aggregation.secaggplus_constants import (
|
|
55
55
|
from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
|
56
56
|
from flwr.server.client_proxy import ClientProxy
|
57
57
|
from flwr.server.compat.legacy_context import LegacyContext
|
58
|
-
from flwr.server.grid import
|
58
|
+
from flwr.server.grid import Grid
|
59
59
|
|
60
60
|
from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
|
61
61
|
from ..constant import Key as WorkflowKey
|
@@ -66,7 +66,7 @@ class WorkflowState: # pylint: disable=R0902
|
|
66
66
|
"""The state of the SecAgg+ protocol."""
|
67
67
|
|
68
68
|
nid_to_proxies: dict[int, ClientProxy] = field(default_factory=dict)
|
69
|
-
nid_to_fitins: dict[int,
|
69
|
+
nid_to_fitins: dict[int, RecordDict] = field(default_factory=dict)
|
70
70
|
sampled_node_ids: set[int] = field(default_factory=set)
|
71
71
|
active_node_ids: set[int] = field(default_factory=set)
|
72
72
|
num_shares: int = 0
|
@@ -186,7 +186,7 @@ class SecAggPlusWorkflow:
|
|
186
186
|
|
187
187
|
self._check_init_params()
|
188
188
|
|
189
|
-
def __call__(self,
|
189
|
+
def __call__(self, grid: Grid, context: Context) -> None:
|
190
190
|
"""Run the SecAgg+ protocol."""
|
191
191
|
if not isinstance(context, LegacyContext):
|
192
192
|
raise TypeError(
|
@@ -202,7 +202,7 @@ class SecAggPlusWorkflow:
|
|
202
202
|
)
|
203
203
|
log(INFO, "Secure aggregation commencing.")
|
204
204
|
for step in steps:
|
205
|
-
if not step(
|
205
|
+
if not step(grid, context, state):
|
206
206
|
log(INFO, "Secure aggregation halted.")
|
207
207
|
return
|
208
208
|
log(INFO, "Secure aggregation completed.")
|
@@ -279,7 +279,7 @@ class SecAggPlusWorkflow:
|
|
279
279
|
return True
|
280
280
|
|
281
281
|
def setup_stage( # pylint: disable=R0912, R0914, R0915
|
282
|
-
self,
|
282
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
283
283
|
) -> bool:
|
284
284
|
"""Execute the 'setup' stage."""
|
285
285
|
# Obtain fit instructions
|
@@ -303,7 +303,7 @@ class SecAggPlusWorkflow:
|
|
303
303
|
)
|
304
304
|
|
305
305
|
state.nid_to_fitins = {
|
306
|
-
proxy.node_id: compat.
|
306
|
+
proxy.node_id: compat.fitins_to_recorddict(fitins, True)
|
307
307
|
for proxy, fitins in proxy_fitins_lst
|
308
308
|
}
|
309
309
|
state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst}
|
@@ -367,10 +367,10 @@ class SecAggPlusWorkflow:
|
|
367
367
|
|
368
368
|
# Send setup configuration to clients
|
369
369
|
cfgs_record = ConfigsRecord(sa_params_dict) # type: ignore
|
370
|
-
content =
|
370
|
+
content = RecordDict({RECORD_KEY_CONFIGS: cfgs_record})
|
371
371
|
|
372
372
|
def make(nid: int) -> Message:
|
373
|
-
return
|
373
|
+
return grid.create_message(
|
374
374
|
content=content,
|
375
375
|
message_type=MessageType.TRAIN,
|
376
376
|
dst_node_id=nid,
|
@@ -382,7 +382,7 @@ class SecAggPlusWorkflow:
|
|
382
382
|
"[Stage 0] Sending configurations to %s clients.",
|
383
383
|
len(state.active_node_ids),
|
384
384
|
)
|
385
|
-
msgs =
|
385
|
+
msgs = grid.send_and_receive(
|
386
386
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
387
387
|
)
|
388
388
|
state.active_node_ids = {
|
@@ -406,7 +406,7 @@ class SecAggPlusWorkflow:
|
|
406
406
|
return self._check_threshold(state)
|
407
407
|
|
408
408
|
def share_keys_stage( # pylint: disable=R0914
|
409
|
-
self,
|
409
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
410
410
|
) -> bool:
|
411
411
|
"""Execute the 'share keys' stage."""
|
412
412
|
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
@@ -417,8 +417,8 @@ class SecAggPlusWorkflow:
|
|
417
417
|
{str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
|
418
418
|
)
|
419
419
|
cfgs_record[Key.STAGE] = Stage.SHARE_KEYS
|
420
|
-
content =
|
421
|
-
return
|
420
|
+
content = RecordDict({RECORD_KEY_CONFIGS: cfgs_record})
|
421
|
+
return grid.create_message(
|
422
422
|
content=content,
|
423
423
|
message_type=MessageType.TRAIN,
|
424
424
|
dst_node_id=nid,
|
@@ -431,7 +431,7 @@ class SecAggPlusWorkflow:
|
|
431
431
|
"[Stage 1] Forwarding public keys to %s clients.",
|
432
432
|
len(state.active_node_ids),
|
433
433
|
)
|
434
|
-
msgs =
|
434
|
+
msgs = grid.send_and_receive(
|
435
435
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
436
436
|
)
|
437
437
|
state.active_node_ids = {
|
@@ -476,7 +476,7 @@ class SecAggPlusWorkflow:
|
|
476
476
|
return self._check_threshold(state)
|
477
477
|
|
478
478
|
def collect_masked_vectors_stage(
|
479
|
-
self,
|
479
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
480
480
|
) -> bool:
|
481
481
|
"""Execute the 'collect masked vectors' stage."""
|
482
482
|
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
@@ -491,7 +491,7 @@ class SecAggPlusWorkflow:
|
|
491
491
|
cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
|
492
492
|
content = state.nid_to_fitins[nid]
|
493
493
|
content.configs_records[RECORD_KEY_CONFIGS] = cfgs_record
|
494
|
-
return
|
494
|
+
return grid.create_message(
|
495
495
|
content=content,
|
496
496
|
message_type=MessageType.TRAIN,
|
497
497
|
dst_node_id=nid,
|
@@ -503,7 +503,7 @@ class SecAggPlusWorkflow:
|
|
503
503
|
"[Stage 2] Forwarding encrypted key shares to %s clients.",
|
504
504
|
len(state.active_node_ids),
|
505
505
|
)
|
506
|
-
msgs =
|
506
|
+
msgs = grid.send_and_receive(
|
507
507
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
508
508
|
)
|
509
509
|
state.active_node_ids = {
|
@@ -540,14 +540,14 @@ class SecAggPlusWorkflow:
|
|
540
540
|
if msg.has_error():
|
541
541
|
state.failures.append(Exception(msg.error))
|
542
542
|
continue
|
543
|
-
fitres = compat.
|
543
|
+
fitres = compat.recorddict_to_fitres(msg.content, True)
|
544
544
|
proxy = state.nid_to_proxies[msg.metadata.src_node_id]
|
545
545
|
state.legacy_results.append((proxy, fitres))
|
546
546
|
|
547
547
|
return self._check_threshold(state)
|
548
548
|
|
549
549
|
def unmask_stage( # pylint: disable=R0912, R0914, R0915
|
550
|
-
self,
|
550
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
551
551
|
) -> bool:
|
552
552
|
"""Execute the 'unmask' stage."""
|
553
553
|
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
@@ -566,8 +566,8 @@ class SecAggPlusWorkflow:
|
|
566
566
|
Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
|
567
567
|
}
|
568
568
|
cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
|
569
|
-
content =
|
570
|
-
return
|
569
|
+
content = RecordDict({RECORD_KEY_CONFIGS: cfgs_record})
|
570
|
+
return grid.create_message(
|
571
571
|
content=content,
|
572
572
|
message_type=MessageType.TRAIN,
|
573
573
|
dst_node_id=nid,
|
@@ -579,7 +579,7 @@ class SecAggPlusWorkflow:
|
|
579
579
|
"[Stage 3] Requesting key shares from %s clients to remove masks.",
|
580
580
|
len(state.active_node_ids),
|
581
581
|
)
|
582
|
-
msgs =
|
582
|
+
msgs = grid.send_and_receive(
|
583
583
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
584
584
|
)
|
585
585
|
state.active_node_ids = {
|