flwr-nightly 1.10.0.dev20240618__py3-none-any.whl → 1.10.0.dev20240620__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +3 -0
- flwr/cli/build.py +3 -7
- flwr/cli/new/new.py +1 -1
- flwr/cli/run/run.py +8 -1
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +4 -0
- flwr/client/client_app.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/connection.py +1 -1
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/__init__.py +4 -4
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +1 -1
- flwr/client/supernode/app.py +29 -6
- flwr/common/__init__.py +12 -12
- flwr/common/address.py +1 -1
- flwr/common/config.py +8 -6
- flwr/common/constant.py +5 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +1 -1
- flwr/common/object_ref.py +39 -5
- flwr/common/record/__init__.py +1 -1
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/version.py +14 -0
- flwr/server/__init__.py +2 -2
- flwr/server/app.py +47 -7
- flwr/server/compat/app.py +1 -1
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +1 -1
- flwr/server/driver/driver.py +6 -0
- flwr/server/driver/grpc_driver.py +85 -63
- flwr/server/driver/inmemory_driver.py +28 -26
- flwr/server/run_serverapp.py +15 -8
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +15 -3
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -1
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +1 -1
- flwr/server/superlink/state/sqlite_state.py +1 -1
- flwr/server/superlink/state/state.py +1 -1
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/simulation/__init__.py +5 -2
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +15 -8
- flwr/superexec/app.py +1 -1
- {flwr_nightly-1.10.0.dev20240618.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/METADATA +2 -1
- {flwr_nightly-1.10.0.dev20240618.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/RECORD +98 -96
- {flwr_nightly-1.10.0.dev20240618.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240618.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240618.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/entry_points.txt +0 -0
flwr/server/app.py
CHANGED
|
@@ -36,6 +36,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
|
36
36
|
from flwr.common.address import parse_address
|
|
37
37
|
from flwr.common.constant import (
|
|
38
38
|
MISSING_EXTRA_REST,
|
|
39
|
+
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
39
40
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
40
41
|
TRANSPORT_TYPE_REST,
|
|
41
42
|
)
|
|
@@ -48,6 +49,7 @@ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
|
48
49
|
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
|
49
50
|
add_FleetServicer_to_server,
|
|
50
51
|
)
|
|
52
|
+
from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
|
|
51
53
|
|
|
52
54
|
from .client_manager import ClientManager
|
|
53
55
|
from .history import History
|
|
@@ -55,6 +57,7 @@ from .server import Server, init_defaults, run_fl
|
|
|
55
57
|
from .server_config import ServerConfig
|
|
56
58
|
from .strategy import Strategy
|
|
57
59
|
from .superlink.driver.driver_grpc import run_driver_api_grpc
|
|
60
|
+
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
|
|
58
61
|
from .superlink.fleet.grpc_bidi.grpc_server import (
|
|
59
62
|
generic_create_grpc_server,
|
|
60
63
|
start_grpc_server,
|
|
@@ -218,11 +221,13 @@ def run_superlink() -> None:
|
|
|
218
221
|
grpc_servers = [driver_server]
|
|
219
222
|
bckg_threads = []
|
|
220
223
|
if not args.fleet_api_address:
|
|
221
|
-
args.
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
224
|
+
if args.fleet_api_type in [
|
|
225
|
+
TRANSPORT_TYPE_GRPC_RERE,
|
|
226
|
+
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
227
|
+
]:
|
|
228
|
+
args.fleet_api_address = ADDRESS_FLEET_API_GRPC_RERE
|
|
229
|
+
elif args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
230
|
+
args.fleet_api_address = ADDRESS_FLEET_API_REST
|
|
226
231
|
|
|
227
232
|
fleet_address, host, port = _format_address(args.fleet_api_address)
|
|
228
233
|
|
|
@@ -293,6 +298,13 @@ def run_superlink() -> None:
|
|
|
293
298
|
interceptors=interceptors,
|
|
294
299
|
)
|
|
295
300
|
grpc_servers.append(fleet_server)
|
|
301
|
+
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER:
|
|
302
|
+
fleet_server = _run_fleet_api_grpc_adapter(
|
|
303
|
+
address=fleet_address,
|
|
304
|
+
state_factory=state_factory,
|
|
305
|
+
certificates=certificates,
|
|
306
|
+
)
|
|
307
|
+
grpc_servers.append(fleet_server)
|
|
296
308
|
else:
|
|
297
309
|
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
|
|
298
310
|
|
|
@@ -419,7 +431,7 @@ def _try_obtain_certificates(
|
|
|
419
431
|
log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.")
|
|
420
432
|
return None
|
|
421
433
|
# Check if certificates are provided
|
|
422
|
-
if args.fleet_api_type
|
|
434
|
+
if args.fleet_api_type in [TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_GRPC_ADAPTER]:
|
|
423
435
|
if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
|
|
424
436
|
if not isfile(args.ssl_ca_certfile):
|
|
425
437
|
sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
|
|
@@ -491,6 +503,30 @@ def _run_fleet_api_grpc_rere(
|
|
|
491
503
|
return fleet_grpc_server
|
|
492
504
|
|
|
493
505
|
|
|
506
|
+
def _run_fleet_api_grpc_adapter(
|
|
507
|
+
address: str,
|
|
508
|
+
state_factory: StateFactory,
|
|
509
|
+
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
510
|
+
) -> grpc.Server:
|
|
511
|
+
"""Run Fleet API (GrpcAdapter)."""
|
|
512
|
+
# Create Fleet API gRPC server
|
|
513
|
+
fleet_servicer = GrpcAdapterServicer(
|
|
514
|
+
state_factory=state_factory,
|
|
515
|
+
)
|
|
516
|
+
fleet_add_servicer_to_server_fn = add_GrpcAdapterServicer_to_server
|
|
517
|
+
fleet_grpc_server = generic_create_grpc_server(
|
|
518
|
+
servicer_and_add_fn=(fleet_servicer, fleet_add_servicer_to_server_fn),
|
|
519
|
+
server_address=address,
|
|
520
|
+
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
521
|
+
certificates=certificates,
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
log(INFO, "Flower ECE: Starting Fleet API (GrpcAdapter) on %s", address)
|
|
525
|
+
fleet_grpc_server.start()
|
|
526
|
+
|
|
527
|
+
return fleet_grpc_server
|
|
528
|
+
|
|
529
|
+
|
|
494
530
|
# pylint: disable=import-outside-toplevel,too-many-arguments
|
|
495
531
|
def _run_fleet_api_rest(
|
|
496
532
|
host: str,
|
|
@@ -606,7 +642,11 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
|
|
|
606
642
|
"--fleet-api-type",
|
|
607
643
|
default=TRANSPORT_TYPE_GRPC_RERE,
|
|
608
644
|
type=str,
|
|
609
|
-
choices=[
|
|
645
|
+
choices=[
|
|
646
|
+
TRANSPORT_TYPE_GRPC_RERE,
|
|
647
|
+
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
648
|
+
TRANSPORT_TYPE_REST,
|
|
649
|
+
],
|
|
610
650
|
help="Start a gRPC-rere or REST (experimental) Fleet API server.",
|
|
611
651
|
)
|
|
612
652
|
parser.add_argument(
|
flwr/server/compat/app.py
CHANGED
flwr/server/compat/app_utils.py
CHANGED
flwr/server/driver/driver.py
CHANGED
|
@@ -19,11 +19,17 @@ from abc import ABC, abstractmethod
|
|
|
19
19
|
from typing import Iterable, List, Optional
|
|
20
20
|
|
|
21
21
|
from flwr.common import Message, RecordSet
|
|
22
|
+
from flwr.common.typing import Run
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
class Driver(ABC):
|
|
25
26
|
"""Abstract base Driver class for the Driver API."""
|
|
26
27
|
|
|
28
|
+
@property
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def run(self) -> Run:
|
|
31
|
+
"""Run information."""
|
|
32
|
+
|
|
27
33
|
@abstractmethod
|
|
28
34
|
def create_message( # pylint: disable=too-many-arguments
|
|
29
35
|
self,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
import time
|
|
18
18
|
import warnings
|
|
19
19
|
from logging import DEBUG, ERROR, WARNING
|
|
20
|
-
from typing import Iterable, List, Optional, Tuple
|
|
20
|
+
from typing import Iterable, List, Optional, Tuple, cast
|
|
21
21
|
|
|
22
22
|
import grpc
|
|
23
23
|
|
|
@@ -25,6 +25,7 @@ from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, ev
|
|
|
25
25
|
from flwr.common.grpc import create_channel
|
|
26
26
|
from flwr.common.logger import log
|
|
27
27
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
28
|
+
from flwr.common.typing import Run
|
|
28
29
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
29
30
|
CreateRunRequest,
|
|
30
31
|
CreateRunResponse,
|
|
@@ -37,6 +38,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
|
37
38
|
)
|
|
38
39
|
from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
|
|
39
40
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
41
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
40
42
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
41
43
|
|
|
42
44
|
from .driver import Driver
|
|
@@ -46,13 +48,24 @@ DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
|
46
48
|
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
|
|
47
49
|
[Driver] Error: Not connected.
|
|
48
50
|
|
|
49
|
-
Call `connect()` on the `
|
|
50
|
-
`
|
|
51
|
+
Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
|
|
52
|
+
`GrpcDriverStub` methods.
|
|
51
53
|
"""
|
|
52
54
|
|
|
53
55
|
|
|
54
|
-
class
|
|
55
|
-
"""`
|
|
56
|
+
class GrpcDriverStub:
|
|
57
|
+
"""`GrpcDriverStub` provides access to the gRPC Driver API/service.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
driver_service_address : Optional[str]
|
|
62
|
+
The IPv4 or IPv6 address of the Driver API server.
|
|
63
|
+
Defaults to `"[::]:9091"`.
|
|
64
|
+
root_certificates : Optional[bytes] (default: None)
|
|
65
|
+
The PEM-encoded root certificates as a byte string.
|
|
66
|
+
If provided, a secure connection using the certificates will be
|
|
67
|
+
established to an SSL-enabled Flower server.
|
|
68
|
+
"""
|
|
56
69
|
|
|
57
70
|
def __init__(
|
|
58
71
|
self,
|
|
@@ -64,6 +77,10 @@ class GrpcDriverHelper:
|
|
|
64
77
|
self.channel: Optional[grpc.Channel] = None
|
|
65
78
|
self.stub: Optional[DriverStub] = None
|
|
66
79
|
|
|
80
|
+
def is_connected(self) -> bool:
|
|
81
|
+
"""Return True if connected to the Driver API server, otherwise False."""
|
|
82
|
+
return self.channel is not None
|
|
83
|
+
|
|
67
84
|
def connect(self) -> None:
|
|
68
85
|
"""Connect to the Driver API."""
|
|
69
86
|
event(EventType.DRIVER_CONNECT)
|
|
@@ -95,18 +112,29 @@ class GrpcDriverHelper:
|
|
|
95
112
|
# Check if channel is open
|
|
96
113
|
if self.stub is None:
|
|
97
114
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
98
|
-
raise ConnectionError("`
|
|
115
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
99
116
|
|
|
100
117
|
# Call Driver API
|
|
101
118
|
res: CreateRunResponse = self.stub.CreateRun(request=req)
|
|
102
119
|
return res
|
|
103
120
|
|
|
121
|
+
def get_run(self, req: GetRunRequest) -> GetRunResponse:
|
|
122
|
+
"""Get run information."""
|
|
123
|
+
# Check if channel is open
|
|
124
|
+
if self.stub is None:
|
|
125
|
+
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
126
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
127
|
+
|
|
128
|
+
# Call gRPC Driver API
|
|
129
|
+
res: GetRunResponse = self.stub.GetRun(request=req)
|
|
130
|
+
return res
|
|
131
|
+
|
|
104
132
|
def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
|
|
105
133
|
"""Get client IDs."""
|
|
106
134
|
# Check if channel is open
|
|
107
135
|
if self.stub is None:
|
|
108
136
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
109
|
-
raise ConnectionError("`
|
|
137
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
110
138
|
|
|
111
139
|
# Call gRPC Driver API
|
|
112
140
|
res: GetNodesResponse = self.stub.GetNodes(request=req)
|
|
@@ -117,7 +145,7 @@ class GrpcDriverHelper:
|
|
|
117
145
|
# Check if channel is open
|
|
118
146
|
if self.stub is None:
|
|
119
147
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
120
|
-
raise ConnectionError("`
|
|
148
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
121
149
|
|
|
122
150
|
# Call gRPC Driver API
|
|
123
151
|
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
|
|
@@ -128,7 +156,7 @@ class GrpcDriverHelper:
|
|
|
128
156
|
# Check if channel is open
|
|
129
157
|
if self.stub is None:
|
|
130
158
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
131
|
-
raise ConnectionError("`
|
|
159
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
132
160
|
|
|
133
161
|
# Call Driver API
|
|
134
162
|
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
|
|
@@ -140,56 +168,52 @@ class GrpcDriver(Driver):
|
|
|
140
168
|
|
|
141
169
|
Parameters
|
|
142
170
|
----------
|
|
143
|
-
|
|
144
|
-
The
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
to start a secure SSL-enabled server. The tuple is expected to have
|
|
149
|
-
three bytes elements in the following order:
|
|
150
|
-
|
|
151
|
-
* CA certificate.
|
|
152
|
-
* server certificate.
|
|
153
|
-
* server private key.
|
|
154
|
-
fab_id : str (default: None)
|
|
155
|
-
The identifier of the FAB used in the run.
|
|
156
|
-
fab_version : str (default: None)
|
|
157
|
-
The version of the FAB used in the run.
|
|
171
|
+
run_id : int
|
|
172
|
+
The identifier of the run.
|
|
173
|
+
stub : Optional[GrpcDriverStub] (default: None)
|
|
174
|
+
The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
|
|
175
|
+
If None, an instance connected to "[::]:9091" will be created.
|
|
158
176
|
"""
|
|
159
177
|
|
|
160
|
-
def __init__(
|
|
178
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
161
179
|
self,
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
fab_id: Optional[str] = None,
|
|
165
|
-
fab_version: Optional[str] = None,
|
|
180
|
+
run_id: int,
|
|
181
|
+
stub: Optional[GrpcDriverStub] = None,
|
|
166
182
|
) -> None:
|
|
167
|
-
self.
|
|
168
|
-
self.
|
|
169
|
-
self.
|
|
170
|
-
self.run_id: Optional[int] = None
|
|
171
|
-
self.fab_id = fab_id if fab_id is not None else ""
|
|
172
|
-
self.fab_version = fab_version if fab_version is not None else ""
|
|
183
|
+
self._run_id = run_id
|
|
184
|
+
self._run: Optional[Run] = None
|
|
185
|
+
self.stub = stub if stub is not None else GrpcDriverStub()
|
|
173
186
|
self.node = Node(node_id=0, anonymous=True)
|
|
174
187
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
188
|
+
@property
|
|
189
|
+
def run(self) -> Run:
|
|
190
|
+
"""Run information."""
|
|
191
|
+
self._get_stub_and_run_id()
|
|
192
|
+
return Run(**vars(cast(Run, self._run)))
|
|
193
|
+
|
|
194
|
+
def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]:
|
|
195
|
+
# Check if is initialized
|
|
196
|
+
if self._run is None:
|
|
197
|
+
# Connect
|
|
198
|
+
if not self.stub.is_connected():
|
|
199
|
+
self.stub.connect()
|
|
200
|
+
# Get the run info
|
|
201
|
+
req = GetRunRequest(run_id=self._run_id)
|
|
202
|
+
res = self.stub.get_run(req)
|
|
203
|
+
if not res.HasField("run"):
|
|
204
|
+
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
|
|
205
|
+
self._run = Run(
|
|
206
|
+
run_id=res.run.run_id,
|
|
207
|
+
fab_id=res.run.fab_id,
|
|
208
|
+
fab_version=res.run.fab_version,
|
|
182
209
|
)
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
res = self.driver_helper.create_run(req)
|
|
186
|
-
self.run_id = res.run_id
|
|
187
|
-
return self.driver_helper, self.run_id
|
|
210
|
+
|
|
211
|
+
return self.stub, self._run.run_id
|
|
188
212
|
|
|
189
213
|
def _check_message(self, message: Message) -> None:
|
|
190
214
|
# Check if the message is valid
|
|
191
215
|
if not (
|
|
192
|
-
message.metadata.run_id == self.run_id
|
|
216
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
193
217
|
and message.metadata.src_node_id == self.node.node_id
|
|
194
218
|
and message.metadata.message_id == ""
|
|
195
219
|
and message.metadata.reply_to_message == ""
|
|
@@ -210,7 +234,7 @@ class GrpcDriver(Driver):
|
|
|
210
234
|
This method constructs a new `Message` with given content and metadata.
|
|
211
235
|
The `run_id` and `src_node_id` will be set automatically.
|
|
212
236
|
"""
|
|
213
|
-
_, run_id = self.
|
|
237
|
+
_, run_id = self._get_stub_and_run_id()
|
|
214
238
|
if ttl:
|
|
215
239
|
warnings.warn(
|
|
216
240
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -234,9 +258,9 @@ class GrpcDriver(Driver):
|
|
|
234
258
|
|
|
235
259
|
def get_node_ids(self) -> List[int]:
|
|
236
260
|
"""Get node IDs."""
|
|
237
|
-
|
|
238
|
-
# Call
|
|
239
|
-
res =
|
|
261
|
+
stub, run_id = self._get_stub_and_run_id()
|
|
262
|
+
# Call GrpcDriverStub method
|
|
263
|
+
res = stub.get_nodes(GetNodesRequest(run_id=run_id))
|
|
240
264
|
return [node.node_id for node in res.nodes]
|
|
241
265
|
|
|
242
266
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
@@ -245,7 +269,7 @@ class GrpcDriver(Driver):
|
|
|
245
269
|
This method takes an iterable of messages and sends each message
|
|
246
270
|
to the node specified in `dst_node_id`.
|
|
247
271
|
"""
|
|
248
|
-
|
|
272
|
+
stub, _ = self._get_stub_and_run_id()
|
|
249
273
|
# Construct TaskIns
|
|
250
274
|
task_ins_list: List[TaskIns] = []
|
|
251
275
|
for msg in messages:
|
|
@@ -255,10 +279,8 @@ class GrpcDriver(Driver):
|
|
|
255
279
|
taskins = message_to_taskins(msg)
|
|
256
280
|
# Add to list
|
|
257
281
|
task_ins_list.append(taskins)
|
|
258
|
-
# Call
|
|
259
|
-
res =
|
|
260
|
-
PushTaskInsRequest(task_ins_list=task_ins_list)
|
|
261
|
-
)
|
|
282
|
+
# Call GrpcDriverStub method
|
|
283
|
+
res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
|
|
262
284
|
return list(res.task_ids)
|
|
263
285
|
|
|
264
286
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
@@ -267,9 +289,9 @@ class GrpcDriver(Driver):
|
|
|
267
289
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
268
290
|
set of given message IDs.
|
|
269
291
|
"""
|
|
270
|
-
|
|
292
|
+
stub, _ = self._get_stub_and_run_id()
|
|
271
293
|
# Pull TaskRes
|
|
272
|
-
res =
|
|
294
|
+
res = stub.pull_task_res(
|
|
273
295
|
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
274
296
|
)
|
|
275
297
|
# Convert TaskRes to Message
|
|
@@ -308,8 +330,8 @@ class GrpcDriver(Driver):
|
|
|
308
330
|
|
|
309
331
|
def close(self) -> None:
|
|
310
332
|
"""Disconnect from the SuperLink if connected."""
|
|
311
|
-
# Check if
|
|
312
|
-
if self.
|
|
333
|
+
# Check if `connect` was called before
|
|
334
|
+
if not self.stub.is_connected():
|
|
313
335
|
return
|
|
314
336
|
# Disconnect
|
|
315
|
-
self.
|
|
337
|
+
self.stub.disconnect()
|
|
@@ -17,11 +17,12 @@
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
19
|
import warnings
|
|
20
|
-
from typing import Iterable, List, Optional
|
|
20
|
+
from typing import Iterable, List, Optional, cast
|
|
21
21
|
from uuid import UUID
|
|
22
22
|
|
|
23
23
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
24
24
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
25
|
+
from flwr.common.typing import Run
|
|
25
26
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
26
27
|
from flwr.server.superlink.state import StateFactory
|
|
27
28
|
|
|
@@ -33,30 +34,27 @@ class InMemoryDriver(Driver):
|
|
|
33
34
|
|
|
34
35
|
Parameters
|
|
35
36
|
----------
|
|
37
|
+
run_id : int
|
|
38
|
+
The identifier of the run.
|
|
36
39
|
state_factory : StateFactory
|
|
37
40
|
A StateFactory embedding a state that this driver can interface with.
|
|
38
|
-
fab_id : str (default: None)
|
|
39
|
-
The identifier of the FAB used in the run.
|
|
40
|
-
fab_version : str (default: None)
|
|
41
|
-
The version of the FAB used in the run.
|
|
42
41
|
"""
|
|
43
42
|
|
|
44
43
|
def __init__(
|
|
45
44
|
self,
|
|
45
|
+
run_id: int,
|
|
46
46
|
state_factory: StateFactory,
|
|
47
|
-
fab_id: Optional[str] = None,
|
|
48
|
-
fab_version: Optional[str] = None,
|
|
49
47
|
) -> None:
|
|
50
|
-
self.
|
|
51
|
-
self.
|
|
52
|
-
self.fab_version = fab_version if fab_version is not None else ""
|
|
53
|
-
self.node = Node(node_id=0, anonymous=True)
|
|
48
|
+
self._run_id = run_id
|
|
49
|
+
self._run: Optional[Run] = None
|
|
54
50
|
self.state = state_factory.state()
|
|
51
|
+
self.node = Node(node_id=0, anonymous=True)
|
|
55
52
|
|
|
56
53
|
def _check_message(self, message: Message) -> None:
|
|
54
|
+
self._init_run()
|
|
57
55
|
# Check if the message is valid
|
|
58
56
|
if not (
|
|
59
|
-
message.metadata.run_id == self.run_id
|
|
57
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
60
58
|
and message.metadata.src_node_id == self.node.node_id
|
|
61
59
|
and message.metadata.message_id == ""
|
|
62
60
|
and message.metadata.reply_to_message == ""
|
|
@@ -64,16 +62,20 @@ class InMemoryDriver(Driver):
|
|
|
64
62
|
):
|
|
65
63
|
raise ValueError(f"Invalid message: {message}")
|
|
66
64
|
|
|
67
|
-
def
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
if
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
65
|
+
def _init_run(self) -> None:
|
|
66
|
+
"""Initialize the run."""
|
|
67
|
+
if self._run is not None:
|
|
68
|
+
return
|
|
69
|
+
run = self.state.get_run(self._run_id)
|
|
70
|
+
if run is None:
|
|
71
|
+
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
|
|
72
|
+
self._run = run
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def run(self) -> Run:
|
|
76
|
+
"""Run ID."""
|
|
77
|
+
self._init_run()
|
|
78
|
+
return Run(**vars(cast(Run, self._run)))
|
|
77
79
|
|
|
78
80
|
def create_message( # pylint: disable=too-many-arguments
|
|
79
81
|
self,
|
|
@@ -88,7 +90,7 @@ class InMemoryDriver(Driver):
|
|
|
88
90
|
This method constructs a new `Message` with given content and metadata.
|
|
89
91
|
The `run_id` and `src_node_id` will be set automatically.
|
|
90
92
|
"""
|
|
91
|
-
|
|
93
|
+
self._init_run()
|
|
92
94
|
if ttl:
|
|
93
95
|
warnings.warn(
|
|
94
96
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -99,7 +101,7 @@ class InMemoryDriver(Driver):
|
|
|
99
101
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
100
102
|
|
|
101
103
|
metadata = Metadata(
|
|
102
|
-
run_id=run_id,
|
|
104
|
+
run_id=cast(Run, self._run).run_id,
|
|
103
105
|
message_id="", # Will be set by the server
|
|
104
106
|
src_node_id=self.node.node_id,
|
|
105
107
|
dst_node_id=dst_node_id,
|
|
@@ -112,8 +114,8 @@ class InMemoryDriver(Driver):
|
|
|
112
114
|
|
|
113
115
|
def get_node_ids(self) -> List[int]:
|
|
114
116
|
"""Get node IDs."""
|
|
115
|
-
|
|
116
|
-
return list(self.state.get_nodes(run_id))
|
|
117
|
+
self._init_run()
|
|
118
|
+
return list(self.state.get_nodes(cast(Run, self._run).run_id))
|
|
117
119
|
|
|
118
120
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
119
121
|
"""Push messages to specified node IDs.
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -24,8 +24,10 @@ from typing import Optional
|
|
|
24
24
|
from flwr.common import Context, EventType, RecordSet, event
|
|
25
25
|
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
|
|
26
26
|
from flwr.common.object_ref import load_app
|
|
27
|
+
from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611
|
|
27
28
|
|
|
28
|
-
from .driver import Driver
|
|
29
|
+
from .driver import Driver
|
|
30
|
+
from .driver.grpc_driver import GrpcDriver, GrpcDriverStub
|
|
29
31
|
from .server_app import LoadServerAppError, ServerApp
|
|
30
32
|
|
|
31
33
|
ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
@@ -50,7 +52,9 @@ def run(
|
|
|
50
52
|
# Load ServerApp if needed
|
|
51
53
|
def _load() -> ServerApp:
|
|
52
54
|
if server_app_attr:
|
|
53
|
-
server_app: ServerApp = load_app(
|
|
55
|
+
server_app: ServerApp = load_app(
|
|
56
|
+
server_app_attr, LoadServerAppError, server_app_dir
|
|
57
|
+
)
|
|
54
58
|
|
|
55
59
|
if not isinstance(server_app, ServerApp):
|
|
56
60
|
raise LoadServerAppError(
|
|
@@ -147,13 +151,16 @@ def run_server_app() -> None:
|
|
|
147
151
|
server_app_dir = args.dir
|
|
148
152
|
server_app_attr = getattr(args, "server-app")
|
|
149
153
|
|
|
150
|
-
#
|
|
151
|
-
|
|
152
|
-
driver_service_address=args.superlink,
|
|
153
|
-
root_certificates=root_certificates,
|
|
154
|
-
fab_id=args.fab_id,
|
|
155
|
-
fab_version=args.fab_version,
|
|
154
|
+
# Create run
|
|
155
|
+
stub = GrpcDriverStub(
|
|
156
|
+
driver_service_address=args.superlink, root_certificates=root_certificates
|
|
156
157
|
)
|
|
158
|
+
stub.connect()
|
|
159
|
+
req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version)
|
|
160
|
+
res = stub.create_run(req)
|
|
161
|
+
|
|
162
|
+
# Initialize GrpcDriver
|
|
163
|
+
driver = GrpcDriver(run_id=res.run_id, stub=stub)
|
|
157
164
|
|
|
158
165
|
# Run the ServerApp with the Driver
|
|
159
166
|
run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
|
flwr/server/strategy/__init__.py
CHANGED
|
@@ -53,9 +53,10 @@ __all__ = [
|
|
|
53
53
|
"DPFedAvgAdaptive",
|
|
54
54
|
"DPFedAvgFixed",
|
|
55
55
|
"DifferentialPrivacyClientSideAdaptiveClipping",
|
|
56
|
-
"DifferentialPrivacyServerSideAdaptiveClipping",
|
|
57
56
|
"DifferentialPrivacyClientSideFixedClipping",
|
|
57
|
+
"DifferentialPrivacyServerSideAdaptiveClipping",
|
|
58
58
|
"DifferentialPrivacyServerSideFixedClipping",
|
|
59
|
+
"FaultTolerantFedAvg",
|
|
59
60
|
"FedAdagrad",
|
|
60
61
|
"FedAdam",
|
|
61
62
|
"FedAvg",
|
|
@@ -69,7 +70,6 @@ __all__ = [
|
|
|
69
70
|
"FedXgbCyclic",
|
|
70
71
|
"FedXgbNnAvg",
|
|
71
72
|
"FedYogi",
|
|
72
|
-
"FaultTolerantFedAvg",
|
|
73
73
|
"Krum",
|
|
74
74
|
"QFedAvg",
|
|
75
75
|
"Strategy",
|
flwr/server/strategy/bulyan.py
CHANGED
flwr/server/strategy/fedadam.py
CHANGED
flwr/server/strategy/fedavgm.py
CHANGED