flwr-nightly 1.8.0.dev20240315__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +7 -0
- flwr/cli/build.py +150 -0
- flwr/cli/config_utils.py +219 -0
- flwr/cli/example.py +3 -1
- flwr/cli/install.py +227 -0
- flwr/cli/new/new.py +179 -48
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/README.md.tpl +1 -5
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
- flwr/cli/run/run.py +168 -17
- flwr/cli/utils.py +75 -4
- flwr/client/__init__.py +6 -1
- flwr/client/app.py +239 -248
- flwr/client/client_app.py +70 -9
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +97 -0
- flwr/client/grpc_client/connection.py +18 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +127 -33
- flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
- flwr/client/heartbeat.py +74 -0
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +7 -7
- flwr/client/mod/__init__.py +5 -5
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/comms_mods.py +4 -4
- flwr/client/mod/localdp_mod.py +9 -4
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +177 -157
- flwr/client/supernode/__init__.py +26 -0
- flwr/client/supernode/app.py +464 -0
- flwr/client/typing.py +1 -0
- flwr/common/__init__.py +13 -11
- flwr/common/address.py +1 -1
- flwr/common/config.py +193 -0
- flwr/common/constant.py +42 -1
- flwr/common/context.py +26 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +6 -2
- flwr/common/logger.py +79 -8
- flwr/common/message.py +167 -105
- flwr/common/object_ref.py +126 -25
- flwr/common/record/__init__.py +1 -1
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/recordset_compat.py +8 -1
- flwr/common/retry_invoker.py +25 -13
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +209 -3
- flwr/common/telemetry.py +25 -0
- flwr/common/typing.py +38 -0
- flwr/common/version.py +14 -0
- flwr/proto/clientappio_pb2.py +41 -0
- flwr/proto/clientappio_pb2.pyi +110 -0
- flwr/proto/clientappio_pb2_grpc.py +101 -0
- flwr/proto/clientappio_pb2_grpc.pyi +40 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +26 -19
- flwr/proto/driver_pb2.pyi +34 -0
- flwr/proto/driver_pb2_grpc.py +70 -0
- flwr/proto/driver_pb2_grpc.pyi +28 -0
- flwr/proto/exec_pb2.py +43 -0
- flwr/proto/exec_pb2.pyi +95 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +29 -23
- flwr/proto/fleet_pb2.pyi +33 -0
- flwr/proto/fleet_pb2_grpc.py +102 -0
- flwr/proto/fleet_pb2_grpc.pyi +35 -0
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/message_pb2.py +41 -0
- flwr/proto/message_pb2.pyi +122 -0
- flwr/proto/message_pb2_grpc.py +4 -0
- flwr/proto/message_pb2_grpc.pyi +4 -0
- flwr/proto/run_pb2.py +35 -0
- flwr/proto/run_pb2.pyi +76 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +7 -8
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/__init__.py +4 -8
- flwr/server/app.py +298 -350
- flwr/server/compat/app.py +6 -57
- flwr/server/compat/app_utils.py +5 -4
- flwr/server/compat/driver_client_proxy.py +29 -48
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/__init__.py +2 -0
- flwr/server/driver/driver.py +22 -132
- flwr/server/driver/grpc_driver.py +224 -74
- flwr/server/driver/inmemory_driver.py +183 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +121 -34
- flwr/server/server.py +11 -7
- flwr/server/server_app.py +59 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dp_adaptive_clipping.py +3 -3
- flwr/server/strategy/dp_fixed_clipping.py +4 -3
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +51 -4
- flwr/server/superlink/ffs/__init__.py +24 -0
- flwr/server/superlink/ffs/disk_ffs.py +104 -0
- flwr/server/superlink/ffs/ffs.py +79 -0
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
- flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
- flwr/server/superlink/fleet/vce/vce_api.py +190 -127
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +159 -42
- flwr/server/superlink/state/sqlite_state.py +243 -39
- flwr/server/superlink/state/state.py +81 -6
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +62 -0
- flwr/server/typing.py +2 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +67 -25
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
- flwr/simulation/__init__.py +7 -4
- flwr/simulation/app.py +67 -36
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +20 -46
- flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
- flwr/simulation/run_simulation.py +308 -92
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +184 -0
- flwr/superexec/deployment.py +185 -0
- flwr/superexec/exec_grpc.py +55 -0
- flwr/superexec/exec_servicer.py +70 -0
- flwr/superexec/executor.py +75 -0
- flwr/superexec/simulation.py +193 -0
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
- flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
- flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
- flwr/cli/flower_toml.py +0 -140
- flwr/cli/new/templates/app/flower.toml.tpl +0 -13
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
- flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
- flwr_nightly-1.8.0.dev20240315.dist-info/RECORD +0 -211
- flwr_nightly-1.8.0.dev20240315.dist-info/entry_points.txt +0 -9
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
flwr/server/utils/validator.py
CHANGED
|
@@ -31,13 +31,21 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str
|
|
|
31
31
|
if not tasks_ins_res.HasField("task"):
|
|
32
32
|
validation_errors.append("`task` does not set field `task`")
|
|
33
33
|
|
|
34
|
-
# Created/delivered/TTL
|
|
35
|
-
if
|
|
36
|
-
|
|
34
|
+
# Created/delivered/TTL/Pushed
|
|
35
|
+
if (
|
|
36
|
+
tasks_ins_res.task.created_at < 1711497600.0
|
|
37
|
+
): # unix timestamp of 27 March 2024 00h:00m:00s UTC
|
|
38
|
+
validation_errors.append(
|
|
39
|
+
"`created_at` must be a float that records the unix timestamp "
|
|
40
|
+
"in seconds when the message was created."
|
|
41
|
+
)
|
|
37
42
|
if tasks_ins_res.task.delivered_at != "":
|
|
38
43
|
validation_errors.append("`delivered_at` must be an empty str")
|
|
39
|
-
if tasks_ins_res.task.ttl
|
|
40
|
-
validation_errors.append("`ttl` must be
|
|
44
|
+
if tasks_ins_res.task.ttl <= 0:
|
|
45
|
+
validation_errors.append("`ttl` must be higher than zero")
|
|
46
|
+
if tasks_ins_res.task.pushed_at < 1711497600.0:
|
|
47
|
+
# unix timestamp of 27 March 2024 00h:00m:00s UTC
|
|
48
|
+
validation_errors.append("`pushed_at` is not a recent timestamp")
|
|
41
49
|
|
|
42
50
|
# TaskIns specific
|
|
43
51
|
if isinstance(tasks_ins_res, TaskIns):
|
|
@@ -66,8 +74,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str
|
|
|
66
74
|
# Content check
|
|
67
75
|
if tasks_ins_res.task.task_type == "":
|
|
68
76
|
validation_errors.append("`task_type` MUST be set")
|
|
69
|
-
if not
|
|
70
|
-
|
|
77
|
+
if not (
|
|
78
|
+
tasks_ins_res.task.HasField("recordset")
|
|
79
|
+
^ tasks_ins_res.task.HasField("error")
|
|
80
|
+
):
|
|
81
|
+
validation_errors.append("Either `recordset` or `error` MUST be set")
|
|
71
82
|
|
|
72
83
|
# Ancestors
|
|
73
84
|
if len(tasks_ins_res.task.ancestry) != 0:
|
|
@@ -106,8 +117,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str
|
|
|
106
117
|
# Content check
|
|
107
118
|
if tasks_ins_res.task.task_type == "":
|
|
108
119
|
validation_errors.append("`task_type` MUST be set")
|
|
109
|
-
if not
|
|
110
|
-
|
|
120
|
+
if not (
|
|
121
|
+
tasks_ins_res.task.HasField("recordset")
|
|
122
|
+
^ tasks_ins_res.task.HasField("error")
|
|
123
|
+
):
|
|
124
|
+
validation_errors.append("Either `recordset` or `error` MUST be set")
|
|
111
125
|
|
|
112
126
|
# Ancestors
|
|
113
127
|
if len(tasks_ins_res.task.ancestry) == 0:
|
|
@@ -17,13 +17,23 @@
|
|
|
17
17
|
|
|
18
18
|
import io
|
|
19
19
|
import timeit
|
|
20
|
-
from logging import INFO
|
|
21
|
-
from typing import Optional, cast
|
|
20
|
+
from logging import INFO, WARN
|
|
21
|
+
from typing import List, Optional, Tuple, Union, cast
|
|
22
22
|
|
|
23
23
|
import flwr.common.recordset_compat as compat
|
|
24
|
-
from flwr.common import
|
|
24
|
+
from flwr.common import (
|
|
25
|
+
Code,
|
|
26
|
+
ConfigsRecord,
|
|
27
|
+
Context,
|
|
28
|
+
EvaluateRes,
|
|
29
|
+
FitRes,
|
|
30
|
+
GetParametersIns,
|
|
31
|
+
ParametersRecord,
|
|
32
|
+
log,
|
|
33
|
+
)
|
|
25
34
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
26
35
|
|
|
36
|
+
from ..client_proxy import ClientProxy
|
|
27
37
|
from ..compat.app_utils import start_update_client_manager_thread
|
|
28
38
|
from ..compat.legacy_context import LegacyContext
|
|
29
39
|
from ..driver import Driver
|
|
@@ -88,7 +98,12 @@ class DefaultWorkflow:
|
|
|
88
98
|
hist = context.history
|
|
89
99
|
log(INFO, "")
|
|
90
100
|
log(INFO, "[SUMMARY]")
|
|
91
|
-
log(
|
|
101
|
+
log(
|
|
102
|
+
INFO,
|
|
103
|
+
"Run finished %s round(s) in %.2fs",
|
|
104
|
+
context.config.num_rounds,
|
|
105
|
+
elapsed,
|
|
106
|
+
)
|
|
92
107
|
for idx, line in enumerate(io.StringIO(str(hist))):
|
|
93
108
|
if idx == 0:
|
|
94
109
|
log(INFO, "%s", line.strip("\n"))
|
|
@@ -127,13 +142,27 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
127
142
|
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
|
128
143
|
dst_node_id=random_client.node_id,
|
|
129
144
|
group_id="0",
|
|
130
|
-
ttl="",
|
|
131
145
|
)
|
|
132
146
|
]
|
|
133
147
|
)
|
|
134
|
-
log(INFO, "Received initial parameters from one random client")
|
|
135
148
|
msg = list(messages)[0]
|
|
136
|
-
|
|
149
|
+
|
|
150
|
+
if (
|
|
151
|
+
msg.has_content()
|
|
152
|
+
and compat._extract_status_from_recordset( # pylint: disable=W0212
|
|
153
|
+
"getparametersres", msg.content
|
|
154
|
+
).code
|
|
155
|
+
== Code.OK
|
|
156
|
+
):
|
|
157
|
+
log(INFO, "Received initial parameters from one random client")
|
|
158
|
+
paramsrecord = next(iter(msg.content.parameters_records.values()))
|
|
159
|
+
else:
|
|
160
|
+
log(
|
|
161
|
+
WARN,
|
|
162
|
+
"Failed to receive initial parameters from the client."
|
|
163
|
+
" Empty initial parameters will be used.",
|
|
164
|
+
)
|
|
165
|
+
paramsrecord = ParametersRecord()
|
|
137
166
|
|
|
138
167
|
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
|
139
168
|
|
|
@@ -226,7 +255,6 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
226
255
|
message_type=MessageType.TRAIN,
|
|
227
256
|
dst_node_id=proxy.node_id,
|
|
228
257
|
group_id=str(current_round),
|
|
229
|
-
ttl="",
|
|
230
258
|
)
|
|
231
259
|
for proxy, fitins in client_instructions
|
|
232
260
|
]
|
|
@@ -246,14 +274,20 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
246
274
|
)
|
|
247
275
|
|
|
248
276
|
# Aggregate training results
|
|
249
|
-
results = [
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
277
|
+
results: List[Tuple[ClientProxy, FitRes]] = []
|
|
278
|
+
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = []
|
|
279
|
+
for msg in messages:
|
|
280
|
+
if msg.has_content():
|
|
281
|
+
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
282
|
+
fitres = compat.recordset_to_fitres(msg.content, False)
|
|
283
|
+
if fitres.status.code == Code.OK:
|
|
284
|
+
results.append((proxy, fitres))
|
|
285
|
+
else:
|
|
286
|
+
failures.append((proxy, fitres))
|
|
287
|
+
else:
|
|
288
|
+
failures.append(Exception(msg.error))
|
|
289
|
+
|
|
290
|
+
aggregated_result = context.strategy.aggregate_fit(current_round, results, failures)
|
|
257
291
|
parameters_aggregated, metrics_aggregated = aggregated_result
|
|
258
292
|
|
|
259
293
|
# Update the parameters and write history
|
|
@@ -267,6 +301,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
267
301
|
)
|
|
268
302
|
|
|
269
303
|
|
|
304
|
+
# pylint: disable-next=R0914
|
|
270
305
|
def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
271
306
|
"""Execute the default workflow for a single evaluate round."""
|
|
272
307
|
if not isinstance(context, LegacyContext):
|
|
@@ -306,7 +341,6 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
306
341
|
message_type=MessageType.EVALUATE,
|
|
307
342
|
dst_node_id=proxy.node_id,
|
|
308
343
|
group_id=str(current_round),
|
|
309
|
-
ttl="",
|
|
310
344
|
)
|
|
311
345
|
for proxy, evalins in client_instructions
|
|
312
346
|
]
|
|
@@ -326,14 +360,22 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
326
360
|
)
|
|
327
361
|
|
|
328
362
|
# Aggregate the evaluation results
|
|
329
|
-
results = [
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
363
|
+
results: List[Tuple[ClientProxy, EvaluateRes]] = []
|
|
364
|
+
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = []
|
|
365
|
+
for msg in messages:
|
|
366
|
+
if msg.has_content():
|
|
367
|
+
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
368
|
+
evalres = compat.recordset_to_evaluateres(msg.content)
|
|
369
|
+
if evalres.status.code == Code.OK:
|
|
370
|
+
results.append((proxy, evalres))
|
|
371
|
+
else:
|
|
372
|
+
failures.append((proxy, evalres))
|
|
373
|
+
else:
|
|
374
|
+
failures.append(Exception(msg.error))
|
|
375
|
+
|
|
376
|
+
aggregated_result = context.strategy.aggregate_evaluate(
|
|
377
|
+
current_round, results, failures
|
|
378
|
+
)
|
|
337
379
|
|
|
338
380
|
loss_aggregated, metrics_aggregated = aggregated_result
|
|
339
381
|
|
|
@@ -81,6 +81,7 @@ class WorkflowState: # pylint: disable=R0902
|
|
|
81
81
|
forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
|
|
82
82
|
aggregate_ndarrays: NDArrays = field(default_factory=list)
|
|
83
83
|
legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)
|
|
84
|
+
failures: List[Exception] = field(default_factory=list)
|
|
84
85
|
|
|
85
86
|
|
|
86
87
|
class SecAggPlusWorkflow:
|
|
@@ -373,7 +374,6 @@ class SecAggPlusWorkflow:
|
|
|
373
374
|
message_type=MessageType.TRAIN,
|
|
374
375
|
dst_node_id=nid,
|
|
375
376
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
|
376
|
-
ttl="",
|
|
377
377
|
)
|
|
378
378
|
|
|
379
379
|
log(
|
|
@@ -395,6 +395,7 @@ class SecAggPlusWorkflow:
|
|
|
395
395
|
|
|
396
396
|
for msg in msgs:
|
|
397
397
|
if msg.has_error():
|
|
398
|
+
state.failures.append(Exception(msg.error))
|
|
398
399
|
continue
|
|
399
400
|
key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
400
401
|
node_id = msg.metadata.src_node_id
|
|
@@ -421,7 +422,6 @@ class SecAggPlusWorkflow:
|
|
|
421
422
|
message_type=MessageType.TRAIN,
|
|
422
423
|
dst_node_id=nid,
|
|
423
424
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
|
424
|
-
ttl="",
|
|
425
425
|
)
|
|
426
426
|
|
|
427
427
|
# Broadcast public keys to clients and receive secret key shares
|
|
@@ -453,6 +453,9 @@ class SecAggPlusWorkflow:
|
|
|
453
453
|
nid: [] for nid in state.active_node_ids
|
|
454
454
|
} # dest node ID -> list of src node IDs
|
|
455
455
|
for msg in msgs:
|
|
456
|
+
if msg.has_error():
|
|
457
|
+
state.failures.append(Exception(msg.error))
|
|
458
|
+
continue
|
|
456
459
|
node_id = msg.metadata.src_node_id
|
|
457
460
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
458
461
|
dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST])
|
|
@@ -492,7 +495,6 @@ class SecAggPlusWorkflow:
|
|
|
492
495
|
message_type=MessageType.TRAIN,
|
|
493
496
|
dst_node_id=nid,
|
|
494
497
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
|
495
|
-
ttl="",
|
|
496
498
|
)
|
|
497
499
|
|
|
498
500
|
log(
|
|
@@ -518,6 +520,9 @@ class SecAggPlusWorkflow:
|
|
|
518
520
|
# Sum collected masked vectors and compute active/dead node IDs
|
|
519
521
|
masked_vector = None
|
|
520
522
|
for msg in msgs:
|
|
523
|
+
if msg.has_error():
|
|
524
|
+
state.failures.append(Exception(msg.error))
|
|
525
|
+
continue
|
|
521
526
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
522
527
|
bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS])
|
|
523
528
|
client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
|
|
@@ -531,6 +536,9 @@ class SecAggPlusWorkflow:
|
|
|
531
536
|
|
|
532
537
|
# Backward compatibility with Strategy
|
|
533
538
|
for msg in msgs:
|
|
539
|
+
if msg.has_error():
|
|
540
|
+
state.failures.append(Exception(msg.error))
|
|
541
|
+
continue
|
|
534
542
|
fitres = compat.recordset_to_fitres(msg.content, True)
|
|
535
543
|
proxy = state.nid_to_proxies[msg.metadata.src_node_id]
|
|
536
544
|
state.legacy_results.append((proxy, fitres))
|
|
@@ -563,7 +571,6 @@ class SecAggPlusWorkflow:
|
|
|
563
571
|
message_type=MessageType.TRAIN,
|
|
564
572
|
dst_node_id=nid,
|
|
565
573
|
group_id=str(current_round),
|
|
566
|
-
ttl="",
|
|
567
574
|
)
|
|
568
575
|
|
|
569
576
|
log(
|
|
@@ -588,6 +595,9 @@ class SecAggPlusWorkflow:
|
|
|
588
595
|
for nid in state.sampled_node_ids:
|
|
589
596
|
collected_shares_dict[nid] = []
|
|
590
597
|
for msg in msgs:
|
|
598
|
+
if msg.has_error():
|
|
599
|
+
state.failures.append(Exception(msg.error))
|
|
600
|
+
continue
|
|
591
601
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
592
602
|
nids = cast(List[int], res_dict[Key.NODE_ID_LIST])
|
|
593
603
|
shares = cast(List[bytes], res_dict[Key.SHARE_LIST])
|
|
@@ -656,9 +666,11 @@ class SecAggPlusWorkflow:
|
|
|
656
666
|
INFO,
|
|
657
667
|
"aggregate_fit: received %s results and %s failures",
|
|
658
668
|
len(results),
|
|
659
|
-
|
|
669
|
+
len(state.failures),
|
|
670
|
+
)
|
|
671
|
+
aggregated_result = context.strategy.aggregate_fit(
|
|
672
|
+
current_round, results, state.failures # type: ignore
|
|
660
673
|
)
|
|
661
|
-
aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
|
|
662
674
|
parameters_aggregated, metrics_aggregated = aggregated_result
|
|
663
675
|
|
|
664
676
|
# Update the parameters and write history
|
flwr/simulation/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2021 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
import importlib
|
|
19
19
|
|
|
20
|
-
from flwr.simulation.run_simulation import run_simulation
|
|
20
|
+
from flwr.simulation.run_simulation import run_simulation
|
|
21
21
|
|
|
22
22
|
is_ray_installed = importlib.util.find_spec("ray") is not None
|
|
23
23
|
|
|
@@ -28,7 +28,7 @@ else:
|
|
|
28
28
|
|
|
29
29
|
To install the necessary dependencies, install `flwr` with the `simulation` extra:
|
|
30
30
|
|
|
31
|
-
pip install -U flwr[
|
|
31
|
+
pip install -U "flwr[simulation]"
|
|
32
32
|
"""
|
|
33
33
|
|
|
34
34
|
def start_simulation(*args, **kwargs): # type: ignore
|
|
@@ -36,4 +36,7 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
|
|
|
36
36
|
raise ImportError(RAY_IMPORT_ERROR)
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
__all__ = [
|
|
39
|
+
__all__ = [
|
|
40
|
+
"run_simulation",
|
|
41
|
+
"start_simulation",
|
|
42
|
+
]
|
flwr/simulation/app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2021 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,6 +15,8 @@
|
|
|
15
15
|
"""Flower simulation app."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import asyncio
|
|
19
|
+
import logging
|
|
18
20
|
import sys
|
|
19
21
|
import threading
|
|
20
22
|
import traceback
|
|
@@ -25,14 +27,16 @@ from typing import Any, Dict, List, Optional, Type, Union
|
|
|
25
27
|
import ray
|
|
26
28
|
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
|
27
29
|
|
|
28
|
-
from flwr.client import
|
|
30
|
+
from flwr.client import ClientFnExt
|
|
29
31
|
from flwr.common import EventType, event
|
|
30
|
-
from flwr.common.
|
|
32
|
+
from flwr.common.constant import NODE_ID_NUM_BYTES
|
|
33
|
+
from flwr.common.logger import log, set_logger_propagation, warn_unsupported_feature
|
|
31
34
|
from flwr.server.client_manager import ClientManager
|
|
32
35
|
from flwr.server.history import History
|
|
33
36
|
from flwr.server.server import Server, init_defaults, run_fl
|
|
34
37
|
from flwr.server.server_config import ServerConfig
|
|
35
38
|
from flwr.server.strategy import Strategy
|
|
39
|
+
from flwr.server.superlink.state.utils import generate_rand_int_from_bytes
|
|
36
40
|
from flwr.simulation.ray_transport.ray_actor import (
|
|
37
41
|
ClientAppActor,
|
|
38
42
|
VirtualClientEngineActor,
|
|
@@ -49,7 +53,7 @@ Invalid Arguments in method:
|
|
|
49
53
|
`start_simulation(
|
|
50
54
|
*,
|
|
51
55
|
client_fn: ClientFn,
|
|
52
|
-
num_clients:
|
|
56
|
+
num_clients: int,
|
|
53
57
|
clients_ids: Optional[List[str]] = None,
|
|
54
58
|
client_resources: Optional[Dict[str, float]] = None,
|
|
55
59
|
server: Optional[Server] = None,
|
|
@@ -68,13 +72,29 @@ REASON:
|
|
|
68
72
|
|
|
69
73
|
"""
|
|
70
74
|
|
|
75
|
+
NodeToPartitionMapping = Dict[int, int]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _create_node_id_to_partition_mapping(
|
|
79
|
+
num_clients: int,
|
|
80
|
+
) -> NodeToPartitionMapping:
|
|
81
|
+
"""Generate a node_id:partition_id mapping."""
|
|
82
|
+
nodes_mapping: NodeToPartitionMapping = {} # {node-id; partition-id}
|
|
83
|
+
for i in range(num_clients):
|
|
84
|
+
while True:
|
|
85
|
+
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
|
86
|
+
if node_id not in nodes_mapping:
|
|
87
|
+
break
|
|
88
|
+
nodes_mapping[node_id] = i
|
|
89
|
+
return nodes_mapping
|
|
90
|
+
|
|
71
91
|
|
|
72
92
|
# pylint: disable=too-many-arguments,too-many-statements,too-many-branches
|
|
73
93
|
def start_simulation(
|
|
74
94
|
*,
|
|
75
|
-
client_fn:
|
|
76
|
-
num_clients:
|
|
77
|
-
clients_ids: Optional[List[str]] = None,
|
|
95
|
+
client_fn: ClientFnExt,
|
|
96
|
+
num_clients: int,
|
|
97
|
+
clients_ids: Optional[List[str]] = None, # UNSUPPORTED, WILL BE REMOVED
|
|
78
98
|
client_resources: Optional[Dict[str, float]] = None,
|
|
79
99
|
server: Optional[Server] = None,
|
|
80
100
|
config: Optional[ServerConfig] = None,
|
|
@@ -90,23 +110,24 @@ def start_simulation(
|
|
|
90
110
|
|
|
91
111
|
Parameters
|
|
92
112
|
----------
|
|
93
|
-
client_fn :
|
|
94
|
-
A function creating
|
|
95
|
-
`
|
|
96
|
-
of type Client
|
|
97
|
-
and will often be destroyed after a single method
|
|
98
|
-
instances are not long-lived, they should not attempt
|
|
99
|
-
method invocations. Any state required by the instance
|
|
100
|
-
hyperparameters, ...) should be (re-)created in either the
|
|
101
|
-
or the call to any of the client methods (e.g., load
|
|
102
|
-
`evaluate` method itself).
|
|
103
|
-
num_clients :
|
|
104
|
-
The total number of clients in this simulation.
|
|
105
|
-
`clients_ids` is not set and vice-versa.
|
|
113
|
+
client_fn : ClientFnExt
|
|
114
|
+
A function creating `Client` instances. The function must have the signature
|
|
115
|
+
`client_fn(context: Context). It should return
|
|
116
|
+
a single client instance of type `Client`. Note that the created client
|
|
117
|
+
instances are ephemeral and will often be destroyed after a single method
|
|
118
|
+
invocation. Since client instances are not long-lived, they should not attempt
|
|
119
|
+
to carry state over method invocations. Any state required by the instance
|
|
120
|
+
(model, dataset, hyperparameters, ...) should be (re-)created in either the
|
|
121
|
+
call to `client_fn` or the call to any of the client methods (e.g., load
|
|
122
|
+
evaluation data in the `evaluate` method itself).
|
|
123
|
+
num_clients : int
|
|
124
|
+
The total number of clients in this simulation.
|
|
106
125
|
clients_ids : Optional[List[str]]
|
|
126
|
+
UNSUPPORTED, WILL BE REMOVED. USE `num_clients` INSTEAD.
|
|
107
127
|
List `client_id`s for each client. This is only required if
|
|
108
128
|
`num_clients` is not set. Setting both `num_clients` and `clients_ids`
|
|
109
129
|
with `len(clients_ids)` not equal to `num_clients` generates an error.
|
|
130
|
+
Using this argument will raise an error.
|
|
110
131
|
client_resources : Optional[Dict[str, float]] (default: `{"num_cpus": 1, "num_gpus": 0.0}`)
|
|
111
132
|
CPU and GPU resources for a single client. Supported keys
|
|
112
133
|
are `num_cpus` and `num_gpus`. To understand the GPU utilization caused by
|
|
@@ -167,6 +188,26 @@ def start_simulation(
|
|
|
167
188
|
{"num_clients": len(clients_ids) if clients_ids is not None else num_clients},
|
|
168
189
|
)
|
|
169
190
|
|
|
191
|
+
if clients_ids is not None:
|
|
192
|
+
warn_unsupported_feature(
|
|
193
|
+
"Passing `clients_ids` to `start_simulation` is deprecated and not longer "
|
|
194
|
+
"used by `start_simulation`. Use `num_clients` exclusively instead."
|
|
195
|
+
)
|
|
196
|
+
log(ERROR, "`clients_ids` argument used.")
|
|
197
|
+
sys.exit()
|
|
198
|
+
|
|
199
|
+
# Set logger propagation
|
|
200
|
+
loop: Optional[asyncio.AbstractEventLoop] = None
|
|
201
|
+
try:
|
|
202
|
+
loop = asyncio.get_running_loop()
|
|
203
|
+
except RuntimeError:
|
|
204
|
+
loop = None
|
|
205
|
+
finally:
|
|
206
|
+
if loop and loop.is_running():
|
|
207
|
+
# Set logger propagation to False to prevent duplicated log output in Colab.
|
|
208
|
+
logger = logging.getLogger("flwr")
|
|
209
|
+
_ = set_logger_propagation(logger, False)
|
|
210
|
+
|
|
170
211
|
# Initialize server and server config
|
|
171
212
|
initialized_server, initialized_config = init_defaults(
|
|
172
213
|
server=server,
|
|
@@ -181,20 +222,8 @@ def start_simulation(
|
|
|
181
222
|
initialized_config,
|
|
182
223
|
)
|
|
183
224
|
|
|
184
|
-
#
|
|
185
|
-
|
|
186
|
-
if clients_ids is not None:
|
|
187
|
-
if (num_clients is not None) and (len(clients_ids) != num_clients):
|
|
188
|
-
log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
|
|
189
|
-
sys.exit()
|
|
190
|
-
else:
|
|
191
|
-
cids = clients_ids
|
|
192
|
-
else:
|
|
193
|
-
if num_clients is None:
|
|
194
|
-
log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
|
|
195
|
-
sys.exit()
|
|
196
|
-
else:
|
|
197
|
-
cids = [str(x) for x in range(num_clients)]
|
|
225
|
+
# Create node-id to partition-id mapping
|
|
226
|
+
nodes_mapping = _create_node_id_to_partition_mapping(num_clients)
|
|
198
227
|
|
|
199
228
|
# Default arguments for Ray initialization
|
|
200
229
|
if not ray_init_args:
|
|
@@ -293,10 +322,12 @@ def start_simulation(
|
|
|
293
322
|
)
|
|
294
323
|
|
|
295
324
|
# Register one RayClientProxy object for each client with the ClientManager
|
|
296
|
-
for
|
|
325
|
+
for node_id, partition_id in nodes_mapping.items():
|
|
297
326
|
client_proxy = RayActorClientProxy(
|
|
298
327
|
client_fn=client_fn,
|
|
299
|
-
|
|
328
|
+
node_id=node_id,
|
|
329
|
+
partition_id=partition_id,
|
|
330
|
+
num_partitions=num_clients,
|
|
300
331
|
actor_pool=pool,
|
|
301
332
|
)
|
|
302
333
|
initialized_server.client_manager().register(client=client_proxy)
|
|
@@ -14,9 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Ray-based Flower Actor and ActorPool implementation."""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import threading
|
|
19
|
-
import traceback
|
|
20
18
|
from abc import ABC
|
|
21
19
|
from logging import DEBUG, ERROR, WARNING
|
|
22
20
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
|
@@ -25,22 +23,13 @@ import ray
|
|
|
25
23
|
from ray import ObjectRef
|
|
26
24
|
from ray.util.actor_pool import ActorPool
|
|
27
25
|
|
|
28
|
-
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
26
|
+
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
29
27
|
from flwr.common import Context, Message
|
|
30
28
|
from flwr.common.logger import log
|
|
31
29
|
|
|
32
30
|
ClientAppFn = Callable[[], ClientApp]
|
|
33
31
|
|
|
34
32
|
|
|
35
|
-
class ClientException(Exception):
|
|
36
|
-
"""Raised when client side logic crashes with an exception."""
|
|
37
|
-
|
|
38
|
-
def __init__(self, message: str):
|
|
39
|
-
div = ">" * 7
|
|
40
|
-
self.message = "\n" + div + "A ClientException occurred." + message
|
|
41
|
-
super().__init__(self.message)
|
|
42
|
-
|
|
43
|
-
|
|
44
33
|
class VirtualClientEngineActor(ABC):
|
|
45
34
|
"""Abstract base class for VirtualClientEngine Actors."""
|
|
46
35
|
|
|
@@ -71,17 +60,7 @@ class VirtualClientEngineActor(ABC):
|
|
|
71
60
|
raise load_ex
|
|
72
61
|
|
|
73
62
|
except Exception as ex:
|
|
74
|
-
|
|
75
|
-
mssg = (
|
|
76
|
-
"\n\tSomething went wrong when running your client run."
|
|
77
|
-
"\n\tClient "
|
|
78
|
-
+ cid
|
|
79
|
-
+ " crashed when the "
|
|
80
|
-
+ self.__class__.__name__
|
|
81
|
-
+ " was running its run."
|
|
82
|
-
"\n\tException triggered on the client side: " + client_trace,
|
|
83
|
-
)
|
|
84
|
-
raise ClientException(str(mssg)) from ex
|
|
63
|
+
raise ClientAppException(str(ex)) from ex
|
|
85
64
|
|
|
86
65
|
return cid, out_message, context
|
|
87
66
|
|
|
@@ -419,12 +398,6 @@ class VirtualClientEngineActorPool(ActorPool):
|
|
|
419
398
|
return self._fetch_future_result(cid)
|
|
420
399
|
|
|
421
400
|
|
|
422
|
-
def init_ray(*args: Any, **kwargs: Any) -> None:
|
|
423
|
-
"""Intialises Ray if not already initialised."""
|
|
424
|
-
if not ray.is_initialized():
|
|
425
|
-
ray.init(*args, **kwargs)
|
|
426
|
-
|
|
427
|
-
|
|
428
401
|
class BasicActorPool:
|
|
429
402
|
"""A basic actor pool."""
|
|
430
403
|
|
|
@@ -437,9 +410,7 @@ class BasicActorPool:
|
|
|
437
410
|
self.client_resources = client_resources
|
|
438
411
|
|
|
439
412
|
# Queue of idle actors
|
|
440
|
-
self.pool:
|
|
441
|
-
maxsize=1024
|
|
442
|
-
)
|
|
413
|
+
self.pool: List[VirtualClientEngineActor] = []
|
|
443
414
|
self.num_actors = 0
|
|
444
415
|
|
|
445
416
|
# Resolve arguments to pass during actor init
|
|
@@ -453,38 +424,37 @@ class BasicActorPool:
|
|
|
453
424
|
# Figure out how many actors can be created given the cluster resources
|
|
454
425
|
# and the resources the user indicates each VirtualClient will need
|
|
455
426
|
self.actors_capacity = pool_size_from_resources(client_resources)
|
|
456
|
-
self._future_to_actor: Dict[Any,
|
|
427
|
+
self._future_to_actor: Dict[Any, VirtualClientEngineActor] = {}
|
|
457
428
|
|
|
458
429
|
def is_actor_available(self) -> bool:
|
|
459
430
|
"""Return true if there is an idle actor."""
|
|
460
|
-
return self.pool
|
|
431
|
+
return len(self.pool) > 0
|
|
461
432
|
|
|
462
|
-
|
|
433
|
+
def add_actors_to_pool(self, num_actors: int) -> None:
|
|
463
434
|
"""Add actors to the pool.
|
|
464
435
|
|
|
465
436
|
This method may be executed also if new resources are added to your Ray cluster
|
|
466
437
|
(e.g. you add a new node).
|
|
467
438
|
"""
|
|
468
439
|
for _ in range(num_actors):
|
|
469
|
-
|
|
440
|
+
self.pool.append(self.create_actor_fn()) # type: ignore
|
|
470
441
|
self.num_actors += num_actors
|
|
471
442
|
|
|
472
|
-
|
|
443
|
+
def terminate_all_actors(self) -> None:
|
|
473
444
|
"""Terminate actors in pool."""
|
|
474
445
|
num_terminated = 0
|
|
475
|
-
|
|
476
|
-
actor = await self.pool.get()
|
|
446
|
+
for actor in self.pool:
|
|
477
447
|
actor.terminate.remote() # type: ignore
|
|
478
448
|
num_terminated += 1
|
|
479
449
|
|
|
480
450
|
log(DEBUG, "Terminated %i actors", num_terminated)
|
|
481
451
|
|
|
482
|
-
|
|
452
|
+
def submit(
|
|
483
453
|
self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context]
|
|
484
454
|
) -> Any:
|
|
485
455
|
"""On idle actor, submit job and return future."""
|
|
486
456
|
# Remove idle actor from pool
|
|
487
|
-
actor =
|
|
457
|
+
actor = self.pool.pop()
|
|
488
458
|
# Submit job to actor
|
|
489
459
|
app_fn, mssg, cid, context = job
|
|
490
460
|
future = actor_fn(actor, app_fn, mssg, cid, context)
|
|
@@ -493,14 +463,18 @@ class BasicActorPool:
|
|
|
493
463
|
self._future_to_actor[future] = actor
|
|
494
464
|
return future
|
|
495
465
|
|
|
496
|
-
|
|
466
|
+
def add_actor_back_to_pool(self, future: Any) -> None:
|
|
467
|
+
"""Ad actor assigned to run future back into the pool."""
|
|
468
|
+
actor = self._future_to_actor.pop(future)
|
|
469
|
+
self.pool.append(actor)
|
|
470
|
+
|
|
471
|
+
def fetch_result_and_return_actor_to_pool(
|
|
497
472
|
self, future: Any
|
|
498
473
|
) -> Tuple[Message, Context]:
|
|
499
474
|
"""Pull result given a future and add actor back to pool."""
|
|
500
|
-
# Get actor that ran job
|
|
501
|
-
actor = self._future_to_actor.pop(future)
|
|
502
|
-
await self.pool.put(actor)
|
|
503
475
|
# Retrieve result for object store
|
|
504
476
|
# Instead of doing ray.get(future) we await it
|
|
505
|
-
_, out_mssg, updated_context =
|
|
477
|
+
_, out_mssg, updated_context = ray.get(future)
|
|
478
|
+
# Get actor that ran job
|
|
479
|
+
self.add_actor_back_to_pool(future)
|
|
506
480
|
return out_mssg, updated_context
|