flwr-nightly 1.17.0.dev20250318__py3-none-any.whl → 1.17.0.dev20250319__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/common/constant.py +2 -0
- flwr/common/logger.py +2 -2
- flwr/server/__init__.py +2 -0
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +11 -11
- flwr/server/compat/app_utils.py +16 -16
- flwr/server/compat/grid_client_proxy.py +8 -8
- flwr/server/grid/__init__.py +7 -6
- flwr/server/grid/grid.py +43 -14
- flwr/server/grid/grpc_grid.py +11 -10
- flwr/server/grid/inmemory_grid.py +5 -5
- flwr/server/run_serverapp.py +4 -4
- flwr/server/server_app.py +37 -11
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/linkstate/in_memory_linkstate.py +28 -3
- flwr/server/superlink/linkstate/sqlite_linkstate.py +40 -2
- flwr/server/superlink/linkstate/utils.py +67 -10
- flwr/server/typing.py +3 -3
- flwr/server/workflow/default_workflows.py +17 -19
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +15 -15
- flwr/simulation/run_simulation.py +10 -10
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/METADATA +1 -1
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/RECORD +26 -26
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/entry_points.txt +0 -0
flwr/server/serverapp/app.py
CHANGED
@@ -60,7 +60,7 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
60
60
|
PullServerAppInputsResponse,
|
61
61
|
PushServerAppOutputsRequest,
|
62
62
|
)
|
63
|
-
from flwr.server.grid.grpc_grid import
|
63
|
+
from flwr.server.grid.grpc_grid import GrpcGrid
|
64
64
|
from flwr.server.run_serverapp import run as run_
|
65
65
|
|
66
66
|
|
@@ -106,7 +106,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
106
106
|
certificates: Optional[bytes] = None,
|
107
107
|
) -> None:
|
108
108
|
"""Run Flower ServerApp process."""
|
109
|
-
|
109
|
+
grid = GrpcGrid(
|
110
110
|
serverappio_service_address=serverappio_api_address,
|
111
111
|
root_certificates=certificates,
|
112
112
|
)
|
@@ -123,7 +123,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
123
123
|
# Pull ServerAppInputs from LinkState
|
124
124
|
req = PullServerAppInputsRequest()
|
125
125
|
log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
|
126
|
-
res: PullServerAppInputsResponse =
|
126
|
+
res: PullServerAppInputsResponse = grid._stub.PullServerAppInputs(req)
|
127
127
|
if not res.HasField("run"):
|
128
128
|
sleep(3)
|
129
129
|
run_status = None
|
@@ -135,14 +135,14 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
135
135
|
|
136
136
|
hash_run_id = get_sha256_hash(run.run_id)
|
137
137
|
|
138
|
-
|
138
|
+
grid.set_run(run.run_id)
|
139
139
|
|
140
140
|
# Start log uploader for this run
|
141
141
|
log_uploader = start_log_uploader(
|
142
142
|
log_queue=log_queue,
|
143
143
|
node_id=0,
|
144
144
|
run_id=run.run_id,
|
145
|
-
stub=
|
145
|
+
stub=grid._stub,
|
146
146
|
)
|
147
147
|
|
148
148
|
log(DEBUG, "[flwr-serverapp] Start FAB installation.")
|
@@ -173,7 +173,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
173
173
|
|
174
174
|
# Change status to Running
|
175
175
|
run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
|
176
|
-
|
176
|
+
grid._stub.UpdateRunStatus(
|
177
177
|
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
178
178
|
)
|
179
179
|
|
@@ -182,9 +182,9 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
182
182
|
event_details={"run-id-hash": hash_run_id},
|
183
183
|
)
|
184
184
|
|
185
|
-
# Load and run the ServerApp with the
|
185
|
+
# Load and run the ServerApp with the Grid
|
186
186
|
updated_context = run_(
|
187
|
-
|
187
|
+
grid=grid,
|
188
188
|
server_app_dir=app_path,
|
189
189
|
server_app_attr=server_app_attr,
|
190
190
|
context=context,
|
@@ -196,7 +196,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
196
196
|
out_req = PushServerAppOutputsRequest(
|
197
197
|
run_id=run.run_id, context=context_proto
|
198
198
|
)
|
199
|
-
_ =
|
199
|
+
_ = grid._stub.PushServerAppOutputs(out_req)
|
200
200
|
|
201
201
|
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
202
202
|
except RunNotRunningException:
|
@@ -221,7 +221,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
221
221
|
# Update run status
|
222
222
|
if run_status:
|
223
223
|
run_status_proto = run_status_to_proto(run_status)
|
224
|
-
|
224
|
+
grid._stub.UpdateRunStatus(
|
225
225
|
UpdateRunStatusRequest(
|
226
226
|
run_id=run.run_id, run_status=run_status_proto
|
227
227
|
)
|
@@ -27,6 +27,7 @@ from flwr.common import Context, Message, log, now
|
|
27
27
|
from flwr.common.constant import (
|
28
28
|
MESSAGE_TTL_TOLERANCE,
|
29
29
|
NODE_ID_NUM_BYTES,
|
30
|
+
PING_PATIENCE,
|
30
31
|
RUN_ID_NUM_BYTES,
|
31
32
|
SUPERLINK_NODE_ID,
|
32
33
|
Status,
|
@@ -37,6 +38,7 @@ from flwr.server.superlink.linkstate.linkstate import LinkState
|
|
37
38
|
from flwr.server.utils import validate_message
|
38
39
|
|
39
40
|
from .utils import (
|
41
|
+
check_node_availability_for_in_message,
|
40
42
|
generate_rand_int_from_bytes,
|
41
43
|
has_valid_sub_status,
|
42
44
|
is_valid_transition,
|
@@ -232,13 +234,28 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
232
234
|
with self.lock:
|
233
235
|
current = time.time()
|
234
236
|
|
235
|
-
# Verify
|
237
|
+
# Verify Message IDs
|
236
238
|
ret = verify_message_ids(
|
237
239
|
inquired_message_ids=message_ids,
|
238
240
|
found_message_ins_dict=self.message_ins_store,
|
239
241
|
current_time=current,
|
240
242
|
)
|
241
243
|
|
244
|
+
# Check node availability
|
245
|
+
dst_node_ids = {
|
246
|
+
self.message_ins_store[message_id].metadata.dst_node_id
|
247
|
+
for message_id in message_ids
|
248
|
+
}
|
249
|
+
tmp_ret_dict = check_node_availability_for_in_message(
|
250
|
+
inquired_in_message_ids=message_ids,
|
251
|
+
found_in_message_dict=self.message_ins_store,
|
252
|
+
node_id_to_online_until={
|
253
|
+
node_id: self.node_ids[node_id][0] for node_id in dst_node_ids
|
254
|
+
},
|
255
|
+
current_time=current,
|
256
|
+
)
|
257
|
+
ret.update(tmp_ret_dict)
|
258
|
+
|
242
259
|
# Find all reply Messages
|
243
260
|
message_res_found: list[Message] = []
|
244
261
|
for message_id in message_ids:
|
@@ -317,6 +334,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
317
334
|
log(ERROR, "Unexpected node registration failure.")
|
318
335
|
return 0
|
319
336
|
|
337
|
+
# Mark the node online util time.time() + ping_interval
|
320
338
|
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
321
339
|
return node_id
|
322
340
|
|
@@ -519,10 +537,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
519
537
|
return self.federation_options[run_id]
|
520
538
|
|
521
539
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
522
|
-
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
540
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
541
|
+
|
542
|
+
It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
|
543
|
+
marking the node as offline, where PING_PATIENCE = 2 in default.
|
544
|
+
"""
|
523
545
|
with self.lock:
|
524
546
|
if node_id in self.node_ids:
|
525
|
-
self.node_ids[node_id] = (
|
547
|
+
self.node_ids[node_id] = (
|
548
|
+
time.time() + PING_PATIENCE * ping_interval,
|
549
|
+
ping_interval,
|
550
|
+
)
|
526
551
|
return True
|
527
552
|
return False
|
528
553
|
|
@@ -30,6 +30,7 @@ from flwr.common import Context, Message, Metadata, log, now
|
|
30
30
|
from flwr.common.constant import (
|
31
31
|
MESSAGE_TTL_TOLERANCE,
|
32
32
|
NODE_ID_NUM_BYTES,
|
33
|
+
PING_PATIENCE,
|
33
34
|
RUN_ID_NUM_BYTES,
|
34
35
|
SUPERLINK_NODE_ID,
|
35
36
|
Status,
|
@@ -52,6 +53,7 @@ from flwr.server.utils.validator import validate_message
|
|
52
53
|
|
53
54
|
from .linkstate import LinkState
|
54
55
|
from .utils import (
|
56
|
+
check_node_availability_for_in_message,
|
55
57
|
configsrecord_from_bytes,
|
56
58
|
configsrecord_to_bytes,
|
57
59
|
context_from_bytes,
|
@@ -442,6 +444,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
442
444
|
|
443
445
|
def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
|
444
446
|
"""Get reply Messages for the given Message IDs."""
|
447
|
+
# pylint: disable-msg=too-many-locals
|
445
448
|
ret: dict[UUID, Message] = {}
|
446
449
|
|
447
450
|
# Verify Message IDs
|
@@ -465,6 +468,29 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
465
468
|
current_time=current,
|
466
469
|
)
|
467
470
|
|
471
|
+
# Check node availability
|
472
|
+
dst_node_ids: set[int] = set()
|
473
|
+
for message_id in message_ids:
|
474
|
+
in_message = found_message_ins_dict[message_id]
|
475
|
+
sint_node_id = convert_uint64_to_sint64(in_message.metadata.dst_node_id)
|
476
|
+
dst_node_ids.add(sint_node_id)
|
477
|
+
query = f"""
|
478
|
+
SELECT node_id, online_until
|
479
|
+
FROM node
|
480
|
+
WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))});
|
481
|
+
"""
|
482
|
+
rows = self.query(query, tuple(dst_node_ids))
|
483
|
+
tmp_ret_dict = check_node_availability_for_in_message(
|
484
|
+
inquired_in_message_ids=message_ids,
|
485
|
+
found_in_message_dict=found_message_ins_dict,
|
486
|
+
node_id_to_online_until={
|
487
|
+
convert_sint64_to_uint64(row["node_id"]): row["online_until"]
|
488
|
+
for row in rows
|
489
|
+
},
|
490
|
+
current_time=current,
|
491
|
+
)
|
492
|
+
ret.update(tmp_ret_dict)
|
493
|
+
|
468
494
|
# Find all reply Messages
|
469
495
|
query = f"""
|
470
496
|
SELECT *
|
@@ -584,6 +610,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
584
610
|
"VALUES (?, ?, ?, ?)"
|
585
611
|
)
|
586
612
|
|
613
|
+
# Mark the node online util time.time() + ping_interval
|
587
614
|
try:
|
588
615
|
self.query(
|
589
616
|
query,
|
@@ -899,7 +926,11 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
899
926
|
return configsrecord_from_bytes(row["federation_options"])
|
900
927
|
|
901
928
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
902
|
-
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
929
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
930
|
+
|
931
|
+
It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
|
932
|
+
marking the node as offline, where PING_PATIENCE = 2 in default.
|
933
|
+
"""
|
903
934
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
904
935
|
|
905
936
|
# Check if the node exists in the `node` table
|
@@ -909,7 +940,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
909
940
|
|
910
941
|
# Update `online_until` and `ping_interval` for the given `node_id`
|
911
942
|
query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
|
912
|
-
self.query(
|
943
|
+
self.query(
|
944
|
+
query,
|
945
|
+
(
|
946
|
+
time.time() + PING_PATIENCE * ping_interval,
|
947
|
+
ping_interval,
|
948
|
+
sint64_node_id,
|
949
|
+
),
|
950
|
+
)
|
913
951
|
return True
|
914
952
|
|
915
953
|
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
@@ -34,12 +34,6 @@ from flwr.proto.message_pb2 import Context as ProtoContext
|
|
34
34
|
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
|
35
35
|
|
36
36
|
# pylint: enable=E0611
|
37
|
-
|
38
|
-
NODE_UNAVAILABLE_ERROR_REASON = (
|
39
|
-
"Error: Node Unavailable - The destination node is currently unavailable. "
|
40
|
-
"It exceeds the time limit specified in its last ping."
|
41
|
-
)
|
42
|
-
|
43
37
|
VALID_RUN_STATUS_TRANSITIONS = {
|
44
38
|
(Status.PENDING, Status.STARTING),
|
45
39
|
(Status.STARTING, Status.RUNNING),
|
@@ -60,6 +54,10 @@ MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
60
54
|
REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
61
55
|
"Error: Reply Message Unavailable - The reply message has expired."
|
62
56
|
)
|
57
|
+
NODE_UNAVAILABLE_ERROR_REASON = (
|
58
|
+
"Error: Node Unavailable - The destination node is currently unavailable. "
|
59
|
+
"It exceeds twice the time limit specified in its last ping."
|
60
|
+
)
|
63
61
|
|
64
62
|
|
65
63
|
def generate_rand_int_from_bytes(
|
@@ -237,7 +235,9 @@ def has_valid_sub_status(status: RunStatus) -> bool:
|
|
237
235
|
return status.sub_status == ""
|
238
236
|
|
239
237
|
|
240
|
-
def create_message_error_unavailable_res_message(
|
238
|
+
def create_message_error_unavailable_res_message(
|
239
|
+
ins_metadata: Metadata, error_type: str
|
240
|
+
) -> Message:
|
241
241
|
"""Generate an error Message that the SuperLink returns carrying the specified
|
242
242
|
error."""
|
243
243
|
current_time = now().timestamp()
|
@@ -256,8 +256,16 @@ def create_message_error_unavailable_res_message(ins_metadata: Metadata) -> Mess
|
|
256
256
|
return Message(
|
257
257
|
metadata=metadata,
|
258
258
|
error=Error(
|
259
|
-
code=
|
260
|
-
|
259
|
+
code=(
|
260
|
+
ErrorCode.REPLY_MESSAGE_UNAVAILABLE
|
261
|
+
if error_type == "msg_unavail"
|
262
|
+
else ErrorCode.NODE_UNAVAILABLE
|
263
|
+
),
|
264
|
+
reason=(
|
265
|
+
REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON
|
266
|
+
if error_type == "msg_unavail"
|
267
|
+
else NODE_UNAVAILABLE_ERROR_REASON
|
268
|
+
),
|
261
269
|
),
|
262
270
|
)
|
263
271
|
|
@@ -371,7 +379,56 @@ def verify_found_message_replies(
|
|
371
379
|
if message_ttl_has_expired(message_res.metadata, current):
|
372
380
|
# No need to insert the error Message
|
373
381
|
message_res = create_message_error_unavailable_res_message(
|
374
|
-
found_message_ins_dict[message_ins_id].metadata
|
382
|
+
found_message_ins_dict[message_ins_id].metadata, "msg_unavail"
|
375
383
|
)
|
376
384
|
ret_dict[message_ins_id] = message_res
|
377
385
|
return ret_dict
|
386
|
+
|
387
|
+
|
388
|
+
def check_node_availability_for_in_message(
|
389
|
+
inquired_in_message_ids: set[UUID],
|
390
|
+
found_in_message_dict: dict[UUID, Message],
|
391
|
+
node_id_to_online_until: dict[int, float],
|
392
|
+
current_time: Optional[float] = None,
|
393
|
+
update_set: bool = True,
|
394
|
+
) -> dict[UUID, Message]:
|
395
|
+
"""Check node availability for given Message and generate error reply Message if
|
396
|
+
unavailable. A Message error indicating node unavailability will be generated for
|
397
|
+
each given Message whose destination node is offline or non-existent.
|
398
|
+
|
399
|
+
Parameters
|
400
|
+
----------
|
401
|
+
inquired_in_message_ids : set[UUID]
|
402
|
+
Set of Message IDs for which to check destination node availability.
|
403
|
+
found_in_message_dict : dict[UUID, Message]
|
404
|
+
Dictionary containing all found Message indexed by their IDs.
|
405
|
+
node_id_to_online_until : dict[int, float]
|
406
|
+
Dictionary mapping node IDs to their online-until timestamps.
|
407
|
+
current_time : Optional[float] (default: None)
|
408
|
+
The current time to check for expiration. If set to `None`, the current time
|
409
|
+
will automatically be set to the current timestamp using `now().timestamp()`.
|
410
|
+
update_set : bool (default: True)
|
411
|
+
If True, the `inquired_in_message_ids` will be updated to remove invalid ones,
|
412
|
+
by default True.
|
413
|
+
|
414
|
+
Returns
|
415
|
+
-------
|
416
|
+
dict[UUID, Message]
|
417
|
+
A dictionary of error Message indexed by the corresponding Message ID.
|
418
|
+
"""
|
419
|
+
ret_dict = {}
|
420
|
+
current = current_time if current_time else now().timestamp()
|
421
|
+
for in_message_id in list(inquired_in_message_ids):
|
422
|
+
in_message = found_in_message_dict[in_message_id]
|
423
|
+
node_id = in_message.metadata.dst_node_id
|
424
|
+
online_until = node_id_to_online_until.get(node_id)
|
425
|
+
# Generate a reply message containing an error reply
|
426
|
+
# if the node is offline or doesn't exist.
|
427
|
+
if online_until is None or online_until < current:
|
428
|
+
if update_set:
|
429
|
+
inquired_in_message_ids.remove(in_message_id)
|
430
|
+
reply_message = create_message_error_unavailable_res_message(
|
431
|
+
in_message.metadata, "node_unavail"
|
432
|
+
)
|
433
|
+
ret_dict[in_message_id] = reply_message
|
434
|
+
return ret_dict
|
flwr/server/typing.py
CHANGED
@@ -19,9 +19,9 @@ from typing import Callable
|
|
19
19
|
|
20
20
|
from flwr.common import Context
|
21
21
|
|
22
|
-
from .grid import
|
22
|
+
from .grid import Grid
|
23
23
|
from .serverapp_components import ServerAppComponents
|
24
24
|
|
25
|
-
ServerAppCallable = Callable[[
|
26
|
-
Workflow = Callable[[
|
25
|
+
ServerAppCallable = Callable[[Grid, Context], None]
|
26
|
+
Workflow = Callable[[Grid, Context], None]
|
27
27
|
ServerFn = Callable[[Context], ServerAppComponents]
|
@@ -36,7 +36,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
36
36
|
from ..client_proxy import ClientProxy
|
37
37
|
from ..compat.app_utils import start_update_client_manager_thread
|
38
38
|
from ..compat.legacy_context import LegacyContext
|
39
|
-
from ..grid import
|
39
|
+
from ..grid import Grid
|
40
40
|
from ..typing import Workflow
|
41
41
|
from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key
|
42
42
|
|
@@ -56,7 +56,7 @@ class DefaultWorkflow:
|
|
56
56
|
self.fit_workflow: Workflow = fit_workflow
|
57
57
|
self.evaluate_workflow: Workflow = evaluate_workflow
|
58
58
|
|
59
|
-
def __call__(self,
|
59
|
+
def __call__(self, grid: Grid, context: Context) -> None:
|
60
60
|
"""Execute the workflow."""
|
61
61
|
if not isinstance(context, LegacyContext):
|
62
62
|
raise TypeError(
|
@@ -65,7 +65,7 @@ class DefaultWorkflow:
|
|
65
65
|
|
66
66
|
# Start the thread updating nodes
|
67
67
|
thread, f_stop, c_done = start_update_client_manager_thread(
|
68
|
-
|
68
|
+
grid, context.client_manager
|
69
69
|
)
|
70
70
|
|
71
71
|
# Wait until the node registration done
|
@@ -73,7 +73,7 @@ class DefaultWorkflow:
|
|
73
73
|
|
74
74
|
# Initialize parameters
|
75
75
|
log(INFO, "[INIT]")
|
76
|
-
default_init_params_workflow(
|
76
|
+
default_init_params_workflow(grid, context)
|
77
77
|
|
78
78
|
# Run federated learning for num_rounds
|
79
79
|
start_time = timeit.default_timer()
|
@@ -87,13 +87,13 @@ class DefaultWorkflow:
|
|
87
87
|
cfg[Key.CURRENT_ROUND] = current_round
|
88
88
|
|
89
89
|
# Fit round
|
90
|
-
self.fit_workflow(
|
90
|
+
self.fit_workflow(grid, context)
|
91
91
|
|
92
92
|
# Centralized evaluation
|
93
|
-
default_centralized_evaluation_workflow(
|
93
|
+
default_centralized_evaluation_workflow(grid, context)
|
94
94
|
|
95
95
|
# Evaluate round
|
96
|
-
self.evaluate_workflow(
|
96
|
+
self.evaluate_workflow(grid, context)
|
97
97
|
|
98
98
|
# Bookkeeping and log results
|
99
99
|
end_time = timeit.default_timer()
|
@@ -119,7 +119,7 @@ class DefaultWorkflow:
|
|
119
119
|
thread.join()
|
120
120
|
|
121
121
|
|
122
|
-
def default_init_params_workflow(
|
122
|
+
def default_init_params_workflow(grid: Grid, context: Context) -> None:
|
123
123
|
"""Execute the default workflow for parameters initialization."""
|
124
124
|
if not isinstance(context, LegacyContext):
|
125
125
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
@@ -138,9 +138,9 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
138
138
|
random_client = context.client_manager.sample(1)[0]
|
139
139
|
# Send GetParametersIns and get the response
|
140
140
|
content = compat.getparametersins_to_recordset(GetParametersIns({}))
|
141
|
-
messages =
|
141
|
+
messages = grid.send_and_receive(
|
142
142
|
[
|
143
|
-
|
143
|
+
grid.create_message(
|
144
144
|
content=content,
|
145
145
|
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
146
146
|
dst_node_id=random_client.node_id,
|
@@ -186,7 +186,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
186
186
|
log(INFO, "Evaluation returned no results (`None`)")
|
187
187
|
|
188
188
|
|
189
|
-
def default_centralized_evaluation_workflow(_:
|
189
|
+
def default_centralized_evaluation_workflow(_: Grid, context: Context) -> None:
|
190
190
|
"""Execute the default workflow for centralized evaluation."""
|
191
191
|
if not isinstance(context, LegacyContext):
|
192
192
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
@@ -218,9 +218,7 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
|
|
218
218
|
)
|
219
219
|
|
220
220
|
|
221
|
-
def default_fit_workflow( # pylint: disable=R0914
|
222
|
-
driver: Driver, context: Context
|
223
|
-
) -> None:
|
221
|
+
def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disable=R0914
|
224
222
|
"""Execute the default workflow for a single fit round."""
|
225
223
|
if not isinstance(context, LegacyContext):
|
226
224
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
@@ -255,7 +253,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
255
253
|
|
256
254
|
# Build out messages
|
257
255
|
out_messages = [
|
258
|
-
|
256
|
+
grid.create_message(
|
259
257
|
content=compat.fitins_to_recordset(fitins, True),
|
260
258
|
message_type=MessageType.TRAIN,
|
261
259
|
dst_node_id=proxy.node_id,
|
@@ -266,7 +264,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
266
264
|
|
267
265
|
# Send instructions to clients and
|
268
266
|
# collect `fit` results from all clients participating in this round
|
269
|
-
messages = list(
|
267
|
+
messages = list(grid.send_and_receive(out_messages))
|
270
268
|
del out_messages
|
271
269
|
num_failures = len([msg for msg in messages if msg.has_error()])
|
272
270
|
|
@@ -307,7 +305,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
307
305
|
|
308
306
|
|
309
307
|
# pylint: disable-next=R0914
|
310
|
-
def default_evaluate_workflow(
|
308
|
+
def default_evaluate_workflow(grid: Grid, context: Context) -> None:
|
311
309
|
"""Execute the default workflow for a single evaluate round."""
|
312
310
|
if not isinstance(context, LegacyContext):
|
313
311
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
@@ -341,7 +339,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
341
339
|
|
342
340
|
# Build out messages
|
343
341
|
out_messages = [
|
344
|
-
|
342
|
+
grid.create_message(
|
345
343
|
content=compat.evaluateins_to_recordset(evalins, True),
|
346
344
|
message_type=MessageType.EVALUATE,
|
347
345
|
dst_node_id=proxy.node_id,
|
@@ -352,7 +350,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
352
350
|
|
353
351
|
# Send instructions to clients and
|
354
352
|
# collect `evaluate` results from all clients participating in this round
|
355
|
-
messages = list(
|
353
|
+
messages = list(grid.send_and_receive(out_messages))
|
356
354
|
del out_messages
|
357
355
|
num_failures = len([msg for msg in messages if msg.has_error()])
|
358
356
|
|
@@ -55,7 +55,7 @@ from flwr.common.secure_aggregation.secaggplus_constants import (
|
|
55
55
|
from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
|
56
56
|
from flwr.server.client_proxy import ClientProxy
|
57
57
|
from flwr.server.compat.legacy_context import LegacyContext
|
58
|
-
from flwr.server.grid import
|
58
|
+
from flwr.server.grid import Grid
|
59
59
|
|
60
60
|
from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
|
61
61
|
from ..constant import Key as WorkflowKey
|
@@ -186,7 +186,7 @@ class SecAggPlusWorkflow:
|
|
186
186
|
|
187
187
|
self._check_init_params()
|
188
188
|
|
189
|
-
def __call__(self,
|
189
|
+
def __call__(self, grid: Grid, context: Context) -> None:
|
190
190
|
"""Run the SecAgg+ protocol."""
|
191
191
|
if not isinstance(context, LegacyContext):
|
192
192
|
raise TypeError(
|
@@ -202,7 +202,7 @@ class SecAggPlusWorkflow:
|
|
202
202
|
)
|
203
203
|
log(INFO, "Secure aggregation commencing.")
|
204
204
|
for step in steps:
|
205
|
-
if not step(
|
205
|
+
if not step(grid, context, state):
|
206
206
|
log(INFO, "Secure aggregation halted.")
|
207
207
|
return
|
208
208
|
log(INFO, "Secure aggregation completed.")
|
@@ -279,7 +279,7 @@ class SecAggPlusWorkflow:
|
|
279
279
|
return True
|
280
280
|
|
281
281
|
def setup_stage( # pylint: disable=R0912, R0914, R0915
|
282
|
-
self,
|
282
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
283
283
|
) -> bool:
|
284
284
|
"""Execute the 'setup' stage."""
|
285
285
|
# Obtain fit instructions
|
@@ -370,7 +370,7 @@ class SecAggPlusWorkflow:
|
|
370
370
|
content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record})
|
371
371
|
|
372
372
|
def make(nid: int) -> Message:
|
373
|
-
return
|
373
|
+
return grid.create_message(
|
374
374
|
content=content,
|
375
375
|
message_type=MessageType.TRAIN,
|
376
376
|
dst_node_id=nid,
|
@@ -382,7 +382,7 @@ class SecAggPlusWorkflow:
|
|
382
382
|
"[Stage 0] Sending configurations to %s clients.",
|
383
383
|
len(state.active_node_ids),
|
384
384
|
)
|
385
|
-
msgs =
|
385
|
+
msgs = grid.send_and_receive(
|
386
386
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
387
387
|
)
|
388
388
|
state.active_node_ids = {
|
@@ -406,7 +406,7 @@ class SecAggPlusWorkflow:
|
|
406
406
|
return self._check_threshold(state)
|
407
407
|
|
408
408
|
def share_keys_stage( # pylint: disable=R0914
|
409
|
-
self,
|
409
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
410
410
|
) -> bool:
|
411
411
|
"""Execute the 'share keys' stage."""
|
412
412
|
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
@@ -418,7 +418,7 @@ class SecAggPlusWorkflow:
|
|
418
418
|
)
|
419
419
|
cfgs_record[Key.STAGE] = Stage.SHARE_KEYS
|
420
420
|
content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record})
|
421
|
-
return
|
421
|
+
return grid.create_message(
|
422
422
|
content=content,
|
423
423
|
message_type=MessageType.TRAIN,
|
424
424
|
dst_node_id=nid,
|
@@ -431,7 +431,7 @@ class SecAggPlusWorkflow:
|
|
431
431
|
"[Stage 1] Forwarding public keys to %s clients.",
|
432
432
|
len(state.active_node_ids),
|
433
433
|
)
|
434
|
-
msgs =
|
434
|
+
msgs = grid.send_and_receive(
|
435
435
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
436
436
|
)
|
437
437
|
state.active_node_ids = {
|
@@ -476,7 +476,7 @@ class SecAggPlusWorkflow:
|
|
476
476
|
return self._check_threshold(state)
|
477
477
|
|
478
478
|
def collect_masked_vectors_stage(
|
479
|
-
self,
|
479
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
480
480
|
) -> bool:
|
481
481
|
"""Execute the 'collect masked vectors' stage."""
|
482
482
|
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
@@ -491,7 +491,7 @@ class SecAggPlusWorkflow:
|
|
491
491
|
cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
|
492
492
|
content = state.nid_to_fitins[nid]
|
493
493
|
content.configs_records[RECORD_KEY_CONFIGS] = cfgs_record
|
494
|
-
return
|
494
|
+
return grid.create_message(
|
495
495
|
content=content,
|
496
496
|
message_type=MessageType.TRAIN,
|
497
497
|
dst_node_id=nid,
|
@@ -503,7 +503,7 @@ class SecAggPlusWorkflow:
|
|
503
503
|
"[Stage 2] Forwarding encrypted key shares to %s clients.",
|
504
504
|
len(state.active_node_ids),
|
505
505
|
)
|
506
|
-
msgs =
|
506
|
+
msgs = grid.send_and_receive(
|
507
507
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
508
508
|
)
|
509
509
|
state.active_node_ids = {
|
@@ -547,7 +547,7 @@ class SecAggPlusWorkflow:
|
|
547
547
|
return self._check_threshold(state)
|
548
548
|
|
549
549
|
def unmask_stage( # pylint: disable=R0912, R0914, R0915
|
550
|
-
self,
|
550
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
551
551
|
) -> bool:
|
552
552
|
"""Execute the 'unmask' stage."""
|
553
553
|
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
@@ -567,7 +567,7 @@ class SecAggPlusWorkflow:
|
|
567
567
|
}
|
568
568
|
cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
|
569
569
|
content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record})
|
570
|
-
return
|
570
|
+
return grid.create_message(
|
571
571
|
content=content,
|
572
572
|
message_type=MessageType.TRAIN,
|
573
573
|
dst_node_id=nid,
|
@@ -579,7 +579,7 @@ class SecAggPlusWorkflow:
|
|
579
579
|
"[Stage 3] Requesting key shares from %s clients to remove masks.",
|
580
580
|
len(state.active_node_ids),
|
581
581
|
)
|
582
|
-
msgs =
|
582
|
+
msgs = grid.send_and_receive(
|
583
583
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
584
584
|
)
|
585
585
|
state.active_node_ids = {
|