flwr-nightly 1.9.0.dev20240507__py3-none-any.whl → 1.9.0.dev20240520__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (30) hide show
  1. flwr/cli/new/new.py +4 -0
  2. flwr/cli/new/templates/app/code/client.hf.py.tpl +55 -0
  3. flwr/cli/new/templates/app/code/client.jax.py.tpl +55 -0
  4. flwr/cli/new/templates/app/code/server.hf.py.tpl +17 -0
  5. flwr/cli/new/templates/app/code/server.jax.py.tpl +12 -0
  6. flwr/cli/new/templates/app/code/task.hf.py.tpl +87 -0
  7. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  8. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +31 -0
  9. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +28 -0
  10. flwr/client/mod/comms_mods.py +4 -4
  11. flwr/client/mod/localdp_mod.py +1 -2
  12. flwr/server/__init__.py +0 -2
  13. flwr/server/app.py +7 -1
  14. flwr/server/compat/app.py +6 -57
  15. flwr/server/driver/__init__.py +3 -2
  16. flwr/server/driver/inmemory_driver.py +181 -0
  17. flwr/server/history.py +20 -20
  18. flwr/server/server.py +11 -7
  19. flwr/server/strategy/dp_adaptive_clipping.py +2 -4
  20. flwr/server/strategy/dp_fixed_clipping.py +2 -4
  21. flwr/server/superlink/driver/driver_servicer.py +2 -2
  22. flwr/server/superlink/fleet/vce/backend/raybackend.py +11 -3
  23. flwr/server/superlink/fleet/vce/vce_api.py +1 -1
  24. flwr/server/workflow/default_workflows.py +67 -22
  25. flwr/simulation/run_simulation.py +7 -34
  26. {flwr_nightly-1.9.0.dev20240507.dist-info → flwr_nightly-1.9.0.dev20240520.dist-info}/METADATA +2 -1
  27. {flwr_nightly-1.9.0.dev20240507.dist-info → flwr_nightly-1.9.0.dev20240520.dist-info}/RECORD +30 -21
  28. {flwr_nightly-1.9.0.dev20240507.dist-info → flwr_nightly-1.9.0.dev20240520.dist-info}/LICENSE +0 -0
  29. {flwr_nightly-1.9.0.dev20240507.dist-info → flwr_nightly-1.9.0.dev20240520.dist-info}/WHEEL +0 -0
  30. {flwr_nightly-1.9.0.dev20240507.dist-info → flwr_nightly-1.9.0.dev20240520.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,181 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower in-memory Driver."""
16
+
17
+
18
+ import time
19
+ import warnings
20
+ from typing import Iterable, List, Optional
21
+ from uuid import UUID
22
+
23
+ from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
24
+ from flwr.common.serde import message_from_taskres, message_to_taskins
25
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
26
+ from flwr.server.superlink.state import StateFactory
27
+
28
+ from .driver import Driver
29
+
30
+
31
+ class InMemoryDriver(Driver):
32
+ """`InMemoryDriver` class provides an interface to the Driver API.
33
+
34
+ Parameters
35
+ ----------
36
+ state_factory : StateFactory
37
+ A StateFactory embedding a state that this driver can interface with.
38
+ fab_id : str (default: None)
39
+ The identifier of the FAB used in the run.
40
+ fab_version : str (default: None)
41
+ The version of the FAB used in the run.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ state_factory: StateFactory,
47
+ fab_id: Optional[str] = None,
48
+ fab_version: Optional[str] = None,
49
+ ) -> None:
50
+ self.run_id: Optional[int] = None
51
+ self.fab_id = fab_id if fab_id is not None else ""
52
+ self.fab_version = fab_version if fab_version is not None else ""
53
+ self.node = Node(node_id=0, anonymous=True)
54
+ self.state = state_factory.state()
55
+
56
+ def _check_message(self, message: Message) -> None:
57
+ # Check if the message is valid
58
+ if not (
59
+ message.metadata.run_id == self.run_id
60
+ and message.metadata.src_node_id == self.node.node_id
61
+ and message.metadata.message_id == ""
62
+ and message.metadata.reply_to_message == ""
63
+ and message.metadata.ttl > 0
64
+ ):
65
+ raise ValueError(f"Invalid message: {message}")
66
+
67
+ def _get_run_id(self) -> int:
68
+ """Return run_id.
69
+
70
+ If unset, create a new run.
71
+ """
72
+ if self.run_id is None:
73
+ self.run_id = self.state.create_run(
74
+ fab_id=self.fab_id, fab_version=self.fab_version
75
+ )
76
+ return self.run_id
77
+
78
+ def create_message( # pylint: disable=too-many-arguments
79
+ self,
80
+ content: RecordSet,
81
+ message_type: str,
82
+ dst_node_id: int,
83
+ group_id: str,
84
+ ttl: Optional[float] = None,
85
+ ) -> Message:
86
+ """Create a new message with specified parameters.
87
+
88
+ This method constructs a new `Message` with given content and metadata.
89
+ The `run_id` and `src_node_id` will be set automatically.
90
+ """
91
+ run_id = self._get_run_id()
92
+ if ttl:
93
+ warnings.warn(
94
+ "A custom TTL was set, but note that the SuperLink does not enforce "
95
+ "the TTL yet. The SuperLink will start enforcing the TTL in a future "
96
+ "version of Flower.",
97
+ stacklevel=2,
98
+ )
99
+ ttl_ = DEFAULT_TTL if ttl is None else ttl
100
+
101
+ metadata = Metadata(
102
+ run_id=run_id,
103
+ message_id="", # Will be set by the server
104
+ src_node_id=self.node.node_id,
105
+ dst_node_id=dst_node_id,
106
+ reply_to_message="",
107
+ group_id=group_id,
108
+ ttl=ttl_,
109
+ message_type=message_type,
110
+ )
111
+ return Message(metadata=metadata, content=content)
112
+
113
+ def get_node_ids(self) -> List[int]:
114
+ """Get node IDs."""
115
+ run_id = self._get_run_id()
116
+ return list(self.state.get_nodes(run_id))
117
+
118
+ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
119
+ """Push messages to specified node IDs.
120
+
121
+ This method takes an iterable of messages and sends each message
122
+ to the node specified in `dst_node_id`.
123
+ """
124
+ task_ids: List[str] = []
125
+ for msg in messages:
126
+ # Check message
127
+ self._check_message(msg)
128
+ # Convert Message to TaskIns
129
+ taskins = message_to_taskins(msg)
130
+ # Store in state
131
+ taskins.task.pushed_at = time.time()
132
+ task_id = self.state.store_task_ins(taskins)
133
+ if task_id:
134
+ task_ids.append(str(task_id))
135
+
136
+ return task_ids
137
+
138
+ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
139
+ """Pull messages based on message IDs.
140
+
141
+ This method is used to collect messages from the SuperLink that correspond to a
142
+ set of given message IDs.
143
+ """
144
+ msg_ids = {UUID(msg_id) for msg_id in message_ids}
145
+ # Pull TaskRes
146
+ task_res_list = self.state.get_task_res(task_ids=msg_ids, limit=len(msg_ids))
147
+ # Delete tasks in state
148
+ self.state.delete_tasks(msg_ids)
149
+ # Convert TaskRes to Message
150
+ msgs = [message_from_taskres(taskres) for taskres in task_res_list]
151
+ return msgs
152
+
153
+ def send_and_receive(
154
+ self,
155
+ messages: Iterable[Message],
156
+ *,
157
+ timeout: Optional[float] = None,
158
+ ) -> Iterable[Message]:
159
+ """Push messages to specified node IDs and pull the reply messages.
160
+
161
+ This method sends a list of messages to their destination node IDs and then
162
+ waits for the replies. It continues to pull replies until either all replies are
163
+ received or the specified timeout duration is exceeded.
164
+ """
165
+ # Push messages
166
+ msg_ids = set(self.push_messages(messages))
167
+
168
+ # Pull messages
169
+ end_time = time.time() + (timeout if timeout is not None else 0.0)
170
+ ret: List[Message] = []
171
+ while timeout is None or time.time() < end_time:
172
+ res_msgs = self.pull_messages(msg_ids)
173
+ ret.extend(res_msgs)
174
+ msg_ids.difference_update(
175
+ {msg.metadata.reply_to_message for msg in res_msgs}
176
+ )
177
+ if len(msg_ids) == 0:
178
+ break
179
+ # Sleep
180
+ time.sleep(3)
181
+ return ret
flwr/server/history.py CHANGED
@@ -91,32 +91,32 @@ class History:
91
91
  """
92
92
  rep = ""
93
93
  if self.losses_distributed:
94
- rep += "History (loss, distributed):\n" + pprint.pformat(
95
- reduce(
96
- lambda a, b: a + b,
97
- [
98
- f"\tround {server_round}: {loss}\n"
99
- for server_round, loss in self.losses_distributed
100
- ],
101
- )
94
+ rep += "History (loss, distributed):\n" + reduce(
95
+ lambda a, b: a + b,
96
+ [
97
+ f"\tround {server_round}: {loss}\n"
98
+ for server_round, loss in self.losses_distributed
99
+ ],
102
100
  )
103
101
  if self.losses_centralized:
104
- rep += "History (loss, centralized):\n" + pprint.pformat(
105
- reduce(
106
- lambda a, b: a + b,
107
- [
108
- f"\tround {server_round}: {loss}\n"
109
- for server_round, loss in self.losses_centralized
110
- ],
111
- )
102
+ rep += "History (loss, centralized):\n" + reduce(
103
+ lambda a, b: a + b,
104
+ [
105
+ f"\tround {server_round}: {loss}\n"
106
+ for server_round, loss in self.losses_centralized
107
+ ],
112
108
  )
113
109
  if self.metrics_distributed_fit:
114
- rep += "History (metrics, distributed, fit):\n" + pprint.pformat(
115
- self.metrics_distributed_fit
110
+ rep += (
111
+ "History (metrics, distributed, fit):\n"
112
+ + pprint.pformat(self.metrics_distributed_fit)
113
+ + "\n"
116
114
  )
117
115
  if self.metrics_distributed:
118
- rep += "History (metrics, distributed, evaluate):\n" + pprint.pformat(
119
- self.metrics_distributed
116
+ rep += (
117
+ "History (metrics, distributed, evaluate):\n"
118
+ + pprint.pformat(self.metrics_distributed)
119
+ + "\n"
120
120
  )
121
121
  if self.metrics_centralized:
122
122
  rep += "History (metrics, centralized):\n" + pprint.pformat(
flwr/server/server.py CHANGED
@@ -282,7 +282,14 @@ class Server:
282
282
  get_parameters_res = random_client.get_parameters(
283
283
  ins=ins, timeout=timeout, group_id=server_round
284
284
  )
285
- log(INFO, "Received initial parameters from one random client")
285
+ if get_parameters_res.status.code == Code.OK:
286
+ log(INFO, "Received initial parameters from one random client")
287
+ else:
288
+ log(
289
+ WARN,
290
+ "Failed to receive initial parameters from the client."
291
+ " Empty initial parameters will be used.",
292
+ )
286
293
  return get_parameters_res.parameters
287
294
 
288
295
 
@@ -486,12 +493,9 @@ def run_fl(
486
493
 
487
494
  log(INFO, "")
488
495
  log(INFO, "[SUMMARY]")
489
- log(INFO, "Run finished %s rounds in %.2fs", config.num_rounds, elapsed_time)
490
- for idx, line in enumerate(io.StringIO(str(hist))):
491
- if idx == 0:
492
- log(INFO, "%s", line.strip("\n"))
493
- else:
494
- log(INFO, "\t%s", line.strip("\n"))
496
+ log(INFO, "Run finished %s round(s) in %.2fs", config.num_rounds, elapsed_time)
497
+ for line in io.StringIO(str(hist)):
498
+ log(INFO, "\t%s", line.strip("\n"))
495
499
  log(INFO, "")
496
500
 
497
501
  # Graceful shutdown
@@ -234,8 +234,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
234
234
  )
235
235
  log(
236
236
  INFO,
237
- "aggregate_fit: central DP noise with "
238
- "standard deviation: %.4f added to parameters.",
237
+ "aggregate_fit: central DP noise with %.4f stdev added",
239
238
  compute_stdv(
240
239
  self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
241
240
  ),
@@ -425,8 +424,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
425
424
  )
426
425
  log(
427
426
  INFO,
428
- "aggregate_fit: central DP noise with "
429
- "standard deviation: %.4f added to parameters.",
427
+ "aggregate_fit: central DP noise with %.4f stdev added",
430
428
  compute_stdv(
431
429
  self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
432
430
  ),
@@ -180,8 +180,7 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
180
180
 
181
181
  log(
182
182
  INFO,
183
- "aggregate_fit: central DP noise with "
184
- "standard deviation: %.4f added to parameters.",
183
+ "aggregate_fit: central DP noise with %.4f stdev added",
185
184
  compute_stdv(
186
185
  self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
187
186
  ),
@@ -338,8 +337,7 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
338
337
  )
339
338
  log(
340
339
  INFO,
341
- "aggregate_fit: central DP noise with "
342
- "standard deviation: %.4f added to parameters.",
340
+ "aggregate_fit: central DP noise with %.4f stdev added",
343
341
  compute_stdv(
344
342
  self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
345
343
  ),
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  import time
19
- from logging import DEBUG, INFO
19
+ from logging import DEBUG
20
20
  from typing import List, Optional, Set
21
21
  from uuid import UUID
22
22
 
@@ -62,7 +62,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
62
62
  self, request: CreateRunRequest, context: grpc.ServicerContext
63
63
  ) -> CreateRunResponse:
64
64
  """Create run ID."""
65
- log(INFO, "DriverServicer.CreateRun")
65
+ log(DEBUG, "DriverServicer.CreateRun")
66
66
  state: State = self.state_factory.state()
67
67
  run_id = state.create_run(request.fab_id, request.fab_version)
68
68
  return CreateRunResponse(run_id=run_id)
@@ -15,7 +15,7 @@
15
15
  """Ray backend for the Fleet API using the Simulation Engine."""
16
16
 
17
17
  import pathlib
18
- from logging import DEBUG, ERROR, INFO
18
+ from logging import DEBUG, ERROR, WARNING
19
19
  from typing import Callable, Dict, List, Tuple, Union
20
20
 
21
21
  import ray
@@ -45,7 +45,7 @@ class RayBackend(Backend):
45
45
  work_dir: str,
46
46
  ) -> None:
47
47
  """Prepare RayBackend by initialising Ray and creating the ActorPool."""
48
- log(INFO, "Initialising: %s", self.__class__.__name__)
48
+ log(DEBUG, "Initialising: %s", self.__class__.__name__)
49
49
  log(DEBUG, "Backend config: %s", backend_config)
50
50
 
51
51
  if not pathlib.Path(work_dir).exists():
@@ -55,7 +55,15 @@ class RayBackend(Backend):
55
55
  runtime_env = (
56
56
  self._configure_runtime_env(work_dir=work_dir) if work_dir else None
57
57
  )
58
- init_ray(runtime_env=runtime_env)
58
+
59
+ if backend_config.get("mute_logging", False):
60
+ init_ray(
61
+ logging_level=WARNING, log_to_driver=False, runtime_env=runtime_env
62
+ )
63
+ elif backend_config.get("silent", False):
64
+ init_ray(logging_level=WARNING, log_to_driver=True, runtime_env=runtime_env)
65
+ else:
66
+ init_ray(runtime_env=runtime_env)
59
67
 
60
68
  # Validate client resources
61
69
  self.client_resources_key = "client_resources"
@@ -46,7 +46,7 @@ def _register_nodes(
46
46
  for i in range(num_nodes):
47
47
  node_id = state.create_node(ping_interval=PING_MAX_INTERVAL)
48
48
  nodes_mapping[node_id] = i
49
- log(INFO, "Registered %i nodes", len(nodes_mapping))
49
+ log(DEBUG, "Registered %i nodes", len(nodes_mapping))
50
50
  return nodes_mapping
51
51
 
52
52
 
@@ -17,13 +17,23 @@
17
17
 
18
18
  import io
19
19
  import timeit
20
- from logging import INFO
21
- from typing import Optional, cast
20
+ from logging import INFO, WARN
21
+ from typing import List, Optional, Tuple, Union, cast
22
22
 
23
23
  import flwr.common.recordset_compat as compat
24
- from flwr.common import ConfigsRecord, Context, GetParametersIns, log
24
+ from flwr.common import (
25
+ Code,
26
+ ConfigsRecord,
27
+ Context,
28
+ EvaluateRes,
29
+ FitRes,
30
+ GetParametersIns,
31
+ ParametersRecord,
32
+ log,
33
+ )
25
34
  from flwr.common.constant import MessageType, MessageTypeLegacy
26
35
 
36
+ from ..client_proxy import ClientProxy
27
37
  from ..compat.app_utils import start_update_client_manager_thread
28
38
  from ..compat.legacy_context import LegacyContext
29
39
  from ..driver import Driver
@@ -88,7 +98,12 @@ class DefaultWorkflow:
88
98
  hist = context.history
89
99
  log(INFO, "")
90
100
  log(INFO, "[SUMMARY]")
91
- log(INFO, "Run finished %s rounds in %.2fs", context.config.num_rounds, elapsed)
101
+ log(
102
+ INFO,
103
+ "Run finished %s round(s) in %.2fs",
104
+ context.config.num_rounds,
105
+ elapsed,
106
+ )
92
107
  for idx, line in enumerate(io.StringIO(str(hist))):
93
108
  if idx == 0:
94
109
  log(INFO, "%s", line.strip("\n"))
@@ -130,9 +145,24 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
130
145
  )
131
146
  ]
132
147
  )
133
- log(INFO, "Received initial parameters from one random client")
134
148
  msg = list(messages)[0]
135
- paramsrecord = next(iter(msg.content.parameters_records.values()))
149
+
150
+ if (
151
+ msg.has_content()
152
+ and compat._extract_status_from_recordset( # pylint: disable=W0212
153
+ "getparametersres", msg.content
154
+ ).code
155
+ == Code.OK
156
+ ):
157
+ log(INFO, "Received initial parameters from one random client")
158
+ paramsrecord = next(iter(msg.content.parameters_records.values()))
159
+ else:
160
+ log(
161
+ WARN,
162
+ "Failed to receive initial parameters from the client."
163
+ " Empty initial parameters will be used.",
164
+ )
165
+ paramsrecord = ParametersRecord()
136
166
 
137
167
  context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
138
168
 
@@ -244,14 +274,20 @@ def default_fit_workflow( # pylint: disable=R0914
244
274
  )
245
275
 
246
276
  # Aggregate training results
247
- results = [
248
- (
249
- node_id_to_proxy[msg.metadata.src_node_id],
250
- compat.recordset_to_fitres(msg.content, False),
251
- )
252
- for msg in messages
253
- ]
254
- aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
277
+ results: List[Tuple[ClientProxy, FitRes]] = []
278
+ failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = []
279
+ for msg in messages:
280
+ if msg.has_content():
281
+ proxy = node_id_to_proxy[msg.metadata.src_node_id]
282
+ fitres = compat.recordset_to_fitres(msg.content, False)
283
+ if fitres.status.code == Code.OK:
284
+ results.append((proxy, fitres))
285
+ else:
286
+ failures.append((proxy, fitres))
287
+ else:
288
+ failures.append(Exception(msg.error))
289
+
290
+ aggregated_result = context.strategy.aggregate_fit(current_round, results, failures)
255
291
  parameters_aggregated, metrics_aggregated = aggregated_result
256
292
 
257
293
  # Update the parameters and write history
@@ -265,6 +301,7 @@ def default_fit_workflow( # pylint: disable=R0914
265
301
  )
266
302
 
267
303
 
304
+ # pylint: disable-next=R0914
268
305
  def default_evaluate_workflow(driver: Driver, context: Context) -> None:
269
306
  """Execute the default workflow for a single evaluate round."""
270
307
  if not isinstance(context, LegacyContext):
@@ -323,14 +360,22 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
323
360
  )
324
361
 
325
362
  # Aggregate the evaluation results
326
- results = [
327
- (
328
- node_id_to_proxy[msg.metadata.src_node_id],
329
- compat.recordset_to_evaluateres(msg.content),
330
- )
331
- for msg in messages
332
- ]
333
- aggregated_result = context.strategy.aggregate_evaluate(current_round, results, [])
363
+ results: List[Tuple[ClientProxy, EvaluateRes]] = []
364
+ failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = []
365
+ for msg in messages:
366
+ if msg.has_content():
367
+ proxy = node_id_to_proxy[msg.metadata.src_node_id]
368
+ evalres = compat.recordset_to_evaluateres(msg.content)
369
+ if evalres.status.code == Code.OK:
370
+ results.append((proxy, evalres))
371
+ else:
372
+ failures.append((proxy, evalres))
373
+ else:
374
+ failures.append(Exception(msg.error))
375
+
376
+ aggregated_result = context.strategy.aggregate_evaluate(
377
+ current_round, results, failures
378
+ )
334
379
 
335
380
  loss_aggregated, metrics_aggregated = aggregated_result
336
381
 
@@ -24,16 +24,13 @@ from logging import DEBUG, ERROR, INFO, WARNING
24
24
  from time import sleep
25
25
  from typing import Dict, Optional
26
26
 
27
- import grpc
28
-
29
27
  from flwr.client import ClientApp
30
28
  from flwr.common import EventType, event, log
31
29
  from flwr.common.logger import set_logger_propagation, update_console_handler
32
30
  from flwr.common.typing import ConfigsRecordValues
33
- from flwr.server.driver import Driver, GrpcDriver
31
+ from flwr.server.driver import Driver, InMemoryDriver
34
32
  from flwr.server.run_serverapp import run
35
33
  from flwr.server.server_app import ServerApp
36
- from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
37
34
  from flwr.server.superlink.fleet import vce
38
35
  from flwr.server.superlink.state import StateFactory
39
36
  from flwr.simulation.ray_transport.utils import (
@@ -56,7 +53,6 @@ def run_simulation_from_cli() -> None:
56
53
  backend_name=args.backend,
57
54
  backend_config=backend_config_dict,
58
55
  app_dir=args.app_dir,
59
- driver_api_address=args.driver_api_address,
60
56
  enable_tf_gpu_growth=args.enable_tf_gpu_growth,
61
57
  verbose_logging=args.verbose,
62
58
  )
@@ -177,7 +173,6 @@ def _main_loop(
177
173
  num_supernodes: int,
178
174
  backend_name: str,
179
175
  backend_config_stream: str,
180
- driver_api_address: str,
181
176
  app_dir: str,
182
177
  enable_tf_gpu_growth: bool,
183
178
  client_app: Optional[ClientApp] = None,
@@ -194,21 +189,11 @@ def _main_loop(
194
189
  # Initialize StateFactory
195
190
  state_factory = StateFactory(":flwr-in-memory-state:")
196
191
 
197
- # Start Driver API
198
- driver_server: grpc.Server = run_driver_api_grpc(
199
- address=driver_api_address,
200
- state_factory=state_factory,
201
- certificates=None,
202
- )
203
-
204
192
  f_stop = asyncio.Event()
205
193
  serverapp_th = None
206
194
  try:
207
195
  # Initialize Driver
208
- driver = GrpcDriver(
209
- driver_service_address=driver_api_address,
210
- root_certificates=None,
211
- )
196
+ driver = InMemoryDriver(state_factory)
212
197
 
213
198
  # Get and run ServerApp thread
214
199
  serverapp_th = run_serverapp_th(
@@ -239,9 +224,6 @@ def _main_loop(
239
224
  raise RuntimeError("An error was encountered. Ending simulation.") from ex
240
225
 
241
226
  finally:
242
- # Stop Driver
243
- driver_server.stop(grace=0)
244
- driver.close()
245
227
  # Trigger stop event
246
228
  f_stop.set()
247
229
 
@@ -262,7 +244,6 @@ def _run_simulation(
262
244
  client_app_attr: Optional[str] = None,
263
245
  server_app_attr: Optional[str] = None,
264
246
  app_dir: str = "",
265
- driver_api_address: str = "0.0.0.0:9091",
266
247
  enable_tf_gpu_growth: bool = False,
267
248
  verbose_logging: bool = False,
268
249
  ) -> None:
@@ -302,9 +283,6 @@ def _run_simulation(
302
283
  Add specified directory to the PYTHONPATH and load `ClientApp` from there.
303
284
  (Default: current working directory.)
304
285
 
305
- driver_api_address : str (default: "0.0.0.0:9091")
306
- Driver API (gRPC) server address (IPv4, IPv6, or a domain name)
307
-
308
286
  enable_tf_gpu_growth : bool (default: False)
309
287
  A boolean to indicate whether to enable GPU growth on the main thread. This is
310
288
  desirable if you make use of a TensorFlow model on your `ServerApp` while
@@ -317,13 +295,15 @@ def _run_simulation(
317
295
  When diabled, only INFO, WARNING and ERROR log messages will be shown. If
318
296
  enabled, DEBUG-level logs will be displayed.
319
297
  """
298
+ if backend_config is None:
299
+ backend_config = {}
300
+
320
301
  # Set logging level
321
302
  logger = logging.getLogger("flwr")
322
303
  if verbose_logging:
323
304
  update_console_handler(level=DEBUG, timestamps=True, colored=True)
324
-
325
- if backend_config is None:
326
- backend_config = {}
305
+ else:
306
+ backend_config["silent"] = True
327
307
 
328
308
  if enable_tf_gpu_growth:
329
309
  # Check that Backend config has also enabled using GPU growth
@@ -340,7 +320,6 @@ def _run_simulation(
340
320
  num_supernodes,
341
321
  backend_name,
342
322
  backend_config_stream,
343
- driver_api_address,
344
323
  app_dir,
345
324
  enable_tf_gpu_growth,
346
325
  client_app,
@@ -397,12 +376,6 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
397
376
  required=True,
398
377
  help="Number of simulated SuperNodes.",
399
378
  )
400
- parser.add_argument(
401
- "--driver-api-address",
402
- default="0.0.0.0:9091",
403
- type=str,
404
- help="For example: `server:app` or `project.package.module:wrapper.app`",
405
- )
406
379
  parser.add_argument(
407
380
  "--backend",
408
381
  default="ray",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.9.0.dev20240507
3
+ Version: 1.9.0.dev20240520
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0
@@ -202,6 +202,7 @@ Other [examples](https://github.com/adap/flower/tree/main/examples):
202
202
  - [Comprehensive Flower+XGBoost](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive)
203
203
  - [Flower through Docker Compose and with Grafana dashboard](https://github.com/adap/flower/tree/main/examples/flower-via-docker-compose)
204
204
  - [Flower with KaplanMeierFitter from the lifelines library](https://github.com/adap/flower/tree/main/examples/federated-kaplna-meier-fitter)
205
+ - [Sample Level Privacy with Opacus](https://github.com/adap/flower/tree/main/examples/opacus)
205
206
 
206
207
  ## Community
207
208