flwr-nightly 1.8.0.dev20240327__py3-none-any.whl → 1.8.0.dev20240402__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (30) hide show
  1. flwr/client/app.py +53 -29
  2. flwr/client/client_app.py +16 -0
  3. flwr/client/grpc_rere_client/connection.py +71 -29
  4. flwr/client/heartbeat.py +72 -0
  5. flwr/client/rest_client/connection.py +102 -28
  6. flwr/common/constant.py +20 -0
  7. flwr/common/logger.py +4 -4
  8. flwr/common/message.py +53 -14
  9. flwr/common/retry_invoker.py +24 -13
  10. flwr/proto/fleet_pb2.py +26 -26
  11. flwr/proto/fleet_pb2.pyi +5 -0
  12. flwr/server/compat/driver_client_proxy.py +16 -0
  13. flwr/server/driver/driver.py +15 -5
  14. flwr/server/server_app.py +3 -0
  15. flwr/server/superlink/fleet/message_handler/message_handler.py +3 -2
  16. flwr/server/superlink/fleet/rest_rere/rest_api.py +28 -0
  17. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -4
  18. flwr/server/superlink/fleet/vce/vce_api.py +61 -27
  19. flwr/server/superlink/state/in_memory_state.py +25 -8
  20. flwr/server/superlink/state/sqlite_state.py +53 -5
  21. flwr/server/superlink/state/state.py +1 -1
  22. flwr/server/superlink/state/utils.py +56 -0
  23. flwr/server/workflow/default_workflows.py +1 -4
  24. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +0 -5
  25. flwr/simulation/ray_transport/ray_actor.py +8 -24
  26. {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/METADATA +1 -1
  27. {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/RECORD +30 -28
  28. {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/LICENSE +0 -0
  29. {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/WHEEL +0 -0
  30. {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/entry_points.txt +0 -0
@@ -14,16 +14,19 @@
14
14
  # ==============================================================================
15
15
  """Fleet Simulation Engine API."""
16
16
 
17
-
18
17
  import asyncio
19
18
  import json
19
+ import sys
20
+ import time
20
21
  import traceback
21
22
  from logging import DEBUG, ERROR, INFO, WARN
22
23
  from typing import Callable, Dict, List, Optional
23
24
 
24
- from flwr.client.client_app import ClientApp, LoadClientAppError
25
+ from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
25
26
  from flwr.client.node_state import NodeState
27
+ from flwr.common.constant import PING_MAX_INTERVAL, ErrorCode
26
28
  from flwr.common.logger import log
29
+ from flwr.common.message import Error
27
30
  from flwr.common.object_ref import load_app
28
31
  from flwr.common.serde import message_from_taskins, message_to_taskres
29
32
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
@@ -41,7 +44,7 @@ def _register_nodes(
41
44
  nodes_mapping: NodeToPartitionMapping = {}
42
45
  state = state_factory.state()
43
46
  for i in range(num_nodes):
44
- node_id = state.create_node()
47
+ node_id = state.create_node(ping_interval=PING_MAX_INTERVAL)
45
48
  nodes_mapping[node_id] = i
46
49
  log(INFO, "Registered %i nodes", len(nodes_mapping))
47
50
  return nodes_mapping
@@ -59,6 +62,7 @@ async def worker(
59
62
  """Get TaskIns from queue and pass it to an actor in the pool to execute it."""
60
63
  state = state_factory.state()
61
64
  while True:
65
+ out_mssg = None
62
66
  try:
63
67
  task_ins: TaskIns = await queue.get()
64
68
  node_id = task_ins.task.consumer.node_id
@@ -82,24 +86,34 @@ async def worker(
82
86
  task_ins.run_id, context=updated_context
83
87
  )
84
88
 
85
- # Convert to TaskRes
86
- task_res = message_to_taskres(out_mssg)
87
- # Store TaskRes in state
88
- state.store_task_res(task_res)
89
-
90
89
  except asyncio.CancelledError as e:
91
- log(DEBUG, "Async worker: %s", e)
90
+ log(DEBUG, "Terminating async worker: %s", e)
92
91
  break
93
92
 
94
- except LoadClientAppError as app_ex:
95
- log(ERROR, "Async worker: %s", app_ex)
96
- log(ERROR, traceback.format_exc())
97
- raise
98
-
93
+ # Exceptions aren't raised but reported as an error message
99
94
  except Exception as ex: # pylint: disable=broad-exception-caught
100
95
  log(ERROR, ex)
101
96
  log(ERROR, traceback.format_exc())
102
- break
97
+
98
+ if isinstance(ex, ClientAppException):
99
+ e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
100
+ elif isinstance(ex, LoadClientAppError):
101
+ e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
102
+ else:
103
+ e_code = ErrorCode.UNKNOWN
104
+
105
+ reason = str(type(ex)) + ":<'" + str(ex) + "'>"
106
+ out_mssg = message.create_error_reply(
107
+ error=Error(code=e_code, reason=reason)
108
+ )
109
+
110
+ finally:
111
+ if out_mssg:
112
+ # Convert to TaskRes
113
+ task_res = message_to_taskres(out_mssg)
114
+ # Store TaskRes in state
115
+ task_res.task.pushed_at = time.time()
116
+ state.store_task_res(task_res)
103
117
 
104
118
 
105
119
  async def add_taskins_to_queue(
@@ -218,7 +232,8 @@ async def run(
218
232
  await backend.terminate()
219
233
 
220
234
 
221
- # pylint: disable=too-many-arguments,unused-argument,too-many-locals
235
+ # pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
236
+ # pylint: disable=too-many-statements
222
237
  def start_vce(
223
238
  backend_name: str,
224
239
  backend_config_json_stream: str,
@@ -300,12 +315,14 @@ def start_vce(
300
315
  """Instantiate a Backend."""
301
316
  return backend_type(backend_config, work_dir=app_dir)
302
317
 
303
- log(INFO, "client_app_attr = %s", client_app_attr)
304
-
305
318
  # Load ClientApp if needed
306
319
  def _load() -> ClientApp:
307
320
 
308
321
  if client_app_attr:
322
+
323
+ if app_dir is not None:
324
+ sys.path.insert(0, app_dir)
325
+
309
326
  app: ClientApp = load_app(client_app_attr, LoadClientAppError)
310
327
 
311
328
  if not isinstance(app, ClientApp):
@@ -319,13 +336,30 @@ def start_vce(
319
336
 
320
337
  app_fn = _load
321
338
 
322
- asyncio.run(
323
- run(
324
- app_fn,
325
- backend_fn,
326
- nodes_mapping,
327
- state_factory,
328
- node_states,
329
- f_stop,
339
+ try:
340
+ # Test if ClientApp can be loaded
341
+ _ = app_fn()
342
+
343
+ # Run main simulation loop
344
+ asyncio.run(
345
+ run(
346
+ app_fn,
347
+ backend_fn,
348
+ nodes_mapping,
349
+ state_factory,
350
+ node_states,
351
+ f_stop,
352
+ )
330
353
  )
331
- )
354
+ except LoadClientAppError as loadapp_ex:
355
+ f_stop_delay = 10
356
+ log(
357
+ ERROR,
358
+ "LoadClientAppError exception encountered. Terminating simulation in %is",
359
+ f_stop_delay,
360
+ )
361
+ time.sleep(f_stop_delay)
362
+ f_stop.set() # set termination event
363
+ raise loadapp_ex
364
+ except Exception as ex:
365
+ raise ex
@@ -27,6 +27,8 @@ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
27
27
  from flwr.server.superlink.state.state import State
28
28
  from flwr.server.utils import validate_task_ins_or_res
29
29
 
30
+ from .utils import make_node_unavailable_taskres
31
+
30
32
 
31
33
  class InMemoryState(State):
32
34
  """In-memory State implementation."""
@@ -129,15 +131,32 @@ class InMemoryState(State):
129
131
  with self.lock:
130
132
  # Find TaskRes that were not delivered yet
131
133
  task_res_list: List[TaskRes] = []
134
+ replied_task_ids: Set[UUID] = set()
132
135
  for _, task_res in self.task_res_store.items():
133
- if (
134
- UUID(task_res.task.ancestry[0]) in task_ids
135
- and task_res.task.delivered_at == ""
136
- ):
136
+ reply_to = UUID(task_res.task.ancestry[0])
137
+ if reply_to in task_ids and task_res.task.delivered_at == "":
137
138
  task_res_list.append(task_res)
139
+ replied_task_ids.add(reply_to)
138
140
  if limit and len(task_res_list) == limit:
139
141
  break
140
142
 
143
+ # Check if the node is offline
144
+ for task_id in task_ids - replied_task_ids:
145
+ if limit and len(task_res_list) == limit:
146
+ break
147
+ task_ins = self.task_ins_store.get(task_id)
148
+ if task_ins is None:
149
+ continue
150
+ node_id = task_ins.task.consumer.node_id
151
+ online_until, _ = self.node_ids[node_id]
152
+ # Generate a TaskRes containing an error reply if the node is offline.
153
+ if online_until < time.time():
154
+ err_taskres = make_node_unavailable_taskres(
155
+ ref_taskins=task_ins,
156
+ )
157
+ self.task_res_store[UUID(err_taskres.task_id)] = err_taskres
158
+ task_res_list.append(err_taskres)
159
+
141
160
  # Mark all of them as delivered
142
161
  delivered_at = now().isoformat()
143
162
  for task_res in task_res_list:
@@ -182,16 +201,14 @@ class InMemoryState(State):
182
201
  """
183
202
  return len(self.task_res_store)
184
203
 
185
- def create_node(self) -> int:
204
+ def create_node(self, ping_interval: float) -> int:
186
205
  """Create, store in state, and return `node_id`."""
187
206
  # Sample a random int64 as node_id
188
207
  node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
189
208
 
190
209
  with self.lock:
191
210
  if node_id not in self.node_ids:
192
- # Default ping interval is 30s
193
- # TODO: change 1e9 to 30s # pylint: disable=W0511
194
- self.node_ids[node_id] = (time.time() + 1e9, 1e9)
211
+ self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
195
212
  return node_id
196
213
  log(ERROR, "Unexpected node registration failure.")
197
214
  return 0
@@ -30,6 +30,7 @@ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
30
30
  from flwr.server.utils.validator import validate_task_ins_or_res
31
31
 
32
32
  from .state import State
33
+ from .utils import make_node_unavailable_taskres
33
34
 
34
35
  SQL_CREATE_TABLE_NODE = """
35
36
  CREATE TABLE IF NOT EXISTS node(
@@ -344,6 +345,7 @@ class SqliteState(State):
344
345
 
345
346
  return task_id
346
347
 
348
+ # pylint: disable-next=R0914
347
349
  def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]:
348
350
  """Get TaskRes for task_ids.
349
351
 
@@ -374,7 +376,7 @@ class SqliteState(State):
374
376
  AND delivered_at = ""
375
377
  """
376
378
 
377
- data: Dict[str, Union[str, int]] = {}
379
+ data: Dict[str, Union[str, float, int]] = {}
378
380
 
379
381
  if limit is not None:
380
382
  query += " LIMIT :limit"
@@ -408,6 +410,54 @@ class SqliteState(State):
408
410
  rows = self.query(query, data)
409
411
 
410
412
  result = [dict_to_task_res(row) for row in rows]
413
+
414
+ # 1. Query: Fetch consumer_node_id of remaining task_ids
415
+ # Assume the ancestry field only contains one element
416
+ data.clear()
417
+ replied_task_ids: Set[UUID] = {UUID(str(row["ancestry"])) for row in rows}
418
+ remaining_task_ids = task_ids - replied_task_ids
419
+ placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))])
420
+ query = f"""
421
+ SELECT consumer_node_id
422
+ FROM task_ins
423
+ WHERE task_id IN ({placeholders});
424
+ """
425
+ for index, task_id in enumerate(remaining_task_ids):
426
+ data[f"id_{index}"] = str(task_id)
427
+ node_ids = [int(row["consumer_node_id"]) for row in self.query(query, data)]
428
+
429
+ # 2. Query: Select offline nodes
430
+ placeholders = ",".join([f":id_{i}" for i in range(len(node_ids))])
431
+ query = f"""
432
+ SELECT node_id
433
+ FROM node
434
+ WHERE node_id IN ({placeholders})
435
+ AND online_until < :time;
436
+ """
437
+ data = {f"id_{i}": str(node_id) for i, node_id in enumerate(node_ids)}
438
+ data["time"] = time.time()
439
+ offline_node_ids = [int(row["node_id"]) for row in self.query(query, data)]
440
+
441
+ # 3. Query: Select TaskIns for offline nodes
442
+ placeholders = ",".join([f":id_{i}" for i in range(len(offline_node_ids))])
443
+ query = f"""
444
+ SELECT *
445
+ FROM task_ins
446
+ WHERE consumer_node_id IN ({placeholders});
447
+ """
448
+ data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)}
449
+ task_ins_rows = self.query(query, data)
450
+
451
+ # Make TaskRes containing node unavailabe error
452
+ for row in task_ins_rows:
453
+ if limit and len(result) == limit:
454
+ break
455
+ task_ins = dict_to_task_ins(row)
456
+ err_taskres = make_node_unavailable_taskres(
457
+ ref_taskins=task_ins,
458
+ )
459
+ result.append(err_taskres)
460
+
411
461
  return result
412
462
 
413
463
  def num_task_ins(self) -> int:
@@ -468,7 +518,7 @@ class SqliteState(State):
468
518
 
469
519
  return None
470
520
 
471
- def create_node(self) -> int:
521
+ def create_node(self, ping_interval: float) -> int:
472
522
  """Create, store in state, and return `node_id`."""
473
523
  # Sample a random int64 as node_id
474
524
  node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
@@ -478,9 +528,7 @@ class SqliteState(State):
478
528
  )
479
529
 
480
530
  try:
481
- # Default ping interval is 30s
482
- # TODO: change 1e9 to 30s # pylint: disable=W0511
483
- self.query(query, (node_id, time.time() + 1e9, 1e9))
531
+ self.query(query, (node_id, time.time() + ping_interval, ping_interval))
484
532
  except sqlite3.IntegrityError:
485
533
  log(ERROR, "Unexpected node registration failure.")
486
534
  return 0
@@ -132,7 +132,7 @@ class State(abc.ABC):
132
132
  """Delete all delivered TaskIns/TaskRes pairs."""
133
133
 
134
134
  @abc.abstractmethod
135
- def create_node(self) -> int:
135
+ def create_node(self, ping_interval: float) -> int:
136
136
  """Create, store in state, and return `node_id`."""
137
137
 
138
138
  @abc.abstractmethod
@@ -0,0 +1,56 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utility functions for State."""
16
+
17
+
18
+ import time
19
+ from logging import ERROR
20
+ from uuid import uuid4
21
+
22
+ from flwr.common import log
23
+ from flwr.common.constant import ErrorCode
24
+ from flwr.proto.error_pb2 import Error # pylint: disable=E0611
25
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
26
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
27
+
28
+ NODE_UNAVAILABLE_ERROR_REASON = (
29
+ "Error: Node Unavailable - The destination node is currently unavailable. "
30
+ "It exceeds the time limit specified in its last ping."
31
+ )
32
+
33
+
34
+ def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
35
+ """Generate a TaskRes with a node unavailable error from a TaskIns."""
36
+ current_time = time.time()
37
+ ttl = ref_taskins.task.ttl - (current_time - ref_taskins.task.created_at)
38
+ if ttl < 0:
39
+ log(ERROR, "Creating TaskRes for TaskIns that exceeds its TTL.")
40
+ ttl = 0
41
+ return TaskRes(
42
+ task_id=str(uuid4()),
43
+ group_id=ref_taskins.group_id,
44
+ run_id=ref_taskins.run_id,
45
+ task=Task(
46
+ producer=Node(node_id=ref_taskins.task.consumer.node_id, anonymous=False),
47
+ consumer=Node(node_id=ref_taskins.task.producer.node_id, anonymous=False),
48
+ created_at=current_time,
49
+ ttl=ttl,
50
+ ancestry=[ref_taskins.task_id],
51
+ task_type=ref_taskins.task.task_type,
52
+ error=Error(
53
+ code=ErrorCode.NODE_UNAVAILABLE, reason=NODE_UNAVAILABLE_ERROR_REASON
54
+ ),
55
+ ),
56
+ )
@@ -21,7 +21,7 @@ from logging import INFO
21
21
  from typing import Optional, cast
22
22
 
23
23
  import flwr.common.recordset_compat as compat
24
- from flwr.common import DEFAULT_TTL, ConfigsRecord, Context, GetParametersIns, log
24
+ from flwr.common import ConfigsRecord, Context, GetParametersIns, log
25
25
  from flwr.common.constant import MessageType, MessageTypeLegacy
26
26
 
27
27
  from ..compat.app_utils import start_update_client_manager_thread
@@ -127,7 +127,6 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
127
127
  message_type=MessageTypeLegacy.GET_PARAMETERS,
128
128
  dst_node_id=random_client.node_id,
129
129
  group_id="0",
130
- ttl=DEFAULT_TTL,
131
130
  )
132
131
  ]
133
132
  )
@@ -226,7 +225,6 @@ def default_fit_workflow( # pylint: disable=R0914
226
225
  message_type=MessageType.TRAIN,
227
226
  dst_node_id=proxy.node_id,
228
227
  group_id=str(current_round),
229
- ttl=DEFAULT_TTL,
230
228
  )
231
229
  for proxy, fitins in client_instructions
232
230
  ]
@@ -306,7 +304,6 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
306
304
  message_type=MessageType.EVALUATE,
307
305
  dst_node_id=proxy.node_id,
308
306
  group_id=str(current_round),
309
- ttl=DEFAULT_TTL,
310
307
  )
311
308
  for proxy, evalins in client_instructions
312
309
  ]
@@ -22,7 +22,6 @@ from typing import Dict, List, Optional, Set, Tuple, Union, cast
22
22
 
23
23
  import flwr.common.recordset_compat as compat
24
24
  from flwr.common import (
25
- DEFAULT_TTL,
26
25
  ConfigsRecord,
27
26
  Context,
28
27
  FitRes,
@@ -374,7 +373,6 @@ class SecAggPlusWorkflow:
374
373
  message_type=MessageType.TRAIN,
375
374
  dst_node_id=nid,
376
375
  group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
377
- ttl=DEFAULT_TTL,
378
376
  )
379
377
 
380
378
  log(
@@ -422,7 +420,6 @@ class SecAggPlusWorkflow:
422
420
  message_type=MessageType.TRAIN,
423
421
  dst_node_id=nid,
424
422
  group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
425
- ttl=DEFAULT_TTL,
426
423
  )
427
424
 
428
425
  # Broadcast public keys to clients and receive secret key shares
@@ -493,7 +490,6 @@ class SecAggPlusWorkflow:
493
490
  message_type=MessageType.TRAIN,
494
491
  dst_node_id=nid,
495
492
  group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
496
- ttl=DEFAULT_TTL,
497
493
  )
498
494
 
499
495
  log(
@@ -564,7 +560,6 @@ class SecAggPlusWorkflow:
564
560
  message_type=MessageType.TRAIN,
565
561
  dst_node_id=nid,
566
562
  group_id=str(current_round),
567
- ttl=DEFAULT_TTL,
568
563
  )
569
564
 
570
565
  log(
@@ -16,7 +16,6 @@
16
16
 
17
17
  import asyncio
18
18
  import threading
19
- import traceback
20
19
  from abc import ABC
21
20
  from logging import DEBUG, ERROR, WARNING
22
21
  from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
@@ -25,22 +24,13 @@ import ray
25
24
  from ray import ObjectRef
26
25
  from ray.util.actor_pool import ActorPool
27
26
 
28
- from flwr.client.client_app import ClientApp, LoadClientAppError
27
+ from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
29
28
  from flwr.common import Context, Message
30
29
  from flwr.common.logger import log
31
30
 
32
31
  ClientAppFn = Callable[[], ClientApp]
33
32
 
34
33
 
35
- class ClientException(Exception):
36
- """Raised when client side logic crashes with an exception."""
37
-
38
- def __init__(self, message: str):
39
- div = ">" * 7
40
- self.message = "\n" + div + "A ClientException occurred." + message
41
- super().__init__(self.message)
42
-
43
-
44
34
  class VirtualClientEngineActor(ABC):
45
35
  """Abstract base class for VirtualClientEngine Actors."""
46
36
 
@@ -71,17 +61,7 @@ class VirtualClientEngineActor(ABC):
71
61
  raise load_ex
72
62
 
73
63
  except Exception as ex:
74
- client_trace = traceback.format_exc()
75
- mssg = (
76
- "\n\tSomething went wrong when running your client run."
77
- "\n\tClient "
78
- + cid
79
- + " crashed when the "
80
- + self.__class__.__name__
81
- + " was running its run."
82
- "\n\tException triggered on the client side: " + client_trace,
83
- )
84
- raise ClientException(str(mssg)) from ex
64
+ raise ClientAppException(str(ex)) from ex
85
65
 
86
66
  return cid, out_message, context
87
67
 
@@ -493,13 +473,17 @@ class BasicActorPool:
493
473
  self._future_to_actor[future] = actor
494
474
  return future
495
475
 
476
+ async def add_actor_back_to_pool(self, future: Any) -> None:
477
+ """Ad actor assigned to run future back into the pool."""
478
+ actor = self._future_to_actor.pop(future)
479
+ await self.pool.put(actor)
480
+
496
481
  async def fetch_result_and_return_actor_to_pool(
497
482
  self, future: Any
498
483
  ) -> Tuple[Message, Context]:
499
484
  """Pull result given a future and add actor back to pool."""
500
485
  # Get actor that ran job
501
- actor = self._future_to_actor.pop(future)
502
- await self.pool.put(actor)
486
+ await self.add_actor_back_to_pool(future)
503
487
  # Retrieve result for object store
504
488
  # Instead of doing ray.get(future) we await it
505
489
  _, out_mssg, updated_context = await future
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.8.0.dev20240327
3
+ Version: 1.8.0.dev20240402
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0