flwr-nightly 1.10.0.dev20240710__py3-none-any.whl → 1.10.0.dev20240712__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (31) hide show
  1. flwr/cli/config_utils.py +10 -0
  2. flwr/cli/run/run.py +25 -8
  3. flwr/client/app.py +49 -17
  4. flwr/client/grpc_adapter_client/connection.py +1 -1
  5. flwr/client/grpc_client/connection.py +1 -1
  6. flwr/client/grpc_rere_client/connection.py +3 -2
  7. flwr/client/node_state.py +44 -11
  8. flwr/client/node_state_tests.py +4 -3
  9. flwr/client/rest_client/connection.py +4 -3
  10. flwr/client/supernode/app.py +14 -7
  11. flwr/common/config.py +3 -3
  12. flwr/common/context.py +13 -2
  13. flwr/common/logger.py +25 -0
  14. flwr/server/__init__.py +2 -0
  15. flwr/server/compat/legacy_context.py +1 -1
  16. flwr/server/run_serverapp.py +3 -1
  17. flwr/server/server_app.py +56 -10
  18. flwr/server/serverapp_components.py +52 -0
  19. flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
  20. flwr/server/superlink/fleet/vce/backend/raybackend.py +8 -9
  21. flwr/server/superlink/fleet/vce/vce_api.py +88 -121
  22. flwr/server/typing.py +2 -0
  23. flwr/simulation/ray_transport/ray_actor.py +15 -19
  24. flwr/simulation/ray_transport/ray_client_proxy.py +3 -1
  25. flwr/simulation/run_simulation.py +49 -33
  26. flwr/superexec/app.py +3 -3
  27. {flwr_nightly-1.10.0.dev20240710.dist-info → flwr_nightly-1.10.0.dev20240712.dist-info}/METADATA +2 -2
  28. {flwr_nightly-1.10.0.dev20240710.dist-info → flwr_nightly-1.10.0.dev20240712.dist-info}/RECORD +31 -30
  29. {flwr_nightly-1.10.0.dev20240710.dist-info → flwr_nightly-1.10.0.dev20240712.dist-info}/LICENSE +0 -0
  30. {flwr_nightly-1.10.0.dev20240710.dist-info → flwr_nightly-1.10.0.dev20240712.dist-info}/WHEEL +0 -0
  31. {flwr_nightly-1.10.0.dev20240710.dist-info → flwr_nightly-1.10.0.dev20240712.dist-info}/entry_points.txt +0 -0
flwr/server/server_app.py CHANGED
@@ -17,8 +17,11 @@
17
17
 
18
18
  from typing import Callable, Optional
19
19
 
20
- from flwr.common import Context, RecordSet
21
- from flwr.common.logger import warn_preview_feature
20
+ from flwr.common import Context
21
+ from flwr.common.logger import (
22
+ warn_deprecated_feature_with_example,
23
+ warn_preview_feature,
24
+ )
22
25
  from flwr.server.strategy import Strategy
23
26
 
24
27
  from .client_manager import ClientManager
@@ -26,7 +29,20 @@ from .compat import start_driver
26
29
  from .driver import Driver
27
30
  from .server import Server
28
31
  from .server_config import ServerConfig
29
- from .typing import ServerAppCallable
32
+ from .typing import ServerAppCallable, ServerFn
33
+
34
+ SERVER_FN_USAGE_EXAMPLE = """
35
+
36
+ def server_fn(context: Context):
37
+ server_config = ServerConfig(num_rounds=3)
38
+ strategy = FedAvg()
39
+ return ServerAppComponents(
40
+ strategy=strategy,
41
+ server_config=server_config,
42
+ )
43
+
44
+ app = ServerApp(server_fn=server_fn)
45
+ """
30
46
 
31
47
 
32
48
  class ServerApp:
@@ -36,13 +52,15 @@ class ServerApp:
36
52
  --------
37
53
  Use the `ServerApp` with an existing `Strategy`:
38
54
 
39
- >>> server_config = ServerConfig(num_rounds=3)
40
- >>> strategy = FedAvg()
55
+ >>> def server_fn(context: Context):
56
+ >>> server_config = ServerConfig(num_rounds=3)
57
+ >>> strategy = FedAvg()
58
+ >>> return ServerAppComponents(
59
+ >>> strategy=strategy,
60
+ >>> server_config=server_config,
61
+ >>> )
41
62
  >>>
42
- >>> app = ServerApp(
43
- >>> server_config=server_config,
44
- >>> strategy=strategy,
45
- >>> )
63
+ >>> app = ServerApp(server_fn=server_fn)
46
64
 
47
65
  Use the `ServerApp` with a custom main function:
48
66
 
@@ -53,23 +71,52 @@ class ServerApp:
53
71
  >>> print("ServerApp running")
54
72
  """
55
73
 
74
+ # pylint: disable=too-many-arguments
56
75
  def __init__(
57
76
  self,
58
77
  server: Optional[Server] = None,
59
78
  config: Optional[ServerConfig] = None,
60
79
  strategy: Optional[Strategy] = None,
61
80
  client_manager: Optional[ClientManager] = None,
81
+ server_fn: Optional[ServerFn] = None,
62
82
  ) -> None:
83
+ if any([server, config, strategy, client_manager]):
84
+ warn_deprecated_feature_with_example(
85
+ deprecation_message="Passing either `server`, `config`, `strategy` or "
86
+ "`client_manager` directly to the ServerApp "
87
+ "constructor is deprecated.",
88
+ example_message="Pass `ServerApp` arguments wrapped "
89
+ "in a `flwr.server.ServerAppComponents` object that gets "
90
+ "returned by a function passed as the `server_fn` argument "
91
+ "to the `ServerApp` constructor. For example: ",
92
+ code_example=SERVER_FN_USAGE_EXAMPLE,
93
+ )
94
+
95
+ if server_fn:
96
+ raise ValueError(
97
+ "Passing `server_fn` is incompatible with passing the "
98
+ "other arguments (now deprecated) to ServerApp. "
99
+ "Use `server_fn` exclusively."
100
+ )
101
+
63
102
  self._server = server
64
103
  self._config = config
65
104
  self._strategy = strategy
66
105
  self._client_manager = client_manager
106
+ self._server_fn = server_fn
67
107
  self._main: Optional[ServerAppCallable] = None
68
108
 
69
109
  def __call__(self, driver: Driver, context: Context) -> None:
70
110
  """Execute `ServerApp`."""
71
111
  # Compatibility mode
72
112
  if not self._main:
113
+ if self._server_fn:
114
+ # Execute server_fn()
115
+ components = self._server_fn(context)
116
+ self._server = components.server
117
+ self._config = components.config
118
+ self._strategy = components.strategy
119
+ self._client_manager = components.client_manager
73
120
  start_driver(
74
121
  server=self._server,
75
122
  config=self._config,
@@ -80,7 +127,6 @@ class ServerApp:
80
127
  return
81
128
 
82
129
  # New execution mode
83
- context = Context(state=RecordSet(), run_config={})
84
130
  self._main(driver, context)
85
131
 
86
132
  def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
@@ -0,0 +1,52 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ServerAppComponents for the ServerApp."""
16
+
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Optional
20
+
21
+ from .client_manager import ClientManager
22
+ from .server import Server
23
+ from .server_config import ServerConfig
24
+ from .strategy import Strategy
25
+
26
+
27
+ @dataclass
28
+ class ServerAppComponents: # pylint: disable=too-many-instance-attributes
29
+ """Components to construct a ServerApp.
30
+
31
+ Parameters
32
+ ----------
33
+ server : Optional[Server] (default: None)
34
+ A server implementation, either `flwr.server.Server` or a subclass
35
+ thereof. If no instance is provided, one will be created internally.
36
+ config : Optional[ServerConfig] (default: None)
37
+ Currently supported values are `num_rounds` (int, default: 1) and
38
+ `round_timeout` in seconds (float, default: None).
39
+ strategy : Optional[Strategy] (default: None)
40
+ An implementation of the abstract base class
41
+ `flwr.server.strategy.Strategy`. If no strategy is provided, then
42
+ `flwr.server.strategy.FedAvg` will be used.
43
+ client_manager : Optional[ClientManager] (default: None)
44
+ An implementation of the class `flwr.server.ClientManager`. If no
45
+ implementation is provided, then `flwr.server.SimpleClientManager`
46
+ will be used.
47
+ """
48
+
49
+ server: Optional[Server] = None
50
+ config: Optional[ServerConfig] = None
51
+ strategy: Optional[Strategy] = None
52
+ client_manager: Optional[ClientManager] = None
@@ -33,8 +33,8 @@ class Backend(ABC):
33
33
  """Construct a backend."""
34
34
 
35
35
  @abstractmethod
36
- async def build(self) -> None:
37
- """Build backend asynchronously.
36
+ def build(self) -> None:
37
+ """Build backend.
38
38
 
39
39
  Different components need to be in place before workers in a backend are ready
40
40
  to accept jobs. When this method finishes executing, the backend should be fully
@@ -54,11 +54,11 @@ class Backend(ABC):
54
54
  """Report whether a backend worker is idle and can therefore run a ClientApp."""
55
55
 
56
56
  @abstractmethod
57
- async def terminate(self) -> None:
57
+ def terminate(self) -> None:
58
58
  """Terminate backend."""
59
59
 
60
60
  @abstractmethod
61
- async def process_message(
61
+ def process_message(
62
62
  self,
63
63
  app: Callable[[], ClientApp],
64
64
  message: Message,
@@ -153,12 +153,12 @@ class RayBackend(Backend):
153
153
  """Report whether the pool has idle actors."""
154
154
  return self.pool.is_actor_available()
155
155
 
156
- async def build(self) -> None:
156
+ def build(self) -> None:
157
157
  """Build pool of Ray actors that this backend will submit jobs to."""
158
- await self.pool.add_actors_to_pool(self.pool.actors_capacity)
158
+ self.pool.add_actors_to_pool(self.pool.actors_capacity)
159
159
  log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
160
160
 
161
- async def process_message(
161
+ def process_message(
162
162
  self,
163
163
  app: Callable[[], ClientApp],
164
164
  message: Message,
@@ -172,17 +172,16 @@ class RayBackend(Backend):
172
172
 
173
173
  try:
174
174
  # Submit a task to the pool
175
- future = await self.pool.submit(
175
+ future = self.pool.submit(
176
176
  lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
177
177
  (app, message, str(partition_id), context),
178
178
  )
179
179
 
180
- await future
181
180
  # Fetch result
182
181
  (
183
182
  out_mssg,
184
183
  updated_context,
185
- ) = await self.pool.fetch_result_and_return_actor_to_pool(future)
184
+ ) = self.pool.fetch_result_and_return_actor_to_pool(future)
186
185
 
187
186
  return out_mssg, updated_context
188
187
 
@@ -193,11 +192,11 @@ class RayBackend(Backend):
193
192
  self.__class__.__name__,
194
193
  )
195
194
  # add actor back into pool
196
- await self.pool.add_actor_back_to_pool(future)
195
+ self.pool.add_actor_back_to_pool(future)
197
196
  raise ex
198
197
 
199
- async def terminate(self) -> None:
198
+ def terminate(self) -> None:
200
199
  """Terminate all actors in actor pool."""
201
- await self.pool.terminate_all_actors()
200
+ self.pool.terminate_all_actors()
202
201
  ray.shutdown()
203
202
  log(DEBUG, "Terminated %s", self.__class__.__name__)
@@ -14,14 +14,18 @@
14
14
  # ==============================================================================
15
15
  """Fleet Simulation Engine API."""
16
16
 
17
- import asyncio
17
+
18
18
  import json
19
19
  import sys
20
+ import threading
20
21
  import time
21
22
  import traceback
23
+ from concurrent.futures import ThreadPoolExecutor
22
24
  from logging import DEBUG, ERROR, INFO, WARN
23
25
  from pathlib import Path
24
- from typing import Callable, Dict, List, Optional
26
+ from queue import Empty, Queue
27
+ from time import sleep
28
+ from typing import Callable, Dict, Optional
25
29
 
26
30
  from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
27
31
  from flwr.client.node_state import NodeState
@@ -31,7 +35,7 @@ from flwr.common.message import Error
31
35
  from flwr.common.object_ref import load_app
32
36
  from flwr.common.serde import message_from_taskins, message_to_taskres
33
37
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
34
- from flwr.server.superlink.state import StateFactory
38
+ from flwr.server.superlink.state import State, StateFactory
35
39
 
36
40
  from .backend import Backend, error_messages_backends, supported_backends
37
41
 
@@ -52,18 +56,21 @@ def _register_nodes(
52
56
 
53
57
 
54
58
  # pylint: disable=too-many-arguments,too-many-locals
55
- async def worker(
59
+ def worker(
56
60
  app_fn: Callable[[], ClientApp],
57
- taskins_queue: "asyncio.Queue[TaskIns]",
58
- taskres_queue: "asyncio.Queue[TaskRes]",
61
+ taskins_queue: "Queue[TaskIns]",
62
+ taskres_queue: "Queue[TaskRes]",
59
63
  node_states: Dict[int, NodeState],
60
64
  backend: Backend,
65
+ f_stop: threading.Event,
61
66
  ) -> None:
62
67
  """Get TaskIns from queue and pass it to an actor in the pool to execute it."""
63
- while True:
68
+ while not f_stop.is_set():
64
69
  out_mssg = None
65
70
  try:
66
- task_ins: TaskIns = await taskins_queue.get()
71
+ # Fetch from queue with timeout. We use a timeout so
72
+ # the stopping event can be evaluated even when the queue is empty.
73
+ task_ins: TaskIns = taskins_queue.get(timeout=1.0)
67
74
  node_id = task_ins.task.consumer.node_id
68
75
 
69
76
  # Register and retrieve runstate
@@ -74,7 +81,7 @@ async def worker(
74
81
  message = message_from_taskins(task_ins)
75
82
 
76
83
  # Let backend process message
77
- out_mssg, updated_context = await backend.process_message(
84
+ out_mssg, updated_context = backend.process_message(
78
85
  app_fn, message, context
79
86
  )
80
87
 
@@ -82,11 +89,9 @@ async def worker(
82
89
  node_states[node_id].update_context(
83
90
  task_ins.run_id, context=updated_context
84
91
  )
85
-
86
- except asyncio.CancelledError as e:
87
- log(DEBUG, "Terminating async worker: %s", e)
88
- break
89
-
92
+ except Empty:
93
+ # An exception raised if queue.get times out
94
+ pass
90
95
  # Exceptions aren't raised but reported as an error message
91
96
  except Exception as ex: # pylint: disable=broad-exception-caught
92
97
  log(ERROR, ex)
@@ -110,83 +115,48 @@ async def worker(
110
115
  task_res = message_to_taskres(out_mssg)
111
116
  # Store TaskRes in state
112
117
  task_res.task.pushed_at = time.time()
113
- await taskres_queue.put(task_res)
118
+ taskres_queue.put(task_res)
114
119
 
115
120
 
116
- async def add_taskins_to_queue(
117
- queue: "asyncio.Queue[TaskIns]",
118
- state_factory: StateFactory,
121
+ def add_taskins_to_queue(
122
+ state: State,
123
+ queue: "Queue[TaskIns]",
119
124
  nodes_mapping: NodeToPartitionMapping,
120
- backend: Backend,
121
- consumers: List["asyncio.Task[None]"],
122
- f_stop: asyncio.Event,
125
+ f_stop: threading.Event,
123
126
  ) -> None:
124
- """Retrieve TaskIns and add it to the queue."""
125
- state = state_factory.state()
126
- num_initial_consumers = len(consumers)
127
+ """Put TaskIns in a queue from State."""
127
128
  while not f_stop.is_set():
128
129
  for node_id in nodes_mapping.keys():
129
- task_ins = state.get_task_ins(node_id=node_id, limit=1)
130
- if task_ins:
131
- await queue.put(task_ins[0])
132
-
133
- # Count consumers that are running
134
- num_active = sum(not (cc.done()) for cc in consumers)
135
-
136
- # Alert if number of consumers decreased by half
137
- if num_active < num_initial_consumers // 2:
138
- log(
139
- WARN,
140
- "Number of active workers has more than halved: (%i/%i active)",
141
- num_active,
142
- num_initial_consumers,
143
- )
130
+ task_ins_list = state.get_task_ins(node_id=node_id, limit=1)
131
+ for task_ins in task_ins_list:
132
+ queue.put(task_ins)
133
+ sleep(0.1)
144
134
 
145
- # Break if consumers died
146
- if num_active == 0:
147
- raise RuntimeError("All workers have died. Ending Simulation.")
148
135
 
149
- # Log some stats
150
- log(
151
- DEBUG,
152
- "Simulation Engine stats: "
153
- "Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)",
154
- num_active,
155
- num_initial_consumers,
156
- backend.__class__.__name__,
157
- backend.num_workers,
158
- queue.qsize(),
159
- )
160
- await asyncio.sleep(1.0)
161
- log(DEBUG, "Async producer: Stopped pulling from StateFactory.")
162
-
163
-
164
- async def put_taskres_into_state(
165
- queue: "asyncio.Queue[TaskRes]",
166
- state_factory: StateFactory,
167
- f_stop: asyncio.Event,
136
+ def put_taskres_into_state(
137
+ state: State, queue: "Queue[TaskRes]", f_stop: threading.Event
168
138
  ) -> None:
169
- """Remove TaskRes from queue and add into State."""
170
- state = state_factory.state()
139
+ """Put TaskRes into State from a queue."""
171
140
  while not f_stop.is_set():
172
- if queue.qsize():
173
- task_res = await queue.get()
174
- state.store_task_res(task_res)
175
- else:
176
- await asyncio.sleep(0.1)
141
+ try:
142
+ taskres = queue.get(timeout=1.0)
143
+ state.store_task_res(taskres)
144
+ except Empty:
145
+ # queue is empty when timeout was triggered
146
+ pass
177
147
 
178
148
 
179
- async def run(
149
+ def run(
180
150
  app_fn: Callable[[], ClientApp],
181
151
  backend_fn: Callable[[], Backend],
182
152
  nodes_mapping: NodeToPartitionMapping,
183
153
  state_factory: StateFactory,
184
154
  node_states: Dict[int, NodeState],
185
- f_stop: asyncio.Event,
155
+ f_stop: threading.Event,
186
156
  ) -> None:
187
- """Run the VCE async."""
188
- taskins_queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128)
189
- taskres_queue: "asyncio.Queue[TaskRes]" = asyncio.Queue(128)
157
+ """Run the VCE."""
158
+ taskins_queue: "Queue[TaskIns]" = Queue()
159
+ taskres_queue: "Queue[TaskRes]" = Queue()
190
160
 
191
161
  try:
192
162
 
@@ -194,42 +164,48 @@ async def run(
194
164
  backend = backend_fn()
195
165
 
196
166
  # Build backend
197
- await backend.build()
167
+ backend.build()
198
168
 
199
169
  # Add workers (they submit Messages to Backend)
200
- worker_tasks = [
201
- asyncio.create_task(
202
- worker(
203
- app_fn,
204
- taskins_queue,
205
- taskres_queue,
206
- node_states,
207
- backend,
208
- )
209
- )
210
- for _ in range(backend.num_workers)
211
- ]
212
- # Create producer (adds TaskIns into Queue)
213
- taskins_producer = asyncio.create_task(
214
- add_taskins_to_queue(
170
+ state = state_factory.state()
171
+
172
+ extractor_th = threading.Thread(
173
+ target=add_taskins_to_queue,
174
+ args=(
175
+ state,
215
176
  taskins_queue,
216
- state_factory,
217
177
  nodes_mapping,
218
- backend,
219
- worker_tasks,
220
178
  f_stop,
221
- )
179
+ ),
222
180
  )
181
+ extractor_th.start()
223
182
 
224
- taskres_consumer = asyncio.create_task(
225
- put_taskres_into_state(taskres_queue, state_factory, f_stop)
183
+ injector_th = threading.Thread(
184
+ target=put_taskres_into_state,
185
+ args=(
186
+ state,
187
+ taskres_queue,
188
+ f_stop,
189
+ ),
226
190
  )
191
+ injector_th.start()
227
192
 
228
- # Wait for asyncio taks pulling/pushing TaskIns/TaskRes.
229
- # These run forever until f_stop is set or until
230
- # all worker (consumer) coroutines are completed. Workers
231
- # also run forever and only end if an exception is raised.
232
- await asyncio.gather(*(taskins_producer, taskres_consumer))
193
+ with ThreadPoolExecutor() as executor:
194
+ _ = [
195
+ executor.submit(
196
+ worker,
197
+ app_fn,
198
+ taskins_queue,
199
+ taskres_queue,
200
+ node_states,
201
+ backend,
202
+ f_stop,
203
+ )
204
+ for _ in range(backend.num_workers)
205
+ ]
206
+
207
+ extractor_th.join()
208
+ injector_th.join()
233
209
 
234
210
  except Exception as ex:
235
211
 
@@ -244,18 +220,9 @@ async def run(
244
220
  raise RuntimeError("Simulation Engine crashed.") from ex
245
221
 
246
222
  finally:
247
- # Produced task terminated, now cancel worker tasks
248
- for w_t in worker_tasks:
249
- _ = w_t.cancel()
250
-
251
- while not all(w_t.done() for w_t in worker_tasks):
252
- log(DEBUG, "Terminating async workers...")
253
- await asyncio.sleep(0.5)
254
-
255
- await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])
256
223
 
257
224
  # Terminate backend
258
- await backend.terminate()
225
+ backend.terminate()
259
226
 
260
227
 
261
228
  # pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
@@ -264,7 +231,7 @@ def start_vce(
264
231
  backend_name: str,
265
232
  backend_config_json_stream: str,
266
233
  app_dir: str,
267
- f_stop: asyncio.Event,
234
+ f_stop: threading.Event,
268
235
  client_app: Optional[ClientApp] = None,
269
236
  client_app_attr: Optional[str] = None,
270
237
  num_supernodes: Optional[int] = None,
@@ -317,7 +284,9 @@ def start_vce(
317
284
  # Construct mapping of NodeStates
318
285
  node_states: Dict[int, NodeState] = {}
319
286
  for node_id, partition_id in nodes_mapping.items():
320
- node_states[node_id] = NodeState(partition_id=partition_id)
287
+ node_states[node_id] = NodeState(
288
+ node_id=node_id, node_config={}, partition_id=partition_id
289
+ )
321
290
 
322
291
  # Load backend config
323
292
  log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
@@ -368,15 +337,13 @@ def start_vce(
368
337
  _ = app_fn()
369
338
 
370
339
  # Run main simulation loop
371
- asyncio.run(
372
- run(
373
- app_fn,
374
- backend_fn,
375
- nodes_mapping,
376
- state_factory,
377
- node_states,
378
- f_stop,
379
- )
340
+ run(
341
+ app_fn,
342
+ backend_fn,
343
+ nodes_mapping,
344
+ state_factory,
345
+ node_states,
346
+ f_stop,
380
347
  )
381
348
  except LoadClientAppError as loadapp_ex:
382
349
  f_stop_delay = 10
flwr/server/typing.py CHANGED
@@ -20,6 +20,8 @@ from typing import Callable
20
20
  from flwr.common import Context
21
21
 
22
22
  from .driver import Driver
23
+ from .serverapp_components import ServerAppComponents
23
24
 
24
25
  ServerAppCallable = Callable[[Driver, Context], None]
25
26
  Workflow = Callable[[Driver, Context], None]
27
+ ServerFn = Callable[[Context], ServerAppComponents]
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
  """Ray-based Flower Actor and ActorPool implementation."""
16
16
 
17
- import asyncio
18
17
  import threading
19
18
  from abc import ABC
20
19
  from logging import DEBUG, ERROR, WARNING
@@ -411,9 +410,7 @@ class BasicActorPool:
411
410
  self.client_resources = client_resources
412
411
 
413
412
  # Queue of idle actors
414
- self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue(
415
- maxsize=1024
416
- )
413
+ self.pool: List[VirtualClientEngineActor] = []
417
414
  self.num_actors = 0
418
415
 
419
416
  # Resolve arguments to pass during actor init
@@ -427,38 +424,37 @@ class BasicActorPool:
427
424
  # Figure out how many actors can be created given the cluster resources
428
425
  # and the resources the user indicates each VirtualClient will need
429
426
  self.actors_capacity = pool_size_from_resources(client_resources)
430
- self._future_to_actor: Dict[Any, Type[VirtualClientEngineActor]] = {}
427
+ self._future_to_actor: Dict[Any, VirtualClientEngineActor] = {}
431
428
 
432
429
  def is_actor_available(self) -> bool:
433
430
  """Return true if there is an idle actor."""
434
- return self.pool.qsize() > 0
431
+ return len(self.pool) > 0
435
432
 
436
- async def add_actors_to_pool(self, num_actors: int) -> None:
433
+ def add_actors_to_pool(self, num_actors: int) -> None:
437
434
  """Add actors to the pool.
438
435
 
439
436
  This method may be executed also if new resources are added to your Ray cluster
440
437
  (e.g. you add a new node).
441
438
  """
442
439
  for _ in range(num_actors):
443
- await self.pool.put(self.create_actor_fn()) # type: ignore
440
+ self.pool.append(self.create_actor_fn()) # type: ignore
444
441
  self.num_actors += num_actors
445
442
 
446
- async def terminate_all_actors(self) -> None:
443
+ def terminate_all_actors(self) -> None:
447
444
  """Terminate actors in pool."""
448
445
  num_terminated = 0
449
- while self.pool.qsize():
450
- actor = await self.pool.get()
446
+ for actor in self.pool:
451
447
  actor.terminate.remote() # type: ignore
452
448
  num_terminated += 1
453
449
 
454
450
  log(DEBUG, "Terminated %i actors", num_terminated)
455
451
 
456
- async def submit(
452
+ def submit(
457
453
  self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context]
458
454
  ) -> Any:
459
455
  """On idle actor, submit job and return future."""
460
456
  # Remove idle actor from pool
461
- actor = await self.pool.get()
457
+ actor = self.pool.pop()
462
458
  # Submit job to actor
463
459
  app_fn, mssg, cid, context = job
464
460
  future = actor_fn(actor, app_fn, mssg, cid, context)
@@ -467,18 +463,18 @@ class BasicActorPool:
467
463
  self._future_to_actor[future] = actor
468
464
  return future
469
465
 
470
- async def add_actor_back_to_pool(self, future: Any) -> None:
466
+ def add_actor_back_to_pool(self, future: Any) -> None:
471
467
  """Ad actor assigned to run future back into the pool."""
472
468
  actor = self._future_to_actor.pop(future)
473
- await self.pool.put(actor)
469
+ self.pool.append(actor)
474
470
 
475
- async def fetch_result_and_return_actor_to_pool(
471
+ def fetch_result_and_return_actor_to_pool(
476
472
  self, future: Any
477
473
  ) -> Tuple[Message, Context]:
478
474
  """Pull result given a future and add actor back to pool."""
479
- # Get actor that ran job
480
- await self.add_actor_back_to_pool(future)
481
475
  # Retrieve result for object store
482
476
  # Instead of doing ray.get(future) we await it
483
- _, out_mssg, updated_context = await future
477
+ _, out_mssg, updated_context = ray.get(future)
478
+ # Get actor that ran job
479
+ self.add_actor_back_to_pool(future)
484
480
  return out_mssg, updated_context