flwr-nightly 1.8.0.dev20240303__py3-none-any.whl → 1.8.0.dev20240305__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flwr/cli/new/new.py CHANGED
@@ -28,6 +28,7 @@ from ..utils import prompt_options
28
28
  class MlFramework(str, Enum):
29
29
  """Available frameworks."""
30
30
 
31
+ NUMPY = "NumPy"
31
32
  PYTORCH = "PyTorch"
32
33
  TENSORFLOW = "TensorFlow"
33
34
 
@@ -0,0 +1,24 @@
1
+ """$project_name: A Flower / NumPy app."""
2
+
3
+ import flwr as fl
4
+ import numpy as np
5
+
6
+
7
+ # Flower client, adapted from Pytorch quickstart example
8
+ class FlowerClient(fl.client.NumPyClient):
9
+ def get_parameters(self, config):
10
+ return [np.ones((1, 1))]
11
+
12
+ def fit(self, parameters, config):
13
+ return ([np.ones((1, 1))], 1, {})
14
+
15
+ def evaluate(self, parameters, config):
16
+ return float(0.0), 1, {"accuracy": float(1.0)}
17
+
18
+
19
+ def client_fn(cid: str):
20
+ return FlowerClient().to_client()
21
+
22
+
23
+ # ClientApp for Flower-Next
24
+ app = fl.client.ClientApp(client_fn=client_fn)
@@ -0,0 +1,12 @@
1
+ """$project_name: A Flower / NumPy app."""
2
+
3
+ import flwr as fl
4
+
5
+ # Configure the strategy
6
+ strategy = fl.server.strategy.FedAvg()
7
+
8
+ # Flower ServerApp
9
+ app = fl.server.ServerApp(
10
+ config=fl.server.ServerConfig(num_rounds=1),
11
+ strategy=strategy,
12
+ )
@@ -1,10 +1,10 @@
1
- [flower]
1
+ [project]
2
2
  name = "$project_name"
3
3
  version = "1.0.0"
4
4
  description = ""
5
5
  license = "Apache-2.0"
6
6
  authors = ["The Flower Authors <hello@flower.ai>"]
7
7
 
8
- [components]
8
+ [flower.components]
9
9
  serverapp = "$project_name.server:app"
10
10
  clientapp = "$project_name.client:app"
@@ -0,0 +1,2 @@
1
+ flwr>=1.8, <2.0
2
+ numpy >= 1.21.0
flwr/client/app.py CHANGED
@@ -20,7 +20,9 @@ import sys
20
20
  import time
21
21
  from logging import DEBUG, INFO, WARN
22
22
  from pathlib import Path
23
- from typing import Callable, ContextManager, Optional, Tuple, Union
23
+ from typing import Callable, ContextManager, Optional, Tuple, Type, Union
24
+
25
+ from grpc import RpcError
24
26
 
25
27
  from flwr.client.client import Client
26
28
  from flwr.client.client_app import ClientApp
@@ -36,6 +38,7 @@ from flwr.common.constant import (
36
38
  )
37
39
  from flwr.common.exit_handlers import register_exit_handlers
38
40
  from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature
41
+ from flwr.common.retry_invoker import RetryInvoker, exponential
39
42
 
40
43
  from .client_app import load_client_app
41
44
  from .grpc_client.connection import grpc_connection
@@ -104,6 +107,8 @@ def run_client_app() -> None:
104
107
  transport="rest" if args.rest else "grpc-rere",
105
108
  root_certificates=root_certificates,
106
109
  insecure=args.insecure,
110
+ max_retries=args.max_retries,
111
+ max_wait_time=args.max_wait_time,
107
112
  )
108
113
  register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
109
114
 
@@ -141,6 +146,22 @@ def _parse_args_run_client_app() -> argparse.ArgumentParser:
141
146
  default="0.0.0.0:9092",
142
147
  help="Server address",
143
148
  )
149
+ parser.add_argument(
150
+ "--max-retries",
151
+ type=int,
152
+ default=None,
153
+ help="The maximum number of times the client will try to connect to the"
154
+ "server before giving up in case of a connection error. By default,"
155
+ "it is set to None, meaning there is no limit to the number of tries.",
156
+ )
157
+ parser.add_argument(
158
+ "--max-wait-time",
159
+ type=float,
160
+ default=None,
161
+ help="The maximum duration before the client stops trying to"
162
+ "connect to the server in case of connection error. By default, it"
163
+ "is set to None, meaning there is no limit to the total time.",
164
+ )
144
165
  parser.add_argument(
145
166
  "--dir",
146
167
  default="",
@@ -180,6 +201,8 @@ def start_client(
180
201
  root_certificates: Optional[Union[bytes, str]] = None,
181
202
  insecure: Optional[bool] = None,
182
203
  transport: Optional[str] = None,
204
+ max_retries: Optional[int] = None,
205
+ max_wait_time: Optional[float] = None,
183
206
  ) -> None:
184
207
  """Start a Flower client node which connects to a Flower server.
185
208
 
@@ -213,6 +236,14 @@ def start_client(
213
236
  - 'grpc-bidi': gRPC, bidirectional streaming
214
237
  - 'grpc-rere': gRPC, request-response (experimental)
215
238
  - 'rest': HTTP (experimental)
239
+ max_retries: Optional[int] (default: None)
240
+ The maximum number of times the client will try to connect to the
241
+ server before giving up in case of a connection error. If set to None,
242
+ there is no limit to the number of tries.
243
+ max_wait_time: Optional[float] (default: None)
244
+ The maximum duration before the client stops trying to
245
+ connect to the server in case of connection error.
246
+ If set to None, there is no limit to the total time.
216
247
 
217
248
  Examples
218
249
  --------
@@ -254,6 +285,8 @@ def start_client(
254
285
  root_certificates=root_certificates,
255
286
  insecure=insecure,
256
287
  transport=transport,
288
+ max_retries=max_retries,
289
+ max_wait_time=max_wait_time,
257
290
  )
258
291
  event(EventType.START_CLIENT_LEAVE)
259
292
 
@@ -272,6 +305,8 @@ def _start_client_internal(
272
305
  root_certificates: Optional[Union[bytes, str]] = None,
273
306
  insecure: Optional[bool] = None,
274
307
  transport: Optional[str] = None,
308
+ max_retries: Optional[int] = None,
309
+ max_wait_time: Optional[float] = None,
275
310
  ) -> None:
276
311
  """Start a Flower client node which connects to a Flower server.
277
312
 
@@ -299,7 +334,7 @@ def _start_client_internal(
299
334
  The PEM-encoded root certificates as a byte string or a path string.
300
335
  If provided, a secure connection using the certificates will be
301
336
  established to an SSL-enabled Flower server.
302
- insecure : bool (default: True)
337
+ insecure : Optional[bool] (default: None)
303
338
  Starts an insecure gRPC connection when True. Enables HTTPS connection
304
339
  when False, using system certificates if `root_certificates` is None.
305
340
  transport : Optional[str] (default: None)
@@ -307,6 +342,14 @@ def _start_client_internal(
307
342
  - 'grpc-bidi': gRPC, bidirectional streaming
308
343
  - 'grpc-rere': gRPC, request-response (experimental)
309
344
  - 'rest': HTTP (experimental)
345
+ max_retries: Optional[int] (default: None)
346
+ The maximum number of times the client will try to connect to the
347
+ server before giving up in case of a connection error. If set to None,
348
+ there is no limit to the number of tries.
349
+ max_wait_time: Optional[float] (default: None)
350
+ The maximum duration before the client stops trying to
351
+ connect to the server in case of connection error.
352
+ If set to None, there is no limit to the total time.
310
353
  """
311
354
  if insecure is None:
312
355
  insecure = root_certificates is None
@@ -338,7 +381,45 @@ def _start_client_internal(
338
381
  # Both `client` and `client_fn` must not be used directly
339
382
 
340
383
  # Initialize connection context manager
341
- connection, address = _init_connection(transport, server_address)
384
+ connection, address, connection_error_type = _init_connection(
385
+ transport, server_address
386
+ )
387
+
388
+ retry_invoker = RetryInvoker(
389
+ wait_factory=exponential,
390
+ recoverable_exceptions=connection_error_type,
391
+ max_tries=max_retries,
392
+ max_time=max_wait_time,
393
+ on_giveup=lambda retry_state: (
394
+ log(
395
+ WARN,
396
+ "Giving up reconnection after %.2f seconds and %s tries.",
397
+ retry_state.elapsed_time,
398
+ retry_state.tries,
399
+ )
400
+ if retry_state.tries > 1
401
+ else None
402
+ ),
403
+ on_success=lambda retry_state: (
404
+ log(
405
+ INFO,
406
+ "Connection successful after %.2f seconds and %s tries.",
407
+ retry_state.elapsed_time,
408
+ retry_state.tries,
409
+ )
410
+ if retry_state.tries > 1
411
+ else None
412
+ ),
413
+ on_backoff=lambda retry_state: (
414
+ log(WARN, "Connection attempt failed, retrying...")
415
+ if retry_state.tries == 1
416
+ else log(
417
+ DEBUG,
418
+ "Connection attempt failed, retrying in %.2f seconds",
419
+ retry_state.actual_wait,
420
+ )
421
+ ),
422
+ )
342
423
 
343
424
  node_state = NodeState()
344
425
 
@@ -347,6 +428,7 @@ def _start_client_internal(
347
428
  with connection(
348
429
  address,
349
430
  insecure,
431
+ retry_invoker,
350
432
  grpc_max_message_length,
351
433
  root_certificates,
352
434
  ) as conn:
@@ -509,7 +591,7 @@ def start_numpy_client(
509
591
 
510
592
  def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
511
593
  Callable[
512
- [str, bool, int, Union[bytes, str, None]],
594
+ [str, bool, RetryInvoker, int, Union[bytes, str, None]],
513
595
  ContextManager[
514
596
  Tuple[
515
597
  Callable[[], Optional[Message]],
@@ -520,6 +602,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
520
602
  ],
521
603
  ],
522
604
  str,
605
+ Type[Exception],
523
606
  ]:
524
607
  # Parse IP address
525
608
  parsed_address = parse_address(server_address)
@@ -535,6 +618,8 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
535
618
  # Use either gRPC bidirectional streaming or REST request/response
536
619
  if transport == TRANSPORT_TYPE_REST:
537
620
  try:
621
+ from requests.exceptions import ConnectionError as RequestsConnectionError
622
+
538
623
  from .rest_client.connection import http_request_response
539
624
  except ModuleNotFoundError:
540
625
  sys.exit(MISSING_EXTRA_REST)
@@ -543,14 +628,14 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
543
628
  "When using the REST API, please provide `https://` or "
544
629
  "`http://` before the server address (e.g. `http://127.0.0.1:8080`)"
545
630
  )
546
- connection = http_request_response
631
+ connection, error_type = http_request_response, RequestsConnectionError
547
632
  elif transport == TRANSPORT_TYPE_GRPC_RERE:
548
- connection = grpc_request_response
633
+ connection, error_type = grpc_request_response, RpcError
549
634
  elif transport == TRANSPORT_TYPE_GRPC_BIDI:
550
- connection = grpc_connection
635
+ connection, error_type = grpc_connection, RpcError
551
636
  else:
552
637
  raise ValueError(
553
638
  f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})"
554
639
  )
555
640
 
556
- return connection, address
641
+ return connection, address, error_type
@@ -39,6 +39,7 @@ from flwr.common.constant import (
39
39
  )
40
40
  from flwr.common.grpc import create_channel
41
41
  from flwr.common.logger import log
42
+ from flwr.common.retry_invoker import RetryInvoker
42
43
  from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
43
44
  ClientMessage,
44
45
  Reason,
@@ -62,6 +63,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:
62
63
  def grpc_connection( # pylint: disable=R0915
63
64
  server_address: str,
64
65
  insecure: bool,
66
+ retry_invoker: RetryInvoker, # pylint: disable=unused-argument
65
67
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
66
68
  root_certificates: Optional[Union[bytes, str]] = None,
67
69
  ) -> Iterator[
@@ -80,6 +82,11 @@ def grpc_connection( # pylint: disable=R0915
80
82
  The IPv4 or IPv6 address of the server. If the Flower server runs on the same
81
83
  machine on port 8080, then `server_address` would be `"0.0.0.0:8080"` or
82
84
  `"[::]:8080"`.
85
+ insecure : bool
86
+ Starts an insecure gRPC connection when True. Enables HTTPS connection
87
+ when False, using system certificates if `root_certificates` is None.
88
+ retry_invoker: RetryInvoker
89
+ Unused argument present for compatibilty.
83
90
  max_message_length : int
84
91
  The maximum length of gRPC messages that can be exchanged with the Flower
85
92
  server. The default should be sufficient for most models. Users who train
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
27
27
  from flwr.common.grpc import create_channel
28
28
  from flwr.common.logger import log, warn_experimental_feature
29
29
  from flwr.common.message import Message, Metadata
30
+ from flwr.common.retry_invoker import RetryInvoker
30
31
  from flwr.common.serde import message_from_taskins, message_to_taskres
31
32
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
33
  CreateNodeRequest,
@@ -51,6 +52,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:
51
52
  def grpc_request_response(
52
53
  server_address: str,
53
54
  insecure: bool,
55
+ retry_invoker: RetryInvoker,
54
56
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
55
57
  root_certificates: Optional[Union[bytes, str]] = None,
56
58
  ) -> Iterator[
@@ -72,6 +74,13 @@ def grpc_request_response(
72
74
  The IPv6 address of the server with `http://` or `https://`.
73
75
  If the Flower server runs on the same machine
74
76
  on port 8080, then `server_address` would be `"http://[::]:8080"`.
77
+ insecure : bool
78
+ Starts an insecure gRPC connection when True. Enables HTTPS connection
79
+ when False, using system certificates if `root_certificates` is None.
80
+ retry_invoker: RetryInvoker
81
+ `RetryInvoker` object that will try to reconnect the client to the server
82
+ after gRPC errors. If None, the client will only try to
83
+ reconnect once after a failure.
75
84
  max_message_length : int
76
85
  Ignored, only present to preserve API-compatibility.
77
86
  root_certificates : Optional[Union[bytes, str]] (default: None)
@@ -113,7 +122,8 @@ def grpc_request_response(
113
122
  def create_node() -> None:
114
123
  """Set create_node."""
115
124
  create_node_request = CreateNodeRequest()
116
- create_node_response = stub.CreateNode(
125
+ create_node_response = retry_invoker.invoke(
126
+ stub.CreateNode,
117
127
  request=create_node_request,
118
128
  )
119
129
  node_store[KEY_NODE] = create_node_response.node
@@ -127,7 +137,7 @@ def grpc_request_response(
127
137
  node: Node = cast(Node, node_store[KEY_NODE])
128
138
 
129
139
  delete_node_request = DeleteNodeRequest(node=node)
130
- stub.DeleteNode(request=delete_node_request)
140
+ retry_invoker.invoke(stub.DeleteNode, request=delete_node_request)
131
141
 
132
142
  del node_store[KEY_NODE]
133
143
 
@@ -141,7 +151,7 @@ def grpc_request_response(
141
151
 
142
152
  # Request instructions (task) from server
143
153
  request = PullTaskInsRequest(node=node)
144
- response = stub.PullTaskIns(request=request)
154
+ response = retry_invoker.invoke(stub.PullTaskIns, request=request)
145
155
 
146
156
  # Get the current TaskIns
147
157
  task_ins: Optional[TaskIns] = get_task_ins(response)
@@ -185,7 +195,7 @@ def grpc_request_response(
185
195
 
186
196
  # Serialize ProtoBuf to bytes
187
197
  request = PushTaskResRequest(task_res_list=[task_res])
188
- _ = stub.PushTaskRes(request)
198
+ _ = retry_invoker.invoke(stub.PushTaskRes, request)
189
199
 
190
200
  state[KEY_METADATA] = None
191
201
 
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
27
27
  from flwr.common.constant import MISSING_EXTRA_REST
28
28
  from flwr.common.logger import log
29
29
  from flwr.common.message import Message, Metadata
30
+ from flwr.common.retry_invoker import RetryInvoker
30
31
  from flwr.common.serde import message_from_taskins, message_to_taskres
31
32
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
33
  CreateNodeRequest,
@@ -61,6 +62,7 @@ PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
61
62
  def http_request_response(
62
63
  server_address: str,
63
64
  insecure: bool, # pylint: disable=unused-argument
65
+ retry_invoker: RetryInvoker,
64
66
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
65
67
  root_certificates: Optional[
66
68
  Union[bytes, str]
@@ -84,6 +86,12 @@ def http_request_response(
84
86
  The IPv6 address of the server with `http://` or `https://`.
85
87
  If the Flower server runs on the same machine
86
88
  on port 8080, then `server_address` would be `"http://[::]:8080"`.
89
+ insecure : bool
90
+ Unused argument present for compatibilty.
91
+ retry_invoker: RetryInvoker
92
+ `RetryInvoker` object that will try to reconnect the client to the server
93
+ after REST connection errors. If None, the client will only try to
94
+ reconnect once after a failure.
87
95
  max_message_length : int
88
96
  Ignored, only present to preserve API-compatibility.
89
97
  root_certificates : Optional[Union[bytes, str]] (default: None)
@@ -134,7 +142,8 @@ def http_request_response(
134
142
  create_node_req_proto = CreateNodeRequest()
135
143
  create_node_req_bytes: bytes = create_node_req_proto.SerializeToString()
136
144
 
137
- res = requests.post(
145
+ res = retry_invoker.invoke(
146
+ requests.post,
138
147
  url=f"{base_url}/{PATH_CREATE_NODE}",
139
148
  headers={
140
149
  "Accept": "application/protobuf",
@@ -177,7 +186,8 @@ def http_request_response(
177
186
  node: Node = cast(Node, node_store[KEY_NODE])
178
187
  delete_node_req_proto = DeleteNodeRequest(node=node)
179
188
  delete_node_req_req_bytes: bytes = delete_node_req_proto.SerializeToString()
180
- res = requests.post(
189
+ res = retry_invoker.invoke(
190
+ requests.post,
181
191
  url=f"{base_url}/{PATH_DELETE_NODE}",
182
192
  headers={
183
193
  "Accept": "application/protobuf",
@@ -218,7 +228,8 @@ def http_request_response(
218
228
  pull_task_ins_req_bytes: bytes = pull_task_ins_req_proto.SerializeToString()
219
229
 
220
230
  # Request instructions (task) from server
221
- res = requests.post(
231
+ res = retry_invoker.invoke(
232
+ requests.post,
222
233
  url=f"{base_url}/{PATH_PULL_TASK_INS}",
223
234
  headers={
224
235
  "Accept": "application/protobuf",
@@ -298,7 +309,8 @@ def http_request_response(
298
309
  )
299
310
 
300
311
  # Send ClientMessage to server
301
- res = requests.post(
312
+ res = retry_invoker.invoke(
313
+ requests.post,
302
314
  url=f"{base_url}/{PATH_PUSH_TASK_RES}",
303
315
  headers={
304
316
  "Accept": "application/protobuf",
flwr/common/__init__.py CHANGED
@@ -15,11 +15,13 @@
15
15
  """Common components shared between server and client."""
16
16
 
17
17
 
18
+ from .constant import MessageType as MessageType
18
19
  from .context import Context as Context
19
20
  from .date import now as now
20
21
  from .grpc import GRPC_MAX_MESSAGE_LENGTH
21
22
  from .logger import configure as configure
22
23
  from .logger import log as log
24
+ from .message import Error as Error
23
25
  from .message import Message as Message
24
26
  from .message import Metadata as Metadata
25
27
  from .parameter import bytes_to_ndarray as bytes_to_ndarray
@@ -74,6 +76,7 @@ __all__ = [
74
76
  "EventType",
75
77
  "FitIns",
76
78
  "FitRes",
79
+ "Error",
77
80
  "GetParametersIns",
78
81
  "GetParametersRes",
79
82
  "GetPropertiesIns",
@@ -81,6 +84,7 @@ __all__ = [
81
84
  "GRPC_MAX_MESSAGE_LENGTH",
82
85
  "log",
83
86
  "Message",
87
+ "MessageType",
84
88
  "Metadata",
85
89
  "Metrics",
86
90
  "MetricsAggregationFn",
flwr/common/constant.py CHANGED
@@ -42,6 +42,17 @@ MESSAGE_TYPE_FIT = "fit"
42
42
  MESSAGE_TYPE_EVALUATE = "evaluate"
43
43
 
44
44
 
45
+ class MessageType:
46
+ """Message type."""
47
+
48
+ TRAIN = "train"
49
+ EVALUATE = "evaluate"
50
+
51
+ def __new__(cls) -> MessageType:
52
+ """Prevent instantiation."""
53
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
54
+
55
+
45
56
  class SType:
46
57
  """Serialisation type."""
47
58
 
flwr/server/app.py CHANGED
@@ -362,10 +362,10 @@ def run_superlink() -> None:
362
362
  f_stop = asyncio.Event() # Does nothing
363
363
  _run_fleet_api_vce(
364
364
  num_supernodes=args.num_supernodes,
365
- client_app_module_name=args.client_app,
365
+ client_app_attr=args.client_app,
366
366
  backend_name=args.backend,
367
367
  backend_config_json_stream=args.backend_config,
368
- working_dir=args.dir,
368
+ app_dir=args.app_dir,
369
369
  state_factory=state_factory,
370
370
  f_stop=f_stop,
371
371
  )
@@ -438,10 +438,10 @@ def _run_fleet_api_grpc_rere(
438
438
  # pylint: disable=too-many-arguments
439
439
  def _run_fleet_api_vce(
440
440
  num_supernodes: int,
441
- client_app_module_name: str,
441
+ client_app_attr: str,
442
442
  backend_name: str,
443
443
  backend_config_json_stream: str,
444
- working_dir: str,
444
+ app_dir: str,
445
445
  state_factory: StateFactory,
446
446
  f_stop: asyncio.Event,
447
447
  ) -> None:
@@ -449,11 +449,11 @@ def _run_fleet_api_vce(
449
449
 
450
450
  start_vce(
451
451
  num_supernodes=num_supernodes,
452
- client_app_module_name=client_app_module_name,
452
+ client_app_attr=client_app_attr,
453
453
  backend_name=backend_name,
454
454
  backend_config_json_stream=backend_config_json_stream,
455
455
  state_factory=state_factory,
456
- working_dir=working_dir,
456
+ app_dir=app_dir,
457
457
  f_stop=f_stop,
458
458
  )
459
459
 
@@ -705,7 +705,7 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
705
705
  "`flwr.common.typing.ConfigsRecordValues`. ",
706
706
  )
707
707
  parser.add_argument(
708
- "--dir",
708
+ "--app-dir",
709
709
  default="",
710
710
  help="Add specified directory to the PYTHONPATH and load"
711
711
  "ClientApp from there."
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  import argparse
19
- import asyncio
20
19
  import sys
21
20
  from logging import DEBUG, WARN
22
21
  from pathlib import Path
@@ -30,17 +29,27 @@ from .server_app import ServerApp, load_server_app
30
29
 
31
30
 
32
31
  def run(
33
- server_app_attr: str,
34
32
  driver: Driver,
35
33
  server_app_dir: str,
36
- stop_event: Optional[asyncio.Event] = None,
34
+ server_app_attr: Optional[str] = None,
35
+ loaded_server_app: Optional[ServerApp] = None,
37
36
  ) -> None:
38
37
  """Run ServerApp with a given Driver."""
38
+ if not (server_app_attr is None) ^ (loaded_server_app is None):
39
+ raise ValueError(
40
+ "Either `server_app_attr` or `loaded_server_app` should be set "
41
+ "but not both. "
42
+ )
43
+
39
44
  if server_app_dir is not None:
40
45
  sys.path.insert(0, server_app_dir)
41
46
 
47
+ # Load ServerApp if needed
42
48
  def _load() -> ServerApp:
43
- server_app: ServerApp = load_server_app(server_app_attr)
49
+ if server_app_attr:
50
+ server_app: ServerApp = load_server_app(server_app_attr)
51
+ if loaded_server_app:
52
+ server_app = loaded_server_app
44
53
  return server_app
45
54
 
46
55
  server_app = _load()
@@ -52,10 +61,6 @@ def run(
52
61
  server_app(driver=driver, context=context)
53
62
 
54
63
  log(DEBUG, "ServerApp finished running.")
55
- # Upon completion, trigger stop event if one was passed
56
- if stop_event is not None:
57
- log(DEBUG, "Triggering stop event.")
58
- stop_event.set()
59
64
 
60
65
 
61
66
  def run_server_app() -> None:
@@ -117,7 +122,7 @@ def run_server_app() -> None:
117
122
  )
118
123
 
119
124
  # Run the Server App with the Driver
120
- run(server_app_attr, driver, server_app_dir)
125
+ run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
121
126
 
122
127
  # Clean up
123
128
  driver.__del__() # pylint: disable=unnecessary-dunder-call
flwr/server/server.py CHANGED
@@ -17,7 +17,7 @@
17
17
 
18
18
  import concurrent.futures
19
19
  import timeit
20
- from logging import DEBUG, INFO, WARN
20
+ from logging import INFO, WARN
21
21
  from typing import Dict, List, Optional, Tuple, Union
22
22
 
23
23
  from flwr.common import (
@@ -173,7 +173,7 @@ class Server:
173
173
  log(INFO, "evaluate_round %s: no clients selected, cancel", server_round)
174
174
  return None
175
175
  log(
176
- DEBUG,
176
+ INFO,
177
177
  "evaluate_round %s: strategy sampled %s clients (out of %s)",
178
178
  server_round,
179
179
  len(client_instructions),
@@ -188,7 +188,7 @@ class Server:
188
188
  group_id=server_round,
189
189
  )
190
190
  log(
191
- DEBUG,
191
+ INFO,
192
192
  "evaluate_round %s received %s results and %s failures",
193
193
  server_round,
194
194
  len(results),
@@ -223,7 +223,7 @@ class Server:
223
223
  log(INFO, "fit_round %s: no clients selected, cancel", server_round)
224
224
  return None
225
225
  log(
226
- DEBUG,
226
+ INFO,
227
227
  "fit_round %s: strategy sampled %s clients (out of %s)",
228
228
  server_round,
229
229
  len(client_instructions),
@@ -238,7 +238,7 @@ class Server:
238
238
  group_id=server_round,
239
239
  )
240
240
  log(
241
- DEBUG,
241
+ INFO,
242
242
  "fit_round %s received %s results and %s failures",
243
243
  server_round,
244
244
  len(results),
@@ -49,7 +49,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
49
49
  self, request: GetNodesRequest, context: grpc.ServicerContext
50
50
  ) -> GetNodesResponse:
51
51
  """Get available nodes."""
52
- log(INFO, "DriverServicer.GetNodes")
52
+ log(DEBUG, "DriverServicer.GetNodes")
53
53
  state: State = self.state_factory.state()
54
54
  all_ids: Set[int] = state.get_nodes(request.run_id)
55
55
  nodes: List[Node] = [