flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/mod/__init__.py +3 -2
- flwr/client/mod/centraldp_mods.py +63 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
- flwr/common/differential_privacy.py +77 -0
- flwr/common/differential_privacy_constants.py +1 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
- flwr/proto/error_pb2.py +26 -0
- flwr/proto/error_pb2.pyi +25 -0
- flwr/proto/error_pb2_grpc.py +4 -0
- flwr/proto/error_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +8 -7
- flwr/proto/task_pb2.pyi +7 -2
- flwr/server/__init__.py +4 -0
- flwr/server/app.py +8 -31
- flwr/server/client_proxy.py +5 -0
- flwr/server/compat/__init__.py +2 -0
- flwr/server/compat/app.py +7 -88
- flwr/server/compat/app_utils.py +102 -0
- flwr/server/compat/driver_client_proxy.py +22 -10
- flwr/server/compat/legacy_context.py +55 -0
- flwr/server/run_serverapp.py +1 -1
- flwr/server/server.py +18 -8
- flwr/server/strategy/__init__.py +24 -14
- flwr/server/strategy/dp_adaptive_clipping.py +449 -0
- flwr/server/strategy/dp_fixed_clipping.py +5 -7
- flwr/server/superlink/driver/driver_grpc.py +54 -0
- flwr/server/superlink/driver/driver_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
- flwr/server/superlink/fleet/vce/vce_api.py +236 -16
- flwr/server/typing.py +1 -0
- flwr/server/workflow/__init__.py +22 -0
- flwr/server/workflow/default_workflows.py +357 -0
- flwr/simulation/__init__.py +3 -0
- flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
- flwr/simulation/run_simulation.py +177 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Driver gRPC API."""
|
16
|
+
|
17
|
+
from logging import INFO
|
18
|
+
from typing import Optional, Tuple
|
19
|
+
|
20
|
+
import grpc
|
21
|
+
|
22
|
+
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
23
|
+
from flwr.common.logger import log
|
24
|
+
from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
|
25
|
+
add_DriverServicer_to_server,
|
26
|
+
)
|
27
|
+
from flwr.server.superlink.state import StateFactory
|
28
|
+
|
29
|
+
from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
30
|
+
from .driver_servicer import DriverServicer
|
31
|
+
|
32
|
+
|
33
|
+
def run_driver_api_grpc(
|
34
|
+
address: str,
|
35
|
+
state_factory: StateFactory,
|
36
|
+
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
37
|
+
) -> grpc.Server:
|
38
|
+
"""Run Driver API (gRPC, request-response)."""
|
39
|
+
# Create Driver API gRPC server
|
40
|
+
driver_servicer: grpc.Server = DriverServicer(
|
41
|
+
state_factory=state_factory,
|
42
|
+
)
|
43
|
+
driver_add_servicer_to_server_fn = add_DriverServicer_to_server
|
44
|
+
driver_grpc_server = generic_create_grpc_server(
|
45
|
+
servicer_and_add_fn=(driver_servicer, driver_add_servicer_to_server_fn),
|
46
|
+
server_address=address,
|
47
|
+
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
48
|
+
certificates=certificates,
|
49
|
+
)
|
50
|
+
|
51
|
+
log(INFO, "Flower ECE: Starting Driver API (gRPC-rere) on %s", address)
|
52
|
+
driver_grpc_server.start()
|
53
|
+
|
54
|
+
return driver_grpc_server
|
@@ -15,7 +15,7 @@
|
|
15
15
|
"""Driver API servicer."""
|
16
16
|
|
17
17
|
|
18
|
-
from logging import INFO
|
18
|
+
from logging import DEBUG, INFO
|
19
19
|
from typing import List, Optional, Set
|
20
20
|
from uuid import UUID
|
21
21
|
|
@@ -70,7 +70,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
70
70
|
self, request: PushTaskInsRequest, context: grpc.ServicerContext
|
71
71
|
) -> PushTaskInsResponse:
|
72
72
|
"""Push a set of TaskIns."""
|
73
|
-
log(
|
73
|
+
log(DEBUG, "DriverServicer.PushTaskIns")
|
74
74
|
|
75
75
|
# Validate request
|
76
76
|
_raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty")
|
@@ -95,7 +95,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
95
95
|
self, request: PullTaskResRequest, context: grpc.ServicerContext
|
96
96
|
) -> PullTaskResResponse:
|
97
97
|
"""Pull a set of TaskRes."""
|
98
|
-
log(
|
98
|
+
log(DEBUG, "DriverServicer.PullTaskRes")
|
99
99
|
|
100
100
|
# Convert each task_id str to UUID
|
101
101
|
task_ids: Set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
@@ -105,7 +105,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
105
105
|
|
106
106
|
# Register callback
|
107
107
|
def on_rpc_done() -> None:
|
108
|
-
log(
|
108
|
+
log(DEBUG, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes")
|
109
109
|
|
110
110
|
if context.is_active():
|
111
111
|
return
|
@@ -46,6 +46,7 @@ class GrpcClientProxy(ClientProxy):
|
|
46
46
|
self,
|
47
47
|
ins: common.GetPropertiesIns,
|
48
48
|
timeout: Optional[float],
|
49
|
+
group_id: Optional[int],
|
49
50
|
) -> common.GetPropertiesRes:
|
50
51
|
"""Request client's set of internal properties."""
|
51
52
|
get_properties_msg = serde.get_properties_ins_to_proto(ins)
|
@@ -65,6 +66,7 @@ class GrpcClientProxy(ClientProxy):
|
|
65
66
|
self,
|
66
67
|
ins: common.GetParametersIns,
|
67
68
|
timeout: Optional[float],
|
69
|
+
group_id: Optional[int],
|
68
70
|
) -> common.GetParametersRes:
|
69
71
|
"""Return the current local model parameters."""
|
70
72
|
get_parameters_msg = serde.get_parameters_ins_to_proto(ins)
|
@@ -84,6 +86,7 @@ class GrpcClientProxy(ClientProxy):
|
|
84
86
|
self,
|
85
87
|
ins: common.FitIns,
|
86
88
|
timeout: Optional[float],
|
89
|
+
group_id: Optional[int],
|
87
90
|
) -> common.FitRes:
|
88
91
|
"""Refine the provided parameters using the locally held dataset."""
|
89
92
|
fit_ins_msg = serde.fit_ins_to_proto(ins)
|
@@ -102,6 +105,7 @@ class GrpcClientProxy(ClientProxy):
|
|
102
105
|
self,
|
103
106
|
ins: common.EvaluateIns,
|
104
107
|
timeout: Optional[float],
|
108
|
+
group_id: Optional[int],
|
105
109
|
) -> common.EvaluateRes:
|
106
110
|
"""Evaluate the provided parameters using the locally held dataset."""
|
107
111
|
evaluate_msg = serde.evaluate_ins_to_proto(ins)
|
@@ -119,6 +123,7 @@ class GrpcClientProxy(ClientProxy):
|
|
119
123
|
self,
|
120
124
|
ins: common.ReconnectIns,
|
121
125
|
timeout: Optional[float],
|
126
|
+
group_id: Optional[int],
|
122
127
|
) -> common.DisconnectRes:
|
123
128
|
"""Disconnect and (optionally) reconnect later."""
|
124
129
|
reconnect_ins_msg = serde.reconnect_ins_to_proto(ins)
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""Fleet
|
15
|
+
"""Fleet Simulation Engine side."""
|
16
16
|
|
17
17
|
from .vce_api import start_vce
|
18
18
|
|
@@ -141,13 +141,13 @@ class RayBackend(Backend):
|
|
141
141
|
|
142
142
|
Return output message and updated context.
|
143
143
|
"""
|
144
|
-
|
144
|
+
partition_id = message.metadata.partition_id
|
145
145
|
|
146
146
|
try:
|
147
147
|
# Submite a task to the pool
|
148
148
|
future = await self.pool.submit(
|
149
149
|
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
|
150
|
-
(app, message, str(
|
150
|
+
(app, message, str(partition_id), context),
|
151
151
|
)
|
152
152
|
|
153
153
|
await future
|
@@ -163,10 +163,9 @@ class RayBackend(Backend):
|
|
163
163
|
except LoadClientAppError as load_ex:
|
164
164
|
log(
|
165
165
|
ERROR,
|
166
|
-
"An exception was raised when processing a message
|
166
|
+
"An exception was raised when processing a message by %s",
|
167
167
|
self.__class__.__name__,
|
168
168
|
)
|
169
|
-
await self.terminate()
|
170
169
|
raise load_ex
|
171
170
|
|
172
171
|
async def terminate(self) -> None:
|
@@ -12,19 +12,23 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""Fleet
|
15
|
+
"""Fleet Simulation Engine API."""
|
16
|
+
|
16
17
|
|
17
18
|
import asyncio
|
18
19
|
import json
|
19
|
-
|
20
|
-
from
|
20
|
+
import traceback
|
21
|
+
from logging import DEBUG, ERROR, INFO, WARN
|
22
|
+
from typing import Callable, Dict, List, Optional
|
21
23
|
|
22
|
-
from flwr.client.client_app import ClientApp, load_client_app
|
24
|
+
from flwr.client.client_app import ClientApp, LoadClientAppError, load_client_app
|
23
25
|
from flwr.client.node_state import NodeState
|
24
26
|
from flwr.common.logger import log
|
27
|
+
from flwr.common.serde import message_from_taskins, message_to_taskres
|
28
|
+
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
25
29
|
from flwr.server.superlink.state import StateFactory
|
26
30
|
|
27
|
-
from .backend import error_messages_backends, supported_backends
|
31
|
+
from .backend import Backend, error_messages_backends, supported_backends
|
28
32
|
|
29
33
|
NodeToPartitionMapping = Dict[int, int]
|
30
34
|
|
@@ -42,21 +46,223 @@ def _register_nodes(
|
|
42
46
|
return nodes_mapping
|
43
47
|
|
44
48
|
|
45
|
-
# pylint: disable=too-many-arguments,
|
49
|
+
# pylint: disable=too-many-arguments,too-many-locals
|
50
|
+
async def worker(
|
51
|
+
app_fn: Callable[[], ClientApp],
|
52
|
+
queue: "asyncio.Queue[TaskIns]",
|
53
|
+
node_states: Dict[int, NodeState],
|
54
|
+
state_factory: StateFactory,
|
55
|
+
nodes_mapping: NodeToPartitionMapping,
|
56
|
+
backend: Backend,
|
57
|
+
) -> None:
|
58
|
+
"""Get TaskIns from queue and pass it to an actor in the pool to execute it."""
|
59
|
+
state = state_factory.state()
|
60
|
+
while True:
|
61
|
+
try:
|
62
|
+
task_ins: TaskIns = await queue.get()
|
63
|
+
node_id = task_ins.task.consumer.node_id
|
64
|
+
|
65
|
+
# Register and retrieve runstate
|
66
|
+
node_states[node_id].register_context(run_id=task_ins.run_id)
|
67
|
+
context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
|
68
|
+
|
69
|
+
# Convert TaskIns to Message
|
70
|
+
message = message_from_taskins(task_ins)
|
71
|
+
# Set partition_id
|
72
|
+
message.metadata.partition_id = nodes_mapping[node_id]
|
73
|
+
|
74
|
+
# Let backend process message
|
75
|
+
out_mssg, updated_context = await backend.process_message(
|
76
|
+
app_fn, message, context
|
77
|
+
)
|
78
|
+
|
79
|
+
# Update Context
|
80
|
+
node_states[node_id].update_context(
|
81
|
+
task_ins.run_id, context=updated_context
|
82
|
+
)
|
83
|
+
|
84
|
+
# Convert to TaskRes
|
85
|
+
task_res = message_to_taskres(out_mssg)
|
86
|
+
# Store TaskRes in state
|
87
|
+
state.store_task_res(task_res)
|
88
|
+
|
89
|
+
except asyncio.CancelledError as e:
|
90
|
+
log(DEBUG, "Async worker: %s", e)
|
91
|
+
break
|
92
|
+
|
93
|
+
except LoadClientAppError as app_ex:
|
94
|
+
log(ERROR, "Async worker: %s", app_ex)
|
95
|
+
log(ERROR, traceback.format_exc())
|
96
|
+
raise
|
97
|
+
|
98
|
+
except Exception as ex: # pylint: disable=broad-exception-caught
|
99
|
+
log(ERROR, ex)
|
100
|
+
log(ERROR, traceback.format_exc())
|
101
|
+
break
|
102
|
+
|
103
|
+
|
104
|
+
async def add_taskins_to_queue(
|
105
|
+
queue: "asyncio.Queue[TaskIns]",
|
106
|
+
state_factory: StateFactory,
|
107
|
+
nodes_mapping: NodeToPartitionMapping,
|
108
|
+
backend: Backend,
|
109
|
+
consumers: List["asyncio.Task[None]"],
|
110
|
+
f_stop: asyncio.Event,
|
111
|
+
) -> None:
|
112
|
+
"""Retrieve TaskIns and add it to the queue."""
|
113
|
+
state = state_factory.state()
|
114
|
+
num_initial_consumers = len(consumers)
|
115
|
+
while not f_stop.is_set():
|
116
|
+
for node_id in nodes_mapping.keys():
|
117
|
+
task_ins = state.get_task_ins(node_id=node_id, limit=1)
|
118
|
+
if task_ins:
|
119
|
+
await queue.put(task_ins[0])
|
120
|
+
|
121
|
+
# Count consumers that are running
|
122
|
+
num_active = sum(not (cc.done()) for cc in consumers)
|
123
|
+
|
124
|
+
# Alert if number of consumers decreased by half
|
125
|
+
if num_active < num_initial_consumers // 2:
|
126
|
+
log(
|
127
|
+
WARN,
|
128
|
+
"Number of active workers has more than halved: (%i/%i active)",
|
129
|
+
num_active,
|
130
|
+
num_initial_consumers,
|
131
|
+
)
|
132
|
+
|
133
|
+
# Break if consumers died
|
134
|
+
if num_active == 0:
|
135
|
+
raise RuntimeError("All workers have died. Ending Simulation.")
|
136
|
+
|
137
|
+
# Log some stats
|
138
|
+
log(
|
139
|
+
DEBUG,
|
140
|
+
"Simulation Engine stats: "
|
141
|
+
"Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)",
|
142
|
+
num_active,
|
143
|
+
num_initial_consumers,
|
144
|
+
backend.__class__.__name__,
|
145
|
+
backend.num_workers,
|
146
|
+
queue.qsize(),
|
147
|
+
)
|
148
|
+
await asyncio.sleep(1.0)
|
149
|
+
log(DEBUG, "Async producer: Stopped pulling from StateFactory.")
|
150
|
+
|
151
|
+
|
152
|
+
async def run(
|
153
|
+
app_fn: Callable[[], ClientApp],
|
154
|
+
backend_fn: Callable[[], Backend],
|
155
|
+
nodes_mapping: NodeToPartitionMapping,
|
156
|
+
state_factory: StateFactory,
|
157
|
+
node_states: Dict[int, NodeState],
|
158
|
+
f_stop: asyncio.Event,
|
159
|
+
) -> None:
|
160
|
+
"""Run the VCE async."""
|
161
|
+
queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128)
|
162
|
+
|
163
|
+
try:
|
164
|
+
|
165
|
+
# Instantiate backend
|
166
|
+
backend = backend_fn()
|
167
|
+
|
168
|
+
# Build backend
|
169
|
+
await backend.build()
|
170
|
+
|
171
|
+
# Add workers (they submit Messages to Backend)
|
172
|
+
worker_tasks = [
|
173
|
+
asyncio.create_task(
|
174
|
+
worker(
|
175
|
+
app_fn, queue, node_states, state_factory, nodes_mapping, backend
|
176
|
+
)
|
177
|
+
)
|
178
|
+
for _ in range(backend.num_workers)
|
179
|
+
]
|
180
|
+
# Create producer (adds TaskIns into Queue)
|
181
|
+
producer = asyncio.create_task(
|
182
|
+
add_taskins_to_queue(
|
183
|
+
queue, state_factory, nodes_mapping, backend, worker_tasks, f_stop
|
184
|
+
)
|
185
|
+
)
|
186
|
+
|
187
|
+
# Wait for producer to finish
|
188
|
+
# The producer runs forever until f_stop is set or until
|
189
|
+
# all worker (consumer) coroutines are completed. Workers
|
190
|
+
# also run forever and only end if an exception is raised.
|
191
|
+
await asyncio.gather(producer)
|
192
|
+
|
193
|
+
except Exception as ex:
|
194
|
+
|
195
|
+
log(ERROR, "An exception occured!! %s", ex)
|
196
|
+
log(ERROR, traceback.format_exc())
|
197
|
+
log(WARN, "Stopping Simulation Engine.")
|
198
|
+
|
199
|
+
# Manually trigger stopping event
|
200
|
+
f_stop.set()
|
201
|
+
|
202
|
+
# Raise exception
|
203
|
+
raise RuntimeError("Simulation Engine crashed.") from ex
|
204
|
+
|
205
|
+
finally:
|
206
|
+
# Produced task terminated, now cancel worker tasks
|
207
|
+
for w_t in worker_tasks:
|
208
|
+
_ = w_t.cancel()
|
209
|
+
|
210
|
+
while not all(w_t.done() for w_t in worker_tasks):
|
211
|
+
log(DEBUG, "Terminating async workers...")
|
212
|
+
await asyncio.sleep(0.5)
|
213
|
+
|
214
|
+
await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])
|
215
|
+
|
216
|
+
# Terminate backend
|
217
|
+
await backend.terminate()
|
218
|
+
|
219
|
+
|
220
|
+
# pylint: disable=too-many-arguments,unused-argument,too-many-locals
|
46
221
|
def start_vce(
|
47
|
-
num_supernodes: int,
|
48
222
|
client_app_module_name: str,
|
49
223
|
backend_name: str,
|
50
224
|
backend_config_json_stream: str,
|
51
|
-
state_factory: StateFactory,
|
52
225
|
working_dir: str,
|
53
|
-
f_stop:
|
226
|
+
f_stop: asyncio.Event,
|
227
|
+
num_supernodes: Optional[int] = None,
|
228
|
+
state_factory: Optional[StateFactory] = None,
|
229
|
+
existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
|
54
230
|
) -> None:
|
55
|
-
"""Start Fleet API with the
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
231
|
+
"""Start Fleet API with the Simulation Engine."""
|
232
|
+
if num_supernodes is not None and existing_nodes_mapping is not None:
|
233
|
+
raise ValueError(
|
234
|
+
"Both `num_supernodes` and `existing_nodes_mapping` are provided, "
|
235
|
+
"but only one is allowed."
|
236
|
+
)
|
237
|
+
if num_supernodes is None:
|
238
|
+
if state_factory is None or existing_nodes_mapping is None:
|
239
|
+
raise ValueError(
|
240
|
+
"If not passing an existing `state_factory` and associated "
|
241
|
+
"`existing_nodes_mapping` you must supply `num_supernodes` to indicate "
|
242
|
+
"how many nodes to insert into a new StateFactory that will be created."
|
243
|
+
)
|
244
|
+
if existing_nodes_mapping:
|
245
|
+
if state_factory is None:
|
246
|
+
raise ValueError(
|
247
|
+
"`existing_nodes_mapping` was passed, but no `state_factory` was "
|
248
|
+
"passed."
|
249
|
+
)
|
250
|
+
log(INFO, "Using exiting NodeToPartitionMapping and StateFactory.")
|
251
|
+
# Use mapping constructed externally. This also means nodes
|
252
|
+
# have previously being registered.
|
253
|
+
nodes_mapping = existing_nodes_mapping
|
254
|
+
|
255
|
+
if not state_factory:
|
256
|
+
log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
|
257
|
+
# Create an empty in-memory state factory
|
258
|
+
state_factory = StateFactory(":flwr-in-memory-state:")
|
259
|
+
log(INFO, "Created new %s.", state_factory.__class__.__name__)
|
260
|
+
|
261
|
+
if num_supernodes:
|
262
|
+
# Register SuperNodes
|
263
|
+
nodes_mapping = _register_nodes(
|
264
|
+
num_nodes=num_supernodes, state_factory=state_factory
|
265
|
+
)
|
60
266
|
|
61
267
|
# Construct mapping of NodeStates
|
62
268
|
node_states: Dict[int, NodeState] = {}
|
@@ -69,7 +275,6 @@ def start_vce(
|
|
69
275
|
|
70
276
|
try:
|
71
277
|
backend_type = supported_backends[backend_name]
|
72
|
-
_ = backend_type(backend_config, work_dir=working_dir)
|
73
278
|
except KeyError as ex:
|
74
279
|
log(
|
75
280
|
ERROR,
|
@@ -83,10 +288,25 @@ def start_vce(
|
|
83
288
|
|
84
289
|
raise ex
|
85
290
|
|
291
|
+
def backend_fn() -> Backend:
|
292
|
+
"""Instantiate a Backend."""
|
293
|
+
return backend_type(backend_config, work_dir=working_dir)
|
294
|
+
|
86
295
|
log(INFO, "client_app_module_name = %s", client_app_module_name)
|
87
296
|
|
88
297
|
def _load() -> ClientApp:
|
89
298
|
app: ClientApp = load_client_app(client_app_module_name)
|
90
299
|
return app
|
91
300
|
|
92
|
-
|
301
|
+
app_fn = _load
|
302
|
+
|
303
|
+
asyncio.run(
|
304
|
+
run(
|
305
|
+
app_fn,
|
306
|
+
backend_fn,
|
307
|
+
nodes_mapping,
|
308
|
+
state_factory,
|
309
|
+
node_states,
|
310
|
+
f_stop,
|
311
|
+
)
|
312
|
+
)
|
flwr/server/typing.py
CHANGED
@@ -0,0 +1,22 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Workflows."""
|
16
|
+
|
17
|
+
|
18
|
+
from .default_workflows import DefaultWorkflow
|
19
|
+
|
20
|
+
__all__ = [
|
21
|
+
"DefaultWorkflow",
|
22
|
+
]
|