flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (42) hide show
  1. flwr/client/mod/__init__.py +3 -2
  2. flwr/client/mod/centraldp_mods.py +63 -2
  3. flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
  4. flwr/common/differential_privacy.py +77 -0
  5. flwr/common/differential_privacy_constants.py +1 -0
  6. flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
  7. flwr/proto/error_pb2.py +26 -0
  8. flwr/proto/error_pb2.pyi +25 -0
  9. flwr/proto/error_pb2_grpc.py +4 -0
  10. flwr/proto/error_pb2_grpc.pyi +4 -0
  11. flwr/proto/task_pb2.py +8 -7
  12. flwr/proto/task_pb2.pyi +7 -2
  13. flwr/server/__init__.py +4 -0
  14. flwr/server/app.py +8 -31
  15. flwr/server/client_proxy.py +5 -0
  16. flwr/server/compat/__init__.py +2 -0
  17. flwr/server/compat/app.py +7 -88
  18. flwr/server/compat/app_utils.py +102 -0
  19. flwr/server/compat/driver_client_proxy.py +22 -10
  20. flwr/server/compat/legacy_context.py +55 -0
  21. flwr/server/run_serverapp.py +1 -1
  22. flwr/server/server.py +18 -8
  23. flwr/server/strategy/__init__.py +24 -14
  24. flwr/server/strategy/dp_adaptive_clipping.py +449 -0
  25. flwr/server/strategy/dp_fixed_clipping.py +5 -7
  26. flwr/server/superlink/driver/driver_grpc.py +54 -0
  27. flwr/server/superlink/driver/driver_servicer.py +4 -4
  28. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
  29. flwr/server/superlink/fleet/vce/__init__.py +1 -1
  30. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
  31. flwr/server/superlink/fleet/vce/vce_api.py +236 -16
  32. flwr/server/typing.py +1 -0
  33. flwr/server/workflow/__init__.py +22 -0
  34. flwr/server/workflow/default_workflows.py +357 -0
  35. flwr/simulation/__init__.py +3 -0
  36. flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
  37. flwr/simulation/run_simulation.py +177 -0
  38. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
  39. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
  40. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
  41. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
  42. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
@@ -0,0 +1,357 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Legacy default workflows."""
16
+
17
+
18
+ import timeit
19
+ from logging import DEBUG, INFO
20
+ from typing import Optional, cast
21
+
22
+ import flwr.common.recordset_compat as compat
23
+ from flwr.common import ConfigsRecord, Context, GetParametersIns, log
24
+ from flwr.common.constant import (
25
+ MESSAGE_TYPE_EVALUATE,
26
+ MESSAGE_TYPE_FIT,
27
+ MESSAGE_TYPE_GET_PARAMETERS,
28
+ )
29
+
30
+ from ..compat.app_utils import start_update_client_manager_thread
31
+ from ..compat.legacy_context import LegacyContext
32
+ from ..driver import Driver
33
+ from ..typing import Workflow
34
+
35
+ KEY_CURRENT_ROUND = "current_round"
36
+ KEY_START_TIME = "start_time"
37
+ CONFIGS_RECORD_KEY = "config"
38
+ PARAMS_RECORD_KEY = "parameters"
39
+
40
+
41
+ class DefaultWorkflow:
42
+ """Default workflow in Flower."""
43
+
44
+ def __init__(
45
+ self,
46
+ fit_workflow: Optional[Workflow] = None,
47
+ evaluate_workflow: Optional[Workflow] = None,
48
+ ) -> None:
49
+ if fit_workflow is None:
50
+ fit_workflow = default_fit_workflow
51
+ if evaluate_workflow is None:
52
+ evaluate_workflow = default_evaluate_workflow
53
+ self.fit_workflow: Workflow = fit_workflow
54
+ self.evaluate_workflow: Workflow = evaluate_workflow
55
+
56
+ def __call__(self, driver: Driver, context: Context) -> None:
57
+ """Execute the workflow."""
58
+ if not isinstance(context, LegacyContext):
59
+ raise TypeError(
60
+ f"Expect a LegacyContext, but get {type(context).__name__}."
61
+ )
62
+
63
+ # Start the thread updating nodes
64
+ thread, f_stop = start_update_client_manager_thread(
65
+ driver, context.client_manager
66
+ )
67
+
68
+ # Initialize parameters
69
+ default_init_params_workflow(driver, context)
70
+
71
+ # Run federated learning for num_rounds
72
+ log(INFO, "FL starting")
73
+ start_time = timeit.default_timer()
74
+ cfg = ConfigsRecord()
75
+ cfg[KEY_START_TIME] = start_time
76
+ context.state.configs_records[CONFIGS_RECORD_KEY] = cfg
77
+
78
+ for current_round in range(1, context.config.num_rounds + 1):
79
+ cfg[KEY_CURRENT_ROUND] = current_round
80
+
81
+ # Fit round
82
+ self.fit_workflow(driver, context)
83
+
84
+ # Centralized evaluation
85
+ default_centralized_evaluation_workflow(driver, context)
86
+
87
+ # Evaluate round
88
+ self.evaluate_workflow(driver, context)
89
+
90
+ # Bookkeeping
91
+ end_time = timeit.default_timer()
92
+ elapsed = end_time - start_time
93
+ log(INFO, "FL finished in %s", elapsed)
94
+
95
+ # Log results
96
+ hist = context.history
97
+ log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed))
98
+ log(
99
+ INFO,
100
+ "app_fit: metrics_distributed_fit %s",
101
+ str(hist.metrics_distributed_fit),
102
+ )
103
+ log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed))
104
+ log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized))
105
+ log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized))
106
+
107
+ # Terminate the thread
108
+ f_stop.set()
109
+ del driver
110
+ thread.join()
111
+
112
+
113
+ def default_init_params_workflow(driver: Driver, context: Context) -> None:
114
+ """Execute the default workflow for parameters initialization."""
115
+ if not isinstance(context, LegacyContext):
116
+ raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
117
+
118
+ log(INFO, "Initializing global parameters")
119
+ parameters = context.strategy.initialize_parameters(
120
+ client_manager=context.client_manager
121
+ )
122
+ if parameters is not None:
123
+ log(INFO, "Using initial parameters provided by strategy")
124
+ paramsrecord = compat.parameters_to_parametersrecord(
125
+ parameters, keep_input=True
126
+ )
127
+ else:
128
+ # Get initial parameters from one of the clients
129
+ log(INFO, "Requesting initial parameters from one random client")
130
+ random_client = context.client_manager.sample(1)[0]
131
+ # Send GetParametersIns and get the response
132
+ content = compat.getparametersins_to_recordset(GetParametersIns({}))
133
+ messages = driver.send_and_receive(
134
+ [
135
+ driver.create_message(
136
+ content=content,
137
+ message_type=MESSAGE_TYPE_GET_PARAMETERS,
138
+ dst_node_id=random_client.node_id,
139
+ group_id="",
140
+ ttl="",
141
+ )
142
+ ]
143
+ )
144
+ log(INFO, "Received initial parameters from one random client")
145
+ msg = list(messages)[0]
146
+ paramsrecord = next(iter(msg.content.parameters_records.values()))
147
+
148
+ context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord
149
+
150
+ # Evaluate initial parameters
151
+ log(INFO, "Evaluating initial parameters")
152
+ parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
153
+ res = context.strategy.evaluate(0, parameters=parameters)
154
+ if res is not None:
155
+ log(
156
+ INFO,
157
+ "initial parameters (loss, other metrics): %s, %s",
158
+ res[0],
159
+ res[1],
160
+ )
161
+ context.history.add_loss_centralized(server_round=0, loss=res[0])
162
+ context.history.add_metrics_centralized(server_round=0, metrics=res[1])
163
+
164
+
165
+ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None:
166
+ """Execute the default workflow for centralized evaluation."""
167
+ if not isinstance(context, LegacyContext):
168
+ raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
169
+
170
+ # Retrieve current_round and start_time from the context
171
+ cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
172
+ current_round = cast(int, cfg[KEY_CURRENT_ROUND])
173
+ start_time = cast(float, cfg[KEY_START_TIME])
174
+
175
+ # Centralized evaluation
176
+ parameters = compat.parametersrecord_to_parameters(
177
+ record=context.state.parameters_records[PARAMS_RECORD_KEY],
178
+ keep_input=True,
179
+ )
180
+ res_cen = context.strategy.evaluate(current_round, parameters=parameters)
181
+ if res_cen is not None:
182
+ loss_cen, metrics_cen = res_cen
183
+ log(
184
+ INFO,
185
+ "fit progress: (%s, %s, %s, %s)",
186
+ current_round,
187
+ loss_cen,
188
+ metrics_cen,
189
+ timeit.default_timer() - start_time,
190
+ )
191
+ context.history.add_loss_centralized(server_round=current_round, loss=loss_cen)
192
+ context.history.add_metrics_centralized(
193
+ server_round=current_round, metrics=metrics_cen
194
+ )
195
+
196
+
197
+ def default_fit_workflow(driver: Driver, context: Context) -> None:
198
+ """Execute the default workflow for a single fit round."""
199
+ if not isinstance(context, LegacyContext):
200
+ raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
201
+
202
+ # Get current_round and parameters
203
+ cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
204
+ current_round = cast(int, cfg[KEY_CURRENT_ROUND])
205
+ parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY]
206
+ parameters = compat.parametersrecord_to_parameters(
207
+ parametersrecord, keep_input=True
208
+ )
209
+
210
+ # Get clients and their respective instructions from strategy
211
+ client_instructions = context.strategy.configure_fit(
212
+ server_round=current_round,
213
+ parameters=parameters,
214
+ client_manager=context.client_manager,
215
+ )
216
+
217
+ if not client_instructions:
218
+ log(INFO, "fit_round %s: no clients selected, cancel", current_round)
219
+ return
220
+ log(
221
+ DEBUG,
222
+ "fit_round %s: strategy sampled %s clients (out of %s)",
223
+ current_round,
224
+ len(client_instructions),
225
+ context.client_manager.num_available(),
226
+ )
227
+
228
+ # Build dictionary mapping node_id to ClientProxy
229
+ node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions}
230
+
231
+ # Build out messages
232
+ out_messages = [
233
+ driver.create_message(
234
+ content=compat.fitins_to_recordset(fitins, True),
235
+ message_type=MESSAGE_TYPE_FIT,
236
+ dst_node_id=proxy.node_id,
237
+ group_id="",
238
+ ttl="",
239
+ )
240
+ for proxy, fitins in client_instructions
241
+ ]
242
+
243
+ # Send instructions to clients and
244
+ # collect `fit` results from all clients participating in this round
245
+ messages = list(driver.send_and_receive(out_messages))
246
+ del out_messages
247
+
248
+ # No exception/failure handling currently
249
+ log(
250
+ DEBUG,
251
+ "fit_round %s received %s results and %s failures",
252
+ current_round,
253
+ len(messages),
254
+ 0,
255
+ )
256
+
257
+ # Aggregate training results
258
+ results = [
259
+ (
260
+ node_id_to_proxy[msg.metadata.src_node_id],
261
+ compat.recordset_to_fitres(msg.content, False),
262
+ )
263
+ for msg in messages
264
+ ]
265
+ aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
266
+ parameters_aggregated, metrics_aggregated = aggregated_result
267
+
268
+ # Update the parameters and write history
269
+ if parameters_aggregated:
270
+ paramsrecord = compat.parameters_to_parametersrecord(
271
+ parameters_aggregated, True
272
+ )
273
+ context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord
274
+ context.history.add_metrics_distributed_fit(
275
+ server_round=current_round, metrics=metrics_aggregated
276
+ )
277
+
278
+
279
+ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
280
+ """Execute the default workflow for a single evaluate round."""
281
+ if not isinstance(context, LegacyContext):
282
+ raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
283
+
284
+ # Get current_round and parameters
285
+ cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
286
+ current_round = cast(int, cfg[KEY_CURRENT_ROUND])
287
+ parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY]
288
+ parameters = compat.parametersrecord_to_parameters(
289
+ parametersrecord, keep_input=True
290
+ )
291
+
292
+ # Get clients and their respective instructions from strategy
293
+ client_instructions = context.strategy.configure_evaluate(
294
+ server_round=current_round,
295
+ parameters=parameters,
296
+ client_manager=context.client_manager,
297
+ )
298
+ if not client_instructions:
299
+ log(INFO, "evaluate_round %s: no clients selected, cancel", current_round)
300
+ return
301
+ log(
302
+ DEBUG,
303
+ "evaluate_round %s: strategy sampled %s clients (out of %s)",
304
+ current_round,
305
+ len(client_instructions),
306
+ context.client_manager.num_available(),
307
+ )
308
+
309
+ # Build dictionary mapping node_id to ClientProxy
310
+ node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions}
311
+
312
+ # Build out messages
313
+ out_messages = [
314
+ driver.create_message(
315
+ content=compat.evaluateins_to_recordset(evalins, True),
316
+ message_type=MESSAGE_TYPE_EVALUATE,
317
+ dst_node_id=proxy.node_id,
318
+ group_id="",
319
+ ttl="",
320
+ )
321
+ for proxy, evalins in client_instructions
322
+ ]
323
+
324
+ # Send instructions to clients and
325
+ # collect `evaluate` results from all clients participating in this round
326
+ messages = list(driver.send_and_receive(out_messages))
327
+ del out_messages
328
+
329
+ # No exception/failure handling currently
330
+ log(
331
+ DEBUG,
332
+ "evaluate_round %s received %s results and %s failures",
333
+ current_round,
334
+ len(messages),
335
+ 0,
336
+ )
337
+
338
+ # Aggregate the evaluation results
339
+ results = [
340
+ (
341
+ node_id_to_proxy[msg.metadata.src_node_id],
342
+ compat.recordset_to_evaluateres(msg.content),
343
+ )
344
+ for msg in messages
345
+ ]
346
+ aggregated_result = context.strategy.aggregate_evaluate(current_round, results, [])
347
+
348
+ loss_aggregated, metrics_aggregated = aggregated_result
349
+
350
+ # Write history
351
+ if loss_aggregated is not None:
352
+ context.history.add_loss_distributed(
353
+ server_round=current_round, loss=loss_aggregated
354
+ )
355
+ context.history.add_metrics_distributed(
356
+ server_round=current_round, metrics=metrics_aggregated
357
+ )
@@ -17,6 +17,8 @@
17
17
 
18
18
  import importlib
19
19
 
20
+ from flwr.simulation.run_simulation import run_simulation
21
+
20
22
  is_ray_installed = importlib.util.find_spec("ray") is not None
21
23
 
22
24
  if is_ray_installed:
@@ -36,4 +38,5 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
36
38
 
37
39
  __all__ = [
38
40
  "start_simulation",
41
+ "run_simulation",
39
42
  ]
@@ -98,6 +98,7 @@ class RayActorClientProxy(ClientProxy):
98
98
  recordset: RecordSet,
99
99
  message_type: str,
100
100
  timeout: Optional[float],
101
+ group_id: Optional[int],
101
102
  ) -> Message:
102
103
  """Wrap a RecordSet inside a Message."""
103
104
  return Message(
@@ -105,7 +106,7 @@ class RayActorClientProxy(ClientProxy):
105
106
  metadata=Metadata(
106
107
  run_id=0,
107
108
  message_id="",
108
- group_id="",
109
+ group_id=str(group_id) if group_id is not None else "",
109
110
  src_node_id=0,
110
111
  dst_node_id=int(self.cid),
111
112
  reply_to_message="",
@@ -116,7 +117,10 @@ class RayActorClientProxy(ClientProxy):
116
117
  )
117
118
 
118
119
  def get_properties(
119
- self, ins: common.GetPropertiesIns, timeout: Optional[float]
120
+ self,
121
+ ins: common.GetPropertiesIns,
122
+ timeout: Optional[float],
123
+ group_id: Optional[int],
120
124
  ) -> common.GetPropertiesRes:
121
125
  """Return client's properties."""
122
126
  recordset = getpropertiesins_to_recordset(ins)
@@ -124,6 +128,7 @@ class RayActorClientProxy(ClientProxy):
124
128
  recordset,
125
129
  message_type=MESSAGE_TYPE_GET_PROPERTIES,
126
130
  timeout=timeout,
131
+ group_id=group_id,
127
132
  )
128
133
 
129
134
  message_out = self._submit_job(message, timeout)
@@ -131,7 +136,10 @@ class RayActorClientProxy(ClientProxy):
131
136
  return recordset_to_getpropertiesres(message_out.content)
132
137
 
133
138
  def get_parameters(
134
- self, ins: common.GetParametersIns, timeout: Optional[float]
139
+ self,
140
+ ins: common.GetParametersIns,
141
+ timeout: Optional[float],
142
+ group_id: Optional[int],
135
143
  ) -> common.GetParametersRes:
136
144
  """Return the current local model parameters."""
137
145
  recordset = getparametersins_to_recordset(ins)
@@ -139,19 +147,25 @@ class RayActorClientProxy(ClientProxy):
139
147
  recordset,
140
148
  message_type=MESSAGE_TYPE_GET_PARAMETERS,
141
149
  timeout=timeout,
150
+ group_id=group_id,
142
151
  )
143
152
 
144
153
  message_out = self._submit_job(message, timeout)
145
154
 
146
155
  return recordset_to_getparametersres(message_out.content, keep_input=False)
147
156
 
148
- def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes:
157
+ def fit(
158
+ self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
159
+ ) -> common.FitRes:
149
160
  """Train model parameters on the locally held dataset."""
150
161
  recordset = fitins_to_recordset(
151
162
  ins, keep_input=True
152
163
  ) # This must stay TRUE since ins are in-memory
153
164
  message = self._wrap_recordset_in_message(
154
- recordset, message_type=MESSAGE_TYPE_FIT, timeout=timeout
165
+ recordset,
166
+ message_type=MESSAGE_TYPE_FIT,
167
+ timeout=timeout,
168
+ group_id=group_id,
155
169
  )
156
170
 
157
171
  message_out = self._submit_job(message, timeout)
@@ -159,14 +173,17 @@ class RayActorClientProxy(ClientProxy):
159
173
  return recordset_to_fitres(message_out.content, keep_input=False)
160
174
 
161
175
  def evaluate(
162
- self, ins: common.EvaluateIns, timeout: Optional[float]
176
+ self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
163
177
  ) -> common.EvaluateRes:
164
178
  """Evaluate model parameters on the locally held dataset."""
165
179
  recordset = evaluateins_to_recordset(
166
180
  ins, keep_input=True
167
181
  ) # This must stay TRUE since ins are in-memory
168
182
  message = self._wrap_recordset_in_message(
169
- recordset, message_type=MESSAGE_TYPE_EVALUATE, timeout=timeout
183
+ recordset,
184
+ message_type=MESSAGE_TYPE_EVALUATE,
185
+ timeout=timeout,
186
+ group_id=group_id,
170
187
  )
171
188
 
172
189
  message_out = self._submit_job(message, timeout)
@@ -174,7 +191,10 @@ class RayActorClientProxy(ClientProxy):
174
191
  return recordset_to_evaluateres(message_out.content)
175
192
 
176
193
  def reconnect(
177
- self, ins: common.ReconnectIns, timeout: Optional[float]
194
+ self,
195
+ ins: common.ReconnectIns,
196
+ timeout: Optional[float],
197
+ group_id: Optional[int],
178
198
  ) -> common.DisconnectRes:
179
199
  """Disconnect and (optionally) reconnect later."""
180
200
  return common.DisconnectRes(reason="") # Nothing to do here (yet)
@@ -0,0 +1,177 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower Simulation."""
16
+
17
+ import argparse
18
+ import asyncio
19
+ import json
20
+ import threading
21
+ import traceback
22
+ from logging import ERROR, INFO, WARNING
23
+
24
+ import grpc
25
+
26
+ from flwr.common import EventType, event, log
27
+ from flwr.common.exit_handlers import register_exit_handlers
28
+ from flwr.server.driver.driver import Driver
29
+ from flwr.server.run_serverapp import run
30
+ from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
31
+ from flwr.server.superlink.fleet import vce
32
+ from flwr.server.superlink.state import StateFactory
33
+ from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
34
+
35
+
36
+ def run_simulation() -> None:
37
+ """Run Simulation Engine."""
38
+ args = _parse_args_run_simulation().parse_args()
39
+
40
+ # Load JSON config
41
+ backend_config_dict = json.loads(args.backend_config)
42
+
43
+ # Enable GPU memory growth (relevant only for TF)
44
+ if args.enable_tf_gpu_growth:
45
+ log(INFO, "Enabling GPU growth for Tensorflow on the main thread.")
46
+ enable_tf_gpu_growth()
47
+ # Check that Backend config has also enabled using GPU growth
48
+ use_tf = backend_config_dict.get("tensorflow", False)
49
+ if not use_tf:
50
+ log(WARNING, "Enabling GPU growth for your backend.")
51
+ backend_config_dict["tensorflow"] = True
52
+
53
+ # Convert back to JSON stream
54
+ backend_config = json.dumps(backend_config_dict)
55
+
56
+ # Initialize StateFactory
57
+ state_factory = StateFactory(":flwr-in-memory-state:")
58
+
59
+ # Start Driver API
60
+ driver_server: grpc.Server = run_driver_api_grpc(
61
+ address=args.driver_api_address,
62
+ state_factory=state_factory,
63
+ certificates=None,
64
+ )
65
+
66
+ # SuperLink with Simulation Engine
67
+ f_stop = asyncio.Event()
68
+ superlink_th = threading.Thread(
69
+ target=vce.start_vce,
70
+ kwargs={
71
+ "num_supernodes": args.num_supernodes,
72
+ "client_app_module_name": args.client_app,
73
+ "backend_name": args.backend,
74
+ "backend_config_json_stream": backend_config,
75
+ "working_dir": args.dir,
76
+ "state_factory": state_factory,
77
+ "f_stop": f_stop,
78
+ },
79
+ daemon=False,
80
+ )
81
+
82
+ superlink_th.start()
83
+ event(EventType.RUN_SUPERLINK_ENTER)
84
+
85
+ try:
86
+ # Initialize Driver
87
+ driver = Driver(
88
+ driver_service_address=args.driver_api_address,
89
+ root_certificates=None,
90
+ )
91
+
92
+ # Launch server app
93
+ run(args.server_app, driver, args.dir)
94
+
95
+ except Exception as ex:
96
+
97
+ log(ERROR, "An exception occurred: %s", ex)
98
+ log(ERROR, traceback.format_exc())
99
+ raise RuntimeError(
100
+ "An error was encountered by the Simulation Engine. Ending simulation."
101
+ ) from ex
102
+
103
+ finally:
104
+
105
+ del driver
106
+
107
+ # Trigger stop event
108
+ f_stop.set()
109
+
110
+ register_exit_handlers(
111
+ grpc_servers=[driver_server],
112
+ bckg_threads=[superlink_th],
113
+ event_type=EventType.RUN_SUPERLINK_LEAVE,
114
+ )
115
+ superlink_th.join()
116
+
117
+
118
+ def _parse_args_run_simulation() -> argparse.ArgumentParser:
119
+ """Parse flower-simulation command line arguments."""
120
+ parser = argparse.ArgumentParser(
121
+ description="Start a Flower simulation",
122
+ )
123
+ parser.add_argument(
124
+ "--client-app",
125
+ required=True,
126
+ help="For example: `client:app` or `project.package.module:wrapper.app`",
127
+ )
128
+ parser.add_argument(
129
+ "--server-app",
130
+ required=True,
131
+ help="For example: `server:app` or `project.package.module:wrapper.app`",
132
+ )
133
+ parser.add_argument(
134
+ "--driver-api-address",
135
+ default="0.0.0.0:9091",
136
+ type=str,
137
+ help="For example: `server:app` or `project.package.module:wrapper.app`",
138
+ )
139
+ parser.add_argument(
140
+ "--num-supernodes",
141
+ type=int,
142
+ required=True,
143
+ help="Number of simulated SuperNodes.",
144
+ )
145
+ parser.add_argument(
146
+ "--backend",
147
+ default="ray",
148
+ type=str,
149
+ help="Simulation backend that executes the ClientApp.",
150
+ )
151
+ parser.add_argument(
152
+ "--enable-tf-gpu-growth",
153
+ action="store_true",
154
+ help="Enables GPU growth on the main thread. This is desirable if you make "
155
+ "use of a TensorFlow model on your `ServerApp` while having your `ClientApp` "
156
+ "running on the same GPU. Without enabling this, you might encounter an "
157
+ "out-of-memory error becasue TensorFlow by default allocates all GPU memory."
158
+ "Read mor about how `tf.config.experimental.set_memory_growth()` works in "
159
+ "the TensorFlow documentation: https://www.tensorflow.org/api/stable.",
160
+ )
161
+ parser.add_argument(
162
+ "--backend-config",
163
+ type=str,
164
+ default='{"client_resources": {"num_cpus":2, "num_gpus":0.0}, "tensorflow": 0}',
165
+ help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
166
+ "configure a backend. Values supported in <value> are those included by "
167
+ "`flwr.common.typing.ConfigsRecordValues`. ",
168
+ )
169
+ parser.add_argument(
170
+ "--dir",
171
+ default="",
172
+ help="Add specified directory to the PYTHONPATH and load"
173
+ "ClientApp and ServerApp from there."
174
+ " Default: current working directory.",
175
+ )
176
+
177
+ return parser
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.8.0.dev20240227
3
+ Version: 1.8.0.dev20240229
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0
@@ -32,7 +32,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
32
32
  Classifier: Typing :: Typed
33
33
  Provides-Extra: rest
34
34
  Provides-Extra: simulation
35
- Requires-Dist: cryptography (>=41.0.2,<42.0.0)
35
+ Requires-Dist: cryptography (>=42.0.4,<43.0.0)
36
36
  Requires-Dist: grpcio (>=1.60.0,<2.0.0)
37
37
  Requires-Dist: iterators (>=0.0.2,<0.0.3)
38
38
  Requires-Dist: numpy (>=1.21.0,<2.0.0)
@@ -83,7 +83,7 @@ design of Flower is based on a few guiding principles:
83
83
 
84
84
  - **Framework-agnostic**: Different machine learning frameworks have different
85
85
  strengths. Flower can be used with any machine learning framework, for
86
- example, [PyTorch](https://pytorch.org), [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [fastai](https://www.fast.ai/), [MLX](https://ml-explore.github.io/mlx/build/html/index.html), [XGBoost](https://xgboost.readthedocs.io/en/stable/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/)
86
+ example, [PyTorch](https://pytorch.org), [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [MONAI](https://docs.monai.io/en/latest/index.html), [fastai](https://www.fast.ai/), [MLX](https://ml-explore.github.io/mlx/build/html/index.html), [XGBoost](https://xgboost.readthedocs.io/en/stable/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/)
87
87
  for users who enjoy computing gradients by hand.
88
88
 
89
89
  - **Understandable**: Flower is written with maintainability in mind. The
@@ -180,6 +180,7 @@ Quickstart examples:
180
180
  - [Quickstart (fastai)](https://github.com/adap/flower/tree/main/examples/quickstart-fastai)
181
181
  - [Quickstart (Pandas)](https://github.com/adap/flower/tree/main/examples/quickstart-pandas)
182
182
  - [Quickstart (JAX)](https://github.com/adap/flower/tree/main/examples/quickstart-jax)
183
+ - [Quickstart (MONAI)](https://github.com/adap/flower/tree/main/examples/quickstart-monai)
183
184
  - [Quickstart (scikit-learn)](https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist)
184
185
  - [Quickstart (XGBoost)](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart)
185
186
  - [Quickstart (Android [TFLite])](https://github.com/adap/flower/tree/main/examples/android)