flwr-nightly 1.8.0.dev20240327__py3-none-any.whl → 1.8.0.dev20240402__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/client/app.py +53 -29
- flwr/client/client_app.py +16 -0
- flwr/client/grpc_rere_client/connection.py +71 -29
- flwr/client/heartbeat.py +72 -0
- flwr/client/rest_client/connection.py +102 -28
- flwr/common/constant.py +20 -0
- flwr/common/logger.py +4 -4
- flwr/common/message.py +53 -14
- flwr/common/retry_invoker.py +24 -13
- flwr/proto/fleet_pb2.py +26 -26
- flwr/proto/fleet_pb2.pyi +5 -0
- flwr/server/compat/driver_client_proxy.py +16 -0
- flwr/server/driver/driver.py +15 -5
- flwr/server/server_app.py +3 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +3 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +28 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -4
- flwr/server/superlink/fleet/vce/vce_api.py +61 -27
- flwr/server/superlink/state/in_memory_state.py +25 -8
- flwr/server/superlink/state/sqlite_state.py +53 -5
- flwr/server/superlink/state/state.py +1 -1
- flwr/server/superlink/state/utils.py +56 -0
- flwr/server/workflow/default_workflows.py +1 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +0 -5
- flwr/simulation/ray_transport/ray_actor.py +8 -24
- {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/RECORD +30 -28
- {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/entry_points.txt +0 -0
|
@@ -14,16 +14,19 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Fleet Simulation Engine API."""
|
|
16
16
|
|
|
17
|
-
|
|
18
17
|
import asyncio
|
|
19
18
|
import json
|
|
19
|
+
import sys
|
|
20
|
+
import time
|
|
20
21
|
import traceback
|
|
21
22
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
22
23
|
from typing import Callable, Dict, List, Optional
|
|
23
24
|
|
|
24
|
-
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
25
|
+
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
25
26
|
from flwr.client.node_state import NodeState
|
|
27
|
+
from flwr.common.constant import PING_MAX_INTERVAL, ErrorCode
|
|
26
28
|
from flwr.common.logger import log
|
|
29
|
+
from flwr.common.message import Error
|
|
27
30
|
from flwr.common.object_ref import load_app
|
|
28
31
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
29
32
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
@@ -41,7 +44,7 @@ def _register_nodes(
|
|
|
41
44
|
nodes_mapping: NodeToPartitionMapping = {}
|
|
42
45
|
state = state_factory.state()
|
|
43
46
|
for i in range(num_nodes):
|
|
44
|
-
node_id = state.create_node()
|
|
47
|
+
node_id = state.create_node(ping_interval=PING_MAX_INTERVAL)
|
|
45
48
|
nodes_mapping[node_id] = i
|
|
46
49
|
log(INFO, "Registered %i nodes", len(nodes_mapping))
|
|
47
50
|
return nodes_mapping
|
|
@@ -59,6 +62,7 @@ async def worker(
|
|
|
59
62
|
"""Get TaskIns from queue and pass it to an actor in the pool to execute it."""
|
|
60
63
|
state = state_factory.state()
|
|
61
64
|
while True:
|
|
65
|
+
out_mssg = None
|
|
62
66
|
try:
|
|
63
67
|
task_ins: TaskIns = await queue.get()
|
|
64
68
|
node_id = task_ins.task.consumer.node_id
|
|
@@ -82,24 +86,34 @@ async def worker(
|
|
|
82
86
|
task_ins.run_id, context=updated_context
|
|
83
87
|
)
|
|
84
88
|
|
|
85
|
-
# Convert to TaskRes
|
|
86
|
-
task_res = message_to_taskres(out_mssg)
|
|
87
|
-
# Store TaskRes in state
|
|
88
|
-
state.store_task_res(task_res)
|
|
89
|
-
|
|
90
89
|
except asyncio.CancelledError as e:
|
|
91
|
-
log(DEBUG, "
|
|
90
|
+
log(DEBUG, "Terminating async worker: %s", e)
|
|
92
91
|
break
|
|
93
92
|
|
|
94
|
-
|
|
95
|
-
log(ERROR, "Async worker: %s", app_ex)
|
|
96
|
-
log(ERROR, traceback.format_exc())
|
|
97
|
-
raise
|
|
98
|
-
|
|
93
|
+
# Exceptions aren't raised but reported as an error message
|
|
99
94
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
100
95
|
log(ERROR, ex)
|
|
101
96
|
log(ERROR, traceback.format_exc())
|
|
102
|
-
|
|
97
|
+
|
|
98
|
+
if isinstance(ex, ClientAppException):
|
|
99
|
+
e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
|
|
100
|
+
elif isinstance(ex, LoadClientAppError):
|
|
101
|
+
e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
|
|
102
|
+
else:
|
|
103
|
+
e_code = ErrorCode.UNKNOWN
|
|
104
|
+
|
|
105
|
+
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
|
|
106
|
+
out_mssg = message.create_error_reply(
|
|
107
|
+
error=Error(code=e_code, reason=reason)
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
finally:
|
|
111
|
+
if out_mssg:
|
|
112
|
+
# Convert to TaskRes
|
|
113
|
+
task_res = message_to_taskres(out_mssg)
|
|
114
|
+
# Store TaskRes in state
|
|
115
|
+
task_res.task.pushed_at = time.time()
|
|
116
|
+
state.store_task_res(task_res)
|
|
103
117
|
|
|
104
118
|
|
|
105
119
|
async def add_taskins_to_queue(
|
|
@@ -218,7 +232,8 @@ async def run(
|
|
|
218
232
|
await backend.terminate()
|
|
219
233
|
|
|
220
234
|
|
|
221
|
-
# pylint: disable=too-many-arguments,unused-argument,too-many-locals
|
|
235
|
+
# pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
|
|
236
|
+
# pylint: disable=too-many-statements
|
|
222
237
|
def start_vce(
|
|
223
238
|
backend_name: str,
|
|
224
239
|
backend_config_json_stream: str,
|
|
@@ -300,12 +315,14 @@ def start_vce(
|
|
|
300
315
|
"""Instantiate a Backend."""
|
|
301
316
|
return backend_type(backend_config, work_dir=app_dir)
|
|
302
317
|
|
|
303
|
-
log(INFO, "client_app_attr = %s", client_app_attr)
|
|
304
|
-
|
|
305
318
|
# Load ClientApp if needed
|
|
306
319
|
def _load() -> ClientApp:
|
|
307
320
|
|
|
308
321
|
if client_app_attr:
|
|
322
|
+
|
|
323
|
+
if app_dir is not None:
|
|
324
|
+
sys.path.insert(0, app_dir)
|
|
325
|
+
|
|
309
326
|
app: ClientApp = load_app(client_app_attr, LoadClientAppError)
|
|
310
327
|
|
|
311
328
|
if not isinstance(app, ClientApp):
|
|
@@ -319,13 +336,30 @@ def start_vce(
|
|
|
319
336
|
|
|
320
337
|
app_fn = _load
|
|
321
338
|
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
339
|
+
try:
|
|
340
|
+
# Test if ClientApp can be loaded
|
|
341
|
+
_ = app_fn()
|
|
342
|
+
|
|
343
|
+
# Run main simulation loop
|
|
344
|
+
asyncio.run(
|
|
345
|
+
run(
|
|
346
|
+
app_fn,
|
|
347
|
+
backend_fn,
|
|
348
|
+
nodes_mapping,
|
|
349
|
+
state_factory,
|
|
350
|
+
node_states,
|
|
351
|
+
f_stop,
|
|
352
|
+
)
|
|
330
353
|
)
|
|
331
|
-
|
|
354
|
+
except LoadClientAppError as loadapp_ex:
|
|
355
|
+
f_stop_delay = 10
|
|
356
|
+
log(
|
|
357
|
+
ERROR,
|
|
358
|
+
"LoadClientAppError exception encountered. Terminating simulation in %is",
|
|
359
|
+
f_stop_delay,
|
|
360
|
+
)
|
|
361
|
+
time.sleep(f_stop_delay)
|
|
362
|
+
f_stop.set() # set termination event
|
|
363
|
+
raise loadapp_ex
|
|
364
|
+
except Exception as ex:
|
|
365
|
+
raise ex
|
|
@@ -27,6 +27,8 @@ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
|
27
27
|
from flwr.server.superlink.state.state import State
|
|
28
28
|
from flwr.server.utils import validate_task_ins_or_res
|
|
29
29
|
|
|
30
|
+
from .utils import make_node_unavailable_taskres
|
|
31
|
+
|
|
30
32
|
|
|
31
33
|
class InMemoryState(State):
|
|
32
34
|
"""In-memory State implementation."""
|
|
@@ -129,15 +131,32 @@ class InMemoryState(State):
|
|
|
129
131
|
with self.lock:
|
|
130
132
|
# Find TaskRes that were not delivered yet
|
|
131
133
|
task_res_list: List[TaskRes] = []
|
|
134
|
+
replied_task_ids: Set[UUID] = set()
|
|
132
135
|
for _, task_res in self.task_res_store.items():
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
and task_res.task.delivered_at == ""
|
|
136
|
-
):
|
|
136
|
+
reply_to = UUID(task_res.task.ancestry[0])
|
|
137
|
+
if reply_to in task_ids and task_res.task.delivered_at == "":
|
|
137
138
|
task_res_list.append(task_res)
|
|
139
|
+
replied_task_ids.add(reply_to)
|
|
138
140
|
if limit and len(task_res_list) == limit:
|
|
139
141
|
break
|
|
140
142
|
|
|
143
|
+
# Check if the node is offline
|
|
144
|
+
for task_id in task_ids - replied_task_ids:
|
|
145
|
+
if limit and len(task_res_list) == limit:
|
|
146
|
+
break
|
|
147
|
+
task_ins = self.task_ins_store.get(task_id)
|
|
148
|
+
if task_ins is None:
|
|
149
|
+
continue
|
|
150
|
+
node_id = task_ins.task.consumer.node_id
|
|
151
|
+
online_until, _ = self.node_ids[node_id]
|
|
152
|
+
# Generate a TaskRes containing an error reply if the node is offline.
|
|
153
|
+
if online_until < time.time():
|
|
154
|
+
err_taskres = make_node_unavailable_taskres(
|
|
155
|
+
ref_taskins=task_ins,
|
|
156
|
+
)
|
|
157
|
+
self.task_res_store[UUID(err_taskres.task_id)] = err_taskres
|
|
158
|
+
task_res_list.append(err_taskres)
|
|
159
|
+
|
|
141
160
|
# Mark all of them as delivered
|
|
142
161
|
delivered_at = now().isoformat()
|
|
143
162
|
for task_res in task_res_list:
|
|
@@ -182,16 +201,14 @@ class InMemoryState(State):
|
|
|
182
201
|
"""
|
|
183
202
|
return len(self.task_res_store)
|
|
184
203
|
|
|
185
|
-
def create_node(self) -> int:
|
|
204
|
+
def create_node(self, ping_interval: float) -> int:
|
|
186
205
|
"""Create, store in state, and return `node_id`."""
|
|
187
206
|
# Sample a random int64 as node_id
|
|
188
207
|
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
189
208
|
|
|
190
209
|
with self.lock:
|
|
191
210
|
if node_id not in self.node_ids:
|
|
192
|
-
|
|
193
|
-
# TODO: change 1e9 to 30s # pylint: disable=W0511
|
|
194
|
-
self.node_ids[node_id] = (time.time() + 1e9, 1e9)
|
|
211
|
+
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
195
212
|
return node_id
|
|
196
213
|
log(ERROR, "Unexpected node registration failure.")
|
|
197
214
|
return 0
|
|
@@ -30,6 +30,7 @@ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
|
30
30
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
31
31
|
|
|
32
32
|
from .state import State
|
|
33
|
+
from .utils import make_node_unavailable_taskres
|
|
33
34
|
|
|
34
35
|
SQL_CREATE_TABLE_NODE = """
|
|
35
36
|
CREATE TABLE IF NOT EXISTS node(
|
|
@@ -344,6 +345,7 @@ class SqliteState(State):
|
|
|
344
345
|
|
|
345
346
|
return task_id
|
|
346
347
|
|
|
348
|
+
# pylint: disable-next=R0914
|
|
347
349
|
def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]:
|
|
348
350
|
"""Get TaskRes for task_ids.
|
|
349
351
|
|
|
@@ -374,7 +376,7 @@ class SqliteState(State):
|
|
|
374
376
|
AND delivered_at = ""
|
|
375
377
|
"""
|
|
376
378
|
|
|
377
|
-
data: Dict[str, Union[str, int]] = {}
|
|
379
|
+
data: Dict[str, Union[str, float, int]] = {}
|
|
378
380
|
|
|
379
381
|
if limit is not None:
|
|
380
382
|
query += " LIMIT :limit"
|
|
@@ -408,6 +410,54 @@ class SqliteState(State):
|
|
|
408
410
|
rows = self.query(query, data)
|
|
409
411
|
|
|
410
412
|
result = [dict_to_task_res(row) for row in rows]
|
|
413
|
+
|
|
414
|
+
# 1. Query: Fetch consumer_node_id of remaining task_ids
|
|
415
|
+
# Assume the ancestry field only contains one element
|
|
416
|
+
data.clear()
|
|
417
|
+
replied_task_ids: Set[UUID] = {UUID(str(row["ancestry"])) for row in rows}
|
|
418
|
+
remaining_task_ids = task_ids - replied_task_ids
|
|
419
|
+
placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))])
|
|
420
|
+
query = f"""
|
|
421
|
+
SELECT consumer_node_id
|
|
422
|
+
FROM task_ins
|
|
423
|
+
WHERE task_id IN ({placeholders});
|
|
424
|
+
"""
|
|
425
|
+
for index, task_id in enumerate(remaining_task_ids):
|
|
426
|
+
data[f"id_{index}"] = str(task_id)
|
|
427
|
+
node_ids = [int(row["consumer_node_id"]) for row in self.query(query, data)]
|
|
428
|
+
|
|
429
|
+
# 2. Query: Select offline nodes
|
|
430
|
+
placeholders = ",".join([f":id_{i}" for i in range(len(node_ids))])
|
|
431
|
+
query = f"""
|
|
432
|
+
SELECT node_id
|
|
433
|
+
FROM node
|
|
434
|
+
WHERE node_id IN ({placeholders})
|
|
435
|
+
AND online_until < :time;
|
|
436
|
+
"""
|
|
437
|
+
data = {f"id_{i}": str(node_id) for i, node_id in enumerate(node_ids)}
|
|
438
|
+
data["time"] = time.time()
|
|
439
|
+
offline_node_ids = [int(row["node_id"]) for row in self.query(query, data)]
|
|
440
|
+
|
|
441
|
+
# 3. Query: Select TaskIns for offline nodes
|
|
442
|
+
placeholders = ",".join([f":id_{i}" for i in range(len(offline_node_ids))])
|
|
443
|
+
query = f"""
|
|
444
|
+
SELECT *
|
|
445
|
+
FROM task_ins
|
|
446
|
+
WHERE consumer_node_id IN ({placeholders});
|
|
447
|
+
"""
|
|
448
|
+
data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)}
|
|
449
|
+
task_ins_rows = self.query(query, data)
|
|
450
|
+
|
|
451
|
+
# Make TaskRes containing node unavailabe error
|
|
452
|
+
for row in task_ins_rows:
|
|
453
|
+
if limit and len(result) == limit:
|
|
454
|
+
break
|
|
455
|
+
task_ins = dict_to_task_ins(row)
|
|
456
|
+
err_taskres = make_node_unavailable_taskres(
|
|
457
|
+
ref_taskins=task_ins,
|
|
458
|
+
)
|
|
459
|
+
result.append(err_taskres)
|
|
460
|
+
|
|
411
461
|
return result
|
|
412
462
|
|
|
413
463
|
def num_task_ins(self) -> int:
|
|
@@ -468,7 +518,7 @@ class SqliteState(State):
|
|
|
468
518
|
|
|
469
519
|
return None
|
|
470
520
|
|
|
471
|
-
def create_node(self) -> int:
|
|
521
|
+
def create_node(self, ping_interval: float) -> int:
|
|
472
522
|
"""Create, store in state, and return `node_id`."""
|
|
473
523
|
# Sample a random int64 as node_id
|
|
474
524
|
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
@@ -478,9 +528,7 @@ class SqliteState(State):
|
|
|
478
528
|
)
|
|
479
529
|
|
|
480
530
|
try:
|
|
481
|
-
|
|
482
|
-
# TODO: change 1e9 to 30s # pylint: disable=W0511
|
|
483
|
-
self.query(query, (node_id, time.time() + 1e9, 1e9))
|
|
531
|
+
self.query(query, (node_id, time.time() + ping_interval, ping_interval))
|
|
484
532
|
except sqlite3.IntegrityError:
|
|
485
533
|
log(ERROR, "Unexpected node registration failure.")
|
|
486
534
|
return 0
|
|
@@ -132,7 +132,7 @@ class State(abc.ABC):
|
|
|
132
132
|
"""Delete all delivered TaskIns/TaskRes pairs."""
|
|
133
133
|
|
|
134
134
|
@abc.abstractmethod
|
|
135
|
-
def create_node(self) -> int:
|
|
135
|
+
def create_node(self, ping_interval: float) -> int:
|
|
136
136
|
"""Create, store in state, and return `node_id`."""
|
|
137
137
|
|
|
138
138
|
@abc.abstractmethod
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Utility functions for State."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import time
|
|
19
|
+
from logging import ERROR
|
|
20
|
+
from uuid import uuid4
|
|
21
|
+
|
|
22
|
+
from flwr.common import log
|
|
23
|
+
from flwr.common.constant import ErrorCode
|
|
24
|
+
from flwr.proto.error_pb2 import Error # pylint: disable=E0611
|
|
25
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
26
|
+
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
27
|
+
|
|
28
|
+
NODE_UNAVAILABLE_ERROR_REASON = (
|
|
29
|
+
"Error: Node Unavailable - The destination node is currently unavailable. "
|
|
30
|
+
"It exceeds the time limit specified in its last ping."
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
|
|
35
|
+
"""Generate a TaskRes with a node unavailable error from a TaskIns."""
|
|
36
|
+
current_time = time.time()
|
|
37
|
+
ttl = ref_taskins.task.ttl - (current_time - ref_taskins.task.created_at)
|
|
38
|
+
if ttl < 0:
|
|
39
|
+
log(ERROR, "Creating TaskRes for TaskIns that exceeds its TTL.")
|
|
40
|
+
ttl = 0
|
|
41
|
+
return TaskRes(
|
|
42
|
+
task_id=str(uuid4()),
|
|
43
|
+
group_id=ref_taskins.group_id,
|
|
44
|
+
run_id=ref_taskins.run_id,
|
|
45
|
+
task=Task(
|
|
46
|
+
producer=Node(node_id=ref_taskins.task.consumer.node_id, anonymous=False),
|
|
47
|
+
consumer=Node(node_id=ref_taskins.task.producer.node_id, anonymous=False),
|
|
48
|
+
created_at=current_time,
|
|
49
|
+
ttl=ttl,
|
|
50
|
+
ancestry=[ref_taskins.task_id],
|
|
51
|
+
task_type=ref_taskins.task.task_type,
|
|
52
|
+
error=Error(
|
|
53
|
+
code=ErrorCode.NODE_UNAVAILABLE, reason=NODE_UNAVAILABLE_ERROR_REASON
|
|
54
|
+
),
|
|
55
|
+
),
|
|
56
|
+
)
|
|
@@ -21,7 +21,7 @@ from logging import INFO
|
|
|
21
21
|
from typing import Optional, cast
|
|
22
22
|
|
|
23
23
|
import flwr.common.recordset_compat as compat
|
|
24
|
-
from flwr.common import
|
|
24
|
+
from flwr.common import ConfigsRecord, Context, GetParametersIns, log
|
|
25
25
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
26
26
|
|
|
27
27
|
from ..compat.app_utils import start_update_client_manager_thread
|
|
@@ -127,7 +127,6 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
127
127
|
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
|
128
128
|
dst_node_id=random_client.node_id,
|
|
129
129
|
group_id="0",
|
|
130
|
-
ttl=DEFAULT_TTL,
|
|
131
130
|
)
|
|
132
131
|
]
|
|
133
132
|
)
|
|
@@ -226,7 +225,6 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
226
225
|
message_type=MessageType.TRAIN,
|
|
227
226
|
dst_node_id=proxy.node_id,
|
|
228
227
|
group_id=str(current_round),
|
|
229
|
-
ttl=DEFAULT_TTL,
|
|
230
228
|
)
|
|
231
229
|
for proxy, fitins in client_instructions
|
|
232
230
|
]
|
|
@@ -306,7 +304,6 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
306
304
|
message_type=MessageType.EVALUATE,
|
|
307
305
|
dst_node_id=proxy.node_id,
|
|
308
306
|
group_id=str(current_round),
|
|
309
|
-
ttl=DEFAULT_TTL,
|
|
310
307
|
)
|
|
311
308
|
for proxy, evalins in client_instructions
|
|
312
309
|
]
|
|
@@ -22,7 +22,6 @@ from typing import Dict, List, Optional, Set, Tuple, Union, cast
|
|
|
22
22
|
|
|
23
23
|
import flwr.common.recordset_compat as compat
|
|
24
24
|
from flwr.common import (
|
|
25
|
-
DEFAULT_TTL,
|
|
26
25
|
ConfigsRecord,
|
|
27
26
|
Context,
|
|
28
27
|
FitRes,
|
|
@@ -374,7 +373,6 @@ class SecAggPlusWorkflow:
|
|
|
374
373
|
message_type=MessageType.TRAIN,
|
|
375
374
|
dst_node_id=nid,
|
|
376
375
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
|
377
|
-
ttl=DEFAULT_TTL,
|
|
378
376
|
)
|
|
379
377
|
|
|
380
378
|
log(
|
|
@@ -422,7 +420,6 @@ class SecAggPlusWorkflow:
|
|
|
422
420
|
message_type=MessageType.TRAIN,
|
|
423
421
|
dst_node_id=nid,
|
|
424
422
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
|
425
|
-
ttl=DEFAULT_TTL,
|
|
426
423
|
)
|
|
427
424
|
|
|
428
425
|
# Broadcast public keys to clients and receive secret key shares
|
|
@@ -493,7 +490,6 @@ class SecAggPlusWorkflow:
|
|
|
493
490
|
message_type=MessageType.TRAIN,
|
|
494
491
|
dst_node_id=nid,
|
|
495
492
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
|
496
|
-
ttl=DEFAULT_TTL,
|
|
497
493
|
)
|
|
498
494
|
|
|
499
495
|
log(
|
|
@@ -564,7 +560,6 @@ class SecAggPlusWorkflow:
|
|
|
564
560
|
message_type=MessageType.TRAIN,
|
|
565
561
|
dst_node_id=nid,
|
|
566
562
|
group_id=str(current_round),
|
|
567
|
-
ttl=DEFAULT_TTL,
|
|
568
563
|
)
|
|
569
564
|
|
|
570
565
|
log(
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
import asyncio
|
|
18
18
|
import threading
|
|
19
|
-
import traceback
|
|
20
19
|
from abc import ABC
|
|
21
20
|
from logging import DEBUG, ERROR, WARNING
|
|
22
21
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
|
@@ -25,22 +24,13 @@ import ray
|
|
|
25
24
|
from ray import ObjectRef
|
|
26
25
|
from ray.util.actor_pool import ActorPool
|
|
27
26
|
|
|
28
|
-
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
27
|
+
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
29
28
|
from flwr.common import Context, Message
|
|
30
29
|
from flwr.common.logger import log
|
|
31
30
|
|
|
32
31
|
ClientAppFn = Callable[[], ClientApp]
|
|
33
32
|
|
|
34
33
|
|
|
35
|
-
class ClientException(Exception):
|
|
36
|
-
"""Raised when client side logic crashes with an exception."""
|
|
37
|
-
|
|
38
|
-
def __init__(self, message: str):
|
|
39
|
-
div = ">" * 7
|
|
40
|
-
self.message = "\n" + div + "A ClientException occurred." + message
|
|
41
|
-
super().__init__(self.message)
|
|
42
|
-
|
|
43
|
-
|
|
44
34
|
class VirtualClientEngineActor(ABC):
|
|
45
35
|
"""Abstract base class for VirtualClientEngine Actors."""
|
|
46
36
|
|
|
@@ -71,17 +61,7 @@ class VirtualClientEngineActor(ABC):
|
|
|
71
61
|
raise load_ex
|
|
72
62
|
|
|
73
63
|
except Exception as ex:
|
|
74
|
-
|
|
75
|
-
mssg = (
|
|
76
|
-
"\n\tSomething went wrong when running your client run."
|
|
77
|
-
"\n\tClient "
|
|
78
|
-
+ cid
|
|
79
|
-
+ " crashed when the "
|
|
80
|
-
+ self.__class__.__name__
|
|
81
|
-
+ " was running its run."
|
|
82
|
-
"\n\tException triggered on the client side: " + client_trace,
|
|
83
|
-
)
|
|
84
|
-
raise ClientException(str(mssg)) from ex
|
|
64
|
+
raise ClientAppException(str(ex)) from ex
|
|
85
65
|
|
|
86
66
|
return cid, out_message, context
|
|
87
67
|
|
|
@@ -493,13 +473,17 @@ class BasicActorPool:
|
|
|
493
473
|
self._future_to_actor[future] = actor
|
|
494
474
|
return future
|
|
495
475
|
|
|
476
|
+
async def add_actor_back_to_pool(self, future: Any) -> None:
|
|
477
|
+
"""Ad actor assigned to run future back into the pool."""
|
|
478
|
+
actor = self._future_to_actor.pop(future)
|
|
479
|
+
await self.pool.put(actor)
|
|
480
|
+
|
|
496
481
|
async def fetch_result_and_return_actor_to_pool(
|
|
497
482
|
self, future: Any
|
|
498
483
|
) -> Tuple[Message, Context]:
|
|
499
484
|
"""Pull result given a future and add actor back to pool."""
|
|
500
485
|
# Get actor that ran job
|
|
501
|
-
|
|
502
|
-
await self.pool.put(actor)
|
|
486
|
+
await self.add_actor_back_to_pool(future)
|
|
503
487
|
# Retrieve result for object store
|
|
504
488
|
# Instead of doing ray.get(future) we await it
|
|
505
489
|
_, out_mssg, updated_context = await future
|