flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.11.0.dev20240724__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +16 -2
- flwr/cli/config_utils.py +47 -27
- flwr/cli/install.py +17 -1
- flwr/cli/new/new.py +32 -21
- flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +15 -5
- flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +16 -5
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +25 -5
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +22 -19
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +16 -8
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +12 -7
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +16 -8
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -10
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
- flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +14 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +14 -2
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -3
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
- flwr/cli/run/run.py +133 -54
- flwr/client/app.py +56 -24
- flwr/client/client_app.py +28 -8
- flwr/client/grpc_adapter_client/connection.py +3 -2
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +17 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/node_state.py +59 -12
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/connection.py +19 -8
- flwr/client/supernode/app.py +39 -39
- flwr/client/typing.py +2 -2
- flwr/common/config.py +92 -2
- flwr/common/constant.py +3 -0
- flwr/common/context.py +24 -9
- flwr/common/logger.py +25 -0
- flwr/common/object_ref.py +84 -21
- flwr/common/serde.py +45 -0
- flwr/common/telemetry.py +17 -0
- flwr/common/typing.py +5 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +24 -19
- flwr/proto/driver_pb2.pyi +21 -1
- flwr/proto/exec_pb2.py +20 -11
- flwr/proto/exec_pb2.pyi +41 -1
- flwr/proto/run_pb2.py +12 -7
- flwr/proto/run_pb2.pyi +22 -1
- flwr/proto/task_pb2.py +7 -8
- flwr/server/__init__.py +2 -0
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/grpc_driver.py +82 -140
- flwr/server/run_serverapp.py +40 -18
- flwr/server/server_app.py +56 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/superlink/driver/driver_servicer.py +18 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
- flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +149 -117
- flwr/server/superlink/state/in_memory_state.py +11 -3
- flwr/server/superlink/state/sqlite_state.py +23 -8
- flwr/server/superlink/state/state.py +7 -2
- flwr/server/typing.py +2 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +4 -3
- flwr/simulation/ray_transport/ray_actor.py +15 -19
- flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
- flwr/simulation/run_simulation.py +269 -70
- flwr/superexec/app.py +17 -11
- flwr/superexec/deployment.py +111 -35
- flwr/superexec/exec_grpc.py +5 -1
- flwr/superexec/exec_servicer.py +6 -1
- flwr/superexec/executor.py +21 -0
- flwr/superexec/simulation.py +181 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/METADATA +3 -2
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/RECORD +97 -91
- flwr/cli/new/templates/app/code/server.hf.py.tpl +0 -17
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +0 -37
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/entry_points.txt +0 -0
|
@@ -14,24 +14,33 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Fleet Simulation Engine API."""
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
|
|
18
18
|
import json
|
|
19
|
-
import
|
|
19
|
+
import threading
|
|
20
20
|
import time
|
|
21
21
|
import traceback
|
|
22
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
22
23
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
23
24
|
from pathlib import Path
|
|
24
|
-
from
|
|
25
|
+
from queue import Empty, Queue
|
|
26
|
+
from time import sleep
|
|
27
|
+
from typing import Callable, Dict, Optional
|
|
25
28
|
|
|
26
29
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
27
30
|
from flwr.client.node_state import NodeState
|
|
28
|
-
from flwr.
|
|
31
|
+
from flwr.client.supernode.app import _get_load_client_app_fn
|
|
32
|
+
from flwr.common.constant import (
|
|
33
|
+
NUM_PARTITIONS_KEY,
|
|
34
|
+
PARTITION_ID_KEY,
|
|
35
|
+
PING_MAX_INTERVAL,
|
|
36
|
+
ErrorCode,
|
|
37
|
+
)
|
|
29
38
|
from flwr.common.logger import log
|
|
30
39
|
from flwr.common.message import Error
|
|
31
|
-
from flwr.common.object_ref import load_app
|
|
32
40
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
33
|
-
from flwr.
|
|
34
|
-
from flwr.
|
|
41
|
+
from flwr.common.typing import Run
|
|
42
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
43
|
+
from flwr.server.superlink.state import State, StateFactory
|
|
35
44
|
|
|
36
45
|
from .backend import Backend, error_messages_backends, supported_backends
|
|
37
46
|
|
|
@@ -51,31 +60,57 @@ def _register_nodes(
|
|
|
51
60
|
return nodes_mapping
|
|
52
61
|
|
|
53
62
|
|
|
63
|
+
def _register_node_states(
|
|
64
|
+
nodes_mapping: NodeToPartitionMapping,
|
|
65
|
+
run: Run,
|
|
66
|
+
app_dir: Optional[str] = None,
|
|
67
|
+
) -> Dict[int, NodeState]:
|
|
68
|
+
"""Create NodeState objects and pre-register the context for the run."""
|
|
69
|
+
node_states: Dict[int, NodeState] = {}
|
|
70
|
+
num_partitions = len(set(nodes_mapping.values()))
|
|
71
|
+
for node_id, partition_id in nodes_mapping.items():
|
|
72
|
+
node_states[node_id] = NodeState(
|
|
73
|
+
node_id=node_id,
|
|
74
|
+
node_config={
|
|
75
|
+
PARTITION_ID_KEY: partition_id,
|
|
76
|
+
NUM_PARTITIONS_KEY: num_partitions,
|
|
77
|
+
},
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Pre-register Context objects
|
|
81
|
+
node_states[node_id].register_context(
|
|
82
|
+
run_id=run.run_id, run=run, app_dir=app_dir
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return node_states
|
|
86
|
+
|
|
87
|
+
|
|
54
88
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
55
|
-
|
|
89
|
+
def worker(
|
|
56
90
|
app_fn: Callable[[], ClientApp],
|
|
57
|
-
|
|
91
|
+
taskins_queue: "Queue[TaskIns]",
|
|
92
|
+
taskres_queue: "Queue[TaskRes]",
|
|
58
93
|
node_states: Dict[int, NodeState],
|
|
59
|
-
state_factory: StateFactory,
|
|
60
94
|
backend: Backend,
|
|
95
|
+
f_stop: threading.Event,
|
|
61
96
|
) -> None:
|
|
62
97
|
"""Get TaskIns from queue and pass it to an actor in the pool to execute it."""
|
|
63
|
-
|
|
64
|
-
while True:
|
|
98
|
+
while not f_stop.is_set():
|
|
65
99
|
out_mssg = None
|
|
66
100
|
try:
|
|
67
|
-
|
|
101
|
+
# Fetch from queue with timeout. We use a timeout so
|
|
102
|
+
# the stopping event can be evaluated even when the queue is empty.
|
|
103
|
+
task_ins: TaskIns = taskins_queue.get(timeout=1.0)
|
|
68
104
|
node_id = task_ins.task.consumer.node_id
|
|
69
105
|
|
|
70
|
-
#
|
|
71
|
-
node_states[node_id].register_context(run_id=task_ins.run_id)
|
|
106
|
+
# Retrieve context
|
|
72
107
|
context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
|
|
73
108
|
|
|
74
109
|
# Convert TaskIns to Message
|
|
75
110
|
message = message_from_taskins(task_ins)
|
|
76
111
|
|
|
77
112
|
# Let backend process message
|
|
78
|
-
out_mssg, updated_context =
|
|
113
|
+
out_mssg, updated_context = backend.process_message(
|
|
79
114
|
app_fn, message, context
|
|
80
115
|
)
|
|
81
116
|
|
|
@@ -83,11 +118,9 @@ async def worker(
|
|
|
83
118
|
node_states[node_id].update_context(
|
|
84
119
|
task_ins.run_id, context=updated_context
|
|
85
120
|
)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
break
|
|
90
|
-
|
|
121
|
+
except Empty:
|
|
122
|
+
# An exception raised if queue.get times out
|
|
123
|
+
pass
|
|
91
124
|
# Exceptions aren't raised but reported as an error message
|
|
92
125
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
93
126
|
log(ERROR, ex)
|
|
@@ -111,67 +144,48 @@ async def worker(
|
|
|
111
144
|
task_res = message_to_taskres(out_mssg)
|
|
112
145
|
# Store TaskRes in state
|
|
113
146
|
task_res.task.pushed_at = time.time()
|
|
114
|
-
|
|
147
|
+
taskres_queue.put(task_res)
|
|
115
148
|
|
|
116
149
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
150
|
+
def add_taskins_to_queue(
|
|
151
|
+
state: State,
|
|
152
|
+
queue: "Queue[TaskIns]",
|
|
120
153
|
nodes_mapping: NodeToPartitionMapping,
|
|
121
|
-
|
|
122
|
-
consumers: List["asyncio.Task[None]"],
|
|
123
|
-
f_stop: asyncio.Event,
|
|
154
|
+
f_stop: threading.Event,
|
|
124
155
|
) -> None:
|
|
125
|
-
"""
|
|
126
|
-
state = state_factory.state()
|
|
127
|
-
num_initial_consumers = len(consumers)
|
|
156
|
+
"""Put TaskIns in a queue from State."""
|
|
128
157
|
while not f_stop.is_set():
|
|
129
158
|
for node_id in nodes_mapping.keys():
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
# Count consumers that are running
|
|
135
|
-
num_active = sum(not (cc.done()) for cc in consumers)
|
|
136
|
-
|
|
137
|
-
# Alert if number of consumers decreased by half
|
|
138
|
-
if num_active < num_initial_consumers // 2:
|
|
139
|
-
log(
|
|
140
|
-
WARN,
|
|
141
|
-
"Number of active workers has more than halved: (%i/%i active)",
|
|
142
|
-
num_active,
|
|
143
|
-
num_initial_consumers,
|
|
144
|
-
)
|
|
159
|
+
task_ins_list = state.get_task_ins(node_id=node_id, limit=1)
|
|
160
|
+
for task_ins in task_ins_list:
|
|
161
|
+
queue.put(task_ins)
|
|
162
|
+
sleep(0.1)
|
|
145
163
|
|
|
146
|
-
# Break if consumers died
|
|
147
|
-
if num_active == 0:
|
|
148
|
-
raise RuntimeError("All workers have died. Ending Simulation.")
|
|
149
164
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
queue
|
|
160
|
-
|
|
161
|
-
await asyncio.sleep(1.0)
|
|
162
|
-
log(DEBUG, "Async producer: Stopped pulling from StateFactory.")
|
|
165
|
+
def put_taskres_into_state(
|
|
166
|
+
state: State, queue: "Queue[TaskRes]", f_stop: threading.Event
|
|
167
|
+
) -> None:
|
|
168
|
+
"""Put TaskRes into State from a queue."""
|
|
169
|
+
while not f_stop.is_set():
|
|
170
|
+
try:
|
|
171
|
+
taskres = queue.get(timeout=1.0)
|
|
172
|
+
state.store_task_res(taskres)
|
|
173
|
+
except Empty:
|
|
174
|
+
# queue is empty when timeout was triggered
|
|
175
|
+
pass
|
|
163
176
|
|
|
164
177
|
|
|
165
|
-
|
|
178
|
+
def run_api(
|
|
166
179
|
app_fn: Callable[[], ClientApp],
|
|
167
180
|
backend_fn: Callable[[], Backend],
|
|
168
181
|
nodes_mapping: NodeToPartitionMapping,
|
|
169
182
|
state_factory: StateFactory,
|
|
170
183
|
node_states: Dict[int, NodeState],
|
|
171
|
-
f_stop:
|
|
184
|
+
f_stop: threading.Event,
|
|
172
185
|
) -> None:
|
|
173
|
-
"""Run the VCE
|
|
174
|
-
|
|
186
|
+
"""Run the VCE."""
|
|
187
|
+
taskins_queue: "Queue[TaskIns]" = Queue()
|
|
188
|
+
taskres_queue: "Queue[TaskRes]" = Queue()
|
|
175
189
|
|
|
176
190
|
try:
|
|
177
191
|
|
|
@@ -179,27 +193,48 @@ async def run(
|
|
|
179
193
|
backend = backend_fn()
|
|
180
194
|
|
|
181
195
|
# Build backend
|
|
182
|
-
|
|
196
|
+
backend.build()
|
|
183
197
|
|
|
184
198
|
# Add workers (they submit Messages to Backend)
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
)
|
|
199
|
+
state = state_factory.state()
|
|
200
|
+
|
|
201
|
+
extractor_th = threading.Thread(
|
|
202
|
+
target=add_taskins_to_queue,
|
|
203
|
+
args=(
|
|
204
|
+
state,
|
|
205
|
+
taskins_queue,
|
|
206
|
+
nodes_mapping,
|
|
207
|
+
f_stop,
|
|
208
|
+
),
|
|
196
209
|
)
|
|
210
|
+
extractor_th.start()
|
|
197
211
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
212
|
+
injector_th = threading.Thread(
|
|
213
|
+
target=put_taskres_into_state,
|
|
214
|
+
args=(
|
|
215
|
+
state,
|
|
216
|
+
taskres_queue,
|
|
217
|
+
f_stop,
|
|
218
|
+
),
|
|
219
|
+
)
|
|
220
|
+
injector_th.start()
|
|
221
|
+
|
|
222
|
+
with ThreadPoolExecutor() as executor:
|
|
223
|
+
_ = [
|
|
224
|
+
executor.submit(
|
|
225
|
+
worker,
|
|
226
|
+
app_fn,
|
|
227
|
+
taskins_queue,
|
|
228
|
+
taskres_queue,
|
|
229
|
+
node_states,
|
|
230
|
+
backend,
|
|
231
|
+
f_stop,
|
|
232
|
+
)
|
|
233
|
+
for _ in range(backend.num_workers)
|
|
234
|
+
]
|
|
235
|
+
|
|
236
|
+
extractor_th.join()
|
|
237
|
+
injector_th.join()
|
|
203
238
|
|
|
204
239
|
except Exception as ex:
|
|
205
240
|
|
|
@@ -214,18 +249,9 @@ async def run(
|
|
|
214
249
|
raise RuntimeError("Simulation Engine crashed.") from ex
|
|
215
250
|
|
|
216
251
|
finally:
|
|
217
|
-
# Produced task terminated, now cancel worker tasks
|
|
218
|
-
for w_t in worker_tasks:
|
|
219
|
-
_ = w_t.cancel()
|
|
220
|
-
|
|
221
|
-
while not all(w_t.done() for w_t in worker_tasks):
|
|
222
|
-
log(DEBUG, "Terminating async workers...")
|
|
223
|
-
await asyncio.sleep(0.5)
|
|
224
|
-
|
|
225
|
-
await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])
|
|
226
252
|
|
|
227
253
|
# Terminate backend
|
|
228
|
-
|
|
254
|
+
backend.terminate()
|
|
229
255
|
|
|
230
256
|
|
|
231
257
|
# pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
|
|
@@ -234,7 +260,10 @@ def start_vce(
|
|
|
234
260
|
backend_name: str,
|
|
235
261
|
backend_config_json_stream: str,
|
|
236
262
|
app_dir: str,
|
|
237
|
-
|
|
263
|
+
is_app: bool,
|
|
264
|
+
f_stop: threading.Event,
|
|
265
|
+
run: Run,
|
|
266
|
+
flwr_dir: Optional[str] = None,
|
|
238
267
|
client_app: Optional[ClientApp] = None,
|
|
239
268
|
client_app_attr: Optional[str] = None,
|
|
240
269
|
num_supernodes: Optional[int] = None,
|
|
@@ -285,9 +314,9 @@ def start_vce(
|
|
|
285
314
|
)
|
|
286
315
|
|
|
287
316
|
# Construct mapping of NodeStates
|
|
288
|
-
node_states
|
|
289
|
-
|
|
290
|
-
|
|
317
|
+
node_states = _register_node_states(
|
|
318
|
+
nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None
|
|
319
|
+
)
|
|
291
320
|
|
|
292
321
|
# Load backend config
|
|
293
322
|
log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
|
|
@@ -316,16 +345,12 @@ def start_vce(
|
|
|
316
345
|
def _load() -> ClientApp:
|
|
317
346
|
|
|
318
347
|
if client_app_attr:
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
if not isinstance(app, ClientApp):
|
|
326
|
-
raise LoadClientAppError(
|
|
327
|
-
f"Attribute {client_app_attr} is not of type {ClientApp}",
|
|
328
|
-
) from None
|
|
348
|
+
app = _get_load_client_app_fn(
|
|
349
|
+
default_app_ref=client_app_attr,
|
|
350
|
+
project_dir=app_dir,
|
|
351
|
+
flwr_dir=flwr_dir,
|
|
352
|
+
multi_app=True,
|
|
353
|
+
)(run.fab_id, run.fab_version)
|
|
329
354
|
|
|
330
355
|
if client_app:
|
|
331
356
|
app = client_app
|
|
@@ -335,18 +360,25 @@ def start_vce(
|
|
|
335
360
|
|
|
336
361
|
try:
|
|
337
362
|
# Test if ClientApp can be loaded
|
|
338
|
-
|
|
363
|
+
client_app = app_fn()
|
|
364
|
+
|
|
365
|
+
# Cache `ClientApp`
|
|
366
|
+
if client_app_attr:
|
|
367
|
+
# Now wrap the loaded ClientApp in a dummy function
|
|
368
|
+
# this prevent unnecesary low-level loading of ClientApp
|
|
369
|
+
def _load_client_app() -> ClientApp:
|
|
370
|
+
return client_app
|
|
371
|
+
|
|
372
|
+
app_fn = _load_client_app
|
|
339
373
|
|
|
340
374
|
# Run main simulation loop
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
f_stop,
|
|
349
|
-
)
|
|
375
|
+
run_api(
|
|
376
|
+
app_fn,
|
|
377
|
+
backend_fn,
|
|
378
|
+
nodes_mapping,
|
|
379
|
+
state_factory,
|
|
380
|
+
node_states,
|
|
381
|
+
f_stop,
|
|
350
382
|
)
|
|
351
383
|
except LoadClientAppError as loadapp_ex:
|
|
352
384
|
f_stop_delay = 10
|
|
@@ -23,7 +23,7 @@ from uuid import UUID, uuid4
|
|
|
23
23
|
|
|
24
24
|
from flwr.common import log, now
|
|
25
25
|
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
|
|
26
|
-
from flwr.common.typing import Run
|
|
26
|
+
from flwr.common.typing import Run, UserConfig
|
|
27
27
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
28
28
|
from flwr.server.superlink.state.state import State
|
|
29
29
|
from flwr.server.utils import validate_task_ins_or_res
|
|
@@ -275,7 +275,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
275
275
|
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
276
276
|
return self.public_key_to_node_id.get(client_public_key)
|
|
277
277
|
|
|
278
|
-
def create_run(
|
|
278
|
+
def create_run(
|
|
279
|
+
self,
|
|
280
|
+
fab_id: str,
|
|
281
|
+
fab_version: str,
|
|
282
|
+
override_config: UserConfig,
|
|
283
|
+
) -> int:
|
|
279
284
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
280
285
|
# Sample a random int64 as run_id
|
|
281
286
|
with self.lock:
|
|
@@ -283,7 +288,10 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
283
288
|
|
|
284
289
|
if run_id not in self.run_ids:
|
|
285
290
|
self.run_ids[run_id] = Run(
|
|
286
|
-
run_id=run_id,
|
|
291
|
+
run_id=run_id,
|
|
292
|
+
fab_id=fab_id,
|
|
293
|
+
fab_version=fab_version,
|
|
294
|
+
override_config=override_config,
|
|
287
295
|
)
|
|
288
296
|
return run_id
|
|
289
297
|
log(ERROR, "Unexpected run creation failure.")
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""SQLite based implemenation of server state."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import json
|
|
18
19
|
import re
|
|
19
20
|
import sqlite3
|
|
20
21
|
import time
|
|
@@ -24,7 +25,7 @@ from uuid import UUID, uuid4
|
|
|
24
25
|
|
|
25
26
|
from flwr.common import log, now
|
|
26
27
|
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
|
|
27
|
-
from flwr.common.typing import Run
|
|
28
|
+
from flwr.common.typing import Run, UserConfig
|
|
28
29
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
29
30
|
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
|
|
30
31
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
@@ -61,9 +62,10 @@ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
|
|
61
62
|
|
|
62
63
|
SQL_CREATE_TABLE_RUN = """
|
|
63
64
|
CREATE TABLE IF NOT EXISTS run(
|
|
64
|
-
run_id
|
|
65
|
-
fab_id
|
|
66
|
-
fab_version
|
|
65
|
+
run_id INTEGER UNIQUE,
|
|
66
|
+
fab_id TEXT,
|
|
67
|
+
fab_version TEXT,
|
|
68
|
+
override_config TEXT
|
|
67
69
|
);
|
|
68
70
|
"""
|
|
69
71
|
|
|
@@ -613,7 +615,12 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
613
615
|
return node_id
|
|
614
616
|
return None
|
|
615
617
|
|
|
616
|
-
def create_run(
|
|
618
|
+
def create_run(
|
|
619
|
+
self,
|
|
620
|
+
fab_id: str,
|
|
621
|
+
fab_version: str,
|
|
622
|
+
override_config: UserConfig,
|
|
623
|
+
) -> int:
|
|
617
624
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
618
625
|
# Sample a random int64 as run_id
|
|
619
626
|
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
@@ -622,8 +629,13 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
622
629
|
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
|
|
623
630
|
# If run_id does not exist
|
|
624
631
|
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
|
|
625
|
-
query =
|
|
626
|
-
|
|
632
|
+
query = (
|
|
633
|
+
"INSERT INTO run (run_id, fab_id, fab_version, override_config)"
|
|
634
|
+
"VALUES (?, ?, ?, ?);"
|
|
635
|
+
)
|
|
636
|
+
self.query(
|
|
637
|
+
query, (run_id, fab_id, fab_version, json.dumps(override_config))
|
|
638
|
+
)
|
|
627
639
|
return run_id
|
|
628
640
|
log(ERROR, "Unexpected run creation failure.")
|
|
629
641
|
return 0
|
|
@@ -687,7 +699,10 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
687
699
|
try:
|
|
688
700
|
row = self.query(query, (run_id,))[0]
|
|
689
701
|
return Run(
|
|
690
|
-
run_id=run_id,
|
|
702
|
+
run_id=run_id,
|
|
703
|
+
fab_id=row["fab_id"],
|
|
704
|
+
fab_version=row["fab_version"],
|
|
705
|
+
override_config=json.loads(row["override_config"]),
|
|
691
706
|
)
|
|
692
707
|
except sqlite3.IntegrityError:
|
|
693
708
|
log(ERROR, "`run_id` does not exist.")
|
|
@@ -19,7 +19,7 @@ import abc
|
|
|
19
19
|
from typing import List, Optional, Set
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
-
from flwr.common.typing import Run
|
|
22
|
+
from flwr.common.typing import Run, UserConfig
|
|
23
23
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
24
24
|
|
|
25
25
|
|
|
@@ -157,7 +157,12 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
157
157
|
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
158
158
|
|
|
159
159
|
@abc.abstractmethod
|
|
160
|
-
def create_run(
|
|
160
|
+
def create_run(
|
|
161
|
+
self,
|
|
162
|
+
fab_id: str,
|
|
163
|
+
fab_version: str,
|
|
164
|
+
override_config: UserConfig,
|
|
165
|
+
) -> int:
|
|
161
166
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
162
167
|
|
|
163
168
|
@abc.abstractmethod
|
flwr/server/typing.py
CHANGED
|
@@ -20,6 +20,8 @@ from typing import Callable
|
|
|
20
20
|
from flwr.common import Context
|
|
21
21
|
|
|
22
22
|
from .driver import Driver
|
|
23
|
+
from .serverapp_components import ServerAppComponents
|
|
23
24
|
|
|
24
25
|
ServerAppCallable = Callable[[Driver, Context], None]
|
|
25
26
|
Workflow = Callable[[Driver, Context], None]
|
|
27
|
+
ServerFn = Callable[[Context], ServerAppComponents]
|
|
@@ -81,6 +81,7 @@ class WorkflowState: # pylint: disable=R0902
|
|
|
81
81
|
forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
|
|
82
82
|
aggregate_ndarrays: NDArrays = field(default_factory=list)
|
|
83
83
|
legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)
|
|
84
|
+
failures: List[Exception] = field(default_factory=list)
|
|
84
85
|
|
|
85
86
|
|
|
86
87
|
class SecAggPlusWorkflow:
|
|
@@ -394,6 +395,7 @@ class SecAggPlusWorkflow:
|
|
|
394
395
|
|
|
395
396
|
for msg in msgs:
|
|
396
397
|
if msg.has_error():
|
|
398
|
+
state.failures.append(Exception(msg.error))
|
|
397
399
|
continue
|
|
398
400
|
key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
399
401
|
node_id = msg.metadata.src_node_id
|
|
@@ -451,6 +453,9 @@ class SecAggPlusWorkflow:
|
|
|
451
453
|
nid: [] for nid in state.active_node_ids
|
|
452
454
|
} # dest node ID -> list of src node IDs
|
|
453
455
|
for msg in msgs:
|
|
456
|
+
if msg.has_error():
|
|
457
|
+
state.failures.append(Exception(msg.error))
|
|
458
|
+
continue
|
|
454
459
|
node_id = msg.metadata.src_node_id
|
|
455
460
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
456
461
|
dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST])
|
|
@@ -515,6 +520,9 @@ class SecAggPlusWorkflow:
|
|
|
515
520
|
# Sum collected masked vectors and compute active/dead node IDs
|
|
516
521
|
masked_vector = None
|
|
517
522
|
for msg in msgs:
|
|
523
|
+
if msg.has_error():
|
|
524
|
+
state.failures.append(Exception(msg.error))
|
|
525
|
+
continue
|
|
518
526
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
519
527
|
bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS])
|
|
520
528
|
client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
|
|
@@ -528,6 +536,9 @@ class SecAggPlusWorkflow:
|
|
|
528
536
|
|
|
529
537
|
# Backward compatibility with Strategy
|
|
530
538
|
for msg in msgs:
|
|
539
|
+
if msg.has_error():
|
|
540
|
+
state.failures.append(Exception(msg.error))
|
|
541
|
+
continue
|
|
531
542
|
fitres = compat.recordset_to_fitres(msg.content, True)
|
|
532
543
|
proxy = state.nid_to_proxies[msg.metadata.src_node_id]
|
|
533
544
|
state.legacy_results.append((proxy, fitres))
|
|
@@ -584,6 +595,9 @@ class SecAggPlusWorkflow:
|
|
|
584
595
|
for nid in state.sampled_node_ids:
|
|
585
596
|
collected_shares_dict[nid] = []
|
|
586
597
|
for msg in msgs:
|
|
598
|
+
if msg.has_error():
|
|
599
|
+
state.failures.append(Exception(msg.error))
|
|
600
|
+
continue
|
|
587
601
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
588
602
|
nids = cast(List[int], res_dict[Key.NODE_ID_LIST])
|
|
589
603
|
shares = cast(List[bytes], res_dict[Key.SHARE_LIST])
|
|
@@ -652,9 +666,11 @@ class SecAggPlusWorkflow:
|
|
|
652
666
|
INFO,
|
|
653
667
|
"aggregate_fit: received %s results and %s failures",
|
|
654
668
|
len(results),
|
|
655
|
-
|
|
669
|
+
len(state.failures),
|
|
670
|
+
)
|
|
671
|
+
aggregated_result = context.strategy.aggregate_fit(
|
|
672
|
+
current_round, results, state.failures # type: ignore
|
|
656
673
|
)
|
|
657
|
-
aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
|
|
658
674
|
parameters_aggregated, metrics_aggregated = aggregated_result
|
|
659
675
|
|
|
660
676
|
# Update the parameters and write history
|
flwr/simulation/__init__.py
CHANGED
flwr/simulation/app.py
CHANGED
|
@@ -111,9 +111,9 @@ def start_simulation(
|
|
|
111
111
|
Parameters
|
|
112
112
|
----------
|
|
113
113
|
client_fn : ClientFnExt
|
|
114
|
-
A function creating Client instances. The function must have the signature
|
|
115
|
-
`client_fn(
|
|
116
|
-
a single client instance of type Client
|
|
114
|
+
A function creating `Client` instances. The function must have the signature
|
|
115
|
+
`client_fn(context: Context). It should return
|
|
116
|
+
a single client instance of type `Client`. Note that the created client
|
|
117
117
|
instances are ephemeral and will often be destroyed after a single method
|
|
118
118
|
invocation. Since client instances are not long-lived, they should not attempt
|
|
119
119
|
to carry state over method invocations. Any state required by the instance
|
|
@@ -327,6 +327,7 @@ def start_simulation(
|
|
|
327
327
|
client_fn=client_fn,
|
|
328
328
|
node_id=node_id,
|
|
329
329
|
partition_id=partition_id,
|
|
330
|
+
num_partitions=num_clients,
|
|
330
331
|
actor_pool=pool,
|
|
331
332
|
)
|
|
332
333
|
initialized_server.client_manager().register(client=client_proxy)
|