flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +7 -0
- flwr/cli/build.py +150 -0
- flwr/cli/config_utils.py +219 -0
- flwr/cli/example.py +3 -1
- flwr/cli/install.py +227 -0
- flwr/cli/new/new.py +179 -48
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/README.md.tpl +1 -5
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
- flwr/cli/run/run.py +168 -17
- flwr/cli/utils.py +75 -4
- flwr/client/__init__.py +6 -1
- flwr/client/app.py +239 -248
- flwr/client/client_app.py +70 -9
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +97 -0
- flwr/client/grpc_client/connection.py +18 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +127 -33
- flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
- flwr/client/heartbeat.py +74 -0
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +7 -7
- flwr/client/mod/__init__.py +5 -5
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/comms_mods.py +4 -4
- flwr/client/mod/localdp_mod.py +9 -4
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +177 -157
- flwr/client/supernode/__init__.py +26 -0
- flwr/client/supernode/app.py +464 -0
- flwr/client/typing.py +1 -0
- flwr/common/__init__.py +13 -11
- flwr/common/address.py +1 -1
- flwr/common/config.py +193 -0
- flwr/common/constant.py +42 -1
- flwr/common/context.py +26 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +6 -2
- flwr/common/logger.py +79 -8
- flwr/common/message.py +167 -105
- flwr/common/object_ref.py +126 -25
- flwr/common/record/__init__.py +1 -1
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/recordset_compat.py +8 -1
- flwr/common/retry_invoker.py +25 -13
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +209 -3
- flwr/common/telemetry.py +25 -0
- flwr/common/typing.py +38 -0
- flwr/common/version.py +14 -0
- flwr/proto/clientappio_pb2.py +41 -0
- flwr/proto/clientappio_pb2.pyi +110 -0
- flwr/proto/clientappio_pb2_grpc.py +101 -0
- flwr/proto/clientappio_pb2_grpc.pyi +40 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +26 -19
- flwr/proto/driver_pb2.pyi +34 -0
- flwr/proto/driver_pb2_grpc.py +70 -0
- flwr/proto/driver_pb2_grpc.pyi +28 -0
- flwr/proto/exec_pb2.py +43 -0
- flwr/proto/exec_pb2.pyi +95 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +29 -23
- flwr/proto/fleet_pb2.pyi +33 -0
- flwr/proto/fleet_pb2_grpc.py +102 -0
- flwr/proto/fleet_pb2_grpc.pyi +35 -0
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/message_pb2.py +41 -0
- flwr/proto/message_pb2.pyi +122 -0
- flwr/proto/message_pb2_grpc.py +4 -0
- flwr/proto/message_pb2_grpc.pyi +4 -0
- flwr/proto/run_pb2.py +35 -0
- flwr/proto/run_pb2.pyi +76 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +7 -8
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/__init__.py +4 -8
- flwr/server/app.py +298 -350
- flwr/server/compat/app.py +6 -57
- flwr/server/compat/app_utils.py +5 -4
- flwr/server/compat/driver_client_proxy.py +29 -48
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/__init__.py +2 -0
- flwr/server/driver/driver.py +22 -132
- flwr/server/driver/grpc_driver.py +224 -74
- flwr/server/driver/inmemory_driver.py +183 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +121 -34
- flwr/server/server.py +11 -7
- flwr/server/server_app.py +59 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dp_adaptive_clipping.py +3 -3
- flwr/server/strategy/dp_fixed_clipping.py +4 -3
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +51 -4
- flwr/server/superlink/ffs/__init__.py +24 -0
- flwr/server/superlink/ffs/disk_ffs.py +104 -0
- flwr/server/superlink/ffs/ffs.py +79 -0
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
- flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
- flwr/server/superlink/fleet/vce/vce_api.py +190 -127
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +159 -42
- flwr/server/superlink/state/sqlite_state.py +243 -39
- flwr/server/superlink/state/state.py +81 -6
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +62 -0
- flwr/server/typing.py +2 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +67 -25
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
- flwr/simulation/__init__.py +7 -4
- flwr/simulation/app.py +67 -36
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +20 -46
- flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
- flwr/simulation/run_simulation.py +308 -92
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +184 -0
- flwr/superexec/deployment.py +185 -0
- flwr/superexec/exec_grpc.py +55 -0
- flwr/superexec/exec_servicer.py +70 -0
- flwr/superexec/executor.py +75 -0
- flwr/superexec/simulation.py +193 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
- flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
- flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
- flwr/cli/flower_toml.py +0 -140
- flwr/cli/new/templates/app/flower.toml.tpl +0 -13
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
- flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
- flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
- flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
|
@@ -15,19 +15,32 @@
|
|
|
15
15
|
"""Fleet Simulation Engine API."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import asyncio
|
|
19
18
|
import json
|
|
19
|
+
import threading
|
|
20
|
+
import time
|
|
20
21
|
import traceback
|
|
22
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
21
23
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
22
|
-
from
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from queue import Empty, Queue
|
|
26
|
+
from time import sleep
|
|
27
|
+
from typing import Callable, Dict, Optional
|
|
23
28
|
|
|
24
|
-
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
29
|
+
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
25
30
|
from flwr.client.node_state import NodeState
|
|
31
|
+
from flwr.client.supernode.app import _get_load_client_app_fn
|
|
32
|
+
from flwr.common.constant import (
|
|
33
|
+
NUM_PARTITIONS_KEY,
|
|
34
|
+
PARTITION_ID_KEY,
|
|
35
|
+
PING_MAX_INTERVAL,
|
|
36
|
+
ErrorCode,
|
|
37
|
+
)
|
|
26
38
|
from flwr.common.logger import log
|
|
27
|
-
from flwr.common.
|
|
39
|
+
from flwr.common.message import Error
|
|
28
40
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
29
|
-
from flwr.
|
|
30
|
-
from flwr.
|
|
41
|
+
from flwr.common.typing import Run
|
|
42
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
43
|
+
from flwr.server.superlink.state import State, StateFactory
|
|
31
44
|
|
|
32
45
|
from .backend import Backend, error_messages_backends, supported_backends
|
|
33
46
|
|
|
@@ -41,39 +54,63 @@ def _register_nodes(
|
|
|
41
54
|
nodes_mapping: NodeToPartitionMapping = {}
|
|
42
55
|
state = state_factory.state()
|
|
43
56
|
for i in range(num_nodes):
|
|
44
|
-
node_id = state.create_node()
|
|
57
|
+
node_id = state.create_node(ping_interval=PING_MAX_INTERVAL)
|
|
45
58
|
nodes_mapping[node_id] = i
|
|
46
|
-
log(
|
|
59
|
+
log(DEBUG, "Registered %i nodes", len(nodes_mapping))
|
|
47
60
|
return nodes_mapping
|
|
48
61
|
|
|
49
62
|
|
|
63
|
+
def _register_node_states(
|
|
64
|
+
nodes_mapping: NodeToPartitionMapping,
|
|
65
|
+
run: Run,
|
|
66
|
+
app_dir: Optional[str] = None,
|
|
67
|
+
) -> Dict[int, NodeState]:
|
|
68
|
+
"""Create NodeState objects and pre-register the context for the run."""
|
|
69
|
+
node_states: Dict[int, NodeState] = {}
|
|
70
|
+
num_partitions = len(set(nodes_mapping.values()))
|
|
71
|
+
for node_id, partition_id in nodes_mapping.items():
|
|
72
|
+
node_states[node_id] = NodeState(
|
|
73
|
+
node_id=node_id,
|
|
74
|
+
node_config={
|
|
75
|
+
PARTITION_ID_KEY: partition_id,
|
|
76
|
+
NUM_PARTITIONS_KEY: num_partitions,
|
|
77
|
+
},
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Pre-register Context objects
|
|
81
|
+
node_states[node_id].register_context(
|
|
82
|
+
run_id=run.run_id, run=run, app_dir=app_dir
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return node_states
|
|
86
|
+
|
|
87
|
+
|
|
50
88
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
51
|
-
|
|
89
|
+
def worker(
|
|
52
90
|
app_fn: Callable[[], ClientApp],
|
|
53
|
-
|
|
91
|
+
taskins_queue: "Queue[TaskIns]",
|
|
92
|
+
taskres_queue: "Queue[TaskRes]",
|
|
54
93
|
node_states: Dict[int, NodeState],
|
|
55
|
-
state_factory: StateFactory,
|
|
56
|
-
nodes_mapping: NodeToPartitionMapping,
|
|
57
94
|
backend: Backend,
|
|
95
|
+
f_stop: threading.Event,
|
|
58
96
|
) -> None:
|
|
59
97
|
"""Get TaskIns from queue and pass it to an actor in the pool to execute it."""
|
|
60
|
-
|
|
61
|
-
|
|
98
|
+
while not f_stop.is_set():
|
|
99
|
+
out_mssg = None
|
|
62
100
|
try:
|
|
63
|
-
|
|
101
|
+
# Fetch from queue with timeout. We use a timeout so
|
|
102
|
+
# the stopping event can be evaluated even when the queue is empty.
|
|
103
|
+
task_ins: TaskIns = taskins_queue.get(timeout=1.0)
|
|
64
104
|
node_id = task_ins.task.consumer.node_id
|
|
65
105
|
|
|
66
|
-
#
|
|
67
|
-
node_states[node_id].register_context(run_id=task_ins.run_id)
|
|
106
|
+
# Retrieve context
|
|
68
107
|
context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
|
|
69
108
|
|
|
70
109
|
# Convert TaskIns to Message
|
|
71
110
|
message = message_from_taskins(task_ins)
|
|
72
|
-
# Set partition_id
|
|
73
|
-
message.metadata.partition_id = nodes_mapping[node_id]
|
|
74
111
|
|
|
75
112
|
# Let backend process message
|
|
76
|
-
out_mssg, updated_context =
|
|
113
|
+
out_mssg, updated_context = backend.process_message(
|
|
77
114
|
app_fn, message, context
|
|
78
115
|
)
|
|
79
116
|
|
|
@@ -81,85 +118,74 @@ async def worker(
|
|
|
81
118
|
node_states[node_id].update_context(
|
|
82
119
|
task_ins.run_id, context=updated_context
|
|
83
120
|
)
|
|
84
|
-
|
|
85
|
-
#
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
state.store_task_res(task_res)
|
|
89
|
-
|
|
90
|
-
except asyncio.CancelledError as e:
|
|
91
|
-
log(DEBUG, "Async worker: %s", e)
|
|
92
|
-
break
|
|
93
|
-
|
|
94
|
-
except LoadClientAppError as app_ex:
|
|
95
|
-
log(ERROR, "Async worker: %s", app_ex)
|
|
96
|
-
log(ERROR, traceback.format_exc())
|
|
97
|
-
raise
|
|
98
|
-
|
|
121
|
+
except Empty:
|
|
122
|
+
# An exception raised if queue.get times out
|
|
123
|
+
pass
|
|
124
|
+
# Exceptions aren't raised but reported as an error message
|
|
99
125
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
100
126
|
log(ERROR, ex)
|
|
101
127
|
log(ERROR, traceback.format_exc())
|
|
102
|
-
break
|
|
103
128
|
|
|
129
|
+
if isinstance(ex, ClientAppException):
|
|
130
|
+
e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
|
|
131
|
+
elif isinstance(ex, LoadClientAppError):
|
|
132
|
+
e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
|
|
133
|
+
else:
|
|
134
|
+
e_code = ErrorCode.UNKNOWN
|
|
104
135
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
136
|
+
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
|
|
137
|
+
out_mssg = message.create_error_reply(
|
|
138
|
+
error=Error(code=e_code, reason=reason)
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
finally:
|
|
142
|
+
if out_mssg:
|
|
143
|
+
# Convert to TaskRes
|
|
144
|
+
task_res = message_to_taskres(out_mssg)
|
|
145
|
+
# Store TaskRes in state
|
|
146
|
+
task_res.task.pushed_at = time.time()
|
|
147
|
+
taskres_queue.put(task_res)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def add_taskins_to_queue(
|
|
151
|
+
state: State,
|
|
152
|
+
queue: "Queue[TaskIns]",
|
|
108
153
|
nodes_mapping: NodeToPartitionMapping,
|
|
109
|
-
|
|
110
|
-
consumers: List["asyncio.Task[None]"],
|
|
111
|
-
f_stop: asyncio.Event,
|
|
154
|
+
f_stop: threading.Event,
|
|
112
155
|
) -> None:
|
|
113
|
-
"""
|
|
114
|
-
state = state_factory.state()
|
|
115
|
-
num_initial_consumers = len(consumers)
|
|
156
|
+
"""Put TaskIns in a queue from State."""
|
|
116
157
|
while not f_stop.is_set():
|
|
117
158
|
for node_id in nodes_mapping.keys():
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
# Count consumers that are running
|
|
123
|
-
num_active = sum(not (cc.done()) for cc in consumers)
|
|
124
|
-
|
|
125
|
-
# Alert if number of consumers decreased by half
|
|
126
|
-
if num_active < num_initial_consumers // 2:
|
|
127
|
-
log(
|
|
128
|
-
WARN,
|
|
129
|
-
"Number of active workers has more than halved: (%i/%i active)",
|
|
130
|
-
num_active,
|
|
131
|
-
num_initial_consumers,
|
|
132
|
-
)
|
|
159
|
+
task_ins_list = state.get_task_ins(node_id=node_id, limit=1)
|
|
160
|
+
for task_ins in task_ins_list:
|
|
161
|
+
queue.put(task_ins)
|
|
162
|
+
sleep(0.1)
|
|
133
163
|
|
|
134
|
-
# Break if consumers died
|
|
135
|
-
if num_active == 0:
|
|
136
|
-
raise RuntimeError("All workers have died. Ending Simulation.")
|
|
137
164
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
queue
|
|
148
|
-
|
|
149
|
-
await asyncio.sleep(1.0)
|
|
150
|
-
log(DEBUG, "Async producer: Stopped pulling from StateFactory.")
|
|
165
|
+
def put_taskres_into_state(
|
|
166
|
+
state: State, queue: "Queue[TaskRes]", f_stop: threading.Event
|
|
167
|
+
) -> None:
|
|
168
|
+
"""Put TaskRes into State from a queue."""
|
|
169
|
+
while not f_stop.is_set():
|
|
170
|
+
try:
|
|
171
|
+
taskres = queue.get(timeout=1.0)
|
|
172
|
+
state.store_task_res(taskres)
|
|
173
|
+
except Empty:
|
|
174
|
+
# queue is empty when timeout was triggered
|
|
175
|
+
pass
|
|
151
176
|
|
|
152
177
|
|
|
153
|
-
|
|
178
|
+
def run_api(
|
|
154
179
|
app_fn: Callable[[], ClientApp],
|
|
155
180
|
backend_fn: Callable[[], Backend],
|
|
156
181
|
nodes_mapping: NodeToPartitionMapping,
|
|
157
182
|
state_factory: StateFactory,
|
|
158
183
|
node_states: Dict[int, NodeState],
|
|
159
|
-
f_stop:
|
|
184
|
+
f_stop: threading.Event,
|
|
160
185
|
) -> None:
|
|
161
|
-
"""Run the VCE
|
|
162
|
-
|
|
186
|
+
"""Run the VCE."""
|
|
187
|
+
taskins_queue: "Queue[TaskIns]" = Queue()
|
|
188
|
+
taskres_queue: "Queue[TaskRes]" = Queue()
|
|
163
189
|
|
|
164
190
|
try:
|
|
165
191
|
|
|
@@ -167,29 +193,48 @@ async def run(
|
|
|
167
193
|
backend = backend_fn()
|
|
168
194
|
|
|
169
195
|
# Build backend
|
|
170
|
-
|
|
196
|
+
backend.build()
|
|
171
197
|
|
|
172
198
|
# Add workers (they submit Messages to Backend)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
add_taskins_to_queue(
|
|
184
|
-
queue, state_factory, nodes_mapping, backend, worker_tasks, f_stop
|
|
185
|
-
)
|
|
199
|
+
state = state_factory.state()
|
|
200
|
+
|
|
201
|
+
extractor_th = threading.Thread(
|
|
202
|
+
target=add_taskins_to_queue,
|
|
203
|
+
args=(
|
|
204
|
+
state,
|
|
205
|
+
taskins_queue,
|
|
206
|
+
nodes_mapping,
|
|
207
|
+
f_stop,
|
|
208
|
+
),
|
|
186
209
|
)
|
|
210
|
+
extractor_th.start()
|
|
211
|
+
|
|
212
|
+
injector_th = threading.Thread(
|
|
213
|
+
target=put_taskres_into_state,
|
|
214
|
+
args=(
|
|
215
|
+
state,
|
|
216
|
+
taskres_queue,
|
|
217
|
+
f_stop,
|
|
218
|
+
),
|
|
219
|
+
)
|
|
220
|
+
injector_th.start()
|
|
221
|
+
|
|
222
|
+
with ThreadPoolExecutor() as executor:
|
|
223
|
+
_ = [
|
|
224
|
+
executor.submit(
|
|
225
|
+
worker,
|
|
226
|
+
app_fn,
|
|
227
|
+
taskins_queue,
|
|
228
|
+
taskres_queue,
|
|
229
|
+
node_states,
|
|
230
|
+
backend,
|
|
231
|
+
f_stop,
|
|
232
|
+
)
|
|
233
|
+
for _ in range(backend.num_workers)
|
|
234
|
+
]
|
|
187
235
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
# all worker (consumer) coroutines are completed. Workers
|
|
191
|
-
# also run forever and only end if an exception is raised.
|
|
192
|
-
await asyncio.gather(producer)
|
|
236
|
+
extractor_th.join()
|
|
237
|
+
injector_th.join()
|
|
193
238
|
|
|
194
239
|
except Exception as ex:
|
|
195
240
|
|
|
@@ -204,26 +249,21 @@ async def run(
|
|
|
204
249
|
raise RuntimeError("Simulation Engine crashed.") from ex
|
|
205
250
|
|
|
206
251
|
finally:
|
|
207
|
-
# Produced task terminated, now cancel worker tasks
|
|
208
|
-
for w_t in worker_tasks:
|
|
209
|
-
_ = w_t.cancel()
|
|
210
|
-
|
|
211
|
-
while not all(w_t.done() for w_t in worker_tasks):
|
|
212
|
-
log(DEBUG, "Terminating async workers...")
|
|
213
|
-
await asyncio.sleep(0.5)
|
|
214
|
-
|
|
215
|
-
await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])
|
|
216
252
|
|
|
217
253
|
# Terminate backend
|
|
218
|
-
|
|
254
|
+
backend.terminate()
|
|
219
255
|
|
|
220
256
|
|
|
221
|
-
# pylint: disable=too-many-arguments,unused-argument,too-many-locals
|
|
257
|
+
# pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
|
|
258
|
+
# pylint: disable=too-many-statements
|
|
222
259
|
def start_vce(
|
|
223
260
|
backend_name: str,
|
|
224
261
|
backend_config_json_stream: str,
|
|
225
262
|
app_dir: str,
|
|
226
|
-
|
|
263
|
+
is_app: bool,
|
|
264
|
+
f_stop: threading.Event,
|
|
265
|
+
run: Run,
|
|
266
|
+
flwr_dir: Optional[str] = None,
|
|
227
267
|
client_app: Optional[ClientApp] = None,
|
|
228
268
|
client_app_attr: Optional[str] = None,
|
|
229
269
|
num_supernodes: Optional[int] = None,
|
|
@@ -259,6 +299,7 @@ def start_vce(
|
|
|
259
299
|
# Use mapping constructed externally. This also means nodes
|
|
260
300
|
# have previously being registered.
|
|
261
301
|
nodes_mapping = existing_nodes_mapping
|
|
302
|
+
app_dir = str(Path(app_dir).absolute())
|
|
262
303
|
|
|
263
304
|
if not state_factory:
|
|
264
305
|
log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
|
|
@@ -273,12 +314,12 @@ def start_vce(
|
|
|
273
314
|
)
|
|
274
315
|
|
|
275
316
|
# Construct mapping of NodeStates
|
|
276
|
-
node_states
|
|
277
|
-
|
|
278
|
-
|
|
317
|
+
node_states = _register_node_states(
|
|
318
|
+
nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None
|
|
319
|
+
)
|
|
279
320
|
|
|
280
321
|
# Load backend config
|
|
281
|
-
log(
|
|
322
|
+
log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
|
|
282
323
|
backend_config = json.loads(backend_config_json_stream)
|
|
283
324
|
|
|
284
325
|
try:
|
|
@@ -298,20 +339,18 @@ def start_vce(
|
|
|
298
339
|
|
|
299
340
|
def backend_fn() -> Backend:
|
|
300
341
|
"""Instantiate a Backend."""
|
|
301
|
-
return backend_type(backend_config
|
|
302
|
-
|
|
303
|
-
log(INFO, "client_app_attr = %s", client_app_attr)
|
|
342
|
+
return backend_type(backend_config)
|
|
304
343
|
|
|
305
344
|
# Load ClientApp if needed
|
|
306
345
|
def _load() -> ClientApp:
|
|
307
346
|
|
|
308
347
|
if client_app_attr:
|
|
309
|
-
app
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
348
|
+
app = _get_load_client_app_fn(
|
|
349
|
+
default_app_ref=client_app_attr,
|
|
350
|
+
app_path=app_dir,
|
|
351
|
+
flwr_dir=flwr_dir,
|
|
352
|
+
multi_app=False,
|
|
353
|
+
)(run.fab_id, run.fab_version)
|
|
315
354
|
|
|
316
355
|
if client_app:
|
|
317
356
|
app = client_app
|
|
@@ -319,8 +358,21 @@ def start_vce(
|
|
|
319
358
|
|
|
320
359
|
app_fn = _load
|
|
321
360
|
|
|
322
|
-
|
|
323
|
-
|
|
361
|
+
try:
|
|
362
|
+
# Test if ClientApp can be loaded
|
|
363
|
+
client_app = app_fn()
|
|
364
|
+
|
|
365
|
+
# Cache `ClientApp`
|
|
366
|
+
if client_app_attr:
|
|
367
|
+
# Now wrap the loaded ClientApp in a dummy function
|
|
368
|
+
# this prevent unnecesary low-level loading of ClientApp
|
|
369
|
+
def _load_client_app() -> ClientApp:
|
|
370
|
+
return client_app
|
|
371
|
+
|
|
372
|
+
app_fn = _load_client_app
|
|
373
|
+
|
|
374
|
+
# Run main simulation loop
|
|
375
|
+
run_api(
|
|
324
376
|
app_fn,
|
|
325
377
|
backend_fn,
|
|
326
378
|
nodes_mapping,
|
|
@@ -328,4 +380,15 @@ def start_vce(
|
|
|
328
380
|
node_states,
|
|
329
381
|
f_stop,
|
|
330
382
|
)
|
|
331
|
-
|
|
383
|
+
except LoadClientAppError as loadapp_ex:
|
|
384
|
+
f_stop_delay = 10
|
|
385
|
+
log(
|
|
386
|
+
ERROR,
|
|
387
|
+
"LoadClientAppError exception encountered. Terminating simulation in %is",
|
|
388
|
+
f_stop_delay,
|
|
389
|
+
)
|
|
390
|
+
time.sleep(f_stop_delay)
|
|
391
|
+
f_stop.set() # set termination event
|
|
392
|
+
raise loadapp_ex
|
|
393
|
+
except Exception as ex:
|
|
394
|
+
raise ex
|