flwr-nightly 1.10.0.dev20240710__py3-none-any.whl → 1.10.0.dev20240712__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/config_utils.py +10 -0
- flwr/cli/run/run.py +25 -8
- flwr/client/app.py +49 -17
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/connection.py +1 -1
- flwr/client/grpc_rere_client/connection.py +3 -2
- flwr/client/node_state.py +44 -11
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/connection.py +4 -3
- flwr/client/supernode/app.py +14 -7
- flwr/common/config.py +3 -3
- flwr/common/context.py +13 -2
- flwr/common/logger.py +25 -0
- flwr/server/__init__.py +2 -0
- flwr/server/compat/legacy_context.py +1 -1
- flwr/server/run_serverapp.py +3 -1
- flwr/server/server_app.py +56 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
- flwr/server/superlink/fleet/vce/backend/raybackend.py +8 -9
- flwr/server/superlink/fleet/vce/vce_api.py +88 -121
- flwr/server/typing.py +2 -0
- flwr/simulation/ray_transport/ray_actor.py +15 -19
- flwr/simulation/ray_transport/ray_client_proxy.py +3 -1
- flwr/simulation/run_simulation.py +49 -33
- flwr/superexec/app.py +3 -3
- {flwr_nightly-1.10.0.dev20240710.dist-info → flwr_nightly-1.10.0.dev20240712.dist-info}/METADATA +2 -2
- {flwr_nightly-1.10.0.dev20240710.dist-info → flwr_nightly-1.10.0.dev20240712.dist-info}/RECORD +31 -30
- {flwr_nightly-1.10.0.dev20240710.dist-info → flwr_nightly-1.10.0.dev20240712.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240710.dist-info → flwr_nightly-1.10.0.dev20240712.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240710.dist-info → flwr_nightly-1.10.0.dev20240712.dist-info}/entry_points.txt +0 -0
flwr/server/server_app.py
CHANGED
|
@@ -17,8 +17,11 @@
|
|
|
17
17
|
|
|
18
18
|
from typing import Callable, Optional
|
|
19
19
|
|
|
20
|
-
from flwr.common import Context
|
|
21
|
-
from flwr.common.logger import
|
|
20
|
+
from flwr.common import Context
|
|
21
|
+
from flwr.common.logger import (
|
|
22
|
+
warn_deprecated_feature_with_example,
|
|
23
|
+
warn_preview_feature,
|
|
24
|
+
)
|
|
22
25
|
from flwr.server.strategy import Strategy
|
|
23
26
|
|
|
24
27
|
from .client_manager import ClientManager
|
|
@@ -26,7 +29,20 @@ from .compat import start_driver
|
|
|
26
29
|
from .driver import Driver
|
|
27
30
|
from .server import Server
|
|
28
31
|
from .server_config import ServerConfig
|
|
29
|
-
from .typing import ServerAppCallable
|
|
32
|
+
from .typing import ServerAppCallable, ServerFn
|
|
33
|
+
|
|
34
|
+
SERVER_FN_USAGE_EXAMPLE = """
|
|
35
|
+
|
|
36
|
+
def server_fn(context: Context):
|
|
37
|
+
server_config = ServerConfig(num_rounds=3)
|
|
38
|
+
strategy = FedAvg()
|
|
39
|
+
return ServerAppComponents(
|
|
40
|
+
strategy=strategy,
|
|
41
|
+
server_config=server_config,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
app = ServerApp(server_fn=server_fn)
|
|
45
|
+
"""
|
|
30
46
|
|
|
31
47
|
|
|
32
48
|
class ServerApp:
|
|
@@ -36,13 +52,15 @@ class ServerApp:
|
|
|
36
52
|
--------
|
|
37
53
|
Use the `ServerApp` with an existing `Strategy`:
|
|
38
54
|
|
|
39
|
-
>>>
|
|
40
|
-
>>>
|
|
55
|
+
>>> def server_fn(context: Context):
|
|
56
|
+
>>> server_config = ServerConfig(num_rounds=3)
|
|
57
|
+
>>> strategy = FedAvg()
|
|
58
|
+
>>> return ServerAppComponents(
|
|
59
|
+
>>> strategy=strategy,
|
|
60
|
+
>>> server_config=server_config,
|
|
61
|
+
>>> )
|
|
41
62
|
>>>
|
|
42
|
-
>>> app = ServerApp(
|
|
43
|
-
>>> server_config=server_config,
|
|
44
|
-
>>> strategy=strategy,
|
|
45
|
-
>>> )
|
|
63
|
+
>>> app = ServerApp(server_fn=server_fn)
|
|
46
64
|
|
|
47
65
|
Use the `ServerApp` with a custom main function:
|
|
48
66
|
|
|
@@ -53,23 +71,52 @@ class ServerApp:
|
|
|
53
71
|
>>> print("ServerApp running")
|
|
54
72
|
"""
|
|
55
73
|
|
|
74
|
+
# pylint: disable=too-many-arguments
|
|
56
75
|
def __init__(
|
|
57
76
|
self,
|
|
58
77
|
server: Optional[Server] = None,
|
|
59
78
|
config: Optional[ServerConfig] = None,
|
|
60
79
|
strategy: Optional[Strategy] = None,
|
|
61
80
|
client_manager: Optional[ClientManager] = None,
|
|
81
|
+
server_fn: Optional[ServerFn] = None,
|
|
62
82
|
) -> None:
|
|
83
|
+
if any([server, config, strategy, client_manager]):
|
|
84
|
+
warn_deprecated_feature_with_example(
|
|
85
|
+
deprecation_message="Passing either `server`, `config`, `strategy` or "
|
|
86
|
+
"`client_manager` directly to the ServerApp "
|
|
87
|
+
"constructor is deprecated.",
|
|
88
|
+
example_message="Pass `ServerApp` arguments wrapped "
|
|
89
|
+
"in a `flwr.server.ServerAppComponents` object that gets "
|
|
90
|
+
"returned by a function passed as the `server_fn` argument "
|
|
91
|
+
"to the `ServerApp` constructor. For example: ",
|
|
92
|
+
code_example=SERVER_FN_USAGE_EXAMPLE,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if server_fn:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"Passing `server_fn` is incompatible with passing the "
|
|
98
|
+
"other arguments (now deprecated) to ServerApp. "
|
|
99
|
+
"Use `server_fn` exclusively."
|
|
100
|
+
)
|
|
101
|
+
|
|
63
102
|
self._server = server
|
|
64
103
|
self._config = config
|
|
65
104
|
self._strategy = strategy
|
|
66
105
|
self._client_manager = client_manager
|
|
106
|
+
self._server_fn = server_fn
|
|
67
107
|
self._main: Optional[ServerAppCallable] = None
|
|
68
108
|
|
|
69
109
|
def __call__(self, driver: Driver, context: Context) -> None:
|
|
70
110
|
"""Execute `ServerApp`."""
|
|
71
111
|
# Compatibility mode
|
|
72
112
|
if not self._main:
|
|
113
|
+
if self._server_fn:
|
|
114
|
+
# Execute server_fn()
|
|
115
|
+
components = self._server_fn(context)
|
|
116
|
+
self._server = components.server
|
|
117
|
+
self._config = components.config
|
|
118
|
+
self._strategy = components.strategy
|
|
119
|
+
self._client_manager = components.client_manager
|
|
73
120
|
start_driver(
|
|
74
121
|
server=self._server,
|
|
75
122
|
config=self._config,
|
|
@@ -80,7 +127,6 @@ class ServerApp:
|
|
|
80
127
|
return
|
|
81
128
|
|
|
82
129
|
# New execution mode
|
|
83
|
-
context = Context(state=RecordSet(), run_config={})
|
|
84
130
|
self._main(driver, context)
|
|
85
131
|
|
|
86
132
|
def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""ServerAppComponents for the ServerApp."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
21
|
+
from .client_manager import ClientManager
|
|
22
|
+
from .server import Server
|
|
23
|
+
from .server_config import ServerConfig
|
|
24
|
+
from .strategy import Strategy
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class ServerAppComponents: # pylint: disable=too-many-instance-attributes
|
|
29
|
+
"""Components to construct a ServerApp.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
server : Optional[Server] (default: None)
|
|
34
|
+
A server implementation, either `flwr.server.Server` or a subclass
|
|
35
|
+
thereof. If no instance is provided, one will be created internally.
|
|
36
|
+
config : Optional[ServerConfig] (default: None)
|
|
37
|
+
Currently supported values are `num_rounds` (int, default: 1) and
|
|
38
|
+
`round_timeout` in seconds (float, default: None).
|
|
39
|
+
strategy : Optional[Strategy] (default: None)
|
|
40
|
+
An implementation of the abstract base class
|
|
41
|
+
`flwr.server.strategy.Strategy`. If no strategy is provided, then
|
|
42
|
+
`flwr.server.strategy.FedAvg` will be used.
|
|
43
|
+
client_manager : Optional[ClientManager] (default: None)
|
|
44
|
+
An implementation of the class `flwr.server.ClientManager`. If no
|
|
45
|
+
implementation is provided, then `flwr.server.SimpleClientManager`
|
|
46
|
+
will be used.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
server: Optional[Server] = None
|
|
50
|
+
config: Optional[ServerConfig] = None
|
|
51
|
+
strategy: Optional[Strategy] = None
|
|
52
|
+
client_manager: Optional[ClientManager] = None
|
|
@@ -33,8 +33,8 @@ class Backend(ABC):
|
|
|
33
33
|
"""Construct a backend."""
|
|
34
34
|
|
|
35
35
|
@abstractmethod
|
|
36
|
-
|
|
37
|
-
"""Build backend
|
|
36
|
+
def build(self) -> None:
|
|
37
|
+
"""Build backend.
|
|
38
38
|
|
|
39
39
|
Different components need to be in place before workers in a backend are ready
|
|
40
40
|
to accept jobs. When this method finishes executing, the backend should be fully
|
|
@@ -54,11 +54,11 @@ class Backend(ABC):
|
|
|
54
54
|
"""Report whether a backend worker is idle and can therefore run a ClientApp."""
|
|
55
55
|
|
|
56
56
|
@abstractmethod
|
|
57
|
-
|
|
57
|
+
def terminate(self) -> None:
|
|
58
58
|
"""Terminate backend."""
|
|
59
59
|
|
|
60
60
|
@abstractmethod
|
|
61
|
-
|
|
61
|
+
def process_message(
|
|
62
62
|
self,
|
|
63
63
|
app: Callable[[], ClientApp],
|
|
64
64
|
message: Message,
|
|
@@ -153,12 +153,12 @@ class RayBackend(Backend):
|
|
|
153
153
|
"""Report whether the pool has idle actors."""
|
|
154
154
|
return self.pool.is_actor_available()
|
|
155
155
|
|
|
156
|
-
|
|
156
|
+
def build(self) -> None:
|
|
157
157
|
"""Build pool of Ray actors that this backend will submit jobs to."""
|
|
158
|
-
|
|
158
|
+
self.pool.add_actors_to_pool(self.pool.actors_capacity)
|
|
159
159
|
log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
|
|
160
160
|
|
|
161
|
-
|
|
161
|
+
def process_message(
|
|
162
162
|
self,
|
|
163
163
|
app: Callable[[], ClientApp],
|
|
164
164
|
message: Message,
|
|
@@ -172,17 +172,16 @@ class RayBackend(Backend):
|
|
|
172
172
|
|
|
173
173
|
try:
|
|
174
174
|
# Submit a task to the pool
|
|
175
|
-
future =
|
|
175
|
+
future = self.pool.submit(
|
|
176
176
|
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
|
|
177
177
|
(app, message, str(partition_id), context),
|
|
178
178
|
)
|
|
179
179
|
|
|
180
|
-
await future
|
|
181
180
|
# Fetch result
|
|
182
181
|
(
|
|
183
182
|
out_mssg,
|
|
184
183
|
updated_context,
|
|
185
|
-
) =
|
|
184
|
+
) = self.pool.fetch_result_and_return_actor_to_pool(future)
|
|
186
185
|
|
|
187
186
|
return out_mssg, updated_context
|
|
188
187
|
|
|
@@ -193,11 +192,11 @@ class RayBackend(Backend):
|
|
|
193
192
|
self.__class__.__name__,
|
|
194
193
|
)
|
|
195
194
|
# add actor back into pool
|
|
196
|
-
|
|
195
|
+
self.pool.add_actor_back_to_pool(future)
|
|
197
196
|
raise ex
|
|
198
197
|
|
|
199
|
-
|
|
198
|
+
def terminate(self) -> None:
|
|
200
199
|
"""Terminate all actors in actor pool."""
|
|
201
|
-
|
|
200
|
+
self.pool.terminate_all_actors()
|
|
202
201
|
ray.shutdown()
|
|
203
202
|
log(DEBUG, "Terminated %s", self.__class__.__name__)
|
|
@@ -14,14 +14,18 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Fleet Simulation Engine API."""
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
|
|
18
18
|
import json
|
|
19
19
|
import sys
|
|
20
|
+
import threading
|
|
20
21
|
import time
|
|
21
22
|
import traceback
|
|
23
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
22
24
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
23
25
|
from pathlib import Path
|
|
24
|
-
from
|
|
26
|
+
from queue import Empty, Queue
|
|
27
|
+
from time import sleep
|
|
28
|
+
from typing import Callable, Dict, Optional
|
|
25
29
|
|
|
26
30
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
27
31
|
from flwr.client.node_state import NodeState
|
|
@@ -31,7 +35,7 @@ from flwr.common.message import Error
|
|
|
31
35
|
from flwr.common.object_ref import load_app
|
|
32
36
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
33
37
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
34
|
-
from flwr.server.superlink.state import StateFactory
|
|
38
|
+
from flwr.server.superlink.state import State, StateFactory
|
|
35
39
|
|
|
36
40
|
from .backend import Backend, error_messages_backends, supported_backends
|
|
37
41
|
|
|
@@ -52,18 +56,21 @@ def _register_nodes(
|
|
|
52
56
|
|
|
53
57
|
|
|
54
58
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
55
|
-
|
|
59
|
+
def worker(
|
|
56
60
|
app_fn: Callable[[], ClientApp],
|
|
57
|
-
taskins_queue: "
|
|
58
|
-
taskres_queue: "
|
|
61
|
+
taskins_queue: "Queue[TaskIns]",
|
|
62
|
+
taskres_queue: "Queue[TaskRes]",
|
|
59
63
|
node_states: Dict[int, NodeState],
|
|
60
64
|
backend: Backend,
|
|
65
|
+
f_stop: threading.Event,
|
|
61
66
|
) -> None:
|
|
62
67
|
"""Get TaskIns from queue and pass it to an actor in the pool to execute it."""
|
|
63
|
-
while
|
|
68
|
+
while not f_stop.is_set():
|
|
64
69
|
out_mssg = None
|
|
65
70
|
try:
|
|
66
|
-
|
|
71
|
+
# Fetch from queue with timeout. We use a timeout so
|
|
72
|
+
# the stopping event can be evaluated even when the queue is empty.
|
|
73
|
+
task_ins: TaskIns = taskins_queue.get(timeout=1.0)
|
|
67
74
|
node_id = task_ins.task.consumer.node_id
|
|
68
75
|
|
|
69
76
|
# Register and retrieve runstate
|
|
@@ -74,7 +81,7 @@ async def worker(
|
|
|
74
81
|
message = message_from_taskins(task_ins)
|
|
75
82
|
|
|
76
83
|
# Let backend process message
|
|
77
|
-
out_mssg, updated_context =
|
|
84
|
+
out_mssg, updated_context = backend.process_message(
|
|
78
85
|
app_fn, message, context
|
|
79
86
|
)
|
|
80
87
|
|
|
@@ -82,11 +89,9 @@ async def worker(
|
|
|
82
89
|
node_states[node_id].update_context(
|
|
83
90
|
task_ins.run_id, context=updated_context
|
|
84
91
|
)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
break
|
|
89
|
-
|
|
92
|
+
except Empty:
|
|
93
|
+
# An exception raised if queue.get times out
|
|
94
|
+
pass
|
|
90
95
|
# Exceptions aren't raised but reported as an error message
|
|
91
96
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
92
97
|
log(ERROR, ex)
|
|
@@ -110,83 +115,48 @@ async def worker(
|
|
|
110
115
|
task_res = message_to_taskres(out_mssg)
|
|
111
116
|
# Store TaskRes in state
|
|
112
117
|
task_res.task.pushed_at = time.time()
|
|
113
|
-
|
|
118
|
+
taskres_queue.put(task_res)
|
|
114
119
|
|
|
115
120
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
121
|
+
def add_taskins_to_queue(
|
|
122
|
+
state: State,
|
|
123
|
+
queue: "Queue[TaskIns]",
|
|
119
124
|
nodes_mapping: NodeToPartitionMapping,
|
|
120
|
-
|
|
121
|
-
consumers: List["asyncio.Task[None]"],
|
|
122
|
-
f_stop: asyncio.Event,
|
|
125
|
+
f_stop: threading.Event,
|
|
123
126
|
) -> None:
|
|
124
|
-
"""
|
|
125
|
-
state = state_factory.state()
|
|
126
|
-
num_initial_consumers = len(consumers)
|
|
127
|
+
"""Put TaskIns in a queue from State."""
|
|
127
128
|
while not f_stop.is_set():
|
|
128
129
|
for node_id in nodes_mapping.keys():
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
# Count consumers that are running
|
|
134
|
-
num_active = sum(not (cc.done()) for cc in consumers)
|
|
135
|
-
|
|
136
|
-
# Alert if number of consumers decreased by half
|
|
137
|
-
if num_active < num_initial_consumers // 2:
|
|
138
|
-
log(
|
|
139
|
-
WARN,
|
|
140
|
-
"Number of active workers has more than halved: (%i/%i active)",
|
|
141
|
-
num_active,
|
|
142
|
-
num_initial_consumers,
|
|
143
|
-
)
|
|
130
|
+
task_ins_list = state.get_task_ins(node_id=node_id, limit=1)
|
|
131
|
+
for task_ins in task_ins_list:
|
|
132
|
+
queue.put(task_ins)
|
|
133
|
+
sleep(0.1)
|
|
144
134
|
|
|
145
|
-
# Break if consumers died
|
|
146
|
-
if num_active == 0:
|
|
147
|
-
raise RuntimeError("All workers have died. Ending Simulation.")
|
|
148
135
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
DEBUG,
|
|
152
|
-
"Simulation Engine stats: "
|
|
153
|
-
"Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)",
|
|
154
|
-
num_active,
|
|
155
|
-
num_initial_consumers,
|
|
156
|
-
backend.__class__.__name__,
|
|
157
|
-
backend.num_workers,
|
|
158
|
-
queue.qsize(),
|
|
159
|
-
)
|
|
160
|
-
await asyncio.sleep(1.0)
|
|
161
|
-
log(DEBUG, "Async producer: Stopped pulling from StateFactory.")
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
async def put_taskres_into_state(
|
|
165
|
-
queue: "asyncio.Queue[TaskRes]",
|
|
166
|
-
state_factory: StateFactory,
|
|
167
|
-
f_stop: asyncio.Event,
|
|
136
|
+
def put_taskres_into_state(
|
|
137
|
+
state: State, queue: "Queue[TaskRes]", f_stop: threading.Event
|
|
168
138
|
) -> None:
|
|
169
|
-
"""
|
|
170
|
-
state = state_factory.state()
|
|
139
|
+
"""Put TaskRes into State from a queue."""
|
|
171
140
|
while not f_stop.is_set():
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
state.store_task_res(
|
|
175
|
-
|
|
176
|
-
|
|
141
|
+
try:
|
|
142
|
+
taskres = queue.get(timeout=1.0)
|
|
143
|
+
state.store_task_res(taskres)
|
|
144
|
+
except Empty:
|
|
145
|
+
# queue is empty when timeout was triggered
|
|
146
|
+
pass
|
|
177
147
|
|
|
178
148
|
|
|
179
|
-
|
|
149
|
+
def run(
|
|
180
150
|
app_fn: Callable[[], ClientApp],
|
|
181
151
|
backend_fn: Callable[[], Backend],
|
|
182
152
|
nodes_mapping: NodeToPartitionMapping,
|
|
183
153
|
state_factory: StateFactory,
|
|
184
154
|
node_states: Dict[int, NodeState],
|
|
185
|
-
f_stop:
|
|
155
|
+
f_stop: threading.Event,
|
|
186
156
|
) -> None:
|
|
187
|
-
"""Run the VCE
|
|
188
|
-
taskins_queue: "
|
|
189
|
-
taskres_queue: "
|
|
157
|
+
"""Run the VCE."""
|
|
158
|
+
taskins_queue: "Queue[TaskIns]" = Queue()
|
|
159
|
+
taskres_queue: "Queue[TaskRes]" = Queue()
|
|
190
160
|
|
|
191
161
|
try:
|
|
192
162
|
|
|
@@ -194,42 +164,48 @@ async def run(
|
|
|
194
164
|
backend = backend_fn()
|
|
195
165
|
|
|
196
166
|
# Build backend
|
|
197
|
-
|
|
167
|
+
backend.build()
|
|
198
168
|
|
|
199
169
|
# Add workers (they submit Messages to Backend)
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
node_states,
|
|
207
|
-
backend,
|
|
208
|
-
)
|
|
209
|
-
)
|
|
210
|
-
for _ in range(backend.num_workers)
|
|
211
|
-
]
|
|
212
|
-
# Create producer (adds TaskIns into Queue)
|
|
213
|
-
taskins_producer = asyncio.create_task(
|
|
214
|
-
add_taskins_to_queue(
|
|
170
|
+
state = state_factory.state()
|
|
171
|
+
|
|
172
|
+
extractor_th = threading.Thread(
|
|
173
|
+
target=add_taskins_to_queue,
|
|
174
|
+
args=(
|
|
175
|
+
state,
|
|
215
176
|
taskins_queue,
|
|
216
|
-
state_factory,
|
|
217
177
|
nodes_mapping,
|
|
218
|
-
backend,
|
|
219
|
-
worker_tasks,
|
|
220
178
|
f_stop,
|
|
221
|
-
)
|
|
179
|
+
),
|
|
222
180
|
)
|
|
181
|
+
extractor_th.start()
|
|
223
182
|
|
|
224
|
-
|
|
225
|
-
put_taskres_into_state
|
|
183
|
+
injector_th = threading.Thread(
|
|
184
|
+
target=put_taskres_into_state,
|
|
185
|
+
args=(
|
|
186
|
+
state,
|
|
187
|
+
taskres_queue,
|
|
188
|
+
f_stop,
|
|
189
|
+
),
|
|
226
190
|
)
|
|
191
|
+
injector_th.start()
|
|
227
192
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
193
|
+
with ThreadPoolExecutor() as executor:
|
|
194
|
+
_ = [
|
|
195
|
+
executor.submit(
|
|
196
|
+
worker,
|
|
197
|
+
app_fn,
|
|
198
|
+
taskins_queue,
|
|
199
|
+
taskres_queue,
|
|
200
|
+
node_states,
|
|
201
|
+
backend,
|
|
202
|
+
f_stop,
|
|
203
|
+
)
|
|
204
|
+
for _ in range(backend.num_workers)
|
|
205
|
+
]
|
|
206
|
+
|
|
207
|
+
extractor_th.join()
|
|
208
|
+
injector_th.join()
|
|
233
209
|
|
|
234
210
|
except Exception as ex:
|
|
235
211
|
|
|
@@ -244,18 +220,9 @@ async def run(
|
|
|
244
220
|
raise RuntimeError("Simulation Engine crashed.") from ex
|
|
245
221
|
|
|
246
222
|
finally:
|
|
247
|
-
# Produced task terminated, now cancel worker tasks
|
|
248
|
-
for w_t in worker_tasks:
|
|
249
|
-
_ = w_t.cancel()
|
|
250
|
-
|
|
251
|
-
while not all(w_t.done() for w_t in worker_tasks):
|
|
252
|
-
log(DEBUG, "Terminating async workers...")
|
|
253
|
-
await asyncio.sleep(0.5)
|
|
254
|
-
|
|
255
|
-
await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])
|
|
256
223
|
|
|
257
224
|
# Terminate backend
|
|
258
|
-
|
|
225
|
+
backend.terminate()
|
|
259
226
|
|
|
260
227
|
|
|
261
228
|
# pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
|
|
@@ -264,7 +231,7 @@ def start_vce(
|
|
|
264
231
|
backend_name: str,
|
|
265
232
|
backend_config_json_stream: str,
|
|
266
233
|
app_dir: str,
|
|
267
|
-
f_stop:
|
|
234
|
+
f_stop: threading.Event,
|
|
268
235
|
client_app: Optional[ClientApp] = None,
|
|
269
236
|
client_app_attr: Optional[str] = None,
|
|
270
237
|
num_supernodes: Optional[int] = None,
|
|
@@ -317,7 +284,9 @@ def start_vce(
|
|
|
317
284
|
# Construct mapping of NodeStates
|
|
318
285
|
node_states: Dict[int, NodeState] = {}
|
|
319
286
|
for node_id, partition_id in nodes_mapping.items():
|
|
320
|
-
node_states[node_id] = NodeState(
|
|
287
|
+
node_states[node_id] = NodeState(
|
|
288
|
+
node_id=node_id, node_config={}, partition_id=partition_id
|
|
289
|
+
)
|
|
321
290
|
|
|
322
291
|
# Load backend config
|
|
323
292
|
log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
|
|
@@ -368,15 +337,13 @@ def start_vce(
|
|
|
368
337
|
_ = app_fn()
|
|
369
338
|
|
|
370
339
|
# Run main simulation loop
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
f_stop,
|
|
379
|
-
)
|
|
340
|
+
run(
|
|
341
|
+
app_fn,
|
|
342
|
+
backend_fn,
|
|
343
|
+
nodes_mapping,
|
|
344
|
+
state_factory,
|
|
345
|
+
node_states,
|
|
346
|
+
f_stop,
|
|
380
347
|
)
|
|
381
348
|
except LoadClientAppError as loadapp_ex:
|
|
382
349
|
f_stop_delay = 10
|
flwr/server/typing.py
CHANGED
|
@@ -20,6 +20,8 @@ from typing import Callable
|
|
|
20
20
|
from flwr.common import Context
|
|
21
21
|
|
|
22
22
|
from .driver import Driver
|
|
23
|
+
from .serverapp_components import ServerAppComponents
|
|
23
24
|
|
|
24
25
|
ServerAppCallable = Callable[[Driver, Context], None]
|
|
25
26
|
Workflow = Callable[[Driver, Context], None]
|
|
27
|
+
ServerFn = Callable[[Context], ServerAppComponents]
|
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Ray-based Flower Actor and ActorPool implementation."""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import threading
|
|
19
18
|
from abc import ABC
|
|
20
19
|
from logging import DEBUG, ERROR, WARNING
|
|
@@ -411,9 +410,7 @@ class BasicActorPool:
|
|
|
411
410
|
self.client_resources = client_resources
|
|
412
411
|
|
|
413
412
|
# Queue of idle actors
|
|
414
|
-
self.pool:
|
|
415
|
-
maxsize=1024
|
|
416
|
-
)
|
|
413
|
+
self.pool: List[VirtualClientEngineActor] = []
|
|
417
414
|
self.num_actors = 0
|
|
418
415
|
|
|
419
416
|
# Resolve arguments to pass during actor init
|
|
@@ -427,38 +424,37 @@ class BasicActorPool:
|
|
|
427
424
|
# Figure out how many actors can be created given the cluster resources
|
|
428
425
|
# and the resources the user indicates each VirtualClient will need
|
|
429
426
|
self.actors_capacity = pool_size_from_resources(client_resources)
|
|
430
|
-
self._future_to_actor: Dict[Any,
|
|
427
|
+
self._future_to_actor: Dict[Any, VirtualClientEngineActor] = {}
|
|
431
428
|
|
|
432
429
|
def is_actor_available(self) -> bool:
|
|
433
430
|
"""Return true if there is an idle actor."""
|
|
434
|
-
return self.pool
|
|
431
|
+
return len(self.pool) > 0
|
|
435
432
|
|
|
436
|
-
|
|
433
|
+
def add_actors_to_pool(self, num_actors: int) -> None:
|
|
437
434
|
"""Add actors to the pool.
|
|
438
435
|
|
|
439
436
|
This method may be executed also if new resources are added to your Ray cluster
|
|
440
437
|
(e.g. you add a new node).
|
|
441
438
|
"""
|
|
442
439
|
for _ in range(num_actors):
|
|
443
|
-
|
|
440
|
+
self.pool.append(self.create_actor_fn()) # type: ignore
|
|
444
441
|
self.num_actors += num_actors
|
|
445
442
|
|
|
446
|
-
|
|
443
|
+
def terminate_all_actors(self) -> None:
|
|
447
444
|
"""Terminate actors in pool."""
|
|
448
445
|
num_terminated = 0
|
|
449
|
-
|
|
450
|
-
actor = await self.pool.get()
|
|
446
|
+
for actor in self.pool:
|
|
451
447
|
actor.terminate.remote() # type: ignore
|
|
452
448
|
num_terminated += 1
|
|
453
449
|
|
|
454
450
|
log(DEBUG, "Terminated %i actors", num_terminated)
|
|
455
451
|
|
|
456
|
-
|
|
452
|
+
def submit(
|
|
457
453
|
self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context]
|
|
458
454
|
) -> Any:
|
|
459
455
|
"""On idle actor, submit job and return future."""
|
|
460
456
|
# Remove idle actor from pool
|
|
461
|
-
actor =
|
|
457
|
+
actor = self.pool.pop()
|
|
462
458
|
# Submit job to actor
|
|
463
459
|
app_fn, mssg, cid, context = job
|
|
464
460
|
future = actor_fn(actor, app_fn, mssg, cid, context)
|
|
@@ -467,18 +463,18 @@ class BasicActorPool:
|
|
|
467
463
|
self._future_to_actor[future] = actor
|
|
468
464
|
return future
|
|
469
465
|
|
|
470
|
-
|
|
466
|
+
def add_actor_back_to_pool(self, future: Any) -> None:
|
|
471
467
|
"""Ad actor assigned to run future back into the pool."""
|
|
472
468
|
actor = self._future_to_actor.pop(future)
|
|
473
|
-
|
|
469
|
+
self.pool.append(actor)
|
|
474
470
|
|
|
475
|
-
|
|
471
|
+
def fetch_result_and_return_actor_to_pool(
|
|
476
472
|
self, future: Any
|
|
477
473
|
) -> Tuple[Message, Context]:
|
|
478
474
|
"""Pull result given a future and add actor back to pool."""
|
|
479
|
-
# Get actor that ran job
|
|
480
|
-
await self.add_actor_back_to_pool(future)
|
|
481
475
|
# Retrieve result for object store
|
|
482
476
|
# Instead of doing ray.get(future) we await it
|
|
483
|
-
_, out_mssg, updated_context =
|
|
477
|
+
_, out_mssg, updated_context = ray.get(future)
|
|
478
|
+
# Get actor that ran job
|
|
479
|
+
self.add_actor_back_to_pool(future)
|
|
484
480
|
return out_mssg, updated_context
|