flwr-nightly 1.8.0.dev20240304__py3-none-any.whl → 1.8.0.dev20240306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/flower_toml.py +151 -0
  3. flwr/cli/new/new.py +1 -0
  4. flwr/cli/new/templates/app/code/client.numpy.py.tpl +24 -0
  5. flwr/cli/new/templates/app/code/server.numpy.py.tpl +12 -0
  6. flwr/cli/new/templates/app/flower.toml.tpl +2 -2
  7. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +2 -0
  8. flwr/cli/run/__init__.py +21 -0
  9. flwr/cli/run/run.py +102 -0
  10. flwr/client/app.py +93 -8
  11. flwr/client/grpc_client/connection.py +16 -14
  12. flwr/client/grpc_rere_client/connection.py +14 -4
  13. flwr/client/message_handler/message_handler.py +5 -10
  14. flwr/client/mod/centraldp_mods.py +5 -5
  15. flwr/client/mod/secure_aggregation/secaggplus_mod.py +2 -2
  16. flwr/client/rest_client/connection.py +16 -4
  17. flwr/common/__init__.py +6 -0
  18. flwr/common/constant.py +21 -4
  19. flwr/server/app.py +7 -7
  20. flwr/server/compat/driver_client_proxy.py +5 -11
  21. flwr/server/run_serverapp.py +14 -9
  22. flwr/server/server.py +5 -5
  23. flwr/server/superlink/driver/driver_servicer.py +1 -1
  24. flwr/server/superlink/fleet/vce/vce_api.py +17 -5
  25. flwr/server/workflow/default_workflows.py +4 -8
  26. flwr/simulation/__init__.py +2 -5
  27. flwr/simulation/ray_transport/ray_client_proxy.py +5 -10
  28. flwr/simulation/run_simulation.py +301 -76
  29. {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/METADATA +4 -3
  30. {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/RECORD +33 -27
  31. {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/entry_points.txt +1 -1
  32. {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/LICENSE +0 -0
  33. {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/WHEEL +0 -0
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
27
27
  from flwr.common.grpc import create_channel
28
28
  from flwr.common.logger import log, warn_experimental_feature
29
29
  from flwr.common.message import Message, Metadata
30
+ from flwr.common.retry_invoker import RetryInvoker
30
31
  from flwr.common.serde import message_from_taskins, message_to_taskres
31
32
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
33
  CreateNodeRequest,
@@ -51,6 +52,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:
51
52
  def grpc_request_response(
52
53
  server_address: str,
53
54
  insecure: bool,
55
+ retry_invoker: RetryInvoker,
54
56
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
55
57
  root_certificates: Optional[Union[bytes, str]] = None,
56
58
  ) -> Iterator[
@@ -72,6 +74,13 @@ def grpc_request_response(
72
74
  The IPv6 address of the server with `http://` or `https://`.
73
75
  If the Flower server runs on the same machine
74
76
  on port 8080, then `server_address` would be `"http://[::]:8080"`.
77
+ insecure : bool
78
+ Starts an insecure gRPC connection when True. Enables HTTPS connection
79
+ when False, using system certificates if `root_certificates` is None.
80
+ retry_invoker: RetryInvoker
81
+ `RetryInvoker` object that will try to reconnect the client to the server
82
+ after gRPC errors. If None, the client will only try to
83
+ reconnect once after a failure.
75
84
  max_message_length : int
76
85
  Ignored, only present to preserve API-compatibility.
77
86
  root_certificates : Optional[Union[bytes, str]] (default: None)
@@ -113,7 +122,8 @@ def grpc_request_response(
113
122
  def create_node() -> None:
114
123
  """Set create_node."""
115
124
  create_node_request = CreateNodeRequest()
116
- create_node_response = stub.CreateNode(
125
+ create_node_response = retry_invoker.invoke(
126
+ stub.CreateNode,
117
127
  request=create_node_request,
118
128
  )
119
129
  node_store[KEY_NODE] = create_node_response.node
@@ -127,7 +137,7 @@ def grpc_request_response(
127
137
  node: Node = cast(Node, node_store[KEY_NODE])
128
138
 
129
139
  delete_node_request = DeleteNodeRequest(node=node)
130
- stub.DeleteNode(request=delete_node_request)
140
+ retry_invoker.invoke(stub.DeleteNode, request=delete_node_request)
131
141
 
132
142
  del node_store[KEY_NODE]
133
143
 
@@ -141,7 +151,7 @@ def grpc_request_response(
141
151
 
142
152
  # Request instructions (task) from server
143
153
  request = PullTaskInsRequest(node=node)
144
- response = stub.PullTaskIns(request=request)
154
+ response = retry_invoker.invoke(stub.PullTaskIns, request=request)
145
155
 
146
156
  # Get the current TaskIns
147
157
  task_ins: Optional[TaskIns] = get_task_ins(response)
@@ -185,7 +195,7 @@ def grpc_request_response(
185
195
 
186
196
  # Serialize ProtoBuf to bytes
187
197
  request = PushTaskResRequest(task_res_list=[task_res])
188
- _ = stub.PushTaskRes(request)
198
+ _ = retry_invoker.invoke(stub.PushTaskRes, request)
189
199
 
190
200
  state[KEY_METADATA] = None
191
201
 
@@ -27,12 +27,7 @@ from flwr.client.client import (
27
27
  from flwr.client.numpy_client import NumPyClient
28
28
  from flwr.client.typing import ClientFn
29
29
  from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log
30
- from flwr.common.constant import (
31
- MESSAGE_TYPE_EVALUATE,
32
- MESSAGE_TYPE_FIT,
33
- MESSAGE_TYPE_GET_PARAMETERS,
34
- MESSAGE_TYPE_GET_PROPERTIES,
35
- )
30
+ from flwr.common.constant import MessageType, MessageTypeLegacy
36
31
  from flwr.common.recordset_compat import (
37
32
  evaluateres_to_recordset,
38
33
  fitres_to_recordset,
@@ -115,14 +110,14 @@ def handle_legacy_message_from_msgtype(
115
110
  message_type = message.metadata.message_type
116
111
 
117
112
  # Handle GetPropertiesIns
118
- if message_type == MESSAGE_TYPE_GET_PROPERTIES:
113
+ if message_type == MessageTypeLegacy.GET_PROPERTIES:
119
114
  get_properties_res = maybe_call_get_properties(
120
115
  client=client,
121
116
  get_properties_ins=recordset_to_getpropertiesins(message.content),
122
117
  )
123
118
  out_recordset = getpropertiesres_to_recordset(get_properties_res)
124
119
  # Handle GetParametersIns
125
- elif message_type == MESSAGE_TYPE_GET_PARAMETERS:
120
+ elif message_type == MessageTypeLegacy.GET_PARAMETERS:
126
121
  get_parameters_res = maybe_call_get_parameters(
127
122
  client=client,
128
123
  get_parameters_ins=recordset_to_getparametersins(message.content),
@@ -131,14 +126,14 @@ def handle_legacy_message_from_msgtype(
131
126
  get_parameters_res, keep_input=False
132
127
  )
133
128
  # Handle FitIns
134
- elif message_type == MESSAGE_TYPE_FIT:
129
+ elif message_type == MessageType.TRAIN:
135
130
  fit_res = maybe_call_fit(
136
131
  client=client,
137
132
  fit_ins=recordset_to_fitins(message.content, keep_input=True),
138
133
  )
139
134
  out_recordset = fitres_to_recordset(fit_res, keep_input=False)
140
135
  # Handle EvaluateIns
141
- elif message_type == MESSAGE_TYPE_EVALUATE:
136
+ elif message_type == MessageType.EVALUATE:
142
137
  evaluate_res = maybe_call_evaluate(
143
138
  client=client,
144
139
  evaluate_ins=recordset_to_evaluateins(message.content, keep_input=True),
@@ -18,7 +18,7 @@
18
18
  from flwr.client.typing import ClientAppCallable
19
19
  from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
20
20
  from flwr.common import recordset_compat as compat
21
- from flwr.common.constant import MESSAGE_TYPE_FIT
21
+ from flwr.common.constant import MessageType
22
22
  from flwr.common.context import Context
23
23
  from flwr.common.differential_privacy import (
24
24
  compute_adaptive_clip_model_update,
@@ -40,7 +40,7 @@ def fixedclipping_mod(
40
40
 
41
41
  This mod clips the client model updates before sending them to the server.
42
42
 
43
- It operates on messages with type MESSAGE_TYPE_FIT.
43
+ It operates on messages with type MessageType.TRAIN.
44
44
 
45
45
  Notes
46
46
  -----
@@ -48,7 +48,7 @@ def fixedclipping_mod(
48
48
 
49
49
  Typically, fixedclipping_mod should be the last to operate on params.
50
50
  """
51
- if msg.metadata.message_type != MESSAGE_TYPE_FIT:
51
+ if msg.metadata.message_type != MessageType.TRAIN:
52
52
  return call_next(msg, ctxt)
53
53
  fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
54
54
  if KEY_CLIPPING_NORM not in fit_ins.config:
@@ -93,7 +93,7 @@ def adaptiveclipping_mod(
93
93
 
94
94
  It also sends KEY_NORM_BIT to the server for computing the new clipping value.
95
95
 
96
- It operates on messages with type MESSAGE_TYPE_FIT.
96
+ It operates on messages with type MessageType.TRAIN.
97
97
 
98
98
  Notes
99
99
  -----
@@ -101,7 +101,7 @@ def adaptiveclipping_mod(
101
101
 
102
102
  Typically, adaptiveclipping_mod should be the last to operate on params.
103
103
  """
104
- if msg.metadata.message_type != MESSAGE_TYPE_FIT:
104
+ if msg.metadata.message_type != MessageType.TRAIN:
105
105
  return call_next(msg, ctxt)
106
106
 
107
107
  fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
@@ -30,7 +30,7 @@ from flwr.common import (
30
30
  parameters_to_ndarrays,
31
31
  )
32
32
  from flwr.common import recordset_compat as compat
33
- from flwr.common.constant import MESSAGE_TYPE_FIT
33
+ from flwr.common.constant import MessageType
34
34
  from flwr.common.logger import log
35
35
  from flwr.common.secure_aggregation.crypto.shamir import create_shares
36
36
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
@@ -150,7 +150,7 @@ def secaggplus_mod(
150
150
  ) -> Message:
151
151
  """Handle incoming message and return results, following the SecAgg+ protocol."""
152
152
  # Ignore non-fit messages
153
- if msg.metadata.message_type != MESSAGE_TYPE_FIT:
153
+ if msg.metadata.message_type != MessageType.TRAIN:
154
154
  return call_next(msg, ctxt)
155
155
 
156
156
  # Retrieve local state
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
27
27
  from flwr.common.constant import MISSING_EXTRA_REST
28
28
  from flwr.common.logger import log
29
29
  from flwr.common.message import Message, Metadata
30
+ from flwr.common.retry_invoker import RetryInvoker
30
31
  from flwr.common.serde import message_from_taskins, message_to_taskres
31
32
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
33
  CreateNodeRequest,
@@ -61,6 +62,7 @@ PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
61
62
  def http_request_response(
62
63
  server_address: str,
63
64
  insecure: bool, # pylint: disable=unused-argument
65
+ retry_invoker: RetryInvoker,
64
66
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
65
67
  root_certificates: Optional[
66
68
  Union[bytes, str]
@@ -84,6 +86,12 @@ def http_request_response(
84
86
  The IPv6 address of the server with `http://` or `https://`.
85
87
  If the Flower server runs on the same machine
86
88
  on port 8080, then `server_address` would be `"http://[::]:8080"`.
89
+ insecure : bool
90
+ Unused argument present for compatibilty.
91
+ retry_invoker: RetryInvoker
92
+ `RetryInvoker` object that will try to reconnect the client to the server
93
+ after REST connection errors. If None, the client will only try to
94
+ reconnect once after a failure.
87
95
  max_message_length : int
88
96
  Ignored, only present to preserve API-compatibility.
89
97
  root_certificates : Optional[Union[bytes, str]] (default: None)
@@ -134,7 +142,8 @@ def http_request_response(
134
142
  create_node_req_proto = CreateNodeRequest()
135
143
  create_node_req_bytes: bytes = create_node_req_proto.SerializeToString()
136
144
 
137
- res = requests.post(
145
+ res = retry_invoker.invoke(
146
+ requests.post,
138
147
  url=f"{base_url}/{PATH_CREATE_NODE}",
139
148
  headers={
140
149
  "Accept": "application/protobuf",
@@ -177,7 +186,8 @@ def http_request_response(
177
186
  node: Node = cast(Node, node_store[KEY_NODE])
178
187
  delete_node_req_proto = DeleteNodeRequest(node=node)
179
188
  delete_node_req_req_bytes: bytes = delete_node_req_proto.SerializeToString()
180
- res = requests.post(
189
+ res = retry_invoker.invoke(
190
+ requests.post,
181
191
  url=f"{base_url}/{PATH_DELETE_NODE}",
182
192
  headers={
183
193
  "Accept": "application/protobuf",
@@ -218,7 +228,8 @@ def http_request_response(
218
228
  pull_task_ins_req_bytes: bytes = pull_task_ins_req_proto.SerializeToString()
219
229
 
220
230
  # Request instructions (task) from server
221
- res = requests.post(
231
+ res = retry_invoker.invoke(
232
+ requests.post,
222
233
  url=f"{base_url}/{PATH_PULL_TASK_INS}",
223
234
  headers={
224
235
  "Accept": "application/protobuf",
@@ -298,7 +309,8 @@ def http_request_response(
298
309
  )
299
310
 
300
311
  # Send ClientMessage to server
301
- res = requests.post(
312
+ res = retry_invoker.invoke(
313
+ requests.post,
302
314
  url=f"{base_url}/{PATH_PUSH_TASK_RES}",
303
315
  headers={
304
316
  "Accept": "application/protobuf",
flwr/common/__init__.py CHANGED
@@ -15,11 +15,14 @@
15
15
  """Common components shared between server and client."""
16
16
 
17
17
 
18
+ from .constant import MessageType as MessageType
19
+ from .constant import MessageTypeLegacy as MessageTypeLegacy
18
20
  from .context import Context as Context
19
21
  from .date import now as now
20
22
  from .grpc import GRPC_MAX_MESSAGE_LENGTH
21
23
  from .logger import configure as configure
22
24
  from .logger import log as log
25
+ from .message import Error as Error
23
26
  from .message import Message as Message
24
27
  from .message import Metadata as Metadata
25
28
  from .parameter import bytes_to_ndarray as bytes_to_ndarray
@@ -74,6 +77,7 @@ __all__ = [
74
77
  "EventType",
75
78
  "FitIns",
76
79
  "FitRes",
80
+ "Error",
77
81
  "GetParametersIns",
78
82
  "GetParametersRes",
79
83
  "GetPropertiesIns",
@@ -81,6 +85,8 @@ __all__ = [
81
85
  "GRPC_MAX_MESSAGE_LENGTH",
82
86
  "log",
83
87
  "Message",
88
+ "MessageType",
89
+ "MessageTypeLegacy",
84
90
  "Metadata",
85
91
  "Metrics",
86
92
  "MetricsAggregationFn",
flwr/common/constant.py CHANGED
@@ -36,10 +36,27 @@ TRANSPORT_TYPES = [
36
36
  TRANSPORT_TYPE_VCE,
37
37
  ]
38
38
 
39
- MESSAGE_TYPE_GET_PROPERTIES = "get_properties"
40
- MESSAGE_TYPE_GET_PARAMETERS = "get_parameters"
41
- MESSAGE_TYPE_FIT = "fit"
42
- MESSAGE_TYPE_EVALUATE = "evaluate"
39
+
40
+ class MessageType:
41
+ """Message type."""
42
+
43
+ TRAIN = "train"
44
+ EVALUATE = "evaluate"
45
+
46
+ def __new__(cls) -> MessageType:
47
+ """Prevent instantiation."""
48
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
49
+
50
+
51
+ class MessageTypeLegacy:
52
+ """Legacy message type."""
53
+
54
+ GET_PROPERTIES = "get_properties"
55
+ GET_PARAMETERS = "get_parameters"
56
+
57
+ def __new__(cls) -> MessageTypeLegacy:
58
+ """Prevent instantiation."""
59
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
43
60
 
44
61
 
45
62
  class SType:
flwr/server/app.py CHANGED
@@ -362,10 +362,10 @@ def run_superlink() -> None:
362
362
  f_stop = asyncio.Event() # Does nothing
363
363
  _run_fleet_api_vce(
364
364
  num_supernodes=args.num_supernodes,
365
- client_app_module_name=args.client_app,
365
+ client_app_attr=args.client_app,
366
366
  backend_name=args.backend,
367
367
  backend_config_json_stream=args.backend_config,
368
- working_dir=args.dir,
368
+ app_dir=args.app_dir,
369
369
  state_factory=state_factory,
370
370
  f_stop=f_stop,
371
371
  )
@@ -438,10 +438,10 @@ def _run_fleet_api_grpc_rere(
438
438
  # pylint: disable=too-many-arguments
439
439
  def _run_fleet_api_vce(
440
440
  num_supernodes: int,
441
- client_app_module_name: str,
441
+ client_app_attr: str,
442
442
  backend_name: str,
443
443
  backend_config_json_stream: str,
444
- working_dir: str,
444
+ app_dir: str,
445
445
  state_factory: StateFactory,
446
446
  f_stop: asyncio.Event,
447
447
  ) -> None:
@@ -449,11 +449,11 @@ def _run_fleet_api_vce(
449
449
 
450
450
  start_vce(
451
451
  num_supernodes=num_supernodes,
452
- client_app_module_name=client_app_module_name,
452
+ client_app_attr=client_app_attr,
453
453
  backend_name=backend_name,
454
454
  backend_config_json_stream=backend_config_json_stream,
455
455
  state_factory=state_factory,
456
- working_dir=working_dir,
456
+ app_dir=app_dir,
457
457
  f_stop=f_stop,
458
458
  )
459
459
 
@@ -705,7 +705,7 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
705
705
  "`flwr.common.typing.ConfigsRecordValues`. ",
706
706
  )
707
707
  parser.add_argument(
708
- "--dir",
708
+ "--app-dir",
709
709
  default="",
710
710
  help="Add specified directory to the PYTHONPATH and load"
711
711
  "ClientApp from there."
@@ -19,15 +19,9 @@ import time
19
19
  from typing import List, Optional
20
20
 
21
21
  from flwr import common
22
- from flwr.common import RecordSet
22
+ from flwr.common import MessageType, MessageTypeLegacy, RecordSet
23
23
  from flwr.common import recordset_compat as compat
24
24
  from flwr.common import serde
25
- from flwr.common.constant import (
26
- MESSAGE_TYPE_EVALUATE,
27
- MESSAGE_TYPE_FIT,
28
- MESSAGE_TYPE_GET_PARAMETERS,
29
- MESSAGE_TYPE_GET_PROPERTIES,
30
- )
31
25
  from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
32
26
  from flwr.server.client_proxy import ClientProxy
33
27
 
@@ -57,7 +51,7 @@ class DriverClientProxy(ClientProxy):
57
51
  out_recordset = compat.getpropertiesins_to_recordset(ins)
58
52
  # Fetch response
59
53
  in_recordset = self._send_receive_recordset(
60
- out_recordset, MESSAGE_TYPE_GET_PROPERTIES, timeout, group_id
54
+ out_recordset, MessageTypeLegacy.GET_PROPERTIES, timeout, group_id
61
55
  )
62
56
  # RecordSet to Res
63
57
  return compat.recordset_to_getpropertiesres(in_recordset)
@@ -73,7 +67,7 @@ class DriverClientProxy(ClientProxy):
73
67
  out_recordset = compat.getparametersins_to_recordset(ins)
74
68
  # Fetch response
75
69
  in_recordset = self._send_receive_recordset(
76
- out_recordset, MESSAGE_TYPE_GET_PARAMETERS, timeout, group_id
70
+ out_recordset, MessageTypeLegacy.GET_PARAMETERS, timeout, group_id
77
71
  )
78
72
  # RecordSet to Res
79
73
  return compat.recordset_to_getparametersres(in_recordset, False)
@@ -86,7 +80,7 @@ class DriverClientProxy(ClientProxy):
86
80
  out_recordset = compat.fitins_to_recordset(ins, keep_input=True)
87
81
  # Fetch response
88
82
  in_recordset = self._send_receive_recordset(
89
- out_recordset, MESSAGE_TYPE_FIT, timeout, group_id
83
+ out_recordset, MessageType.TRAIN, timeout, group_id
90
84
  )
91
85
  # RecordSet to Res
92
86
  return compat.recordset_to_fitres(in_recordset, keep_input=False)
@@ -99,7 +93,7 @@ class DriverClientProxy(ClientProxy):
99
93
  out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True)
100
94
  # Fetch response
101
95
  in_recordset = self._send_receive_recordset(
102
- out_recordset, MESSAGE_TYPE_EVALUATE, timeout, group_id
96
+ out_recordset, MessageType.EVALUATE, timeout, group_id
103
97
  )
104
98
  # RecordSet to Res
105
99
  return compat.recordset_to_evaluateres(in_recordset)
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  import argparse
19
- import asyncio
20
19
  import sys
21
20
  from logging import DEBUG, WARN
22
21
  from pathlib import Path
@@ -30,17 +29,27 @@ from .server_app import ServerApp, load_server_app
30
29
 
31
30
 
32
31
  def run(
33
- server_app_attr: str,
34
32
  driver: Driver,
35
33
  server_app_dir: str,
36
- stop_event: Optional[asyncio.Event] = None,
34
+ server_app_attr: Optional[str] = None,
35
+ loaded_server_app: Optional[ServerApp] = None,
37
36
  ) -> None:
38
37
  """Run ServerApp with a given Driver."""
38
+ if not (server_app_attr is None) ^ (loaded_server_app is None):
39
+ raise ValueError(
40
+ "Either `server_app_attr` or `loaded_server_app` should be set "
41
+ "but not both. "
42
+ )
43
+
39
44
  if server_app_dir is not None:
40
45
  sys.path.insert(0, server_app_dir)
41
46
 
47
+ # Load ServerApp if needed
42
48
  def _load() -> ServerApp:
43
- server_app: ServerApp = load_server_app(server_app_attr)
49
+ if server_app_attr:
50
+ server_app: ServerApp = load_server_app(server_app_attr)
51
+ if loaded_server_app:
52
+ server_app = loaded_server_app
44
53
  return server_app
45
54
 
46
55
  server_app = _load()
@@ -52,10 +61,6 @@ def run(
52
61
  server_app(driver=driver, context=context)
53
62
 
54
63
  log(DEBUG, "ServerApp finished running.")
55
- # Upon completion, trigger stop event if one was passed
56
- if stop_event is not None:
57
- log(DEBUG, "Triggering stop event.")
58
- stop_event.set()
59
64
 
60
65
 
61
66
  def run_server_app() -> None:
@@ -117,7 +122,7 @@ def run_server_app() -> None:
117
122
  )
118
123
 
119
124
  # Run the Server App with the Driver
120
- run(server_app_attr, driver, server_app_dir)
125
+ run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
121
126
 
122
127
  # Clean up
123
128
  driver.__del__() # pylint: disable=unnecessary-dunder-call
flwr/server/server.py CHANGED
@@ -17,7 +17,7 @@
17
17
 
18
18
  import concurrent.futures
19
19
  import timeit
20
- from logging import DEBUG, INFO, WARN
20
+ from logging import INFO, WARN
21
21
  from typing import Dict, List, Optional, Tuple, Union
22
22
 
23
23
  from flwr.common import (
@@ -173,7 +173,7 @@ class Server:
173
173
  log(INFO, "evaluate_round %s: no clients selected, cancel", server_round)
174
174
  return None
175
175
  log(
176
- DEBUG,
176
+ INFO,
177
177
  "evaluate_round %s: strategy sampled %s clients (out of %s)",
178
178
  server_round,
179
179
  len(client_instructions),
@@ -188,7 +188,7 @@ class Server:
188
188
  group_id=server_round,
189
189
  )
190
190
  log(
191
- DEBUG,
191
+ INFO,
192
192
  "evaluate_round %s received %s results and %s failures",
193
193
  server_round,
194
194
  len(results),
@@ -223,7 +223,7 @@ class Server:
223
223
  log(INFO, "fit_round %s: no clients selected, cancel", server_round)
224
224
  return None
225
225
  log(
226
- DEBUG,
226
+ INFO,
227
227
  "fit_round %s: strategy sampled %s clients (out of %s)",
228
228
  server_round,
229
229
  len(client_instructions),
@@ -238,7 +238,7 @@ class Server:
238
238
  group_id=server_round,
239
239
  )
240
240
  log(
241
- DEBUG,
241
+ INFO,
242
242
  "fit_round %s received %s results and %s failures",
243
243
  server_round,
244
244
  len(results),
@@ -49,7 +49,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
49
49
  self, request: GetNodesRequest, context: grpc.ServicerContext
50
50
  ) -> GetNodesResponse:
51
51
  """Get available nodes."""
52
- log(INFO, "DriverServicer.GetNodes")
52
+ log(DEBUG, "DriverServicer.GetNodes")
53
53
  state: State = self.state_factory.state()
54
54
  all_ids: Set[int] = state.get_nodes(request.run_id)
55
55
  nodes: List[Node] = [
@@ -219,16 +219,23 @@ async def run(
219
219
 
220
220
  # pylint: disable=too-many-arguments,unused-argument,too-many-locals
221
221
  def start_vce(
222
- client_app_module_name: str,
223
222
  backend_name: str,
224
223
  backend_config_json_stream: str,
225
- working_dir: str,
224
+ app_dir: str,
226
225
  f_stop: asyncio.Event,
226
+ client_app: Optional[ClientApp] = None,
227
+ client_app_attr: Optional[str] = None,
227
228
  num_supernodes: Optional[int] = None,
228
229
  state_factory: Optional[StateFactory] = None,
229
230
  existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
230
231
  ) -> None:
231
232
  """Start Fleet API with the Simulation Engine."""
233
+ if client_app_attr is not None and client_app is not None:
234
+ raise ValueError(
235
+ "Both `client_app_attr` and `client_app` are provided, "
236
+ "but only one is allowed."
237
+ )
238
+
232
239
  if num_supernodes is not None and existing_nodes_mapping is not None:
233
240
  raise ValueError(
234
241
  "Both `num_supernodes` and `existing_nodes_mapping` are provided, "
@@ -290,12 +297,17 @@ def start_vce(
290
297
 
291
298
  def backend_fn() -> Backend:
292
299
  """Instantiate a Backend."""
293
- return backend_type(backend_config, work_dir=working_dir)
300
+ return backend_type(backend_config, work_dir=app_dir)
294
301
 
295
- log(INFO, "client_app_module_name = %s", client_app_module_name)
302
+ log(INFO, "client_app_attr = %s", client_app_attr)
296
303
 
304
+ # Load ClientApp if needed
297
305
  def _load() -> ClientApp:
298
- app: ClientApp = load_client_app(client_app_module_name)
306
+
307
+ if client_app_attr:
308
+ app: ClientApp = load_client_app(client_app_attr)
309
+ if client_app:
310
+ app = client_app
299
311
  return app
300
312
 
301
313
  app_fn = _load
@@ -21,11 +21,7 @@ from typing import Optional, cast
21
21
 
22
22
  import flwr.common.recordset_compat as compat
23
23
  from flwr.common import ConfigsRecord, Context, GetParametersIns, log
24
- from flwr.common.constant import (
25
- MESSAGE_TYPE_EVALUATE,
26
- MESSAGE_TYPE_FIT,
27
- MESSAGE_TYPE_GET_PARAMETERS,
28
- )
24
+ from flwr.common.constant import MessageType, MessageTypeLegacy
29
25
 
30
26
  from ..compat.app_utils import start_update_client_manager_thread
31
27
  from ..compat.legacy_context import LegacyContext
@@ -134,7 +130,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
134
130
  [
135
131
  driver.create_message(
136
132
  content=content,
137
- message_type=MESSAGE_TYPE_GET_PARAMETERS,
133
+ message_type=MessageTypeLegacy.GET_PARAMETERS,
138
134
  dst_node_id=random_client.node_id,
139
135
  group_id="",
140
136
  ttl="",
@@ -232,7 +228,7 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
232
228
  out_messages = [
233
229
  driver.create_message(
234
230
  content=compat.fitins_to_recordset(fitins, True),
235
- message_type=MESSAGE_TYPE_FIT,
231
+ message_type=MessageType.TRAIN,
236
232
  dst_node_id=proxy.node_id,
237
233
  group_id="",
238
234
  ttl="",
@@ -313,7 +309,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
313
309
  out_messages = [
314
310
  driver.create_message(
315
311
  content=compat.evaluateins_to_recordset(evalins, True),
316
- message_type=MESSAGE_TYPE_EVALUATE,
312
+ message_type=MessageType.EVALUATE,
317
313
  dst_node_id=proxy.node_id,
318
314
  group_id="",
319
315
  ttl="",
@@ -17,7 +17,7 @@
17
17
 
18
18
  import importlib
19
19
 
20
- from flwr.simulation.run_simulation import run_simulation
20
+ from flwr.simulation.run_simulation import run_simulation, run_simulation_from_cli
21
21
 
22
22
  is_ray_installed = importlib.util.find_spec("ray") is not None
23
23
 
@@ -36,7 +36,4 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
36
36
  raise ImportError(RAY_IMPORT_ERROR)
37
37
 
38
38
 
39
- __all__ = [
40
- "start_simulation",
41
- "run_simulation",
42
- ]
39
+ __all__ = ["start_simulation", "run_simulation_from_cli", "run_simulation"]