flwr-nightly 1.10.0.dev20240624__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +18 -4
- flwr/cli/config_utils.py +36 -14
- flwr/cli/install.py +17 -1
- flwr/cli/new/new.py +31 -20
- flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
- flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
- flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
- flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
- flwr/cli/run/run.py +135 -51
- flwr/client/__init__.py +2 -0
- flwr/client/app.py +63 -26
- flwr/client/client_app.py +49 -4
- flwr/client/grpc_adapter_client/connection.py +3 -2
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +17 -6
- flwr/client/message_handler/message_handler.py +3 -4
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/connection.py +19 -8
- flwr/client/supernode/app.py +60 -21
- flwr/client/typing.py +1 -0
- flwr/common/config.py +87 -2
- flwr/common/constant.py +6 -0
- flwr/common/context.py +26 -1
- flwr/common/logger.py +38 -0
- flwr/common/message.py +0 -17
- flwr/common/serde.py +45 -0
- flwr/common/telemetry.py +17 -0
- flwr/common/typing.py +5 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +24 -19
- flwr/proto/driver_pb2.pyi +21 -1
- flwr/proto/exec_pb2.py +16 -11
- flwr/proto/exec_pb2.pyi +22 -1
- flwr/proto/run_pb2.py +12 -7
- flwr/proto/run_pb2.pyi +22 -1
- flwr/proto/task_pb2.py +7 -8
- flwr/server/__init__.py +2 -0
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/grpc_driver.py +82 -140
- flwr/server/run_serverapp.py +40 -15
- flwr/server/server_app.py +56 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/superlink/driver/driver_servicer.py +18 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
- flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +149 -122
- flwr/server/superlink/state/in_memory_state.py +15 -7
- flwr/server/superlink/state/sqlite_state.py +27 -12
- flwr/server/superlink/state/state.py +7 -2
- flwr/server/superlink/state/utils.py +6 -0
- flwr/server/typing.py +2 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
- flwr/simulation/app.py +52 -36
- flwr/simulation/ray_transport/ray_actor.py +15 -19
- flwr/simulation/ray_transport/ray_client_proxy.py +33 -13
- flwr/simulation/run_simulation.py +237 -66
- flwr/superexec/app.py +14 -7
- flwr/superexec/deployment.py +186 -0
- flwr/superexec/exec_grpc.py +5 -1
- flwr/superexec/exec_servicer.py +4 -1
- flwr/superexec/executor.py +18 -0
- flwr/superexec/simulation.py +151 -0
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +95 -88
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
|
@@ -33,8 +33,8 @@ class Backend(ABC):
|
|
|
33
33
|
"""Construct a backend."""
|
|
34
34
|
|
|
35
35
|
@abstractmethod
|
|
36
|
-
|
|
37
|
-
"""Build backend
|
|
36
|
+
def build(self) -> None:
|
|
37
|
+
"""Build backend.
|
|
38
38
|
|
|
39
39
|
Different components need to be in place before workers in a backend are ready
|
|
40
40
|
to accept jobs. When this method finishes executing, the backend should be fully
|
|
@@ -54,11 +54,11 @@ class Backend(ABC):
|
|
|
54
54
|
"""Report whether a backend worker is idle and can therefore run a ClientApp."""
|
|
55
55
|
|
|
56
56
|
@abstractmethod
|
|
57
|
-
|
|
57
|
+
def terminate(self) -> None:
|
|
58
58
|
"""Terminate backend."""
|
|
59
59
|
|
|
60
60
|
@abstractmethod
|
|
61
|
-
|
|
61
|
+
def process_message(
|
|
62
62
|
self,
|
|
63
63
|
app: Callable[[], ClientApp],
|
|
64
64
|
message: Message,
|
|
@@ -21,6 +21,7 @@ from typing import Callable, Dict, List, Tuple, Union
|
|
|
21
21
|
import ray
|
|
22
22
|
|
|
23
23
|
from flwr.client.client_app import ClientApp
|
|
24
|
+
from flwr.common.constant import PARTITION_ID_KEY
|
|
24
25
|
from flwr.common.context import Context
|
|
25
26
|
from flwr.common.logger import log
|
|
26
27
|
from flwr.common.message import Message
|
|
@@ -153,12 +154,12 @@ class RayBackend(Backend):
|
|
|
153
154
|
"""Report whether the pool has idle actors."""
|
|
154
155
|
return self.pool.is_actor_available()
|
|
155
156
|
|
|
156
|
-
|
|
157
|
+
def build(self) -> None:
|
|
157
158
|
"""Build pool of Ray actors that this backend will submit jobs to."""
|
|
158
|
-
|
|
159
|
+
self.pool.add_actors_to_pool(self.pool.actors_capacity)
|
|
159
160
|
log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
|
|
160
161
|
|
|
161
|
-
|
|
162
|
+
def process_message(
|
|
162
163
|
self,
|
|
163
164
|
app: Callable[[], ClientApp],
|
|
164
165
|
message: Message,
|
|
@@ -168,21 +169,20 @@ class RayBackend(Backend):
|
|
|
168
169
|
|
|
169
170
|
Return output message and updated context.
|
|
170
171
|
"""
|
|
171
|
-
partition_id =
|
|
172
|
+
partition_id = context.node_config[PARTITION_ID_KEY]
|
|
172
173
|
|
|
173
174
|
try:
|
|
174
175
|
# Submit a task to the pool
|
|
175
|
-
future =
|
|
176
|
+
future = self.pool.submit(
|
|
176
177
|
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
|
|
177
178
|
(app, message, str(partition_id), context),
|
|
178
179
|
)
|
|
179
180
|
|
|
180
|
-
await future
|
|
181
181
|
# Fetch result
|
|
182
182
|
(
|
|
183
183
|
out_mssg,
|
|
184
184
|
updated_context,
|
|
185
|
-
) =
|
|
185
|
+
) = self.pool.fetch_result_and_return_actor_to_pool(future)
|
|
186
186
|
|
|
187
187
|
return out_mssg, updated_context
|
|
188
188
|
|
|
@@ -193,11 +193,11 @@ class RayBackend(Backend):
|
|
|
193
193
|
self.__class__.__name__,
|
|
194
194
|
)
|
|
195
195
|
# add actor back into pool
|
|
196
|
-
|
|
196
|
+
self.pool.add_actor_back_to_pool(future)
|
|
197
197
|
raise ex
|
|
198
198
|
|
|
199
|
-
|
|
199
|
+
def terminate(self) -> None:
|
|
200
200
|
"""Terminate all actors in actor pool."""
|
|
201
|
-
|
|
201
|
+
self.pool.terminate_all_actors()
|
|
202
202
|
ray.shutdown()
|
|
203
203
|
log(DEBUG, "Terminated %s", self.__class__.__name__)
|
|
@@ -14,24 +14,33 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Fleet Simulation Engine API."""
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
|
|
18
18
|
import json
|
|
19
|
-
import
|
|
19
|
+
import threading
|
|
20
20
|
import time
|
|
21
21
|
import traceback
|
|
22
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
22
23
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
23
24
|
from pathlib import Path
|
|
24
|
-
from
|
|
25
|
+
from queue import Empty, Queue
|
|
26
|
+
from time import sleep
|
|
27
|
+
from typing import Callable, Dict, Optional
|
|
25
28
|
|
|
26
29
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
27
30
|
from flwr.client.node_state import NodeState
|
|
28
|
-
from flwr.
|
|
31
|
+
from flwr.client.supernode.app import _get_load_client_app_fn
|
|
32
|
+
from flwr.common.constant import (
|
|
33
|
+
NUM_PARTITIONS_KEY,
|
|
34
|
+
PARTITION_ID_KEY,
|
|
35
|
+
PING_MAX_INTERVAL,
|
|
36
|
+
ErrorCode,
|
|
37
|
+
)
|
|
29
38
|
from flwr.common.logger import log
|
|
30
39
|
from flwr.common.message import Error
|
|
31
|
-
from flwr.common.object_ref import load_app
|
|
32
40
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
33
|
-
from flwr.
|
|
34
|
-
from flwr.
|
|
41
|
+
from flwr.common.typing import Run
|
|
42
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
43
|
+
from flwr.server.superlink.state import State, StateFactory
|
|
35
44
|
|
|
36
45
|
from .backend import Backend, error_messages_backends, supported_backends
|
|
37
46
|
|
|
@@ -51,34 +60,57 @@ def _register_nodes(
|
|
|
51
60
|
return nodes_mapping
|
|
52
61
|
|
|
53
62
|
|
|
63
|
+
def _register_node_states(
|
|
64
|
+
nodes_mapping: NodeToPartitionMapping,
|
|
65
|
+
run: Run,
|
|
66
|
+
app_dir: Optional[str] = None,
|
|
67
|
+
) -> Dict[int, NodeState]:
|
|
68
|
+
"""Create NodeState objects and pre-register the context for the run."""
|
|
69
|
+
node_states: Dict[int, NodeState] = {}
|
|
70
|
+
num_partitions = len(set(nodes_mapping.values()))
|
|
71
|
+
for node_id, partition_id in nodes_mapping.items():
|
|
72
|
+
node_states[node_id] = NodeState(
|
|
73
|
+
node_id=node_id,
|
|
74
|
+
node_config={
|
|
75
|
+
PARTITION_ID_KEY: str(partition_id),
|
|
76
|
+
NUM_PARTITIONS_KEY: str(num_partitions),
|
|
77
|
+
},
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Pre-register Context objects
|
|
81
|
+
node_states[node_id].register_context(
|
|
82
|
+
run_id=run.run_id, run=run, app_dir=app_dir
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return node_states
|
|
86
|
+
|
|
87
|
+
|
|
54
88
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
55
|
-
|
|
89
|
+
def worker(
|
|
56
90
|
app_fn: Callable[[], ClientApp],
|
|
57
|
-
|
|
91
|
+
taskins_queue: "Queue[TaskIns]",
|
|
92
|
+
taskres_queue: "Queue[TaskRes]",
|
|
58
93
|
node_states: Dict[int, NodeState],
|
|
59
|
-
state_factory: StateFactory,
|
|
60
|
-
nodes_mapping: NodeToPartitionMapping,
|
|
61
94
|
backend: Backend,
|
|
95
|
+
f_stop: threading.Event,
|
|
62
96
|
) -> None:
|
|
63
97
|
"""Get TaskIns from queue and pass it to an actor in the pool to execute it."""
|
|
64
|
-
|
|
65
|
-
while True:
|
|
98
|
+
while not f_stop.is_set():
|
|
66
99
|
out_mssg = None
|
|
67
100
|
try:
|
|
68
|
-
|
|
101
|
+
# Fetch from queue with timeout. We use a timeout so
|
|
102
|
+
# the stopping event can be evaluated even when the queue is empty.
|
|
103
|
+
task_ins: TaskIns = taskins_queue.get(timeout=1.0)
|
|
69
104
|
node_id = task_ins.task.consumer.node_id
|
|
70
105
|
|
|
71
|
-
#
|
|
72
|
-
node_states[node_id].register_context(run_id=task_ins.run_id)
|
|
106
|
+
# Retrieve context
|
|
73
107
|
context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
|
|
74
108
|
|
|
75
109
|
# Convert TaskIns to Message
|
|
76
110
|
message = message_from_taskins(task_ins)
|
|
77
|
-
# Set partition_id
|
|
78
|
-
message.metadata.partition_id = nodes_mapping[node_id]
|
|
79
111
|
|
|
80
112
|
# Let backend process message
|
|
81
|
-
out_mssg, updated_context =
|
|
113
|
+
out_mssg, updated_context = backend.process_message(
|
|
82
114
|
app_fn, message, context
|
|
83
115
|
)
|
|
84
116
|
|
|
@@ -86,11 +118,9 @@ async def worker(
|
|
|
86
118
|
node_states[node_id].update_context(
|
|
87
119
|
task_ins.run_id, context=updated_context
|
|
88
120
|
)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
break
|
|
93
|
-
|
|
121
|
+
except Empty:
|
|
122
|
+
# An exception raised if queue.get times out
|
|
123
|
+
pass
|
|
94
124
|
# Exceptions aren't raised but reported as an error message
|
|
95
125
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
96
126
|
log(ERROR, ex)
|
|
@@ -114,67 +144,48 @@ async def worker(
|
|
|
114
144
|
task_res = message_to_taskres(out_mssg)
|
|
115
145
|
# Store TaskRes in state
|
|
116
146
|
task_res.task.pushed_at = time.time()
|
|
117
|
-
|
|
147
|
+
taskres_queue.put(task_res)
|
|
118
148
|
|
|
119
149
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
150
|
+
def add_taskins_to_queue(
|
|
151
|
+
state: State,
|
|
152
|
+
queue: "Queue[TaskIns]",
|
|
123
153
|
nodes_mapping: NodeToPartitionMapping,
|
|
124
|
-
|
|
125
|
-
consumers: List["asyncio.Task[None]"],
|
|
126
|
-
f_stop: asyncio.Event,
|
|
154
|
+
f_stop: threading.Event,
|
|
127
155
|
) -> None:
|
|
128
|
-
"""
|
|
129
|
-
state = state_factory.state()
|
|
130
|
-
num_initial_consumers = len(consumers)
|
|
156
|
+
"""Put TaskIns in a queue from State."""
|
|
131
157
|
while not f_stop.is_set():
|
|
132
158
|
for node_id in nodes_mapping.keys():
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
# Count consumers that are running
|
|
138
|
-
num_active = sum(not (cc.done()) for cc in consumers)
|
|
139
|
-
|
|
140
|
-
# Alert if number of consumers decreased by half
|
|
141
|
-
if num_active < num_initial_consumers // 2:
|
|
142
|
-
log(
|
|
143
|
-
WARN,
|
|
144
|
-
"Number of active workers has more than halved: (%i/%i active)",
|
|
145
|
-
num_active,
|
|
146
|
-
num_initial_consumers,
|
|
147
|
-
)
|
|
159
|
+
task_ins_list = state.get_task_ins(node_id=node_id, limit=1)
|
|
160
|
+
for task_ins in task_ins_list:
|
|
161
|
+
queue.put(task_ins)
|
|
162
|
+
sleep(0.1)
|
|
148
163
|
|
|
149
|
-
# Break if consumers died
|
|
150
|
-
if num_active == 0:
|
|
151
|
-
raise RuntimeError("All workers have died. Ending Simulation.")
|
|
152
164
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
queue
|
|
163
|
-
|
|
164
|
-
await asyncio.sleep(1.0)
|
|
165
|
-
log(DEBUG, "Async producer: Stopped pulling from StateFactory.")
|
|
165
|
+
def put_taskres_into_state(
|
|
166
|
+
state: State, queue: "Queue[TaskRes]", f_stop: threading.Event
|
|
167
|
+
) -> None:
|
|
168
|
+
"""Put TaskRes into State from a queue."""
|
|
169
|
+
while not f_stop.is_set():
|
|
170
|
+
try:
|
|
171
|
+
taskres = queue.get(timeout=1.0)
|
|
172
|
+
state.store_task_res(taskres)
|
|
173
|
+
except Empty:
|
|
174
|
+
# queue is empty when timeout was triggered
|
|
175
|
+
pass
|
|
166
176
|
|
|
167
177
|
|
|
168
|
-
|
|
178
|
+
def run_api(
|
|
169
179
|
app_fn: Callable[[], ClientApp],
|
|
170
180
|
backend_fn: Callable[[], Backend],
|
|
171
181
|
nodes_mapping: NodeToPartitionMapping,
|
|
172
182
|
state_factory: StateFactory,
|
|
173
183
|
node_states: Dict[int, NodeState],
|
|
174
|
-
f_stop:
|
|
184
|
+
f_stop: threading.Event,
|
|
175
185
|
) -> None:
|
|
176
|
-
"""Run the VCE
|
|
177
|
-
|
|
186
|
+
"""Run the VCE."""
|
|
187
|
+
taskins_queue: "Queue[TaskIns]" = Queue()
|
|
188
|
+
taskres_queue: "Queue[TaskRes]" = Queue()
|
|
178
189
|
|
|
179
190
|
try:
|
|
180
191
|
|
|
@@ -182,29 +193,48 @@ async def run(
|
|
|
182
193
|
backend = backend_fn()
|
|
183
194
|
|
|
184
195
|
# Build backend
|
|
185
|
-
|
|
196
|
+
backend.build()
|
|
186
197
|
|
|
187
198
|
# Add workers (they submit Messages to Backend)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
199
|
+
state = state_factory.state()
|
|
200
|
+
|
|
201
|
+
extractor_th = threading.Thread(
|
|
202
|
+
target=add_taskins_to_queue,
|
|
203
|
+
args=(
|
|
204
|
+
state,
|
|
205
|
+
taskins_queue,
|
|
206
|
+
nodes_mapping,
|
|
207
|
+
f_stop,
|
|
208
|
+
),
|
|
209
|
+
)
|
|
210
|
+
extractor_th.start()
|
|
211
|
+
|
|
212
|
+
injector_th = threading.Thread(
|
|
213
|
+
target=put_taskres_into_state,
|
|
214
|
+
args=(
|
|
215
|
+
state,
|
|
216
|
+
taskres_queue,
|
|
217
|
+
f_stop,
|
|
218
|
+
),
|
|
201
219
|
)
|
|
220
|
+
injector_th.start()
|
|
221
|
+
|
|
222
|
+
with ThreadPoolExecutor() as executor:
|
|
223
|
+
_ = [
|
|
224
|
+
executor.submit(
|
|
225
|
+
worker,
|
|
226
|
+
app_fn,
|
|
227
|
+
taskins_queue,
|
|
228
|
+
taskres_queue,
|
|
229
|
+
node_states,
|
|
230
|
+
backend,
|
|
231
|
+
f_stop,
|
|
232
|
+
)
|
|
233
|
+
for _ in range(backend.num_workers)
|
|
234
|
+
]
|
|
202
235
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
# all worker (consumer) coroutines are completed. Workers
|
|
206
|
-
# also run forever and only end if an exception is raised.
|
|
207
|
-
await asyncio.gather(producer)
|
|
236
|
+
extractor_th.join()
|
|
237
|
+
injector_th.join()
|
|
208
238
|
|
|
209
239
|
except Exception as ex:
|
|
210
240
|
|
|
@@ -219,18 +249,9 @@ async def run(
|
|
|
219
249
|
raise RuntimeError("Simulation Engine crashed.") from ex
|
|
220
250
|
|
|
221
251
|
finally:
|
|
222
|
-
# Produced task terminated, now cancel worker tasks
|
|
223
|
-
for w_t in worker_tasks:
|
|
224
|
-
_ = w_t.cancel()
|
|
225
|
-
|
|
226
|
-
while not all(w_t.done() for w_t in worker_tasks):
|
|
227
|
-
log(DEBUG, "Terminating async workers...")
|
|
228
|
-
await asyncio.sleep(0.5)
|
|
229
|
-
|
|
230
|
-
await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])
|
|
231
252
|
|
|
232
253
|
# Terminate backend
|
|
233
|
-
|
|
254
|
+
backend.terminate()
|
|
234
255
|
|
|
235
256
|
|
|
236
257
|
# pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
|
|
@@ -239,7 +260,10 @@ def start_vce(
|
|
|
239
260
|
backend_name: str,
|
|
240
261
|
backend_config_json_stream: str,
|
|
241
262
|
app_dir: str,
|
|
242
|
-
|
|
263
|
+
is_app: bool,
|
|
264
|
+
f_stop: threading.Event,
|
|
265
|
+
run: Run,
|
|
266
|
+
flwr_dir: Optional[str] = None,
|
|
243
267
|
client_app: Optional[ClientApp] = None,
|
|
244
268
|
client_app_attr: Optional[str] = None,
|
|
245
269
|
num_supernodes: Optional[int] = None,
|
|
@@ -290,9 +314,9 @@ def start_vce(
|
|
|
290
314
|
)
|
|
291
315
|
|
|
292
316
|
# Construct mapping of NodeStates
|
|
293
|
-
node_states
|
|
294
|
-
|
|
295
|
-
|
|
317
|
+
node_states = _register_node_states(
|
|
318
|
+
nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None
|
|
319
|
+
)
|
|
296
320
|
|
|
297
321
|
# Load backend config
|
|
298
322
|
log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
|
|
@@ -321,16 +345,12 @@ def start_vce(
|
|
|
321
345
|
def _load() -> ClientApp:
|
|
322
346
|
|
|
323
347
|
if client_app_attr:
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
if not isinstance(app, ClientApp):
|
|
331
|
-
raise LoadClientAppError(
|
|
332
|
-
f"Attribute {client_app_attr} is not of type {ClientApp}",
|
|
333
|
-
) from None
|
|
348
|
+
app = _get_load_client_app_fn(
|
|
349
|
+
default_app_ref=client_app_attr,
|
|
350
|
+
dir_arg=app_dir,
|
|
351
|
+
flwr_dir_arg=flwr_dir,
|
|
352
|
+
multi_app=True,
|
|
353
|
+
)(run.fab_id, run.fab_version)
|
|
334
354
|
|
|
335
355
|
if client_app:
|
|
336
356
|
app = client_app
|
|
@@ -340,18 +360,25 @@ def start_vce(
|
|
|
340
360
|
|
|
341
361
|
try:
|
|
342
362
|
# Test if ClientApp can be loaded
|
|
343
|
-
|
|
363
|
+
client_app = app_fn()
|
|
364
|
+
|
|
365
|
+
# Cache `ClientApp`
|
|
366
|
+
if client_app_attr:
|
|
367
|
+
# Now wrap the loaded ClientApp in a dummy function
|
|
368
|
+
# this prevent unnecesary low-level loading of ClientApp
|
|
369
|
+
def _load_client_app() -> ClientApp:
|
|
370
|
+
return client_app
|
|
371
|
+
|
|
372
|
+
app_fn = _load_client_app
|
|
344
373
|
|
|
345
374
|
# Run main simulation loop
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
f_stop,
|
|
354
|
-
)
|
|
375
|
+
run_api(
|
|
376
|
+
app_fn,
|
|
377
|
+
backend_fn,
|
|
378
|
+
nodes_mapping,
|
|
379
|
+
state_factory,
|
|
380
|
+
node_states,
|
|
381
|
+
f_stop,
|
|
355
382
|
)
|
|
356
383
|
except LoadClientAppError as loadapp_ex:
|
|
357
384
|
f_stop_delay = 10
|
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
"""In-memory State implementation."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import os
|
|
19
18
|
import threading
|
|
20
19
|
import time
|
|
21
20
|
from logging import ERROR
|
|
@@ -23,12 +22,13 @@ from typing import Dict, List, Optional, Set, Tuple
|
|
|
23
22
|
from uuid import UUID, uuid4
|
|
24
23
|
|
|
25
24
|
from flwr.common import log, now
|
|
26
|
-
from flwr.common.
|
|
25
|
+
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
|
|
26
|
+
from flwr.common.typing import Run, UserConfig
|
|
27
27
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
28
28
|
from flwr.server.superlink.state.state import State
|
|
29
29
|
from flwr.server.utils import validate_task_ins_or_res
|
|
30
30
|
|
|
31
|
-
from .utils import make_node_unavailable_taskres
|
|
31
|
+
from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
@@ -216,7 +216,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
216
216
|
) -> int:
|
|
217
217
|
"""Create, store in state, and return `node_id`."""
|
|
218
218
|
# Sample a random int64 as node_id
|
|
219
|
-
node_id
|
|
219
|
+
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
|
220
220
|
|
|
221
221
|
with self.lock:
|
|
222
222
|
if node_id in self.node_ids:
|
|
@@ -275,15 +275,23 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
275
275
|
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
276
276
|
return self.public_key_to_node_id.get(client_public_key)
|
|
277
277
|
|
|
278
|
-
def create_run(
|
|
278
|
+
def create_run(
|
|
279
|
+
self,
|
|
280
|
+
fab_id: str,
|
|
281
|
+
fab_version: str,
|
|
282
|
+
override_config: UserConfig,
|
|
283
|
+
) -> int:
|
|
279
284
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
280
285
|
# Sample a random int64 as run_id
|
|
281
286
|
with self.lock:
|
|
282
|
-
run_id
|
|
287
|
+
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
283
288
|
|
|
284
289
|
if run_id not in self.run_ids:
|
|
285
290
|
self.run_ids[run_id] = Run(
|
|
286
|
-
run_id=run_id,
|
|
291
|
+
run_id=run_id,
|
|
292
|
+
fab_id=fab_id,
|
|
293
|
+
fab_version=fab_version,
|
|
294
|
+
override_config=override_config,
|
|
287
295
|
)
|
|
288
296
|
return run_id
|
|
289
297
|
log(ERROR, "Unexpected run creation failure.")
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""SQLite based implemenation of server state."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import
|
|
18
|
+
import json
|
|
19
19
|
import re
|
|
20
20
|
import sqlite3
|
|
21
21
|
import time
|
|
@@ -24,14 +24,15 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
|
|
|
24
24
|
from uuid import UUID, uuid4
|
|
25
25
|
|
|
26
26
|
from flwr.common import log, now
|
|
27
|
-
from flwr.common.
|
|
27
|
+
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
|
|
28
|
+
from flwr.common.typing import Run, UserConfig
|
|
28
29
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
29
30
|
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
|
|
30
31
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
31
32
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
32
33
|
|
|
33
34
|
from .state import State
|
|
34
|
-
from .utils import make_node_unavailable_taskres
|
|
35
|
+
from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
|
|
35
36
|
|
|
36
37
|
SQL_CREATE_TABLE_NODE = """
|
|
37
38
|
CREATE TABLE IF NOT EXISTS node(
|
|
@@ -61,9 +62,10 @@ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
|
|
61
62
|
|
|
62
63
|
SQL_CREATE_TABLE_RUN = """
|
|
63
64
|
CREATE TABLE IF NOT EXISTS run(
|
|
64
|
-
run_id
|
|
65
|
-
fab_id
|
|
66
|
-
fab_version
|
|
65
|
+
run_id INTEGER UNIQUE,
|
|
66
|
+
fab_id TEXT,
|
|
67
|
+
fab_version TEXT,
|
|
68
|
+
override_config TEXT
|
|
67
69
|
);
|
|
68
70
|
"""
|
|
69
71
|
|
|
@@ -541,7 +543,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
541
543
|
) -> int:
|
|
542
544
|
"""Create, store in state, and return `node_id`."""
|
|
543
545
|
# Sample a random int64 as node_id
|
|
544
|
-
node_id
|
|
546
|
+
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
|
545
547
|
|
|
546
548
|
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
|
547
549
|
row = self.query(query, {"public_key": public_key})
|
|
@@ -613,17 +615,27 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
613
615
|
return node_id
|
|
614
616
|
return None
|
|
615
617
|
|
|
616
|
-
def create_run(
|
|
618
|
+
def create_run(
|
|
619
|
+
self,
|
|
620
|
+
fab_id: str,
|
|
621
|
+
fab_version: str,
|
|
622
|
+
override_config: UserConfig,
|
|
623
|
+
) -> int:
|
|
617
624
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
618
625
|
# Sample a random int64 as run_id
|
|
619
|
-
run_id
|
|
626
|
+
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
620
627
|
|
|
621
628
|
# Check conflicts
|
|
622
629
|
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
|
|
623
630
|
# If run_id does not exist
|
|
624
631
|
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
|
|
625
|
-
query =
|
|
626
|
-
|
|
632
|
+
query = (
|
|
633
|
+
"INSERT INTO run (run_id, fab_id, fab_version, override_config)"
|
|
634
|
+
"VALUES (?, ?, ?, ?);"
|
|
635
|
+
)
|
|
636
|
+
self.query(
|
|
637
|
+
query, (run_id, fab_id, fab_version, json.dumps(override_config))
|
|
638
|
+
)
|
|
627
639
|
return run_id
|
|
628
640
|
log(ERROR, "Unexpected run creation failure.")
|
|
629
641
|
return 0
|
|
@@ -687,7 +699,10 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
687
699
|
try:
|
|
688
700
|
row = self.query(query, (run_id,))[0]
|
|
689
701
|
return Run(
|
|
690
|
-
run_id=run_id,
|
|
702
|
+
run_id=run_id,
|
|
703
|
+
fab_id=row["fab_id"],
|
|
704
|
+
fab_version=row["fab_version"],
|
|
705
|
+
override_config=json.loads(row["override_config"]),
|
|
691
706
|
)
|
|
692
707
|
except sqlite3.IntegrityError:
|
|
693
708
|
log(ERROR, "`run_id` does not exist.")
|
|
@@ -19,7 +19,7 @@ import abc
|
|
|
19
19
|
from typing import List, Optional, Set
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
-
from flwr.common.typing import Run
|
|
22
|
+
from flwr.common.typing import Run, UserConfig
|
|
23
23
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
24
24
|
|
|
25
25
|
|
|
@@ -157,7 +157,12 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
157
157
|
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
158
158
|
|
|
159
159
|
@abc.abstractmethod
|
|
160
|
-
def create_run(
|
|
160
|
+
def create_run(
|
|
161
|
+
self,
|
|
162
|
+
fab_id: str,
|
|
163
|
+
fab_version: str,
|
|
164
|
+
override_config: UserConfig,
|
|
165
|
+
) -> int:
|
|
161
166
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
162
167
|
|
|
163
168
|
@abc.abstractmethod
|
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
19
|
from logging import ERROR
|
|
20
|
+
from os import urandom
|
|
20
21
|
from uuid import uuid4
|
|
21
22
|
|
|
22
23
|
from flwr.common import log
|
|
@@ -31,6 +32,11 @@ NODE_UNAVAILABLE_ERROR_REASON = (
|
|
|
31
32
|
)
|
|
32
33
|
|
|
33
34
|
|
|
35
|
+
def generate_rand_int_from_bytes(num_bytes: int) -> int:
|
|
36
|
+
"""Generate a random `num_bytes` integer."""
|
|
37
|
+
return int.from_bytes(urandom(num_bytes), "little", signed=True)
|
|
38
|
+
|
|
39
|
+
|
|
34
40
|
def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
|
|
35
41
|
"""Generate a TaskRes with a node unavailable error from a TaskIns."""
|
|
36
42
|
current_time = time.time()
|