flwr-nightly 1.10.0.dev20240624__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +18 -4
- flwr/cli/config_utils.py +36 -14
- flwr/cli/install.py +17 -1
- flwr/cli/new/new.py +31 -20
- flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
- flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
- flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
- flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
- flwr/cli/run/run.py +135 -51
- flwr/client/__init__.py +2 -0
- flwr/client/app.py +63 -26
- flwr/client/client_app.py +49 -4
- flwr/client/grpc_adapter_client/connection.py +3 -2
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +17 -6
- flwr/client/message_handler/message_handler.py +3 -4
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/connection.py +19 -8
- flwr/client/supernode/app.py +60 -21
- flwr/client/typing.py +1 -0
- flwr/common/config.py +87 -2
- flwr/common/constant.py +6 -0
- flwr/common/context.py +26 -1
- flwr/common/logger.py +38 -0
- flwr/common/message.py +0 -17
- flwr/common/serde.py +45 -0
- flwr/common/telemetry.py +17 -0
- flwr/common/typing.py +5 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +24 -19
- flwr/proto/driver_pb2.pyi +21 -1
- flwr/proto/exec_pb2.py +16 -11
- flwr/proto/exec_pb2.pyi +22 -1
- flwr/proto/run_pb2.py +12 -7
- flwr/proto/run_pb2.pyi +22 -1
- flwr/proto/task_pb2.py +7 -8
- flwr/server/__init__.py +2 -0
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/grpc_driver.py +82 -140
- flwr/server/run_serverapp.py +40 -15
- flwr/server/server_app.py +56 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/superlink/driver/driver_servicer.py +18 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
- flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +149 -122
- flwr/server/superlink/state/in_memory_state.py +15 -7
- flwr/server/superlink/state/sqlite_state.py +27 -12
- flwr/server/superlink/state/state.py +7 -2
- flwr/server/superlink/state/utils.py +6 -0
- flwr/server/typing.py +2 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
- flwr/simulation/app.py +52 -36
- flwr/simulation/ray_transport/ray_actor.py +15 -19
- flwr/simulation/ray_transport/ray_client_proxy.py +33 -13
- flwr/simulation/run_simulation.py +237 -66
- flwr/superexec/app.py +14 -7
- flwr/superexec/deployment.py +186 -0
- flwr/superexec/exec_grpc.py +5 -1
- flwr/superexec/exec_servicer.py +4 -1
- flwr/superexec/executor.py +18 -0
- flwr/superexec/simulation.py +151 -0
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +95 -88
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
flwr/server/typing.py
CHANGED
|
@@ -20,6 +20,8 @@ from typing import Callable
|
|
|
20
20
|
from flwr.common import Context
|
|
21
21
|
|
|
22
22
|
from .driver import Driver
|
|
23
|
+
from .serverapp_components import ServerAppComponents
|
|
23
24
|
|
|
24
25
|
ServerAppCallable = Callable[[Driver, Context], None]
|
|
25
26
|
Workflow = Callable[[Driver, Context], None]
|
|
27
|
+
ServerFn = Callable[[Context], ServerAppComponents]
|
|
@@ -81,6 +81,7 @@ class WorkflowState: # pylint: disable=R0902
|
|
|
81
81
|
forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
|
|
82
82
|
aggregate_ndarrays: NDArrays = field(default_factory=list)
|
|
83
83
|
legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)
|
|
84
|
+
failures: List[Exception] = field(default_factory=list)
|
|
84
85
|
|
|
85
86
|
|
|
86
87
|
class SecAggPlusWorkflow:
|
|
@@ -394,6 +395,7 @@ class SecAggPlusWorkflow:
|
|
|
394
395
|
|
|
395
396
|
for msg in msgs:
|
|
396
397
|
if msg.has_error():
|
|
398
|
+
state.failures.append(Exception(msg.error))
|
|
397
399
|
continue
|
|
398
400
|
key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
399
401
|
node_id = msg.metadata.src_node_id
|
|
@@ -451,6 +453,9 @@ class SecAggPlusWorkflow:
|
|
|
451
453
|
nid: [] for nid in state.active_node_ids
|
|
452
454
|
} # dest node ID -> list of src node IDs
|
|
453
455
|
for msg in msgs:
|
|
456
|
+
if msg.has_error():
|
|
457
|
+
state.failures.append(Exception(msg.error))
|
|
458
|
+
continue
|
|
454
459
|
node_id = msg.metadata.src_node_id
|
|
455
460
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
456
461
|
dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST])
|
|
@@ -515,6 +520,9 @@ class SecAggPlusWorkflow:
|
|
|
515
520
|
# Sum collected masked vectors and compute active/dead node IDs
|
|
516
521
|
masked_vector = None
|
|
517
522
|
for msg in msgs:
|
|
523
|
+
if msg.has_error():
|
|
524
|
+
state.failures.append(Exception(msg.error))
|
|
525
|
+
continue
|
|
518
526
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
519
527
|
bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS])
|
|
520
528
|
client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
|
|
@@ -528,6 +536,9 @@ class SecAggPlusWorkflow:
|
|
|
528
536
|
|
|
529
537
|
# Backward compatibility with Strategy
|
|
530
538
|
for msg in msgs:
|
|
539
|
+
if msg.has_error():
|
|
540
|
+
state.failures.append(Exception(msg.error))
|
|
541
|
+
continue
|
|
531
542
|
fitres = compat.recordset_to_fitres(msg.content, True)
|
|
532
543
|
proxy = state.nid_to_proxies[msg.metadata.src_node_id]
|
|
533
544
|
state.legacy_results.append((proxy, fitres))
|
|
@@ -584,6 +595,9 @@ class SecAggPlusWorkflow:
|
|
|
584
595
|
for nid in state.sampled_node_ids:
|
|
585
596
|
collected_shares_dict[nid] = []
|
|
586
597
|
for msg in msgs:
|
|
598
|
+
if msg.has_error():
|
|
599
|
+
state.failures.append(Exception(msg.error))
|
|
600
|
+
continue
|
|
587
601
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
588
602
|
nids = cast(List[int], res_dict[Key.NODE_ID_LIST])
|
|
589
603
|
shares = cast(List[bytes], res_dict[Key.SHARE_LIST])
|
|
@@ -652,9 +666,11 @@ class SecAggPlusWorkflow:
|
|
|
652
666
|
INFO,
|
|
653
667
|
"aggregate_fit: received %s results and %s failures",
|
|
654
668
|
len(results),
|
|
655
|
-
|
|
669
|
+
len(state.failures),
|
|
670
|
+
)
|
|
671
|
+
aggregated_result = context.strategy.aggregate_fit(
|
|
672
|
+
current_round, results, state.failures # type: ignore
|
|
656
673
|
)
|
|
657
|
-
aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
|
|
658
674
|
parameters_aggregated, metrics_aggregated = aggregated_result
|
|
659
675
|
|
|
660
676
|
# Update the parameters and write history
|
flwr/simulation/app.py
CHANGED
|
@@ -27,14 +27,16 @@ from typing import Any, Dict, List, Optional, Type, Union
|
|
|
27
27
|
import ray
|
|
28
28
|
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
|
29
29
|
|
|
30
|
-
from flwr.client import
|
|
30
|
+
from flwr.client import ClientFnExt
|
|
31
31
|
from flwr.common import EventType, event
|
|
32
|
-
from flwr.common.
|
|
32
|
+
from flwr.common.constant import NODE_ID_NUM_BYTES
|
|
33
|
+
from flwr.common.logger import log, set_logger_propagation, warn_unsupported_feature
|
|
33
34
|
from flwr.server.client_manager import ClientManager
|
|
34
35
|
from flwr.server.history import History
|
|
35
36
|
from flwr.server.server import Server, init_defaults, run_fl
|
|
36
37
|
from flwr.server.server_config import ServerConfig
|
|
37
38
|
from flwr.server.strategy import Strategy
|
|
39
|
+
from flwr.server.superlink.state.utils import generate_rand_int_from_bytes
|
|
38
40
|
from flwr.simulation.ray_transport.ray_actor import (
|
|
39
41
|
ClientAppActor,
|
|
40
42
|
VirtualClientEngineActor,
|
|
@@ -51,7 +53,7 @@ Invalid Arguments in method:
|
|
|
51
53
|
`start_simulation(
|
|
52
54
|
*,
|
|
53
55
|
client_fn: ClientFn,
|
|
54
|
-
num_clients:
|
|
56
|
+
num_clients: int,
|
|
55
57
|
clients_ids: Optional[List[str]] = None,
|
|
56
58
|
client_resources: Optional[Dict[str, float]] = None,
|
|
57
59
|
server: Optional[Server] = None,
|
|
@@ -70,13 +72,29 @@ REASON:
|
|
|
70
72
|
|
|
71
73
|
"""
|
|
72
74
|
|
|
75
|
+
NodeToPartitionMapping = Dict[int, int]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _create_node_id_to_partition_mapping(
|
|
79
|
+
num_clients: int,
|
|
80
|
+
) -> NodeToPartitionMapping:
|
|
81
|
+
"""Generate a node_id:partition_id mapping."""
|
|
82
|
+
nodes_mapping: NodeToPartitionMapping = {} # {node-id; partition-id}
|
|
83
|
+
for i in range(num_clients):
|
|
84
|
+
while True:
|
|
85
|
+
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
|
86
|
+
if node_id not in nodes_mapping:
|
|
87
|
+
break
|
|
88
|
+
nodes_mapping[node_id] = i
|
|
89
|
+
return nodes_mapping
|
|
90
|
+
|
|
73
91
|
|
|
74
92
|
# pylint: disable=too-many-arguments,too-many-statements,too-many-branches
|
|
75
93
|
def start_simulation(
|
|
76
94
|
*,
|
|
77
|
-
client_fn:
|
|
78
|
-
num_clients:
|
|
79
|
-
clients_ids: Optional[List[str]] = None,
|
|
95
|
+
client_fn: ClientFnExt,
|
|
96
|
+
num_clients: int,
|
|
97
|
+
clients_ids: Optional[List[str]] = None, # UNSUPPORTED, WILL BE REMOVED
|
|
80
98
|
client_resources: Optional[Dict[str, float]] = None,
|
|
81
99
|
server: Optional[Server] = None,
|
|
82
100
|
config: Optional[ServerConfig] = None,
|
|
@@ -92,23 +110,24 @@ def start_simulation(
|
|
|
92
110
|
|
|
93
111
|
Parameters
|
|
94
112
|
----------
|
|
95
|
-
client_fn :
|
|
96
|
-
A function creating
|
|
97
|
-
`
|
|
98
|
-
of type Client
|
|
99
|
-
and will often be destroyed after a single method
|
|
100
|
-
instances are not long-lived, they should not attempt
|
|
101
|
-
method invocations. Any state required by the instance
|
|
102
|
-
hyperparameters, ...) should be (re-)created in either the
|
|
103
|
-
or the call to any of the client methods (e.g., load
|
|
104
|
-
`evaluate` method itself).
|
|
105
|
-
num_clients :
|
|
106
|
-
The total number of clients in this simulation.
|
|
107
|
-
`clients_ids` is not set and vice-versa.
|
|
113
|
+
client_fn : ClientFnExt
|
|
114
|
+
A function creating `Client` instances. The function must have the signature
|
|
115
|
+
`client_fn(context: Context). It should return
|
|
116
|
+
a single client instance of type `Client`. Note that the created client
|
|
117
|
+
instances are ephemeral and will often be destroyed after a single method
|
|
118
|
+
invocation. Since client instances are not long-lived, they should not attempt
|
|
119
|
+
to carry state over method invocations. Any state required by the instance
|
|
120
|
+
(model, dataset, hyperparameters, ...) should be (re-)created in either the
|
|
121
|
+
call to `client_fn` or the call to any of the client methods (e.g., load
|
|
122
|
+
evaluation data in the `evaluate` method itself).
|
|
123
|
+
num_clients : int
|
|
124
|
+
The total number of clients in this simulation.
|
|
108
125
|
clients_ids : Optional[List[str]]
|
|
126
|
+
UNSUPPORTED, WILL BE REMOVED. USE `num_clients` INSTEAD.
|
|
109
127
|
List `client_id`s for each client. This is only required if
|
|
110
128
|
`num_clients` is not set. Setting both `num_clients` and `clients_ids`
|
|
111
129
|
with `len(clients_ids)` not equal to `num_clients` generates an error.
|
|
130
|
+
Using this argument will raise an error.
|
|
112
131
|
client_resources : Optional[Dict[str, float]] (default: `{"num_cpus": 1, "num_gpus": 0.0}`)
|
|
113
132
|
CPU and GPU resources for a single client. Supported keys
|
|
114
133
|
are `num_cpus` and `num_gpus`. To understand the GPU utilization caused by
|
|
@@ -158,7 +177,6 @@ def start_simulation(
|
|
|
158
177
|
is an advanced feature. For all details, please refer to the Ray documentation:
|
|
159
178
|
https://docs.ray.io/en/latest/ray-core/scheduling/index.html
|
|
160
179
|
|
|
161
|
-
|
|
162
180
|
Returns
|
|
163
181
|
-------
|
|
164
182
|
hist : flwr.server.history.History
|
|
@@ -170,6 +188,14 @@ def start_simulation(
|
|
|
170
188
|
{"num_clients": len(clients_ids) if clients_ids is not None else num_clients},
|
|
171
189
|
)
|
|
172
190
|
|
|
191
|
+
if clients_ids is not None:
|
|
192
|
+
warn_unsupported_feature(
|
|
193
|
+
"Passing `clients_ids` to `start_simulation` is deprecated and not longer "
|
|
194
|
+
"used by `start_simulation`. Use `num_clients` exclusively instead."
|
|
195
|
+
)
|
|
196
|
+
log(ERROR, "`clients_ids` argument used.")
|
|
197
|
+
sys.exit()
|
|
198
|
+
|
|
173
199
|
# Set logger propagation
|
|
174
200
|
loop: Optional[asyncio.AbstractEventLoop] = None
|
|
175
201
|
try:
|
|
@@ -196,20 +222,8 @@ def start_simulation(
|
|
|
196
222
|
initialized_config,
|
|
197
223
|
)
|
|
198
224
|
|
|
199
|
-
#
|
|
200
|
-
|
|
201
|
-
if clients_ids is not None:
|
|
202
|
-
if (num_clients is not None) and (len(clients_ids) != num_clients):
|
|
203
|
-
log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
|
|
204
|
-
sys.exit()
|
|
205
|
-
else:
|
|
206
|
-
cids = clients_ids
|
|
207
|
-
else:
|
|
208
|
-
if num_clients is None:
|
|
209
|
-
log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
|
|
210
|
-
sys.exit()
|
|
211
|
-
else:
|
|
212
|
-
cids = [str(x) for x in range(num_clients)]
|
|
225
|
+
# Create node-id to partition-id mapping
|
|
226
|
+
nodes_mapping = _create_node_id_to_partition_mapping(num_clients)
|
|
213
227
|
|
|
214
228
|
# Default arguments for Ray initialization
|
|
215
229
|
if not ray_init_args:
|
|
@@ -308,10 +322,12 @@ def start_simulation(
|
|
|
308
322
|
)
|
|
309
323
|
|
|
310
324
|
# Register one RayClientProxy object for each client with the ClientManager
|
|
311
|
-
for
|
|
325
|
+
for node_id, partition_id in nodes_mapping.items():
|
|
312
326
|
client_proxy = RayActorClientProxy(
|
|
313
327
|
client_fn=client_fn,
|
|
314
|
-
|
|
328
|
+
node_id=node_id,
|
|
329
|
+
partition_id=partition_id,
|
|
330
|
+
num_partitions=num_clients,
|
|
315
331
|
actor_pool=pool,
|
|
316
332
|
)
|
|
317
333
|
initialized_server.client_manager().register(client=client_proxy)
|
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Ray-based Flower Actor and ActorPool implementation."""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import threading
|
|
19
18
|
from abc import ABC
|
|
20
19
|
from logging import DEBUG, ERROR, WARNING
|
|
@@ -411,9 +410,7 @@ class BasicActorPool:
|
|
|
411
410
|
self.client_resources = client_resources
|
|
412
411
|
|
|
413
412
|
# Queue of idle actors
|
|
414
|
-
self.pool:
|
|
415
|
-
maxsize=1024
|
|
416
|
-
)
|
|
413
|
+
self.pool: List[VirtualClientEngineActor] = []
|
|
417
414
|
self.num_actors = 0
|
|
418
415
|
|
|
419
416
|
# Resolve arguments to pass during actor init
|
|
@@ -427,38 +424,37 @@ class BasicActorPool:
|
|
|
427
424
|
# Figure out how many actors can be created given the cluster resources
|
|
428
425
|
# and the resources the user indicates each VirtualClient will need
|
|
429
426
|
self.actors_capacity = pool_size_from_resources(client_resources)
|
|
430
|
-
self._future_to_actor: Dict[Any,
|
|
427
|
+
self._future_to_actor: Dict[Any, VirtualClientEngineActor] = {}
|
|
431
428
|
|
|
432
429
|
def is_actor_available(self) -> bool:
|
|
433
430
|
"""Return true if there is an idle actor."""
|
|
434
|
-
return self.pool
|
|
431
|
+
return len(self.pool) > 0
|
|
435
432
|
|
|
436
|
-
|
|
433
|
+
def add_actors_to_pool(self, num_actors: int) -> None:
|
|
437
434
|
"""Add actors to the pool.
|
|
438
435
|
|
|
439
436
|
This method may be executed also if new resources are added to your Ray cluster
|
|
440
437
|
(e.g. you add a new node).
|
|
441
438
|
"""
|
|
442
439
|
for _ in range(num_actors):
|
|
443
|
-
|
|
440
|
+
self.pool.append(self.create_actor_fn()) # type: ignore
|
|
444
441
|
self.num_actors += num_actors
|
|
445
442
|
|
|
446
|
-
|
|
443
|
+
def terminate_all_actors(self) -> None:
|
|
447
444
|
"""Terminate actors in pool."""
|
|
448
445
|
num_terminated = 0
|
|
449
|
-
|
|
450
|
-
actor = await self.pool.get()
|
|
446
|
+
for actor in self.pool:
|
|
451
447
|
actor.terminate.remote() # type: ignore
|
|
452
448
|
num_terminated += 1
|
|
453
449
|
|
|
454
450
|
log(DEBUG, "Terminated %i actors", num_terminated)
|
|
455
451
|
|
|
456
|
-
|
|
452
|
+
def submit(
|
|
457
453
|
self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context]
|
|
458
454
|
) -> Any:
|
|
459
455
|
"""On idle actor, submit job and return future."""
|
|
460
456
|
# Remove idle actor from pool
|
|
461
|
-
actor =
|
|
457
|
+
actor = self.pool.pop()
|
|
462
458
|
# Submit job to actor
|
|
463
459
|
app_fn, mssg, cid, context = job
|
|
464
460
|
future = actor_fn(actor, app_fn, mssg, cid, context)
|
|
@@ -467,18 +463,18 @@ class BasicActorPool:
|
|
|
467
463
|
self._future_to_actor[future] = actor
|
|
468
464
|
return future
|
|
469
465
|
|
|
470
|
-
|
|
466
|
+
def add_actor_back_to_pool(self, future: Any) -> None:
|
|
471
467
|
"""Ad actor assigned to run future back into the pool."""
|
|
472
468
|
actor = self._future_to_actor.pop(future)
|
|
473
|
-
|
|
469
|
+
self.pool.append(actor)
|
|
474
470
|
|
|
475
|
-
|
|
471
|
+
def fetch_result_and_return_actor_to_pool(
|
|
476
472
|
self, future: Any
|
|
477
473
|
) -> Tuple[Message, Context]:
|
|
478
474
|
"""Pull result given a future and add actor back to pool."""
|
|
479
|
-
# Get actor that ran job
|
|
480
|
-
await self.add_actor_back_to_pool(future)
|
|
481
475
|
# Retrieve result for object store
|
|
482
476
|
# Instead of doing ray.get(future) we await it
|
|
483
|
-
_, out_mssg, updated_context =
|
|
477
|
+
_, out_mssg, updated_context = ray.get(future)
|
|
478
|
+
# Get actor that ran job
|
|
479
|
+
self.add_actor_back_to_pool(future)
|
|
484
480
|
return out_mssg, updated_context
|
|
@@ -20,11 +20,16 @@ from logging import ERROR
|
|
|
20
20
|
from typing import Optional
|
|
21
21
|
|
|
22
22
|
from flwr import common
|
|
23
|
-
from flwr.client import
|
|
23
|
+
from flwr.client import ClientFnExt
|
|
24
24
|
from flwr.client.client_app import ClientApp
|
|
25
25
|
from flwr.client.node_state import NodeState
|
|
26
26
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
27
|
-
from flwr.common.constant import
|
|
27
|
+
from flwr.common.constant import (
|
|
28
|
+
NUM_PARTITIONS_KEY,
|
|
29
|
+
PARTITION_ID_KEY,
|
|
30
|
+
MessageType,
|
|
31
|
+
MessageTypeLegacy,
|
|
32
|
+
)
|
|
28
33
|
from flwr.common.logger import log
|
|
29
34
|
from flwr.common.recordset_compat import (
|
|
30
35
|
evaluateins_to_recordset,
|
|
@@ -43,17 +48,30 @@ from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActorPool
|
|
|
43
48
|
class RayActorClientProxy(ClientProxy):
|
|
44
49
|
"""Flower client proxy which delegates work using Ray."""
|
|
45
50
|
|
|
46
|
-
def __init__(
|
|
47
|
-
self,
|
|
51
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
52
|
+
self,
|
|
53
|
+
client_fn: ClientFnExt,
|
|
54
|
+
node_id: int,
|
|
55
|
+
partition_id: int,
|
|
56
|
+
num_partitions: int,
|
|
57
|
+
actor_pool: VirtualClientEngineActorPool,
|
|
48
58
|
):
|
|
49
|
-
super().__init__(cid)
|
|
59
|
+
super().__init__(cid=str(node_id))
|
|
60
|
+
self.node_id = node_id
|
|
61
|
+
self.partition_id = partition_id
|
|
50
62
|
|
|
51
63
|
def _load_app() -> ClientApp:
|
|
52
64
|
return ClientApp(client_fn=client_fn)
|
|
53
65
|
|
|
54
66
|
self.app_fn = _load_app
|
|
55
67
|
self.actor_pool = actor_pool
|
|
56
|
-
self.proxy_state = NodeState(
|
|
68
|
+
self.proxy_state = NodeState(
|
|
69
|
+
node_id=node_id,
|
|
70
|
+
node_config={
|
|
71
|
+
PARTITION_ID_KEY: str(partition_id),
|
|
72
|
+
NUM_PARTITIONS_KEY: str(num_partitions),
|
|
73
|
+
},
|
|
74
|
+
)
|
|
57
75
|
|
|
58
76
|
def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
|
|
59
77
|
"""Sumbit a message to the ActorPool."""
|
|
@@ -62,16 +80,19 @@ class RayActorClientProxy(ClientProxy):
|
|
|
62
80
|
# Register state
|
|
63
81
|
self.proxy_state.register_context(run_id=run_id)
|
|
64
82
|
|
|
65
|
-
# Retrieve
|
|
66
|
-
|
|
83
|
+
# Retrieve context
|
|
84
|
+
context = self.proxy_state.retrieve_context(run_id=run_id)
|
|
85
|
+
partition_id_str = str(context.node_config[PARTITION_ID_KEY])
|
|
67
86
|
|
|
68
87
|
try:
|
|
69
88
|
self.actor_pool.submit_client_job(
|
|
70
|
-
lambda a, a_fn, mssg,
|
|
71
|
-
|
|
89
|
+
lambda a, a_fn, mssg, partition_id, context: a.run.remote(
|
|
90
|
+
a_fn, mssg, partition_id, context
|
|
91
|
+
),
|
|
92
|
+
(self.app_fn, message, partition_id_str, context),
|
|
72
93
|
)
|
|
73
94
|
out_mssg, updated_context = self.actor_pool.get_client_result(
|
|
74
|
-
|
|
95
|
+
partition_id_str, timeout
|
|
75
96
|
)
|
|
76
97
|
|
|
77
98
|
# Update state
|
|
@@ -103,11 +124,10 @@ class RayActorClientProxy(ClientProxy):
|
|
|
103
124
|
message_id="",
|
|
104
125
|
group_id=str(group_id) if group_id is not None else "",
|
|
105
126
|
src_node_id=0,
|
|
106
|
-
dst_node_id=
|
|
127
|
+
dst_node_id=self.node_id,
|
|
107
128
|
reply_to_message="",
|
|
108
129
|
ttl=timeout if timeout else DEFAULT_TTL,
|
|
109
130
|
message_type=message_type,
|
|
110
|
-
partition_id=int(self.cid),
|
|
111
131
|
),
|
|
112
132
|
)
|
|
113
133
|
|