flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. flwr/client/mod/__init__.py +3 -2
  2. flwr/client/mod/centraldp_mods.py +63 -2
  3. flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
  4. flwr/common/differential_privacy.py +77 -0
  5. flwr/common/differential_privacy_constants.py +1 -0
  6. flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
  7. flwr/proto/error_pb2.py +26 -0
  8. flwr/proto/error_pb2.pyi +25 -0
  9. flwr/proto/error_pb2_grpc.py +4 -0
  10. flwr/proto/error_pb2_grpc.pyi +4 -0
  11. flwr/proto/task_pb2.py +8 -7
  12. flwr/proto/task_pb2.pyi +7 -2
  13. flwr/server/__init__.py +4 -0
  14. flwr/server/app.py +8 -31
  15. flwr/server/client_proxy.py +5 -0
  16. flwr/server/compat/__init__.py +2 -0
  17. flwr/server/compat/app.py +7 -88
  18. flwr/server/compat/app_utils.py +102 -0
  19. flwr/server/compat/driver_client_proxy.py +22 -10
  20. flwr/server/compat/legacy_context.py +55 -0
  21. flwr/server/run_serverapp.py +1 -1
  22. flwr/server/server.py +18 -8
  23. flwr/server/strategy/__init__.py +24 -14
  24. flwr/server/strategy/dp_adaptive_clipping.py +449 -0
  25. flwr/server/strategy/dp_fixed_clipping.py +5 -7
  26. flwr/server/superlink/driver/driver_grpc.py +54 -0
  27. flwr/server/superlink/driver/driver_servicer.py +4 -4
  28. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
  29. flwr/server/superlink/fleet/vce/__init__.py +1 -1
  30. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
  31. flwr/server/superlink/fleet/vce/vce_api.py +236 -16
  32. flwr/server/typing.py +1 -0
  33. flwr/server/workflow/__init__.py +22 -0
  34. flwr/server/workflow/default_workflows.py +357 -0
  35. flwr/simulation/__init__.py +3 -0
  36. flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
  37. flwr/simulation/run_simulation.py +177 -0
  38. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
  39. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
  40. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
  41. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
  42. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
@@ -0,0 +1,357 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Legacy default workflows."""
16
+
17
+
18
+ import timeit
19
+ from logging import DEBUG, INFO
20
+ from typing import Optional, cast
21
+
22
+ import flwr.common.recordset_compat as compat
23
+ from flwr.common import ConfigsRecord, Context, GetParametersIns, log
24
+ from flwr.common.constant import (
25
+ MESSAGE_TYPE_EVALUATE,
26
+ MESSAGE_TYPE_FIT,
27
+ MESSAGE_TYPE_GET_PARAMETERS,
28
+ )
29
+
30
+ from ..compat.app_utils import start_update_client_manager_thread
31
+ from ..compat.legacy_context import LegacyContext
32
+ from ..driver import Driver
33
+ from ..typing import Workflow
34
+
35
+ KEY_CURRENT_ROUND = "current_round"
36
+ KEY_START_TIME = "start_time"
37
+ CONFIGS_RECORD_KEY = "config"
38
+ PARAMS_RECORD_KEY = "parameters"
39
+
40
+
41
+ class DefaultWorkflow:
42
+ """Default workflow in Flower."""
43
+
44
+ def __init__(
45
+ self,
46
+ fit_workflow: Optional[Workflow] = None,
47
+ evaluate_workflow: Optional[Workflow] = None,
48
+ ) -> None:
49
+ if fit_workflow is None:
50
+ fit_workflow = default_fit_workflow
51
+ if evaluate_workflow is None:
52
+ evaluate_workflow = default_evaluate_workflow
53
+ self.fit_workflow: Workflow = fit_workflow
54
+ self.evaluate_workflow: Workflow = evaluate_workflow
55
+
56
+ def __call__(self, driver: Driver, context: Context) -> None:
57
+ """Execute the workflow."""
58
+ if not isinstance(context, LegacyContext):
59
+ raise TypeError(
60
+ f"Expect a LegacyContext, but get {type(context).__name__}."
61
+ )
62
+
63
+ # Start the thread updating nodes
64
+ thread, f_stop = start_update_client_manager_thread(
65
+ driver, context.client_manager
66
+ )
67
+
68
+ # Initialize parameters
69
+ default_init_params_workflow(driver, context)
70
+
71
+ # Run federated learning for num_rounds
72
+ log(INFO, "FL starting")
73
+ start_time = timeit.default_timer()
74
+ cfg = ConfigsRecord()
75
+ cfg[KEY_START_TIME] = start_time
76
+ context.state.configs_records[CONFIGS_RECORD_KEY] = cfg
77
+
78
+ for current_round in range(1, context.config.num_rounds + 1):
79
+ cfg[KEY_CURRENT_ROUND] = current_round
80
+
81
+ # Fit round
82
+ self.fit_workflow(driver, context)
83
+
84
+ # Centralized evaluation
85
+ default_centralized_evaluation_workflow(driver, context)
86
+
87
+ # Evaluate round
88
+ self.evaluate_workflow(driver, context)
89
+
90
+ # Bookkeeping
91
+ end_time = timeit.default_timer()
92
+ elapsed = end_time - start_time
93
+ log(INFO, "FL finished in %s", elapsed)
94
+
95
+ # Log results
96
+ hist = context.history
97
+ log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed))
98
+ log(
99
+ INFO,
100
+ "app_fit: metrics_distributed_fit %s",
101
+ str(hist.metrics_distributed_fit),
102
+ )
103
+ log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed))
104
+ log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized))
105
+ log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized))
106
+
107
+ # Terminate the thread
108
+ f_stop.set()
109
+ del driver
110
+ thread.join()
111
+
112
+
113
+ def default_init_params_workflow(driver: Driver, context: Context) -> None:
114
+ """Execute the default workflow for parameters initialization."""
115
+ if not isinstance(context, LegacyContext):
116
+ raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
117
+
118
+ log(INFO, "Initializing global parameters")
119
+ parameters = context.strategy.initialize_parameters(
120
+ client_manager=context.client_manager
121
+ )
122
+ if parameters is not None:
123
+ log(INFO, "Using initial parameters provided by strategy")
124
+ paramsrecord = compat.parameters_to_parametersrecord(
125
+ parameters, keep_input=True
126
+ )
127
+ else:
128
+ # Get initial parameters from one of the clients
129
+ log(INFO, "Requesting initial parameters from one random client")
130
+ random_client = context.client_manager.sample(1)[0]
131
+ # Send GetParametersIns and get the response
132
+ content = compat.getparametersins_to_recordset(GetParametersIns({}))
133
+ messages = driver.send_and_receive(
134
+ [
135
+ driver.create_message(
136
+ content=content,
137
+ message_type=MESSAGE_TYPE_GET_PARAMETERS,
138
+ dst_node_id=random_client.node_id,
139
+ group_id="",
140
+ ttl="",
141
+ )
142
+ ]
143
+ )
144
+ log(INFO, "Received initial parameters from one random client")
145
+ msg = list(messages)[0]
146
+ paramsrecord = next(iter(msg.content.parameters_records.values()))
147
+
148
+ context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord
149
+
150
+ # Evaluate initial parameters
151
+ log(INFO, "Evaluating initial parameters")
152
+ parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
153
+ res = context.strategy.evaluate(0, parameters=parameters)
154
+ if res is not None:
155
+ log(
156
+ INFO,
157
+ "initial parameters (loss, other metrics): %s, %s",
158
+ res[0],
159
+ res[1],
160
+ )
161
+ context.history.add_loss_centralized(server_round=0, loss=res[0])
162
+ context.history.add_metrics_centralized(server_round=0, metrics=res[1])
163
+
164
+
165
+ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None:
166
+ """Execute the default workflow for centralized evaluation."""
167
+ if not isinstance(context, LegacyContext):
168
+ raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
169
+
170
+ # Retrieve current_round and start_time from the context
171
+ cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
172
+ current_round = cast(int, cfg[KEY_CURRENT_ROUND])
173
+ start_time = cast(float, cfg[KEY_START_TIME])
174
+
175
+ # Centralized evaluation
176
+ parameters = compat.parametersrecord_to_parameters(
177
+ record=context.state.parameters_records[PARAMS_RECORD_KEY],
178
+ keep_input=True,
179
+ )
180
+ res_cen = context.strategy.evaluate(current_round, parameters=parameters)
181
+ if res_cen is not None:
182
+ loss_cen, metrics_cen = res_cen
183
+ log(
184
+ INFO,
185
+ "fit progress: (%s, %s, %s, %s)",
186
+ current_round,
187
+ loss_cen,
188
+ metrics_cen,
189
+ timeit.default_timer() - start_time,
190
+ )
191
+ context.history.add_loss_centralized(server_round=current_round, loss=loss_cen)
192
+ context.history.add_metrics_centralized(
193
+ server_round=current_round, metrics=metrics_cen
194
+ )
195
+
196
+
197
+ def default_fit_workflow(driver: Driver, context: Context) -> None:
198
+ """Execute the default workflow for a single fit round."""
199
+ if not isinstance(context, LegacyContext):
200
+ raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
201
+
202
+ # Get current_round and parameters
203
+ cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
204
+ current_round = cast(int, cfg[KEY_CURRENT_ROUND])
205
+ parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY]
206
+ parameters = compat.parametersrecord_to_parameters(
207
+ parametersrecord, keep_input=True
208
+ )
209
+
210
+ # Get clients and their respective instructions from strategy
211
+ client_instructions = context.strategy.configure_fit(
212
+ server_round=current_round,
213
+ parameters=parameters,
214
+ client_manager=context.client_manager,
215
+ )
216
+
217
+ if not client_instructions:
218
+ log(INFO, "fit_round %s: no clients selected, cancel", current_round)
219
+ return
220
+ log(
221
+ DEBUG,
222
+ "fit_round %s: strategy sampled %s clients (out of %s)",
223
+ current_round,
224
+ len(client_instructions),
225
+ context.client_manager.num_available(),
226
+ )
227
+
228
+ # Build dictionary mapping node_id to ClientProxy
229
+ node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions}
230
+
231
+ # Build out messages
232
+ out_messages = [
233
+ driver.create_message(
234
+ content=compat.fitins_to_recordset(fitins, True),
235
+ message_type=MESSAGE_TYPE_FIT,
236
+ dst_node_id=proxy.node_id,
237
+ group_id="",
238
+ ttl="",
239
+ )
240
+ for proxy, fitins in client_instructions
241
+ ]
242
+
243
+ # Send instructions to clients and
244
+ # collect `fit` results from all clients participating in this round
245
+ messages = list(driver.send_and_receive(out_messages))
246
+ del out_messages
247
+
248
+ # No exception/failure handling currently
249
+ log(
250
+ DEBUG,
251
+ "fit_round %s received %s results and %s failures",
252
+ current_round,
253
+ len(messages),
254
+ 0,
255
+ )
256
+
257
+ # Aggregate training results
258
+ results = [
259
+ (
260
+ node_id_to_proxy[msg.metadata.src_node_id],
261
+ compat.recordset_to_fitres(msg.content, False),
262
+ )
263
+ for msg in messages
264
+ ]
265
+ aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
266
+ parameters_aggregated, metrics_aggregated = aggregated_result
267
+
268
+ # Update the parameters and write history
269
+ if parameters_aggregated:
270
+ paramsrecord = compat.parameters_to_parametersrecord(
271
+ parameters_aggregated, True
272
+ )
273
+ context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord
274
+ context.history.add_metrics_distributed_fit(
275
+ server_round=current_round, metrics=metrics_aggregated
276
+ )
277
+
278
+
279
+ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
280
+ """Execute the default workflow for a single evaluate round."""
281
+ if not isinstance(context, LegacyContext):
282
+ raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
283
+
284
+ # Get current_round and parameters
285
+ cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
286
+ current_round = cast(int, cfg[KEY_CURRENT_ROUND])
287
+ parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY]
288
+ parameters = compat.parametersrecord_to_parameters(
289
+ parametersrecord, keep_input=True
290
+ )
291
+
292
+ # Get clients and their respective instructions from strategy
293
+ client_instructions = context.strategy.configure_evaluate(
294
+ server_round=current_round,
295
+ parameters=parameters,
296
+ client_manager=context.client_manager,
297
+ )
298
+ if not client_instructions:
299
+ log(INFO, "evaluate_round %s: no clients selected, cancel", current_round)
300
+ return
301
+ log(
302
+ DEBUG,
303
+ "evaluate_round %s: strategy sampled %s clients (out of %s)",
304
+ current_round,
305
+ len(client_instructions),
306
+ context.client_manager.num_available(),
307
+ )
308
+
309
+ # Build dictionary mapping node_id to ClientProxy
310
+ node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions}
311
+
312
+ # Build out messages
313
+ out_messages = [
314
+ driver.create_message(
315
+ content=compat.evaluateins_to_recordset(evalins, True),
316
+ message_type=MESSAGE_TYPE_EVALUATE,
317
+ dst_node_id=proxy.node_id,
318
+ group_id="",
319
+ ttl="",
320
+ )
321
+ for proxy, evalins in client_instructions
322
+ ]
323
+
324
+ # Send instructions to clients and
325
+ # collect `evaluate` results from all clients participating in this round
326
+ messages = list(driver.send_and_receive(out_messages))
327
+ del out_messages
328
+
329
+ # No exception/failure handling currently
330
+ log(
331
+ DEBUG,
332
+ "evaluate_round %s received %s results and %s failures",
333
+ current_round,
334
+ len(messages),
335
+ 0,
336
+ )
337
+
338
+ # Aggregate the evaluation results
339
+ results = [
340
+ (
341
+ node_id_to_proxy[msg.metadata.src_node_id],
342
+ compat.recordset_to_evaluateres(msg.content),
343
+ )
344
+ for msg in messages
345
+ ]
346
+ aggregated_result = context.strategy.aggregate_evaluate(current_round, results, [])
347
+
348
+ loss_aggregated, metrics_aggregated = aggregated_result
349
+
350
+ # Write history
351
+ if loss_aggregated is not None:
352
+ context.history.add_loss_distributed(
353
+ server_round=current_round, loss=loss_aggregated
354
+ )
355
+ context.history.add_metrics_distributed(
356
+ server_round=current_round, metrics=metrics_aggregated
357
+ )
@@ -17,6 +17,8 @@
17
17
 
18
18
  import importlib
19
19
 
20
+ from flwr.simulation.run_simulation import run_simulation
21
+
20
22
  is_ray_installed = importlib.util.find_spec("ray") is not None
21
23
 
22
24
  if is_ray_installed:
@@ -36,4 +38,5 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
36
38
 
37
39
  __all__ = [
38
40
  "start_simulation",
41
+ "run_simulation",
39
42
  ]
@@ -98,6 +98,7 @@ class RayActorClientProxy(ClientProxy):
98
98
  recordset: RecordSet,
99
99
  message_type: str,
100
100
  timeout: Optional[float],
101
+ group_id: Optional[int],
101
102
  ) -> Message:
102
103
  """Wrap a RecordSet inside a Message."""
103
104
  return Message(
@@ -105,7 +106,7 @@ class RayActorClientProxy(ClientProxy):
105
106
  metadata=Metadata(
106
107
  run_id=0,
107
108
  message_id="",
108
- group_id="",
109
+ group_id=str(group_id) if group_id is not None else "",
109
110
  src_node_id=0,
110
111
  dst_node_id=int(self.cid),
111
112
  reply_to_message="",
@@ -116,7 +117,10 @@ class RayActorClientProxy(ClientProxy):
116
117
  )
117
118
 
118
119
  def get_properties(
119
- self, ins: common.GetPropertiesIns, timeout: Optional[float]
120
+ self,
121
+ ins: common.GetPropertiesIns,
122
+ timeout: Optional[float],
123
+ group_id: Optional[int],
120
124
  ) -> common.GetPropertiesRes:
121
125
  """Return client's properties."""
122
126
  recordset = getpropertiesins_to_recordset(ins)
@@ -124,6 +128,7 @@ class RayActorClientProxy(ClientProxy):
124
128
  recordset,
125
129
  message_type=MESSAGE_TYPE_GET_PROPERTIES,
126
130
  timeout=timeout,
131
+ group_id=group_id,
127
132
  )
128
133
 
129
134
  message_out = self._submit_job(message, timeout)
@@ -131,7 +136,10 @@ class RayActorClientProxy(ClientProxy):
131
136
  return recordset_to_getpropertiesres(message_out.content)
132
137
 
133
138
  def get_parameters(
134
- self, ins: common.GetParametersIns, timeout: Optional[float]
139
+ self,
140
+ ins: common.GetParametersIns,
141
+ timeout: Optional[float],
142
+ group_id: Optional[int],
135
143
  ) -> common.GetParametersRes:
136
144
  """Return the current local model parameters."""
137
145
  recordset = getparametersins_to_recordset(ins)
@@ -139,19 +147,25 @@ class RayActorClientProxy(ClientProxy):
139
147
  recordset,
140
148
  message_type=MESSAGE_TYPE_GET_PARAMETERS,
141
149
  timeout=timeout,
150
+ group_id=group_id,
142
151
  )
143
152
 
144
153
  message_out = self._submit_job(message, timeout)
145
154
 
146
155
  return recordset_to_getparametersres(message_out.content, keep_input=False)
147
156
 
148
- def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes:
157
+ def fit(
158
+ self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
159
+ ) -> common.FitRes:
149
160
  """Train model parameters on the locally held dataset."""
150
161
  recordset = fitins_to_recordset(
151
162
  ins, keep_input=True
152
163
  ) # This must stay TRUE since ins are in-memory
153
164
  message = self._wrap_recordset_in_message(
154
- recordset, message_type=MESSAGE_TYPE_FIT, timeout=timeout
165
+ recordset,
166
+ message_type=MESSAGE_TYPE_FIT,
167
+ timeout=timeout,
168
+ group_id=group_id,
155
169
  )
156
170
 
157
171
  message_out = self._submit_job(message, timeout)
@@ -159,14 +173,17 @@ class RayActorClientProxy(ClientProxy):
159
173
  return recordset_to_fitres(message_out.content, keep_input=False)
160
174
 
161
175
  def evaluate(
162
- self, ins: common.EvaluateIns, timeout: Optional[float]
176
+ self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
163
177
  ) -> common.EvaluateRes:
164
178
  """Evaluate model parameters on the locally held dataset."""
165
179
  recordset = evaluateins_to_recordset(
166
180
  ins, keep_input=True
167
181
  ) # This must stay TRUE since ins are in-memory
168
182
  message = self._wrap_recordset_in_message(
169
- recordset, message_type=MESSAGE_TYPE_EVALUATE, timeout=timeout
183
+ recordset,
184
+ message_type=MESSAGE_TYPE_EVALUATE,
185
+ timeout=timeout,
186
+ group_id=group_id,
170
187
  )
171
188
 
172
189
  message_out = self._submit_job(message, timeout)
@@ -174,7 +191,10 @@ class RayActorClientProxy(ClientProxy):
174
191
  return recordset_to_evaluateres(message_out.content)
175
192
 
176
193
  def reconnect(
177
- self, ins: common.ReconnectIns, timeout: Optional[float]
194
+ self,
195
+ ins: common.ReconnectIns,
196
+ timeout: Optional[float],
197
+ group_id: Optional[int],
178
198
  ) -> common.DisconnectRes:
179
199
  """Disconnect and (optionally) reconnect later."""
180
200
  return common.DisconnectRes(reason="") # Nothing to do here (yet)
@@ -0,0 +1,177 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower Simulation."""
16
+
17
+ import argparse
18
+ import asyncio
19
+ import json
20
+ import threading
21
+ import traceback
22
+ from logging import ERROR, INFO, WARNING
23
+
24
+ import grpc
25
+
26
+ from flwr.common import EventType, event, log
27
+ from flwr.common.exit_handlers import register_exit_handlers
28
+ from flwr.server.driver.driver import Driver
29
+ from flwr.server.run_serverapp import run
30
+ from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
31
+ from flwr.server.superlink.fleet import vce
32
+ from flwr.server.superlink.state import StateFactory
33
+ from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
34
+
35
+
36
+ def run_simulation() -> None:
37
+ """Run Simulation Engine."""
38
+ args = _parse_args_run_simulation().parse_args()
39
+
40
+ # Load JSON config
41
+ backend_config_dict = json.loads(args.backend_config)
42
+
43
+ # Enable GPU memory growth (relevant only for TF)
44
+ if args.enable_tf_gpu_growth:
45
+ log(INFO, "Enabling GPU growth for Tensorflow on the main thread.")
46
+ enable_tf_gpu_growth()
47
+ # Check that Backend config has also enabled using GPU growth
48
+ use_tf = backend_config_dict.get("tensorflow", False)
49
+ if not use_tf:
50
+ log(WARNING, "Enabling GPU growth for your backend.")
51
+ backend_config_dict["tensorflow"] = True
52
+
53
+ # Convert back to JSON stream
54
+ backend_config = json.dumps(backend_config_dict)
55
+
56
+ # Initialize StateFactory
57
+ state_factory = StateFactory(":flwr-in-memory-state:")
58
+
59
+ # Start Driver API
60
+ driver_server: grpc.Server = run_driver_api_grpc(
61
+ address=args.driver_api_address,
62
+ state_factory=state_factory,
63
+ certificates=None,
64
+ )
65
+
66
+ # SuperLink with Simulation Engine
67
+ f_stop = asyncio.Event()
68
+ superlink_th = threading.Thread(
69
+ target=vce.start_vce,
70
+ kwargs={
71
+ "num_supernodes": args.num_supernodes,
72
+ "client_app_module_name": args.client_app,
73
+ "backend_name": args.backend,
74
+ "backend_config_json_stream": backend_config,
75
+ "working_dir": args.dir,
76
+ "state_factory": state_factory,
77
+ "f_stop": f_stop,
78
+ },
79
+ daemon=False,
80
+ )
81
+
82
+ superlink_th.start()
83
+ event(EventType.RUN_SUPERLINK_ENTER)
84
+
85
+ try:
86
+ # Initialize Driver
87
+ driver = Driver(
88
+ driver_service_address=args.driver_api_address,
89
+ root_certificates=None,
90
+ )
91
+
92
+ # Launch server app
93
+ run(args.server_app, driver, args.dir)
94
+
95
+ except Exception as ex:
96
+
97
+ log(ERROR, "An exception occurred: %s", ex)
98
+ log(ERROR, traceback.format_exc())
99
+ raise RuntimeError(
100
+ "An error was encountered by the Simulation Engine. Ending simulation."
101
+ ) from ex
102
+
103
+ finally:
104
+
105
+ del driver
106
+
107
+ # Trigger stop event
108
+ f_stop.set()
109
+
110
+ register_exit_handlers(
111
+ grpc_servers=[driver_server],
112
+ bckg_threads=[superlink_th],
113
+ event_type=EventType.RUN_SUPERLINK_LEAVE,
114
+ )
115
+ superlink_th.join()
116
+
117
+
118
+ def _parse_args_run_simulation() -> argparse.ArgumentParser:
119
+ """Parse flower-simulation command line arguments."""
120
+ parser = argparse.ArgumentParser(
121
+ description="Start a Flower simulation",
122
+ )
123
+ parser.add_argument(
124
+ "--client-app",
125
+ required=True,
126
+ help="For example: `client:app` or `project.package.module:wrapper.app`",
127
+ )
128
+ parser.add_argument(
129
+ "--server-app",
130
+ required=True,
131
+ help="For example: `server:app` or `project.package.module:wrapper.app`",
132
+ )
133
+ parser.add_argument(
134
+ "--driver-api-address",
135
+ default="0.0.0.0:9091",
136
+ type=str,
137
+ help="For example: `server:app` or `project.package.module:wrapper.app`",
138
+ )
139
+ parser.add_argument(
140
+ "--num-supernodes",
141
+ type=int,
142
+ required=True,
143
+ help="Number of simulated SuperNodes.",
144
+ )
145
+ parser.add_argument(
146
+ "--backend",
147
+ default="ray",
148
+ type=str,
149
+ help="Simulation backend that executes the ClientApp.",
150
+ )
151
+ parser.add_argument(
152
+ "--enable-tf-gpu-growth",
153
+ action="store_true",
154
+ help="Enables GPU growth on the main thread. This is desirable if you make "
155
+ "use of a TensorFlow model on your `ServerApp` while having your `ClientApp` "
156
+ "running on the same GPU. Without enabling this, you might encounter an "
157
+ "out-of-memory error becasue TensorFlow by default allocates all GPU memory."
158
+ "Read mor about how `tf.config.experimental.set_memory_growth()` works in "
159
+ "the TensorFlow documentation: https://www.tensorflow.org/api/stable.",
160
+ )
161
+ parser.add_argument(
162
+ "--backend-config",
163
+ type=str,
164
+ default='{"client_resources": {"num_cpus":2, "num_gpus":0.0}, "tensorflow": 0}',
165
+ help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
166
+ "configure a backend. Values supported in <value> are those included by "
167
+ "`flwr.common.typing.ConfigsRecordValues`. ",
168
+ )
169
+ parser.add_argument(
170
+ "--dir",
171
+ default="",
172
+ help="Add specified directory to the PYTHONPATH and load"
173
+ "ClientApp and ServerApp from there."
174
+ " Default: current working directory.",
175
+ )
176
+
177
+ return parser
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.8.0.dev20240227
3
+ Version: 1.8.0.dev20240229
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0
@@ -32,7 +32,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
32
32
  Classifier: Typing :: Typed
33
33
  Provides-Extra: rest
34
34
  Provides-Extra: simulation
35
- Requires-Dist: cryptography (>=41.0.2,<42.0.0)
35
+ Requires-Dist: cryptography (>=42.0.4,<43.0.0)
36
36
  Requires-Dist: grpcio (>=1.60.0,<2.0.0)
37
37
  Requires-Dist: iterators (>=0.0.2,<0.0.3)
38
38
  Requires-Dist: numpy (>=1.21.0,<2.0.0)
@@ -83,7 +83,7 @@ design of Flower is based on a few guiding principles:
83
83
 
84
84
  - **Framework-agnostic**: Different machine learning frameworks have different
85
85
  strengths. Flower can be used with any machine learning framework, for
86
- example, [PyTorch](https://pytorch.org), [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [fastai](https://www.fast.ai/), [MLX](https://ml-explore.github.io/mlx/build/html/index.html), [XGBoost](https://xgboost.readthedocs.io/en/stable/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/)
86
+ example, [PyTorch](https://pytorch.org), [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [MONAI](https://docs.monai.io/en/latest/index.html), [fastai](https://www.fast.ai/), [MLX](https://ml-explore.github.io/mlx/build/html/index.html), [XGBoost](https://xgboost.readthedocs.io/en/stable/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/)
87
87
  for users who enjoy computing gradients by hand.
88
88
 
89
89
  - **Understandable**: Flower is written with maintainability in mind. The
@@ -180,6 +180,7 @@ Quickstart examples:
180
180
  - [Quickstart (fastai)](https://github.com/adap/flower/tree/main/examples/quickstart-fastai)
181
181
  - [Quickstart (Pandas)](https://github.com/adap/flower/tree/main/examples/quickstart-pandas)
182
182
  - [Quickstart (JAX)](https://github.com/adap/flower/tree/main/examples/quickstart-jax)
183
+ - [Quickstart (MONAI)](https://github.com/adap/flower/tree/main/examples/quickstart-monai)
183
184
  - [Quickstart (scikit-learn)](https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist)
184
185
  - [Quickstart (XGBoost)](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart)
185
186
  - [Quickstart (Android [TFLite])](https://github.com/adap/flower/tree/main/examples/android)