flwr-nightly 1.8.0.dev20240304__py3-none-any.whl → 1.8.0.dev20240306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +2 -0
- flwr/cli/flower_toml.py +151 -0
- flwr/cli/new/new.py +1 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +12 -0
- flwr/cli/new/templates/app/flower.toml.tpl +2 -2
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +2 -0
- flwr/cli/run/__init__.py +21 -0
- flwr/cli/run/run.py +102 -0
- flwr/client/app.py +93 -8
- flwr/client/grpc_client/connection.py +16 -14
- flwr/client/grpc_rere_client/connection.py +14 -4
- flwr/client/message_handler/message_handler.py +5 -10
- flwr/client/mod/centraldp_mods.py +5 -5
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +2 -2
- flwr/client/rest_client/connection.py +16 -4
- flwr/common/__init__.py +6 -0
- flwr/common/constant.py +21 -4
- flwr/server/app.py +7 -7
- flwr/server/compat/driver_client_proxy.py +5 -11
- flwr/server/run_serverapp.py +14 -9
- flwr/server/server.py +5 -5
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +17 -5
- flwr/server/workflow/default_workflows.py +4 -8
- flwr/simulation/__init__.py +2 -5
- flwr/simulation/ray_transport/ray_client_proxy.py +5 -10
- flwr/simulation/run_simulation.py +301 -76
- {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/METADATA +4 -3
- {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/RECORD +33 -27
- {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/entry_points.txt +1 -1
- {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/WHEEL +0 -0
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
27
27
|
from flwr.common.grpc import create_channel
|
28
28
|
from flwr.common.logger import log, warn_experimental_feature
|
29
29
|
from flwr.common.message import Message, Metadata
|
30
|
+
from flwr.common.retry_invoker import RetryInvoker
|
30
31
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
31
32
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
32
33
|
CreateNodeRequest,
|
@@ -51,6 +52,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:
|
|
51
52
|
def grpc_request_response(
|
52
53
|
server_address: str,
|
53
54
|
insecure: bool,
|
55
|
+
retry_invoker: RetryInvoker,
|
54
56
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
55
57
|
root_certificates: Optional[Union[bytes, str]] = None,
|
56
58
|
) -> Iterator[
|
@@ -72,6 +74,13 @@ def grpc_request_response(
|
|
72
74
|
The IPv6 address of the server with `http://` or `https://`.
|
73
75
|
If the Flower server runs on the same machine
|
74
76
|
on port 8080, then `server_address` would be `"http://[::]:8080"`.
|
77
|
+
insecure : bool
|
78
|
+
Starts an insecure gRPC connection when True. Enables HTTPS connection
|
79
|
+
when False, using system certificates if `root_certificates` is None.
|
80
|
+
retry_invoker: RetryInvoker
|
81
|
+
`RetryInvoker` object that will try to reconnect the client to the server
|
82
|
+
after gRPC errors. If None, the client will only try to
|
83
|
+
reconnect once after a failure.
|
75
84
|
max_message_length : int
|
76
85
|
Ignored, only present to preserve API-compatibility.
|
77
86
|
root_certificates : Optional[Union[bytes, str]] (default: None)
|
@@ -113,7 +122,8 @@ def grpc_request_response(
|
|
113
122
|
def create_node() -> None:
|
114
123
|
"""Set create_node."""
|
115
124
|
create_node_request = CreateNodeRequest()
|
116
|
-
create_node_response =
|
125
|
+
create_node_response = retry_invoker.invoke(
|
126
|
+
stub.CreateNode,
|
117
127
|
request=create_node_request,
|
118
128
|
)
|
119
129
|
node_store[KEY_NODE] = create_node_response.node
|
@@ -127,7 +137,7 @@ def grpc_request_response(
|
|
127
137
|
node: Node = cast(Node, node_store[KEY_NODE])
|
128
138
|
|
129
139
|
delete_node_request = DeleteNodeRequest(node=node)
|
130
|
-
stub.DeleteNode
|
140
|
+
retry_invoker.invoke(stub.DeleteNode, request=delete_node_request)
|
131
141
|
|
132
142
|
del node_store[KEY_NODE]
|
133
143
|
|
@@ -141,7 +151,7 @@ def grpc_request_response(
|
|
141
151
|
|
142
152
|
# Request instructions (task) from server
|
143
153
|
request = PullTaskInsRequest(node=node)
|
144
|
-
response = stub.PullTaskIns
|
154
|
+
response = retry_invoker.invoke(stub.PullTaskIns, request=request)
|
145
155
|
|
146
156
|
# Get the current TaskIns
|
147
157
|
task_ins: Optional[TaskIns] = get_task_ins(response)
|
@@ -185,7 +195,7 @@ def grpc_request_response(
|
|
185
195
|
|
186
196
|
# Serialize ProtoBuf to bytes
|
187
197
|
request = PushTaskResRequest(task_res_list=[task_res])
|
188
|
-
_ = stub.PushTaskRes
|
198
|
+
_ = retry_invoker.invoke(stub.PushTaskRes, request)
|
189
199
|
|
190
200
|
state[KEY_METADATA] = None
|
191
201
|
|
@@ -27,12 +27,7 @@ from flwr.client.client import (
|
|
27
27
|
from flwr.client.numpy_client import NumPyClient
|
28
28
|
from flwr.client.typing import ClientFn
|
29
29
|
from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log
|
30
|
-
from flwr.common.constant import
|
31
|
-
MESSAGE_TYPE_EVALUATE,
|
32
|
-
MESSAGE_TYPE_FIT,
|
33
|
-
MESSAGE_TYPE_GET_PARAMETERS,
|
34
|
-
MESSAGE_TYPE_GET_PROPERTIES,
|
35
|
-
)
|
30
|
+
from flwr.common.constant import MessageType, MessageTypeLegacy
|
36
31
|
from flwr.common.recordset_compat import (
|
37
32
|
evaluateres_to_recordset,
|
38
33
|
fitres_to_recordset,
|
@@ -115,14 +110,14 @@ def handle_legacy_message_from_msgtype(
|
|
115
110
|
message_type = message.metadata.message_type
|
116
111
|
|
117
112
|
# Handle GetPropertiesIns
|
118
|
-
if message_type ==
|
113
|
+
if message_type == MessageTypeLegacy.GET_PROPERTIES:
|
119
114
|
get_properties_res = maybe_call_get_properties(
|
120
115
|
client=client,
|
121
116
|
get_properties_ins=recordset_to_getpropertiesins(message.content),
|
122
117
|
)
|
123
118
|
out_recordset = getpropertiesres_to_recordset(get_properties_res)
|
124
119
|
# Handle GetParametersIns
|
125
|
-
elif message_type ==
|
120
|
+
elif message_type == MessageTypeLegacy.GET_PARAMETERS:
|
126
121
|
get_parameters_res = maybe_call_get_parameters(
|
127
122
|
client=client,
|
128
123
|
get_parameters_ins=recordset_to_getparametersins(message.content),
|
@@ -131,14 +126,14 @@ def handle_legacy_message_from_msgtype(
|
|
131
126
|
get_parameters_res, keep_input=False
|
132
127
|
)
|
133
128
|
# Handle FitIns
|
134
|
-
elif message_type ==
|
129
|
+
elif message_type == MessageType.TRAIN:
|
135
130
|
fit_res = maybe_call_fit(
|
136
131
|
client=client,
|
137
132
|
fit_ins=recordset_to_fitins(message.content, keep_input=True),
|
138
133
|
)
|
139
134
|
out_recordset = fitres_to_recordset(fit_res, keep_input=False)
|
140
135
|
# Handle EvaluateIns
|
141
|
-
elif message_type ==
|
136
|
+
elif message_type == MessageType.EVALUATE:
|
142
137
|
evaluate_res = maybe_call_evaluate(
|
143
138
|
client=client,
|
144
139
|
evaluate_ins=recordset_to_evaluateins(message.content, keep_input=True),
|
@@ -18,7 +18,7 @@
|
|
18
18
|
from flwr.client.typing import ClientAppCallable
|
19
19
|
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
|
20
20
|
from flwr.common import recordset_compat as compat
|
21
|
-
from flwr.common.constant import
|
21
|
+
from flwr.common.constant import MessageType
|
22
22
|
from flwr.common.context import Context
|
23
23
|
from flwr.common.differential_privacy import (
|
24
24
|
compute_adaptive_clip_model_update,
|
@@ -40,7 +40,7 @@ def fixedclipping_mod(
|
|
40
40
|
|
41
41
|
This mod clips the client model updates before sending them to the server.
|
42
42
|
|
43
|
-
It operates on messages with type
|
43
|
+
It operates on messages with type MessageType.TRAIN.
|
44
44
|
|
45
45
|
Notes
|
46
46
|
-----
|
@@ -48,7 +48,7 @@ def fixedclipping_mod(
|
|
48
48
|
|
49
49
|
Typically, fixedclipping_mod should be the last to operate on params.
|
50
50
|
"""
|
51
|
-
if msg.metadata.message_type !=
|
51
|
+
if msg.metadata.message_type != MessageType.TRAIN:
|
52
52
|
return call_next(msg, ctxt)
|
53
53
|
fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
|
54
54
|
if KEY_CLIPPING_NORM not in fit_ins.config:
|
@@ -93,7 +93,7 @@ def adaptiveclipping_mod(
|
|
93
93
|
|
94
94
|
It also sends KEY_NORM_BIT to the server for computing the new clipping value.
|
95
95
|
|
96
|
-
It operates on messages with type
|
96
|
+
It operates on messages with type MessageType.TRAIN.
|
97
97
|
|
98
98
|
Notes
|
99
99
|
-----
|
@@ -101,7 +101,7 @@ def adaptiveclipping_mod(
|
|
101
101
|
|
102
102
|
Typically, adaptiveclipping_mod should be the last to operate on params.
|
103
103
|
"""
|
104
|
-
if msg.metadata.message_type !=
|
104
|
+
if msg.metadata.message_type != MessageType.TRAIN:
|
105
105
|
return call_next(msg, ctxt)
|
106
106
|
|
107
107
|
fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
|
@@ -30,7 +30,7 @@ from flwr.common import (
|
|
30
30
|
parameters_to_ndarrays,
|
31
31
|
)
|
32
32
|
from flwr.common import recordset_compat as compat
|
33
|
-
from flwr.common.constant import
|
33
|
+
from flwr.common.constant import MessageType
|
34
34
|
from flwr.common.logger import log
|
35
35
|
from flwr.common.secure_aggregation.crypto.shamir import create_shares
|
36
36
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
@@ -150,7 +150,7 @@ def secaggplus_mod(
|
|
150
150
|
) -> Message:
|
151
151
|
"""Handle incoming message and return results, following the SecAgg+ protocol."""
|
152
152
|
# Ignore non-fit messages
|
153
|
-
if msg.metadata.message_type !=
|
153
|
+
if msg.metadata.message_type != MessageType.TRAIN:
|
154
154
|
return call_next(msg, ctxt)
|
155
155
|
|
156
156
|
# Retrieve local state
|
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
27
27
|
from flwr.common.constant import MISSING_EXTRA_REST
|
28
28
|
from flwr.common.logger import log
|
29
29
|
from flwr.common.message import Message, Metadata
|
30
|
+
from flwr.common.retry_invoker import RetryInvoker
|
30
31
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
31
32
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
32
33
|
CreateNodeRequest,
|
@@ -61,6 +62,7 @@ PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
|
|
61
62
|
def http_request_response(
|
62
63
|
server_address: str,
|
63
64
|
insecure: bool, # pylint: disable=unused-argument
|
65
|
+
retry_invoker: RetryInvoker,
|
64
66
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
65
67
|
root_certificates: Optional[
|
66
68
|
Union[bytes, str]
|
@@ -84,6 +86,12 @@ def http_request_response(
|
|
84
86
|
The IPv6 address of the server with `http://` or `https://`.
|
85
87
|
If the Flower server runs on the same machine
|
86
88
|
on port 8080, then `server_address` would be `"http://[::]:8080"`.
|
89
|
+
insecure : bool
|
90
|
+
Unused argument present for compatibilty.
|
91
|
+
retry_invoker: RetryInvoker
|
92
|
+
`RetryInvoker` object that will try to reconnect the client to the server
|
93
|
+
after REST connection errors. If None, the client will only try to
|
94
|
+
reconnect once after a failure.
|
87
95
|
max_message_length : int
|
88
96
|
Ignored, only present to preserve API-compatibility.
|
89
97
|
root_certificates : Optional[Union[bytes, str]] (default: None)
|
@@ -134,7 +142,8 @@ def http_request_response(
|
|
134
142
|
create_node_req_proto = CreateNodeRequest()
|
135
143
|
create_node_req_bytes: bytes = create_node_req_proto.SerializeToString()
|
136
144
|
|
137
|
-
res =
|
145
|
+
res = retry_invoker.invoke(
|
146
|
+
requests.post,
|
138
147
|
url=f"{base_url}/{PATH_CREATE_NODE}",
|
139
148
|
headers={
|
140
149
|
"Accept": "application/protobuf",
|
@@ -177,7 +186,8 @@ def http_request_response(
|
|
177
186
|
node: Node = cast(Node, node_store[KEY_NODE])
|
178
187
|
delete_node_req_proto = DeleteNodeRequest(node=node)
|
179
188
|
delete_node_req_req_bytes: bytes = delete_node_req_proto.SerializeToString()
|
180
|
-
res =
|
189
|
+
res = retry_invoker.invoke(
|
190
|
+
requests.post,
|
181
191
|
url=f"{base_url}/{PATH_DELETE_NODE}",
|
182
192
|
headers={
|
183
193
|
"Accept": "application/protobuf",
|
@@ -218,7 +228,8 @@ def http_request_response(
|
|
218
228
|
pull_task_ins_req_bytes: bytes = pull_task_ins_req_proto.SerializeToString()
|
219
229
|
|
220
230
|
# Request instructions (task) from server
|
221
|
-
res =
|
231
|
+
res = retry_invoker.invoke(
|
232
|
+
requests.post,
|
222
233
|
url=f"{base_url}/{PATH_PULL_TASK_INS}",
|
223
234
|
headers={
|
224
235
|
"Accept": "application/protobuf",
|
@@ -298,7 +309,8 @@ def http_request_response(
|
|
298
309
|
)
|
299
310
|
|
300
311
|
# Send ClientMessage to server
|
301
|
-
res =
|
312
|
+
res = retry_invoker.invoke(
|
313
|
+
requests.post,
|
302
314
|
url=f"{base_url}/{PATH_PUSH_TASK_RES}",
|
303
315
|
headers={
|
304
316
|
"Accept": "application/protobuf",
|
flwr/common/__init__.py
CHANGED
@@ -15,11 +15,14 @@
|
|
15
15
|
"""Common components shared between server and client."""
|
16
16
|
|
17
17
|
|
18
|
+
from .constant import MessageType as MessageType
|
19
|
+
from .constant import MessageTypeLegacy as MessageTypeLegacy
|
18
20
|
from .context import Context as Context
|
19
21
|
from .date import now as now
|
20
22
|
from .grpc import GRPC_MAX_MESSAGE_LENGTH
|
21
23
|
from .logger import configure as configure
|
22
24
|
from .logger import log as log
|
25
|
+
from .message import Error as Error
|
23
26
|
from .message import Message as Message
|
24
27
|
from .message import Metadata as Metadata
|
25
28
|
from .parameter import bytes_to_ndarray as bytes_to_ndarray
|
@@ -74,6 +77,7 @@ __all__ = [
|
|
74
77
|
"EventType",
|
75
78
|
"FitIns",
|
76
79
|
"FitRes",
|
80
|
+
"Error",
|
77
81
|
"GetParametersIns",
|
78
82
|
"GetParametersRes",
|
79
83
|
"GetPropertiesIns",
|
@@ -81,6 +85,8 @@ __all__ = [
|
|
81
85
|
"GRPC_MAX_MESSAGE_LENGTH",
|
82
86
|
"log",
|
83
87
|
"Message",
|
88
|
+
"MessageType",
|
89
|
+
"MessageTypeLegacy",
|
84
90
|
"Metadata",
|
85
91
|
"Metrics",
|
86
92
|
"MetricsAggregationFn",
|
flwr/common/constant.py
CHANGED
@@ -36,10 +36,27 @@ TRANSPORT_TYPES = [
|
|
36
36
|
TRANSPORT_TYPE_VCE,
|
37
37
|
]
|
38
38
|
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
39
|
+
|
40
|
+
class MessageType:
|
41
|
+
"""Message type."""
|
42
|
+
|
43
|
+
TRAIN = "train"
|
44
|
+
EVALUATE = "evaluate"
|
45
|
+
|
46
|
+
def __new__(cls) -> MessageType:
|
47
|
+
"""Prevent instantiation."""
|
48
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
49
|
+
|
50
|
+
|
51
|
+
class MessageTypeLegacy:
|
52
|
+
"""Legacy message type."""
|
53
|
+
|
54
|
+
GET_PROPERTIES = "get_properties"
|
55
|
+
GET_PARAMETERS = "get_parameters"
|
56
|
+
|
57
|
+
def __new__(cls) -> MessageTypeLegacy:
|
58
|
+
"""Prevent instantiation."""
|
59
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
43
60
|
|
44
61
|
|
45
62
|
class SType:
|
flwr/server/app.py
CHANGED
@@ -362,10 +362,10 @@ def run_superlink() -> None:
|
|
362
362
|
f_stop = asyncio.Event() # Does nothing
|
363
363
|
_run_fleet_api_vce(
|
364
364
|
num_supernodes=args.num_supernodes,
|
365
|
-
|
365
|
+
client_app_attr=args.client_app,
|
366
366
|
backend_name=args.backend,
|
367
367
|
backend_config_json_stream=args.backend_config,
|
368
|
-
|
368
|
+
app_dir=args.app_dir,
|
369
369
|
state_factory=state_factory,
|
370
370
|
f_stop=f_stop,
|
371
371
|
)
|
@@ -438,10 +438,10 @@ def _run_fleet_api_grpc_rere(
|
|
438
438
|
# pylint: disable=too-many-arguments
|
439
439
|
def _run_fleet_api_vce(
|
440
440
|
num_supernodes: int,
|
441
|
-
|
441
|
+
client_app_attr: str,
|
442
442
|
backend_name: str,
|
443
443
|
backend_config_json_stream: str,
|
444
|
-
|
444
|
+
app_dir: str,
|
445
445
|
state_factory: StateFactory,
|
446
446
|
f_stop: asyncio.Event,
|
447
447
|
) -> None:
|
@@ -449,11 +449,11 @@ def _run_fleet_api_vce(
|
|
449
449
|
|
450
450
|
start_vce(
|
451
451
|
num_supernodes=num_supernodes,
|
452
|
-
|
452
|
+
client_app_attr=client_app_attr,
|
453
453
|
backend_name=backend_name,
|
454
454
|
backend_config_json_stream=backend_config_json_stream,
|
455
455
|
state_factory=state_factory,
|
456
|
-
|
456
|
+
app_dir=app_dir,
|
457
457
|
f_stop=f_stop,
|
458
458
|
)
|
459
459
|
|
@@ -705,7 +705,7 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
|
|
705
705
|
"`flwr.common.typing.ConfigsRecordValues`. ",
|
706
706
|
)
|
707
707
|
parser.add_argument(
|
708
|
-
"--dir",
|
708
|
+
"--app-dir",
|
709
709
|
default="",
|
710
710
|
help="Add specified directory to the PYTHONPATH and load"
|
711
711
|
"ClientApp from there."
|
@@ -19,15 +19,9 @@ import time
|
|
19
19
|
from typing import List, Optional
|
20
20
|
|
21
21
|
from flwr import common
|
22
|
-
from flwr.common import RecordSet
|
22
|
+
from flwr.common import MessageType, MessageTypeLegacy, RecordSet
|
23
23
|
from flwr.common import recordset_compat as compat
|
24
24
|
from flwr.common import serde
|
25
|
-
from flwr.common.constant import (
|
26
|
-
MESSAGE_TYPE_EVALUATE,
|
27
|
-
MESSAGE_TYPE_FIT,
|
28
|
-
MESSAGE_TYPE_GET_PARAMETERS,
|
29
|
-
MESSAGE_TYPE_GET_PROPERTIES,
|
30
|
-
)
|
31
25
|
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
|
32
26
|
from flwr.server.client_proxy import ClientProxy
|
33
27
|
|
@@ -57,7 +51,7 @@ class DriverClientProxy(ClientProxy):
|
|
57
51
|
out_recordset = compat.getpropertiesins_to_recordset(ins)
|
58
52
|
# Fetch response
|
59
53
|
in_recordset = self._send_receive_recordset(
|
60
|
-
out_recordset,
|
54
|
+
out_recordset, MessageTypeLegacy.GET_PROPERTIES, timeout, group_id
|
61
55
|
)
|
62
56
|
# RecordSet to Res
|
63
57
|
return compat.recordset_to_getpropertiesres(in_recordset)
|
@@ -73,7 +67,7 @@ class DriverClientProxy(ClientProxy):
|
|
73
67
|
out_recordset = compat.getparametersins_to_recordset(ins)
|
74
68
|
# Fetch response
|
75
69
|
in_recordset = self._send_receive_recordset(
|
76
|
-
out_recordset,
|
70
|
+
out_recordset, MessageTypeLegacy.GET_PARAMETERS, timeout, group_id
|
77
71
|
)
|
78
72
|
# RecordSet to Res
|
79
73
|
return compat.recordset_to_getparametersres(in_recordset, False)
|
@@ -86,7 +80,7 @@ class DriverClientProxy(ClientProxy):
|
|
86
80
|
out_recordset = compat.fitins_to_recordset(ins, keep_input=True)
|
87
81
|
# Fetch response
|
88
82
|
in_recordset = self._send_receive_recordset(
|
89
|
-
out_recordset,
|
83
|
+
out_recordset, MessageType.TRAIN, timeout, group_id
|
90
84
|
)
|
91
85
|
# RecordSet to Res
|
92
86
|
return compat.recordset_to_fitres(in_recordset, keep_input=False)
|
@@ -99,7 +93,7 @@ class DriverClientProxy(ClientProxy):
|
|
99
93
|
out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True)
|
100
94
|
# Fetch response
|
101
95
|
in_recordset = self._send_receive_recordset(
|
102
|
-
out_recordset,
|
96
|
+
out_recordset, MessageType.EVALUATE, timeout, group_id
|
103
97
|
)
|
104
98
|
# RecordSet to Res
|
105
99
|
return compat.recordset_to_evaluateres(in_recordset)
|
flwr/server/run_serverapp.py
CHANGED
@@ -16,7 +16,6 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import argparse
|
19
|
-
import asyncio
|
20
19
|
import sys
|
21
20
|
from logging import DEBUG, WARN
|
22
21
|
from pathlib import Path
|
@@ -30,17 +29,27 @@ from .server_app import ServerApp, load_server_app
|
|
30
29
|
|
31
30
|
|
32
31
|
def run(
|
33
|
-
server_app_attr: str,
|
34
32
|
driver: Driver,
|
35
33
|
server_app_dir: str,
|
36
|
-
|
34
|
+
server_app_attr: Optional[str] = None,
|
35
|
+
loaded_server_app: Optional[ServerApp] = None,
|
37
36
|
) -> None:
|
38
37
|
"""Run ServerApp with a given Driver."""
|
38
|
+
if not (server_app_attr is None) ^ (loaded_server_app is None):
|
39
|
+
raise ValueError(
|
40
|
+
"Either `server_app_attr` or `loaded_server_app` should be set "
|
41
|
+
"but not both. "
|
42
|
+
)
|
43
|
+
|
39
44
|
if server_app_dir is not None:
|
40
45
|
sys.path.insert(0, server_app_dir)
|
41
46
|
|
47
|
+
# Load ServerApp if needed
|
42
48
|
def _load() -> ServerApp:
|
43
|
-
|
49
|
+
if server_app_attr:
|
50
|
+
server_app: ServerApp = load_server_app(server_app_attr)
|
51
|
+
if loaded_server_app:
|
52
|
+
server_app = loaded_server_app
|
44
53
|
return server_app
|
45
54
|
|
46
55
|
server_app = _load()
|
@@ -52,10 +61,6 @@ def run(
|
|
52
61
|
server_app(driver=driver, context=context)
|
53
62
|
|
54
63
|
log(DEBUG, "ServerApp finished running.")
|
55
|
-
# Upon completion, trigger stop event if one was passed
|
56
|
-
if stop_event is not None:
|
57
|
-
log(DEBUG, "Triggering stop event.")
|
58
|
-
stop_event.set()
|
59
64
|
|
60
65
|
|
61
66
|
def run_server_app() -> None:
|
@@ -117,7 +122,7 @@ def run_server_app() -> None:
|
|
117
122
|
)
|
118
123
|
|
119
124
|
# Run the Server App with the Driver
|
120
|
-
run(
|
125
|
+
run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
|
121
126
|
|
122
127
|
# Clean up
|
123
128
|
driver.__del__() # pylint: disable=unnecessary-dunder-call
|
flwr/server/server.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
import concurrent.futures
|
19
19
|
import timeit
|
20
|
-
from logging import
|
20
|
+
from logging import INFO, WARN
|
21
21
|
from typing import Dict, List, Optional, Tuple, Union
|
22
22
|
|
23
23
|
from flwr.common import (
|
@@ -173,7 +173,7 @@ class Server:
|
|
173
173
|
log(INFO, "evaluate_round %s: no clients selected, cancel", server_round)
|
174
174
|
return None
|
175
175
|
log(
|
176
|
-
|
176
|
+
INFO,
|
177
177
|
"evaluate_round %s: strategy sampled %s clients (out of %s)",
|
178
178
|
server_round,
|
179
179
|
len(client_instructions),
|
@@ -188,7 +188,7 @@ class Server:
|
|
188
188
|
group_id=server_round,
|
189
189
|
)
|
190
190
|
log(
|
191
|
-
|
191
|
+
INFO,
|
192
192
|
"evaluate_round %s received %s results and %s failures",
|
193
193
|
server_round,
|
194
194
|
len(results),
|
@@ -223,7 +223,7 @@ class Server:
|
|
223
223
|
log(INFO, "fit_round %s: no clients selected, cancel", server_round)
|
224
224
|
return None
|
225
225
|
log(
|
226
|
-
|
226
|
+
INFO,
|
227
227
|
"fit_round %s: strategy sampled %s clients (out of %s)",
|
228
228
|
server_round,
|
229
229
|
len(client_instructions),
|
@@ -238,7 +238,7 @@ class Server:
|
|
238
238
|
group_id=server_round,
|
239
239
|
)
|
240
240
|
log(
|
241
|
-
|
241
|
+
INFO,
|
242
242
|
"fit_round %s received %s results and %s failures",
|
243
243
|
server_round,
|
244
244
|
len(results),
|
@@ -49,7 +49,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
49
49
|
self, request: GetNodesRequest, context: grpc.ServicerContext
|
50
50
|
) -> GetNodesResponse:
|
51
51
|
"""Get available nodes."""
|
52
|
-
log(
|
52
|
+
log(DEBUG, "DriverServicer.GetNodes")
|
53
53
|
state: State = self.state_factory.state()
|
54
54
|
all_ids: Set[int] = state.get_nodes(request.run_id)
|
55
55
|
nodes: List[Node] = [
|
@@ -219,16 +219,23 @@ async def run(
|
|
219
219
|
|
220
220
|
# pylint: disable=too-many-arguments,unused-argument,too-many-locals
|
221
221
|
def start_vce(
|
222
|
-
client_app_module_name: str,
|
223
222
|
backend_name: str,
|
224
223
|
backend_config_json_stream: str,
|
225
|
-
|
224
|
+
app_dir: str,
|
226
225
|
f_stop: asyncio.Event,
|
226
|
+
client_app: Optional[ClientApp] = None,
|
227
|
+
client_app_attr: Optional[str] = None,
|
227
228
|
num_supernodes: Optional[int] = None,
|
228
229
|
state_factory: Optional[StateFactory] = None,
|
229
230
|
existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
|
230
231
|
) -> None:
|
231
232
|
"""Start Fleet API with the Simulation Engine."""
|
233
|
+
if client_app_attr is not None and client_app is not None:
|
234
|
+
raise ValueError(
|
235
|
+
"Both `client_app_attr` and `client_app` are provided, "
|
236
|
+
"but only one is allowed."
|
237
|
+
)
|
238
|
+
|
232
239
|
if num_supernodes is not None and existing_nodes_mapping is not None:
|
233
240
|
raise ValueError(
|
234
241
|
"Both `num_supernodes` and `existing_nodes_mapping` are provided, "
|
@@ -290,12 +297,17 @@ def start_vce(
|
|
290
297
|
|
291
298
|
def backend_fn() -> Backend:
|
292
299
|
"""Instantiate a Backend."""
|
293
|
-
return backend_type(backend_config, work_dir=
|
300
|
+
return backend_type(backend_config, work_dir=app_dir)
|
294
301
|
|
295
|
-
log(INFO, "
|
302
|
+
log(INFO, "client_app_attr = %s", client_app_attr)
|
296
303
|
|
304
|
+
# Load ClientApp if needed
|
297
305
|
def _load() -> ClientApp:
|
298
|
-
|
306
|
+
|
307
|
+
if client_app_attr:
|
308
|
+
app: ClientApp = load_client_app(client_app_attr)
|
309
|
+
if client_app:
|
310
|
+
app = client_app
|
299
311
|
return app
|
300
312
|
|
301
313
|
app_fn = _load
|
@@ -21,11 +21,7 @@ from typing import Optional, cast
|
|
21
21
|
|
22
22
|
import flwr.common.recordset_compat as compat
|
23
23
|
from flwr.common import ConfigsRecord, Context, GetParametersIns, log
|
24
|
-
from flwr.common.constant import
|
25
|
-
MESSAGE_TYPE_EVALUATE,
|
26
|
-
MESSAGE_TYPE_FIT,
|
27
|
-
MESSAGE_TYPE_GET_PARAMETERS,
|
28
|
-
)
|
24
|
+
from flwr.common.constant import MessageType, MessageTypeLegacy
|
29
25
|
|
30
26
|
from ..compat.app_utils import start_update_client_manager_thread
|
31
27
|
from ..compat.legacy_context import LegacyContext
|
@@ -134,7 +130,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
134
130
|
[
|
135
131
|
driver.create_message(
|
136
132
|
content=content,
|
137
|
-
message_type=
|
133
|
+
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
138
134
|
dst_node_id=random_client.node_id,
|
139
135
|
group_id="",
|
140
136
|
ttl="",
|
@@ -232,7 +228,7 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
|
|
232
228
|
out_messages = [
|
233
229
|
driver.create_message(
|
234
230
|
content=compat.fitins_to_recordset(fitins, True),
|
235
|
-
message_type=
|
231
|
+
message_type=MessageType.TRAIN,
|
236
232
|
dst_node_id=proxy.node_id,
|
237
233
|
group_id="",
|
238
234
|
ttl="",
|
@@ -313,7 +309,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
313
309
|
out_messages = [
|
314
310
|
driver.create_message(
|
315
311
|
content=compat.evaluateins_to_recordset(evalins, True),
|
316
|
-
message_type=
|
312
|
+
message_type=MessageType.EVALUATE,
|
317
313
|
dst_node_id=proxy.node_id,
|
318
314
|
group_id="",
|
319
315
|
ttl="",
|
flwr/simulation/__init__.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
import importlib
|
19
19
|
|
20
|
-
from flwr.simulation.run_simulation import run_simulation
|
20
|
+
from flwr.simulation.run_simulation import run_simulation, run_simulation_from_cli
|
21
21
|
|
22
22
|
is_ray_installed = importlib.util.find_spec("ray") is not None
|
23
23
|
|
@@ -36,7 +36,4 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
|
|
36
36
|
raise ImportError(RAY_IMPORT_ERROR)
|
37
37
|
|
38
38
|
|
39
|
-
__all__ = [
|
40
|
-
"start_simulation",
|
41
|
-
"run_simulation",
|
42
|
-
]
|
39
|
+
__all__ = ["start_simulation", "run_simulation_from_cli", "run_simulation"]
|