flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/mod/__init__.py +3 -2
- flwr/client/mod/centraldp_mods.py +63 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
- flwr/common/differential_privacy.py +77 -0
- flwr/common/differential_privacy_constants.py +1 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
- flwr/proto/error_pb2.py +26 -0
- flwr/proto/error_pb2.pyi +25 -0
- flwr/proto/error_pb2_grpc.py +4 -0
- flwr/proto/error_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +8 -7
- flwr/proto/task_pb2.pyi +7 -2
- flwr/server/__init__.py +4 -0
- flwr/server/app.py +8 -31
- flwr/server/client_proxy.py +5 -0
- flwr/server/compat/__init__.py +2 -0
- flwr/server/compat/app.py +7 -88
- flwr/server/compat/app_utils.py +102 -0
- flwr/server/compat/driver_client_proxy.py +22 -10
- flwr/server/compat/legacy_context.py +55 -0
- flwr/server/run_serverapp.py +1 -1
- flwr/server/server.py +18 -8
- flwr/server/strategy/__init__.py +24 -14
- flwr/server/strategy/dp_adaptive_clipping.py +449 -0
- flwr/server/strategy/dp_fixed_clipping.py +5 -7
- flwr/server/superlink/driver/driver_grpc.py +54 -0
- flwr/server/superlink/driver/driver_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
- flwr/server/superlink/fleet/vce/vce_api.py +236 -16
- flwr/server/typing.py +1 -0
- flwr/server/workflow/__init__.py +22 -0
- flwr/server/workflow/default_workflows.py +357 -0
- flwr/simulation/__init__.py +3 -0
- flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
- flwr/simulation/run_simulation.py +177 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
flwr/proto/task_pb2.pyi
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
isort:skip_file
|
4
4
|
"""
|
5
5
|
import builtins
|
6
|
+
import flwr.proto.error_pb2
|
6
7
|
import flwr.proto.node_pb2
|
7
8
|
import flwr.proto.recordset_pb2
|
8
9
|
import google.protobuf.descriptor
|
@@ -23,6 +24,7 @@ class Task(google.protobuf.message.Message):
|
|
23
24
|
ANCESTRY_FIELD_NUMBER: builtins.int
|
24
25
|
TASK_TYPE_FIELD_NUMBER: builtins.int
|
25
26
|
RECORDSET_FIELD_NUMBER: builtins.int
|
27
|
+
ERROR_FIELD_NUMBER: builtins.int
|
26
28
|
@property
|
27
29
|
def producer(self) -> flwr.proto.node_pb2.Node: ...
|
28
30
|
@property
|
@@ -35,6 +37,8 @@ class Task(google.protobuf.message.Message):
|
|
35
37
|
task_type: typing.Text
|
36
38
|
@property
|
37
39
|
def recordset(self) -> flwr.proto.recordset_pb2.RecordSet: ...
|
40
|
+
@property
|
41
|
+
def error(self) -> flwr.proto.error_pb2.Error: ...
|
38
42
|
def __init__(self,
|
39
43
|
*,
|
40
44
|
producer: typing.Optional[flwr.proto.node_pb2.Node] = ...,
|
@@ -45,9 +49,10 @@ class Task(google.protobuf.message.Message):
|
|
45
49
|
ancestry: typing.Optional[typing.Iterable[typing.Text]] = ...,
|
46
50
|
task_type: typing.Text = ...,
|
47
51
|
recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ...,
|
52
|
+
error: typing.Optional[flwr.proto.error_pb2.Error] = ...,
|
48
53
|
) -> None: ...
|
49
|
-
def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
|
50
|
-
def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
|
54
|
+
def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","error",b"error","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
|
55
|
+
def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
|
51
56
|
global___Task = Task
|
52
57
|
|
53
58
|
class TaskIns(google.protobuf.message.Message):
|
flwr/server/__init__.py
CHANGED
@@ -16,12 +16,14 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
from . import strategy
|
19
|
+
from . import workflow as workflow
|
19
20
|
from .app import run_driver_api as run_driver_api
|
20
21
|
from .app import run_fleet_api as run_fleet_api
|
21
22
|
from .app import run_superlink as run_superlink
|
22
23
|
from .app import start_server as start_server
|
23
24
|
from .client_manager import ClientManager as ClientManager
|
24
25
|
from .client_manager import SimpleClientManager as SimpleClientManager
|
26
|
+
from .compat import LegacyContext as LegacyContext
|
25
27
|
from .compat import start_driver as start_driver
|
26
28
|
from .driver import Driver as Driver
|
27
29
|
from .history import History as History
|
@@ -34,6 +36,7 @@ __all__ = [
|
|
34
36
|
"ClientManager",
|
35
37
|
"Driver",
|
36
38
|
"History",
|
39
|
+
"LegacyContext",
|
37
40
|
"run_driver_api",
|
38
41
|
"run_fleet_api",
|
39
42
|
"run_server_app",
|
@@ -45,4 +48,5 @@ __all__ = [
|
|
45
48
|
"start_driver",
|
46
49
|
"start_server",
|
47
50
|
"strategy",
|
51
|
+
"workflow",
|
48
52
|
]
|
flwr/server/app.py
CHANGED
@@ -14,8 +14,8 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Flower server app."""
|
16
16
|
|
17
|
-
|
18
17
|
import argparse
|
18
|
+
import asyncio
|
19
19
|
import importlib.util
|
20
20
|
import sys
|
21
21
|
import threading
|
@@ -36,9 +36,6 @@ from flwr.common.constant import (
|
|
36
36
|
)
|
37
37
|
from flwr.common.exit_handlers import register_exit_handlers
|
38
38
|
from flwr.common.logger import log
|
39
|
-
from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
|
40
|
-
add_DriverServicer_to_server,
|
41
|
-
)
|
42
39
|
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
43
40
|
add_FleetServicer_to_server,
|
44
41
|
)
|
@@ -48,7 +45,7 @@ from .history import History
|
|
48
45
|
from .server import Server, init_defaults, run_fl
|
49
46
|
from .server_config import ServerConfig
|
50
47
|
from .strategy import Strategy
|
51
|
-
from .superlink.driver.
|
48
|
+
from .superlink.driver.driver_grpc import run_driver_api_grpc
|
52
49
|
from .superlink.fleet.grpc_bidi.grpc_server import (
|
53
50
|
generic_create_grpc_server,
|
54
51
|
start_grpc_server,
|
@@ -204,7 +201,7 @@ def run_driver_api() -> None:
|
|
204
201
|
state_factory = StateFactory(args.database)
|
205
202
|
|
206
203
|
# Start server
|
207
|
-
grpc_server: grpc.Server =
|
204
|
+
grpc_server: grpc.Server = run_driver_api_grpc(
|
208
205
|
address=address,
|
209
206
|
state_factory=state_factory,
|
210
207
|
certificates=certificates,
|
@@ -313,7 +310,7 @@ def run_superlink() -> None:
|
|
313
310
|
state_factory = StateFactory(args.database)
|
314
311
|
|
315
312
|
# Start Driver API
|
316
|
-
driver_server: grpc.Server =
|
313
|
+
driver_server: grpc.Server = run_driver_api_grpc(
|
317
314
|
address=address,
|
318
315
|
state_factory=state_factory,
|
319
316
|
certificates=certificates,
|
@@ -362,6 +359,7 @@ def run_superlink() -> None:
|
|
362
359
|
)
|
363
360
|
grpc_servers.append(fleet_server)
|
364
361
|
elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
|
362
|
+
f_stop = asyncio.Event() # Does nothing
|
365
363
|
_run_fleet_api_vce(
|
366
364
|
num_supernodes=args.num_supernodes,
|
367
365
|
client_app_module_name=args.client_app,
|
@@ -369,6 +367,7 @@ def run_superlink() -> None:
|
|
369
367
|
backend_config_json_stream=args.backend_config,
|
370
368
|
working_dir=args.dir,
|
371
369
|
state_factory=state_factory,
|
370
|
+
f_stop=f_stop,
|
372
371
|
)
|
373
372
|
else:
|
374
373
|
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
|
@@ -412,30 +411,6 @@ def _try_obtain_certificates(
|
|
412
411
|
return certificates
|
413
412
|
|
414
413
|
|
415
|
-
def _run_driver_api_grpc(
|
416
|
-
address: str,
|
417
|
-
state_factory: StateFactory,
|
418
|
-
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
419
|
-
) -> grpc.Server:
|
420
|
-
"""Run Driver API (gRPC, request-response)."""
|
421
|
-
# Create Driver API gRPC server
|
422
|
-
driver_servicer: grpc.Server = DriverServicer(
|
423
|
-
state_factory=state_factory,
|
424
|
-
)
|
425
|
-
driver_add_servicer_to_server_fn = add_DriverServicer_to_server
|
426
|
-
driver_grpc_server = generic_create_grpc_server(
|
427
|
-
servicer_and_add_fn=(driver_servicer, driver_add_servicer_to_server_fn),
|
428
|
-
server_address=address,
|
429
|
-
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
430
|
-
certificates=certificates,
|
431
|
-
)
|
432
|
-
|
433
|
-
log(INFO, "Flower ECE: Starting Driver API (gRPC-rere) on %s", address)
|
434
|
-
driver_grpc_server.start()
|
435
|
-
|
436
|
-
return driver_grpc_server
|
437
|
-
|
438
|
-
|
439
414
|
def _run_fleet_api_grpc_rere(
|
440
415
|
address: str,
|
441
416
|
state_factory: StateFactory,
|
@@ -468,6 +443,7 @@ def _run_fleet_api_vce(
|
|
468
443
|
backend_config_json_stream: str,
|
469
444
|
working_dir: str,
|
470
445
|
state_factory: StateFactory,
|
446
|
+
f_stop: asyncio.Event,
|
471
447
|
) -> None:
|
472
448
|
log(INFO, "Flower VCE: Starting Fleet API (VirtualClientEngine)")
|
473
449
|
|
@@ -478,6 +454,7 @@ def _run_fleet_api_vce(
|
|
478
454
|
backend_config_json_stream=backend_config_json_stream,
|
479
455
|
state_factory=state_factory,
|
480
456
|
working_dir=working_dir,
|
457
|
+
f_stop=f_stop,
|
481
458
|
)
|
482
459
|
|
483
460
|
|
flwr/server/client_proxy.py
CHANGED
@@ -47,6 +47,7 @@ class ClientProxy(ABC):
|
|
47
47
|
self,
|
48
48
|
ins: GetPropertiesIns,
|
49
49
|
timeout: Optional[float],
|
50
|
+
group_id: Optional[int],
|
50
51
|
) -> GetPropertiesRes:
|
51
52
|
"""Return the client's properties."""
|
52
53
|
|
@@ -55,6 +56,7 @@ class ClientProxy(ABC):
|
|
55
56
|
self,
|
56
57
|
ins: GetParametersIns,
|
57
58
|
timeout: Optional[float],
|
59
|
+
group_id: Optional[int],
|
58
60
|
) -> GetParametersRes:
|
59
61
|
"""Return the current local model parameters."""
|
60
62
|
|
@@ -63,6 +65,7 @@ class ClientProxy(ABC):
|
|
63
65
|
self,
|
64
66
|
ins: FitIns,
|
65
67
|
timeout: Optional[float],
|
68
|
+
group_id: Optional[int],
|
66
69
|
) -> FitRes:
|
67
70
|
"""Refine the provided parameters using the locally held dataset."""
|
68
71
|
|
@@ -71,6 +74,7 @@ class ClientProxy(ABC):
|
|
71
74
|
self,
|
72
75
|
ins: EvaluateIns,
|
73
76
|
timeout: Optional[float],
|
77
|
+
group_id: Optional[int],
|
74
78
|
) -> EvaluateRes:
|
75
79
|
"""Evaluate the provided parameters using the locally held dataset."""
|
76
80
|
|
@@ -79,5 +83,6 @@ class ClientProxy(ABC):
|
|
79
83
|
self,
|
80
84
|
ins: ReconnectIns,
|
81
85
|
timeout: Optional[float],
|
86
|
+
group_id: Optional[int],
|
82
87
|
) -> DisconnectRes:
|
83
88
|
"""Disconnect and (optionally) reconnect later."""
|
flwr/server/compat/__init__.py
CHANGED
flwr/server/compat/app.py
CHANGED
@@ -16,16 +16,13 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import sys
|
19
|
-
import threading
|
20
|
-
import time
|
21
19
|
from logging import INFO
|
22
20
|
from pathlib import Path
|
23
|
-
from typing import
|
21
|
+
from typing import Optional, Union
|
24
22
|
|
25
23
|
from flwr.common import EventType, event
|
26
24
|
from flwr.common.address import parse_address
|
27
25
|
from flwr.common.logger import log, warn_deprecated_feature
|
28
|
-
from flwr.proto import driver_pb2 # pylint: disable=E0611
|
29
26
|
from flwr.server.client_manager import ClientManager
|
30
27
|
from flwr.server.history import History
|
31
28
|
from flwr.server.server import Server, init_defaults, run_fl
|
@@ -33,8 +30,7 @@ from flwr.server.server_config import ServerConfig
|
|
33
30
|
from flwr.server.strategy import Strategy
|
34
31
|
|
35
32
|
from ..driver import Driver
|
36
|
-
from
|
37
|
-
from .driver_client_proxy import DriverClientProxy
|
33
|
+
from .app_utils import start_update_client_manager_thread
|
38
34
|
|
39
35
|
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
40
36
|
|
@@ -104,11 +100,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
104
100
|
"""
|
105
101
|
event(EventType.START_DRIVER_ENTER)
|
106
102
|
|
107
|
-
if driver:
|
108
|
-
# pylint: disable=protected-access
|
109
|
-
grpc_driver, _ = driver._get_grpc_driver_and_run_id()
|
110
|
-
# pylint: enable=protected-access
|
111
|
-
else:
|
103
|
+
if driver is None:
|
112
104
|
# Not passing a `Driver` object is deprecated
|
113
105
|
warn_deprecated_feature("start_driver")
|
114
106
|
|
@@ -122,12 +114,9 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
122
114
|
# Create the Driver
|
123
115
|
if isinstance(root_certificates, str):
|
124
116
|
root_certificates = Path(root_certificates).read_bytes()
|
125
|
-
|
117
|
+
driver = Driver(
|
126
118
|
driver_service_address=address, root_certificates=root_certificates
|
127
119
|
)
|
128
|
-
grpc_driver.connect()
|
129
|
-
|
130
|
-
lock = threading.Lock()
|
131
120
|
|
132
121
|
# Initialize the Driver API server and config
|
133
122
|
initialized_server, initialized_config = init_defaults(
|
@@ -142,18 +131,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
142
131
|
initialized_config,
|
143
132
|
)
|
144
133
|
|
145
|
-
f_stop = threading.Event()
|
146
134
|
# Start the thread updating nodes
|
147
|
-
thread =
|
148
|
-
|
149
|
-
args=(
|
150
|
-
grpc_driver,
|
151
|
-
initialized_server.client_manager(),
|
152
|
-
lock,
|
153
|
-
f_stop,
|
154
|
-
),
|
135
|
+
thread, f_stop = start_update_client_manager_thread(
|
136
|
+
driver, initialized_server.client_manager()
|
155
137
|
)
|
156
|
-
thread.start()
|
157
138
|
|
158
139
|
# Start training
|
159
140
|
hist = run_fl(
|
@@ -164,72 +145,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
164
145
|
f_stop.set()
|
165
146
|
|
166
147
|
# Stop the Driver API server and the thread
|
167
|
-
|
168
|
-
if driver:
|
169
|
-
del driver
|
170
|
-
else:
|
171
|
-
grpc_driver.disconnect()
|
148
|
+
del driver
|
172
149
|
|
173
150
|
thread.join()
|
174
151
|
|
175
152
|
event(EventType.START_SERVER_LEAVE)
|
176
153
|
|
177
154
|
return hist
|
178
|
-
|
179
|
-
|
180
|
-
def update_client_manager(
|
181
|
-
driver: GrpcDriver,
|
182
|
-
client_manager: ClientManager,
|
183
|
-
lock: threading.Lock,
|
184
|
-
f_stop: threading.Event,
|
185
|
-
) -> None:
|
186
|
-
"""Update the nodes list in the client manager.
|
187
|
-
|
188
|
-
This function periodically communicates with the associated driver to get all
|
189
|
-
node_ids. Each node_id is then converted into a `DriverClientProxy` instance
|
190
|
-
and stored in the `registered_nodes` dictionary with node_id as key.
|
191
|
-
|
192
|
-
New nodes will be added to the ClientManager via `client_manager.register()`,
|
193
|
-
and dead nodes will be removed from the ClientManager via
|
194
|
-
`client_manager.unregister()`.
|
195
|
-
"""
|
196
|
-
# Request for run_id
|
197
|
-
run_id = driver.create_run(
|
198
|
-
driver_pb2.CreateRunRequest() # pylint: disable=E1101
|
199
|
-
).run_id
|
200
|
-
|
201
|
-
# Loop until the driver is disconnected
|
202
|
-
registered_nodes: Dict[int, DriverClientProxy] = {}
|
203
|
-
while not f_stop.is_set():
|
204
|
-
with lock:
|
205
|
-
# End the while loop if the driver is disconnected
|
206
|
-
if driver.stub is None:
|
207
|
-
break
|
208
|
-
get_nodes_res = driver.get_nodes(
|
209
|
-
req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101
|
210
|
-
)
|
211
|
-
all_node_ids = {node.node_id for node in get_nodes_res.nodes}
|
212
|
-
dead_nodes = set(registered_nodes).difference(all_node_ids)
|
213
|
-
new_nodes = all_node_ids.difference(registered_nodes)
|
214
|
-
|
215
|
-
# Unregister dead nodes
|
216
|
-
for node_id in dead_nodes:
|
217
|
-
client_proxy = registered_nodes[node_id]
|
218
|
-
client_manager.unregister(client_proxy)
|
219
|
-
del registered_nodes[node_id]
|
220
|
-
|
221
|
-
# Register new nodes
|
222
|
-
for node_id in new_nodes:
|
223
|
-
client_proxy = DriverClientProxy(
|
224
|
-
node_id=node_id,
|
225
|
-
driver=driver,
|
226
|
-
anonymous=False,
|
227
|
-
run_id=run_id,
|
228
|
-
)
|
229
|
-
if client_manager.register(client_proxy):
|
230
|
-
registered_nodes[node_id] = client_proxy
|
231
|
-
else:
|
232
|
-
raise RuntimeError("Could not register node.")
|
233
|
-
|
234
|
-
# Sleep for 3 seconds
|
235
|
-
time.sleep(3)
|
@@ -0,0 +1,102 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utility functions for the `start_driver`."""
|
16
|
+
|
17
|
+
|
18
|
+
import threading
|
19
|
+
import time
|
20
|
+
from typing import Dict, Tuple
|
21
|
+
|
22
|
+
from ..client_manager import ClientManager
|
23
|
+
from ..compat.driver_client_proxy import DriverClientProxy
|
24
|
+
from ..driver import Driver
|
25
|
+
|
26
|
+
|
27
|
+
def start_update_client_manager_thread(
|
28
|
+
driver: Driver,
|
29
|
+
client_manager: ClientManager,
|
30
|
+
) -> Tuple[threading.Thread, threading.Event]:
|
31
|
+
"""Periodically update the nodes list in the client manager in a thread.
|
32
|
+
|
33
|
+
This function starts a thread that periodically uses the associated driver to
|
34
|
+
get all node_ids. Each node_id is then converted into a `DriverClientProxy`
|
35
|
+
instance and stored in the `registered_nodes` dictionary with node_id as key.
|
36
|
+
|
37
|
+
New nodes will be added to the ClientManager via `client_manager.register()`,
|
38
|
+
and dead nodes will be removed from the ClientManager via
|
39
|
+
`client_manager.unregister()`.
|
40
|
+
|
41
|
+
Parameters
|
42
|
+
----------
|
43
|
+
driver : Driver
|
44
|
+
The Driver object to use.
|
45
|
+
client_manager : ClientManager
|
46
|
+
The ClientManager object to be updated.
|
47
|
+
|
48
|
+
Returns
|
49
|
+
-------
|
50
|
+
threading.Thread
|
51
|
+
A thread that updates the ClientManager and handles the stop event.
|
52
|
+
threading.Event
|
53
|
+
An event that, when set, signals the thread to stop.
|
54
|
+
"""
|
55
|
+
f_stop = threading.Event()
|
56
|
+
thread = threading.Thread(
|
57
|
+
target=_update_client_manager,
|
58
|
+
args=(
|
59
|
+
driver,
|
60
|
+
client_manager,
|
61
|
+
f_stop,
|
62
|
+
),
|
63
|
+
)
|
64
|
+
thread.start()
|
65
|
+
|
66
|
+
return thread, f_stop
|
67
|
+
|
68
|
+
|
69
|
+
def _update_client_manager(
|
70
|
+
driver: Driver,
|
71
|
+
client_manager: ClientManager,
|
72
|
+
f_stop: threading.Event,
|
73
|
+
) -> None:
|
74
|
+
"""Update the nodes list in the client manager."""
|
75
|
+
# Loop until the driver is disconnected
|
76
|
+
registered_nodes: Dict[int, DriverClientProxy] = {}
|
77
|
+
while not f_stop.is_set():
|
78
|
+
all_node_ids = set(driver.get_node_ids())
|
79
|
+
dead_nodes = set(registered_nodes).difference(all_node_ids)
|
80
|
+
new_nodes = all_node_ids.difference(registered_nodes)
|
81
|
+
|
82
|
+
# Unregister dead nodes
|
83
|
+
for node_id in dead_nodes:
|
84
|
+
client_proxy = registered_nodes[node_id]
|
85
|
+
client_manager.unregister(client_proxy)
|
86
|
+
del registered_nodes[node_id]
|
87
|
+
|
88
|
+
# Register new nodes
|
89
|
+
for node_id in new_nodes:
|
90
|
+
client_proxy = DriverClientProxy(
|
91
|
+
node_id=node_id,
|
92
|
+
driver=driver.grpc_driver, # type: ignore
|
93
|
+
anonymous=False,
|
94
|
+
run_id=driver.run_id, # type: ignore
|
95
|
+
)
|
96
|
+
if client_manager.register(client_proxy):
|
97
|
+
registered_nodes[node_id] = client_proxy
|
98
|
+
else:
|
99
|
+
raise RuntimeError("Could not register node.")
|
100
|
+
|
101
|
+
# Sleep for 3 seconds
|
102
|
+
time.sleep(3)
|
@@ -47,57 +47,68 @@ class DriverClientProxy(ClientProxy):
|
|
47
47
|
self.anonymous = anonymous
|
48
48
|
|
49
49
|
def get_properties(
|
50
|
-
self,
|
50
|
+
self,
|
51
|
+
ins: common.GetPropertiesIns,
|
52
|
+
timeout: Optional[float],
|
53
|
+
group_id: Optional[int],
|
51
54
|
) -> common.GetPropertiesRes:
|
52
55
|
"""Return client's properties."""
|
53
56
|
# Ins to RecordSet
|
54
57
|
out_recordset = compat.getpropertiesins_to_recordset(ins)
|
55
58
|
# Fetch response
|
56
59
|
in_recordset = self._send_receive_recordset(
|
57
|
-
out_recordset, MESSAGE_TYPE_GET_PROPERTIES, timeout
|
60
|
+
out_recordset, MESSAGE_TYPE_GET_PROPERTIES, timeout, group_id
|
58
61
|
)
|
59
62
|
# RecordSet to Res
|
60
63
|
return compat.recordset_to_getpropertiesres(in_recordset)
|
61
64
|
|
62
65
|
def get_parameters(
|
63
|
-
self,
|
66
|
+
self,
|
67
|
+
ins: common.GetParametersIns,
|
68
|
+
timeout: Optional[float],
|
69
|
+
group_id: Optional[int],
|
64
70
|
) -> common.GetParametersRes:
|
65
71
|
"""Return the current local model parameters."""
|
66
72
|
# Ins to RecordSet
|
67
73
|
out_recordset = compat.getparametersins_to_recordset(ins)
|
68
74
|
# Fetch response
|
69
75
|
in_recordset = self._send_receive_recordset(
|
70
|
-
out_recordset, MESSAGE_TYPE_GET_PARAMETERS, timeout
|
76
|
+
out_recordset, MESSAGE_TYPE_GET_PARAMETERS, timeout, group_id
|
71
77
|
)
|
72
78
|
# RecordSet to Res
|
73
79
|
return compat.recordset_to_getparametersres(in_recordset, False)
|
74
80
|
|
75
|
-
def fit(
|
81
|
+
def fit(
|
82
|
+
self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
|
83
|
+
) -> common.FitRes:
|
76
84
|
"""Train model parameters on the locally held dataset."""
|
77
85
|
# Ins to RecordSet
|
78
86
|
out_recordset = compat.fitins_to_recordset(ins, keep_input=True)
|
79
87
|
# Fetch response
|
80
88
|
in_recordset = self._send_receive_recordset(
|
81
|
-
out_recordset, MESSAGE_TYPE_FIT, timeout
|
89
|
+
out_recordset, MESSAGE_TYPE_FIT, timeout, group_id
|
82
90
|
)
|
83
91
|
# RecordSet to Res
|
84
92
|
return compat.recordset_to_fitres(in_recordset, keep_input=False)
|
85
93
|
|
86
94
|
def evaluate(
|
87
|
-
self, ins: common.EvaluateIns, timeout: Optional[float]
|
95
|
+
self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
|
88
96
|
) -> common.EvaluateRes:
|
89
97
|
"""Evaluate model parameters on the locally held dataset."""
|
90
98
|
# Ins to RecordSet
|
91
99
|
out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True)
|
92
100
|
# Fetch response
|
93
101
|
in_recordset = self._send_receive_recordset(
|
94
|
-
out_recordset, MESSAGE_TYPE_EVALUATE, timeout
|
102
|
+
out_recordset, MESSAGE_TYPE_EVALUATE, timeout, group_id
|
95
103
|
)
|
96
104
|
# RecordSet to Res
|
97
105
|
return compat.recordset_to_evaluateres(in_recordset)
|
98
106
|
|
99
107
|
def reconnect(
|
100
|
-
self,
|
108
|
+
self,
|
109
|
+
ins: common.ReconnectIns,
|
110
|
+
timeout: Optional[float],
|
111
|
+
group_id: Optional[int],
|
101
112
|
) -> common.DisconnectRes:
|
102
113
|
"""Disconnect and (optionally) reconnect later."""
|
103
114
|
return common.DisconnectRes(reason="") # Nothing to do here (yet)
|
@@ -107,10 +118,11 @@ class DriverClientProxy(ClientProxy):
|
|
107
118
|
recordset: RecordSet,
|
108
119
|
task_type: str,
|
109
120
|
timeout: Optional[float],
|
121
|
+
group_id: Optional[int],
|
110
122
|
) -> RecordSet:
|
111
123
|
task_ins = task_pb2.TaskIns( # pylint: disable=E1101
|
112
124
|
task_id="",
|
113
|
-
group_id="",
|
125
|
+
group_id=str(group_id) if group_id is not None else "",
|
114
126
|
run_id=self.run_id,
|
115
127
|
task=task_pb2.Task( # pylint: disable=E1101
|
116
128
|
producer=node_pb2.Node( # pylint: disable=E1101
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Legacy Context."""
|
16
|
+
|
17
|
+
|
18
|
+
from dataclasses import dataclass
|
19
|
+
from typing import Optional
|
20
|
+
|
21
|
+
from flwr.common import Context, RecordSet
|
22
|
+
|
23
|
+
from ..client_manager import ClientManager, SimpleClientManager
|
24
|
+
from ..history import History
|
25
|
+
from ..server_config import ServerConfig
|
26
|
+
from ..strategy import FedAvg, Strategy
|
27
|
+
|
28
|
+
|
29
|
+
@dataclass
|
30
|
+
class LegacyContext(Context):
|
31
|
+
"""Legacy Context."""
|
32
|
+
|
33
|
+
config: ServerConfig
|
34
|
+
strategy: Strategy
|
35
|
+
client_manager: ClientManager
|
36
|
+
history: History
|
37
|
+
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
state: RecordSet,
|
41
|
+
config: Optional[ServerConfig] = None,
|
42
|
+
strategy: Optional[Strategy] = None,
|
43
|
+
client_manager: Optional[ClientManager] = None,
|
44
|
+
) -> None:
|
45
|
+
if config is None:
|
46
|
+
config = ServerConfig()
|
47
|
+
if strategy is None:
|
48
|
+
strategy = FedAvg()
|
49
|
+
if client_manager is None:
|
50
|
+
client_manager = SimpleClientManager()
|
51
|
+
self.config = config
|
52
|
+
self.strategy = strategy
|
53
|
+
self.client_manager = client_manager
|
54
|
+
self.history = History()
|
55
|
+
super().__init__(state)
|
flwr/server/run_serverapp.py
CHANGED