flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +7 -0
- flwr/cli/build.py +150 -0
- flwr/cli/config_utils.py +219 -0
- flwr/cli/example.py +3 -1
- flwr/cli/install.py +227 -0
- flwr/cli/new/new.py +179 -48
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/README.md.tpl +1 -5
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
- flwr/cli/run/run.py +168 -17
- flwr/cli/utils.py +75 -4
- flwr/client/__init__.py +6 -1
- flwr/client/app.py +239 -248
- flwr/client/client_app.py +70 -9
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +97 -0
- flwr/client/grpc_client/connection.py +18 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +127 -33
- flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
- flwr/client/heartbeat.py +74 -0
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +7 -7
- flwr/client/mod/__init__.py +5 -5
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/comms_mods.py +4 -4
- flwr/client/mod/localdp_mod.py +9 -4
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +177 -157
- flwr/client/supernode/__init__.py +26 -0
- flwr/client/supernode/app.py +464 -0
- flwr/client/typing.py +1 -0
- flwr/common/__init__.py +13 -11
- flwr/common/address.py +1 -1
- flwr/common/config.py +193 -0
- flwr/common/constant.py +42 -1
- flwr/common/context.py +26 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +6 -2
- flwr/common/logger.py +79 -8
- flwr/common/message.py +167 -105
- flwr/common/object_ref.py +126 -25
- flwr/common/record/__init__.py +1 -1
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/recordset_compat.py +8 -1
- flwr/common/retry_invoker.py +25 -13
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +209 -3
- flwr/common/telemetry.py +25 -0
- flwr/common/typing.py +38 -0
- flwr/common/version.py +14 -0
- flwr/proto/clientappio_pb2.py +41 -0
- flwr/proto/clientappio_pb2.pyi +110 -0
- flwr/proto/clientappio_pb2_grpc.py +101 -0
- flwr/proto/clientappio_pb2_grpc.pyi +40 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +26 -19
- flwr/proto/driver_pb2.pyi +34 -0
- flwr/proto/driver_pb2_grpc.py +70 -0
- flwr/proto/driver_pb2_grpc.pyi +28 -0
- flwr/proto/exec_pb2.py +43 -0
- flwr/proto/exec_pb2.pyi +95 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +29 -23
- flwr/proto/fleet_pb2.pyi +33 -0
- flwr/proto/fleet_pb2_grpc.py +102 -0
- flwr/proto/fleet_pb2_grpc.pyi +35 -0
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/message_pb2.py +41 -0
- flwr/proto/message_pb2.pyi +122 -0
- flwr/proto/message_pb2_grpc.py +4 -0
- flwr/proto/message_pb2_grpc.pyi +4 -0
- flwr/proto/run_pb2.py +35 -0
- flwr/proto/run_pb2.pyi +76 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +7 -8
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/__init__.py +4 -8
- flwr/server/app.py +298 -350
- flwr/server/compat/app.py +6 -57
- flwr/server/compat/app_utils.py +5 -4
- flwr/server/compat/driver_client_proxy.py +29 -48
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/__init__.py +2 -0
- flwr/server/driver/driver.py +22 -132
- flwr/server/driver/grpc_driver.py +224 -74
- flwr/server/driver/inmemory_driver.py +183 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +121 -34
- flwr/server/server.py +11 -7
- flwr/server/server_app.py +59 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dp_adaptive_clipping.py +3 -3
- flwr/server/strategy/dp_fixed_clipping.py +4 -3
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +51 -4
- flwr/server/superlink/ffs/__init__.py +24 -0
- flwr/server/superlink/ffs/disk_ffs.py +104 -0
- flwr/server/superlink/ffs/ffs.py +79 -0
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
- flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
- flwr/server/superlink/fleet/vce/vce_api.py +190 -127
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +159 -42
- flwr/server/superlink/state/sqlite_state.py +243 -39
- flwr/server/superlink/state/state.py +81 -6
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +62 -0
- flwr/server/typing.py +2 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +67 -25
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
- flwr/simulation/__init__.py +7 -4
- flwr/simulation/app.py +67 -36
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +20 -46
- flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
- flwr/simulation/run_simulation.py +308 -92
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +184 -0
- flwr/superexec/deployment.py +185 -0
- flwr/superexec/exec_grpc.py +55 -0
- flwr/superexec/exec_servicer.py +70 -0
- flwr/superexec/executor.py +75 -0
- flwr/superexec/simulation.py +193 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
- flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
- flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
- flwr/cli/flower_toml.py +0 -140
- flwr/cli/new/templates/app/flower.toml.tpl +0 -13
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
- flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
- flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
- flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
flwr/server/compat/app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,14 +15,11 @@
|
|
|
15
15
|
"""Flower driver app."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import sys
|
|
19
18
|
from logging import INFO
|
|
20
|
-
from
|
|
21
|
-
from typing import Optional, Union
|
|
19
|
+
from typing import Optional
|
|
22
20
|
|
|
23
21
|
from flwr.common import EventType, event
|
|
24
|
-
from flwr.common.
|
|
25
|
-
from flwr.common.logger import log, warn_deprecated_feature
|
|
22
|
+
from flwr.common.logger import log
|
|
26
23
|
from flwr.server.client_manager import ClientManager
|
|
27
24
|
from flwr.server.history import History
|
|
28
25
|
from flwr.server.server import Server, init_defaults, run_fl
|
|
@@ -32,33 +29,21 @@ from flwr.server.strategy import Strategy
|
|
|
32
29
|
from ..driver import Driver
|
|
33
30
|
from .app_utils import start_update_client_manager_thread
|
|
34
31
|
|
|
35
|
-
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
36
|
-
|
|
37
|
-
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
|
|
38
|
-
[Driver] Error: Not connected.
|
|
39
|
-
|
|
40
|
-
Call `connect()` on the `Driver` instance before calling any of the other `Driver`
|
|
41
|
-
methods.
|
|
42
|
-
"""
|
|
43
|
-
|
|
44
32
|
|
|
45
33
|
def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
46
34
|
*,
|
|
47
|
-
|
|
35
|
+
driver: Driver,
|
|
48
36
|
server: Optional[Server] = None,
|
|
49
37
|
config: Optional[ServerConfig] = None,
|
|
50
38
|
strategy: Optional[Strategy] = None,
|
|
51
39
|
client_manager: Optional[ClientManager] = None,
|
|
52
|
-
root_certificates: Optional[Union[bytes, str]] = None,
|
|
53
|
-
driver: Optional[Driver] = None,
|
|
54
40
|
) -> History:
|
|
55
41
|
"""Start a Flower Driver API server.
|
|
56
42
|
|
|
57
43
|
Parameters
|
|
58
44
|
----------
|
|
59
|
-
|
|
60
|
-
The
|
|
61
|
-
Defaults to `"[::]:8080"`.
|
|
45
|
+
driver : Driver
|
|
46
|
+
The Driver object to use.
|
|
62
47
|
server : Optional[flwr.server.Server] (default: None)
|
|
63
48
|
A server implementation, either `flwr.server.Server` or a subclass
|
|
64
49
|
thereof. If no instance is provided, then `start_driver` will create
|
|
@@ -74,50 +59,14 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
|
74
59
|
An implementation of the class `flwr.server.ClientManager`. If no
|
|
75
60
|
implementation is provided, then `start_driver` will use
|
|
76
61
|
`flwr.server.SimpleClientManager`.
|
|
77
|
-
root_certificates : Optional[Union[bytes, str]] (default: None)
|
|
78
|
-
The PEM-encoded root certificates as a byte string or a path string.
|
|
79
|
-
If provided, a secure connection using the certificates will be
|
|
80
|
-
established to an SSL-enabled Flower server.
|
|
81
|
-
driver : Optional[Driver] (default: None)
|
|
82
|
-
The Driver object to use.
|
|
83
62
|
|
|
84
63
|
Returns
|
|
85
64
|
-------
|
|
86
65
|
hist : flwr.server.history.History
|
|
87
66
|
Object containing training and evaluation metrics.
|
|
88
|
-
|
|
89
|
-
Examples
|
|
90
|
-
--------
|
|
91
|
-
Starting a driver that connects to an insecure server:
|
|
92
|
-
|
|
93
|
-
>>> start_driver()
|
|
94
|
-
|
|
95
|
-
Starting a driver that connects to an SSL-enabled server:
|
|
96
|
-
|
|
97
|
-
>>> start_driver(
|
|
98
|
-
>>> root_certificates=Path("/crts/root.pem").read_bytes()
|
|
99
|
-
>>> )
|
|
100
67
|
"""
|
|
101
68
|
event(EventType.START_DRIVER_ENTER)
|
|
102
69
|
|
|
103
|
-
if driver is None:
|
|
104
|
-
# Not passing a `Driver` object is deprecated
|
|
105
|
-
warn_deprecated_feature("start_driver")
|
|
106
|
-
|
|
107
|
-
# Parse IP address
|
|
108
|
-
parsed_address = parse_address(server_address)
|
|
109
|
-
if not parsed_address:
|
|
110
|
-
sys.exit(f"Server IP address ({server_address}) cannot be parsed.")
|
|
111
|
-
host, port, is_v6 = parsed_address
|
|
112
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
113
|
-
|
|
114
|
-
# Create the Driver
|
|
115
|
-
if isinstance(root_certificates, str):
|
|
116
|
-
root_certificates = Path(root_certificates).read_bytes()
|
|
117
|
-
driver = Driver(
|
|
118
|
-
driver_service_address=address, root_certificates=root_certificates
|
|
119
|
-
)
|
|
120
|
-
|
|
121
70
|
# Initialize the Driver API server and config
|
|
122
71
|
initialized_server, initialized_config = init_defaults(
|
|
123
72
|
server=server,
|
flwr/server/compat/app_utils.py
CHANGED
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
|
-
import time
|
|
20
19
|
from typing import Dict, Tuple
|
|
21
20
|
|
|
22
21
|
from ..client_manager import ClientManager
|
|
@@ -60,6 +59,7 @@ def start_update_client_manager_thread(
|
|
|
60
59
|
client_manager,
|
|
61
60
|
f_stop,
|
|
62
61
|
),
|
|
62
|
+
daemon=True,
|
|
63
63
|
)
|
|
64
64
|
thread.start()
|
|
65
65
|
|
|
@@ -89,9 +89,9 @@ def _update_client_manager(
|
|
|
89
89
|
for node_id in new_nodes:
|
|
90
90
|
client_proxy = DriverClientProxy(
|
|
91
91
|
node_id=node_id,
|
|
92
|
-
driver=driver
|
|
92
|
+
driver=driver,
|
|
93
93
|
anonymous=False,
|
|
94
|
-
run_id=driver.run_id,
|
|
94
|
+
run_id=driver.run.run_id,
|
|
95
95
|
)
|
|
96
96
|
if client_manager.register(client_proxy):
|
|
97
97
|
registered_nodes[node_id] = client_proxy
|
|
@@ -99,4 +99,5 @@ def _update_client_manager(
|
|
|
99
99
|
raise RuntimeError("Could not register node.")
|
|
100
100
|
|
|
101
101
|
# Sleep for 3 seconds
|
|
102
|
-
|
|
102
|
+
if not f_stop.is_set():
|
|
103
|
+
f_stop.wait(3)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -16,16 +16,14 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
|
-
from typing import
|
|
19
|
+
from typing import Optional
|
|
20
20
|
|
|
21
21
|
from flwr import common
|
|
22
|
-
from flwr.common import MessageType, MessageTypeLegacy, RecordSet
|
|
22
|
+
from flwr.common import Message, MessageType, MessageTypeLegacy, RecordSet
|
|
23
23
|
from flwr.common import recordset_compat as compat
|
|
24
|
-
from flwr.common import serde
|
|
25
|
-
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
|
|
26
24
|
from flwr.server.client_proxy import ClientProxy
|
|
27
25
|
|
|
28
|
-
from ..driver.
|
|
26
|
+
from ..driver.driver import Driver
|
|
29
27
|
|
|
30
28
|
SLEEP_TIME = 1
|
|
31
29
|
|
|
@@ -33,7 +31,7 @@ SLEEP_TIME = 1
|
|
|
33
31
|
class DriverClientProxy(ClientProxy):
|
|
34
32
|
"""Flower client proxy which delegates work using the Driver API."""
|
|
35
33
|
|
|
36
|
-
def __init__(self, node_id: int, driver:
|
|
34
|
+
def __init__(self, node_id: int, driver: Driver, anonymous: bool, run_id: int):
|
|
37
35
|
super().__init__(str(node_id))
|
|
38
36
|
self.node_id = node_id
|
|
39
37
|
self.driver = driver
|
|
@@ -114,55 +112,38 @@ class DriverClientProxy(ClientProxy):
|
|
|
114
112
|
timeout: Optional[float],
|
|
115
113
|
group_id: Optional[int],
|
|
116
114
|
) -> RecordSet:
|
|
117
|
-
task_ins = task_pb2.TaskIns( # pylint: disable=E1101
|
|
118
|
-
task_id="",
|
|
119
|
-
group_id=str(group_id) if group_id is not None else "",
|
|
120
|
-
run_id=self.run_id,
|
|
121
|
-
task=task_pb2.Task( # pylint: disable=E1101
|
|
122
|
-
producer=node_pb2.Node( # pylint: disable=E1101
|
|
123
|
-
node_id=0,
|
|
124
|
-
anonymous=True,
|
|
125
|
-
),
|
|
126
|
-
consumer=node_pb2.Node( # pylint: disable=E1101
|
|
127
|
-
node_id=self.node_id,
|
|
128
|
-
anonymous=self.anonymous,
|
|
129
|
-
),
|
|
130
|
-
task_type=task_type,
|
|
131
|
-
recordset=serde.recordset_to_proto(recordset),
|
|
132
|
-
),
|
|
133
|
-
)
|
|
134
|
-
push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101
|
|
135
|
-
task_ins_list=[task_ins]
|
|
136
|
-
)
|
|
137
115
|
|
|
138
|
-
#
|
|
139
|
-
|
|
116
|
+
# Create message
|
|
117
|
+
message = self.driver.create_message(
|
|
118
|
+
content=recordset,
|
|
119
|
+
message_type=task_type,
|
|
120
|
+
dst_node_id=self.node_id,
|
|
121
|
+
group_id=str(group_id) if group_id else "",
|
|
122
|
+
ttl=timeout,
|
|
123
|
+
)
|
|
140
124
|
|
|
141
|
-
|
|
142
|
-
|
|
125
|
+
# Push message
|
|
126
|
+
message_ids = list(self.driver.push_messages(messages=[message]))
|
|
127
|
+
if len(message_ids) != 1:
|
|
128
|
+
raise ValueError("Unexpected number of message_ids")
|
|
143
129
|
|
|
144
|
-
|
|
145
|
-
if
|
|
146
|
-
raise ValueError(f"Failed to
|
|
130
|
+
message_id = message_ids[0]
|
|
131
|
+
if message_id == "":
|
|
132
|
+
raise ValueError(f"Failed to send message to node {self.node_id}")
|
|
147
133
|
|
|
148
134
|
if timeout:
|
|
149
135
|
start_time = time.time()
|
|
150
136
|
|
|
151
137
|
while True:
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
pull_task_res_res.task_res_list
|
|
162
|
-
)
|
|
163
|
-
if len(task_res_list) == 1:
|
|
164
|
-
task_res = task_res_list[0]
|
|
165
|
-
return serde.recordset_from_proto(task_res.task.recordset)
|
|
138
|
+
messages = list(self.driver.pull_messages(message_ids))
|
|
139
|
+
if len(messages) == 1:
|
|
140
|
+
msg: Message = messages[0]
|
|
141
|
+
if msg.has_error():
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"Message contains an Error (reason: {msg.error.reason}). "
|
|
144
|
+
"It originated during client-side execution of a message."
|
|
145
|
+
)
|
|
146
|
+
return msg.content
|
|
166
147
|
|
|
167
148
|
if timeout is not None and time.time() > start_time + timeout:
|
|
168
149
|
raise RuntimeError("Timeout reached")
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
from dataclasses import dataclass
|
|
19
19
|
from typing import Optional
|
|
20
20
|
|
|
21
|
-
from flwr.common import Context
|
|
21
|
+
from flwr.common import Context
|
|
22
22
|
|
|
23
23
|
from ..client_manager import ClientManager, SimpleClientManager
|
|
24
24
|
from ..history import History
|
|
@@ -35,9 +35,9 @@ class LegacyContext(Context):
|
|
|
35
35
|
client_manager: ClientManager
|
|
36
36
|
history: History
|
|
37
37
|
|
|
38
|
-
def __init__(
|
|
38
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
39
39
|
self,
|
|
40
|
-
|
|
40
|
+
context: Context,
|
|
41
41
|
config: Optional[ServerConfig] = None,
|
|
42
42
|
strategy: Optional[Strategy] = None,
|
|
43
43
|
client_manager: Optional[ClientManager] = None,
|
|
@@ -52,4 +52,5 @@ class LegacyContext(Context):
|
|
|
52
52
|
self.strategy = strategy
|
|
53
53
|
self.client_manager = client_manager
|
|
54
54
|
self.history = History()
|
|
55
|
-
|
|
55
|
+
|
|
56
|
+
super().__init__(**vars(context))
|
flwr/server/driver/__init__.py
CHANGED
flwr/server/driver/driver.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,85 +12,32 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""Driver (abstract base class)."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import
|
|
19
|
-
from typing import Iterable, List, Optional
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
from typing import Iterable, List, Optional
|
|
20
20
|
|
|
21
|
-
from flwr.common import Message,
|
|
22
|
-
from flwr.common.
|
|
23
|
-
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
24
|
-
CreateRunRequest,
|
|
25
|
-
GetNodesRequest,
|
|
26
|
-
PullTaskResRequest,
|
|
27
|
-
PushTaskInsRequest,
|
|
28
|
-
)
|
|
29
|
-
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
30
|
-
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
21
|
+
from flwr.common import Message, RecordSet
|
|
22
|
+
from flwr.common.typing import Run
|
|
31
23
|
|
|
32
|
-
from .grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver
|
|
33
24
|
|
|
25
|
+
class Driver(ABC):
|
|
26
|
+
"""Abstract base Driver class for the Driver API."""
|
|
34
27
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
----------
|
|
40
|
-
driver_service_address : Optional[str]
|
|
41
|
-
The IPv4 or IPv6 address of the Driver API server.
|
|
42
|
-
Defaults to `"[::]:9091"`.
|
|
43
|
-
certificates : bytes (default: None)
|
|
44
|
-
Tuple containing root certificate, server certificate, and private key
|
|
45
|
-
to start a secure SSL-enabled server. The tuple is expected to have
|
|
46
|
-
three bytes elements in the following order:
|
|
47
|
-
|
|
48
|
-
* CA certificate.
|
|
49
|
-
* server certificate.
|
|
50
|
-
* server private key.
|
|
51
|
-
"""
|
|
52
|
-
|
|
53
|
-
def __init__(
|
|
54
|
-
self,
|
|
55
|
-
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
|
|
56
|
-
root_certificates: Optional[bytes] = None,
|
|
57
|
-
) -> None:
|
|
58
|
-
self.addr = driver_service_address
|
|
59
|
-
self.root_certificates = root_certificates
|
|
60
|
-
self.grpc_driver: Optional[GrpcDriver] = None
|
|
61
|
-
self.run_id: Optional[int] = None
|
|
62
|
-
self.node = Node(node_id=0, anonymous=True)
|
|
63
|
-
|
|
64
|
-
def _get_grpc_driver_and_run_id(self) -> Tuple[GrpcDriver, int]:
|
|
65
|
-
# Check if the GrpcDriver is initialized
|
|
66
|
-
if self.grpc_driver is None or self.run_id is None:
|
|
67
|
-
# Connect and create run
|
|
68
|
-
self.grpc_driver = GrpcDriver(
|
|
69
|
-
driver_service_address=self.addr,
|
|
70
|
-
root_certificates=self.root_certificates,
|
|
71
|
-
)
|
|
72
|
-
self.grpc_driver.connect()
|
|
73
|
-
res = self.grpc_driver.create_run(CreateRunRequest())
|
|
74
|
-
self.run_id = res.run_id
|
|
75
|
-
return self.grpc_driver, self.run_id
|
|
76
|
-
|
|
77
|
-
def _check_message(self, message: Message) -> None:
|
|
78
|
-
# Check if the message is valid
|
|
79
|
-
if not (
|
|
80
|
-
message.metadata.run_id == self.run_id
|
|
81
|
-
and message.metadata.src_node_id == self.node.node_id
|
|
82
|
-
and message.metadata.message_id == ""
|
|
83
|
-
and message.metadata.reply_to_message == ""
|
|
84
|
-
):
|
|
85
|
-
raise ValueError(f"Invalid message: {message}")
|
|
28
|
+
@property
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def run(self) -> Run:
|
|
31
|
+
"""Run information."""
|
|
86
32
|
|
|
33
|
+
@abstractmethod
|
|
87
34
|
def create_message( # pylint: disable=too-many-arguments
|
|
88
35
|
self,
|
|
89
36
|
content: RecordSet,
|
|
90
37
|
message_type: str,
|
|
91
38
|
dst_node_id: int,
|
|
92
39
|
group_id: str,
|
|
93
|
-
ttl:
|
|
40
|
+
ttl: Optional[float] = None,
|
|
94
41
|
) -> Message:
|
|
95
42
|
"""Create a new message with specified parameters.
|
|
96
43
|
|
|
@@ -110,36 +57,23 @@ class Driver:
|
|
|
110
57
|
group_id : str
|
|
111
58
|
The ID of the group to which this message is associated. In some settings,
|
|
112
59
|
this is used as the FL round.
|
|
113
|
-
ttl :
|
|
60
|
+
ttl : Optional[float] (default: None)
|
|
114
61
|
Time-to-live for the round trip of this message, i.e., the time from sending
|
|
115
|
-
this message to receiving a reply. It specifies the duration for
|
|
116
|
-
message and its potential reply are considered valid.
|
|
62
|
+
this message to receiving a reply. It specifies in seconds the duration for
|
|
63
|
+
which the message and its potential reply are considered valid. If unset,
|
|
64
|
+
the default TTL (i.e., `common.DEFAULT_TTL`) will be used.
|
|
117
65
|
|
|
118
66
|
Returns
|
|
119
67
|
-------
|
|
120
68
|
message : Message
|
|
121
69
|
A new `Message` instance with the specified content and metadata.
|
|
122
70
|
"""
|
|
123
|
-
_, run_id = self._get_grpc_driver_and_run_id()
|
|
124
|
-
metadata = Metadata(
|
|
125
|
-
run_id=run_id,
|
|
126
|
-
message_id="", # Will be set by the server
|
|
127
|
-
src_node_id=self.node.node_id,
|
|
128
|
-
dst_node_id=dst_node_id,
|
|
129
|
-
reply_to_message="",
|
|
130
|
-
group_id=group_id,
|
|
131
|
-
ttl=ttl,
|
|
132
|
-
message_type=message_type,
|
|
133
|
-
)
|
|
134
|
-
return Message(metadata=metadata, content=content)
|
|
135
71
|
|
|
72
|
+
@abstractmethod
|
|
136
73
|
def get_node_ids(self) -> List[int]:
|
|
137
74
|
"""Get node IDs."""
|
|
138
|
-
grpc_driver, run_id = self._get_grpc_driver_and_run_id()
|
|
139
|
-
# Call GrpcDriver method
|
|
140
|
-
res = grpc_driver.get_nodes(GetNodesRequest(run_id=run_id))
|
|
141
|
-
return [node.node_id for node in res.nodes]
|
|
142
75
|
|
|
76
|
+
@abstractmethod
|
|
143
77
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
144
78
|
"""Push messages to specified node IDs.
|
|
145
79
|
|
|
@@ -157,20 +91,8 @@ class Driver:
|
|
|
157
91
|
An iterable of IDs for the messages that were sent, which can be used
|
|
158
92
|
to pull replies.
|
|
159
93
|
"""
|
|
160
|
-
grpc_driver, _ = self._get_grpc_driver_and_run_id()
|
|
161
|
-
# Construct TaskIns
|
|
162
|
-
task_ins_list: List[TaskIns] = []
|
|
163
|
-
for msg in messages:
|
|
164
|
-
# Check message
|
|
165
|
-
self._check_message(msg)
|
|
166
|
-
# Convert Message to TaskIns
|
|
167
|
-
taskins = message_to_taskins(msg)
|
|
168
|
-
# Add to list
|
|
169
|
-
task_ins_list.append(taskins)
|
|
170
|
-
# Call GrpcDriver method
|
|
171
|
-
res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
|
|
172
|
-
return list(res.task_ids)
|
|
173
94
|
|
|
95
|
+
@abstractmethod
|
|
174
96
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
175
97
|
"""Pull messages based on message IDs.
|
|
176
98
|
|
|
@@ -187,15 +109,8 @@ class Driver:
|
|
|
187
109
|
messages : Iterable[Message]
|
|
188
110
|
An iterable of messages received.
|
|
189
111
|
"""
|
|
190
|
-
grpc_driver, _ = self._get_grpc_driver_and_run_id()
|
|
191
|
-
# Pull TaskRes
|
|
192
|
-
res = grpc_driver.pull_task_res(
|
|
193
|
-
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
194
|
-
)
|
|
195
|
-
# Convert TaskRes to Message
|
|
196
|
-
msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
|
|
197
|
-
return msgs
|
|
198
112
|
|
|
113
|
+
@abstractmethod
|
|
199
114
|
def send_and_receive(
|
|
200
115
|
self,
|
|
201
116
|
messages: Iterable[Message],
|
|
@@ -229,28 +144,3 @@ class Driver:
|
|
|
229
144
|
replies for all sent messages. A message remains valid until its TTL,
|
|
230
145
|
which is not affected by `timeout`.
|
|
231
146
|
"""
|
|
232
|
-
# Push messages
|
|
233
|
-
msg_ids = set(self.push_messages(messages))
|
|
234
|
-
|
|
235
|
-
# Pull messages
|
|
236
|
-
end_time = time.time() + (timeout if timeout is not None else 0.0)
|
|
237
|
-
ret: List[Message] = []
|
|
238
|
-
while timeout is None or time.time() < end_time:
|
|
239
|
-
res_msgs = self.pull_messages(msg_ids)
|
|
240
|
-
ret.extend(res_msgs)
|
|
241
|
-
msg_ids.difference_update(
|
|
242
|
-
{msg.metadata.reply_to_message for msg in res_msgs}
|
|
243
|
-
)
|
|
244
|
-
if len(msg_ids) == 0:
|
|
245
|
-
break
|
|
246
|
-
# Sleep
|
|
247
|
-
time.sleep(3)
|
|
248
|
-
return ret
|
|
249
|
-
|
|
250
|
-
def close(self) -> None:
|
|
251
|
-
"""Disconnect from the SuperLink if connected."""
|
|
252
|
-
# Check if GrpcDriver is initialized
|
|
253
|
-
if self.grpc_driver is None:
|
|
254
|
-
return
|
|
255
|
-
# Disconnect
|
|
256
|
-
self.grpc_driver.disconnect()
|