flwr-nightly 1.17.0.dev20250317__py3-none-any.whl → 1.17.0.dev20250319__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/common/constant.py +5 -0
- flwr/common/logger.py +2 -2
- flwr/common/record/parametersrecord.py +336 -92
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +1 -1
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +11 -11
- flwr/server/compat/app_utils.py +16 -16
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +9 -9
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +44 -15
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +12 -20
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +6 -14
- flwr/server/run_serverapp.py +4 -4
- flwr/server/server_app.py +38 -12
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/linkstate/in_memory_linkstate.py +28 -3
- flwr/server/superlink/linkstate/sqlite_linkstate.py +40 -2
- flwr/server/superlink/linkstate/utils.py +67 -10
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +1 -1
- flwr/server/typing.py +3 -3
- flwr/server/workflow/default_workflows.py +17 -19
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +15 -15
- flwr/simulation/run_simulation.py +10 -10
- {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/METADATA +1 -1
- {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/RECORD +31 -31
- {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/entry_points.txt +0 -0
flwr/server/server_app.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
"""Flower ServerApp."""
|
16
16
|
|
17
17
|
|
18
|
+
import inspect
|
18
19
|
from collections.abc import Iterator
|
19
20
|
from contextlib import contextmanager
|
20
21
|
from typing import Callable, Optional
|
@@ -24,8 +25,8 @@ from flwr.common.logger import warn_deprecated_feature_with_example
|
|
24
25
|
from flwr.server.strategy import Strategy
|
25
26
|
|
26
27
|
from .client_manager import ClientManager
|
27
|
-
from .compat import
|
28
|
-
from .
|
28
|
+
from .compat import start_grid
|
29
|
+
from .grid import Driver, Grid
|
29
30
|
from .server import Server
|
30
31
|
from .server_config import ServerConfig
|
31
32
|
from .typing import ServerAppCallable, ServerFn
|
@@ -43,6 +44,21 @@ SERVER_FN_USAGE_EXAMPLE = """
|
|
43
44
|
app = ServerApp(server_fn=server_fn)
|
44
45
|
"""
|
45
46
|
|
47
|
+
GRID_USAGE_EXAMPLE = """
|
48
|
+
app = ServerApp()
|
49
|
+
|
50
|
+
@app.main()
|
51
|
+
def main(grid: Grid, context: Context) -> None:
|
52
|
+
# Your existing ServerApp code ...
|
53
|
+
"""
|
54
|
+
|
55
|
+
DRIVER_DEPRECATION_MSG = """
|
56
|
+
The `Driver` class is deprecated, it will be removed in a future release.
|
57
|
+
"""
|
58
|
+
DRIVER_EXAMPLE_MSG = """
|
59
|
+
Instead, use `Grid` in the signature of your `ServerApp`. For example:
|
60
|
+
"""
|
61
|
+
|
46
62
|
|
47
63
|
@contextmanager
|
48
64
|
def _empty_lifespan(_: Context) -> Iterator[None]:
|
@@ -54,7 +70,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
54
70
|
|
55
71
|
Examples
|
56
72
|
--------
|
57
|
-
Use the
|
73
|
+
Use the ``ServerApp`` with an existing ``Strategy``:
|
58
74
|
|
59
75
|
>>> def server_fn(context: Context):
|
60
76
|
>>> server_config = ServerConfig(num_rounds=3)
|
@@ -66,12 +82,12 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
66
82
|
>>>
|
67
83
|
>>> app = ServerApp(server_fn=server_fn)
|
68
84
|
|
69
|
-
Use the
|
85
|
+
Use the ``ServerApp`` with a custom main function:
|
70
86
|
|
71
87
|
>>> app = ServerApp()
|
72
88
|
>>>
|
73
89
|
>>> @app.main()
|
74
|
-
>>> def main(
|
90
|
+
>>> def main(grid: Grid, context: Context) -> None:
|
75
91
|
>>> print("ServerApp running")
|
76
92
|
"""
|
77
93
|
|
@@ -111,7 +127,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
111
127
|
self._main: Optional[ServerAppCallable] = None
|
112
128
|
self._lifespan = _empty_lifespan
|
113
129
|
|
114
|
-
def __call__(self,
|
130
|
+
def __call__(self, grid: Grid, context: Context) -> None:
|
115
131
|
"""Execute `ServerApp`."""
|
116
132
|
with self._lifespan(context):
|
117
133
|
# Compatibility mode
|
@@ -123,17 +139,17 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
123
139
|
self._config = components.config
|
124
140
|
self._strategy = components.strategy
|
125
141
|
self._client_manager = components.client_manager
|
126
|
-
|
142
|
+
start_grid(
|
127
143
|
server=self._server,
|
128
144
|
config=self._config,
|
129
145
|
strategy=self._strategy,
|
130
146
|
client_manager=self._client_manager,
|
131
|
-
|
147
|
+
grid=grid,
|
132
148
|
)
|
133
149
|
return
|
134
150
|
|
135
151
|
# New execution mode
|
136
|
-
self._main(
|
152
|
+
self._main(grid, context)
|
137
153
|
|
138
154
|
def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
|
139
155
|
"""Return a decorator that registers the main fn with the server app.
|
@@ -143,7 +159,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
143
159
|
>>> app = ServerApp()
|
144
160
|
>>>
|
145
161
|
>>> @app.main()
|
146
|
-
>>> def main(
|
162
|
+
>>> def main(grid: Grid, context: Context) -> None:
|
147
163
|
>>> print("ServerApp running")
|
148
164
|
"""
|
149
165
|
|
@@ -168,11 +184,21 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
168
184
|
>>> app = ServerApp()
|
169
185
|
>>>
|
170
186
|
>>> @app.main()
|
171
|
-
>>> def main(
|
187
|
+
>>> def main(grid: Grid, context: Context) -> None:
|
172
188
|
>>> print("ServerApp running")
|
173
189
|
""",
|
174
190
|
)
|
175
191
|
|
192
|
+
sig = inspect.signature(main_fn)
|
193
|
+
param = list(sig.parameters.values())[0]
|
194
|
+
# Check if parameter name or the annotation should be updated
|
195
|
+
if param.name == "driver" or param.annotation is Driver:
|
196
|
+
warn_deprecated_feature_with_example(
|
197
|
+
deprecation_message=DRIVER_DEPRECATION_MSG,
|
198
|
+
example_message=DRIVER_EXAMPLE_MSG,
|
199
|
+
code_example=GRID_USAGE_EXAMPLE,
|
200
|
+
)
|
201
|
+
|
176
202
|
# Register provided function with the ServerApp object
|
177
203
|
self._main = main_fn
|
178
204
|
|
@@ -207,7 +233,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
207
233
|
"""
|
208
234
|
|
209
235
|
def lifespan_decorator(
|
210
|
-
lifespan_fn: Callable[[Context], Iterator[None]]
|
236
|
+
lifespan_fn: Callable[[Context], Iterator[None]],
|
211
237
|
) -> Callable[[Context], Iterator[None]]:
|
212
238
|
"""Register the lifespan fn with the ServerApp object."""
|
213
239
|
|
flwr/server/serverapp/app.py
CHANGED
@@ -60,7 +60,7 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
60
60
|
PullServerAppInputsResponse,
|
61
61
|
PushServerAppOutputsRequest,
|
62
62
|
)
|
63
|
-
from flwr.server.
|
63
|
+
from flwr.server.grid.grpc_grid import GrpcGrid
|
64
64
|
from flwr.server.run_serverapp import run as run_
|
65
65
|
|
66
66
|
|
@@ -106,7 +106,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
106
106
|
certificates: Optional[bytes] = None,
|
107
107
|
) -> None:
|
108
108
|
"""Run Flower ServerApp process."""
|
109
|
-
|
109
|
+
grid = GrpcGrid(
|
110
110
|
serverappio_service_address=serverappio_api_address,
|
111
111
|
root_certificates=certificates,
|
112
112
|
)
|
@@ -123,7 +123,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
123
123
|
# Pull ServerAppInputs from LinkState
|
124
124
|
req = PullServerAppInputsRequest()
|
125
125
|
log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
|
126
|
-
res: PullServerAppInputsResponse =
|
126
|
+
res: PullServerAppInputsResponse = grid._stub.PullServerAppInputs(req)
|
127
127
|
if not res.HasField("run"):
|
128
128
|
sleep(3)
|
129
129
|
run_status = None
|
@@ -135,14 +135,14 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
135
135
|
|
136
136
|
hash_run_id = get_sha256_hash(run.run_id)
|
137
137
|
|
138
|
-
|
138
|
+
grid.set_run(run.run_id)
|
139
139
|
|
140
140
|
# Start log uploader for this run
|
141
141
|
log_uploader = start_log_uploader(
|
142
142
|
log_queue=log_queue,
|
143
143
|
node_id=0,
|
144
144
|
run_id=run.run_id,
|
145
|
-
stub=
|
145
|
+
stub=grid._stub,
|
146
146
|
)
|
147
147
|
|
148
148
|
log(DEBUG, "[flwr-serverapp] Start FAB installation.")
|
@@ -173,7 +173,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
173
173
|
|
174
174
|
# Change status to Running
|
175
175
|
run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
|
176
|
-
|
176
|
+
grid._stub.UpdateRunStatus(
|
177
177
|
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
178
178
|
)
|
179
179
|
|
@@ -182,9 +182,9 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
182
182
|
event_details={"run-id-hash": hash_run_id},
|
183
183
|
)
|
184
184
|
|
185
|
-
# Load and run the ServerApp with the
|
185
|
+
# Load and run the ServerApp with the Grid
|
186
186
|
updated_context = run_(
|
187
|
-
|
187
|
+
grid=grid,
|
188
188
|
server_app_dir=app_path,
|
189
189
|
server_app_attr=server_app_attr,
|
190
190
|
context=context,
|
@@ -196,7 +196,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
196
196
|
out_req = PushServerAppOutputsRequest(
|
197
197
|
run_id=run.run_id, context=context_proto
|
198
198
|
)
|
199
|
-
_ =
|
199
|
+
_ = grid._stub.PushServerAppOutputs(out_req)
|
200
200
|
|
201
201
|
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
202
202
|
except RunNotRunningException:
|
@@ -221,7 +221,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
221
221
|
# Update run status
|
222
222
|
if run_status:
|
223
223
|
run_status_proto = run_status_to_proto(run_status)
|
224
|
-
|
224
|
+
grid._stub.UpdateRunStatus(
|
225
225
|
UpdateRunStatusRequest(
|
226
226
|
run_id=run.run_id, run_status=run_status_proto
|
227
227
|
)
|
@@ -27,6 +27,7 @@ from flwr.common import Context, Message, log, now
|
|
27
27
|
from flwr.common.constant import (
|
28
28
|
MESSAGE_TTL_TOLERANCE,
|
29
29
|
NODE_ID_NUM_BYTES,
|
30
|
+
PING_PATIENCE,
|
30
31
|
RUN_ID_NUM_BYTES,
|
31
32
|
SUPERLINK_NODE_ID,
|
32
33
|
Status,
|
@@ -37,6 +38,7 @@ from flwr.server.superlink.linkstate.linkstate import LinkState
|
|
37
38
|
from flwr.server.utils import validate_message
|
38
39
|
|
39
40
|
from .utils import (
|
41
|
+
check_node_availability_for_in_message,
|
40
42
|
generate_rand_int_from_bytes,
|
41
43
|
has_valid_sub_status,
|
42
44
|
is_valid_transition,
|
@@ -232,13 +234,28 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
232
234
|
with self.lock:
|
233
235
|
current = time.time()
|
234
236
|
|
235
|
-
# Verify
|
237
|
+
# Verify Message IDs
|
236
238
|
ret = verify_message_ids(
|
237
239
|
inquired_message_ids=message_ids,
|
238
240
|
found_message_ins_dict=self.message_ins_store,
|
239
241
|
current_time=current,
|
240
242
|
)
|
241
243
|
|
244
|
+
# Check node availability
|
245
|
+
dst_node_ids = {
|
246
|
+
self.message_ins_store[message_id].metadata.dst_node_id
|
247
|
+
for message_id in message_ids
|
248
|
+
}
|
249
|
+
tmp_ret_dict = check_node_availability_for_in_message(
|
250
|
+
inquired_in_message_ids=message_ids,
|
251
|
+
found_in_message_dict=self.message_ins_store,
|
252
|
+
node_id_to_online_until={
|
253
|
+
node_id: self.node_ids[node_id][0] for node_id in dst_node_ids
|
254
|
+
},
|
255
|
+
current_time=current,
|
256
|
+
)
|
257
|
+
ret.update(tmp_ret_dict)
|
258
|
+
|
242
259
|
# Find all reply Messages
|
243
260
|
message_res_found: list[Message] = []
|
244
261
|
for message_id in message_ids:
|
@@ -317,6 +334,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
317
334
|
log(ERROR, "Unexpected node registration failure.")
|
318
335
|
return 0
|
319
336
|
|
337
|
+
# Mark the node online util time.time() + ping_interval
|
320
338
|
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
321
339
|
return node_id
|
322
340
|
|
@@ -519,10 +537,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
519
537
|
return self.federation_options[run_id]
|
520
538
|
|
521
539
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
522
|
-
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
540
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
541
|
+
|
542
|
+
It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
|
543
|
+
marking the node as offline, where PING_PATIENCE = 2 in default.
|
544
|
+
"""
|
523
545
|
with self.lock:
|
524
546
|
if node_id in self.node_ids:
|
525
|
-
self.node_ids[node_id] = (
|
547
|
+
self.node_ids[node_id] = (
|
548
|
+
time.time() + PING_PATIENCE * ping_interval,
|
549
|
+
ping_interval,
|
550
|
+
)
|
526
551
|
return True
|
527
552
|
return False
|
528
553
|
|
@@ -30,6 +30,7 @@ from flwr.common import Context, Message, Metadata, log, now
|
|
30
30
|
from flwr.common.constant import (
|
31
31
|
MESSAGE_TTL_TOLERANCE,
|
32
32
|
NODE_ID_NUM_BYTES,
|
33
|
+
PING_PATIENCE,
|
33
34
|
RUN_ID_NUM_BYTES,
|
34
35
|
SUPERLINK_NODE_ID,
|
35
36
|
Status,
|
@@ -52,6 +53,7 @@ from flwr.server.utils.validator import validate_message
|
|
52
53
|
|
53
54
|
from .linkstate import LinkState
|
54
55
|
from .utils import (
|
56
|
+
check_node_availability_for_in_message,
|
55
57
|
configsrecord_from_bytes,
|
56
58
|
configsrecord_to_bytes,
|
57
59
|
context_from_bytes,
|
@@ -442,6 +444,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
442
444
|
|
443
445
|
def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
|
444
446
|
"""Get reply Messages for the given Message IDs."""
|
447
|
+
# pylint: disable-msg=too-many-locals
|
445
448
|
ret: dict[UUID, Message] = {}
|
446
449
|
|
447
450
|
# Verify Message IDs
|
@@ -465,6 +468,29 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
465
468
|
current_time=current,
|
466
469
|
)
|
467
470
|
|
471
|
+
# Check node availability
|
472
|
+
dst_node_ids: set[int] = set()
|
473
|
+
for message_id in message_ids:
|
474
|
+
in_message = found_message_ins_dict[message_id]
|
475
|
+
sint_node_id = convert_uint64_to_sint64(in_message.metadata.dst_node_id)
|
476
|
+
dst_node_ids.add(sint_node_id)
|
477
|
+
query = f"""
|
478
|
+
SELECT node_id, online_until
|
479
|
+
FROM node
|
480
|
+
WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))});
|
481
|
+
"""
|
482
|
+
rows = self.query(query, tuple(dst_node_ids))
|
483
|
+
tmp_ret_dict = check_node_availability_for_in_message(
|
484
|
+
inquired_in_message_ids=message_ids,
|
485
|
+
found_in_message_dict=found_message_ins_dict,
|
486
|
+
node_id_to_online_until={
|
487
|
+
convert_sint64_to_uint64(row["node_id"]): row["online_until"]
|
488
|
+
for row in rows
|
489
|
+
},
|
490
|
+
current_time=current,
|
491
|
+
)
|
492
|
+
ret.update(tmp_ret_dict)
|
493
|
+
|
468
494
|
# Find all reply Messages
|
469
495
|
query = f"""
|
470
496
|
SELECT *
|
@@ -584,6 +610,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
584
610
|
"VALUES (?, ?, ?, ?)"
|
585
611
|
)
|
586
612
|
|
613
|
+
# Mark the node online util time.time() + ping_interval
|
587
614
|
try:
|
588
615
|
self.query(
|
589
616
|
query,
|
@@ -899,7 +926,11 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
899
926
|
return configsrecord_from_bytes(row["federation_options"])
|
900
927
|
|
901
928
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
902
|
-
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
929
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
930
|
+
|
931
|
+
It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
|
932
|
+
marking the node as offline, where PING_PATIENCE = 2 in default.
|
933
|
+
"""
|
903
934
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
904
935
|
|
905
936
|
# Check if the node exists in the `node` table
|
@@ -909,7 +940,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
909
940
|
|
910
941
|
# Update `online_until` and `ping_interval` for the given `node_id`
|
911
942
|
query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
|
912
|
-
self.query(
|
943
|
+
self.query(
|
944
|
+
query,
|
945
|
+
(
|
946
|
+
time.time() + PING_PATIENCE * ping_interval,
|
947
|
+
ping_interval,
|
948
|
+
sint64_node_id,
|
949
|
+
),
|
950
|
+
)
|
913
951
|
return True
|
914
952
|
|
915
953
|
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
@@ -34,12 +34,6 @@ from flwr.proto.message_pb2 import Context as ProtoContext
|
|
34
34
|
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
|
35
35
|
|
36
36
|
# pylint: enable=E0611
|
37
|
-
|
38
|
-
NODE_UNAVAILABLE_ERROR_REASON = (
|
39
|
-
"Error: Node Unavailable - The destination node is currently unavailable. "
|
40
|
-
"It exceeds the time limit specified in its last ping."
|
41
|
-
)
|
42
|
-
|
43
37
|
VALID_RUN_STATUS_TRANSITIONS = {
|
44
38
|
(Status.PENDING, Status.STARTING),
|
45
39
|
(Status.STARTING, Status.RUNNING),
|
@@ -60,6 +54,10 @@ MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
60
54
|
REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
61
55
|
"Error: Reply Message Unavailable - The reply message has expired."
|
62
56
|
)
|
57
|
+
NODE_UNAVAILABLE_ERROR_REASON = (
|
58
|
+
"Error: Node Unavailable - The destination node is currently unavailable. "
|
59
|
+
"It exceeds twice the time limit specified in its last ping."
|
60
|
+
)
|
63
61
|
|
64
62
|
|
65
63
|
def generate_rand_int_from_bytes(
|
@@ -237,7 +235,9 @@ def has_valid_sub_status(status: RunStatus) -> bool:
|
|
237
235
|
return status.sub_status == ""
|
238
236
|
|
239
237
|
|
240
|
-
def create_message_error_unavailable_res_message(
|
238
|
+
def create_message_error_unavailable_res_message(
|
239
|
+
ins_metadata: Metadata, error_type: str
|
240
|
+
) -> Message:
|
241
241
|
"""Generate an error Message that the SuperLink returns carrying the specified
|
242
242
|
error."""
|
243
243
|
current_time = now().timestamp()
|
@@ -256,8 +256,16 @@ def create_message_error_unavailable_res_message(ins_metadata: Metadata) -> Mess
|
|
256
256
|
return Message(
|
257
257
|
metadata=metadata,
|
258
258
|
error=Error(
|
259
|
-
code=
|
260
|
-
|
259
|
+
code=(
|
260
|
+
ErrorCode.REPLY_MESSAGE_UNAVAILABLE
|
261
|
+
if error_type == "msg_unavail"
|
262
|
+
else ErrorCode.NODE_UNAVAILABLE
|
263
|
+
),
|
264
|
+
reason=(
|
265
|
+
REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON
|
266
|
+
if error_type == "msg_unavail"
|
267
|
+
else NODE_UNAVAILABLE_ERROR_REASON
|
268
|
+
),
|
261
269
|
),
|
262
270
|
)
|
263
271
|
|
@@ -371,7 +379,56 @@ def verify_found_message_replies(
|
|
371
379
|
if message_ttl_has_expired(message_res.metadata, current):
|
372
380
|
# No need to insert the error Message
|
373
381
|
message_res = create_message_error_unavailable_res_message(
|
374
|
-
found_message_ins_dict[message_ins_id].metadata
|
382
|
+
found_message_ins_dict[message_ins_id].metadata, "msg_unavail"
|
375
383
|
)
|
376
384
|
ret_dict[message_ins_id] = message_res
|
377
385
|
return ret_dict
|
386
|
+
|
387
|
+
|
388
|
+
def check_node_availability_for_in_message(
|
389
|
+
inquired_in_message_ids: set[UUID],
|
390
|
+
found_in_message_dict: dict[UUID, Message],
|
391
|
+
node_id_to_online_until: dict[int, float],
|
392
|
+
current_time: Optional[float] = None,
|
393
|
+
update_set: bool = True,
|
394
|
+
) -> dict[UUID, Message]:
|
395
|
+
"""Check node availability for given Message and generate error reply Message if
|
396
|
+
unavailable. A Message error indicating node unavailability will be generated for
|
397
|
+
each given Message whose destination node is offline or non-existent.
|
398
|
+
|
399
|
+
Parameters
|
400
|
+
----------
|
401
|
+
inquired_in_message_ids : set[UUID]
|
402
|
+
Set of Message IDs for which to check destination node availability.
|
403
|
+
found_in_message_dict : dict[UUID, Message]
|
404
|
+
Dictionary containing all found Message indexed by their IDs.
|
405
|
+
node_id_to_online_until : dict[int, float]
|
406
|
+
Dictionary mapping node IDs to their online-until timestamps.
|
407
|
+
current_time : Optional[float] (default: None)
|
408
|
+
The current time to check for expiration. If set to `None`, the current time
|
409
|
+
will automatically be set to the current timestamp using `now().timestamp()`.
|
410
|
+
update_set : bool (default: True)
|
411
|
+
If True, the `inquired_in_message_ids` will be updated to remove invalid ones,
|
412
|
+
by default True.
|
413
|
+
|
414
|
+
Returns
|
415
|
+
-------
|
416
|
+
dict[UUID, Message]
|
417
|
+
A dictionary of error Message indexed by the corresponding Message ID.
|
418
|
+
"""
|
419
|
+
ret_dict = {}
|
420
|
+
current = current_time if current_time else now().timestamp()
|
421
|
+
for in_message_id in list(inquired_in_message_ids):
|
422
|
+
in_message = found_in_message_dict[in_message_id]
|
423
|
+
node_id = in_message.metadata.dst_node_id
|
424
|
+
online_until = node_id_to_online_until.get(node_id)
|
425
|
+
# Generate a reply message containing an error reply
|
426
|
+
# if the node is offline or doesn't exist.
|
427
|
+
if online_until is None or online_until < current:
|
428
|
+
if update_set:
|
429
|
+
inquired_in_message_ids.remove(in_message_id)
|
430
|
+
reply_message = create_message_error_unavailable_res_message(
|
431
|
+
in_message.metadata, "node_unavail"
|
432
|
+
)
|
433
|
+
ret_dict[in_message_id] = reply_message
|
434
|
+
return ret_dict
|
flwr/server/typing.py
CHANGED
@@ -19,9 +19,9 @@ from typing import Callable
|
|
19
19
|
|
20
20
|
from flwr.common import Context
|
21
21
|
|
22
|
-
from .
|
22
|
+
from .grid import Grid
|
23
23
|
from .serverapp_components import ServerAppComponents
|
24
24
|
|
25
|
-
ServerAppCallable = Callable[[
|
26
|
-
Workflow = Callable[[
|
25
|
+
ServerAppCallable = Callable[[Grid, Context], None]
|
26
|
+
Workflow = Callable[[Grid, Context], None]
|
27
27
|
ServerFn = Callable[[Context], ServerAppComponents]
|
@@ -36,7 +36,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
36
36
|
from ..client_proxy import ClientProxy
|
37
37
|
from ..compat.app_utils import start_update_client_manager_thread
|
38
38
|
from ..compat.legacy_context import LegacyContext
|
39
|
-
from ..
|
39
|
+
from ..grid import Grid
|
40
40
|
from ..typing import Workflow
|
41
41
|
from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key
|
42
42
|
|
@@ -56,7 +56,7 @@ class DefaultWorkflow:
|
|
56
56
|
self.fit_workflow: Workflow = fit_workflow
|
57
57
|
self.evaluate_workflow: Workflow = evaluate_workflow
|
58
58
|
|
59
|
-
def __call__(self,
|
59
|
+
def __call__(self, grid: Grid, context: Context) -> None:
|
60
60
|
"""Execute the workflow."""
|
61
61
|
if not isinstance(context, LegacyContext):
|
62
62
|
raise TypeError(
|
@@ -65,7 +65,7 @@ class DefaultWorkflow:
|
|
65
65
|
|
66
66
|
# Start the thread updating nodes
|
67
67
|
thread, f_stop, c_done = start_update_client_manager_thread(
|
68
|
-
|
68
|
+
grid, context.client_manager
|
69
69
|
)
|
70
70
|
|
71
71
|
# Wait until the node registration done
|
@@ -73,7 +73,7 @@ class DefaultWorkflow:
|
|
73
73
|
|
74
74
|
# Initialize parameters
|
75
75
|
log(INFO, "[INIT]")
|
76
|
-
default_init_params_workflow(
|
76
|
+
default_init_params_workflow(grid, context)
|
77
77
|
|
78
78
|
# Run federated learning for num_rounds
|
79
79
|
start_time = timeit.default_timer()
|
@@ -87,13 +87,13 @@ class DefaultWorkflow:
|
|
87
87
|
cfg[Key.CURRENT_ROUND] = current_round
|
88
88
|
|
89
89
|
# Fit round
|
90
|
-
self.fit_workflow(
|
90
|
+
self.fit_workflow(grid, context)
|
91
91
|
|
92
92
|
# Centralized evaluation
|
93
|
-
default_centralized_evaluation_workflow(
|
93
|
+
default_centralized_evaluation_workflow(grid, context)
|
94
94
|
|
95
95
|
# Evaluate round
|
96
|
-
self.evaluate_workflow(
|
96
|
+
self.evaluate_workflow(grid, context)
|
97
97
|
|
98
98
|
# Bookkeeping and log results
|
99
99
|
end_time = timeit.default_timer()
|
@@ -119,7 +119,7 @@ class DefaultWorkflow:
|
|
119
119
|
thread.join()
|
120
120
|
|
121
121
|
|
122
|
-
def default_init_params_workflow(
|
122
|
+
def default_init_params_workflow(grid: Grid, context: Context) -> None:
|
123
123
|
"""Execute the default workflow for parameters initialization."""
|
124
124
|
if not isinstance(context, LegacyContext):
|
125
125
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
@@ -138,9 +138,9 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
138
138
|
random_client = context.client_manager.sample(1)[0]
|
139
139
|
# Send GetParametersIns and get the response
|
140
140
|
content = compat.getparametersins_to_recordset(GetParametersIns({}))
|
141
|
-
messages =
|
141
|
+
messages = grid.send_and_receive(
|
142
142
|
[
|
143
|
-
|
143
|
+
grid.create_message(
|
144
144
|
content=content,
|
145
145
|
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
146
146
|
dst_node_id=random_client.node_id,
|
@@ -186,7 +186,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
186
186
|
log(INFO, "Evaluation returned no results (`None`)")
|
187
187
|
|
188
188
|
|
189
|
-
def default_centralized_evaluation_workflow(_:
|
189
|
+
def default_centralized_evaluation_workflow(_: Grid, context: Context) -> None:
|
190
190
|
"""Execute the default workflow for centralized evaluation."""
|
191
191
|
if not isinstance(context, LegacyContext):
|
192
192
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
@@ -218,9 +218,7 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
|
|
218
218
|
)
|
219
219
|
|
220
220
|
|
221
|
-
def default_fit_workflow( # pylint: disable=R0914
|
222
|
-
driver: Driver, context: Context
|
223
|
-
) -> None:
|
221
|
+
def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disable=R0914
|
224
222
|
"""Execute the default workflow for a single fit round."""
|
225
223
|
if not isinstance(context, LegacyContext):
|
226
224
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
@@ -255,7 +253,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
255
253
|
|
256
254
|
# Build out messages
|
257
255
|
out_messages = [
|
258
|
-
|
256
|
+
grid.create_message(
|
259
257
|
content=compat.fitins_to_recordset(fitins, True),
|
260
258
|
message_type=MessageType.TRAIN,
|
261
259
|
dst_node_id=proxy.node_id,
|
@@ -266,7 +264,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
266
264
|
|
267
265
|
# Send instructions to clients and
|
268
266
|
# collect `fit` results from all clients participating in this round
|
269
|
-
messages = list(
|
267
|
+
messages = list(grid.send_and_receive(out_messages))
|
270
268
|
del out_messages
|
271
269
|
num_failures = len([msg for msg in messages if msg.has_error()])
|
272
270
|
|
@@ -307,7 +305,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
307
305
|
|
308
306
|
|
309
307
|
# pylint: disable-next=R0914
|
310
|
-
def default_evaluate_workflow(
|
308
|
+
def default_evaluate_workflow(grid: Grid, context: Context) -> None:
|
311
309
|
"""Execute the default workflow for a single evaluate round."""
|
312
310
|
if not isinstance(context, LegacyContext):
|
313
311
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
@@ -341,7 +339,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
341
339
|
|
342
340
|
# Build out messages
|
343
341
|
out_messages = [
|
344
|
-
|
342
|
+
grid.create_message(
|
345
343
|
content=compat.evaluateins_to_recordset(evalins, True),
|
346
344
|
message_type=MessageType.EVALUATE,
|
347
345
|
dst_node_id=proxy.node_id,
|
@@ -352,7 +350,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
352
350
|
|
353
351
|
# Send instructions to clients and
|
354
352
|
# collect `evaluate` results from all clients participating in this round
|
355
|
-
messages = list(
|
353
|
+
messages = list(grid.send_and_receive(out_messages))
|
356
354
|
del out_messages
|
357
355
|
num_failures = len([msg for msg in messages if msg.has_error()])
|
358
356
|
|