flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/client/mod/__init__.py +3 -2
- flwr/client/mod/centraldp_mods.py +63 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
- flwr/common/differential_privacy.py +77 -0
- flwr/common/differential_privacy_constants.py +1 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
- flwr/proto/error_pb2.py +26 -0
- flwr/proto/error_pb2.pyi +25 -0
- flwr/proto/error_pb2_grpc.py +4 -0
- flwr/proto/error_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +8 -7
- flwr/proto/task_pb2.pyi +7 -2
- flwr/server/__init__.py +4 -0
- flwr/server/app.py +8 -31
- flwr/server/client_proxy.py +5 -0
- flwr/server/compat/__init__.py +2 -0
- flwr/server/compat/app.py +7 -88
- flwr/server/compat/app_utils.py +102 -0
- flwr/server/compat/driver_client_proxy.py +22 -10
- flwr/server/compat/legacy_context.py +55 -0
- flwr/server/run_serverapp.py +1 -1
- flwr/server/server.py +18 -8
- flwr/server/strategy/__init__.py +24 -14
- flwr/server/strategy/dp_adaptive_clipping.py +449 -0
- flwr/server/strategy/dp_fixed_clipping.py +5 -7
- flwr/server/superlink/driver/driver_grpc.py +54 -0
- flwr/server/superlink/driver/driver_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
- flwr/server/superlink/fleet/vce/vce_api.py +236 -16
- flwr/server/typing.py +1 -0
- flwr/server/workflow/__init__.py +22 -0
- flwr/server/workflow/default_workflows.py +357 -0
- flwr/simulation/__init__.py +3 -0
- flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
- flwr/simulation/run_simulation.py +177 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
@@ -0,0 +1,357 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Legacy default workflows."""
|
16
|
+
|
17
|
+
|
18
|
+
import timeit
|
19
|
+
from logging import DEBUG, INFO
|
20
|
+
from typing import Optional, cast
|
21
|
+
|
22
|
+
import flwr.common.recordset_compat as compat
|
23
|
+
from flwr.common import ConfigsRecord, Context, GetParametersIns, log
|
24
|
+
from flwr.common.constant import (
|
25
|
+
MESSAGE_TYPE_EVALUATE,
|
26
|
+
MESSAGE_TYPE_FIT,
|
27
|
+
MESSAGE_TYPE_GET_PARAMETERS,
|
28
|
+
)
|
29
|
+
|
30
|
+
from ..compat.app_utils import start_update_client_manager_thread
|
31
|
+
from ..compat.legacy_context import LegacyContext
|
32
|
+
from ..driver import Driver
|
33
|
+
from ..typing import Workflow
|
34
|
+
|
35
|
+
KEY_CURRENT_ROUND = "current_round"
|
36
|
+
KEY_START_TIME = "start_time"
|
37
|
+
CONFIGS_RECORD_KEY = "config"
|
38
|
+
PARAMS_RECORD_KEY = "parameters"
|
39
|
+
|
40
|
+
|
41
|
+
class DefaultWorkflow:
|
42
|
+
"""Default workflow in Flower."""
|
43
|
+
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
fit_workflow: Optional[Workflow] = None,
|
47
|
+
evaluate_workflow: Optional[Workflow] = None,
|
48
|
+
) -> None:
|
49
|
+
if fit_workflow is None:
|
50
|
+
fit_workflow = default_fit_workflow
|
51
|
+
if evaluate_workflow is None:
|
52
|
+
evaluate_workflow = default_evaluate_workflow
|
53
|
+
self.fit_workflow: Workflow = fit_workflow
|
54
|
+
self.evaluate_workflow: Workflow = evaluate_workflow
|
55
|
+
|
56
|
+
def __call__(self, driver: Driver, context: Context) -> None:
|
57
|
+
"""Execute the workflow."""
|
58
|
+
if not isinstance(context, LegacyContext):
|
59
|
+
raise TypeError(
|
60
|
+
f"Expect a LegacyContext, but get {type(context).__name__}."
|
61
|
+
)
|
62
|
+
|
63
|
+
# Start the thread updating nodes
|
64
|
+
thread, f_stop = start_update_client_manager_thread(
|
65
|
+
driver, context.client_manager
|
66
|
+
)
|
67
|
+
|
68
|
+
# Initialize parameters
|
69
|
+
default_init_params_workflow(driver, context)
|
70
|
+
|
71
|
+
# Run federated learning for num_rounds
|
72
|
+
log(INFO, "FL starting")
|
73
|
+
start_time = timeit.default_timer()
|
74
|
+
cfg = ConfigsRecord()
|
75
|
+
cfg[KEY_START_TIME] = start_time
|
76
|
+
context.state.configs_records[CONFIGS_RECORD_KEY] = cfg
|
77
|
+
|
78
|
+
for current_round in range(1, context.config.num_rounds + 1):
|
79
|
+
cfg[KEY_CURRENT_ROUND] = current_round
|
80
|
+
|
81
|
+
# Fit round
|
82
|
+
self.fit_workflow(driver, context)
|
83
|
+
|
84
|
+
# Centralized evaluation
|
85
|
+
default_centralized_evaluation_workflow(driver, context)
|
86
|
+
|
87
|
+
# Evaluate round
|
88
|
+
self.evaluate_workflow(driver, context)
|
89
|
+
|
90
|
+
# Bookkeeping
|
91
|
+
end_time = timeit.default_timer()
|
92
|
+
elapsed = end_time - start_time
|
93
|
+
log(INFO, "FL finished in %s", elapsed)
|
94
|
+
|
95
|
+
# Log results
|
96
|
+
hist = context.history
|
97
|
+
log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed))
|
98
|
+
log(
|
99
|
+
INFO,
|
100
|
+
"app_fit: metrics_distributed_fit %s",
|
101
|
+
str(hist.metrics_distributed_fit),
|
102
|
+
)
|
103
|
+
log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed))
|
104
|
+
log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized))
|
105
|
+
log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized))
|
106
|
+
|
107
|
+
# Terminate the thread
|
108
|
+
f_stop.set()
|
109
|
+
del driver
|
110
|
+
thread.join()
|
111
|
+
|
112
|
+
|
113
|
+
def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
114
|
+
"""Execute the default workflow for parameters initialization."""
|
115
|
+
if not isinstance(context, LegacyContext):
|
116
|
+
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
117
|
+
|
118
|
+
log(INFO, "Initializing global parameters")
|
119
|
+
parameters = context.strategy.initialize_parameters(
|
120
|
+
client_manager=context.client_manager
|
121
|
+
)
|
122
|
+
if parameters is not None:
|
123
|
+
log(INFO, "Using initial parameters provided by strategy")
|
124
|
+
paramsrecord = compat.parameters_to_parametersrecord(
|
125
|
+
parameters, keep_input=True
|
126
|
+
)
|
127
|
+
else:
|
128
|
+
# Get initial parameters from one of the clients
|
129
|
+
log(INFO, "Requesting initial parameters from one random client")
|
130
|
+
random_client = context.client_manager.sample(1)[0]
|
131
|
+
# Send GetParametersIns and get the response
|
132
|
+
content = compat.getparametersins_to_recordset(GetParametersIns({}))
|
133
|
+
messages = driver.send_and_receive(
|
134
|
+
[
|
135
|
+
driver.create_message(
|
136
|
+
content=content,
|
137
|
+
message_type=MESSAGE_TYPE_GET_PARAMETERS,
|
138
|
+
dst_node_id=random_client.node_id,
|
139
|
+
group_id="",
|
140
|
+
ttl="",
|
141
|
+
)
|
142
|
+
]
|
143
|
+
)
|
144
|
+
log(INFO, "Received initial parameters from one random client")
|
145
|
+
msg = list(messages)[0]
|
146
|
+
paramsrecord = next(iter(msg.content.parameters_records.values()))
|
147
|
+
|
148
|
+
context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord
|
149
|
+
|
150
|
+
# Evaluate initial parameters
|
151
|
+
log(INFO, "Evaluating initial parameters")
|
152
|
+
parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
|
153
|
+
res = context.strategy.evaluate(0, parameters=parameters)
|
154
|
+
if res is not None:
|
155
|
+
log(
|
156
|
+
INFO,
|
157
|
+
"initial parameters (loss, other metrics): %s, %s",
|
158
|
+
res[0],
|
159
|
+
res[1],
|
160
|
+
)
|
161
|
+
context.history.add_loss_centralized(server_round=0, loss=res[0])
|
162
|
+
context.history.add_metrics_centralized(server_round=0, metrics=res[1])
|
163
|
+
|
164
|
+
|
165
|
+
def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None:
|
166
|
+
"""Execute the default workflow for centralized evaluation."""
|
167
|
+
if not isinstance(context, LegacyContext):
|
168
|
+
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
169
|
+
|
170
|
+
# Retrieve current_round and start_time from the context
|
171
|
+
cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
|
172
|
+
current_round = cast(int, cfg[KEY_CURRENT_ROUND])
|
173
|
+
start_time = cast(float, cfg[KEY_START_TIME])
|
174
|
+
|
175
|
+
# Centralized evaluation
|
176
|
+
parameters = compat.parametersrecord_to_parameters(
|
177
|
+
record=context.state.parameters_records[PARAMS_RECORD_KEY],
|
178
|
+
keep_input=True,
|
179
|
+
)
|
180
|
+
res_cen = context.strategy.evaluate(current_round, parameters=parameters)
|
181
|
+
if res_cen is not None:
|
182
|
+
loss_cen, metrics_cen = res_cen
|
183
|
+
log(
|
184
|
+
INFO,
|
185
|
+
"fit progress: (%s, %s, %s, %s)",
|
186
|
+
current_round,
|
187
|
+
loss_cen,
|
188
|
+
metrics_cen,
|
189
|
+
timeit.default_timer() - start_time,
|
190
|
+
)
|
191
|
+
context.history.add_loss_centralized(server_round=current_round, loss=loss_cen)
|
192
|
+
context.history.add_metrics_centralized(
|
193
|
+
server_round=current_round, metrics=metrics_cen
|
194
|
+
)
|
195
|
+
|
196
|
+
|
197
|
+
def default_fit_workflow(driver: Driver, context: Context) -> None:
|
198
|
+
"""Execute the default workflow for a single fit round."""
|
199
|
+
if not isinstance(context, LegacyContext):
|
200
|
+
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
201
|
+
|
202
|
+
# Get current_round and parameters
|
203
|
+
cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
|
204
|
+
current_round = cast(int, cfg[KEY_CURRENT_ROUND])
|
205
|
+
parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY]
|
206
|
+
parameters = compat.parametersrecord_to_parameters(
|
207
|
+
parametersrecord, keep_input=True
|
208
|
+
)
|
209
|
+
|
210
|
+
# Get clients and their respective instructions from strategy
|
211
|
+
client_instructions = context.strategy.configure_fit(
|
212
|
+
server_round=current_round,
|
213
|
+
parameters=parameters,
|
214
|
+
client_manager=context.client_manager,
|
215
|
+
)
|
216
|
+
|
217
|
+
if not client_instructions:
|
218
|
+
log(INFO, "fit_round %s: no clients selected, cancel", current_round)
|
219
|
+
return
|
220
|
+
log(
|
221
|
+
DEBUG,
|
222
|
+
"fit_round %s: strategy sampled %s clients (out of %s)",
|
223
|
+
current_round,
|
224
|
+
len(client_instructions),
|
225
|
+
context.client_manager.num_available(),
|
226
|
+
)
|
227
|
+
|
228
|
+
# Build dictionary mapping node_id to ClientProxy
|
229
|
+
node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions}
|
230
|
+
|
231
|
+
# Build out messages
|
232
|
+
out_messages = [
|
233
|
+
driver.create_message(
|
234
|
+
content=compat.fitins_to_recordset(fitins, True),
|
235
|
+
message_type=MESSAGE_TYPE_FIT,
|
236
|
+
dst_node_id=proxy.node_id,
|
237
|
+
group_id="",
|
238
|
+
ttl="",
|
239
|
+
)
|
240
|
+
for proxy, fitins in client_instructions
|
241
|
+
]
|
242
|
+
|
243
|
+
# Send instructions to clients and
|
244
|
+
# collect `fit` results from all clients participating in this round
|
245
|
+
messages = list(driver.send_and_receive(out_messages))
|
246
|
+
del out_messages
|
247
|
+
|
248
|
+
# No exception/failure handling currently
|
249
|
+
log(
|
250
|
+
DEBUG,
|
251
|
+
"fit_round %s received %s results and %s failures",
|
252
|
+
current_round,
|
253
|
+
len(messages),
|
254
|
+
0,
|
255
|
+
)
|
256
|
+
|
257
|
+
# Aggregate training results
|
258
|
+
results = [
|
259
|
+
(
|
260
|
+
node_id_to_proxy[msg.metadata.src_node_id],
|
261
|
+
compat.recordset_to_fitres(msg.content, False),
|
262
|
+
)
|
263
|
+
for msg in messages
|
264
|
+
]
|
265
|
+
aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
|
266
|
+
parameters_aggregated, metrics_aggregated = aggregated_result
|
267
|
+
|
268
|
+
# Update the parameters and write history
|
269
|
+
if parameters_aggregated:
|
270
|
+
paramsrecord = compat.parameters_to_parametersrecord(
|
271
|
+
parameters_aggregated, True
|
272
|
+
)
|
273
|
+
context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord
|
274
|
+
context.history.add_metrics_distributed_fit(
|
275
|
+
server_round=current_round, metrics=metrics_aggregated
|
276
|
+
)
|
277
|
+
|
278
|
+
|
279
|
+
def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
280
|
+
"""Execute the default workflow for a single evaluate round."""
|
281
|
+
if not isinstance(context, LegacyContext):
|
282
|
+
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
283
|
+
|
284
|
+
# Get current_round and parameters
|
285
|
+
cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
|
286
|
+
current_round = cast(int, cfg[KEY_CURRENT_ROUND])
|
287
|
+
parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY]
|
288
|
+
parameters = compat.parametersrecord_to_parameters(
|
289
|
+
parametersrecord, keep_input=True
|
290
|
+
)
|
291
|
+
|
292
|
+
# Get clients and their respective instructions from strategy
|
293
|
+
client_instructions = context.strategy.configure_evaluate(
|
294
|
+
server_round=current_round,
|
295
|
+
parameters=parameters,
|
296
|
+
client_manager=context.client_manager,
|
297
|
+
)
|
298
|
+
if not client_instructions:
|
299
|
+
log(INFO, "evaluate_round %s: no clients selected, cancel", current_round)
|
300
|
+
return
|
301
|
+
log(
|
302
|
+
DEBUG,
|
303
|
+
"evaluate_round %s: strategy sampled %s clients (out of %s)",
|
304
|
+
current_round,
|
305
|
+
len(client_instructions),
|
306
|
+
context.client_manager.num_available(),
|
307
|
+
)
|
308
|
+
|
309
|
+
# Build dictionary mapping node_id to ClientProxy
|
310
|
+
node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions}
|
311
|
+
|
312
|
+
# Build out messages
|
313
|
+
out_messages = [
|
314
|
+
driver.create_message(
|
315
|
+
content=compat.evaluateins_to_recordset(evalins, True),
|
316
|
+
message_type=MESSAGE_TYPE_EVALUATE,
|
317
|
+
dst_node_id=proxy.node_id,
|
318
|
+
group_id="",
|
319
|
+
ttl="",
|
320
|
+
)
|
321
|
+
for proxy, evalins in client_instructions
|
322
|
+
]
|
323
|
+
|
324
|
+
# Send instructions to clients and
|
325
|
+
# collect `evaluate` results from all clients participating in this round
|
326
|
+
messages = list(driver.send_and_receive(out_messages))
|
327
|
+
del out_messages
|
328
|
+
|
329
|
+
# No exception/failure handling currently
|
330
|
+
log(
|
331
|
+
DEBUG,
|
332
|
+
"evaluate_round %s received %s results and %s failures",
|
333
|
+
current_round,
|
334
|
+
len(messages),
|
335
|
+
0,
|
336
|
+
)
|
337
|
+
|
338
|
+
# Aggregate the evaluation results
|
339
|
+
results = [
|
340
|
+
(
|
341
|
+
node_id_to_proxy[msg.metadata.src_node_id],
|
342
|
+
compat.recordset_to_evaluateres(msg.content),
|
343
|
+
)
|
344
|
+
for msg in messages
|
345
|
+
]
|
346
|
+
aggregated_result = context.strategy.aggregate_evaluate(current_round, results, [])
|
347
|
+
|
348
|
+
loss_aggregated, metrics_aggregated = aggregated_result
|
349
|
+
|
350
|
+
# Write history
|
351
|
+
if loss_aggregated is not None:
|
352
|
+
context.history.add_loss_distributed(
|
353
|
+
server_round=current_round, loss=loss_aggregated
|
354
|
+
)
|
355
|
+
context.history.add_metrics_distributed(
|
356
|
+
server_round=current_round, metrics=metrics_aggregated
|
357
|
+
)
|
flwr/simulation/__init__.py
CHANGED
@@ -17,6 +17,8 @@
|
|
17
17
|
|
18
18
|
import importlib
|
19
19
|
|
20
|
+
from flwr.simulation.run_simulation import run_simulation
|
21
|
+
|
20
22
|
is_ray_installed = importlib.util.find_spec("ray") is not None
|
21
23
|
|
22
24
|
if is_ray_installed:
|
@@ -36,4 +38,5 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
|
|
36
38
|
|
37
39
|
__all__ = [
|
38
40
|
"start_simulation",
|
41
|
+
"run_simulation",
|
39
42
|
]
|
@@ -98,6 +98,7 @@ class RayActorClientProxy(ClientProxy):
|
|
98
98
|
recordset: RecordSet,
|
99
99
|
message_type: str,
|
100
100
|
timeout: Optional[float],
|
101
|
+
group_id: Optional[int],
|
101
102
|
) -> Message:
|
102
103
|
"""Wrap a RecordSet inside a Message."""
|
103
104
|
return Message(
|
@@ -105,7 +106,7 @@ class RayActorClientProxy(ClientProxy):
|
|
105
106
|
metadata=Metadata(
|
106
107
|
run_id=0,
|
107
108
|
message_id="",
|
108
|
-
group_id="",
|
109
|
+
group_id=str(group_id) if group_id is not None else "",
|
109
110
|
src_node_id=0,
|
110
111
|
dst_node_id=int(self.cid),
|
111
112
|
reply_to_message="",
|
@@ -116,7 +117,10 @@ class RayActorClientProxy(ClientProxy):
|
|
116
117
|
)
|
117
118
|
|
118
119
|
def get_properties(
|
119
|
-
self,
|
120
|
+
self,
|
121
|
+
ins: common.GetPropertiesIns,
|
122
|
+
timeout: Optional[float],
|
123
|
+
group_id: Optional[int],
|
120
124
|
) -> common.GetPropertiesRes:
|
121
125
|
"""Return client's properties."""
|
122
126
|
recordset = getpropertiesins_to_recordset(ins)
|
@@ -124,6 +128,7 @@ class RayActorClientProxy(ClientProxy):
|
|
124
128
|
recordset,
|
125
129
|
message_type=MESSAGE_TYPE_GET_PROPERTIES,
|
126
130
|
timeout=timeout,
|
131
|
+
group_id=group_id,
|
127
132
|
)
|
128
133
|
|
129
134
|
message_out = self._submit_job(message, timeout)
|
@@ -131,7 +136,10 @@ class RayActorClientProxy(ClientProxy):
|
|
131
136
|
return recordset_to_getpropertiesres(message_out.content)
|
132
137
|
|
133
138
|
def get_parameters(
|
134
|
-
self,
|
139
|
+
self,
|
140
|
+
ins: common.GetParametersIns,
|
141
|
+
timeout: Optional[float],
|
142
|
+
group_id: Optional[int],
|
135
143
|
) -> common.GetParametersRes:
|
136
144
|
"""Return the current local model parameters."""
|
137
145
|
recordset = getparametersins_to_recordset(ins)
|
@@ -139,19 +147,25 @@ class RayActorClientProxy(ClientProxy):
|
|
139
147
|
recordset,
|
140
148
|
message_type=MESSAGE_TYPE_GET_PARAMETERS,
|
141
149
|
timeout=timeout,
|
150
|
+
group_id=group_id,
|
142
151
|
)
|
143
152
|
|
144
153
|
message_out = self._submit_job(message, timeout)
|
145
154
|
|
146
155
|
return recordset_to_getparametersres(message_out.content, keep_input=False)
|
147
156
|
|
148
|
-
def fit(
|
157
|
+
def fit(
|
158
|
+
self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
|
159
|
+
) -> common.FitRes:
|
149
160
|
"""Train model parameters on the locally held dataset."""
|
150
161
|
recordset = fitins_to_recordset(
|
151
162
|
ins, keep_input=True
|
152
163
|
) # This must stay TRUE since ins are in-memory
|
153
164
|
message = self._wrap_recordset_in_message(
|
154
|
-
recordset,
|
165
|
+
recordset,
|
166
|
+
message_type=MESSAGE_TYPE_FIT,
|
167
|
+
timeout=timeout,
|
168
|
+
group_id=group_id,
|
155
169
|
)
|
156
170
|
|
157
171
|
message_out = self._submit_job(message, timeout)
|
@@ -159,14 +173,17 @@ class RayActorClientProxy(ClientProxy):
|
|
159
173
|
return recordset_to_fitres(message_out.content, keep_input=False)
|
160
174
|
|
161
175
|
def evaluate(
|
162
|
-
self, ins: common.EvaluateIns, timeout: Optional[float]
|
176
|
+
self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
|
163
177
|
) -> common.EvaluateRes:
|
164
178
|
"""Evaluate model parameters on the locally held dataset."""
|
165
179
|
recordset = evaluateins_to_recordset(
|
166
180
|
ins, keep_input=True
|
167
181
|
) # This must stay TRUE since ins are in-memory
|
168
182
|
message = self._wrap_recordset_in_message(
|
169
|
-
recordset,
|
183
|
+
recordset,
|
184
|
+
message_type=MESSAGE_TYPE_EVALUATE,
|
185
|
+
timeout=timeout,
|
186
|
+
group_id=group_id,
|
170
187
|
)
|
171
188
|
|
172
189
|
message_out = self._submit_job(message, timeout)
|
@@ -174,7 +191,10 @@ class RayActorClientProxy(ClientProxy):
|
|
174
191
|
return recordset_to_evaluateres(message_out.content)
|
175
192
|
|
176
193
|
def reconnect(
|
177
|
-
self,
|
194
|
+
self,
|
195
|
+
ins: common.ReconnectIns,
|
196
|
+
timeout: Optional[float],
|
197
|
+
group_id: Optional[int],
|
178
198
|
) -> common.DisconnectRes:
|
179
199
|
"""Disconnect and (optionally) reconnect later."""
|
180
200
|
return common.DisconnectRes(reason="") # Nothing to do here (yet)
|
@@ -0,0 +1,177 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Flower Simulation."""
|
16
|
+
|
17
|
+
import argparse
|
18
|
+
import asyncio
|
19
|
+
import json
|
20
|
+
import threading
|
21
|
+
import traceback
|
22
|
+
from logging import ERROR, INFO, WARNING
|
23
|
+
|
24
|
+
import grpc
|
25
|
+
|
26
|
+
from flwr.common import EventType, event, log
|
27
|
+
from flwr.common.exit_handlers import register_exit_handlers
|
28
|
+
from flwr.server.driver.driver import Driver
|
29
|
+
from flwr.server.run_serverapp import run
|
30
|
+
from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
|
31
|
+
from flwr.server.superlink.fleet import vce
|
32
|
+
from flwr.server.superlink.state import StateFactory
|
33
|
+
from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
|
34
|
+
|
35
|
+
|
36
|
+
def run_simulation() -> None:
|
37
|
+
"""Run Simulation Engine."""
|
38
|
+
args = _parse_args_run_simulation().parse_args()
|
39
|
+
|
40
|
+
# Load JSON config
|
41
|
+
backend_config_dict = json.loads(args.backend_config)
|
42
|
+
|
43
|
+
# Enable GPU memory growth (relevant only for TF)
|
44
|
+
if args.enable_tf_gpu_growth:
|
45
|
+
log(INFO, "Enabling GPU growth for Tensorflow on the main thread.")
|
46
|
+
enable_tf_gpu_growth()
|
47
|
+
# Check that Backend config has also enabled using GPU growth
|
48
|
+
use_tf = backend_config_dict.get("tensorflow", False)
|
49
|
+
if not use_tf:
|
50
|
+
log(WARNING, "Enabling GPU growth for your backend.")
|
51
|
+
backend_config_dict["tensorflow"] = True
|
52
|
+
|
53
|
+
# Convert back to JSON stream
|
54
|
+
backend_config = json.dumps(backend_config_dict)
|
55
|
+
|
56
|
+
# Initialize StateFactory
|
57
|
+
state_factory = StateFactory(":flwr-in-memory-state:")
|
58
|
+
|
59
|
+
# Start Driver API
|
60
|
+
driver_server: grpc.Server = run_driver_api_grpc(
|
61
|
+
address=args.driver_api_address,
|
62
|
+
state_factory=state_factory,
|
63
|
+
certificates=None,
|
64
|
+
)
|
65
|
+
|
66
|
+
# SuperLink with Simulation Engine
|
67
|
+
f_stop = asyncio.Event()
|
68
|
+
superlink_th = threading.Thread(
|
69
|
+
target=vce.start_vce,
|
70
|
+
kwargs={
|
71
|
+
"num_supernodes": args.num_supernodes,
|
72
|
+
"client_app_module_name": args.client_app,
|
73
|
+
"backend_name": args.backend,
|
74
|
+
"backend_config_json_stream": backend_config,
|
75
|
+
"working_dir": args.dir,
|
76
|
+
"state_factory": state_factory,
|
77
|
+
"f_stop": f_stop,
|
78
|
+
},
|
79
|
+
daemon=False,
|
80
|
+
)
|
81
|
+
|
82
|
+
superlink_th.start()
|
83
|
+
event(EventType.RUN_SUPERLINK_ENTER)
|
84
|
+
|
85
|
+
try:
|
86
|
+
# Initialize Driver
|
87
|
+
driver = Driver(
|
88
|
+
driver_service_address=args.driver_api_address,
|
89
|
+
root_certificates=None,
|
90
|
+
)
|
91
|
+
|
92
|
+
# Launch server app
|
93
|
+
run(args.server_app, driver, args.dir)
|
94
|
+
|
95
|
+
except Exception as ex:
|
96
|
+
|
97
|
+
log(ERROR, "An exception occurred: %s", ex)
|
98
|
+
log(ERROR, traceback.format_exc())
|
99
|
+
raise RuntimeError(
|
100
|
+
"An error was encountered by the Simulation Engine. Ending simulation."
|
101
|
+
) from ex
|
102
|
+
|
103
|
+
finally:
|
104
|
+
|
105
|
+
del driver
|
106
|
+
|
107
|
+
# Trigger stop event
|
108
|
+
f_stop.set()
|
109
|
+
|
110
|
+
register_exit_handlers(
|
111
|
+
grpc_servers=[driver_server],
|
112
|
+
bckg_threads=[superlink_th],
|
113
|
+
event_type=EventType.RUN_SUPERLINK_LEAVE,
|
114
|
+
)
|
115
|
+
superlink_th.join()
|
116
|
+
|
117
|
+
|
118
|
+
def _parse_args_run_simulation() -> argparse.ArgumentParser:
|
119
|
+
"""Parse flower-simulation command line arguments."""
|
120
|
+
parser = argparse.ArgumentParser(
|
121
|
+
description="Start a Flower simulation",
|
122
|
+
)
|
123
|
+
parser.add_argument(
|
124
|
+
"--client-app",
|
125
|
+
required=True,
|
126
|
+
help="For example: `client:app` or `project.package.module:wrapper.app`",
|
127
|
+
)
|
128
|
+
parser.add_argument(
|
129
|
+
"--server-app",
|
130
|
+
required=True,
|
131
|
+
help="For example: `server:app` or `project.package.module:wrapper.app`",
|
132
|
+
)
|
133
|
+
parser.add_argument(
|
134
|
+
"--driver-api-address",
|
135
|
+
default="0.0.0.0:9091",
|
136
|
+
type=str,
|
137
|
+
help="For example: `server:app` or `project.package.module:wrapper.app`",
|
138
|
+
)
|
139
|
+
parser.add_argument(
|
140
|
+
"--num-supernodes",
|
141
|
+
type=int,
|
142
|
+
required=True,
|
143
|
+
help="Number of simulated SuperNodes.",
|
144
|
+
)
|
145
|
+
parser.add_argument(
|
146
|
+
"--backend",
|
147
|
+
default="ray",
|
148
|
+
type=str,
|
149
|
+
help="Simulation backend that executes the ClientApp.",
|
150
|
+
)
|
151
|
+
parser.add_argument(
|
152
|
+
"--enable-tf-gpu-growth",
|
153
|
+
action="store_true",
|
154
|
+
help="Enables GPU growth on the main thread. This is desirable if you make "
|
155
|
+
"use of a TensorFlow model on your `ServerApp` while having your `ClientApp` "
|
156
|
+
"running on the same GPU. Without enabling this, you might encounter an "
|
157
|
+
"out-of-memory error becasue TensorFlow by default allocates all GPU memory."
|
158
|
+
"Read mor about how `tf.config.experimental.set_memory_growth()` works in "
|
159
|
+
"the TensorFlow documentation: https://www.tensorflow.org/api/stable.",
|
160
|
+
)
|
161
|
+
parser.add_argument(
|
162
|
+
"--backend-config",
|
163
|
+
type=str,
|
164
|
+
default='{"client_resources": {"num_cpus":2, "num_gpus":0.0}, "tensorflow": 0}',
|
165
|
+
help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
|
166
|
+
"configure a backend. Values supported in <value> are those included by "
|
167
|
+
"`flwr.common.typing.ConfigsRecordValues`. ",
|
168
|
+
)
|
169
|
+
parser.add_argument(
|
170
|
+
"--dir",
|
171
|
+
default="",
|
172
|
+
help="Add specified directory to the PYTHONPATH and load"
|
173
|
+
"ClientApp and ServerApp from there."
|
174
|
+
" Default: current working directory.",
|
175
|
+
)
|
176
|
+
|
177
|
+
return parser
|
{flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: flwr-nightly
|
3
|
-
Version: 1.8.0.
|
3
|
+
Version: 1.8.0.dev20240229
|
4
4
|
Summary: Flower: A Friendly Federated Learning Framework
|
5
5
|
Home-page: https://flower.ai
|
6
6
|
License: Apache-2.0
|
@@ -32,7 +32,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
32
32
|
Classifier: Typing :: Typed
|
33
33
|
Provides-Extra: rest
|
34
34
|
Provides-Extra: simulation
|
35
|
-
Requires-Dist: cryptography (>=
|
35
|
+
Requires-Dist: cryptography (>=42.0.4,<43.0.0)
|
36
36
|
Requires-Dist: grpcio (>=1.60.0,<2.0.0)
|
37
37
|
Requires-Dist: iterators (>=0.0.2,<0.0.3)
|
38
38
|
Requires-Dist: numpy (>=1.21.0,<2.0.0)
|
@@ -83,7 +83,7 @@ design of Flower is based on a few guiding principles:
|
|
83
83
|
|
84
84
|
- **Framework-agnostic**: Different machine learning frameworks have different
|
85
85
|
strengths. Flower can be used with any machine learning framework, for
|
86
|
-
example, [PyTorch](https://pytorch.org), [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [fastai](https://www.fast.ai/), [MLX](https://ml-explore.github.io/mlx/build/html/index.html), [XGBoost](https://xgboost.readthedocs.io/en/stable/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/)
|
86
|
+
example, [PyTorch](https://pytorch.org), [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [MONAI](https://docs.monai.io/en/latest/index.html), [fastai](https://www.fast.ai/), [MLX](https://ml-explore.github.io/mlx/build/html/index.html), [XGBoost](https://xgboost.readthedocs.io/en/stable/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/)
|
87
87
|
for users who enjoy computing gradients by hand.
|
88
88
|
|
89
89
|
- **Understandable**: Flower is written with maintainability in mind. The
|
@@ -180,6 +180,7 @@ Quickstart examples:
|
|
180
180
|
- [Quickstart (fastai)](https://github.com/adap/flower/tree/main/examples/quickstart-fastai)
|
181
181
|
- [Quickstart (Pandas)](https://github.com/adap/flower/tree/main/examples/quickstart-pandas)
|
182
182
|
- [Quickstart (JAX)](https://github.com/adap/flower/tree/main/examples/quickstart-jax)
|
183
|
+
- [Quickstart (MONAI)](https://github.com/adap/flower/tree/main/examples/quickstart-monai)
|
183
184
|
- [Quickstart (scikit-learn)](https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist)
|
184
185
|
- [Quickstart (XGBoost)](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart)
|
185
186
|
- [Quickstart (Android [TFLite])](https://github.com/adap/flower/tree/main/examples/android)
|