flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/client/mod/__init__.py +3 -2
- flwr/client/mod/centraldp_mods.py +63 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
- flwr/common/differential_privacy.py +77 -0
- flwr/common/differential_privacy_constants.py +1 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
- flwr/proto/error_pb2.py +26 -0
- flwr/proto/error_pb2.pyi +25 -0
- flwr/proto/error_pb2_grpc.py +4 -0
- flwr/proto/error_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +8 -7
- flwr/proto/task_pb2.pyi +7 -2
- flwr/server/__init__.py +4 -0
- flwr/server/app.py +8 -31
- flwr/server/client_proxy.py +5 -0
- flwr/server/compat/__init__.py +2 -0
- flwr/server/compat/app.py +7 -88
- flwr/server/compat/app_utils.py +102 -0
- flwr/server/compat/driver_client_proxy.py +22 -10
- flwr/server/compat/legacy_context.py +55 -0
- flwr/server/run_serverapp.py +1 -1
- flwr/server/server.py +18 -8
- flwr/server/strategy/__init__.py +24 -14
- flwr/server/strategy/dp_adaptive_clipping.py +449 -0
- flwr/server/strategy/dp_fixed_clipping.py +5 -7
- flwr/server/superlink/driver/driver_grpc.py +54 -0
- flwr/server/superlink/driver/driver_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
- flwr/server/superlink/fleet/vce/vce_api.py +236 -16
- flwr/server/typing.py +1 -0
- flwr/server/workflow/__init__.py +22 -0
- flwr/server/workflow/default_workflows.py +357 -0
- flwr/simulation/__init__.py +3 -0
- flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
- flwr/simulation/run_simulation.py +177 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
flwr/proto/task_pb2.pyi
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
isort:skip_file
|
4
4
|
"""
|
5
5
|
import builtins
|
6
|
+
import flwr.proto.error_pb2
|
6
7
|
import flwr.proto.node_pb2
|
7
8
|
import flwr.proto.recordset_pb2
|
8
9
|
import google.protobuf.descriptor
|
@@ -23,6 +24,7 @@ class Task(google.protobuf.message.Message):
|
|
23
24
|
ANCESTRY_FIELD_NUMBER: builtins.int
|
24
25
|
TASK_TYPE_FIELD_NUMBER: builtins.int
|
25
26
|
RECORDSET_FIELD_NUMBER: builtins.int
|
27
|
+
ERROR_FIELD_NUMBER: builtins.int
|
26
28
|
@property
|
27
29
|
def producer(self) -> flwr.proto.node_pb2.Node: ...
|
28
30
|
@property
|
@@ -35,6 +37,8 @@ class Task(google.protobuf.message.Message):
|
|
35
37
|
task_type: typing.Text
|
36
38
|
@property
|
37
39
|
def recordset(self) -> flwr.proto.recordset_pb2.RecordSet: ...
|
40
|
+
@property
|
41
|
+
def error(self) -> flwr.proto.error_pb2.Error: ...
|
38
42
|
def __init__(self,
|
39
43
|
*,
|
40
44
|
producer: typing.Optional[flwr.proto.node_pb2.Node] = ...,
|
@@ -45,9 +49,10 @@ class Task(google.protobuf.message.Message):
|
|
45
49
|
ancestry: typing.Optional[typing.Iterable[typing.Text]] = ...,
|
46
50
|
task_type: typing.Text = ...,
|
47
51
|
recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ...,
|
52
|
+
error: typing.Optional[flwr.proto.error_pb2.Error] = ...,
|
48
53
|
) -> None: ...
|
49
|
-
def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
|
50
|
-
def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
|
54
|
+
def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","error",b"error","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
|
55
|
+
def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
|
51
56
|
global___Task = Task
|
52
57
|
|
53
58
|
class TaskIns(google.protobuf.message.Message):
|
flwr/server/__init__.py
CHANGED
@@ -16,12 +16,14 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
from . import strategy
|
19
|
+
from . import workflow as workflow
|
19
20
|
from .app import run_driver_api as run_driver_api
|
20
21
|
from .app import run_fleet_api as run_fleet_api
|
21
22
|
from .app import run_superlink as run_superlink
|
22
23
|
from .app import start_server as start_server
|
23
24
|
from .client_manager import ClientManager as ClientManager
|
24
25
|
from .client_manager import SimpleClientManager as SimpleClientManager
|
26
|
+
from .compat import LegacyContext as LegacyContext
|
25
27
|
from .compat import start_driver as start_driver
|
26
28
|
from .driver import Driver as Driver
|
27
29
|
from .history import History as History
|
@@ -34,6 +36,7 @@ __all__ = [
|
|
34
36
|
"ClientManager",
|
35
37
|
"Driver",
|
36
38
|
"History",
|
39
|
+
"LegacyContext",
|
37
40
|
"run_driver_api",
|
38
41
|
"run_fleet_api",
|
39
42
|
"run_server_app",
|
@@ -45,4 +48,5 @@ __all__ = [
|
|
45
48
|
"start_driver",
|
46
49
|
"start_server",
|
47
50
|
"strategy",
|
51
|
+
"workflow",
|
48
52
|
]
|
flwr/server/app.py
CHANGED
@@ -14,8 +14,8 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Flower server app."""
|
16
16
|
|
17
|
-
|
18
17
|
import argparse
|
18
|
+
import asyncio
|
19
19
|
import importlib.util
|
20
20
|
import sys
|
21
21
|
import threading
|
@@ -36,9 +36,6 @@ from flwr.common.constant import (
|
|
36
36
|
)
|
37
37
|
from flwr.common.exit_handlers import register_exit_handlers
|
38
38
|
from flwr.common.logger import log
|
39
|
-
from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
|
40
|
-
add_DriverServicer_to_server,
|
41
|
-
)
|
42
39
|
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
43
40
|
add_FleetServicer_to_server,
|
44
41
|
)
|
@@ -48,7 +45,7 @@ from .history import History
|
|
48
45
|
from .server import Server, init_defaults, run_fl
|
49
46
|
from .server_config import ServerConfig
|
50
47
|
from .strategy import Strategy
|
51
|
-
from .superlink.driver.
|
48
|
+
from .superlink.driver.driver_grpc import run_driver_api_grpc
|
52
49
|
from .superlink.fleet.grpc_bidi.grpc_server import (
|
53
50
|
generic_create_grpc_server,
|
54
51
|
start_grpc_server,
|
@@ -204,7 +201,7 @@ def run_driver_api() -> None:
|
|
204
201
|
state_factory = StateFactory(args.database)
|
205
202
|
|
206
203
|
# Start server
|
207
|
-
grpc_server: grpc.Server =
|
204
|
+
grpc_server: grpc.Server = run_driver_api_grpc(
|
208
205
|
address=address,
|
209
206
|
state_factory=state_factory,
|
210
207
|
certificates=certificates,
|
@@ -313,7 +310,7 @@ def run_superlink() -> None:
|
|
313
310
|
state_factory = StateFactory(args.database)
|
314
311
|
|
315
312
|
# Start Driver API
|
316
|
-
driver_server: grpc.Server =
|
313
|
+
driver_server: grpc.Server = run_driver_api_grpc(
|
317
314
|
address=address,
|
318
315
|
state_factory=state_factory,
|
319
316
|
certificates=certificates,
|
@@ -362,6 +359,7 @@ def run_superlink() -> None:
|
|
362
359
|
)
|
363
360
|
grpc_servers.append(fleet_server)
|
364
361
|
elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
|
362
|
+
f_stop = asyncio.Event() # Does nothing
|
365
363
|
_run_fleet_api_vce(
|
366
364
|
num_supernodes=args.num_supernodes,
|
367
365
|
client_app_module_name=args.client_app,
|
@@ -369,6 +367,7 @@ def run_superlink() -> None:
|
|
369
367
|
backend_config_json_stream=args.backend_config,
|
370
368
|
working_dir=args.dir,
|
371
369
|
state_factory=state_factory,
|
370
|
+
f_stop=f_stop,
|
372
371
|
)
|
373
372
|
else:
|
374
373
|
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
|
@@ -412,30 +411,6 @@ def _try_obtain_certificates(
|
|
412
411
|
return certificates
|
413
412
|
|
414
413
|
|
415
|
-
def _run_driver_api_grpc(
|
416
|
-
address: str,
|
417
|
-
state_factory: StateFactory,
|
418
|
-
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
419
|
-
) -> grpc.Server:
|
420
|
-
"""Run Driver API (gRPC, request-response)."""
|
421
|
-
# Create Driver API gRPC server
|
422
|
-
driver_servicer: grpc.Server = DriverServicer(
|
423
|
-
state_factory=state_factory,
|
424
|
-
)
|
425
|
-
driver_add_servicer_to_server_fn = add_DriverServicer_to_server
|
426
|
-
driver_grpc_server = generic_create_grpc_server(
|
427
|
-
servicer_and_add_fn=(driver_servicer, driver_add_servicer_to_server_fn),
|
428
|
-
server_address=address,
|
429
|
-
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
430
|
-
certificates=certificates,
|
431
|
-
)
|
432
|
-
|
433
|
-
log(INFO, "Flower ECE: Starting Driver API (gRPC-rere) on %s", address)
|
434
|
-
driver_grpc_server.start()
|
435
|
-
|
436
|
-
return driver_grpc_server
|
437
|
-
|
438
|
-
|
439
414
|
def _run_fleet_api_grpc_rere(
|
440
415
|
address: str,
|
441
416
|
state_factory: StateFactory,
|
@@ -468,6 +443,7 @@ def _run_fleet_api_vce(
|
|
468
443
|
backend_config_json_stream: str,
|
469
444
|
working_dir: str,
|
470
445
|
state_factory: StateFactory,
|
446
|
+
f_stop: asyncio.Event,
|
471
447
|
) -> None:
|
472
448
|
log(INFO, "Flower VCE: Starting Fleet API (VirtualClientEngine)")
|
473
449
|
|
@@ -478,6 +454,7 @@ def _run_fleet_api_vce(
|
|
478
454
|
backend_config_json_stream=backend_config_json_stream,
|
479
455
|
state_factory=state_factory,
|
480
456
|
working_dir=working_dir,
|
457
|
+
f_stop=f_stop,
|
481
458
|
)
|
482
459
|
|
483
460
|
|
flwr/server/client_proxy.py
CHANGED
@@ -47,6 +47,7 @@ class ClientProxy(ABC):
|
|
47
47
|
self,
|
48
48
|
ins: GetPropertiesIns,
|
49
49
|
timeout: Optional[float],
|
50
|
+
group_id: Optional[int],
|
50
51
|
) -> GetPropertiesRes:
|
51
52
|
"""Return the client's properties."""
|
52
53
|
|
@@ -55,6 +56,7 @@ class ClientProxy(ABC):
|
|
55
56
|
self,
|
56
57
|
ins: GetParametersIns,
|
57
58
|
timeout: Optional[float],
|
59
|
+
group_id: Optional[int],
|
58
60
|
) -> GetParametersRes:
|
59
61
|
"""Return the current local model parameters."""
|
60
62
|
|
@@ -63,6 +65,7 @@ class ClientProxy(ABC):
|
|
63
65
|
self,
|
64
66
|
ins: FitIns,
|
65
67
|
timeout: Optional[float],
|
68
|
+
group_id: Optional[int],
|
66
69
|
) -> FitRes:
|
67
70
|
"""Refine the provided parameters using the locally held dataset."""
|
68
71
|
|
@@ -71,6 +74,7 @@ class ClientProxy(ABC):
|
|
71
74
|
self,
|
72
75
|
ins: EvaluateIns,
|
73
76
|
timeout: Optional[float],
|
77
|
+
group_id: Optional[int],
|
74
78
|
) -> EvaluateRes:
|
75
79
|
"""Evaluate the provided parameters using the locally held dataset."""
|
76
80
|
|
@@ -79,5 +83,6 @@ class ClientProxy(ABC):
|
|
79
83
|
self,
|
80
84
|
ins: ReconnectIns,
|
81
85
|
timeout: Optional[float],
|
86
|
+
group_id: Optional[int],
|
82
87
|
) -> DisconnectRes:
|
83
88
|
"""Disconnect and (optionally) reconnect later."""
|
flwr/server/compat/__init__.py
CHANGED
flwr/server/compat/app.py
CHANGED
@@ -16,16 +16,13 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import sys
|
19
|
-
import threading
|
20
|
-
import time
|
21
19
|
from logging import INFO
|
22
20
|
from pathlib import Path
|
23
|
-
from typing import
|
21
|
+
from typing import Optional, Union
|
24
22
|
|
25
23
|
from flwr.common import EventType, event
|
26
24
|
from flwr.common.address import parse_address
|
27
25
|
from flwr.common.logger import log, warn_deprecated_feature
|
28
|
-
from flwr.proto import driver_pb2 # pylint: disable=E0611
|
29
26
|
from flwr.server.client_manager import ClientManager
|
30
27
|
from flwr.server.history import History
|
31
28
|
from flwr.server.server import Server, init_defaults, run_fl
|
@@ -33,8 +30,7 @@ from flwr.server.server_config import ServerConfig
|
|
33
30
|
from flwr.server.strategy import Strategy
|
34
31
|
|
35
32
|
from ..driver import Driver
|
36
|
-
from
|
37
|
-
from .driver_client_proxy import DriverClientProxy
|
33
|
+
from .app_utils import start_update_client_manager_thread
|
38
34
|
|
39
35
|
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
40
36
|
|
@@ -104,11 +100,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
104
100
|
"""
|
105
101
|
event(EventType.START_DRIVER_ENTER)
|
106
102
|
|
107
|
-
if driver:
|
108
|
-
# pylint: disable=protected-access
|
109
|
-
grpc_driver, _ = driver._get_grpc_driver_and_run_id()
|
110
|
-
# pylint: enable=protected-access
|
111
|
-
else:
|
103
|
+
if driver is None:
|
112
104
|
# Not passing a `Driver` object is deprecated
|
113
105
|
warn_deprecated_feature("start_driver")
|
114
106
|
|
@@ -122,12 +114,9 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
122
114
|
# Create the Driver
|
123
115
|
if isinstance(root_certificates, str):
|
124
116
|
root_certificates = Path(root_certificates).read_bytes()
|
125
|
-
|
117
|
+
driver = Driver(
|
126
118
|
driver_service_address=address, root_certificates=root_certificates
|
127
119
|
)
|
128
|
-
grpc_driver.connect()
|
129
|
-
|
130
|
-
lock = threading.Lock()
|
131
120
|
|
132
121
|
# Initialize the Driver API server and config
|
133
122
|
initialized_server, initialized_config = init_defaults(
|
@@ -142,18 +131,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
142
131
|
initialized_config,
|
143
132
|
)
|
144
133
|
|
145
|
-
f_stop = threading.Event()
|
146
134
|
# Start the thread updating nodes
|
147
|
-
thread =
|
148
|
-
|
149
|
-
args=(
|
150
|
-
grpc_driver,
|
151
|
-
initialized_server.client_manager(),
|
152
|
-
lock,
|
153
|
-
f_stop,
|
154
|
-
),
|
135
|
+
thread, f_stop = start_update_client_manager_thread(
|
136
|
+
driver, initialized_server.client_manager()
|
155
137
|
)
|
156
|
-
thread.start()
|
157
138
|
|
158
139
|
# Start training
|
159
140
|
hist = run_fl(
|
@@ -164,72 +145,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
164
145
|
f_stop.set()
|
165
146
|
|
166
147
|
# Stop the Driver API server and the thread
|
167
|
-
|
168
|
-
if driver:
|
169
|
-
del driver
|
170
|
-
else:
|
171
|
-
grpc_driver.disconnect()
|
148
|
+
del driver
|
172
149
|
|
173
150
|
thread.join()
|
174
151
|
|
175
152
|
event(EventType.START_SERVER_LEAVE)
|
176
153
|
|
177
154
|
return hist
|
178
|
-
|
179
|
-
|
180
|
-
def update_client_manager(
|
181
|
-
driver: GrpcDriver,
|
182
|
-
client_manager: ClientManager,
|
183
|
-
lock: threading.Lock,
|
184
|
-
f_stop: threading.Event,
|
185
|
-
) -> None:
|
186
|
-
"""Update the nodes list in the client manager.
|
187
|
-
|
188
|
-
This function periodically communicates with the associated driver to get all
|
189
|
-
node_ids. Each node_id is then converted into a `DriverClientProxy` instance
|
190
|
-
and stored in the `registered_nodes` dictionary with node_id as key.
|
191
|
-
|
192
|
-
New nodes will be added to the ClientManager via `client_manager.register()`,
|
193
|
-
and dead nodes will be removed from the ClientManager via
|
194
|
-
`client_manager.unregister()`.
|
195
|
-
"""
|
196
|
-
# Request for run_id
|
197
|
-
run_id = driver.create_run(
|
198
|
-
driver_pb2.CreateRunRequest() # pylint: disable=E1101
|
199
|
-
).run_id
|
200
|
-
|
201
|
-
# Loop until the driver is disconnected
|
202
|
-
registered_nodes: Dict[int, DriverClientProxy] = {}
|
203
|
-
while not f_stop.is_set():
|
204
|
-
with lock:
|
205
|
-
# End the while loop if the driver is disconnected
|
206
|
-
if driver.stub is None:
|
207
|
-
break
|
208
|
-
get_nodes_res = driver.get_nodes(
|
209
|
-
req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101
|
210
|
-
)
|
211
|
-
all_node_ids = {node.node_id for node in get_nodes_res.nodes}
|
212
|
-
dead_nodes = set(registered_nodes).difference(all_node_ids)
|
213
|
-
new_nodes = all_node_ids.difference(registered_nodes)
|
214
|
-
|
215
|
-
# Unregister dead nodes
|
216
|
-
for node_id in dead_nodes:
|
217
|
-
client_proxy = registered_nodes[node_id]
|
218
|
-
client_manager.unregister(client_proxy)
|
219
|
-
del registered_nodes[node_id]
|
220
|
-
|
221
|
-
# Register new nodes
|
222
|
-
for node_id in new_nodes:
|
223
|
-
client_proxy = DriverClientProxy(
|
224
|
-
node_id=node_id,
|
225
|
-
driver=driver,
|
226
|
-
anonymous=False,
|
227
|
-
run_id=run_id,
|
228
|
-
)
|
229
|
-
if client_manager.register(client_proxy):
|
230
|
-
registered_nodes[node_id] = client_proxy
|
231
|
-
else:
|
232
|
-
raise RuntimeError("Could not register node.")
|
233
|
-
|
234
|
-
# Sleep for 3 seconds
|
235
|
-
time.sleep(3)
|
@@ -0,0 +1,102 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utility functions for the `start_driver`."""
|
16
|
+
|
17
|
+
|
18
|
+
import threading
|
19
|
+
import time
|
20
|
+
from typing import Dict, Tuple
|
21
|
+
|
22
|
+
from ..client_manager import ClientManager
|
23
|
+
from ..compat.driver_client_proxy import DriverClientProxy
|
24
|
+
from ..driver import Driver
|
25
|
+
|
26
|
+
|
27
|
+
def start_update_client_manager_thread(
|
28
|
+
driver: Driver,
|
29
|
+
client_manager: ClientManager,
|
30
|
+
) -> Tuple[threading.Thread, threading.Event]:
|
31
|
+
"""Periodically update the nodes list in the client manager in a thread.
|
32
|
+
|
33
|
+
This function starts a thread that periodically uses the associated driver to
|
34
|
+
get all node_ids. Each node_id is then converted into a `DriverClientProxy`
|
35
|
+
instance and stored in the `registered_nodes` dictionary with node_id as key.
|
36
|
+
|
37
|
+
New nodes will be added to the ClientManager via `client_manager.register()`,
|
38
|
+
and dead nodes will be removed from the ClientManager via
|
39
|
+
`client_manager.unregister()`.
|
40
|
+
|
41
|
+
Parameters
|
42
|
+
----------
|
43
|
+
driver : Driver
|
44
|
+
The Driver object to use.
|
45
|
+
client_manager : ClientManager
|
46
|
+
The ClientManager object to be updated.
|
47
|
+
|
48
|
+
Returns
|
49
|
+
-------
|
50
|
+
threading.Thread
|
51
|
+
A thread that updates the ClientManager and handles the stop event.
|
52
|
+
threading.Event
|
53
|
+
An event that, when set, signals the thread to stop.
|
54
|
+
"""
|
55
|
+
f_stop = threading.Event()
|
56
|
+
thread = threading.Thread(
|
57
|
+
target=_update_client_manager,
|
58
|
+
args=(
|
59
|
+
driver,
|
60
|
+
client_manager,
|
61
|
+
f_stop,
|
62
|
+
),
|
63
|
+
)
|
64
|
+
thread.start()
|
65
|
+
|
66
|
+
return thread, f_stop
|
67
|
+
|
68
|
+
|
69
|
+
def _update_client_manager(
|
70
|
+
driver: Driver,
|
71
|
+
client_manager: ClientManager,
|
72
|
+
f_stop: threading.Event,
|
73
|
+
) -> None:
|
74
|
+
"""Update the nodes list in the client manager."""
|
75
|
+
# Loop until the driver is disconnected
|
76
|
+
registered_nodes: Dict[int, DriverClientProxy] = {}
|
77
|
+
while not f_stop.is_set():
|
78
|
+
all_node_ids = set(driver.get_node_ids())
|
79
|
+
dead_nodes = set(registered_nodes).difference(all_node_ids)
|
80
|
+
new_nodes = all_node_ids.difference(registered_nodes)
|
81
|
+
|
82
|
+
# Unregister dead nodes
|
83
|
+
for node_id in dead_nodes:
|
84
|
+
client_proxy = registered_nodes[node_id]
|
85
|
+
client_manager.unregister(client_proxy)
|
86
|
+
del registered_nodes[node_id]
|
87
|
+
|
88
|
+
# Register new nodes
|
89
|
+
for node_id in new_nodes:
|
90
|
+
client_proxy = DriverClientProxy(
|
91
|
+
node_id=node_id,
|
92
|
+
driver=driver.grpc_driver, # type: ignore
|
93
|
+
anonymous=False,
|
94
|
+
run_id=driver.run_id, # type: ignore
|
95
|
+
)
|
96
|
+
if client_manager.register(client_proxy):
|
97
|
+
registered_nodes[node_id] = client_proxy
|
98
|
+
else:
|
99
|
+
raise RuntimeError("Could not register node.")
|
100
|
+
|
101
|
+
# Sleep for 3 seconds
|
102
|
+
time.sleep(3)
|
@@ -47,57 +47,68 @@ class DriverClientProxy(ClientProxy):
|
|
47
47
|
self.anonymous = anonymous
|
48
48
|
|
49
49
|
def get_properties(
|
50
|
-
self,
|
50
|
+
self,
|
51
|
+
ins: common.GetPropertiesIns,
|
52
|
+
timeout: Optional[float],
|
53
|
+
group_id: Optional[int],
|
51
54
|
) -> common.GetPropertiesRes:
|
52
55
|
"""Return client's properties."""
|
53
56
|
# Ins to RecordSet
|
54
57
|
out_recordset = compat.getpropertiesins_to_recordset(ins)
|
55
58
|
# Fetch response
|
56
59
|
in_recordset = self._send_receive_recordset(
|
57
|
-
out_recordset, MESSAGE_TYPE_GET_PROPERTIES, timeout
|
60
|
+
out_recordset, MESSAGE_TYPE_GET_PROPERTIES, timeout, group_id
|
58
61
|
)
|
59
62
|
# RecordSet to Res
|
60
63
|
return compat.recordset_to_getpropertiesres(in_recordset)
|
61
64
|
|
62
65
|
def get_parameters(
|
63
|
-
self,
|
66
|
+
self,
|
67
|
+
ins: common.GetParametersIns,
|
68
|
+
timeout: Optional[float],
|
69
|
+
group_id: Optional[int],
|
64
70
|
) -> common.GetParametersRes:
|
65
71
|
"""Return the current local model parameters."""
|
66
72
|
# Ins to RecordSet
|
67
73
|
out_recordset = compat.getparametersins_to_recordset(ins)
|
68
74
|
# Fetch response
|
69
75
|
in_recordset = self._send_receive_recordset(
|
70
|
-
out_recordset, MESSAGE_TYPE_GET_PARAMETERS, timeout
|
76
|
+
out_recordset, MESSAGE_TYPE_GET_PARAMETERS, timeout, group_id
|
71
77
|
)
|
72
78
|
# RecordSet to Res
|
73
79
|
return compat.recordset_to_getparametersres(in_recordset, False)
|
74
80
|
|
75
|
-
def fit(
|
81
|
+
def fit(
|
82
|
+
self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
|
83
|
+
) -> common.FitRes:
|
76
84
|
"""Train model parameters on the locally held dataset."""
|
77
85
|
# Ins to RecordSet
|
78
86
|
out_recordset = compat.fitins_to_recordset(ins, keep_input=True)
|
79
87
|
# Fetch response
|
80
88
|
in_recordset = self._send_receive_recordset(
|
81
|
-
out_recordset, MESSAGE_TYPE_FIT, timeout
|
89
|
+
out_recordset, MESSAGE_TYPE_FIT, timeout, group_id
|
82
90
|
)
|
83
91
|
# RecordSet to Res
|
84
92
|
return compat.recordset_to_fitres(in_recordset, keep_input=False)
|
85
93
|
|
86
94
|
def evaluate(
|
87
|
-
self, ins: common.EvaluateIns, timeout: Optional[float]
|
95
|
+
self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
|
88
96
|
) -> common.EvaluateRes:
|
89
97
|
"""Evaluate model parameters on the locally held dataset."""
|
90
98
|
# Ins to RecordSet
|
91
99
|
out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True)
|
92
100
|
# Fetch response
|
93
101
|
in_recordset = self._send_receive_recordset(
|
94
|
-
out_recordset, MESSAGE_TYPE_EVALUATE, timeout
|
102
|
+
out_recordset, MESSAGE_TYPE_EVALUATE, timeout, group_id
|
95
103
|
)
|
96
104
|
# RecordSet to Res
|
97
105
|
return compat.recordset_to_evaluateres(in_recordset)
|
98
106
|
|
99
107
|
def reconnect(
|
100
|
-
self,
|
108
|
+
self,
|
109
|
+
ins: common.ReconnectIns,
|
110
|
+
timeout: Optional[float],
|
111
|
+
group_id: Optional[int],
|
101
112
|
) -> common.DisconnectRes:
|
102
113
|
"""Disconnect and (optionally) reconnect later."""
|
103
114
|
return common.DisconnectRes(reason="") # Nothing to do here (yet)
|
@@ -107,10 +118,11 @@ class DriverClientProxy(ClientProxy):
|
|
107
118
|
recordset: RecordSet,
|
108
119
|
task_type: str,
|
109
120
|
timeout: Optional[float],
|
121
|
+
group_id: Optional[int],
|
110
122
|
) -> RecordSet:
|
111
123
|
task_ins = task_pb2.TaskIns( # pylint: disable=E1101
|
112
124
|
task_id="",
|
113
|
-
group_id="",
|
125
|
+
group_id=str(group_id) if group_id is not None else "",
|
114
126
|
run_id=self.run_id,
|
115
127
|
task=task_pb2.Task( # pylint: disable=E1101
|
116
128
|
producer=node_pb2.Node( # pylint: disable=E1101
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Legacy Context."""
|
16
|
+
|
17
|
+
|
18
|
+
from dataclasses import dataclass
|
19
|
+
from typing import Optional
|
20
|
+
|
21
|
+
from flwr.common import Context, RecordSet
|
22
|
+
|
23
|
+
from ..client_manager import ClientManager, SimpleClientManager
|
24
|
+
from ..history import History
|
25
|
+
from ..server_config import ServerConfig
|
26
|
+
from ..strategy import FedAvg, Strategy
|
27
|
+
|
28
|
+
|
29
|
+
@dataclass
|
30
|
+
class LegacyContext(Context):
|
31
|
+
"""Legacy Context."""
|
32
|
+
|
33
|
+
config: ServerConfig
|
34
|
+
strategy: Strategy
|
35
|
+
client_manager: ClientManager
|
36
|
+
history: History
|
37
|
+
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
state: RecordSet,
|
41
|
+
config: Optional[ServerConfig] = None,
|
42
|
+
strategy: Optional[Strategy] = None,
|
43
|
+
client_manager: Optional[ClientManager] = None,
|
44
|
+
) -> None:
|
45
|
+
if config is None:
|
46
|
+
config = ServerConfig()
|
47
|
+
if strategy is None:
|
48
|
+
strategy = FedAvg()
|
49
|
+
if client_manager is None:
|
50
|
+
client_manager = SimpleClientManager()
|
51
|
+
self.config = config
|
52
|
+
self.strategy = strategy
|
53
|
+
self.client_manager = client_manager
|
54
|
+
self.history = History()
|
55
|
+
super().__init__(state)
|
flwr/server/run_serverapp.py
CHANGED