flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/mod/__init__.py +3 -2
- flwr/client/mod/centraldp_mods.py +63 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
- flwr/common/differential_privacy.py +77 -0
- flwr/common/differential_privacy_constants.py +1 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
- flwr/proto/error_pb2.py +26 -0
- flwr/proto/error_pb2.pyi +25 -0
- flwr/proto/error_pb2_grpc.py +4 -0
- flwr/proto/error_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +8 -7
- flwr/proto/task_pb2.pyi +7 -2
- flwr/server/__init__.py +4 -0
- flwr/server/app.py +8 -31
- flwr/server/client_proxy.py +5 -0
- flwr/server/compat/__init__.py +2 -0
- flwr/server/compat/app.py +7 -88
- flwr/server/compat/app_utils.py +102 -0
- flwr/server/compat/driver_client_proxy.py +22 -10
- flwr/server/compat/legacy_context.py +55 -0
- flwr/server/run_serverapp.py +1 -1
- flwr/server/server.py +18 -8
- flwr/server/strategy/__init__.py +24 -14
- flwr/server/strategy/dp_adaptive_clipping.py +449 -0
- flwr/server/strategy/dp_fixed_clipping.py +5 -7
- flwr/server/superlink/driver/driver_grpc.py +54 -0
- flwr/server/superlink/driver/driver_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
- flwr/server/superlink/fleet/vce/vce_api.py +236 -16
- flwr/server/typing.py +1 -0
- flwr/server/workflow/__init__.py +22 -0
- flwr/server/workflow/default_workflows.py +357 -0
- flwr/simulation/__init__.py +3 -0
- flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
- flwr/simulation/run_simulation.py +177 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
@@ -0,0 +1,357 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Legacy default workflows."""
|
16
|
+
|
17
|
+
|
18
|
+
import timeit
|
19
|
+
from logging import DEBUG, INFO
|
20
|
+
from typing import Optional, cast
|
21
|
+
|
22
|
+
import flwr.common.recordset_compat as compat
|
23
|
+
from flwr.common import ConfigsRecord, Context, GetParametersIns, log
|
24
|
+
from flwr.common.constant import (
|
25
|
+
MESSAGE_TYPE_EVALUATE,
|
26
|
+
MESSAGE_TYPE_FIT,
|
27
|
+
MESSAGE_TYPE_GET_PARAMETERS,
|
28
|
+
)
|
29
|
+
|
30
|
+
from ..compat.app_utils import start_update_client_manager_thread
|
31
|
+
from ..compat.legacy_context import LegacyContext
|
32
|
+
from ..driver import Driver
|
33
|
+
from ..typing import Workflow
|
34
|
+
|
35
|
+
KEY_CURRENT_ROUND = "current_round"
|
36
|
+
KEY_START_TIME = "start_time"
|
37
|
+
CONFIGS_RECORD_KEY = "config"
|
38
|
+
PARAMS_RECORD_KEY = "parameters"
|
39
|
+
|
40
|
+
|
41
|
+
class DefaultWorkflow:
|
42
|
+
"""Default workflow in Flower."""
|
43
|
+
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
fit_workflow: Optional[Workflow] = None,
|
47
|
+
evaluate_workflow: Optional[Workflow] = None,
|
48
|
+
) -> None:
|
49
|
+
if fit_workflow is None:
|
50
|
+
fit_workflow = default_fit_workflow
|
51
|
+
if evaluate_workflow is None:
|
52
|
+
evaluate_workflow = default_evaluate_workflow
|
53
|
+
self.fit_workflow: Workflow = fit_workflow
|
54
|
+
self.evaluate_workflow: Workflow = evaluate_workflow
|
55
|
+
|
56
|
+
def __call__(self, driver: Driver, context: Context) -> None:
|
57
|
+
"""Execute the workflow."""
|
58
|
+
if not isinstance(context, LegacyContext):
|
59
|
+
raise TypeError(
|
60
|
+
f"Expect a LegacyContext, but get {type(context).__name__}."
|
61
|
+
)
|
62
|
+
|
63
|
+
# Start the thread updating nodes
|
64
|
+
thread, f_stop = start_update_client_manager_thread(
|
65
|
+
driver, context.client_manager
|
66
|
+
)
|
67
|
+
|
68
|
+
# Initialize parameters
|
69
|
+
default_init_params_workflow(driver, context)
|
70
|
+
|
71
|
+
# Run federated learning for num_rounds
|
72
|
+
log(INFO, "FL starting")
|
73
|
+
start_time = timeit.default_timer()
|
74
|
+
cfg = ConfigsRecord()
|
75
|
+
cfg[KEY_START_TIME] = start_time
|
76
|
+
context.state.configs_records[CONFIGS_RECORD_KEY] = cfg
|
77
|
+
|
78
|
+
for current_round in range(1, context.config.num_rounds + 1):
|
79
|
+
cfg[KEY_CURRENT_ROUND] = current_round
|
80
|
+
|
81
|
+
# Fit round
|
82
|
+
self.fit_workflow(driver, context)
|
83
|
+
|
84
|
+
# Centralized evaluation
|
85
|
+
default_centralized_evaluation_workflow(driver, context)
|
86
|
+
|
87
|
+
# Evaluate round
|
88
|
+
self.evaluate_workflow(driver, context)
|
89
|
+
|
90
|
+
# Bookkeeping
|
91
|
+
end_time = timeit.default_timer()
|
92
|
+
elapsed = end_time - start_time
|
93
|
+
log(INFO, "FL finished in %s", elapsed)
|
94
|
+
|
95
|
+
# Log results
|
96
|
+
hist = context.history
|
97
|
+
log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed))
|
98
|
+
log(
|
99
|
+
INFO,
|
100
|
+
"app_fit: metrics_distributed_fit %s",
|
101
|
+
str(hist.metrics_distributed_fit),
|
102
|
+
)
|
103
|
+
log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed))
|
104
|
+
log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized))
|
105
|
+
log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized))
|
106
|
+
|
107
|
+
# Terminate the thread
|
108
|
+
f_stop.set()
|
109
|
+
del driver
|
110
|
+
thread.join()
|
111
|
+
|
112
|
+
|
113
|
+
def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
114
|
+
"""Execute the default workflow for parameters initialization."""
|
115
|
+
if not isinstance(context, LegacyContext):
|
116
|
+
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
117
|
+
|
118
|
+
log(INFO, "Initializing global parameters")
|
119
|
+
parameters = context.strategy.initialize_parameters(
|
120
|
+
client_manager=context.client_manager
|
121
|
+
)
|
122
|
+
if parameters is not None:
|
123
|
+
log(INFO, "Using initial parameters provided by strategy")
|
124
|
+
paramsrecord = compat.parameters_to_parametersrecord(
|
125
|
+
parameters, keep_input=True
|
126
|
+
)
|
127
|
+
else:
|
128
|
+
# Get initial parameters from one of the clients
|
129
|
+
log(INFO, "Requesting initial parameters from one random client")
|
130
|
+
random_client = context.client_manager.sample(1)[0]
|
131
|
+
# Send GetParametersIns and get the response
|
132
|
+
content = compat.getparametersins_to_recordset(GetParametersIns({}))
|
133
|
+
messages = driver.send_and_receive(
|
134
|
+
[
|
135
|
+
driver.create_message(
|
136
|
+
content=content,
|
137
|
+
message_type=MESSAGE_TYPE_GET_PARAMETERS,
|
138
|
+
dst_node_id=random_client.node_id,
|
139
|
+
group_id="",
|
140
|
+
ttl="",
|
141
|
+
)
|
142
|
+
]
|
143
|
+
)
|
144
|
+
log(INFO, "Received initial parameters from one random client")
|
145
|
+
msg = list(messages)[0]
|
146
|
+
paramsrecord = next(iter(msg.content.parameters_records.values()))
|
147
|
+
|
148
|
+
context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord
|
149
|
+
|
150
|
+
# Evaluate initial parameters
|
151
|
+
log(INFO, "Evaluating initial parameters")
|
152
|
+
parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
|
153
|
+
res = context.strategy.evaluate(0, parameters=parameters)
|
154
|
+
if res is not None:
|
155
|
+
log(
|
156
|
+
INFO,
|
157
|
+
"initial parameters (loss, other metrics): %s, %s",
|
158
|
+
res[0],
|
159
|
+
res[1],
|
160
|
+
)
|
161
|
+
context.history.add_loss_centralized(server_round=0, loss=res[0])
|
162
|
+
context.history.add_metrics_centralized(server_round=0, metrics=res[1])
|
163
|
+
|
164
|
+
|
165
|
+
def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None:
|
166
|
+
"""Execute the default workflow for centralized evaluation."""
|
167
|
+
if not isinstance(context, LegacyContext):
|
168
|
+
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
169
|
+
|
170
|
+
# Retrieve current_round and start_time from the context
|
171
|
+
cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
|
172
|
+
current_round = cast(int, cfg[KEY_CURRENT_ROUND])
|
173
|
+
start_time = cast(float, cfg[KEY_START_TIME])
|
174
|
+
|
175
|
+
# Centralized evaluation
|
176
|
+
parameters = compat.parametersrecord_to_parameters(
|
177
|
+
record=context.state.parameters_records[PARAMS_RECORD_KEY],
|
178
|
+
keep_input=True,
|
179
|
+
)
|
180
|
+
res_cen = context.strategy.evaluate(current_round, parameters=parameters)
|
181
|
+
if res_cen is not None:
|
182
|
+
loss_cen, metrics_cen = res_cen
|
183
|
+
log(
|
184
|
+
INFO,
|
185
|
+
"fit progress: (%s, %s, %s, %s)",
|
186
|
+
current_round,
|
187
|
+
loss_cen,
|
188
|
+
metrics_cen,
|
189
|
+
timeit.default_timer() - start_time,
|
190
|
+
)
|
191
|
+
context.history.add_loss_centralized(server_round=current_round, loss=loss_cen)
|
192
|
+
context.history.add_metrics_centralized(
|
193
|
+
server_round=current_round, metrics=metrics_cen
|
194
|
+
)
|
195
|
+
|
196
|
+
|
197
|
+
def default_fit_workflow(driver: Driver, context: Context) -> None:
|
198
|
+
"""Execute the default workflow for a single fit round."""
|
199
|
+
if not isinstance(context, LegacyContext):
|
200
|
+
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
201
|
+
|
202
|
+
# Get current_round and parameters
|
203
|
+
cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
|
204
|
+
current_round = cast(int, cfg[KEY_CURRENT_ROUND])
|
205
|
+
parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY]
|
206
|
+
parameters = compat.parametersrecord_to_parameters(
|
207
|
+
parametersrecord, keep_input=True
|
208
|
+
)
|
209
|
+
|
210
|
+
# Get clients and their respective instructions from strategy
|
211
|
+
client_instructions = context.strategy.configure_fit(
|
212
|
+
server_round=current_round,
|
213
|
+
parameters=parameters,
|
214
|
+
client_manager=context.client_manager,
|
215
|
+
)
|
216
|
+
|
217
|
+
if not client_instructions:
|
218
|
+
log(INFO, "fit_round %s: no clients selected, cancel", current_round)
|
219
|
+
return
|
220
|
+
log(
|
221
|
+
DEBUG,
|
222
|
+
"fit_round %s: strategy sampled %s clients (out of %s)",
|
223
|
+
current_round,
|
224
|
+
len(client_instructions),
|
225
|
+
context.client_manager.num_available(),
|
226
|
+
)
|
227
|
+
|
228
|
+
# Build dictionary mapping node_id to ClientProxy
|
229
|
+
node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions}
|
230
|
+
|
231
|
+
# Build out messages
|
232
|
+
out_messages = [
|
233
|
+
driver.create_message(
|
234
|
+
content=compat.fitins_to_recordset(fitins, True),
|
235
|
+
message_type=MESSAGE_TYPE_FIT,
|
236
|
+
dst_node_id=proxy.node_id,
|
237
|
+
group_id="",
|
238
|
+
ttl="",
|
239
|
+
)
|
240
|
+
for proxy, fitins in client_instructions
|
241
|
+
]
|
242
|
+
|
243
|
+
# Send instructions to clients and
|
244
|
+
# collect `fit` results from all clients participating in this round
|
245
|
+
messages = list(driver.send_and_receive(out_messages))
|
246
|
+
del out_messages
|
247
|
+
|
248
|
+
# No exception/failure handling currently
|
249
|
+
log(
|
250
|
+
DEBUG,
|
251
|
+
"fit_round %s received %s results and %s failures",
|
252
|
+
current_round,
|
253
|
+
len(messages),
|
254
|
+
0,
|
255
|
+
)
|
256
|
+
|
257
|
+
# Aggregate training results
|
258
|
+
results = [
|
259
|
+
(
|
260
|
+
node_id_to_proxy[msg.metadata.src_node_id],
|
261
|
+
compat.recordset_to_fitres(msg.content, False),
|
262
|
+
)
|
263
|
+
for msg in messages
|
264
|
+
]
|
265
|
+
aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
|
266
|
+
parameters_aggregated, metrics_aggregated = aggregated_result
|
267
|
+
|
268
|
+
# Update the parameters and write history
|
269
|
+
if parameters_aggregated:
|
270
|
+
paramsrecord = compat.parameters_to_parametersrecord(
|
271
|
+
parameters_aggregated, True
|
272
|
+
)
|
273
|
+
context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord
|
274
|
+
context.history.add_metrics_distributed_fit(
|
275
|
+
server_round=current_round, metrics=metrics_aggregated
|
276
|
+
)
|
277
|
+
|
278
|
+
|
279
|
+
def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
280
|
+
"""Execute the default workflow for a single evaluate round."""
|
281
|
+
if not isinstance(context, LegacyContext):
|
282
|
+
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
283
|
+
|
284
|
+
# Get current_round and parameters
|
285
|
+
cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
|
286
|
+
current_round = cast(int, cfg[KEY_CURRENT_ROUND])
|
287
|
+
parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY]
|
288
|
+
parameters = compat.parametersrecord_to_parameters(
|
289
|
+
parametersrecord, keep_input=True
|
290
|
+
)
|
291
|
+
|
292
|
+
# Get clients and their respective instructions from strategy
|
293
|
+
client_instructions = context.strategy.configure_evaluate(
|
294
|
+
server_round=current_round,
|
295
|
+
parameters=parameters,
|
296
|
+
client_manager=context.client_manager,
|
297
|
+
)
|
298
|
+
if not client_instructions:
|
299
|
+
log(INFO, "evaluate_round %s: no clients selected, cancel", current_round)
|
300
|
+
return
|
301
|
+
log(
|
302
|
+
DEBUG,
|
303
|
+
"evaluate_round %s: strategy sampled %s clients (out of %s)",
|
304
|
+
current_round,
|
305
|
+
len(client_instructions),
|
306
|
+
context.client_manager.num_available(),
|
307
|
+
)
|
308
|
+
|
309
|
+
# Build dictionary mapping node_id to ClientProxy
|
310
|
+
node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions}
|
311
|
+
|
312
|
+
# Build out messages
|
313
|
+
out_messages = [
|
314
|
+
driver.create_message(
|
315
|
+
content=compat.evaluateins_to_recordset(evalins, True),
|
316
|
+
message_type=MESSAGE_TYPE_EVALUATE,
|
317
|
+
dst_node_id=proxy.node_id,
|
318
|
+
group_id="",
|
319
|
+
ttl="",
|
320
|
+
)
|
321
|
+
for proxy, evalins in client_instructions
|
322
|
+
]
|
323
|
+
|
324
|
+
# Send instructions to clients and
|
325
|
+
# collect `evaluate` results from all clients participating in this round
|
326
|
+
messages = list(driver.send_and_receive(out_messages))
|
327
|
+
del out_messages
|
328
|
+
|
329
|
+
# No exception/failure handling currently
|
330
|
+
log(
|
331
|
+
DEBUG,
|
332
|
+
"evaluate_round %s received %s results and %s failures",
|
333
|
+
current_round,
|
334
|
+
len(messages),
|
335
|
+
0,
|
336
|
+
)
|
337
|
+
|
338
|
+
# Aggregate the evaluation results
|
339
|
+
results = [
|
340
|
+
(
|
341
|
+
node_id_to_proxy[msg.metadata.src_node_id],
|
342
|
+
compat.recordset_to_evaluateres(msg.content),
|
343
|
+
)
|
344
|
+
for msg in messages
|
345
|
+
]
|
346
|
+
aggregated_result = context.strategy.aggregate_evaluate(current_round, results, [])
|
347
|
+
|
348
|
+
loss_aggregated, metrics_aggregated = aggregated_result
|
349
|
+
|
350
|
+
# Write history
|
351
|
+
if loss_aggregated is not None:
|
352
|
+
context.history.add_loss_distributed(
|
353
|
+
server_round=current_round, loss=loss_aggregated
|
354
|
+
)
|
355
|
+
context.history.add_metrics_distributed(
|
356
|
+
server_round=current_round, metrics=metrics_aggregated
|
357
|
+
)
|
flwr/simulation/__init__.py
CHANGED
@@ -17,6 +17,8 @@
|
|
17
17
|
|
18
18
|
import importlib
|
19
19
|
|
20
|
+
from flwr.simulation.run_simulation import run_simulation
|
21
|
+
|
20
22
|
is_ray_installed = importlib.util.find_spec("ray") is not None
|
21
23
|
|
22
24
|
if is_ray_installed:
|
@@ -36,4 +38,5 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
|
|
36
38
|
|
37
39
|
__all__ = [
|
38
40
|
"start_simulation",
|
41
|
+
"run_simulation",
|
39
42
|
]
|
@@ -98,6 +98,7 @@ class RayActorClientProxy(ClientProxy):
|
|
98
98
|
recordset: RecordSet,
|
99
99
|
message_type: str,
|
100
100
|
timeout: Optional[float],
|
101
|
+
group_id: Optional[int],
|
101
102
|
) -> Message:
|
102
103
|
"""Wrap a RecordSet inside a Message."""
|
103
104
|
return Message(
|
@@ -105,7 +106,7 @@ class RayActorClientProxy(ClientProxy):
|
|
105
106
|
metadata=Metadata(
|
106
107
|
run_id=0,
|
107
108
|
message_id="",
|
108
|
-
group_id="",
|
109
|
+
group_id=str(group_id) if group_id is not None else "",
|
109
110
|
src_node_id=0,
|
110
111
|
dst_node_id=int(self.cid),
|
111
112
|
reply_to_message="",
|
@@ -116,7 +117,10 @@ class RayActorClientProxy(ClientProxy):
|
|
116
117
|
)
|
117
118
|
|
118
119
|
def get_properties(
|
119
|
-
self,
|
120
|
+
self,
|
121
|
+
ins: common.GetPropertiesIns,
|
122
|
+
timeout: Optional[float],
|
123
|
+
group_id: Optional[int],
|
120
124
|
) -> common.GetPropertiesRes:
|
121
125
|
"""Return client's properties."""
|
122
126
|
recordset = getpropertiesins_to_recordset(ins)
|
@@ -124,6 +128,7 @@ class RayActorClientProxy(ClientProxy):
|
|
124
128
|
recordset,
|
125
129
|
message_type=MESSAGE_TYPE_GET_PROPERTIES,
|
126
130
|
timeout=timeout,
|
131
|
+
group_id=group_id,
|
127
132
|
)
|
128
133
|
|
129
134
|
message_out = self._submit_job(message, timeout)
|
@@ -131,7 +136,10 @@ class RayActorClientProxy(ClientProxy):
|
|
131
136
|
return recordset_to_getpropertiesres(message_out.content)
|
132
137
|
|
133
138
|
def get_parameters(
|
134
|
-
self,
|
139
|
+
self,
|
140
|
+
ins: common.GetParametersIns,
|
141
|
+
timeout: Optional[float],
|
142
|
+
group_id: Optional[int],
|
135
143
|
) -> common.GetParametersRes:
|
136
144
|
"""Return the current local model parameters."""
|
137
145
|
recordset = getparametersins_to_recordset(ins)
|
@@ -139,19 +147,25 @@ class RayActorClientProxy(ClientProxy):
|
|
139
147
|
recordset,
|
140
148
|
message_type=MESSAGE_TYPE_GET_PARAMETERS,
|
141
149
|
timeout=timeout,
|
150
|
+
group_id=group_id,
|
142
151
|
)
|
143
152
|
|
144
153
|
message_out = self._submit_job(message, timeout)
|
145
154
|
|
146
155
|
return recordset_to_getparametersres(message_out.content, keep_input=False)
|
147
156
|
|
148
|
-
def fit(
|
157
|
+
def fit(
|
158
|
+
self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
|
159
|
+
) -> common.FitRes:
|
149
160
|
"""Train model parameters on the locally held dataset."""
|
150
161
|
recordset = fitins_to_recordset(
|
151
162
|
ins, keep_input=True
|
152
163
|
) # This must stay TRUE since ins are in-memory
|
153
164
|
message = self._wrap_recordset_in_message(
|
154
|
-
recordset,
|
165
|
+
recordset,
|
166
|
+
message_type=MESSAGE_TYPE_FIT,
|
167
|
+
timeout=timeout,
|
168
|
+
group_id=group_id,
|
155
169
|
)
|
156
170
|
|
157
171
|
message_out = self._submit_job(message, timeout)
|
@@ -159,14 +173,17 @@ class RayActorClientProxy(ClientProxy):
|
|
159
173
|
return recordset_to_fitres(message_out.content, keep_input=False)
|
160
174
|
|
161
175
|
def evaluate(
|
162
|
-
self, ins: common.EvaluateIns, timeout: Optional[float]
|
176
|
+
self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
|
163
177
|
) -> common.EvaluateRes:
|
164
178
|
"""Evaluate model parameters on the locally held dataset."""
|
165
179
|
recordset = evaluateins_to_recordset(
|
166
180
|
ins, keep_input=True
|
167
181
|
) # This must stay TRUE since ins are in-memory
|
168
182
|
message = self._wrap_recordset_in_message(
|
169
|
-
recordset,
|
183
|
+
recordset,
|
184
|
+
message_type=MESSAGE_TYPE_EVALUATE,
|
185
|
+
timeout=timeout,
|
186
|
+
group_id=group_id,
|
170
187
|
)
|
171
188
|
|
172
189
|
message_out = self._submit_job(message, timeout)
|
@@ -174,7 +191,10 @@ class RayActorClientProxy(ClientProxy):
|
|
174
191
|
return recordset_to_evaluateres(message_out.content)
|
175
192
|
|
176
193
|
def reconnect(
|
177
|
-
self,
|
194
|
+
self,
|
195
|
+
ins: common.ReconnectIns,
|
196
|
+
timeout: Optional[float],
|
197
|
+
group_id: Optional[int],
|
178
198
|
) -> common.DisconnectRes:
|
179
199
|
"""Disconnect and (optionally) reconnect later."""
|
180
200
|
return common.DisconnectRes(reason="") # Nothing to do here (yet)
|
@@ -0,0 +1,177 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Flower Simulation."""
|
16
|
+
|
17
|
+
import argparse
|
18
|
+
import asyncio
|
19
|
+
import json
|
20
|
+
import threading
|
21
|
+
import traceback
|
22
|
+
from logging import ERROR, INFO, WARNING
|
23
|
+
|
24
|
+
import grpc
|
25
|
+
|
26
|
+
from flwr.common import EventType, event, log
|
27
|
+
from flwr.common.exit_handlers import register_exit_handlers
|
28
|
+
from flwr.server.driver.driver import Driver
|
29
|
+
from flwr.server.run_serverapp import run
|
30
|
+
from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
|
31
|
+
from flwr.server.superlink.fleet import vce
|
32
|
+
from flwr.server.superlink.state import StateFactory
|
33
|
+
from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
|
34
|
+
|
35
|
+
|
36
|
+
def run_simulation() -> None:
|
37
|
+
"""Run Simulation Engine."""
|
38
|
+
args = _parse_args_run_simulation().parse_args()
|
39
|
+
|
40
|
+
# Load JSON config
|
41
|
+
backend_config_dict = json.loads(args.backend_config)
|
42
|
+
|
43
|
+
# Enable GPU memory growth (relevant only for TF)
|
44
|
+
if args.enable_tf_gpu_growth:
|
45
|
+
log(INFO, "Enabling GPU growth for Tensorflow on the main thread.")
|
46
|
+
enable_tf_gpu_growth()
|
47
|
+
# Check that Backend config has also enabled using GPU growth
|
48
|
+
use_tf = backend_config_dict.get("tensorflow", False)
|
49
|
+
if not use_tf:
|
50
|
+
log(WARNING, "Enabling GPU growth for your backend.")
|
51
|
+
backend_config_dict["tensorflow"] = True
|
52
|
+
|
53
|
+
# Convert back to JSON stream
|
54
|
+
backend_config = json.dumps(backend_config_dict)
|
55
|
+
|
56
|
+
# Initialize StateFactory
|
57
|
+
state_factory = StateFactory(":flwr-in-memory-state:")
|
58
|
+
|
59
|
+
# Start Driver API
|
60
|
+
driver_server: grpc.Server = run_driver_api_grpc(
|
61
|
+
address=args.driver_api_address,
|
62
|
+
state_factory=state_factory,
|
63
|
+
certificates=None,
|
64
|
+
)
|
65
|
+
|
66
|
+
# SuperLink with Simulation Engine
|
67
|
+
f_stop = asyncio.Event()
|
68
|
+
superlink_th = threading.Thread(
|
69
|
+
target=vce.start_vce,
|
70
|
+
kwargs={
|
71
|
+
"num_supernodes": args.num_supernodes,
|
72
|
+
"client_app_module_name": args.client_app,
|
73
|
+
"backend_name": args.backend,
|
74
|
+
"backend_config_json_stream": backend_config,
|
75
|
+
"working_dir": args.dir,
|
76
|
+
"state_factory": state_factory,
|
77
|
+
"f_stop": f_stop,
|
78
|
+
},
|
79
|
+
daemon=False,
|
80
|
+
)
|
81
|
+
|
82
|
+
superlink_th.start()
|
83
|
+
event(EventType.RUN_SUPERLINK_ENTER)
|
84
|
+
|
85
|
+
try:
|
86
|
+
# Initialize Driver
|
87
|
+
driver = Driver(
|
88
|
+
driver_service_address=args.driver_api_address,
|
89
|
+
root_certificates=None,
|
90
|
+
)
|
91
|
+
|
92
|
+
# Launch server app
|
93
|
+
run(args.server_app, driver, args.dir)
|
94
|
+
|
95
|
+
except Exception as ex:
|
96
|
+
|
97
|
+
log(ERROR, "An exception occurred: %s", ex)
|
98
|
+
log(ERROR, traceback.format_exc())
|
99
|
+
raise RuntimeError(
|
100
|
+
"An error was encountered by the Simulation Engine. Ending simulation."
|
101
|
+
) from ex
|
102
|
+
|
103
|
+
finally:
|
104
|
+
|
105
|
+
del driver
|
106
|
+
|
107
|
+
# Trigger stop event
|
108
|
+
f_stop.set()
|
109
|
+
|
110
|
+
register_exit_handlers(
|
111
|
+
grpc_servers=[driver_server],
|
112
|
+
bckg_threads=[superlink_th],
|
113
|
+
event_type=EventType.RUN_SUPERLINK_LEAVE,
|
114
|
+
)
|
115
|
+
superlink_th.join()
|
116
|
+
|
117
|
+
|
118
|
+
def _parse_args_run_simulation() -> argparse.ArgumentParser:
|
119
|
+
"""Parse flower-simulation command line arguments."""
|
120
|
+
parser = argparse.ArgumentParser(
|
121
|
+
description="Start a Flower simulation",
|
122
|
+
)
|
123
|
+
parser.add_argument(
|
124
|
+
"--client-app",
|
125
|
+
required=True,
|
126
|
+
help="For example: `client:app` or `project.package.module:wrapper.app`",
|
127
|
+
)
|
128
|
+
parser.add_argument(
|
129
|
+
"--server-app",
|
130
|
+
required=True,
|
131
|
+
help="For example: `server:app` or `project.package.module:wrapper.app`",
|
132
|
+
)
|
133
|
+
parser.add_argument(
|
134
|
+
"--driver-api-address",
|
135
|
+
default="0.0.0.0:9091",
|
136
|
+
type=str,
|
137
|
+
help="For example: `server:app` or `project.package.module:wrapper.app`",
|
138
|
+
)
|
139
|
+
parser.add_argument(
|
140
|
+
"--num-supernodes",
|
141
|
+
type=int,
|
142
|
+
required=True,
|
143
|
+
help="Number of simulated SuperNodes.",
|
144
|
+
)
|
145
|
+
parser.add_argument(
|
146
|
+
"--backend",
|
147
|
+
default="ray",
|
148
|
+
type=str,
|
149
|
+
help="Simulation backend that executes the ClientApp.",
|
150
|
+
)
|
151
|
+
parser.add_argument(
|
152
|
+
"--enable-tf-gpu-growth",
|
153
|
+
action="store_true",
|
154
|
+
help="Enables GPU growth on the main thread. This is desirable if you make "
|
155
|
+
"use of a TensorFlow model on your `ServerApp` while having your `ClientApp` "
|
156
|
+
"running on the same GPU. Without enabling this, you might encounter an "
|
157
|
+
"out-of-memory error becasue TensorFlow by default allocates all GPU memory."
|
158
|
+
"Read mor about how `tf.config.experimental.set_memory_growth()` works in "
|
159
|
+
"the TensorFlow documentation: https://www.tensorflow.org/api/stable.",
|
160
|
+
)
|
161
|
+
parser.add_argument(
|
162
|
+
"--backend-config",
|
163
|
+
type=str,
|
164
|
+
default='{"client_resources": {"num_cpus":2, "num_gpus":0.0}, "tensorflow": 0}',
|
165
|
+
help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
|
166
|
+
"configure a backend. Values supported in <value> are those included by "
|
167
|
+
"`flwr.common.typing.ConfigsRecordValues`. ",
|
168
|
+
)
|
169
|
+
parser.add_argument(
|
170
|
+
"--dir",
|
171
|
+
default="",
|
172
|
+
help="Add specified directory to the PYTHONPATH and load"
|
173
|
+
"ClientApp and ServerApp from there."
|
174
|
+
" Default: current working directory.",
|
175
|
+
)
|
176
|
+
|
177
|
+
return parser
|
{flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: flwr-nightly
|
3
|
-
Version: 1.8.0.
|
3
|
+
Version: 1.8.0.dev20240229
|
4
4
|
Summary: Flower: A Friendly Federated Learning Framework
|
5
5
|
Home-page: https://flower.ai
|
6
6
|
License: Apache-2.0
|
@@ -32,7 +32,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
32
32
|
Classifier: Typing :: Typed
|
33
33
|
Provides-Extra: rest
|
34
34
|
Provides-Extra: simulation
|
35
|
-
Requires-Dist: cryptography (>=
|
35
|
+
Requires-Dist: cryptography (>=42.0.4,<43.0.0)
|
36
36
|
Requires-Dist: grpcio (>=1.60.0,<2.0.0)
|
37
37
|
Requires-Dist: iterators (>=0.0.2,<0.0.3)
|
38
38
|
Requires-Dist: numpy (>=1.21.0,<2.0.0)
|
@@ -83,7 +83,7 @@ design of Flower is based on a few guiding principles:
|
|
83
83
|
|
84
84
|
- **Framework-agnostic**: Different machine learning frameworks have different
|
85
85
|
strengths. Flower can be used with any machine learning framework, for
|
86
|
-
example, [PyTorch](https://pytorch.org), [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [fastai](https://www.fast.ai/), [MLX](https://ml-explore.github.io/mlx/build/html/index.html), [XGBoost](https://xgboost.readthedocs.io/en/stable/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/)
|
86
|
+
example, [PyTorch](https://pytorch.org), [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [MONAI](https://docs.monai.io/en/latest/index.html), [fastai](https://www.fast.ai/), [MLX](https://ml-explore.github.io/mlx/build/html/index.html), [XGBoost](https://xgboost.readthedocs.io/en/stable/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/)
|
87
87
|
for users who enjoy computing gradients by hand.
|
88
88
|
|
89
89
|
- **Understandable**: Flower is written with maintainability in mind. The
|
@@ -180,6 +180,7 @@ Quickstart examples:
|
|
180
180
|
- [Quickstart (fastai)](https://github.com/adap/flower/tree/main/examples/quickstart-fastai)
|
181
181
|
- [Quickstart (Pandas)](https://github.com/adap/flower/tree/main/examples/quickstart-pandas)
|
182
182
|
- [Quickstart (JAX)](https://github.com/adap/flower/tree/main/examples/quickstart-jax)
|
183
|
+
- [Quickstart (MONAI)](https://github.com/adap/flower/tree/main/examples/quickstart-monai)
|
183
184
|
- [Quickstart (scikit-learn)](https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist)
|
184
185
|
- [Quickstart (XGBoost)](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart)
|
185
186
|
- [Quickstart (Android [TFLite])](https://github.com/adap/flower/tree/main/examples/android)
|