flwr-nightly 1.10.0.dev20240709__py3-none-any.whl → 1.10.0.dev20240711__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

@@ -14,14 +14,18 @@
14
14
  # ==============================================================================
15
15
  """Fleet Simulation Engine API."""
16
16
 
17
- import asyncio
17
+
18
18
  import json
19
19
  import sys
20
+ import threading
20
21
  import time
21
22
  import traceback
23
+ from concurrent.futures import ThreadPoolExecutor
22
24
  from logging import DEBUG, ERROR, INFO, WARN
23
25
  from pathlib import Path
24
- from typing import Callable, Dict, List, Optional
26
+ from queue import Empty, Queue
27
+ from time import sleep
28
+ from typing import Callable, Dict, Optional
25
29
 
26
30
  from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
27
31
  from flwr.client.node_state import NodeState
@@ -30,8 +34,8 @@ from flwr.common.logger import log
30
34
  from flwr.common.message import Error
31
35
  from flwr.common.object_ref import load_app
32
36
  from flwr.common.serde import message_from_taskins, message_to_taskres
33
- from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
34
- from flwr.server.superlink.state import StateFactory
37
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
38
+ from flwr.server.superlink.state import State, StateFactory
35
39
 
36
40
  from .backend import Backend, error_messages_backends, supported_backends
37
41
 
@@ -52,19 +56,21 @@ def _register_nodes(
52
56
 
53
57
 
54
58
  # pylint: disable=too-many-arguments,too-many-locals
55
- async def worker(
59
+ def worker(
56
60
  app_fn: Callable[[], ClientApp],
57
- queue: "asyncio.Queue[TaskIns]",
61
+ taskins_queue: "Queue[TaskIns]",
62
+ taskres_queue: "Queue[TaskRes]",
58
63
  node_states: Dict[int, NodeState],
59
- state_factory: StateFactory,
60
64
  backend: Backend,
65
+ f_stop: threading.Event,
61
66
  ) -> None:
62
67
  """Get TaskIns from queue and pass it to an actor in the pool to execute it."""
63
- state = state_factory.state()
64
- while True:
68
+ while not f_stop.is_set():
65
69
  out_mssg = None
66
70
  try:
67
- task_ins: TaskIns = await queue.get()
71
+ # Fetch from queue with timeout. We use a timeout so
72
+ # the stopping event can be evaluated even when the queue is empty.
73
+ task_ins: TaskIns = taskins_queue.get(timeout=1.0)
68
74
  node_id = task_ins.task.consumer.node_id
69
75
 
70
76
  # Register and retrieve runstate
@@ -75,7 +81,7 @@ async def worker(
75
81
  message = message_from_taskins(task_ins)
76
82
 
77
83
  # Let backend process message
78
- out_mssg, updated_context = await backend.process_message(
84
+ out_mssg, updated_context = backend.process_message(
79
85
  app_fn, message, context
80
86
  )
81
87
 
@@ -83,11 +89,9 @@ async def worker(
83
89
  node_states[node_id].update_context(
84
90
  task_ins.run_id, context=updated_context
85
91
  )
86
-
87
- except asyncio.CancelledError as e:
88
- log(DEBUG, "Terminating async worker: %s", e)
89
- break
90
-
92
+ except Empty:
93
+ # An exception raised if queue.get times out
94
+ pass
91
95
  # Exceptions aren't raised but reported as an error message
92
96
  except Exception as ex: # pylint: disable=broad-exception-caught
93
97
  log(ERROR, ex)
@@ -111,67 +115,48 @@ async def worker(
111
115
  task_res = message_to_taskres(out_mssg)
112
116
  # Store TaskRes in state
113
117
  task_res.task.pushed_at = time.time()
114
- state.store_task_res(task_res)
118
+ taskres_queue.put(task_res)
115
119
 
116
120
 
117
- async def add_taskins_to_queue(
118
- queue: "asyncio.Queue[TaskIns]",
119
- state_factory: StateFactory,
121
+ def add_taskins_to_queue(
122
+ state: State,
123
+ queue: "Queue[TaskIns]",
120
124
  nodes_mapping: NodeToPartitionMapping,
121
- backend: Backend,
122
- consumers: List["asyncio.Task[None]"],
123
- f_stop: asyncio.Event,
125
+ f_stop: threading.Event,
124
126
  ) -> None:
125
- """Retrieve TaskIns and add it to the queue."""
126
- state = state_factory.state()
127
- num_initial_consumers = len(consumers)
127
+ """Put TaskIns in a queue from State."""
128
128
  while not f_stop.is_set():
129
129
  for node_id in nodes_mapping.keys():
130
- task_ins = state.get_task_ins(node_id=node_id, limit=1)
131
- if task_ins:
132
- await queue.put(task_ins[0])
133
-
134
- # Count consumers that are running
135
- num_active = sum(not (cc.done()) for cc in consumers)
136
-
137
- # Alert if number of consumers decreased by half
138
- if num_active < num_initial_consumers // 2:
139
- log(
140
- WARN,
141
- "Number of active workers has more than halved: (%i/%i active)",
142
- num_active,
143
- num_initial_consumers,
144
- )
130
+ task_ins_list = state.get_task_ins(node_id=node_id, limit=1)
131
+ for task_ins in task_ins_list:
132
+ queue.put(task_ins)
133
+ sleep(0.1)
145
134
 
146
- # Break if consumers died
147
- if num_active == 0:
148
- raise RuntimeError("All workers have died. Ending Simulation.")
149
135
 
150
- # Log some stats
151
- log(
152
- DEBUG,
153
- "Simulation Engine stats: "
154
- "Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)",
155
- num_active,
156
- num_initial_consumers,
157
- backend.__class__.__name__,
158
- backend.num_workers,
159
- queue.qsize(),
160
- )
161
- await asyncio.sleep(1.0)
162
- log(DEBUG, "Async producer: Stopped pulling from StateFactory.")
136
+ def put_taskres_into_state(
137
+ state: State, queue: "Queue[TaskRes]", f_stop: threading.Event
138
+ ) -> None:
139
+ """Put TaskRes into State from a queue."""
140
+ while not f_stop.is_set():
141
+ try:
142
+ taskres = queue.get(timeout=1.0)
143
+ state.store_task_res(taskres)
144
+ except Empty:
145
+ # queue is empty when timeout was triggered
146
+ pass
163
147
 
164
148
 
165
- async def run(
149
+ def run(
166
150
  app_fn: Callable[[], ClientApp],
167
151
  backend_fn: Callable[[], Backend],
168
152
  nodes_mapping: NodeToPartitionMapping,
169
153
  state_factory: StateFactory,
170
154
  node_states: Dict[int, NodeState],
171
- f_stop: asyncio.Event,
155
+ f_stop: threading.Event,
172
156
  ) -> None:
173
- """Run the VCE async."""
174
- queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128)
157
+ """Run the VCE."""
158
+ taskins_queue: "Queue[TaskIns]" = Queue()
159
+ taskres_queue: "Queue[TaskRes]" = Queue()
175
160
 
176
161
  try:
177
162
 
@@ -179,27 +164,48 @@ async def run(
179
164
  backend = backend_fn()
180
165
 
181
166
  # Build backend
182
- await backend.build()
167
+ backend.build()
183
168
 
184
169
  # Add workers (they submit Messages to Backend)
185
- worker_tasks = [
186
- asyncio.create_task(
187
- worker(app_fn, queue, node_states, state_factory, backend)
188
- )
189
- for _ in range(backend.num_workers)
190
- ]
191
- # Create producer (adds TaskIns into Queue)
192
- producer = asyncio.create_task(
193
- add_taskins_to_queue(
194
- queue, state_factory, nodes_mapping, backend, worker_tasks, f_stop
195
- )
170
+ state = state_factory.state()
171
+
172
+ extractor_th = threading.Thread(
173
+ target=add_taskins_to_queue,
174
+ args=(
175
+ state,
176
+ taskins_queue,
177
+ nodes_mapping,
178
+ f_stop,
179
+ ),
196
180
  )
181
+ extractor_th.start()
197
182
 
198
- # Wait for producer to finish
199
- # The producer runs forever until f_stop is set or until
200
- # all worker (consumer) coroutines are completed. Workers
201
- # also run forever and only end if an exception is raised.
202
- await asyncio.gather(producer)
183
+ injector_th = threading.Thread(
184
+ target=put_taskres_into_state,
185
+ args=(
186
+ state,
187
+ taskres_queue,
188
+ f_stop,
189
+ ),
190
+ )
191
+ injector_th.start()
192
+
193
+ with ThreadPoolExecutor() as executor:
194
+ _ = [
195
+ executor.submit(
196
+ worker,
197
+ app_fn,
198
+ taskins_queue,
199
+ taskres_queue,
200
+ node_states,
201
+ backend,
202
+ f_stop,
203
+ )
204
+ for _ in range(backend.num_workers)
205
+ ]
206
+
207
+ extractor_th.join()
208
+ injector_th.join()
203
209
 
204
210
  except Exception as ex:
205
211
 
@@ -214,18 +220,9 @@ async def run(
214
220
  raise RuntimeError("Simulation Engine crashed.") from ex
215
221
 
216
222
  finally:
217
- # Produced task terminated, now cancel worker tasks
218
- for w_t in worker_tasks:
219
- _ = w_t.cancel()
220
-
221
- while not all(w_t.done() for w_t in worker_tasks):
222
- log(DEBUG, "Terminating async workers...")
223
- await asyncio.sleep(0.5)
224
-
225
- await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])
226
223
 
227
224
  # Terminate backend
228
- await backend.terminate()
225
+ backend.terminate()
229
226
 
230
227
 
231
228
  # pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
@@ -234,7 +231,7 @@ def start_vce(
234
231
  backend_name: str,
235
232
  backend_config_json_stream: str,
236
233
  app_dir: str,
237
- f_stop: asyncio.Event,
234
+ f_stop: threading.Event,
238
235
  client_app: Optional[ClientApp] = None,
239
236
  client_app_attr: Optional[str] = None,
240
237
  num_supernodes: Optional[int] = None,
@@ -338,15 +335,13 @@ def start_vce(
338
335
  _ = app_fn()
339
336
 
340
337
  # Run main simulation loop
341
- asyncio.run(
342
- run(
343
- app_fn,
344
- backend_fn,
345
- nodes_mapping,
346
- state_factory,
347
- node_states,
348
- f_stop,
349
- )
338
+ run(
339
+ app_fn,
340
+ backend_fn,
341
+ nodes_mapping,
342
+ state_factory,
343
+ node_states,
344
+ f_stop,
350
345
  )
351
346
  except LoadClientAppError as loadapp_ex:
352
347
  f_stop_delay = 10
flwr/server/typing.py CHANGED
@@ -20,6 +20,8 @@ from typing import Callable
20
20
  from flwr.common import Context
21
21
 
22
22
  from .driver import Driver
23
+ from .serverapp_components import ServerAppComponents
23
24
 
24
25
  ServerAppCallable = Callable[[Driver, Context], None]
25
26
  Workflow = Callable[[Driver, Context], None]
27
+ ServerFn = Callable[[Context], ServerAppComponents]
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
  """Ray-based Flower Actor and ActorPool implementation."""
16
16
 
17
- import asyncio
18
17
  import threading
19
18
  from abc import ABC
20
19
  from logging import DEBUG, ERROR, WARNING
@@ -411,9 +410,7 @@ class BasicActorPool:
411
410
  self.client_resources = client_resources
412
411
 
413
412
  # Queue of idle actors
414
- self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue(
415
- maxsize=1024
416
- )
413
+ self.pool: List[VirtualClientEngineActor] = []
417
414
  self.num_actors = 0
418
415
 
419
416
  # Resolve arguments to pass during actor init
@@ -427,38 +424,37 @@ class BasicActorPool:
427
424
  # Figure out how many actors can be created given the cluster resources
428
425
  # and the resources the user indicates each VirtualClient will need
429
426
  self.actors_capacity = pool_size_from_resources(client_resources)
430
- self._future_to_actor: Dict[Any, Type[VirtualClientEngineActor]] = {}
427
+ self._future_to_actor: Dict[Any, VirtualClientEngineActor] = {}
431
428
 
432
429
  def is_actor_available(self) -> bool:
433
430
  """Return true if there is an idle actor."""
434
- return self.pool.qsize() > 0
431
+ return len(self.pool) > 0
435
432
 
436
- async def add_actors_to_pool(self, num_actors: int) -> None:
433
+ def add_actors_to_pool(self, num_actors: int) -> None:
437
434
  """Add actors to the pool.
438
435
 
439
436
  This method may be executed also if new resources are added to your Ray cluster
440
437
  (e.g. you add a new node).
441
438
  """
442
439
  for _ in range(num_actors):
443
- await self.pool.put(self.create_actor_fn()) # type: ignore
440
+ self.pool.append(self.create_actor_fn()) # type: ignore
444
441
  self.num_actors += num_actors
445
442
 
446
- async def terminate_all_actors(self) -> None:
443
+ def terminate_all_actors(self) -> None:
447
444
  """Terminate actors in pool."""
448
445
  num_terminated = 0
449
- while self.pool.qsize():
450
- actor = await self.pool.get()
446
+ for actor in self.pool:
451
447
  actor.terminate.remote() # type: ignore
452
448
  num_terminated += 1
453
449
 
454
450
  log(DEBUG, "Terminated %i actors", num_terminated)
455
451
 
456
- async def submit(
452
+ def submit(
457
453
  self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context]
458
454
  ) -> Any:
459
455
  """On idle actor, submit job and return future."""
460
456
  # Remove idle actor from pool
461
- actor = await self.pool.get()
457
+ actor = self.pool.pop()
462
458
  # Submit job to actor
463
459
  app_fn, mssg, cid, context = job
464
460
  future = actor_fn(actor, app_fn, mssg, cid, context)
@@ -467,18 +463,18 @@ class BasicActorPool:
467
463
  self._future_to_actor[future] = actor
468
464
  return future
469
465
 
470
- async def add_actor_back_to_pool(self, future: Any) -> None:
466
+ def add_actor_back_to_pool(self, future: Any) -> None:
471
467
  """Ad actor assigned to run future back into the pool."""
472
468
  actor = self._future_to_actor.pop(future)
473
- await self.pool.put(actor)
469
+ self.pool.append(actor)
474
470
 
475
- async def fetch_result_and_return_actor_to_pool(
471
+ def fetch_result_and_return_actor_to_pool(
476
472
  self, future: Any
477
473
  ) -> Tuple[Message, Context]:
478
474
  """Pull result given a future and add actor back to pool."""
479
- # Get actor that ran job
480
- await self.add_actor_back_to_pool(future)
481
475
  # Retrieve result for object store
482
476
  # Instead of doing ray.get(future) we await it
483
- _, out_mssg, updated_context = await future
477
+ _, out_mssg, updated_context = ray.get(future)
478
+ # Get actor that ran job
479
+ self.add_actor_back_to_pool(future)
484
480
  return out_mssg, updated_context
@@ -22,7 +22,7 @@ import threading
22
22
  import traceback
23
23
  from logging import DEBUG, ERROR, INFO, WARNING
24
24
  from time import sleep
25
- from typing import Optional
25
+ from typing import Dict, Optional
26
26
 
27
27
  from flwr.client import ClientApp
28
28
  from flwr.common import EventType, event, log
@@ -126,16 +126,25 @@ def run_simulation(
126
126
  def run_serverapp_th(
127
127
  server_app_attr: Optional[str],
128
128
  server_app: Optional[ServerApp],
129
+ server_app_run_config: Dict[str, str],
129
130
  driver: Driver,
130
131
  app_dir: str,
131
- f_stop: asyncio.Event,
132
+ f_stop: threading.Event,
133
+ has_exception: threading.Event,
132
134
  enable_tf_gpu_growth: bool,
133
135
  delay_launch: int = 3,
134
136
  ) -> threading.Thread:
135
137
  """Run SeverApp in a thread."""
136
138
 
137
- def server_th_with_start_checks( # type: ignore
138
- tf_gpu_growth: bool, stop_event: asyncio.Event, **kwargs
139
+ def server_th_with_start_checks(
140
+ tf_gpu_growth: bool,
141
+ stop_event: threading.Event,
142
+ exception_event: threading.Event,
143
+ _driver: Driver,
144
+ _server_app_dir: str,
145
+ _server_app_run_config: Dict[str, str],
146
+ _server_app_attr: Optional[str],
147
+ _server_app: Optional[ServerApp],
139
148
  ) -> None:
140
149
  """Run SeverApp, after check if GPU memory growth has to be set.
141
150
 
@@ -147,10 +156,18 @@ def run_serverapp_th(
147
156
  enable_gpu_growth()
148
157
 
149
158
  # Run ServerApp
150
- run(**kwargs)
159
+ run(
160
+ driver=_driver,
161
+ server_app_dir=_server_app_dir,
162
+ server_app_run_config=_server_app_run_config,
163
+ server_app_attr=_server_app_attr,
164
+ loaded_server_app=_server_app,
165
+ )
151
166
  except Exception as ex: # pylint: disable=broad-exception-caught
152
167
  log(ERROR, "ServerApp thread raised an exception: %s", ex)
153
168
  log(ERROR, traceback.format_exc())
169
+ exception_event.set()
170
+ raise
154
171
  finally:
155
172
  log(DEBUG, "ServerApp finished running.")
156
173
  # Upon completion, trigger stop event if one was passed
@@ -160,13 +177,16 @@ def run_serverapp_th(
160
177
 
161
178
  serverapp_th = threading.Thread(
162
179
  target=server_th_with_start_checks,
163
- args=(enable_tf_gpu_growth, f_stop),
164
- kwargs={
165
- "server_app_attr": server_app_attr,
166
- "loaded_server_app": server_app,
167
- "driver": driver,
168
- "server_app_dir": app_dir,
169
- },
180
+ args=(
181
+ enable_tf_gpu_growth,
182
+ f_stop,
183
+ has_exception,
184
+ driver,
185
+ app_dir,
186
+ server_app_run_config,
187
+ server_app_attr,
188
+ server_app,
189
+ ),
170
190
  )
171
191
  sleep(delay_launch)
172
192
  serverapp_th.start()
@@ -196,20 +216,18 @@ def _main_loop(
196
216
  server_app: Optional[ServerApp] = None,
197
217
  server_app_attr: Optional[str] = None,
198
218
  ) -> None:
199
- """Launch SuperLink with Simulation Engine, then ServerApp on a separate thread.
200
-
201
- Everything runs on the main thread or a separate one, depending on whether the main
202
- thread already contains a running Asyncio event loop. This is the case if running
203
- the Simulation Engine on a Jupyter/Colab notebook.
204
- """
219
+ """Launch SuperLink with Simulation Engine, then ServerApp on a separate thread."""
205
220
  # Initialize StateFactory
206
221
  state_factory = StateFactory(":flwr-in-memory-state:")
207
222
 
208
- f_stop = asyncio.Event()
223
+ f_stop = threading.Event()
224
+ # A Threading event to indicate if an exception was raised in the ServerApp thread
225
+ server_app_thread_has_exception = threading.Event()
209
226
  serverapp_th = None
210
227
  try:
211
228
  # Create run (with empty fab_id and fab_version)
212
229
  run_id_ = state_factory.state().create_run("", "", {})
230
+ server_app_run_config: Dict[str, str] = {}
213
231
 
214
232
  if run_id:
215
233
  _override_run_id(state_factory, run_id_to_replace=run_id_, run_id=run_id)
@@ -222,9 +240,11 @@ def _main_loop(
222
240
  serverapp_th = run_serverapp_th(
223
241
  server_app_attr=server_app_attr,
224
242
  server_app=server_app,
243
+ server_app_run_config=server_app_run_config,
225
244
  driver=driver,
226
245
  app_dir=app_dir,
227
246
  f_stop=f_stop,
247
+ has_exception=server_app_thread_has_exception,
228
248
  enable_tf_gpu_growth=enable_tf_gpu_growth,
229
249
  )
230
250
 
@@ -253,6 +273,8 @@ def _main_loop(
253
273
  event(EventType.RUN_SUPERLINK_LEAVE)
254
274
  if serverapp_th:
255
275
  serverapp_th.join()
276
+ if server_app_thread_has_exception.is_set():
277
+ raise RuntimeError("Exception in ServerApp thread")
256
278
 
257
279
  log(DEBUG, "Stopping Simulation Engine now.")
258
280
 
@@ -349,7 +371,6 @@ def _run_simulation(
349
371
  # Convert config to original JSON-stream format
350
372
  backend_config_stream = json.dumps(backend_config)
351
373
 
352
- simulation_engine_th = None
353
374
  args = (
354
375
  num_supernodes,
355
376
  backend_name,
@@ -363,31 +384,26 @@ def _run_simulation(
363
384
  server_app_attr,
364
385
  )
365
386
  # Detect if there is an Asyncio event loop already running.
366
- # If yes, run everything on a separate thread. In environments
367
- # like Jupyter/Colab notebooks, there is an event loop present.
368
- run_in_thread = False
387
+ # If yes, disable logger propagation. In environmnets
388
+ # like Jupyter/Colab notebooks, it's often better to do this.
389
+ asyncio_loop_running = False
369
390
  try:
370
391
  _ = (
371
392
  asyncio.get_running_loop()
372
393
  ) # Raises RuntimeError if no event loop is present
373
394
  log(DEBUG, "Asyncio event loop already running.")
374
395
 
375
- run_in_thread = True
396
+ asyncio_loop_running = True
376
397
 
377
398
  except RuntimeError:
378
- log(DEBUG, "No asyncio event loop running")
399
+ pass
379
400
 
380
401
  finally:
381
- if run_in_thread:
402
+ if asyncio_loop_running:
382
403
  # Set logger propagation to False to prevent duplicated log output in Colab.
383
404
  logger = set_logger_propagation(logger, False)
384
- log(DEBUG, "Starting Simulation Engine on a new thread.")
385
- simulation_engine_th = threading.Thread(target=_main_loop, args=args)
386
- simulation_engine_th.start()
387
- simulation_engine_th.join()
388
- else:
389
- log(DEBUG, "Starting Simulation Engine on the main thread.")
390
- _main_loop(*args)
405
+
406
+ _main_loop(*args)
391
407
 
392
408
 
393
409
  def _parse_args_run_simulation() -> argparse.ArgumentParser:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.10.0.dev20240709
3
+ Version: 1.10.0.dev20240711
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0
@@ -33,7 +33,7 @@ Classifier: Typing :: Typed
33
33
  Provides-Extra: rest
34
34
  Provides-Extra: simulation
35
35
  Requires-Dist: cryptography (>=42.0.4,<43.0.0)
36
- Requires-Dist: grpcio (>=1.60.0,<2.0.0)
36
+ Requires-Dist: grpcio (>=1.60.0,<2.0.0,!=1.64.2,!=1.65.0)
37
37
  Requires-Dist: iterators (>=0.0.2,<0.0.3)
38
38
  Requires-Dist: numpy (>=1.21.0,<2.0.0)
39
39
  Requires-Dist: pathspec (>=0.12.1,<0.13.0)