flwr-nightly 1.8.0.dev20240304__py3-none-any.whl → 1.8.0.dev20240306__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (33) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/flower_toml.py +151 -0
  3. flwr/cli/new/new.py +1 -0
  4. flwr/cli/new/templates/app/code/client.numpy.py.tpl +24 -0
  5. flwr/cli/new/templates/app/code/server.numpy.py.tpl +12 -0
  6. flwr/cli/new/templates/app/flower.toml.tpl +2 -2
  7. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +2 -0
  8. flwr/cli/run/__init__.py +21 -0
  9. flwr/cli/run/run.py +102 -0
  10. flwr/client/app.py +93 -8
  11. flwr/client/grpc_client/connection.py +16 -14
  12. flwr/client/grpc_rere_client/connection.py +14 -4
  13. flwr/client/message_handler/message_handler.py +5 -10
  14. flwr/client/mod/centraldp_mods.py +5 -5
  15. flwr/client/mod/secure_aggregation/secaggplus_mod.py +2 -2
  16. flwr/client/rest_client/connection.py +16 -4
  17. flwr/common/__init__.py +6 -0
  18. flwr/common/constant.py +21 -4
  19. flwr/server/app.py +7 -7
  20. flwr/server/compat/driver_client_proxy.py +5 -11
  21. flwr/server/run_serverapp.py +14 -9
  22. flwr/server/server.py +5 -5
  23. flwr/server/superlink/driver/driver_servicer.py +1 -1
  24. flwr/server/superlink/fleet/vce/vce_api.py +17 -5
  25. flwr/server/workflow/default_workflows.py +4 -8
  26. flwr/simulation/__init__.py +2 -5
  27. flwr/simulation/ray_transport/ray_client_proxy.py +5 -10
  28. flwr/simulation/run_simulation.py +301 -76
  29. {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/METADATA +4 -3
  30. {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/RECORD +33 -27
  31. {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/entry_points.txt +1 -1
  32. {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/LICENSE +0 -0
  33. {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/WHEEL +0 -0
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
27
27
  from flwr.common.grpc import create_channel
28
28
  from flwr.common.logger import log, warn_experimental_feature
29
29
  from flwr.common.message import Message, Metadata
30
+ from flwr.common.retry_invoker import RetryInvoker
30
31
  from flwr.common.serde import message_from_taskins, message_to_taskres
31
32
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
33
  CreateNodeRequest,
@@ -51,6 +52,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:
51
52
  def grpc_request_response(
52
53
  server_address: str,
53
54
  insecure: bool,
55
+ retry_invoker: RetryInvoker,
54
56
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
55
57
  root_certificates: Optional[Union[bytes, str]] = None,
56
58
  ) -> Iterator[
@@ -72,6 +74,13 @@ def grpc_request_response(
72
74
  The IPv6 address of the server with `http://` or `https://`.
73
75
  If the Flower server runs on the same machine
74
76
  on port 8080, then `server_address` would be `"http://[::]:8080"`.
77
+ insecure : bool
78
+ Starts an insecure gRPC connection when True. Enables HTTPS connection
79
+ when False, using system certificates if `root_certificates` is None.
80
+ retry_invoker: RetryInvoker
81
+ `RetryInvoker` object that will try to reconnect the client to the server
82
+ after gRPC errors. If None, the client will only try to
83
+ reconnect once after a failure.
75
84
  max_message_length : int
76
85
  Ignored, only present to preserve API-compatibility.
77
86
  root_certificates : Optional[Union[bytes, str]] (default: None)
@@ -113,7 +122,8 @@ def grpc_request_response(
113
122
  def create_node() -> None:
114
123
  """Set create_node."""
115
124
  create_node_request = CreateNodeRequest()
116
- create_node_response = stub.CreateNode(
125
+ create_node_response = retry_invoker.invoke(
126
+ stub.CreateNode,
117
127
  request=create_node_request,
118
128
  )
119
129
  node_store[KEY_NODE] = create_node_response.node
@@ -127,7 +137,7 @@ def grpc_request_response(
127
137
  node: Node = cast(Node, node_store[KEY_NODE])
128
138
 
129
139
  delete_node_request = DeleteNodeRequest(node=node)
130
- stub.DeleteNode(request=delete_node_request)
140
+ retry_invoker.invoke(stub.DeleteNode, request=delete_node_request)
131
141
 
132
142
  del node_store[KEY_NODE]
133
143
 
@@ -141,7 +151,7 @@ def grpc_request_response(
141
151
 
142
152
  # Request instructions (task) from server
143
153
  request = PullTaskInsRequest(node=node)
144
- response = stub.PullTaskIns(request=request)
154
+ response = retry_invoker.invoke(stub.PullTaskIns, request=request)
145
155
 
146
156
  # Get the current TaskIns
147
157
  task_ins: Optional[TaskIns] = get_task_ins(response)
@@ -185,7 +195,7 @@ def grpc_request_response(
185
195
 
186
196
  # Serialize ProtoBuf to bytes
187
197
  request = PushTaskResRequest(task_res_list=[task_res])
188
- _ = stub.PushTaskRes(request)
198
+ _ = retry_invoker.invoke(stub.PushTaskRes, request)
189
199
 
190
200
  state[KEY_METADATA] = None
191
201
 
@@ -27,12 +27,7 @@ from flwr.client.client import (
27
27
  from flwr.client.numpy_client import NumPyClient
28
28
  from flwr.client.typing import ClientFn
29
29
  from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log
30
- from flwr.common.constant import (
31
- MESSAGE_TYPE_EVALUATE,
32
- MESSAGE_TYPE_FIT,
33
- MESSAGE_TYPE_GET_PARAMETERS,
34
- MESSAGE_TYPE_GET_PROPERTIES,
35
- )
30
+ from flwr.common.constant import MessageType, MessageTypeLegacy
36
31
  from flwr.common.recordset_compat import (
37
32
  evaluateres_to_recordset,
38
33
  fitres_to_recordset,
@@ -115,14 +110,14 @@ def handle_legacy_message_from_msgtype(
115
110
  message_type = message.metadata.message_type
116
111
 
117
112
  # Handle GetPropertiesIns
118
- if message_type == MESSAGE_TYPE_GET_PROPERTIES:
113
+ if message_type == MessageTypeLegacy.GET_PROPERTIES:
119
114
  get_properties_res = maybe_call_get_properties(
120
115
  client=client,
121
116
  get_properties_ins=recordset_to_getpropertiesins(message.content),
122
117
  )
123
118
  out_recordset = getpropertiesres_to_recordset(get_properties_res)
124
119
  # Handle GetParametersIns
125
- elif message_type == MESSAGE_TYPE_GET_PARAMETERS:
120
+ elif message_type == MessageTypeLegacy.GET_PARAMETERS:
126
121
  get_parameters_res = maybe_call_get_parameters(
127
122
  client=client,
128
123
  get_parameters_ins=recordset_to_getparametersins(message.content),
@@ -131,14 +126,14 @@ def handle_legacy_message_from_msgtype(
131
126
  get_parameters_res, keep_input=False
132
127
  )
133
128
  # Handle FitIns
134
- elif message_type == MESSAGE_TYPE_FIT:
129
+ elif message_type == MessageType.TRAIN:
135
130
  fit_res = maybe_call_fit(
136
131
  client=client,
137
132
  fit_ins=recordset_to_fitins(message.content, keep_input=True),
138
133
  )
139
134
  out_recordset = fitres_to_recordset(fit_res, keep_input=False)
140
135
  # Handle EvaluateIns
141
- elif message_type == MESSAGE_TYPE_EVALUATE:
136
+ elif message_type == MessageType.EVALUATE:
142
137
  evaluate_res = maybe_call_evaluate(
143
138
  client=client,
144
139
  evaluate_ins=recordset_to_evaluateins(message.content, keep_input=True),
@@ -18,7 +18,7 @@
18
18
  from flwr.client.typing import ClientAppCallable
19
19
  from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
20
20
  from flwr.common import recordset_compat as compat
21
- from flwr.common.constant import MESSAGE_TYPE_FIT
21
+ from flwr.common.constant import MessageType
22
22
  from flwr.common.context import Context
23
23
  from flwr.common.differential_privacy import (
24
24
  compute_adaptive_clip_model_update,
@@ -40,7 +40,7 @@ def fixedclipping_mod(
40
40
 
41
41
  This mod clips the client model updates before sending them to the server.
42
42
 
43
- It operates on messages with type MESSAGE_TYPE_FIT.
43
+ It operates on messages with type MessageType.TRAIN.
44
44
 
45
45
  Notes
46
46
  -----
@@ -48,7 +48,7 @@ def fixedclipping_mod(
48
48
 
49
49
  Typically, fixedclipping_mod should be the last to operate on params.
50
50
  """
51
- if msg.metadata.message_type != MESSAGE_TYPE_FIT:
51
+ if msg.metadata.message_type != MessageType.TRAIN:
52
52
  return call_next(msg, ctxt)
53
53
  fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
54
54
  if KEY_CLIPPING_NORM not in fit_ins.config:
@@ -93,7 +93,7 @@ def adaptiveclipping_mod(
93
93
 
94
94
  It also sends KEY_NORM_BIT to the server for computing the new clipping value.
95
95
 
96
- It operates on messages with type MESSAGE_TYPE_FIT.
96
+ It operates on messages with type MessageType.TRAIN.
97
97
 
98
98
  Notes
99
99
  -----
@@ -101,7 +101,7 @@ def adaptiveclipping_mod(
101
101
 
102
102
  Typically, adaptiveclipping_mod should be the last to operate on params.
103
103
  """
104
- if msg.metadata.message_type != MESSAGE_TYPE_FIT:
104
+ if msg.metadata.message_type != MessageType.TRAIN:
105
105
  return call_next(msg, ctxt)
106
106
 
107
107
  fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
@@ -30,7 +30,7 @@ from flwr.common import (
30
30
  parameters_to_ndarrays,
31
31
  )
32
32
  from flwr.common import recordset_compat as compat
33
- from flwr.common.constant import MESSAGE_TYPE_FIT
33
+ from flwr.common.constant import MessageType
34
34
  from flwr.common.logger import log
35
35
  from flwr.common.secure_aggregation.crypto.shamir import create_shares
36
36
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
@@ -150,7 +150,7 @@ def secaggplus_mod(
150
150
  ) -> Message:
151
151
  """Handle incoming message and return results, following the SecAgg+ protocol."""
152
152
  # Ignore non-fit messages
153
- if msg.metadata.message_type != MESSAGE_TYPE_FIT:
153
+ if msg.metadata.message_type != MessageType.TRAIN:
154
154
  return call_next(msg, ctxt)
155
155
 
156
156
  # Retrieve local state
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
27
27
  from flwr.common.constant import MISSING_EXTRA_REST
28
28
  from flwr.common.logger import log
29
29
  from flwr.common.message import Message, Metadata
30
+ from flwr.common.retry_invoker import RetryInvoker
30
31
  from flwr.common.serde import message_from_taskins, message_to_taskres
31
32
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
33
  CreateNodeRequest,
@@ -61,6 +62,7 @@ PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
61
62
  def http_request_response(
62
63
  server_address: str,
63
64
  insecure: bool, # pylint: disable=unused-argument
65
+ retry_invoker: RetryInvoker,
64
66
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
65
67
  root_certificates: Optional[
66
68
  Union[bytes, str]
@@ -84,6 +86,12 @@ def http_request_response(
84
86
  The IPv6 address of the server with `http://` or `https://`.
85
87
  If the Flower server runs on the same machine
86
88
  on port 8080, then `server_address` would be `"http://[::]:8080"`.
89
+ insecure : bool
90
+ Unused argument present for compatibilty.
91
+ retry_invoker: RetryInvoker
92
+ `RetryInvoker` object that will try to reconnect the client to the server
93
+ after REST connection errors. If None, the client will only try to
94
+ reconnect once after a failure.
87
95
  max_message_length : int
88
96
  Ignored, only present to preserve API-compatibility.
89
97
  root_certificates : Optional[Union[bytes, str]] (default: None)
@@ -134,7 +142,8 @@ def http_request_response(
134
142
  create_node_req_proto = CreateNodeRequest()
135
143
  create_node_req_bytes: bytes = create_node_req_proto.SerializeToString()
136
144
 
137
- res = requests.post(
145
+ res = retry_invoker.invoke(
146
+ requests.post,
138
147
  url=f"{base_url}/{PATH_CREATE_NODE}",
139
148
  headers={
140
149
  "Accept": "application/protobuf",
@@ -177,7 +186,8 @@ def http_request_response(
177
186
  node: Node = cast(Node, node_store[KEY_NODE])
178
187
  delete_node_req_proto = DeleteNodeRequest(node=node)
179
188
  delete_node_req_req_bytes: bytes = delete_node_req_proto.SerializeToString()
180
- res = requests.post(
189
+ res = retry_invoker.invoke(
190
+ requests.post,
181
191
  url=f"{base_url}/{PATH_DELETE_NODE}",
182
192
  headers={
183
193
  "Accept": "application/protobuf",
@@ -218,7 +228,8 @@ def http_request_response(
218
228
  pull_task_ins_req_bytes: bytes = pull_task_ins_req_proto.SerializeToString()
219
229
 
220
230
  # Request instructions (task) from server
221
- res = requests.post(
231
+ res = retry_invoker.invoke(
232
+ requests.post,
222
233
  url=f"{base_url}/{PATH_PULL_TASK_INS}",
223
234
  headers={
224
235
  "Accept": "application/protobuf",
@@ -298,7 +309,8 @@ def http_request_response(
298
309
  )
299
310
 
300
311
  # Send ClientMessage to server
301
- res = requests.post(
312
+ res = retry_invoker.invoke(
313
+ requests.post,
302
314
  url=f"{base_url}/{PATH_PUSH_TASK_RES}",
303
315
  headers={
304
316
  "Accept": "application/protobuf",
flwr/common/__init__.py CHANGED
@@ -15,11 +15,14 @@
15
15
  """Common components shared between server and client."""
16
16
 
17
17
 
18
+ from .constant import MessageType as MessageType
19
+ from .constant import MessageTypeLegacy as MessageTypeLegacy
18
20
  from .context import Context as Context
19
21
  from .date import now as now
20
22
  from .grpc import GRPC_MAX_MESSAGE_LENGTH
21
23
  from .logger import configure as configure
22
24
  from .logger import log as log
25
+ from .message import Error as Error
23
26
  from .message import Message as Message
24
27
  from .message import Metadata as Metadata
25
28
  from .parameter import bytes_to_ndarray as bytes_to_ndarray
@@ -74,6 +77,7 @@ __all__ = [
74
77
  "EventType",
75
78
  "FitIns",
76
79
  "FitRes",
80
+ "Error",
77
81
  "GetParametersIns",
78
82
  "GetParametersRes",
79
83
  "GetPropertiesIns",
@@ -81,6 +85,8 @@ __all__ = [
81
85
  "GRPC_MAX_MESSAGE_LENGTH",
82
86
  "log",
83
87
  "Message",
88
+ "MessageType",
89
+ "MessageTypeLegacy",
84
90
  "Metadata",
85
91
  "Metrics",
86
92
  "MetricsAggregationFn",
flwr/common/constant.py CHANGED
@@ -36,10 +36,27 @@ TRANSPORT_TYPES = [
36
36
  TRANSPORT_TYPE_VCE,
37
37
  ]
38
38
 
39
- MESSAGE_TYPE_GET_PROPERTIES = "get_properties"
40
- MESSAGE_TYPE_GET_PARAMETERS = "get_parameters"
41
- MESSAGE_TYPE_FIT = "fit"
42
- MESSAGE_TYPE_EVALUATE = "evaluate"
39
+
40
+ class MessageType:
41
+ """Message type."""
42
+
43
+ TRAIN = "train"
44
+ EVALUATE = "evaluate"
45
+
46
+ def __new__(cls) -> MessageType:
47
+ """Prevent instantiation."""
48
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
49
+
50
+
51
+ class MessageTypeLegacy:
52
+ """Legacy message type."""
53
+
54
+ GET_PROPERTIES = "get_properties"
55
+ GET_PARAMETERS = "get_parameters"
56
+
57
+ def __new__(cls) -> MessageTypeLegacy:
58
+ """Prevent instantiation."""
59
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
43
60
 
44
61
 
45
62
  class SType:
flwr/server/app.py CHANGED
@@ -362,10 +362,10 @@ def run_superlink() -> None:
362
362
  f_stop = asyncio.Event() # Does nothing
363
363
  _run_fleet_api_vce(
364
364
  num_supernodes=args.num_supernodes,
365
- client_app_module_name=args.client_app,
365
+ client_app_attr=args.client_app,
366
366
  backend_name=args.backend,
367
367
  backend_config_json_stream=args.backend_config,
368
- working_dir=args.dir,
368
+ app_dir=args.app_dir,
369
369
  state_factory=state_factory,
370
370
  f_stop=f_stop,
371
371
  )
@@ -438,10 +438,10 @@ def _run_fleet_api_grpc_rere(
438
438
  # pylint: disable=too-many-arguments
439
439
  def _run_fleet_api_vce(
440
440
  num_supernodes: int,
441
- client_app_module_name: str,
441
+ client_app_attr: str,
442
442
  backend_name: str,
443
443
  backend_config_json_stream: str,
444
- working_dir: str,
444
+ app_dir: str,
445
445
  state_factory: StateFactory,
446
446
  f_stop: asyncio.Event,
447
447
  ) -> None:
@@ -449,11 +449,11 @@ def _run_fleet_api_vce(
449
449
 
450
450
  start_vce(
451
451
  num_supernodes=num_supernodes,
452
- client_app_module_name=client_app_module_name,
452
+ client_app_attr=client_app_attr,
453
453
  backend_name=backend_name,
454
454
  backend_config_json_stream=backend_config_json_stream,
455
455
  state_factory=state_factory,
456
- working_dir=working_dir,
456
+ app_dir=app_dir,
457
457
  f_stop=f_stop,
458
458
  )
459
459
 
@@ -705,7 +705,7 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
705
705
  "`flwr.common.typing.ConfigsRecordValues`. ",
706
706
  )
707
707
  parser.add_argument(
708
- "--dir",
708
+ "--app-dir",
709
709
  default="",
710
710
  help="Add specified directory to the PYTHONPATH and load"
711
711
  "ClientApp from there."
@@ -19,15 +19,9 @@ import time
19
19
  from typing import List, Optional
20
20
 
21
21
  from flwr import common
22
- from flwr.common import RecordSet
22
+ from flwr.common import MessageType, MessageTypeLegacy, RecordSet
23
23
  from flwr.common import recordset_compat as compat
24
24
  from flwr.common import serde
25
- from flwr.common.constant import (
26
- MESSAGE_TYPE_EVALUATE,
27
- MESSAGE_TYPE_FIT,
28
- MESSAGE_TYPE_GET_PARAMETERS,
29
- MESSAGE_TYPE_GET_PROPERTIES,
30
- )
31
25
  from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
32
26
  from flwr.server.client_proxy import ClientProxy
33
27
 
@@ -57,7 +51,7 @@ class DriverClientProxy(ClientProxy):
57
51
  out_recordset = compat.getpropertiesins_to_recordset(ins)
58
52
  # Fetch response
59
53
  in_recordset = self._send_receive_recordset(
60
- out_recordset, MESSAGE_TYPE_GET_PROPERTIES, timeout, group_id
54
+ out_recordset, MessageTypeLegacy.GET_PROPERTIES, timeout, group_id
61
55
  )
62
56
  # RecordSet to Res
63
57
  return compat.recordset_to_getpropertiesres(in_recordset)
@@ -73,7 +67,7 @@ class DriverClientProxy(ClientProxy):
73
67
  out_recordset = compat.getparametersins_to_recordset(ins)
74
68
  # Fetch response
75
69
  in_recordset = self._send_receive_recordset(
76
- out_recordset, MESSAGE_TYPE_GET_PARAMETERS, timeout, group_id
70
+ out_recordset, MessageTypeLegacy.GET_PARAMETERS, timeout, group_id
77
71
  )
78
72
  # RecordSet to Res
79
73
  return compat.recordset_to_getparametersres(in_recordset, False)
@@ -86,7 +80,7 @@ class DriverClientProxy(ClientProxy):
86
80
  out_recordset = compat.fitins_to_recordset(ins, keep_input=True)
87
81
  # Fetch response
88
82
  in_recordset = self._send_receive_recordset(
89
- out_recordset, MESSAGE_TYPE_FIT, timeout, group_id
83
+ out_recordset, MessageType.TRAIN, timeout, group_id
90
84
  )
91
85
  # RecordSet to Res
92
86
  return compat.recordset_to_fitres(in_recordset, keep_input=False)
@@ -99,7 +93,7 @@ class DriverClientProxy(ClientProxy):
99
93
  out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True)
100
94
  # Fetch response
101
95
  in_recordset = self._send_receive_recordset(
102
- out_recordset, MESSAGE_TYPE_EVALUATE, timeout, group_id
96
+ out_recordset, MessageType.EVALUATE, timeout, group_id
103
97
  )
104
98
  # RecordSet to Res
105
99
  return compat.recordset_to_evaluateres(in_recordset)
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  import argparse
19
- import asyncio
20
19
  import sys
21
20
  from logging import DEBUG, WARN
22
21
  from pathlib import Path
@@ -30,17 +29,27 @@ from .server_app import ServerApp, load_server_app
30
29
 
31
30
 
32
31
  def run(
33
- server_app_attr: str,
34
32
  driver: Driver,
35
33
  server_app_dir: str,
36
- stop_event: Optional[asyncio.Event] = None,
34
+ server_app_attr: Optional[str] = None,
35
+ loaded_server_app: Optional[ServerApp] = None,
37
36
  ) -> None:
38
37
  """Run ServerApp with a given Driver."""
38
+ if not (server_app_attr is None) ^ (loaded_server_app is None):
39
+ raise ValueError(
40
+ "Either `server_app_attr` or `loaded_server_app` should be set "
41
+ "but not both. "
42
+ )
43
+
39
44
  if server_app_dir is not None:
40
45
  sys.path.insert(0, server_app_dir)
41
46
 
47
+ # Load ServerApp if needed
42
48
  def _load() -> ServerApp:
43
- server_app: ServerApp = load_server_app(server_app_attr)
49
+ if server_app_attr:
50
+ server_app: ServerApp = load_server_app(server_app_attr)
51
+ if loaded_server_app:
52
+ server_app = loaded_server_app
44
53
  return server_app
45
54
 
46
55
  server_app = _load()
@@ -52,10 +61,6 @@ def run(
52
61
  server_app(driver=driver, context=context)
53
62
 
54
63
  log(DEBUG, "ServerApp finished running.")
55
- # Upon completion, trigger stop event if one was passed
56
- if stop_event is not None:
57
- log(DEBUG, "Triggering stop event.")
58
- stop_event.set()
59
64
 
60
65
 
61
66
  def run_server_app() -> None:
@@ -117,7 +122,7 @@ def run_server_app() -> None:
117
122
  )
118
123
 
119
124
  # Run the Server App with the Driver
120
- run(server_app_attr, driver, server_app_dir)
125
+ run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
121
126
 
122
127
  # Clean up
123
128
  driver.__del__() # pylint: disable=unnecessary-dunder-call
flwr/server/server.py CHANGED
@@ -17,7 +17,7 @@
17
17
 
18
18
  import concurrent.futures
19
19
  import timeit
20
- from logging import DEBUG, INFO, WARN
20
+ from logging import INFO, WARN
21
21
  from typing import Dict, List, Optional, Tuple, Union
22
22
 
23
23
  from flwr.common import (
@@ -173,7 +173,7 @@ class Server:
173
173
  log(INFO, "evaluate_round %s: no clients selected, cancel", server_round)
174
174
  return None
175
175
  log(
176
- DEBUG,
176
+ INFO,
177
177
  "evaluate_round %s: strategy sampled %s clients (out of %s)",
178
178
  server_round,
179
179
  len(client_instructions),
@@ -188,7 +188,7 @@ class Server:
188
188
  group_id=server_round,
189
189
  )
190
190
  log(
191
- DEBUG,
191
+ INFO,
192
192
  "evaluate_round %s received %s results and %s failures",
193
193
  server_round,
194
194
  len(results),
@@ -223,7 +223,7 @@ class Server:
223
223
  log(INFO, "fit_round %s: no clients selected, cancel", server_round)
224
224
  return None
225
225
  log(
226
- DEBUG,
226
+ INFO,
227
227
  "fit_round %s: strategy sampled %s clients (out of %s)",
228
228
  server_round,
229
229
  len(client_instructions),
@@ -238,7 +238,7 @@ class Server:
238
238
  group_id=server_round,
239
239
  )
240
240
  log(
241
- DEBUG,
241
+ INFO,
242
242
  "fit_round %s received %s results and %s failures",
243
243
  server_round,
244
244
  len(results),
@@ -49,7 +49,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
49
49
  self, request: GetNodesRequest, context: grpc.ServicerContext
50
50
  ) -> GetNodesResponse:
51
51
  """Get available nodes."""
52
- log(INFO, "DriverServicer.GetNodes")
52
+ log(DEBUG, "DriverServicer.GetNodes")
53
53
  state: State = self.state_factory.state()
54
54
  all_ids: Set[int] = state.get_nodes(request.run_id)
55
55
  nodes: List[Node] = [
@@ -219,16 +219,23 @@ async def run(
219
219
 
220
220
  # pylint: disable=too-many-arguments,unused-argument,too-many-locals
221
221
  def start_vce(
222
- client_app_module_name: str,
223
222
  backend_name: str,
224
223
  backend_config_json_stream: str,
225
- working_dir: str,
224
+ app_dir: str,
226
225
  f_stop: asyncio.Event,
226
+ client_app: Optional[ClientApp] = None,
227
+ client_app_attr: Optional[str] = None,
227
228
  num_supernodes: Optional[int] = None,
228
229
  state_factory: Optional[StateFactory] = None,
229
230
  existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
230
231
  ) -> None:
231
232
  """Start Fleet API with the Simulation Engine."""
233
+ if client_app_attr is not None and client_app is not None:
234
+ raise ValueError(
235
+ "Both `client_app_attr` and `client_app` are provided, "
236
+ "but only one is allowed."
237
+ )
238
+
232
239
  if num_supernodes is not None and existing_nodes_mapping is not None:
233
240
  raise ValueError(
234
241
  "Both `num_supernodes` and `existing_nodes_mapping` are provided, "
@@ -290,12 +297,17 @@ def start_vce(
290
297
 
291
298
  def backend_fn() -> Backend:
292
299
  """Instantiate a Backend."""
293
- return backend_type(backend_config, work_dir=working_dir)
300
+ return backend_type(backend_config, work_dir=app_dir)
294
301
 
295
- log(INFO, "client_app_module_name = %s", client_app_module_name)
302
+ log(INFO, "client_app_attr = %s", client_app_attr)
296
303
 
304
+ # Load ClientApp if needed
297
305
  def _load() -> ClientApp:
298
- app: ClientApp = load_client_app(client_app_module_name)
306
+
307
+ if client_app_attr:
308
+ app: ClientApp = load_client_app(client_app_attr)
309
+ if client_app:
310
+ app = client_app
299
311
  return app
300
312
 
301
313
  app_fn = _load
@@ -21,11 +21,7 @@ from typing import Optional, cast
21
21
 
22
22
  import flwr.common.recordset_compat as compat
23
23
  from flwr.common import ConfigsRecord, Context, GetParametersIns, log
24
- from flwr.common.constant import (
25
- MESSAGE_TYPE_EVALUATE,
26
- MESSAGE_TYPE_FIT,
27
- MESSAGE_TYPE_GET_PARAMETERS,
28
- )
24
+ from flwr.common.constant import MessageType, MessageTypeLegacy
29
25
 
30
26
  from ..compat.app_utils import start_update_client_manager_thread
31
27
  from ..compat.legacy_context import LegacyContext
@@ -134,7 +130,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
134
130
  [
135
131
  driver.create_message(
136
132
  content=content,
137
- message_type=MESSAGE_TYPE_GET_PARAMETERS,
133
+ message_type=MessageTypeLegacy.GET_PARAMETERS,
138
134
  dst_node_id=random_client.node_id,
139
135
  group_id="",
140
136
  ttl="",
@@ -232,7 +228,7 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
232
228
  out_messages = [
233
229
  driver.create_message(
234
230
  content=compat.fitins_to_recordset(fitins, True),
235
- message_type=MESSAGE_TYPE_FIT,
231
+ message_type=MessageType.TRAIN,
236
232
  dst_node_id=proxy.node_id,
237
233
  group_id="",
238
234
  ttl="",
@@ -313,7 +309,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
313
309
  out_messages = [
314
310
  driver.create_message(
315
311
  content=compat.evaluateins_to_recordset(evalins, True),
316
- message_type=MESSAGE_TYPE_EVALUATE,
312
+ message_type=MessageType.EVALUATE,
317
313
  dst_node_id=proxy.node_id,
318
314
  group_id="",
319
315
  ttl="",
@@ -17,7 +17,7 @@
17
17
 
18
18
  import importlib
19
19
 
20
- from flwr.simulation.run_simulation import run_simulation
20
+ from flwr.simulation.run_simulation import run_simulation, run_simulation_from_cli
21
21
 
22
22
  is_ray_installed = importlib.util.find_spec("ray") is not None
23
23
 
@@ -36,7 +36,4 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
36
36
  raise ImportError(RAY_IMPORT_ERROR)
37
37
 
38
38
 
39
- __all__ = [
40
- "start_simulation",
41
- "run_simulation",
42
- ]
39
+ __all__ = ["start_simulation", "run_simulation_from_cli", "run_simulation"]