flwr-nightly 1.8.0.dev20240304__py3-none-any.whl → 1.8.0.dev20240306__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/cli/app.py +2 -0
- flwr/cli/flower_toml.py +151 -0
- flwr/cli/new/new.py +1 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +12 -0
- flwr/cli/new/templates/app/flower.toml.tpl +2 -2
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +2 -0
- flwr/cli/run/__init__.py +21 -0
- flwr/cli/run/run.py +102 -0
- flwr/client/app.py +93 -8
- flwr/client/grpc_client/connection.py +16 -14
- flwr/client/grpc_rere_client/connection.py +14 -4
- flwr/client/message_handler/message_handler.py +5 -10
- flwr/client/mod/centraldp_mods.py +5 -5
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +2 -2
- flwr/client/rest_client/connection.py +16 -4
- flwr/common/__init__.py +6 -0
- flwr/common/constant.py +21 -4
- flwr/server/app.py +7 -7
- flwr/server/compat/driver_client_proxy.py +5 -11
- flwr/server/run_serverapp.py +14 -9
- flwr/server/server.py +5 -5
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +17 -5
- flwr/server/workflow/default_workflows.py +4 -8
- flwr/simulation/__init__.py +2 -5
- flwr/simulation/ray_transport/ray_client_proxy.py +5 -10
- flwr/simulation/run_simulation.py +301 -76
- {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/METADATA +4 -3
- {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/RECORD +33 -27
- {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/entry_points.txt +1 -1
- {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240304.dist-info → flwr_nightly-1.8.0.dev20240306.dist-info}/WHEEL +0 -0
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
27
27
|
from flwr.common.grpc import create_channel
|
28
28
|
from flwr.common.logger import log, warn_experimental_feature
|
29
29
|
from flwr.common.message import Message, Metadata
|
30
|
+
from flwr.common.retry_invoker import RetryInvoker
|
30
31
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
31
32
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
32
33
|
CreateNodeRequest,
|
@@ -51,6 +52,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:
|
|
51
52
|
def grpc_request_response(
|
52
53
|
server_address: str,
|
53
54
|
insecure: bool,
|
55
|
+
retry_invoker: RetryInvoker,
|
54
56
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
55
57
|
root_certificates: Optional[Union[bytes, str]] = None,
|
56
58
|
) -> Iterator[
|
@@ -72,6 +74,13 @@ def grpc_request_response(
|
|
72
74
|
The IPv6 address of the server with `http://` or `https://`.
|
73
75
|
If the Flower server runs on the same machine
|
74
76
|
on port 8080, then `server_address` would be `"http://[::]:8080"`.
|
77
|
+
insecure : bool
|
78
|
+
Starts an insecure gRPC connection when True. Enables HTTPS connection
|
79
|
+
when False, using system certificates if `root_certificates` is None.
|
80
|
+
retry_invoker: RetryInvoker
|
81
|
+
`RetryInvoker` object that will try to reconnect the client to the server
|
82
|
+
after gRPC errors. If None, the client will only try to
|
83
|
+
reconnect once after a failure.
|
75
84
|
max_message_length : int
|
76
85
|
Ignored, only present to preserve API-compatibility.
|
77
86
|
root_certificates : Optional[Union[bytes, str]] (default: None)
|
@@ -113,7 +122,8 @@ def grpc_request_response(
|
|
113
122
|
def create_node() -> None:
|
114
123
|
"""Set create_node."""
|
115
124
|
create_node_request = CreateNodeRequest()
|
116
|
-
create_node_response =
|
125
|
+
create_node_response = retry_invoker.invoke(
|
126
|
+
stub.CreateNode,
|
117
127
|
request=create_node_request,
|
118
128
|
)
|
119
129
|
node_store[KEY_NODE] = create_node_response.node
|
@@ -127,7 +137,7 @@ def grpc_request_response(
|
|
127
137
|
node: Node = cast(Node, node_store[KEY_NODE])
|
128
138
|
|
129
139
|
delete_node_request = DeleteNodeRequest(node=node)
|
130
|
-
stub.DeleteNode
|
140
|
+
retry_invoker.invoke(stub.DeleteNode, request=delete_node_request)
|
131
141
|
|
132
142
|
del node_store[KEY_NODE]
|
133
143
|
|
@@ -141,7 +151,7 @@ def grpc_request_response(
|
|
141
151
|
|
142
152
|
# Request instructions (task) from server
|
143
153
|
request = PullTaskInsRequest(node=node)
|
144
|
-
response = stub.PullTaskIns
|
154
|
+
response = retry_invoker.invoke(stub.PullTaskIns, request=request)
|
145
155
|
|
146
156
|
# Get the current TaskIns
|
147
157
|
task_ins: Optional[TaskIns] = get_task_ins(response)
|
@@ -185,7 +195,7 @@ def grpc_request_response(
|
|
185
195
|
|
186
196
|
# Serialize ProtoBuf to bytes
|
187
197
|
request = PushTaskResRequest(task_res_list=[task_res])
|
188
|
-
_ = stub.PushTaskRes
|
198
|
+
_ = retry_invoker.invoke(stub.PushTaskRes, request)
|
189
199
|
|
190
200
|
state[KEY_METADATA] = None
|
191
201
|
|
@@ -27,12 +27,7 @@ from flwr.client.client import (
|
|
27
27
|
from flwr.client.numpy_client import NumPyClient
|
28
28
|
from flwr.client.typing import ClientFn
|
29
29
|
from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log
|
30
|
-
from flwr.common.constant import
|
31
|
-
MESSAGE_TYPE_EVALUATE,
|
32
|
-
MESSAGE_TYPE_FIT,
|
33
|
-
MESSAGE_TYPE_GET_PARAMETERS,
|
34
|
-
MESSAGE_TYPE_GET_PROPERTIES,
|
35
|
-
)
|
30
|
+
from flwr.common.constant import MessageType, MessageTypeLegacy
|
36
31
|
from flwr.common.recordset_compat import (
|
37
32
|
evaluateres_to_recordset,
|
38
33
|
fitres_to_recordset,
|
@@ -115,14 +110,14 @@ def handle_legacy_message_from_msgtype(
|
|
115
110
|
message_type = message.metadata.message_type
|
116
111
|
|
117
112
|
# Handle GetPropertiesIns
|
118
|
-
if message_type ==
|
113
|
+
if message_type == MessageTypeLegacy.GET_PROPERTIES:
|
119
114
|
get_properties_res = maybe_call_get_properties(
|
120
115
|
client=client,
|
121
116
|
get_properties_ins=recordset_to_getpropertiesins(message.content),
|
122
117
|
)
|
123
118
|
out_recordset = getpropertiesres_to_recordset(get_properties_res)
|
124
119
|
# Handle GetParametersIns
|
125
|
-
elif message_type ==
|
120
|
+
elif message_type == MessageTypeLegacy.GET_PARAMETERS:
|
126
121
|
get_parameters_res = maybe_call_get_parameters(
|
127
122
|
client=client,
|
128
123
|
get_parameters_ins=recordset_to_getparametersins(message.content),
|
@@ -131,14 +126,14 @@ def handle_legacy_message_from_msgtype(
|
|
131
126
|
get_parameters_res, keep_input=False
|
132
127
|
)
|
133
128
|
# Handle FitIns
|
134
|
-
elif message_type ==
|
129
|
+
elif message_type == MessageType.TRAIN:
|
135
130
|
fit_res = maybe_call_fit(
|
136
131
|
client=client,
|
137
132
|
fit_ins=recordset_to_fitins(message.content, keep_input=True),
|
138
133
|
)
|
139
134
|
out_recordset = fitres_to_recordset(fit_res, keep_input=False)
|
140
135
|
# Handle EvaluateIns
|
141
|
-
elif message_type ==
|
136
|
+
elif message_type == MessageType.EVALUATE:
|
142
137
|
evaluate_res = maybe_call_evaluate(
|
143
138
|
client=client,
|
144
139
|
evaluate_ins=recordset_to_evaluateins(message.content, keep_input=True),
|
@@ -18,7 +18,7 @@
|
|
18
18
|
from flwr.client.typing import ClientAppCallable
|
19
19
|
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
|
20
20
|
from flwr.common import recordset_compat as compat
|
21
|
-
from flwr.common.constant import
|
21
|
+
from flwr.common.constant import MessageType
|
22
22
|
from flwr.common.context import Context
|
23
23
|
from flwr.common.differential_privacy import (
|
24
24
|
compute_adaptive_clip_model_update,
|
@@ -40,7 +40,7 @@ def fixedclipping_mod(
|
|
40
40
|
|
41
41
|
This mod clips the client model updates before sending them to the server.
|
42
42
|
|
43
|
-
It operates on messages with type
|
43
|
+
It operates on messages with type MessageType.TRAIN.
|
44
44
|
|
45
45
|
Notes
|
46
46
|
-----
|
@@ -48,7 +48,7 @@ def fixedclipping_mod(
|
|
48
48
|
|
49
49
|
Typically, fixedclipping_mod should be the last to operate on params.
|
50
50
|
"""
|
51
|
-
if msg.metadata.message_type !=
|
51
|
+
if msg.metadata.message_type != MessageType.TRAIN:
|
52
52
|
return call_next(msg, ctxt)
|
53
53
|
fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
|
54
54
|
if KEY_CLIPPING_NORM not in fit_ins.config:
|
@@ -93,7 +93,7 @@ def adaptiveclipping_mod(
|
|
93
93
|
|
94
94
|
It also sends KEY_NORM_BIT to the server for computing the new clipping value.
|
95
95
|
|
96
|
-
It operates on messages with type
|
96
|
+
It operates on messages with type MessageType.TRAIN.
|
97
97
|
|
98
98
|
Notes
|
99
99
|
-----
|
@@ -101,7 +101,7 @@ def adaptiveclipping_mod(
|
|
101
101
|
|
102
102
|
Typically, adaptiveclipping_mod should be the last to operate on params.
|
103
103
|
"""
|
104
|
-
if msg.metadata.message_type !=
|
104
|
+
if msg.metadata.message_type != MessageType.TRAIN:
|
105
105
|
return call_next(msg, ctxt)
|
106
106
|
|
107
107
|
fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
|
@@ -30,7 +30,7 @@ from flwr.common import (
|
|
30
30
|
parameters_to_ndarrays,
|
31
31
|
)
|
32
32
|
from flwr.common import recordset_compat as compat
|
33
|
-
from flwr.common.constant import
|
33
|
+
from flwr.common.constant import MessageType
|
34
34
|
from flwr.common.logger import log
|
35
35
|
from flwr.common.secure_aggregation.crypto.shamir import create_shares
|
36
36
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
@@ -150,7 +150,7 @@ def secaggplus_mod(
|
|
150
150
|
) -> Message:
|
151
151
|
"""Handle incoming message and return results, following the SecAgg+ protocol."""
|
152
152
|
# Ignore non-fit messages
|
153
|
-
if msg.metadata.message_type !=
|
153
|
+
if msg.metadata.message_type != MessageType.TRAIN:
|
154
154
|
return call_next(msg, ctxt)
|
155
155
|
|
156
156
|
# Retrieve local state
|
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
27
27
|
from flwr.common.constant import MISSING_EXTRA_REST
|
28
28
|
from flwr.common.logger import log
|
29
29
|
from flwr.common.message import Message, Metadata
|
30
|
+
from flwr.common.retry_invoker import RetryInvoker
|
30
31
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
31
32
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
32
33
|
CreateNodeRequest,
|
@@ -61,6 +62,7 @@ PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
|
|
61
62
|
def http_request_response(
|
62
63
|
server_address: str,
|
63
64
|
insecure: bool, # pylint: disable=unused-argument
|
65
|
+
retry_invoker: RetryInvoker,
|
64
66
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
65
67
|
root_certificates: Optional[
|
66
68
|
Union[bytes, str]
|
@@ -84,6 +86,12 @@ def http_request_response(
|
|
84
86
|
The IPv6 address of the server with `http://` or `https://`.
|
85
87
|
If the Flower server runs on the same machine
|
86
88
|
on port 8080, then `server_address` would be `"http://[::]:8080"`.
|
89
|
+
insecure : bool
|
90
|
+
Unused argument present for compatibilty.
|
91
|
+
retry_invoker: RetryInvoker
|
92
|
+
`RetryInvoker` object that will try to reconnect the client to the server
|
93
|
+
after REST connection errors. If None, the client will only try to
|
94
|
+
reconnect once after a failure.
|
87
95
|
max_message_length : int
|
88
96
|
Ignored, only present to preserve API-compatibility.
|
89
97
|
root_certificates : Optional[Union[bytes, str]] (default: None)
|
@@ -134,7 +142,8 @@ def http_request_response(
|
|
134
142
|
create_node_req_proto = CreateNodeRequest()
|
135
143
|
create_node_req_bytes: bytes = create_node_req_proto.SerializeToString()
|
136
144
|
|
137
|
-
res =
|
145
|
+
res = retry_invoker.invoke(
|
146
|
+
requests.post,
|
138
147
|
url=f"{base_url}/{PATH_CREATE_NODE}",
|
139
148
|
headers={
|
140
149
|
"Accept": "application/protobuf",
|
@@ -177,7 +186,8 @@ def http_request_response(
|
|
177
186
|
node: Node = cast(Node, node_store[KEY_NODE])
|
178
187
|
delete_node_req_proto = DeleteNodeRequest(node=node)
|
179
188
|
delete_node_req_req_bytes: bytes = delete_node_req_proto.SerializeToString()
|
180
|
-
res =
|
189
|
+
res = retry_invoker.invoke(
|
190
|
+
requests.post,
|
181
191
|
url=f"{base_url}/{PATH_DELETE_NODE}",
|
182
192
|
headers={
|
183
193
|
"Accept": "application/protobuf",
|
@@ -218,7 +228,8 @@ def http_request_response(
|
|
218
228
|
pull_task_ins_req_bytes: bytes = pull_task_ins_req_proto.SerializeToString()
|
219
229
|
|
220
230
|
# Request instructions (task) from server
|
221
|
-
res =
|
231
|
+
res = retry_invoker.invoke(
|
232
|
+
requests.post,
|
222
233
|
url=f"{base_url}/{PATH_PULL_TASK_INS}",
|
223
234
|
headers={
|
224
235
|
"Accept": "application/protobuf",
|
@@ -298,7 +309,8 @@ def http_request_response(
|
|
298
309
|
)
|
299
310
|
|
300
311
|
# Send ClientMessage to server
|
301
|
-
res =
|
312
|
+
res = retry_invoker.invoke(
|
313
|
+
requests.post,
|
302
314
|
url=f"{base_url}/{PATH_PUSH_TASK_RES}",
|
303
315
|
headers={
|
304
316
|
"Accept": "application/protobuf",
|
flwr/common/__init__.py
CHANGED
@@ -15,11 +15,14 @@
|
|
15
15
|
"""Common components shared between server and client."""
|
16
16
|
|
17
17
|
|
18
|
+
from .constant import MessageType as MessageType
|
19
|
+
from .constant import MessageTypeLegacy as MessageTypeLegacy
|
18
20
|
from .context import Context as Context
|
19
21
|
from .date import now as now
|
20
22
|
from .grpc import GRPC_MAX_MESSAGE_LENGTH
|
21
23
|
from .logger import configure as configure
|
22
24
|
from .logger import log as log
|
25
|
+
from .message import Error as Error
|
23
26
|
from .message import Message as Message
|
24
27
|
from .message import Metadata as Metadata
|
25
28
|
from .parameter import bytes_to_ndarray as bytes_to_ndarray
|
@@ -74,6 +77,7 @@ __all__ = [
|
|
74
77
|
"EventType",
|
75
78
|
"FitIns",
|
76
79
|
"FitRes",
|
80
|
+
"Error",
|
77
81
|
"GetParametersIns",
|
78
82
|
"GetParametersRes",
|
79
83
|
"GetPropertiesIns",
|
@@ -81,6 +85,8 @@ __all__ = [
|
|
81
85
|
"GRPC_MAX_MESSAGE_LENGTH",
|
82
86
|
"log",
|
83
87
|
"Message",
|
88
|
+
"MessageType",
|
89
|
+
"MessageTypeLegacy",
|
84
90
|
"Metadata",
|
85
91
|
"Metrics",
|
86
92
|
"MetricsAggregationFn",
|
flwr/common/constant.py
CHANGED
@@ -36,10 +36,27 @@ TRANSPORT_TYPES = [
|
|
36
36
|
TRANSPORT_TYPE_VCE,
|
37
37
|
]
|
38
38
|
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
39
|
+
|
40
|
+
class MessageType:
|
41
|
+
"""Message type."""
|
42
|
+
|
43
|
+
TRAIN = "train"
|
44
|
+
EVALUATE = "evaluate"
|
45
|
+
|
46
|
+
def __new__(cls) -> MessageType:
|
47
|
+
"""Prevent instantiation."""
|
48
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
49
|
+
|
50
|
+
|
51
|
+
class MessageTypeLegacy:
|
52
|
+
"""Legacy message type."""
|
53
|
+
|
54
|
+
GET_PROPERTIES = "get_properties"
|
55
|
+
GET_PARAMETERS = "get_parameters"
|
56
|
+
|
57
|
+
def __new__(cls) -> MessageTypeLegacy:
|
58
|
+
"""Prevent instantiation."""
|
59
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
43
60
|
|
44
61
|
|
45
62
|
class SType:
|
flwr/server/app.py
CHANGED
@@ -362,10 +362,10 @@ def run_superlink() -> None:
|
|
362
362
|
f_stop = asyncio.Event() # Does nothing
|
363
363
|
_run_fleet_api_vce(
|
364
364
|
num_supernodes=args.num_supernodes,
|
365
|
-
|
365
|
+
client_app_attr=args.client_app,
|
366
366
|
backend_name=args.backend,
|
367
367
|
backend_config_json_stream=args.backend_config,
|
368
|
-
|
368
|
+
app_dir=args.app_dir,
|
369
369
|
state_factory=state_factory,
|
370
370
|
f_stop=f_stop,
|
371
371
|
)
|
@@ -438,10 +438,10 @@ def _run_fleet_api_grpc_rere(
|
|
438
438
|
# pylint: disable=too-many-arguments
|
439
439
|
def _run_fleet_api_vce(
|
440
440
|
num_supernodes: int,
|
441
|
-
|
441
|
+
client_app_attr: str,
|
442
442
|
backend_name: str,
|
443
443
|
backend_config_json_stream: str,
|
444
|
-
|
444
|
+
app_dir: str,
|
445
445
|
state_factory: StateFactory,
|
446
446
|
f_stop: asyncio.Event,
|
447
447
|
) -> None:
|
@@ -449,11 +449,11 @@ def _run_fleet_api_vce(
|
|
449
449
|
|
450
450
|
start_vce(
|
451
451
|
num_supernodes=num_supernodes,
|
452
|
-
|
452
|
+
client_app_attr=client_app_attr,
|
453
453
|
backend_name=backend_name,
|
454
454
|
backend_config_json_stream=backend_config_json_stream,
|
455
455
|
state_factory=state_factory,
|
456
|
-
|
456
|
+
app_dir=app_dir,
|
457
457
|
f_stop=f_stop,
|
458
458
|
)
|
459
459
|
|
@@ -705,7 +705,7 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
|
|
705
705
|
"`flwr.common.typing.ConfigsRecordValues`. ",
|
706
706
|
)
|
707
707
|
parser.add_argument(
|
708
|
-
"--dir",
|
708
|
+
"--app-dir",
|
709
709
|
default="",
|
710
710
|
help="Add specified directory to the PYTHONPATH and load"
|
711
711
|
"ClientApp from there."
|
@@ -19,15 +19,9 @@ import time
|
|
19
19
|
from typing import List, Optional
|
20
20
|
|
21
21
|
from flwr import common
|
22
|
-
from flwr.common import RecordSet
|
22
|
+
from flwr.common import MessageType, MessageTypeLegacy, RecordSet
|
23
23
|
from flwr.common import recordset_compat as compat
|
24
24
|
from flwr.common import serde
|
25
|
-
from flwr.common.constant import (
|
26
|
-
MESSAGE_TYPE_EVALUATE,
|
27
|
-
MESSAGE_TYPE_FIT,
|
28
|
-
MESSAGE_TYPE_GET_PARAMETERS,
|
29
|
-
MESSAGE_TYPE_GET_PROPERTIES,
|
30
|
-
)
|
31
25
|
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
|
32
26
|
from flwr.server.client_proxy import ClientProxy
|
33
27
|
|
@@ -57,7 +51,7 @@ class DriverClientProxy(ClientProxy):
|
|
57
51
|
out_recordset = compat.getpropertiesins_to_recordset(ins)
|
58
52
|
# Fetch response
|
59
53
|
in_recordset = self._send_receive_recordset(
|
60
|
-
out_recordset,
|
54
|
+
out_recordset, MessageTypeLegacy.GET_PROPERTIES, timeout, group_id
|
61
55
|
)
|
62
56
|
# RecordSet to Res
|
63
57
|
return compat.recordset_to_getpropertiesres(in_recordset)
|
@@ -73,7 +67,7 @@ class DriverClientProxy(ClientProxy):
|
|
73
67
|
out_recordset = compat.getparametersins_to_recordset(ins)
|
74
68
|
# Fetch response
|
75
69
|
in_recordset = self._send_receive_recordset(
|
76
|
-
out_recordset,
|
70
|
+
out_recordset, MessageTypeLegacy.GET_PARAMETERS, timeout, group_id
|
77
71
|
)
|
78
72
|
# RecordSet to Res
|
79
73
|
return compat.recordset_to_getparametersres(in_recordset, False)
|
@@ -86,7 +80,7 @@ class DriverClientProxy(ClientProxy):
|
|
86
80
|
out_recordset = compat.fitins_to_recordset(ins, keep_input=True)
|
87
81
|
# Fetch response
|
88
82
|
in_recordset = self._send_receive_recordset(
|
89
|
-
out_recordset,
|
83
|
+
out_recordset, MessageType.TRAIN, timeout, group_id
|
90
84
|
)
|
91
85
|
# RecordSet to Res
|
92
86
|
return compat.recordset_to_fitres(in_recordset, keep_input=False)
|
@@ -99,7 +93,7 @@ class DriverClientProxy(ClientProxy):
|
|
99
93
|
out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True)
|
100
94
|
# Fetch response
|
101
95
|
in_recordset = self._send_receive_recordset(
|
102
|
-
out_recordset,
|
96
|
+
out_recordset, MessageType.EVALUATE, timeout, group_id
|
103
97
|
)
|
104
98
|
# RecordSet to Res
|
105
99
|
return compat.recordset_to_evaluateres(in_recordset)
|
flwr/server/run_serverapp.py
CHANGED
@@ -16,7 +16,6 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import argparse
|
19
|
-
import asyncio
|
20
19
|
import sys
|
21
20
|
from logging import DEBUG, WARN
|
22
21
|
from pathlib import Path
|
@@ -30,17 +29,27 @@ from .server_app import ServerApp, load_server_app
|
|
30
29
|
|
31
30
|
|
32
31
|
def run(
|
33
|
-
server_app_attr: str,
|
34
32
|
driver: Driver,
|
35
33
|
server_app_dir: str,
|
36
|
-
|
34
|
+
server_app_attr: Optional[str] = None,
|
35
|
+
loaded_server_app: Optional[ServerApp] = None,
|
37
36
|
) -> None:
|
38
37
|
"""Run ServerApp with a given Driver."""
|
38
|
+
if not (server_app_attr is None) ^ (loaded_server_app is None):
|
39
|
+
raise ValueError(
|
40
|
+
"Either `server_app_attr` or `loaded_server_app` should be set "
|
41
|
+
"but not both. "
|
42
|
+
)
|
43
|
+
|
39
44
|
if server_app_dir is not None:
|
40
45
|
sys.path.insert(0, server_app_dir)
|
41
46
|
|
47
|
+
# Load ServerApp if needed
|
42
48
|
def _load() -> ServerApp:
|
43
|
-
|
49
|
+
if server_app_attr:
|
50
|
+
server_app: ServerApp = load_server_app(server_app_attr)
|
51
|
+
if loaded_server_app:
|
52
|
+
server_app = loaded_server_app
|
44
53
|
return server_app
|
45
54
|
|
46
55
|
server_app = _load()
|
@@ -52,10 +61,6 @@ def run(
|
|
52
61
|
server_app(driver=driver, context=context)
|
53
62
|
|
54
63
|
log(DEBUG, "ServerApp finished running.")
|
55
|
-
# Upon completion, trigger stop event if one was passed
|
56
|
-
if stop_event is not None:
|
57
|
-
log(DEBUG, "Triggering stop event.")
|
58
|
-
stop_event.set()
|
59
64
|
|
60
65
|
|
61
66
|
def run_server_app() -> None:
|
@@ -117,7 +122,7 @@ def run_server_app() -> None:
|
|
117
122
|
)
|
118
123
|
|
119
124
|
# Run the Server App with the Driver
|
120
|
-
run(
|
125
|
+
run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
|
121
126
|
|
122
127
|
# Clean up
|
123
128
|
driver.__del__() # pylint: disable=unnecessary-dunder-call
|
flwr/server/server.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
import concurrent.futures
|
19
19
|
import timeit
|
20
|
-
from logging import
|
20
|
+
from logging import INFO, WARN
|
21
21
|
from typing import Dict, List, Optional, Tuple, Union
|
22
22
|
|
23
23
|
from flwr.common import (
|
@@ -173,7 +173,7 @@ class Server:
|
|
173
173
|
log(INFO, "evaluate_round %s: no clients selected, cancel", server_round)
|
174
174
|
return None
|
175
175
|
log(
|
176
|
-
|
176
|
+
INFO,
|
177
177
|
"evaluate_round %s: strategy sampled %s clients (out of %s)",
|
178
178
|
server_round,
|
179
179
|
len(client_instructions),
|
@@ -188,7 +188,7 @@ class Server:
|
|
188
188
|
group_id=server_round,
|
189
189
|
)
|
190
190
|
log(
|
191
|
-
|
191
|
+
INFO,
|
192
192
|
"evaluate_round %s received %s results and %s failures",
|
193
193
|
server_round,
|
194
194
|
len(results),
|
@@ -223,7 +223,7 @@ class Server:
|
|
223
223
|
log(INFO, "fit_round %s: no clients selected, cancel", server_round)
|
224
224
|
return None
|
225
225
|
log(
|
226
|
-
|
226
|
+
INFO,
|
227
227
|
"fit_round %s: strategy sampled %s clients (out of %s)",
|
228
228
|
server_round,
|
229
229
|
len(client_instructions),
|
@@ -238,7 +238,7 @@ class Server:
|
|
238
238
|
group_id=server_round,
|
239
239
|
)
|
240
240
|
log(
|
241
|
-
|
241
|
+
INFO,
|
242
242
|
"fit_round %s received %s results and %s failures",
|
243
243
|
server_round,
|
244
244
|
len(results),
|
@@ -49,7 +49,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
49
49
|
self, request: GetNodesRequest, context: grpc.ServicerContext
|
50
50
|
) -> GetNodesResponse:
|
51
51
|
"""Get available nodes."""
|
52
|
-
log(
|
52
|
+
log(DEBUG, "DriverServicer.GetNodes")
|
53
53
|
state: State = self.state_factory.state()
|
54
54
|
all_ids: Set[int] = state.get_nodes(request.run_id)
|
55
55
|
nodes: List[Node] = [
|
@@ -219,16 +219,23 @@ async def run(
|
|
219
219
|
|
220
220
|
# pylint: disable=too-many-arguments,unused-argument,too-many-locals
|
221
221
|
def start_vce(
|
222
|
-
client_app_module_name: str,
|
223
222
|
backend_name: str,
|
224
223
|
backend_config_json_stream: str,
|
225
|
-
|
224
|
+
app_dir: str,
|
226
225
|
f_stop: asyncio.Event,
|
226
|
+
client_app: Optional[ClientApp] = None,
|
227
|
+
client_app_attr: Optional[str] = None,
|
227
228
|
num_supernodes: Optional[int] = None,
|
228
229
|
state_factory: Optional[StateFactory] = None,
|
229
230
|
existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
|
230
231
|
) -> None:
|
231
232
|
"""Start Fleet API with the Simulation Engine."""
|
233
|
+
if client_app_attr is not None and client_app is not None:
|
234
|
+
raise ValueError(
|
235
|
+
"Both `client_app_attr` and `client_app` are provided, "
|
236
|
+
"but only one is allowed."
|
237
|
+
)
|
238
|
+
|
232
239
|
if num_supernodes is not None and existing_nodes_mapping is not None:
|
233
240
|
raise ValueError(
|
234
241
|
"Both `num_supernodes` and `existing_nodes_mapping` are provided, "
|
@@ -290,12 +297,17 @@ def start_vce(
|
|
290
297
|
|
291
298
|
def backend_fn() -> Backend:
|
292
299
|
"""Instantiate a Backend."""
|
293
|
-
return backend_type(backend_config, work_dir=
|
300
|
+
return backend_type(backend_config, work_dir=app_dir)
|
294
301
|
|
295
|
-
log(INFO, "
|
302
|
+
log(INFO, "client_app_attr = %s", client_app_attr)
|
296
303
|
|
304
|
+
# Load ClientApp if needed
|
297
305
|
def _load() -> ClientApp:
|
298
|
-
|
306
|
+
|
307
|
+
if client_app_attr:
|
308
|
+
app: ClientApp = load_client_app(client_app_attr)
|
309
|
+
if client_app:
|
310
|
+
app = client_app
|
299
311
|
return app
|
300
312
|
|
301
313
|
app_fn = _load
|
@@ -21,11 +21,7 @@ from typing import Optional, cast
|
|
21
21
|
|
22
22
|
import flwr.common.recordset_compat as compat
|
23
23
|
from flwr.common import ConfigsRecord, Context, GetParametersIns, log
|
24
|
-
from flwr.common.constant import
|
25
|
-
MESSAGE_TYPE_EVALUATE,
|
26
|
-
MESSAGE_TYPE_FIT,
|
27
|
-
MESSAGE_TYPE_GET_PARAMETERS,
|
28
|
-
)
|
24
|
+
from flwr.common.constant import MessageType, MessageTypeLegacy
|
29
25
|
|
30
26
|
from ..compat.app_utils import start_update_client_manager_thread
|
31
27
|
from ..compat.legacy_context import LegacyContext
|
@@ -134,7 +130,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
134
130
|
[
|
135
131
|
driver.create_message(
|
136
132
|
content=content,
|
137
|
-
message_type=
|
133
|
+
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
138
134
|
dst_node_id=random_client.node_id,
|
139
135
|
group_id="",
|
140
136
|
ttl="",
|
@@ -232,7 +228,7 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
|
|
232
228
|
out_messages = [
|
233
229
|
driver.create_message(
|
234
230
|
content=compat.fitins_to_recordset(fitins, True),
|
235
|
-
message_type=
|
231
|
+
message_type=MessageType.TRAIN,
|
236
232
|
dst_node_id=proxy.node_id,
|
237
233
|
group_id="",
|
238
234
|
ttl="",
|
@@ -313,7 +309,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
313
309
|
out_messages = [
|
314
310
|
driver.create_message(
|
315
311
|
content=compat.evaluateins_to_recordset(evalins, True),
|
316
|
-
message_type=
|
312
|
+
message_type=MessageType.EVALUATE,
|
317
313
|
dst_node_id=proxy.node_id,
|
318
314
|
group_id="",
|
319
315
|
ttl="",
|
flwr/simulation/__init__.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
import importlib
|
19
19
|
|
20
|
-
from flwr.simulation.run_simulation import run_simulation
|
20
|
+
from flwr.simulation.run_simulation import run_simulation, run_simulation_from_cli
|
21
21
|
|
22
22
|
is_ray_installed = importlib.util.find_spec("ray") is not None
|
23
23
|
|
@@ -36,7 +36,4 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
|
|
36
36
|
raise ImportError(RAY_IMPORT_ERROR)
|
37
37
|
|
38
38
|
|
39
|
-
__all__ = [
|
40
|
-
"start_simulation",
|
41
|
-
"run_simulation",
|
42
|
-
]
|
39
|
+
__all__ = ["start_simulation", "run_simulation_from_cli", "run_simulation"]
|