flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +16 -2
- flwr/cli/config_utils.py +36 -14
- flwr/cli/install.py +17 -1
- flwr/cli/new/new.py +31 -20
- flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
- flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
- flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
- flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
- flwr/cli/run/run.py +128 -53
- flwr/client/app.py +56 -24
- flwr/client/client_app.py +28 -8
- flwr/client/grpc_adapter_client/connection.py +3 -2
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +17 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/node_state.py +59 -12
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/connection.py +19 -8
- flwr/client/supernode/app.py +55 -24
- flwr/client/typing.py +2 -2
- flwr/common/config.py +87 -2
- flwr/common/constant.py +3 -0
- flwr/common/context.py +24 -9
- flwr/common/logger.py +25 -0
- flwr/common/serde.py +45 -0
- flwr/common/telemetry.py +17 -0
- flwr/common/typing.py +5 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +24 -19
- flwr/proto/driver_pb2.pyi +21 -1
- flwr/proto/exec_pb2.py +16 -11
- flwr/proto/exec_pb2.pyi +22 -1
- flwr/proto/run_pb2.py +12 -7
- flwr/proto/run_pb2.pyi +22 -1
- flwr/proto/task_pb2.py +7 -8
- flwr/server/__init__.py +2 -0
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/grpc_driver.py +82 -140
- flwr/server/run_serverapp.py +40 -15
- flwr/server/server_app.py +56 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/superlink/driver/driver_servicer.py +18 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
- flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +149 -117
- flwr/server/superlink/state/in_memory_state.py +11 -3
- flwr/server/superlink/state/sqlite_state.py +23 -8
- flwr/server/superlink/state/state.py +7 -2
- flwr/server/typing.py +2 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
- flwr/simulation/app.py +4 -3
- flwr/simulation/ray_transport/ray_actor.py +15 -19
- flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
- flwr/simulation/run_simulation.py +237 -66
- flwr/superexec/app.py +14 -7
- flwr/superexec/deployment.py +110 -33
- flwr/superexec/exec_grpc.py +5 -1
- flwr/superexec/exec_servicer.py +4 -1
- flwr/superexec/executor.py +18 -0
- flwr/superexec/simulation.py +151 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +92 -86
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
|
@@ -33,8 +33,8 @@ class Backend(ABC):
|
|
|
33
33
|
"""Construct a backend."""
|
|
34
34
|
|
|
35
35
|
@abstractmethod
|
|
36
|
-
|
|
37
|
-
"""Build backend
|
|
36
|
+
def build(self) -> None:
|
|
37
|
+
"""Build backend.
|
|
38
38
|
|
|
39
39
|
Different components need to be in place before workers in a backend are ready
|
|
40
40
|
to accept jobs. When this method finishes executing, the backend should be fully
|
|
@@ -54,11 +54,11 @@ class Backend(ABC):
|
|
|
54
54
|
"""Report whether a backend worker is idle and can therefore run a ClientApp."""
|
|
55
55
|
|
|
56
56
|
@abstractmethod
|
|
57
|
-
|
|
57
|
+
def terminate(self) -> None:
|
|
58
58
|
"""Terminate backend."""
|
|
59
59
|
|
|
60
60
|
@abstractmethod
|
|
61
|
-
|
|
61
|
+
def process_message(
|
|
62
62
|
self,
|
|
63
63
|
app: Callable[[], ClientApp],
|
|
64
64
|
message: Message,
|
|
@@ -21,6 +21,7 @@ from typing import Callable, Dict, List, Tuple, Union
|
|
|
21
21
|
import ray
|
|
22
22
|
|
|
23
23
|
from flwr.client.client_app import ClientApp
|
|
24
|
+
from flwr.common.constant import PARTITION_ID_KEY
|
|
24
25
|
from flwr.common.context import Context
|
|
25
26
|
from flwr.common.logger import log
|
|
26
27
|
from flwr.common.message import Message
|
|
@@ -153,12 +154,12 @@ class RayBackend(Backend):
|
|
|
153
154
|
"""Report whether the pool has idle actors."""
|
|
154
155
|
return self.pool.is_actor_available()
|
|
155
156
|
|
|
156
|
-
|
|
157
|
+
def build(self) -> None:
|
|
157
158
|
"""Build pool of Ray actors that this backend will submit jobs to."""
|
|
158
|
-
|
|
159
|
+
self.pool.add_actors_to_pool(self.pool.actors_capacity)
|
|
159
160
|
log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
|
|
160
161
|
|
|
161
|
-
|
|
162
|
+
def process_message(
|
|
162
163
|
self,
|
|
163
164
|
app: Callable[[], ClientApp],
|
|
164
165
|
message: Message,
|
|
@@ -168,21 +169,20 @@ class RayBackend(Backend):
|
|
|
168
169
|
|
|
169
170
|
Return output message and updated context.
|
|
170
171
|
"""
|
|
171
|
-
partition_id = context.
|
|
172
|
+
partition_id = context.node_config[PARTITION_ID_KEY]
|
|
172
173
|
|
|
173
174
|
try:
|
|
174
175
|
# Submit a task to the pool
|
|
175
|
-
future =
|
|
176
|
+
future = self.pool.submit(
|
|
176
177
|
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
|
|
177
178
|
(app, message, str(partition_id), context),
|
|
178
179
|
)
|
|
179
180
|
|
|
180
|
-
await future
|
|
181
181
|
# Fetch result
|
|
182
182
|
(
|
|
183
183
|
out_mssg,
|
|
184
184
|
updated_context,
|
|
185
|
-
) =
|
|
185
|
+
) = self.pool.fetch_result_and_return_actor_to_pool(future)
|
|
186
186
|
|
|
187
187
|
return out_mssg, updated_context
|
|
188
188
|
|
|
@@ -193,11 +193,11 @@ class RayBackend(Backend):
|
|
|
193
193
|
self.__class__.__name__,
|
|
194
194
|
)
|
|
195
195
|
# add actor back into pool
|
|
196
|
-
|
|
196
|
+
self.pool.add_actor_back_to_pool(future)
|
|
197
197
|
raise ex
|
|
198
198
|
|
|
199
|
-
|
|
199
|
+
def terminate(self) -> None:
|
|
200
200
|
"""Terminate all actors in actor pool."""
|
|
201
|
-
|
|
201
|
+
self.pool.terminate_all_actors()
|
|
202
202
|
ray.shutdown()
|
|
203
203
|
log(DEBUG, "Terminated %s", self.__class__.__name__)
|
|
@@ -14,24 +14,33 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Fleet Simulation Engine API."""
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
|
|
18
18
|
import json
|
|
19
|
-
import
|
|
19
|
+
import threading
|
|
20
20
|
import time
|
|
21
21
|
import traceback
|
|
22
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
22
23
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
23
24
|
from pathlib import Path
|
|
24
|
-
from
|
|
25
|
+
from queue import Empty, Queue
|
|
26
|
+
from time import sleep
|
|
27
|
+
from typing import Callable, Dict, Optional
|
|
25
28
|
|
|
26
29
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
27
30
|
from flwr.client.node_state import NodeState
|
|
28
|
-
from flwr.
|
|
31
|
+
from flwr.client.supernode.app import _get_load_client_app_fn
|
|
32
|
+
from flwr.common.constant import (
|
|
33
|
+
NUM_PARTITIONS_KEY,
|
|
34
|
+
PARTITION_ID_KEY,
|
|
35
|
+
PING_MAX_INTERVAL,
|
|
36
|
+
ErrorCode,
|
|
37
|
+
)
|
|
29
38
|
from flwr.common.logger import log
|
|
30
39
|
from flwr.common.message import Error
|
|
31
|
-
from flwr.common.object_ref import load_app
|
|
32
40
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
33
|
-
from flwr.
|
|
34
|
-
from flwr.
|
|
41
|
+
from flwr.common.typing import Run
|
|
42
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
43
|
+
from flwr.server.superlink.state import State, StateFactory
|
|
35
44
|
|
|
36
45
|
from .backend import Backend, error_messages_backends, supported_backends
|
|
37
46
|
|
|
@@ -51,31 +60,57 @@ def _register_nodes(
|
|
|
51
60
|
return nodes_mapping
|
|
52
61
|
|
|
53
62
|
|
|
63
|
+
def _register_node_states(
|
|
64
|
+
nodes_mapping: NodeToPartitionMapping,
|
|
65
|
+
run: Run,
|
|
66
|
+
app_dir: Optional[str] = None,
|
|
67
|
+
) -> Dict[int, NodeState]:
|
|
68
|
+
"""Create NodeState objects and pre-register the context for the run."""
|
|
69
|
+
node_states: Dict[int, NodeState] = {}
|
|
70
|
+
num_partitions = len(set(nodes_mapping.values()))
|
|
71
|
+
for node_id, partition_id in nodes_mapping.items():
|
|
72
|
+
node_states[node_id] = NodeState(
|
|
73
|
+
node_id=node_id,
|
|
74
|
+
node_config={
|
|
75
|
+
PARTITION_ID_KEY: str(partition_id),
|
|
76
|
+
NUM_PARTITIONS_KEY: str(num_partitions),
|
|
77
|
+
},
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Pre-register Context objects
|
|
81
|
+
node_states[node_id].register_context(
|
|
82
|
+
run_id=run.run_id, run=run, app_dir=app_dir
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return node_states
|
|
86
|
+
|
|
87
|
+
|
|
54
88
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
55
|
-
|
|
89
|
+
def worker(
|
|
56
90
|
app_fn: Callable[[], ClientApp],
|
|
57
|
-
|
|
91
|
+
taskins_queue: "Queue[TaskIns]",
|
|
92
|
+
taskres_queue: "Queue[TaskRes]",
|
|
58
93
|
node_states: Dict[int, NodeState],
|
|
59
|
-
state_factory: StateFactory,
|
|
60
94
|
backend: Backend,
|
|
95
|
+
f_stop: threading.Event,
|
|
61
96
|
) -> None:
|
|
62
97
|
"""Get TaskIns from queue and pass it to an actor in the pool to execute it."""
|
|
63
|
-
|
|
64
|
-
while True:
|
|
98
|
+
while not f_stop.is_set():
|
|
65
99
|
out_mssg = None
|
|
66
100
|
try:
|
|
67
|
-
|
|
101
|
+
# Fetch from queue with timeout. We use a timeout so
|
|
102
|
+
# the stopping event can be evaluated even when the queue is empty.
|
|
103
|
+
task_ins: TaskIns = taskins_queue.get(timeout=1.0)
|
|
68
104
|
node_id = task_ins.task.consumer.node_id
|
|
69
105
|
|
|
70
|
-
#
|
|
71
|
-
node_states[node_id].register_context(run_id=task_ins.run_id)
|
|
106
|
+
# Retrieve context
|
|
72
107
|
context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
|
|
73
108
|
|
|
74
109
|
# Convert TaskIns to Message
|
|
75
110
|
message = message_from_taskins(task_ins)
|
|
76
111
|
|
|
77
112
|
# Let backend process message
|
|
78
|
-
out_mssg, updated_context =
|
|
113
|
+
out_mssg, updated_context = backend.process_message(
|
|
79
114
|
app_fn, message, context
|
|
80
115
|
)
|
|
81
116
|
|
|
@@ -83,11 +118,9 @@ async def worker(
|
|
|
83
118
|
node_states[node_id].update_context(
|
|
84
119
|
task_ins.run_id, context=updated_context
|
|
85
120
|
)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
break
|
|
90
|
-
|
|
121
|
+
except Empty:
|
|
122
|
+
# An exception raised if queue.get times out
|
|
123
|
+
pass
|
|
91
124
|
# Exceptions aren't raised but reported as an error message
|
|
92
125
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
93
126
|
log(ERROR, ex)
|
|
@@ -111,67 +144,48 @@ async def worker(
|
|
|
111
144
|
task_res = message_to_taskres(out_mssg)
|
|
112
145
|
# Store TaskRes in state
|
|
113
146
|
task_res.task.pushed_at = time.time()
|
|
114
|
-
|
|
147
|
+
taskres_queue.put(task_res)
|
|
115
148
|
|
|
116
149
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
150
|
+
def add_taskins_to_queue(
|
|
151
|
+
state: State,
|
|
152
|
+
queue: "Queue[TaskIns]",
|
|
120
153
|
nodes_mapping: NodeToPartitionMapping,
|
|
121
|
-
|
|
122
|
-
consumers: List["asyncio.Task[None]"],
|
|
123
|
-
f_stop: asyncio.Event,
|
|
154
|
+
f_stop: threading.Event,
|
|
124
155
|
) -> None:
|
|
125
|
-
"""
|
|
126
|
-
state = state_factory.state()
|
|
127
|
-
num_initial_consumers = len(consumers)
|
|
156
|
+
"""Put TaskIns in a queue from State."""
|
|
128
157
|
while not f_stop.is_set():
|
|
129
158
|
for node_id in nodes_mapping.keys():
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
# Count consumers that are running
|
|
135
|
-
num_active = sum(not (cc.done()) for cc in consumers)
|
|
136
|
-
|
|
137
|
-
# Alert if number of consumers decreased by half
|
|
138
|
-
if num_active < num_initial_consumers // 2:
|
|
139
|
-
log(
|
|
140
|
-
WARN,
|
|
141
|
-
"Number of active workers has more than halved: (%i/%i active)",
|
|
142
|
-
num_active,
|
|
143
|
-
num_initial_consumers,
|
|
144
|
-
)
|
|
159
|
+
task_ins_list = state.get_task_ins(node_id=node_id, limit=1)
|
|
160
|
+
for task_ins in task_ins_list:
|
|
161
|
+
queue.put(task_ins)
|
|
162
|
+
sleep(0.1)
|
|
145
163
|
|
|
146
|
-
# Break if consumers died
|
|
147
|
-
if num_active == 0:
|
|
148
|
-
raise RuntimeError("All workers have died. Ending Simulation.")
|
|
149
164
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
queue
|
|
160
|
-
|
|
161
|
-
await asyncio.sleep(1.0)
|
|
162
|
-
log(DEBUG, "Async producer: Stopped pulling from StateFactory.")
|
|
165
|
+
def put_taskres_into_state(
|
|
166
|
+
state: State, queue: "Queue[TaskRes]", f_stop: threading.Event
|
|
167
|
+
) -> None:
|
|
168
|
+
"""Put TaskRes into State from a queue."""
|
|
169
|
+
while not f_stop.is_set():
|
|
170
|
+
try:
|
|
171
|
+
taskres = queue.get(timeout=1.0)
|
|
172
|
+
state.store_task_res(taskres)
|
|
173
|
+
except Empty:
|
|
174
|
+
# queue is empty when timeout was triggered
|
|
175
|
+
pass
|
|
163
176
|
|
|
164
177
|
|
|
165
|
-
|
|
178
|
+
def run_api(
|
|
166
179
|
app_fn: Callable[[], ClientApp],
|
|
167
180
|
backend_fn: Callable[[], Backend],
|
|
168
181
|
nodes_mapping: NodeToPartitionMapping,
|
|
169
182
|
state_factory: StateFactory,
|
|
170
183
|
node_states: Dict[int, NodeState],
|
|
171
|
-
f_stop:
|
|
184
|
+
f_stop: threading.Event,
|
|
172
185
|
) -> None:
|
|
173
|
-
"""Run the VCE
|
|
174
|
-
|
|
186
|
+
"""Run the VCE."""
|
|
187
|
+
taskins_queue: "Queue[TaskIns]" = Queue()
|
|
188
|
+
taskres_queue: "Queue[TaskRes]" = Queue()
|
|
175
189
|
|
|
176
190
|
try:
|
|
177
191
|
|
|
@@ -179,27 +193,48 @@ async def run(
|
|
|
179
193
|
backend = backend_fn()
|
|
180
194
|
|
|
181
195
|
# Build backend
|
|
182
|
-
|
|
196
|
+
backend.build()
|
|
183
197
|
|
|
184
198
|
# Add workers (they submit Messages to Backend)
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
)
|
|
199
|
+
state = state_factory.state()
|
|
200
|
+
|
|
201
|
+
extractor_th = threading.Thread(
|
|
202
|
+
target=add_taskins_to_queue,
|
|
203
|
+
args=(
|
|
204
|
+
state,
|
|
205
|
+
taskins_queue,
|
|
206
|
+
nodes_mapping,
|
|
207
|
+
f_stop,
|
|
208
|
+
),
|
|
196
209
|
)
|
|
210
|
+
extractor_th.start()
|
|
197
211
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
212
|
+
injector_th = threading.Thread(
|
|
213
|
+
target=put_taskres_into_state,
|
|
214
|
+
args=(
|
|
215
|
+
state,
|
|
216
|
+
taskres_queue,
|
|
217
|
+
f_stop,
|
|
218
|
+
),
|
|
219
|
+
)
|
|
220
|
+
injector_th.start()
|
|
221
|
+
|
|
222
|
+
with ThreadPoolExecutor() as executor:
|
|
223
|
+
_ = [
|
|
224
|
+
executor.submit(
|
|
225
|
+
worker,
|
|
226
|
+
app_fn,
|
|
227
|
+
taskins_queue,
|
|
228
|
+
taskres_queue,
|
|
229
|
+
node_states,
|
|
230
|
+
backend,
|
|
231
|
+
f_stop,
|
|
232
|
+
)
|
|
233
|
+
for _ in range(backend.num_workers)
|
|
234
|
+
]
|
|
235
|
+
|
|
236
|
+
extractor_th.join()
|
|
237
|
+
injector_th.join()
|
|
203
238
|
|
|
204
239
|
except Exception as ex:
|
|
205
240
|
|
|
@@ -214,18 +249,9 @@ async def run(
|
|
|
214
249
|
raise RuntimeError("Simulation Engine crashed.") from ex
|
|
215
250
|
|
|
216
251
|
finally:
|
|
217
|
-
# Produced task terminated, now cancel worker tasks
|
|
218
|
-
for w_t in worker_tasks:
|
|
219
|
-
_ = w_t.cancel()
|
|
220
|
-
|
|
221
|
-
while not all(w_t.done() for w_t in worker_tasks):
|
|
222
|
-
log(DEBUG, "Terminating async workers...")
|
|
223
|
-
await asyncio.sleep(0.5)
|
|
224
|
-
|
|
225
|
-
await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])
|
|
226
252
|
|
|
227
253
|
# Terminate backend
|
|
228
|
-
|
|
254
|
+
backend.terminate()
|
|
229
255
|
|
|
230
256
|
|
|
231
257
|
# pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
|
|
@@ -234,7 +260,10 @@ def start_vce(
|
|
|
234
260
|
backend_name: str,
|
|
235
261
|
backend_config_json_stream: str,
|
|
236
262
|
app_dir: str,
|
|
237
|
-
|
|
263
|
+
is_app: bool,
|
|
264
|
+
f_stop: threading.Event,
|
|
265
|
+
run: Run,
|
|
266
|
+
flwr_dir: Optional[str] = None,
|
|
238
267
|
client_app: Optional[ClientApp] = None,
|
|
239
268
|
client_app_attr: Optional[str] = None,
|
|
240
269
|
num_supernodes: Optional[int] = None,
|
|
@@ -285,9 +314,9 @@ def start_vce(
|
|
|
285
314
|
)
|
|
286
315
|
|
|
287
316
|
# Construct mapping of NodeStates
|
|
288
|
-
node_states
|
|
289
|
-
|
|
290
|
-
|
|
317
|
+
node_states = _register_node_states(
|
|
318
|
+
nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None
|
|
319
|
+
)
|
|
291
320
|
|
|
292
321
|
# Load backend config
|
|
293
322
|
log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
|
|
@@ -316,16 +345,12 @@ def start_vce(
|
|
|
316
345
|
def _load() -> ClientApp:
|
|
317
346
|
|
|
318
347
|
if client_app_attr:
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
if not isinstance(app, ClientApp):
|
|
326
|
-
raise LoadClientAppError(
|
|
327
|
-
f"Attribute {client_app_attr} is not of type {ClientApp}",
|
|
328
|
-
) from None
|
|
348
|
+
app = _get_load_client_app_fn(
|
|
349
|
+
default_app_ref=client_app_attr,
|
|
350
|
+
dir_arg=app_dir,
|
|
351
|
+
flwr_dir_arg=flwr_dir,
|
|
352
|
+
multi_app=True,
|
|
353
|
+
)(run.fab_id, run.fab_version)
|
|
329
354
|
|
|
330
355
|
if client_app:
|
|
331
356
|
app = client_app
|
|
@@ -335,18 +360,25 @@ def start_vce(
|
|
|
335
360
|
|
|
336
361
|
try:
|
|
337
362
|
# Test if ClientApp can be loaded
|
|
338
|
-
|
|
363
|
+
client_app = app_fn()
|
|
364
|
+
|
|
365
|
+
# Cache `ClientApp`
|
|
366
|
+
if client_app_attr:
|
|
367
|
+
# Now wrap the loaded ClientApp in a dummy function
|
|
368
|
+
# this prevent unnecesary low-level loading of ClientApp
|
|
369
|
+
def _load_client_app() -> ClientApp:
|
|
370
|
+
return client_app
|
|
371
|
+
|
|
372
|
+
app_fn = _load_client_app
|
|
339
373
|
|
|
340
374
|
# Run main simulation loop
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
f_stop,
|
|
349
|
-
)
|
|
375
|
+
run_api(
|
|
376
|
+
app_fn,
|
|
377
|
+
backend_fn,
|
|
378
|
+
nodes_mapping,
|
|
379
|
+
state_factory,
|
|
380
|
+
node_states,
|
|
381
|
+
f_stop,
|
|
350
382
|
)
|
|
351
383
|
except LoadClientAppError as loadapp_ex:
|
|
352
384
|
f_stop_delay = 10
|
|
@@ -23,7 +23,7 @@ from uuid import UUID, uuid4
|
|
|
23
23
|
|
|
24
24
|
from flwr.common import log, now
|
|
25
25
|
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
|
|
26
|
-
from flwr.common.typing import Run
|
|
26
|
+
from flwr.common.typing import Run, UserConfig
|
|
27
27
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
28
28
|
from flwr.server.superlink.state.state import State
|
|
29
29
|
from flwr.server.utils import validate_task_ins_or_res
|
|
@@ -275,7 +275,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
275
275
|
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
276
276
|
return self.public_key_to_node_id.get(client_public_key)
|
|
277
277
|
|
|
278
|
-
def create_run(
|
|
278
|
+
def create_run(
|
|
279
|
+
self,
|
|
280
|
+
fab_id: str,
|
|
281
|
+
fab_version: str,
|
|
282
|
+
override_config: UserConfig,
|
|
283
|
+
) -> int:
|
|
279
284
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
280
285
|
# Sample a random int64 as run_id
|
|
281
286
|
with self.lock:
|
|
@@ -283,7 +288,10 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
283
288
|
|
|
284
289
|
if run_id not in self.run_ids:
|
|
285
290
|
self.run_ids[run_id] = Run(
|
|
286
|
-
run_id=run_id,
|
|
291
|
+
run_id=run_id,
|
|
292
|
+
fab_id=fab_id,
|
|
293
|
+
fab_version=fab_version,
|
|
294
|
+
override_config=override_config,
|
|
287
295
|
)
|
|
288
296
|
return run_id
|
|
289
297
|
log(ERROR, "Unexpected run creation failure.")
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""SQLite based implemenation of server state."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import json
|
|
18
19
|
import re
|
|
19
20
|
import sqlite3
|
|
20
21
|
import time
|
|
@@ -24,7 +25,7 @@ from uuid import UUID, uuid4
|
|
|
24
25
|
|
|
25
26
|
from flwr.common import log, now
|
|
26
27
|
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
|
|
27
|
-
from flwr.common.typing import Run
|
|
28
|
+
from flwr.common.typing import Run, UserConfig
|
|
28
29
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
29
30
|
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
|
|
30
31
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
@@ -61,9 +62,10 @@ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
|
|
61
62
|
|
|
62
63
|
SQL_CREATE_TABLE_RUN = """
|
|
63
64
|
CREATE TABLE IF NOT EXISTS run(
|
|
64
|
-
run_id
|
|
65
|
-
fab_id
|
|
66
|
-
fab_version
|
|
65
|
+
run_id INTEGER UNIQUE,
|
|
66
|
+
fab_id TEXT,
|
|
67
|
+
fab_version TEXT,
|
|
68
|
+
override_config TEXT
|
|
67
69
|
);
|
|
68
70
|
"""
|
|
69
71
|
|
|
@@ -613,7 +615,12 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
613
615
|
return node_id
|
|
614
616
|
return None
|
|
615
617
|
|
|
616
|
-
def create_run(
|
|
618
|
+
def create_run(
|
|
619
|
+
self,
|
|
620
|
+
fab_id: str,
|
|
621
|
+
fab_version: str,
|
|
622
|
+
override_config: UserConfig,
|
|
623
|
+
) -> int:
|
|
617
624
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
618
625
|
# Sample a random int64 as run_id
|
|
619
626
|
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
@@ -622,8 +629,13 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
622
629
|
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
|
|
623
630
|
# If run_id does not exist
|
|
624
631
|
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
|
|
625
|
-
query =
|
|
626
|
-
|
|
632
|
+
query = (
|
|
633
|
+
"INSERT INTO run (run_id, fab_id, fab_version, override_config)"
|
|
634
|
+
"VALUES (?, ?, ?, ?);"
|
|
635
|
+
)
|
|
636
|
+
self.query(
|
|
637
|
+
query, (run_id, fab_id, fab_version, json.dumps(override_config))
|
|
638
|
+
)
|
|
627
639
|
return run_id
|
|
628
640
|
log(ERROR, "Unexpected run creation failure.")
|
|
629
641
|
return 0
|
|
@@ -687,7 +699,10 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
687
699
|
try:
|
|
688
700
|
row = self.query(query, (run_id,))[0]
|
|
689
701
|
return Run(
|
|
690
|
-
run_id=run_id,
|
|
702
|
+
run_id=run_id,
|
|
703
|
+
fab_id=row["fab_id"],
|
|
704
|
+
fab_version=row["fab_version"],
|
|
705
|
+
override_config=json.loads(row["override_config"]),
|
|
691
706
|
)
|
|
692
707
|
except sqlite3.IntegrityError:
|
|
693
708
|
log(ERROR, "`run_id` does not exist.")
|
|
@@ -19,7 +19,7 @@ import abc
|
|
|
19
19
|
from typing import List, Optional, Set
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
-
from flwr.common.typing import Run
|
|
22
|
+
from flwr.common.typing import Run, UserConfig
|
|
23
23
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
24
24
|
|
|
25
25
|
|
|
@@ -157,7 +157,12 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
157
157
|
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
158
158
|
|
|
159
159
|
@abc.abstractmethod
|
|
160
|
-
def create_run(
|
|
160
|
+
def create_run(
|
|
161
|
+
self,
|
|
162
|
+
fab_id: str,
|
|
163
|
+
fab_version: str,
|
|
164
|
+
override_config: UserConfig,
|
|
165
|
+
) -> int:
|
|
161
166
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
162
167
|
|
|
163
168
|
@abc.abstractmethod
|
flwr/server/typing.py
CHANGED
|
@@ -20,6 +20,8 @@ from typing import Callable
|
|
|
20
20
|
from flwr.common import Context
|
|
21
21
|
|
|
22
22
|
from .driver import Driver
|
|
23
|
+
from .serverapp_components import ServerAppComponents
|
|
23
24
|
|
|
24
25
|
ServerAppCallable = Callable[[Driver, Context], None]
|
|
25
26
|
Workflow = Callable[[Driver, Context], None]
|
|
27
|
+
ServerFn = Callable[[Context], ServerAppComponents]
|
|
@@ -81,6 +81,7 @@ class WorkflowState: # pylint: disable=R0902
|
|
|
81
81
|
forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
|
|
82
82
|
aggregate_ndarrays: NDArrays = field(default_factory=list)
|
|
83
83
|
legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)
|
|
84
|
+
failures: List[Exception] = field(default_factory=list)
|
|
84
85
|
|
|
85
86
|
|
|
86
87
|
class SecAggPlusWorkflow:
|
|
@@ -394,6 +395,7 @@ class SecAggPlusWorkflow:
|
|
|
394
395
|
|
|
395
396
|
for msg in msgs:
|
|
396
397
|
if msg.has_error():
|
|
398
|
+
state.failures.append(Exception(msg.error))
|
|
397
399
|
continue
|
|
398
400
|
key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
399
401
|
node_id = msg.metadata.src_node_id
|
|
@@ -451,6 +453,9 @@ class SecAggPlusWorkflow:
|
|
|
451
453
|
nid: [] for nid in state.active_node_ids
|
|
452
454
|
} # dest node ID -> list of src node IDs
|
|
453
455
|
for msg in msgs:
|
|
456
|
+
if msg.has_error():
|
|
457
|
+
state.failures.append(Exception(msg.error))
|
|
458
|
+
continue
|
|
454
459
|
node_id = msg.metadata.src_node_id
|
|
455
460
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
456
461
|
dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST])
|
|
@@ -515,6 +520,9 @@ class SecAggPlusWorkflow:
|
|
|
515
520
|
# Sum collected masked vectors and compute active/dead node IDs
|
|
516
521
|
masked_vector = None
|
|
517
522
|
for msg in msgs:
|
|
523
|
+
if msg.has_error():
|
|
524
|
+
state.failures.append(Exception(msg.error))
|
|
525
|
+
continue
|
|
518
526
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
519
527
|
bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS])
|
|
520
528
|
client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
|
|
@@ -528,6 +536,9 @@ class SecAggPlusWorkflow:
|
|
|
528
536
|
|
|
529
537
|
# Backward compatibility with Strategy
|
|
530
538
|
for msg in msgs:
|
|
539
|
+
if msg.has_error():
|
|
540
|
+
state.failures.append(Exception(msg.error))
|
|
541
|
+
continue
|
|
531
542
|
fitres = compat.recordset_to_fitres(msg.content, True)
|
|
532
543
|
proxy = state.nid_to_proxies[msg.metadata.src_node_id]
|
|
533
544
|
state.legacy_results.append((proxy, fitres))
|
|
@@ -584,6 +595,9 @@ class SecAggPlusWorkflow:
|
|
|
584
595
|
for nid in state.sampled_node_ids:
|
|
585
596
|
collected_shares_dict[nid] = []
|
|
586
597
|
for msg in msgs:
|
|
598
|
+
if msg.has_error():
|
|
599
|
+
state.failures.append(Exception(msg.error))
|
|
600
|
+
continue
|
|
587
601
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
588
602
|
nids = cast(List[int], res_dict[Key.NODE_ID_LIST])
|
|
589
603
|
shares = cast(List[bytes], res_dict[Key.SHARE_LIST])
|
|
@@ -652,9 +666,11 @@ class SecAggPlusWorkflow:
|
|
|
652
666
|
INFO,
|
|
653
667
|
"aggregate_fit: received %s results and %s failures",
|
|
654
668
|
len(results),
|
|
655
|
-
|
|
669
|
+
len(state.failures),
|
|
670
|
+
)
|
|
671
|
+
aggregated_result = context.strategy.aggregate_fit(
|
|
672
|
+
current_round, results, state.failures # type: ignore
|
|
656
673
|
)
|
|
657
|
-
aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
|
|
658
674
|
parameters_aggregated, metrics_aggregated = aggregated_result
|
|
659
675
|
|
|
660
676
|
# Update the parameters and write history
|