flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/client/mod/__init__.py +3 -2
- flwr/client/mod/centraldp_mods.py +63 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
- flwr/common/differential_privacy.py +77 -0
- flwr/common/differential_privacy_constants.py +1 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
- flwr/proto/error_pb2.py +26 -0
- flwr/proto/error_pb2.pyi +25 -0
- flwr/proto/error_pb2_grpc.py +4 -0
- flwr/proto/error_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +8 -7
- flwr/proto/task_pb2.pyi +7 -2
- flwr/server/__init__.py +4 -0
- flwr/server/app.py +8 -31
- flwr/server/client_proxy.py +5 -0
- flwr/server/compat/__init__.py +2 -0
- flwr/server/compat/app.py +7 -88
- flwr/server/compat/app_utils.py +102 -0
- flwr/server/compat/driver_client_proxy.py +22 -10
- flwr/server/compat/legacy_context.py +55 -0
- flwr/server/run_serverapp.py +1 -1
- flwr/server/server.py +18 -8
- flwr/server/strategy/__init__.py +24 -14
- flwr/server/strategy/dp_adaptive_clipping.py +449 -0
- flwr/server/strategy/dp_fixed_clipping.py +5 -7
- flwr/server/superlink/driver/driver_grpc.py +54 -0
- flwr/server/superlink/driver/driver_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
- flwr/server/superlink/fleet/vce/vce_api.py +236 -16
- flwr/server/typing.py +1 -0
- flwr/server/workflow/__init__.py +22 -0
- flwr/server/workflow/default_workflows.py +357 -0
- flwr/simulation/__init__.py +3 -0
- flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
- flwr/simulation/run_simulation.py +177 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Driver gRPC API."""
|
16
|
+
|
17
|
+
from logging import INFO
|
18
|
+
from typing import Optional, Tuple
|
19
|
+
|
20
|
+
import grpc
|
21
|
+
|
22
|
+
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
23
|
+
from flwr.common.logger import log
|
24
|
+
from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
|
25
|
+
add_DriverServicer_to_server,
|
26
|
+
)
|
27
|
+
from flwr.server.superlink.state import StateFactory
|
28
|
+
|
29
|
+
from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
30
|
+
from .driver_servicer import DriverServicer
|
31
|
+
|
32
|
+
|
33
|
+
def run_driver_api_grpc(
|
34
|
+
address: str,
|
35
|
+
state_factory: StateFactory,
|
36
|
+
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
37
|
+
) -> grpc.Server:
|
38
|
+
"""Run Driver API (gRPC, request-response)."""
|
39
|
+
# Create Driver API gRPC server
|
40
|
+
driver_servicer: grpc.Server = DriverServicer(
|
41
|
+
state_factory=state_factory,
|
42
|
+
)
|
43
|
+
driver_add_servicer_to_server_fn = add_DriverServicer_to_server
|
44
|
+
driver_grpc_server = generic_create_grpc_server(
|
45
|
+
servicer_and_add_fn=(driver_servicer, driver_add_servicer_to_server_fn),
|
46
|
+
server_address=address,
|
47
|
+
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
48
|
+
certificates=certificates,
|
49
|
+
)
|
50
|
+
|
51
|
+
log(INFO, "Flower ECE: Starting Driver API (gRPC-rere) on %s", address)
|
52
|
+
driver_grpc_server.start()
|
53
|
+
|
54
|
+
return driver_grpc_server
|
@@ -15,7 +15,7 @@
|
|
15
15
|
"""Driver API servicer."""
|
16
16
|
|
17
17
|
|
18
|
-
from logging import INFO
|
18
|
+
from logging import DEBUG, INFO
|
19
19
|
from typing import List, Optional, Set
|
20
20
|
from uuid import UUID
|
21
21
|
|
@@ -70,7 +70,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
70
70
|
self, request: PushTaskInsRequest, context: grpc.ServicerContext
|
71
71
|
) -> PushTaskInsResponse:
|
72
72
|
"""Push a set of TaskIns."""
|
73
|
-
log(
|
73
|
+
log(DEBUG, "DriverServicer.PushTaskIns")
|
74
74
|
|
75
75
|
# Validate request
|
76
76
|
_raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty")
|
@@ -95,7 +95,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
95
95
|
self, request: PullTaskResRequest, context: grpc.ServicerContext
|
96
96
|
) -> PullTaskResResponse:
|
97
97
|
"""Pull a set of TaskRes."""
|
98
|
-
log(
|
98
|
+
log(DEBUG, "DriverServicer.PullTaskRes")
|
99
99
|
|
100
100
|
# Convert each task_id str to UUID
|
101
101
|
task_ids: Set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
@@ -105,7 +105,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
105
105
|
|
106
106
|
# Register callback
|
107
107
|
def on_rpc_done() -> None:
|
108
|
-
log(
|
108
|
+
log(DEBUG, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes")
|
109
109
|
|
110
110
|
if context.is_active():
|
111
111
|
return
|
@@ -46,6 +46,7 @@ class GrpcClientProxy(ClientProxy):
|
|
46
46
|
self,
|
47
47
|
ins: common.GetPropertiesIns,
|
48
48
|
timeout: Optional[float],
|
49
|
+
group_id: Optional[int],
|
49
50
|
) -> common.GetPropertiesRes:
|
50
51
|
"""Request client's set of internal properties."""
|
51
52
|
get_properties_msg = serde.get_properties_ins_to_proto(ins)
|
@@ -65,6 +66,7 @@ class GrpcClientProxy(ClientProxy):
|
|
65
66
|
self,
|
66
67
|
ins: common.GetParametersIns,
|
67
68
|
timeout: Optional[float],
|
69
|
+
group_id: Optional[int],
|
68
70
|
) -> common.GetParametersRes:
|
69
71
|
"""Return the current local model parameters."""
|
70
72
|
get_parameters_msg = serde.get_parameters_ins_to_proto(ins)
|
@@ -84,6 +86,7 @@ class GrpcClientProxy(ClientProxy):
|
|
84
86
|
self,
|
85
87
|
ins: common.FitIns,
|
86
88
|
timeout: Optional[float],
|
89
|
+
group_id: Optional[int],
|
87
90
|
) -> common.FitRes:
|
88
91
|
"""Refine the provided parameters using the locally held dataset."""
|
89
92
|
fit_ins_msg = serde.fit_ins_to_proto(ins)
|
@@ -102,6 +105,7 @@ class GrpcClientProxy(ClientProxy):
|
|
102
105
|
self,
|
103
106
|
ins: common.EvaluateIns,
|
104
107
|
timeout: Optional[float],
|
108
|
+
group_id: Optional[int],
|
105
109
|
) -> common.EvaluateRes:
|
106
110
|
"""Evaluate the provided parameters using the locally held dataset."""
|
107
111
|
evaluate_msg = serde.evaluate_ins_to_proto(ins)
|
@@ -119,6 +123,7 @@ class GrpcClientProxy(ClientProxy):
|
|
119
123
|
self,
|
120
124
|
ins: common.ReconnectIns,
|
121
125
|
timeout: Optional[float],
|
126
|
+
group_id: Optional[int],
|
122
127
|
) -> common.DisconnectRes:
|
123
128
|
"""Disconnect and (optionally) reconnect later."""
|
124
129
|
reconnect_ins_msg = serde.reconnect_ins_to_proto(ins)
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""Fleet
|
15
|
+
"""Fleet Simulation Engine side."""
|
16
16
|
|
17
17
|
from .vce_api import start_vce
|
18
18
|
|
@@ -141,13 +141,13 @@ class RayBackend(Backend):
|
|
141
141
|
|
142
142
|
Return output message and updated context.
|
143
143
|
"""
|
144
|
-
|
144
|
+
partition_id = message.metadata.partition_id
|
145
145
|
|
146
146
|
try:
|
147
147
|
# Submite a task to the pool
|
148
148
|
future = await self.pool.submit(
|
149
149
|
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
|
150
|
-
(app, message, str(
|
150
|
+
(app, message, str(partition_id), context),
|
151
151
|
)
|
152
152
|
|
153
153
|
await future
|
@@ -163,10 +163,9 @@ class RayBackend(Backend):
|
|
163
163
|
except LoadClientAppError as load_ex:
|
164
164
|
log(
|
165
165
|
ERROR,
|
166
|
-
"An exception was raised when processing a message
|
166
|
+
"An exception was raised when processing a message by %s",
|
167
167
|
self.__class__.__name__,
|
168
168
|
)
|
169
|
-
await self.terminate()
|
170
169
|
raise load_ex
|
171
170
|
|
172
171
|
async def terminate(self) -> None:
|
@@ -12,19 +12,23 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""Fleet
|
15
|
+
"""Fleet Simulation Engine API."""
|
16
|
+
|
16
17
|
|
17
18
|
import asyncio
|
18
19
|
import json
|
19
|
-
|
20
|
-
from
|
20
|
+
import traceback
|
21
|
+
from logging import DEBUG, ERROR, INFO, WARN
|
22
|
+
from typing import Callable, Dict, List, Optional
|
21
23
|
|
22
|
-
from flwr.client.client_app import ClientApp, load_client_app
|
24
|
+
from flwr.client.client_app import ClientApp, LoadClientAppError, load_client_app
|
23
25
|
from flwr.client.node_state import NodeState
|
24
26
|
from flwr.common.logger import log
|
27
|
+
from flwr.common.serde import message_from_taskins, message_to_taskres
|
28
|
+
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
25
29
|
from flwr.server.superlink.state import StateFactory
|
26
30
|
|
27
|
-
from .backend import error_messages_backends, supported_backends
|
31
|
+
from .backend import Backend, error_messages_backends, supported_backends
|
28
32
|
|
29
33
|
NodeToPartitionMapping = Dict[int, int]
|
30
34
|
|
@@ -42,21 +46,223 @@ def _register_nodes(
|
|
42
46
|
return nodes_mapping
|
43
47
|
|
44
48
|
|
45
|
-
# pylint: disable=too-many-arguments,
|
49
|
+
# pylint: disable=too-many-arguments,too-many-locals
|
50
|
+
async def worker(
|
51
|
+
app_fn: Callable[[], ClientApp],
|
52
|
+
queue: "asyncio.Queue[TaskIns]",
|
53
|
+
node_states: Dict[int, NodeState],
|
54
|
+
state_factory: StateFactory,
|
55
|
+
nodes_mapping: NodeToPartitionMapping,
|
56
|
+
backend: Backend,
|
57
|
+
) -> None:
|
58
|
+
"""Get TaskIns from queue and pass it to an actor in the pool to execute it."""
|
59
|
+
state = state_factory.state()
|
60
|
+
while True:
|
61
|
+
try:
|
62
|
+
task_ins: TaskIns = await queue.get()
|
63
|
+
node_id = task_ins.task.consumer.node_id
|
64
|
+
|
65
|
+
# Register and retrieve runstate
|
66
|
+
node_states[node_id].register_context(run_id=task_ins.run_id)
|
67
|
+
context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
|
68
|
+
|
69
|
+
# Convert TaskIns to Message
|
70
|
+
message = message_from_taskins(task_ins)
|
71
|
+
# Set partition_id
|
72
|
+
message.metadata.partition_id = nodes_mapping[node_id]
|
73
|
+
|
74
|
+
# Let backend process message
|
75
|
+
out_mssg, updated_context = await backend.process_message(
|
76
|
+
app_fn, message, context
|
77
|
+
)
|
78
|
+
|
79
|
+
# Update Context
|
80
|
+
node_states[node_id].update_context(
|
81
|
+
task_ins.run_id, context=updated_context
|
82
|
+
)
|
83
|
+
|
84
|
+
# Convert to TaskRes
|
85
|
+
task_res = message_to_taskres(out_mssg)
|
86
|
+
# Store TaskRes in state
|
87
|
+
state.store_task_res(task_res)
|
88
|
+
|
89
|
+
except asyncio.CancelledError as e:
|
90
|
+
log(DEBUG, "Async worker: %s", e)
|
91
|
+
break
|
92
|
+
|
93
|
+
except LoadClientAppError as app_ex:
|
94
|
+
log(ERROR, "Async worker: %s", app_ex)
|
95
|
+
log(ERROR, traceback.format_exc())
|
96
|
+
raise
|
97
|
+
|
98
|
+
except Exception as ex: # pylint: disable=broad-exception-caught
|
99
|
+
log(ERROR, ex)
|
100
|
+
log(ERROR, traceback.format_exc())
|
101
|
+
break
|
102
|
+
|
103
|
+
|
104
|
+
async def add_taskins_to_queue(
|
105
|
+
queue: "asyncio.Queue[TaskIns]",
|
106
|
+
state_factory: StateFactory,
|
107
|
+
nodes_mapping: NodeToPartitionMapping,
|
108
|
+
backend: Backend,
|
109
|
+
consumers: List["asyncio.Task[None]"],
|
110
|
+
f_stop: asyncio.Event,
|
111
|
+
) -> None:
|
112
|
+
"""Retrieve TaskIns and add it to the queue."""
|
113
|
+
state = state_factory.state()
|
114
|
+
num_initial_consumers = len(consumers)
|
115
|
+
while not f_stop.is_set():
|
116
|
+
for node_id in nodes_mapping.keys():
|
117
|
+
task_ins = state.get_task_ins(node_id=node_id, limit=1)
|
118
|
+
if task_ins:
|
119
|
+
await queue.put(task_ins[0])
|
120
|
+
|
121
|
+
# Count consumers that are running
|
122
|
+
num_active = sum(not (cc.done()) for cc in consumers)
|
123
|
+
|
124
|
+
# Alert if number of consumers decreased by half
|
125
|
+
if num_active < num_initial_consumers // 2:
|
126
|
+
log(
|
127
|
+
WARN,
|
128
|
+
"Number of active workers has more than halved: (%i/%i active)",
|
129
|
+
num_active,
|
130
|
+
num_initial_consumers,
|
131
|
+
)
|
132
|
+
|
133
|
+
# Break if consumers died
|
134
|
+
if num_active == 0:
|
135
|
+
raise RuntimeError("All workers have died. Ending Simulation.")
|
136
|
+
|
137
|
+
# Log some stats
|
138
|
+
log(
|
139
|
+
DEBUG,
|
140
|
+
"Simulation Engine stats: "
|
141
|
+
"Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)",
|
142
|
+
num_active,
|
143
|
+
num_initial_consumers,
|
144
|
+
backend.__class__.__name__,
|
145
|
+
backend.num_workers,
|
146
|
+
queue.qsize(),
|
147
|
+
)
|
148
|
+
await asyncio.sleep(1.0)
|
149
|
+
log(DEBUG, "Async producer: Stopped pulling from StateFactory.")
|
150
|
+
|
151
|
+
|
152
|
+
async def run(
|
153
|
+
app_fn: Callable[[], ClientApp],
|
154
|
+
backend_fn: Callable[[], Backend],
|
155
|
+
nodes_mapping: NodeToPartitionMapping,
|
156
|
+
state_factory: StateFactory,
|
157
|
+
node_states: Dict[int, NodeState],
|
158
|
+
f_stop: asyncio.Event,
|
159
|
+
) -> None:
|
160
|
+
"""Run the VCE async."""
|
161
|
+
queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128)
|
162
|
+
|
163
|
+
try:
|
164
|
+
|
165
|
+
# Instantiate backend
|
166
|
+
backend = backend_fn()
|
167
|
+
|
168
|
+
# Build backend
|
169
|
+
await backend.build()
|
170
|
+
|
171
|
+
# Add workers (they submit Messages to Backend)
|
172
|
+
worker_tasks = [
|
173
|
+
asyncio.create_task(
|
174
|
+
worker(
|
175
|
+
app_fn, queue, node_states, state_factory, nodes_mapping, backend
|
176
|
+
)
|
177
|
+
)
|
178
|
+
for _ in range(backend.num_workers)
|
179
|
+
]
|
180
|
+
# Create producer (adds TaskIns into Queue)
|
181
|
+
producer = asyncio.create_task(
|
182
|
+
add_taskins_to_queue(
|
183
|
+
queue, state_factory, nodes_mapping, backend, worker_tasks, f_stop
|
184
|
+
)
|
185
|
+
)
|
186
|
+
|
187
|
+
# Wait for producer to finish
|
188
|
+
# The producer runs forever until f_stop is set or until
|
189
|
+
# all worker (consumer) coroutines are completed. Workers
|
190
|
+
# also run forever and only end if an exception is raised.
|
191
|
+
await asyncio.gather(producer)
|
192
|
+
|
193
|
+
except Exception as ex:
|
194
|
+
|
195
|
+
log(ERROR, "An exception occured!! %s", ex)
|
196
|
+
log(ERROR, traceback.format_exc())
|
197
|
+
log(WARN, "Stopping Simulation Engine.")
|
198
|
+
|
199
|
+
# Manually trigger stopping event
|
200
|
+
f_stop.set()
|
201
|
+
|
202
|
+
# Raise exception
|
203
|
+
raise RuntimeError("Simulation Engine crashed.") from ex
|
204
|
+
|
205
|
+
finally:
|
206
|
+
# Produced task terminated, now cancel worker tasks
|
207
|
+
for w_t in worker_tasks:
|
208
|
+
_ = w_t.cancel()
|
209
|
+
|
210
|
+
while not all(w_t.done() for w_t in worker_tasks):
|
211
|
+
log(DEBUG, "Terminating async workers...")
|
212
|
+
await asyncio.sleep(0.5)
|
213
|
+
|
214
|
+
await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])
|
215
|
+
|
216
|
+
# Terminate backend
|
217
|
+
await backend.terminate()
|
218
|
+
|
219
|
+
|
220
|
+
# pylint: disable=too-many-arguments,unused-argument,too-many-locals
|
46
221
|
def start_vce(
|
47
|
-
num_supernodes: int,
|
48
222
|
client_app_module_name: str,
|
49
223
|
backend_name: str,
|
50
224
|
backend_config_json_stream: str,
|
51
|
-
state_factory: StateFactory,
|
52
225
|
working_dir: str,
|
53
|
-
f_stop:
|
226
|
+
f_stop: asyncio.Event,
|
227
|
+
num_supernodes: Optional[int] = None,
|
228
|
+
state_factory: Optional[StateFactory] = None,
|
229
|
+
existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
|
54
230
|
) -> None:
|
55
|
-
"""Start Fleet API with the
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
231
|
+
"""Start Fleet API with the Simulation Engine."""
|
232
|
+
if num_supernodes is not None and existing_nodes_mapping is not None:
|
233
|
+
raise ValueError(
|
234
|
+
"Both `num_supernodes` and `existing_nodes_mapping` are provided, "
|
235
|
+
"but only one is allowed."
|
236
|
+
)
|
237
|
+
if num_supernodes is None:
|
238
|
+
if state_factory is None or existing_nodes_mapping is None:
|
239
|
+
raise ValueError(
|
240
|
+
"If not passing an existing `state_factory` and associated "
|
241
|
+
"`existing_nodes_mapping` you must supply `num_supernodes` to indicate "
|
242
|
+
"how many nodes to insert into a new StateFactory that will be created."
|
243
|
+
)
|
244
|
+
if existing_nodes_mapping:
|
245
|
+
if state_factory is None:
|
246
|
+
raise ValueError(
|
247
|
+
"`existing_nodes_mapping` was passed, but no `state_factory` was "
|
248
|
+
"passed."
|
249
|
+
)
|
250
|
+
log(INFO, "Using exiting NodeToPartitionMapping and StateFactory.")
|
251
|
+
# Use mapping constructed externally. This also means nodes
|
252
|
+
# have previously being registered.
|
253
|
+
nodes_mapping = existing_nodes_mapping
|
254
|
+
|
255
|
+
if not state_factory:
|
256
|
+
log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
|
257
|
+
# Create an empty in-memory state factory
|
258
|
+
state_factory = StateFactory(":flwr-in-memory-state:")
|
259
|
+
log(INFO, "Created new %s.", state_factory.__class__.__name__)
|
260
|
+
|
261
|
+
if num_supernodes:
|
262
|
+
# Register SuperNodes
|
263
|
+
nodes_mapping = _register_nodes(
|
264
|
+
num_nodes=num_supernodes, state_factory=state_factory
|
265
|
+
)
|
60
266
|
|
61
267
|
# Construct mapping of NodeStates
|
62
268
|
node_states: Dict[int, NodeState] = {}
|
@@ -69,7 +275,6 @@ def start_vce(
|
|
69
275
|
|
70
276
|
try:
|
71
277
|
backend_type = supported_backends[backend_name]
|
72
|
-
_ = backend_type(backend_config, work_dir=working_dir)
|
73
278
|
except KeyError as ex:
|
74
279
|
log(
|
75
280
|
ERROR,
|
@@ -83,10 +288,25 @@ def start_vce(
|
|
83
288
|
|
84
289
|
raise ex
|
85
290
|
|
291
|
+
def backend_fn() -> Backend:
|
292
|
+
"""Instantiate a Backend."""
|
293
|
+
return backend_type(backend_config, work_dir=working_dir)
|
294
|
+
|
86
295
|
log(INFO, "client_app_module_name = %s", client_app_module_name)
|
87
296
|
|
88
297
|
def _load() -> ClientApp:
|
89
298
|
app: ClientApp = load_client_app(client_app_module_name)
|
90
299
|
return app
|
91
300
|
|
92
|
-
|
301
|
+
app_fn = _load
|
302
|
+
|
303
|
+
asyncio.run(
|
304
|
+
run(
|
305
|
+
app_fn,
|
306
|
+
backend_fn,
|
307
|
+
nodes_mapping,
|
308
|
+
state_factory,
|
309
|
+
node_states,
|
310
|
+
f_stop,
|
311
|
+
)
|
312
|
+
)
|
flwr/server/typing.py
CHANGED
@@ -0,0 +1,22 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Workflows."""
|
16
|
+
|
17
|
+
|
18
|
+
from .default_workflows import DefaultWorkflow
|
19
|
+
|
20
|
+
__all__ = [
|
21
|
+
"DefaultWorkflow",
|
22
|
+
]
|