flwr-nightly 1.8.0.dev20240226__py3-none-any.whl → 1.8.0.dev20240227__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +4 -3
- flwr/client/mod/__init__.py +2 -0
- flwr/client/mod/centraldp_mods.py +76 -0
- flwr/common/differential_privacy_constants.py +2 -0
- flwr/common/exit_handlers.py +87 -0
- flwr/server/app.py +7 -54
- flwr/server/strategy/__init__.py +5 -1
- flwr/server/strategy/dp_fixed_clipping.py +156 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +31 -14
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +4 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- {flwr_nightly-1.8.0.dev20240226.dist-info → flwr_nightly-1.8.0.dev20240227.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240226.dist-info → flwr_nightly-1.8.0.dev20240227.dist-info}/RECORD +20 -18
- /flwr/client/{clientapp.py → client_app.py} +0 -0
- {flwr_nightly-1.8.0.dev20240226.dist-info → flwr_nightly-1.8.0.dev20240227.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240226.dist-info → flwr_nightly-1.8.0.dev20240227.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240226.dist-info → flwr_nightly-1.8.0.dev20240227.dist-info}/entry_points.txt +0 -0
flwr/client/__init__.py
CHANGED
@@ -19,7 +19,7 @@ from .app import run_client_app as run_client_app
|
|
19
19
|
from .app import start_client as start_client
|
20
20
|
from .app import start_numpy_client as start_numpy_client
|
21
21
|
from .client import Client as Client
|
22
|
-
from .
|
22
|
+
from .client_app import ClientApp as ClientApp
|
23
23
|
from .numpy_client import NumPyClient as NumPyClient
|
24
24
|
from .typing import ClientFn as ClientFn
|
25
25
|
|
flwr/client/app.py
CHANGED
@@ -23,7 +23,7 @@ from pathlib import Path
|
|
23
23
|
from typing import Callable, ContextManager, Optional, Tuple, Union
|
24
24
|
|
25
25
|
from flwr.client.client import Client
|
26
|
-
from flwr.client.
|
26
|
+
from flwr.client.client_app import ClientApp
|
27
27
|
from flwr.client.typing import ClientFn
|
28
28
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
|
29
29
|
from flwr.common.address import parse_address
|
@@ -34,9 +34,10 @@ from flwr.common.constant import (
|
|
34
34
|
TRANSPORT_TYPE_REST,
|
35
35
|
TRANSPORT_TYPES,
|
36
36
|
)
|
37
|
+
from flwr.common.exit_handlers import register_exit_handlers
|
37
38
|
from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature
|
38
39
|
|
39
|
-
from .
|
40
|
+
from .client_app import load_client_app
|
40
41
|
from .grpc_client.connection import grpc_connection
|
41
42
|
from .grpc_rere_client.connection import grpc_request_response
|
42
43
|
from .message_handler.message_handler import handle_control_message
|
@@ -104,7 +105,7 @@ def run_client_app() -> None:
|
|
104
105
|
root_certificates=root_certificates,
|
105
106
|
insecure=args.insecure,
|
106
107
|
)
|
107
|
-
|
108
|
+
register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
|
108
109
|
|
109
110
|
|
110
111
|
def _parse_args_run_client_app() -> argparse.ArgumentParser:
|
flwr/client/mod/__init__.py
CHANGED
@@ -15,10 +15,12 @@
|
|
15
15
|
"""Mods."""
|
16
16
|
|
17
17
|
|
18
|
+
from .centraldp_mods import fixedclipping_mod
|
18
19
|
from .secure_aggregation.secaggplus_mod import secaggplus_mod
|
19
20
|
from .utils import make_ffn
|
20
21
|
|
21
22
|
__all__ = [
|
22
23
|
"make_ffn",
|
23
24
|
"secaggplus_mod",
|
25
|
+
"fixedclipping_mod",
|
24
26
|
]
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Clipping modifiers for central DP with client-side clipping."""
|
16
|
+
|
17
|
+
|
18
|
+
from flwr.client.typing import ClientAppCallable
|
19
|
+
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
|
20
|
+
from flwr.common import recordset_compat as compat
|
21
|
+
from flwr.common.constant import MESSAGE_TYPE_FIT
|
22
|
+
from flwr.common.context import Context
|
23
|
+
from flwr.common.differential_privacy import compute_clip_model_update
|
24
|
+
from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM
|
25
|
+
from flwr.common.message import Message
|
26
|
+
|
27
|
+
|
28
|
+
def fixedclipping_mod(
|
29
|
+
msg: Message, ctxt: Context, call_next: ClientAppCallable
|
30
|
+
) -> Message:
|
31
|
+
"""Client-side fixed clipping modifier.
|
32
|
+
|
33
|
+
This mod needs to be used with the DifferentialPrivacyClientSideFixedClipping
|
34
|
+
server-side strategy wrapper.
|
35
|
+
|
36
|
+
The wrapper sends the clipping_norm value to the client.
|
37
|
+
|
38
|
+
This mod clips the client model updates before sending them to the server.
|
39
|
+
|
40
|
+
It operates on messages with type MESSAGE_TYPE_FIT.
|
41
|
+
|
42
|
+
Notes
|
43
|
+
-----
|
44
|
+
Consider the order of mods when using multiple.
|
45
|
+
|
46
|
+
Typically, fixedclipping_mod should be the last to operate on params.
|
47
|
+
"""
|
48
|
+
if msg.metadata.message_type != MESSAGE_TYPE_FIT:
|
49
|
+
return call_next(msg, ctxt)
|
50
|
+
fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
|
51
|
+
if KEY_CLIPPING_NORM not in fit_ins.config:
|
52
|
+
raise KeyError(
|
53
|
+
f"The {KEY_CLIPPING_NORM} value is not supplied by the "
|
54
|
+
f"DifferentialPrivacyClientSideFixedClipping wrapper at"
|
55
|
+
f" the server side."
|
56
|
+
)
|
57
|
+
|
58
|
+
clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM])
|
59
|
+
server_to_client_params = parameters_to_ndarrays(fit_ins.parameters)
|
60
|
+
|
61
|
+
# Call inner app
|
62
|
+
out_msg = call_next(msg, ctxt)
|
63
|
+
fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True)
|
64
|
+
|
65
|
+
client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
|
66
|
+
|
67
|
+
# Clip the client update
|
68
|
+
compute_clip_model_update(
|
69
|
+
client_to_server_params,
|
70
|
+
server_to_client_params,
|
71
|
+
clipping_norm,
|
72
|
+
)
|
73
|
+
|
74
|
+
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
75
|
+
out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
|
76
|
+
return out_msg
|
@@ -14,6 +14,8 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Constants for differential privacy."""
|
16
16
|
|
17
|
+
|
18
|
+
KEY_CLIPPING_NORM = "clipping_norm"
|
17
19
|
CLIENTS_DISCREPANCY_WARNING = (
|
18
20
|
"The number of clients returning parameters (%s)"
|
19
21
|
" differs from the number of sampled clients (%s)."
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Common function to register exit handlers for server and client."""
|
16
|
+
|
17
|
+
|
18
|
+
import sys
|
19
|
+
from signal import SIGINT, SIGTERM, signal
|
20
|
+
from threading import Thread
|
21
|
+
from types import FrameType
|
22
|
+
from typing import List, Optional
|
23
|
+
|
24
|
+
from grpc import Server
|
25
|
+
|
26
|
+
from flwr.common.telemetry import EventType, event
|
27
|
+
|
28
|
+
|
29
|
+
def register_exit_handlers(
|
30
|
+
event_type: EventType,
|
31
|
+
grpc_servers: Optional[List[Server]] = None,
|
32
|
+
bckg_threads: Optional[List[Thread]] = None,
|
33
|
+
) -> None:
|
34
|
+
"""Register exit handlers for `SIGINT` and `SIGTERM` signals.
|
35
|
+
|
36
|
+
Parameters
|
37
|
+
----------
|
38
|
+
event_type : EventType
|
39
|
+
The telemetry event that should be logged before exit.
|
40
|
+
grpc_servers: Optional[List[Server]] (default: None)
|
41
|
+
An otpional list of gRPC servers that need to be gracefully
|
42
|
+
terminated before exiting.
|
43
|
+
bckg_threads: Optional[List[Thread]] (default: None)
|
44
|
+
An optional list of threads that need to be gracefully
|
45
|
+
terminated before exiting.
|
46
|
+
"""
|
47
|
+
default_handlers = {
|
48
|
+
SIGINT: None,
|
49
|
+
SIGTERM: None,
|
50
|
+
}
|
51
|
+
|
52
|
+
def graceful_exit_handler( # type: ignore
|
53
|
+
signalnum,
|
54
|
+
frame: FrameType, # pylint: disable=unused-argument
|
55
|
+
) -> None:
|
56
|
+
"""Exit handler to be registered with `signal.signal`.
|
57
|
+
|
58
|
+
When called will reset signal handler to original signal handler from
|
59
|
+
default_handlers.
|
60
|
+
"""
|
61
|
+
# Reset to default handler
|
62
|
+
signal(signalnum, default_handlers[signalnum])
|
63
|
+
|
64
|
+
event_res = event(event_type=event_type)
|
65
|
+
|
66
|
+
if grpc_servers is not None:
|
67
|
+
for grpc_server in grpc_servers:
|
68
|
+
grpc_server.stop(grace=1)
|
69
|
+
|
70
|
+
if bckg_threads is not None:
|
71
|
+
for bckg_thread in bckg_threads:
|
72
|
+
bckg_thread.join()
|
73
|
+
|
74
|
+
# Ensure event has happend
|
75
|
+
event_res.result()
|
76
|
+
|
77
|
+
# Setup things for graceful exit
|
78
|
+
sys.exit(0)
|
79
|
+
|
80
|
+
default_handlers[SIGINT] = signal( # type: ignore
|
81
|
+
SIGINT,
|
82
|
+
graceful_exit_handler, # type: ignore
|
83
|
+
)
|
84
|
+
default_handlers[SIGTERM] = signal( # type: ignore
|
85
|
+
SIGTERM,
|
86
|
+
graceful_exit_handler, # type: ignore
|
87
|
+
)
|
flwr/server/app.py
CHANGED
@@ -22,8 +22,6 @@ import threading
|
|
22
22
|
from logging import ERROR, INFO, WARN
|
23
23
|
from os.path import isfile
|
24
24
|
from pathlib import Path
|
25
|
-
from signal import SIGINT, SIGTERM, signal
|
26
|
-
from types import FrameType
|
27
25
|
from typing import List, Optional, Tuple
|
28
26
|
|
29
27
|
import grpc
|
@@ -36,6 +34,7 @@ from flwr.common.constant import (
|
|
36
34
|
TRANSPORT_TYPE_REST,
|
37
35
|
TRANSPORT_TYPE_VCE,
|
38
36
|
)
|
37
|
+
from flwr.common.exit_handlers import register_exit_handlers
|
39
38
|
from flwr.common.logger import log
|
40
39
|
from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
|
41
40
|
add_DriverServicer_to_server,
|
@@ -212,10 +211,10 @@ def run_driver_api() -> None:
|
|
212
211
|
)
|
213
212
|
|
214
213
|
# Graceful shutdown
|
215
|
-
|
214
|
+
register_exit_handlers(
|
215
|
+
event_type=EventType.RUN_DRIVER_API_LEAVE,
|
216
216
|
grpc_servers=[grpc_server],
|
217
217
|
bckg_threads=[],
|
218
|
-
event_type=EventType.RUN_DRIVER_API_LEAVE,
|
219
218
|
)
|
220
219
|
|
221
220
|
# Block
|
@@ -280,10 +279,10 @@ def run_fleet_api() -> None:
|
|
280
279
|
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
|
281
280
|
|
282
281
|
# Graceful shutdown
|
283
|
-
|
282
|
+
register_exit_handlers(
|
283
|
+
event_type=EventType.RUN_FLEET_API_LEAVE,
|
284
284
|
grpc_servers=grpc_servers,
|
285
285
|
bckg_threads=bckg_threads,
|
286
|
-
event_type=EventType.RUN_FLEET_API_LEAVE,
|
287
286
|
)
|
288
287
|
|
289
288
|
# Block
|
@@ -375,10 +374,10 @@ def run_superlink() -> None:
|
|
375
374
|
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
|
376
375
|
|
377
376
|
# Graceful shutdown
|
378
|
-
|
377
|
+
register_exit_handlers(
|
378
|
+
event_type=EventType.RUN_SUPERLINK_LEAVE,
|
379
379
|
grpc_servers=grpc_servers,
|
380
380
|
bckg_threads=bckg_threads,
|
381
|
-
event_type=EventType.RUN_SUPERLINK_LEAVE,
|
382
381
|
)
|
383
382
|
|
384
383
|
# Block
|
@@ -413,52 +412,6 @@ def _try_obtain_certificates(
|
|
413
412
|
return certificates
|
414
413
|
|
415
414
|
|
416
|
-
def _register_exit_handlers(
|
417
|
-
grpc_servers: List[grpc.Server],
|
418
|
-
bckg_threads: List[threading.Thread],
|
419
|
-
event_type: EventType,
|
420
|
-
) -> None:
|
421
|
-
default_handlers = {
|
422
|
-
SIGINT: None,
|
423
|
-
SIGTERM: None,
|
424
|
-
}
|
425
|
-
|
426
|
-
def graceful_exit_handler( # type: ignore
|
427
|
-
signalnum,
|
428
|
-
frame: FrameType, # pylint: disable=unused-argument
|
429
|
-
) -> None:
|
430
|
-
"""Exit handler to be registered with signal.signal.
|
431
|
-
|
432
|
-
When called will reset signal handler to original signal handler from
|
433
|
-
default_handlers.
|
434
|
-
"""
|
435
|
-
# Reset to default handler
|
436
|
-
signal(signalnum, default_handlers[signalnum])
|
437
|
-
|
438
|
-
event_res = event(event_type=event_type)
|
439
|
-
|
440
|
-
for grpc_server in grpc_servers:
|
441
|
-
grpc_server.stop(grace=1)
|
442
|
-
|
443
|
-
for bckg_thread in bckg_threads:
|
444
|
-
bckg_thread.join()
|
445
|
-
|
446
|
-
# Ensure event has happend
|
447
|
-
event_res.result()
|
448
|
-
|
449
|
-
# Setup things for graceful exit
|
450
|
-
sys.exit(0)
|
451
|
-
|
452
|
-
default_handlers[SIGINT] = signal( # type: ignore
|
453
|
-
SIGINT,
|
454
|
-
graceful_exit_handler, # type: ignore
|
455
|
-
)
|
456
|
-
default_handlers[SIGTERM] = signal( # type: ignore
|
457
|
-
SIGTERM,
|
458
|
-
graceful_exit_handler, # type: ignore
|
459
|
-
)
|
460
|
-
|
461
|
-
|
462
415
|
def _run_driver_api_grpc(
|
463
416
|
address: str,
|
464
417
|
state_factory: StateFactory,
|
flwr/server/strategy/__init__.py
CHANGED
@@ -16,7 +16,10 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
from .bulyan import Bulyan as Bulyan
|
19
|
-
from .dp_fixed_clipping import
|
19
|
+
from .dp_fixed_clipping import (
|
20
|
+
DifferentialPrivacyClientSideFixedClipping,
|
21
|
+
DifferentialPrivacyServerSideFixedClipping,
|
22
|
+
)
|
20
23
|
from .dpfedavg_adaptive import DPFedAvgAdaptive as DPFedAvgAdaptive
|
21
24
|
from .dpfedavg_fixed import DPFedAvgFixed as DPFedAvgFixed
|
22
25
|
from .fault_tolerant_fedavg import FaultTolerantFedAvg as FaultTolerantFedAvg
|
@@ -59,4 +62,5 @@ __all__ = [
|
|
59
62
|
"DPFedAvgFixed",
|
60
63
|
"Strategy",
|
61
64
|
"DifferentialPrivacyServerSideFixedClipping",
|
65
|
+
"DifferentialPrivacyClientSideFixedClipping",
|
62
66
|
]
|
@@ -36,7 +36,10 @@ from flwr.common.differential_privacy import (
|
|
36
36
|
add_gaussian_noise_to_params,
|
37
37
|
compute_clip_model_update,
|
38
38
|
)
|
39
|
-
from flwr.common.differential_privacy_constants import
|
39
|
+
from flwr.common.differential_privacy_constants import (
|
40
|
+
CLIENTS_DISCREPANCY_WARNING,
|
41
|
+
KEY_CLIPPING_NORM,
|
42
|
+
)
|
40
43
|
from flwr.common.logger import log
|
41
44
|
from flwr.server.client_manager import ClientManager
|
42
45
|
from flwr.server.client_proxy import ClientProxy
|
@@ -44,7 +47,8 @@ from flwr.server.strategy.strategy import Strategy
|
|
44
47
|
|
45
48
|
|
46
49
|
class DifferentialPrivacyServerSideFixedClipping(Strategy):
|
47
|
-
"""
|
50
|
+
"""Strategy wrapper for central differential privacy with server-side fixed
|
51
|
+
clipping.
|
48
52
|
|
49
53
|
Parameters
|
50
54
|
----------
|
@@ -185,3 +189,153 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
|
|
185
189
|
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
|
186
190
|
"""Evaluate model parameters using an evaluation function from the strategy."""
|
187
191
|
return self.strategy.evaluate(server_round, parameters)
|
192
|
+
|
193
|
+
|
194
|
+
class DifferentialPrivacyClientSideFixedClipping(Strategy):
|
195
|
+
"""Strategy wrapper for central differential privacy with client-side fixed
|
196
|
+
clipping.
|
197
|
+
|
198
|
+
Use `fixedclipping_mod` modifier at the client side.
|
199
|
+
|
200
|
+
In comparison to `DifferentialPrivacyServerSideFixedClipping`,
|
201
|
+
which performs clipping on the server-side, `DifferentialPrivacyClientSideFixedClipping`
|
202
|
+
expects clipping to happen on the client-side, usually by using the built-in
|
203
|
+
`fixedclipping_mod `.
|
204
|
+
|
205
|
+
Parameters
|
206
|
+
----------
|
207
|
+
strategy : Strategy
|
208
|
+
The strategy to which DP functionalities will be added by this wrapper.
|
209
|
+
noise_multiplier : float
|
210
|
+
The noise multiplier for the Gaussian mechanism for model updates.
|
211
|
+
A value of 1.0 or higher is recommended for strong privacy.
|
212
|
+
clipping_norm : float
|
213
|
+
The value of the clipping norm.
|
214
|
+
num_sampled_clients : int
|
215
|
+
The number of clients that are sampled on each round.
|
216
|
+
|
217
|
+
Examples
|
218
|
+
--------
|
219
|
+
Create a strategy:
|
220
|
+
|
221
|
+
>>> strategy = fl.server.strategy.FedAvg(...)
|
222
|
+
|
223
|
+
Wrap the strategy with the `DifferentialPrivacyServerSideFixedClipping` wrapper:
|
224
|
+
|
225
|
+
>>> DifferentialPrivacyClientSideFixedClipping(
|
226
|
+
>>> strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients
|
227
|
+
>>> )
|
228
|
+
|
229
|
+
On the client, add the `fixedclipping_mod` to the client-side mods:
|
230
|
+
|
231
|
+
>>> app = fl.client.ClientApp(
|
232
|
+
>>> client_fn=FlowerClient().to_client(), mods=[fixedclipping_mod]
|
233
|
+
>>> )
|
234
|
+
"""
|
235
|
+
|
236
|
+
# pylint: disable=too-many-arguments,too-many-instance-attributes
|
237
|
+
def __init__(
|
238
|
+
self,
|
239
|
+
strategy: Strategy,
|
240
|
+
noise_multiplier: float,
|
241
|
+
clipping_norm: float,
|
242
|
+
num_sampled_clients: int,
|
243
|
+
) -> None:
|
244
|
+
super().__init__()
|
245
|
+
|
246
|
+
self.strategy = strategy
|
247
|
+
|
248
|
+
if noise_multiplier < 0:
|
249
|
+
raise ValueError("The noise multiplier should be a non-negative value.")
|
250
|
+
|
251
|
+
if clipping_norm <= 0:
|
252
|
+
raise ValueError("The clipping threshold should be a positive value.")
|
253
|
+
|
254
|
+
if num_sampled_clients <= 0:
|
255
|
+
raise ValueError(
|
256
|
+
"The number of sampled clients should be a positive value."
|
257
|
+
)
|
258
|
+
|
259
|
+
self.noise_multiplier = noise_multiplier
|
260
|
+
self.clipping_norm = clipping_norm
|
261
|
+
self.num_sampled_clients = num_sampled_clients
|
262
|
+
|
263
|
+
def __repr__(self) -> str:
|
264
|
+
"""Compute a string representation of the strategy."""
|
265
|
+
rep = "Differential Privacy Strategy Wrapper (Client-Side Fixed Clipping)"
|
266
|
+
return rep
|
267
|
+
|
268
|
+
def initialize_parameters(
|
269
|
+
self, client_manager: ClientManager
|
270
|
+
) -> Optional[Parameters]:
|
271
|
+
"""Initialize global model parameters using given strategy."""
|
272
|
+
return self.strategy.initialize_parameters(client_manager)
|
273
|
+
|
274
|
+
def configure_fit(
|
275
|
+
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
276
|
+
) -> List[Tuple[ClientProxy, FitIns]]:
|
277
|
+
"""Configure the next round of training."""
|
278
|
+
additional_config = {KEY_CLIPPING_NORM: self.clipping_norm}
|
279
|
+
inner_strategy_config_result = self.strategy.configure_fit(
|
280
|
+
server_round, parameters, client_manager
|
281
|
+
)
|
282
|
+
for _, fit_ins in inner_strategy_config_result:
|
283
|
+
fit_ins.config.update(additional_config)
|
284
|
+
|
285
|
+
return inner_strategy_config_result
|
286
|
+
|
287
|
+
def configure_evaluate(
|
288
|
+
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
289
|
+
) -> List[Tuple[ClientProxy, EvaluateIns]]:
|
290
|
+
"""Configure the next round of evaluation."""
|
291
|
+
return self.strategy.configure_evaluate(
|
292
|
+
server_round, parameters, client_manager
|
293
|
+
)
|
294
|
+
|
295
|
+
def aggregate_fit(
|
296
|
+
self,
|
297
|
+
server_round: int,
|
298
|
+
results: List[Tuple[ClientProxy, FitRes]],
|
299
|
+
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
|
300
|
+
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
|
301
|
+
"""Add noise to the aggregated parameters."""
|
302
|
+
if failures:
|
303
|
+
return None, {}
|
304
|
+
|
305
|
+
if len(results) != self.num_sampled_clients:
|
306
|
+
log(
|
307
|
+
WARNING,
|
308
|
+
CLIENTS_DISCREPANCY_WARNING,
|
309
|
+
len(results),
|
310
|
+
self.num_sampled_clients,
|
311
|
+
)
|
312
|
+
|
313
|
+
# Pass the new parameters for aggregation
|
314
|
+
aggregated_params, metrics = self.strategy.aggregate_fit(
|
315
|
+
server_round, results, failures
|
316
|
+
)
|
317
|
+
|
318
|
+
# Add Gaussian noise to the aggregated parameters
|
319
|
+
if aggregated_params:
|
320
|
+
aggregated_params = add_gaussian_noise_to_params(
|
321
|
+
aggregated_params,
|
322
|
+
self.noise_multiplier,
|
323
|
+
self.clipping_norm,
|
324
|
+
self.num_sampled_clients,
|
325
|
+
)
|
326
|
+
return aggregated_params, metrics
|
327
|
+
|
328
|
+
def aggregate_evaluate(
|
329
|
+
self,
|
330
|
+
server_round: int,
|
331
|
+
results: List[Tuple[ClientProxy, EvaluateRes]],
|
332
|
+
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
|
333
|
+
) -> Tuple[Optional[float], Dict[str, Scalar]]:
|
334
|
+
"""Aggregate evaluation losses using the given strategy."""
|
335
|
+
return self.strategy.aggregate_evaluate(server_round, results, failures)
|
336
|
+
|
337
|
+
def evaluate(
|
338
|
+
self, server_round: int, parameters: Parameters
|
339
|
+
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
|
340
|
+
"""Evaluate model parameters using an evaluation function from the strategy."""
|
341
|
+
return self.strategy.evaluate(server_round, parameters)
|
@@ -18,7 +18,7 @@
|
|
18
18
|
from abc import ABC, abstractmethod
|
19
19
|
from typing import Callable, Dict, Tuple
|
20
20
|
|
21
|
-
from flwr.client.
|
21
|
+
from flwr.client.client_app import ClientApp
|
22
22
|
from flwr.common.context import Context
|
23
23
|
from flwr.common.message import Message
|
24
24
|
from flwr.common.typing import ConfigsRecordValues
|
@@ -15,10 +15,12 @@
|
|
15
15
|
"""Ray backend for the Fleet API using the Simulation Engine."""
|
16
16
|
|
17
17
|
import pathlib
|
18
|
-
from logging import INFO
|
18
|
+
from logging import ERROR, INFO
|
19
19
|
from typing import Callable, Dict, List, Tuple, Union
|
20
20
|
|
21
|
-
|
21
|
+
import ray
|
22
|
+
|
23
|
+
from flwr.client.client_app import ClientApp, LoadClientAppError
|
22
24
|
from flwr.common.context import Context
|
23
25
|
from flwr.common.logger import log
|
24
26
|
from flwr.common.message import Message
|
@@ -46,6 +48,9 @@ class RayBackend(Backend):
|
|
46
48
|
log(INFO, "Initialising: %s", self.__class__.__name__)
|
47
49
|
log(INFO, "Backend config: %s", backend_config)
|
48
50
|
|
51
|
+
if not pathlib.Path(work_dir).exists():
|
52
|
+
raise ValueError(f"Specified work_dir {work_dir} does not exist.")
|
53
|
+
|
49
54
|
# Init ray and append working dir if needed
|
50
55
|
runtime_env = (
|
51
56
|
self._configure_runtime_env(work_dir=work_dir) if work_dir else None
|
@@ -138,22 +143,34 @@ class RayBackend(Backend):
|
|
138
143
|
"""
|
139
144
|
node_id = message.metadata.dst_node_id
|
140
145
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
+
try:
|
147
|
+
# Submite a task to the pool
|
148
|
+
future = await self.pool.submit(
|
149
|
+
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
|
150
|
+
(app, message, str(node_id), context),
|
151
|
+
)
|
146
152
|
|
147
|
-
|
153
|
+
await future
|
148
154
|
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
155
|
+
# Fetch result
|
156
|
+
(
|
157
|
+
out_mssg,
|
158
|
+
updated_context,
|
159
|
+
) = await self.pool.fetch_result_and_return_actor_to_pool(future)
|
154
160
|
|
155
|
-
|
161
|
+
return out_mssg, updated_context
|
162
|
+
|
163
|
+
except LoadClientAppError as load_ex:
|
164
|
+
log(
|
165
|
+
ERROR,
|
166
|
+
"An exception was raised when processing a message. Terminating %s",
|
167
|
+
self.__class__.__name__,
|
168
|
+
)
|
169
|
+
await self.terminate()
|
170
|
+
raise load_ex
|
156
171
|
|
157
172
|
async def terminate(self) -> None:
|
158
173
|
"""Terminate all actors in actor pool."""
|
159
174
|
await self.pool.terminate_all_actors()
|
175
|
+
ray.shutdown()
|
176
|
+
log(INFO, "Terminated %s", self.__class__.__name__)
|
@@ -19,7 +19,7 @@ import json
|
|
19
19
|
from logging import ERROR, INFO
|
20
20
|
from typing import Dict, Optional
|
21
21
|
|
22
|
-
from flwr.client.
|
22
|
+
from flwr.client.client_app import ClientApp, load_client_app
|
23
23
|
from flwr.client.node_state import NodeState
|
24
24
|
from flwr.common.logger import log
|
25
25
|
from flwr.server.superlink.state import StateFactory
|
@@ -25,7 +25,7 @@ import ray
|
|
25
25
|
from ray import ObjectRef
|
26
26
|
from ray.util.actor_pool import ActorPool
|
27
27
|
|
28
|
-
from flwr.client.
|
28
|
+
from flwr.client.client_app import ClientApp, LoadClientAppError
|
29
29
|
from flwr.common import Context, Message
|
30
30
|
from flwr.common.logger import log
|
31
31
|
|
@@ -67,6 +67,9 @@ class VirtualClientEngineActor(ABC):
|
|
67
67
|
# Handle task message
|
68
68
|
out_message = app(message=message, context=context)
|
69
69
|
|
70
|
+
except LoadClientAppError as load_ex:
|
71
|
+
raise load_ex
|
72
|
+
|
70
73
|
except Exception as ex:
|
71
74
|
client_trace = traceback.format_exc()
|
72
75
|
mssg = (
|
@@ -21,7 +21,7 @@ from typing import Optional
|
|
21
21
|
|
22
22
|
from flwr import common
|
23
23
|
from flwr.client import ClientFn
|
24
|
-
from flwr.client.
|
24
|
+
from flwr.client.client_app import ClientApp
|
25
25
|
from flwr.client.node_state import NodeState
|
26
26
|
from flwr.common import Message, Metadata, RecordSet
|
27
27
|
from flwr.common.constant import (
|
{flwr_nightly-1.8.0.dev20240226.dist-info → flwr_nightly-1.8.0.dev20240227.dist-info}/RECORD
RENAMED
@@ -17,10 +17,10 @@ flwr/cli/new/templates/app/flower.toml.tpl,sha256=nGEU30gV6A2ZaRPt0ZVUjsxoevic4J
|
|
17
17
|
flwr/cli/new/templates/app/requirements.pytorch.txt.tpl,sha256=9Z70jsiCPdsbuorhicrSdO6PVQn-3196vKZ5Ka2GkK0,87
|
18
18
|
flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl,sha256=WTbIgK5G_iG0W-xtvQLCZMxL_Og26rFXial2TkYH5dw,211
|
19
19
|
flwr/cli/utils.py,sha256=9cCEIt8QmJXz85JNmPk1IHPd7p8E3KDn6h5CfF0nDL4,1926
|
20
|
-
flwr/client/__init__.py,sha256=
|
21
|
-
flwr/client/app.py,sha256=
|
20
|
+
flwr/client/__init__.py,sha256=futk_IdY_N1h8BTve4Iru51bxm7H1gv58ZPIXWi5XUA,1187
|
21
|
+
flwr/client/app.py,sha256=WUfnjDmlhL6bajdRj22PQjW2u3d6ONvWoDzOSAeIRcc,20074
|
22
22
|
flwr/client/client.py,sha256=Vp9UkOkoHdNfn6iMYZsj_5m_GICiFfUlKEVaLad-YhM,8183
|
23
|
-
flwr/client/
|
23
|
+
flwr/client/client_app.py,sha256=jrDgJBswP2hD1YdGgQoI3GU_NkliYWVU8glBJLOVzQY,4205
|
24
24
|
flwr/client/dpfedavg_numpy_client.py,sha256=9Tnig4iml2J88HBKNahegjXjbfvIQyBtaIQaqjbeqsA,7435
|
25
25
|
flwr/client/grpc_client/__init__.py,sha256=LsnbqXiJhgQcB0XzAlUQgPx011Uf7Y7yabIC1HxivJ8,735
|
26
26
|
flwr/client/grpc_client/connection.py,sha256=QJKv39MlcDMLr2YQ80ulm-2mD3bAozEz3VKnNsymbYs,8381
|
@@ -29,7 +29,8 @@ flwr/client/grpc_rere_client/connection.py,sha256=QfshoyA9yYuHK15Vb0hlB0QDv0dQRq
|
|
29
29
|
flwr/client/message_handler/__init__.py,sha256=abHvBRJJiiaAMNgeILQbMOa6h8WqMK2BcnvxwQZFpic,719
|
30
30
|
flwr/client/message_handler/message_handler.py,sha256=369gEm8t1Tbp_Y74XlOGMy_mvD1zawD-dLwsBL174tY,6594
|
31
31
|
flwr/client/message_handler/task_handler.py,sha256=ZDJBKmrn2grRMNl1rU1iGs7FiMHL5VmZiSp_6h9GHVU,1824
|
32
|
-
flwr/client/mod/__init__.py,sha256=
|
32
|
+
flwr/client/mod/__init__.py,sha256=6LRDFRjUAMevk2TQ_azVs2zczumXcyKGzAQDRWLoe7A,911
|
33
|
+
flwr/client/mod/centraldp_mods.py,sha256=zhzjikh7PG3DCpkPxBKJFR8fG4QNouhuE7l8wexMz-U,2893
|
33
34
|
flwr/client/mod/secure_aggregation/__init__.py,sha256=AzCdezuzX2BfXUuxVRwXdv8-zUIXoU-Bf6u4LRhzvg8,796
|
34
35
|
flwr/client/mod/secure_aggregation/secaggplus_mod.py,sha256=z_5t1YzqLs91ZLW5Yoo7Ozqw9_nyVuEpJ7Noa2a34bs,19890
|
35
36
|
flwr/client/mod/utils.py,sha256=lvETHcCYsSWz7h8I772hCV_kZspxqlMqzriMZ-SxmKc,1226
|
@@ -45,8 +46,9 @@ flwr/common/constant.py,sha256=jVUVKXo1cFb2HpRYqV70WKMG4RqCVrq7H6KC7zXs23Y,1572
|
|
45
46
|
flwr/common/context.py,sha256=ounF-mWPPtXGwtae3sg5EhF58ScviOa3MVqxRpGVu-8,1313
|
46
47
|
flwr/common/date.py,sha256=UWhBZj49yX9LD4BmatS_ZFZu_-kweGh0KQJ1djyWWH4,891
|
47
48
|
flwr/common/differential_privacy.py,sha256=pVSKRhciVNtdBlhoz1H0--8N5PMLjdO_bA1PLGq4WZ8,2969
|
48
|
-
flwr/common/differential_privacy_constants.py,sha256=
|
49
|
+
flwr/common/differential_privacy_constants.py,sha256=LUP9YurHRDm5--9jATnN2ddrnBSdEX30qvcys0BlUlY,1048
|
49
50
|
flwr/common/dp.py,sha256=Hc3lLHihjexbJaD_ft31gdv9XRcwOTgDBwJzICuok3A,2004
|
51
|
+
flwr/common/exit_handlers.py,sha256=2Nt0wLhc17KQQsLPFSRAjjhUiEFfJK6tNozdGiIY4Fs,2812
|
50
52
|
flwr/common/grpc.py,sha256=qVLB0d6bCuaBRW5YB0vEZXsR7Bo3R2lh4ONiCocqwRI,2270
|
51
53
|
flwr/common/logger.py,sha256=qX_gqEyrmGOH0x_r8uQ1Vskz4fGvEij9asdo4DUOPY8,4135
|
52
54
|
flwr/common/message.py,sha256=lCbaYFKSTI_Fpot-mJsq4rTvj46ZvqNAz9LVcJBlF1Q,6901
|
@@ -99,7 +101,7 @@ flwr/proto/transport_pb2_grpc.py,sha256=vLN3EHtx2aEEMCO4f1Upu-l27BPzd3-5pV-u8wPc
|
|
99
101
|
flwr/proto/transport_pb2_grpc.pyi,sha256=AGXf8RiIiW2J5IKMlm_3qT3AzcDa4F3P5IqUjve_esA,766
|
100
102
|
flwr/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
101
103
|
flwr/server/__init__.py,sha256=_Wv3UkZzSXzbKWXyl2yY8tU6oqf_XqIruggpHjnmikE,1662
|
102
|
-
flwr/server/app.py,sha256=
|
104
|
+
flwr/server/app.py,sha256=wf0CCq5y7kPcf_Jy3Dh8YhopUWo3v2eEVmyF9KhWsZY,25101
|
103
105
|
flwr/server/client_manager.py,sha256=T8UDSRJBVD3fyIDI7NTAA-NA7GPrMNNgH2OAF54RRxE,6127
|
104
106
|
flwr/server/client_proxy.py,sha256=8ScGDvP3jHbl8DV3hyFID5N5VEVlXn8ZTQXtkdOfssI,2234
|
105
107
|
flwr/server/compat/__init__.py,sha256=KNvRFANbIc8LFRKHsBVfxcbOSekEImWgNq_gapCkbic,812
|
@@ -114,10 +116,10 @@ flwr/server/run_serverapp.py,sha256=tYNbz11xMvkbcMty6u5nkLHmZhwIjSO78CfW-Aas5R8,
|
|
114
116
|
flwr/server/server.py,sha256=kUIqgLIXnWcSrhEhXXkaZPRooYhTGGX-RDCYzG9J76g,17495
|
115
117
|
flwr/server/server_app.py,sha256=avNQ7AMMKsn09ly81C3UBgOfHhM_R29l4MrzlalGoj8,5892
|
116
118
|
flwr/server/server_config.py,sha256=yOHpkdyuhOm--Gy_4Vofvu6jCDxhyECEDpIy02beuCg,1018
|
117
|
-
flwr/server/strategy/__init__.py,sha256
|
119
|
+
flwr/server/strategy/__init__.py,sha256=-VCKDaWfVpiplrgbbtJ0rUOmQ180oIh7bgenE9UDDlM,2327
|
118
120
|
flwr/server/strategy/aggregate.py,sha256=QyRIJtI5gnuY1NbgrcrOvkHxGIxBvApq7d9Y4xl-6W4,13468
|
119
121
|
flwr/server/strategy/bulyan.py,sha256=8GsSVJzRSoSWE2zQUKqC3Z795grdN9xpmc3MSGGXnzM,6532
|
120
|
-
flwr/server/strategy/dp_fixed_clipping.py,sha256=
|
122
|
+
flwr/server/strategy/dp_fixed_clipping.py,sha256=h7yICGeC-1CsJgOWlH8MpMVXE23gUV7GzefrvHoGYBw,12187
|
121
123
|
flwr/server/strategy/dpfedavg_adaptive.py,sha256=hLJkPQJl1bHjwrBNg3PSRFKf3no0hg5EHiFaWhHlWqw,4877
|
122
124
|
flwr/server/strategy/dpfedavg_fixed.py,sha256=G0yYxrPoM-MHQ889DYN3OeNiEeU0yQrjgAzcq0G653w,7219
|
123
125
|
flwr/server/strategy/fault_tolerant_fedavg.py,sha256=veGcehB6rXT_MihNDrD1v5JY-TxJi7fybdDl-OZooDQ,5900
|
@@ -154,9 +156,9 @@ flwr/server/superlink/fleet/rest_rere/__init__.py,sha256=VKDvDq5H8koOUztpmQacVzG
|
|
154
156
|
flwr/server/superlink/fleet/rest_rere/rest_api.py,sha256=7JCs7NW4Qq8W5QhXxqsQNFiCLlRY-b_iD420vH1Mu-U,5906
|
155
157
|
flwr/server/superlink/fleet/vce/__init__.py,sha256=bogHbcWSXkD7wZkqUXiLRKQTJUs7jtr5uwaGlmoA-Yc,785
|
156
158
|
flwr/server/superlink/fleet/vce/backend/__init__.py,sha256=oBIzmnrSSRvH_H0vRGEGWhWzQQwqe3zn6e13RsNwlIY,1466
|
157
|
-
flwr/server/superlink/fleet/vce/backend/backend.py,sha256=
|
158
|
-
flwr/server/superlink/fleet/vce/backend/raybackend.py,sha256=
|
159
|
-
flwr/server/superlink/fleet/vce/vce_api.py,sha256=
|
159
|
+
flwr/server/superlink/fleet/vce/backend/backend.py,sha256=LJsKl7oixVvptcG98Rd9ejJycNWcEVB0ODvSreLGp-A,2260
|
160
|
+
flwr/server/superlink/fleet/vce/backend/raybackend.py,sha256=EYnLpX9bTRBrLiEW0F2LLMC6kt7Zhcy10zxqeKKfIGg,6351
|
161
|
+
flwr/server/superlink/fleet/vce/vce_api.py,sha256=TiDuQVdClc38heSjD3b2jo7kZg-KZuNhzh0OdJmmwT4,3099
|
160
162
|
flwr/server/superlink/state/__init__.py,sha256=ij-7Ms-hyordQdRmGQxY1-nVa4OhixJ0jr7_YDkys0s,1003
|
161
163
|
flwr/server/superlink/state/in_memory_state.py,sha256=sZX5XcpnU9cafhhC4Or5oRGIbKR2AKdjIfBDtMVGNLQ,8105
|
162
164
|
flwr/server/superlink/state/sqlite_state.py,sha256=Adc2g1DecAN9Cl9F8lekuTb885mIHiOi6sQv4nxbmSc,21203
|
@@ -169,11 +171,11 @@ flwr/server/utils/validator.py,sha256=IJN2475yyD_i_9kg_SJ_JodIuZh58ufpWGUDQRAqu2
|
|
169
171
|
flwr/simulation/__init__.py,sha256=E2eD5FlTmZZ80u21FmWCkacrM7O4mrEHD8iXqeCaBUQ,1278
|
170
172
|
flwr/simulation/app.py,sha256=WqJxdXTEuehwMW605p5NMmvBbKYx5tuqnV3Mp7jSWXM,13904
|
171
173
|
flwr/simulation/ray_transport/__init__.py,sha256=FsaAnzC4cw4DqoouBCix6496k29jACkfeIam55BvW9g,734
|
172
|
-
flwr/simulation/ray_transport/ray_actor.py,sha256=
|
173
|
-
flwr/simulation/ray_transport/ray_client_proxy.py,sha256=
|
174
|
+
flwr/simulation/ray_transport/ray_actor.py,sha256=zRETW_xuCAOLRFaYnQ-q3IBSz0LIv_0RifGuhgWaYOg,19872
|
175
|
+
flwr/simulation/ray_transport/ray_client_proxy.py,sha256=DVuequIvgbXQGYz_8oIWghnE_sPiiswAGOVfHNNF4U8,6335
|
174
176
|
flwr/simulation/ray_transport/utils.py,sha256=TYdtfg1P9VfTdLMOJlifInGpxWHYs9UfUqIv2wfkRLA,2392
|
175
|
-
flwr_nightly-1.8.0.
|
176
|
-
flwr_nightly-1.8.0.
|
177
|
-
flwr_nightly-1.8.0.
|
178
|
-
flwr_nightly-1.8.0.
|
179
|
-
flwr_nightly-1.8.0.
|
177
|
+
flwr_nightly-1.8.0.dev20240227.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
178
|
+
flwr_nightly-1.8.0.dev20240227.dist-info/METADATA,sha256=a6sTBhHs6TpkUp9isUB69U9Q7PoY_or1jNb7gA44z2w,15040
|
179
|
+
flwr_nightly-1.8.0.dev20240227.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
|
180
|
+
flwr_nightly-1.8.0.dev20240227.dist-info/entry_points.txt,sha256=S1zLNFLrz0uPWs4Zrgo2EPY0iQiIcCJHrIAlnQkkOBI,262
|
181
|
+
flwr_nightly-1.8.0.dev20240227.dist-info/RECORD,,
|
File without changes
|
{flwr_nightly-1.8.0.dev20240226.dist-info → flwr_nightly-1.8.0.dev20240227.dist-info}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|