flwr-nightly 1.8.0.dev20240315__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +7 -0
- flwr/cli/build.py +150 -0
- flwr/cli/config_utils.py +219 -0
- flwr/cli/example.py +3 -1
- flwr/cli/install.py +227 -0
- flwr/cli/new/new.py +179 -48
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/README.md.tpl +1 -5
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
- flwr/cli/run/run.py +168 -17
- flwr/cli/utils.py +75 -4
- flwr/client/__init__.py +6 -1
- flwr/client/app.py +239 -248
- flwr/client/client_app.py +70 -9
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +97 -0
- flwr/client/grpc_client/connection.py +18 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +127 -33
- flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
- flwr/client/heartbeat.py +74 -0
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +7 -7
- flwr/client/mod/__init__.py +5 -5
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/comms_mods.py +4 -4
- flwr/client/mod/localdp_mod.py +9 -4
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +177 -157
- flwr/client/supernode/__init__.py +26 -0
- flwr/client/supernode/app.py +464 -0
- flwr/client/typing.py +1 -0
- flwr/common/__init__.py +13 -11
- flwr/common/address.py +1 -1
- flwr/common/config.py +193 -0
- flwr/common/constant.py +42 -1
- flwr/common/context.py +26 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +6 -2
- flwr/common/logger.py +79 -8
- flwr/common/message.py +167 -105
- flwr/common/object_ref.py +126 -25
- flwr/common/record/__init__.py +1 -1
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/recordset_compat.py +8 -1
- flwr/common/retry_invoker.py +25 -13
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +209 -3
- flwr/common/telemetry.py +25 -0
- flwr/common/typing.py +38 -0
- flwr/common/version.py +14 -0
- flwr/proto/clientappio_pb2.py +41 -0
- flwr/proto/clientappio_pb2.pyi +110 -0
- flwr/proto/clientappio_pb2_grpc.py +101 -0
- flwr/proto/clientappio_pb2_grpc.pyi +40 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +26 -19
- flwr/proto/driver_pb2.pyi +34 -0
- flwr/proto/driver_pb2_grpc.py +70 -0
- flwr/proto/driver_pb2_grpc.pyi +28 -0
- flwr/proto/exec_pb2.py +43 -0
- flwr/proto/exec_pb2.pyi +95 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +29 -23
- flwr/proto/fleet_pb2.pyi +33 -0
- flwr/proto/fleet_pb2_grpc.py +102 -0
- flwr/proto/fleet_pb2_grpc.pyi +35 -0
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/message_pb2.py +41 -0
- flwr/proto/message_pb2.pyi +122 -0
- flwr/proto/message_pb2_grpc.py +4 -0
- flwr/proto/message_pb2_grpc.pyi +4 -0
- flwr/proto/run_pb2.py +35 -0
- flwr/proto/run_pb2.pyi +76 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +7 -8
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/__init__.py +4 -8
- flwr/server/app.py +298 -350
- flwr/server/compat/app.py +6 -57
- flwr/server/compat/app_utils.py +5 -4
- flwr/server/compat/driver_client_proxy.py +29 -48
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/__init__.py +2 -0
- flwr/server/driver/driver.py +22 -132
- flwr/server/driver/grpc_driver.py +224 -74
- flwr/server/driver/inmemory_driver.py +183 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +121 -34
- flwr/server/server.py +11 -7
- flwr/server/server_app.py +59 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dp_adaptive_clipping.py +3 -3
- flwr/server/strategy/dp_fixed_clipping.py +4 -3
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +51 -4
- flwr/server/superlink/ffs/__init__.py +24 -0
- flwr/server/superlink/ffs/disk_ffs.py +104 -0
- flwr/server/superlink/ffs/ffs.py +79 -0
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
- flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
- flwr/server/superlink/fleet/vce/vce_api.py +190 -127
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +159 -42
- flwr/server/superlink/state/sqlite_state.py +243 -39
- flwr/server/superlink/state/state.py +81 -6
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +62 -0
- flwr/server/typing.py +2 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +67 -25
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
- flwr/simulation/__init__.py +7 -4
- flwr/simulation/app.py +67 -36
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +20 -46
- flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
- flwr/simulation/run_simulation.py +308 -92
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +184 -0
- flwr/superexec/deployment.py +185 -0
- flwr/superexec/exec_grpc.py +55 -0
- flwr/superexec/exec_servicer.py +70 -0
- flwr/superexec/executor.py +75 -0
- flwr/superexec/simulation.py +193 -0
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
- flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
- flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
- flwr/cli/flower_toml.py +0 -140
- flwr/cli/new/templates/app/flower.toml.tpl +0 -13
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
- flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
- flwr_nightly-1.8.0.dev20240315.dist-info/RECORD +0 -211
- flwr_nightly-1.8.0.dev20240315.dist-info/entry_points.txt +0 -9
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,20 +12,25 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower
|
|
15
|
+
"""Flower gRPC Driver."""
|
|
16
16
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
from
|
|
17
|
+
import time
|
|
18
|
+
import warnings
|
|
19
|
+
from logging import DEBUG, WARNING
|
|
20
|
+
from typing import Iterable, List, Optional, cast
|
|
20
21
|
|
|
21
22
|
import grpc
|
|
22
23
|
|
|
23
|
-
from flwr.common import EventType, event
|
|
24
|
+
from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
|
|
24
25
|
from flwr.common.grpc import create_channel
|
|
25
26
|
from flwr.common.logger import log
|
|
27
|
+
from flwr.common.serde import (
|
|
28
|
+
message_from_taskres,
|
|
29
|
+
message_to_taskins,
|
|
30
|
+
user_config_from_proto,
|
|
31
|
+
)
|
|
32
|
+
from flwr.common.typing import Run
|
|
26
33
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
27
|
-
CreateRunRequest,
|
|
28
|
-
CreateRunResponse,
|
|
29
34
|
GetNodesRequest,
|
|
30
35
|
GetNodesResponse,
|
|
31
36
|
PullTaskResRequest,
|
|
@@ -34,96 +39,241 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
|
34
39
|
PushTaskInsResponse,
|
|
35
40
|
)
|
|
36
41
|
from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
|
|
42
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
43
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
44
|
+
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
45
|
+
|
|
46
|
+
from .driver import Driver
|
|
37
47
|
|
|
38
48
|
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
39
49
|
|
|
40
50
|
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
|
|
41
51
|
[Driver] Error: Not connected.
|
|
42
52
|
|
|
43
|
-
Call `connect()` on the `
|
|
44
|
-
`
|
|
53
|
+
Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
|
|
54
|
+
`GrpcDriverStub` methods.
|
|
45
55
|
"""
|
|
46
56
|
|
|
47
57
|
|
|
48
|
-
class GrpcDriver:
|
|
49
|
-
"""`GrpcDriver` provides
|
|
58
|
+
class GrpcDriver(Driver):
|
|
59
|
+
"""`GrpcDriver` provides an interface to the Driver API.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
run_id : int
|
|
64
|
+
The identifier of the run.
|
|
65
|
+
driver_service_address : str (default: "[::]:9091")
|
|
66
|
+
The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
|
|
67
|
+
root_certificates : Optional[bytes] (default: None)
|
|
68
|
+
The PEM-encoded root certificates as a byte string.
|
|
69
|
+
If provided, a secure connection using the certificates will be
|
|
70
|
+
established to an SSL-enabled Flower server.
|
|
71
|
+
"""
|
|
50
72
|
|
|
51
|
-
def __init__(
|
|
73
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
52
74
|
self,
|
|
75
|
+
run_id: int,
|
|
53
76
|
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
|
|
54
77
|
root_certificates: Optional[bytes] = None,
|
|
55
78
|
) -> None:
|
|
56
|
-
self.
|
|
57
|
-
self.
|
|
58
|
-
self.
|
|
59
|
-
self.
|
|
79
|
+
self._run_id = run_id
|
|
80
|
+
self._addr = driver_service_address
|
|
81
|
+
self._cert = root_certificates
|
|
82
|
+
self._run: Optional[Run] = None
|
|
83
|
+
self._grpc_stub: Optional[DriverStub] = None
|
|
84
|
+
self._channel: Optional[grpc.Channel] = None
|
|
85
|
+
self.node = Node(node_id=0, anonymous=True)
|
|
60
86
|
|
|
61
|
-
|
|
62
|
-
|
|
87
|
+
@property
|
|
88
|
+
def _is_connected(self) -> bool:
|
|
89
|
+
"""Check if connected to the Driver API server."""
|
|
90
|
+
return self._channel is not None
|
|
91
|
+
|
|
92
|
+
def _connect(self) -> None:
|
|
93
|
+
"""Connect to the Driver API.
|
|
94
|
+
|
|
95
|
+
This will not call GetRun.
|
|
96
|
+
"""
|
|
63
97
|
event(EventType.DRIVER_CONNECT)
|
|
64
|
-
if self.
|
|
98
|
+
if self._is_connected:
|
|
65
99
|
log(WARNING, "Already connected")
|
|
66
100
|
return
|
|
67
|
-
self.
|
|
68
|
-
server_address=self.
|
|
69
|
-
insecure=(self.
|
|
70
|
-
root_certificates=self.
|
|
101
|
+
self._channel = create_channel(
|
|
102
|
+
server_address=self._addr,
|
|
103
|
+
insecure=(self._cert is None),
|
|
104
|
+
root_certificates=self._cert,
|
|
71
105
|
)
|
|
72
|
-
self.
|
|
73
|
-
log(DEBUG, "[Driver] Connected to %s", self.
|
|
106
|
+
self._grpc_stub = DriverStub(self._channel)
|
|
107
|
+
log(DEBUG, "[Driver] Connected to %s", self._addr)
|
|
74
108
|
|
|
75
|
-
def
|
|
109
|
+
def _disconnect(self) -> None:
|
|
76
110
|
"""Disconnect from the Driver API."""
|
|
77
111
|
event(EventType.DRIVER_DISCONNECT)
|
|
78
|
-
if
|
|
112
|
+
if not self._is_connected:
|
|
79
113
|
log(DEBUG, "Already disconnected")
|
|
80
114
|
return
|
|
81
|
-
channel = self.
|
|
82
|
-
self.
|
|
83
|
-
self.
|
|
115
|
+
channel: grpc.Channel = self._channel
|
|
116
|
+
self._channel = None
|
|
117
|
+
self._grpc_stub = None
|
|
84
118
|
channel.close()
|
|
85
119
|
log(DEBUG, "[Driver] Disconnected")
|
|
86
120
|
|
|
87
|
-
def
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
return
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
if self.
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
121
|
+
def _init_run(self) -> None:
|
|
122
|
+
# Check if is initialized
|
|
123
|
+
if self._run is not None:
|
|
124
|
+
return
|
|
125
|
+
# Get the run info
|
|
126
|
+
req = GetRunRequest(run_id=self._run_id)
|
|
127
|
+
res: GetRunResponse = self._stub.GetRun(req)
|
|
128
|
+
if not res.HasField("run"):
|
|
129
|
+
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
|
|
130
|
+
self._run = Run(
|
|
131
|
+
run_id=res.run.run_id,
|
|
132
|
+
fab_id=res.run.fab_id,
|
|
133
|
+
fab_version=res.run.fab_version,
|
|
134
|
+
override_config=user_config_from_proto(res.run.override_config),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def run(self) -> Run:
|
|
139
|
+
"""Run information."""
|
|
140
|
+
self._init_run()
|
|
141
|
+
return Run(**vars(self._run))
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def _stub(self) -> DriverStub:
|
|
145
|
+
"""Driver stub."""
|
|
146
|
+
if not self._is_connected:
|
|
147
|
+
self._connect()
|
|
148
|
+
return cast(DriverStub, self._grpc_stub)
|
|
149
|
+
|
|
150
|
+
def _check_message(self, message: Message) -> None:
|
|
151
|
+
# Check if the message is valid
|
|
152
|
+
if not (
|
|
153
|
+
# Assume self._run being initialized
|
|
154
|
+
message.metadata.run_id == self._run_id
|
|
155
|
+
and message.metadata.src_node_id == self.node.node_id
|
|
156
|
+
and message.metadata.message_id == ""
|
|
157
|
+
and message.metadata.reply_to_message == ""
|
|
158
|
+
and message.metadata.ttl > 0
|
|
159
|
+
):
|
|
160
|
+
raise ValueError(f"Invalid message: {message}")
|
|
161
|
+
|
|
162
|
+
def create_message( # pylint: disable=too-many-arguments
|
|
163
|
+
self,
|
|
164
|
+
content: RecordSet,
|
|
165
|
+
message_type: str,
|
|
166
|
+
dst_node_id: int,
|
|
167
|
+
group_id: str,
|
|
168
|
+
ttl: Optional[float] = None,
|
|
169
|
+
) -> Message:
|
|
170
|
+
"""Create a new message with specified parameters.
|
|
171
|
+
|
|
172
|
+
This method constructs a new `Message` with given content and metadata.
|
|
173
|
+
The `run_id` and `src_node_id` will be set automatically.
|
|
174
|
+
"""
|
|
175
|
+
self._init_run()
|
|
176
|
+
if ttl:
|
|
177
|
+
warnings.warn(
|
|
178
|
+
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
179
|
+
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
|
|
180
|
+
"version of Flower.",
|
|
181
|
+
stacklevel=2,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
185
|
+
metadata = Metadata(
|
|
186
|
+
run_id=self._run_id,
|
|
187
|
+
message_id="", # Will be set by the server
|
|
188
|
+
src_node_id=self.node.node_id,
|
|
189
|
+
dst_node_id=dst_node_id,
|
|
190
|
+
reply_to_message="",
|
|
191
|
+
group_id=group_id,
|
|
192
|
+
ttl=ttl_,
|
|
193
|
+
message_type=message_type,
|
|
194
|
+
)
|
|
195
|
+
return Message(metadata=metadata, content=content)
|
|
196
|
+
|
|
197
|
+
def get_node_ids(self) -> List[int]:
|
|
198
|
+
"""Get node IDs."""
|
|
199
|
+
self._init_run()
|
|
200
|
+
# Call GrpcDriverStub method
|
|
201
|
+
res: GetNodesResponse = self._stub.GetNodes(
|
|
202
|
+
GetNodesRequest(run_id=self._run_id)
|
|
203
|
+
)
|
|
204
|
+
return [node.node_id for node in res.nodes]
|
|
205
|
+
|
|
206
|
+
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
207
|
+
"""Push messages to specified node IDs.
|
|
208
|
+
|
|
209
|
+
This method takes an iterable of messages and sends each message
|
|
210
|
+
to the node specified in `dst_node_id`.
|
|
211
|
+
"""
|
|
212
|
+
self._init_run()
|
|
213
|
+
# Construct TaskIns
|
|
214
|
+
task_ins_list: List[TaskIns] = []
|
|
215
|
+
for msg in messages:
|
|
216
|
+
# Check message
|
|
217
|
+
self._check_message(msg)
|
|
218
|
+
# Convert Message to TaskIns
|
|
219
|
+
taskins = message_to_taskins(msg)
|
|
220
|
+
# Add to list
|
|
221
|
+
task_ins_list.append(taskins)
|
|
222
|
+
# Call GrpcDriverStub method
|
|
223
|
+
res: PushTaskInsResponse = self._stub.PushTaskIns(
|
|
224
|
+
PushTaskInsRequest(task_ins_list=task_ins_list)
|
|
225
|
+
)
|
|
226
|
+
return list(res.task_ids)
|
|
227
|
+
|
|
228
|
+
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
229
|
+
"""Pull messages based on message IDs.
|
|
230
|
+
|
|
231
|
+
This method is used to collect messages from the SuperLink that correspond to a
|
|
232
|
+
set of given message IDs.
|
|
233
|
+
"""
|
|
234
|
+
self._init_run()
|
|
235
|
+
# Pull TaskRes
|
|
236
|
+
res: PullTaskResResponse = self._stub.PullTaskRes(
|
|
237
|
+
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
238
|
+
)
|
|
239
|
+
# Convert TaskRes to Message
|
|
240
|
+
msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
|
|
241
|
+
return msgs
|
|
242
|
+
|
|
243
|
+
def send_and_receive(
|
|
244
|
+
self,
|
|
245
|
+
messages: Iterable[Message],
|
|
246
|
+
*,
|
|
247
|
+
timeout: Optional[float] = None,
|
|
248
|
+
) -> Iterable[Message]:
|
|
249
|
+
"""Push messages to specified node IDs and pull the reply messages.
|
|
250
|
+
|
|
251
|
+
This method sends a list of messages to their destination node IDs and then
|
|
252
|
+
waits for the replies. It continues to pull replies until either all replies are
|
|
253
|
+
received or the specified timeout duration is exceeded.
|
|
254
|
+
"""
|
|
255
|
+
# Push messages
|
|
256
|
+
msg_ids = set(self.push_messages(messages))
|
|
257
|
+
|
|
258
|
+
# Pull messages
|
|
259
|
+
end_time = time.time() + (timeout if timeout is not None else 0.0)
|
|
260
|
+
ret: List[Message] = []
|
|
261
|
+
while timeout is None or time.time() < end_time:
|
|
262
|
+
res_msgs = self.pull_messages(msg_ids)
|
|
263
|
+
ret.extend(res_msgs)
|
|
264
|
+
msg_ids.difference_update(
|
|
265
|
+
{msg.metadata.reply_to_message for msg in res_msgs}
|
|
266
|
+
)
|
|
267
|
+
if len(msg_ids) == 0:
|
|
268
|
+
break
|
|
269
|
+
# Sleep
|
|
270
|
+
time.sleep(3)
|
|
271
|
+
return ret
|
|
272
|
+
|
|
273
|
+
def close(self) -> None:
|
|
274
|
+
"""Disconnect from the SuperLink if connected."""
|
|
275
|
+
# Check if `connect` was called before
|
|
276
|
+
if not self._is_connected:
|
|
277
|
+
return
|
|
278
|
+
# Disconnect
|
|
279
|
+
self._disconnect()
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower in-memory Driver."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import time
|
|
19
|
+
import warnings
|
|
20
|
+
from typing import Iterable, List, Optional, cast
|
|
21
|
+
from uuid import UUID
|
|
22
|
+
|
|
23
|
+
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
24
|
+
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
25
|
+
from flwr.common.typing import Run
|
|
26
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
27
|
+
from flwr.server.superlink.state import StateFactory
|
|
28
|
+
|
|
29
|
+
from .driver import Driver
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class InMemoryDriver(Driver):
|
|
33
|
+
"""`InMemoryDriver` class provides an interface to the Driver API.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
run_id : int
|
|
38
|
+
The identifier of the run.
|
|
39
|
+
state_factory : StateFactory
|
|
40
|
+
A StateFactory embedding a state that this driver can interface with.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
run_id: int,
|
|
46
|
+
state_factory: StateFactory,
|
|
47
|
+
) -> None:
|
|
48
|
+
self._run_id = run_id
|
|
49
|
+
self._run: Optional[Run] = None
|
|
50
|
+
self.state = state_factory.state()
|
|
51
|
+
self.node = Node(node_id=0, anonymous=True)
|
|
52
|
+
|
|
53
|
+
def _check_message(self, message: Message) -> None:
|
|
54
|
+
self._init_run()
|
|
55
|
+
# Check if the message is valid
|
|
56
|
+
if not (
|
|
57
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
58
|
+
and message.metadata.src_node_id == self.node.node_id
|
|
59
|
+
and message.metadata.message_id == ""
|
|
60
|
+
and message.metadata.reply_to_message == ""
|
|
61
|
+
and message.metadata.ttl > 0
|
|
62
|
+
):
|
|
63
|
+
raise ValueError(f"Invalid message: {message}")
|
|
64
|
+
|
|
65
|
+
def _init_run(self) -> None:
|
|
66
|
+
"""Initialize the run."""
|
|
67
|
+
if self._run is not None:
|
|
68
|
+
return
|
|
69
|
+
run = self.state.get_run(self._run_id)
|
|
70
|
+
if run is None:
|
|
71
|
+
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
|
|
72
|
+
self._run = run
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def run(self) -> Run:
|
|
76
|
+
"""Run ID."""
|
|
77
|
+
self._init_run()
|
|
78
|
+
return Run(**vars(cast(Run, self._run)))
|
|
79
|
+
|
|
80
|
+
def create_message( # pylint: disable=too-many-arguments
|
|
81
|
+
self,
|
|
82
|
+
content: RecordSet,
|
|
83
|
+
message_type: str,
|
|
84
|
+
dst_node_id: int,
|
|
85
|
+
group_id: str,
|
|
86
|
+
ttl: Optional[float] = None,
|
|
87
|
+
) -> Message:
|
|
88
|
+
"""Create a new message with specified parameters.
|
|
89
|
+
|
|
90
|
+
This method constructs a new `Message` with given content and metadata.
|
|
91
|
+
The `run_id` and `src_node_id` will be set automatically.
|
|
92
|
+
"""
|
|
93
|
+
self._init_run()
|
|
94
|
+
if ttl:
|
|
95
|
+
warnings.warn(
|
|
96
|
+
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
97
|
+
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
|
|
98
|
+
"version of Flower.",
|
|
99
|
+
stacklevel=2,
|
|
100
|
+
)
|
|
101
|
+
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
102
|
+
|
|
103
|
+
metadata = Metadata(
|
|
104
|
+
run_id=cast(Run, self._run).run_id,
|
|
105
|
+
message_id="", # Will be set by the server
|
|
106
|
+
src_node_id=self.node.node_id,
|
|
107
|
+
dst_node_id=dst_node_id,
|
|
108
|
+
reply_to_message="",
|
|
109
|
+
group_id=group_id,
|
|
110
|
+
ttl=ttl_,
|
|
111
|
+
message_type=message_type,
|
|
112
|
+
)
|
|
113
|
+
return Message(metadata=metadata, content=content)
|
|
114
|
+
|
|
115
|
+
def get_node_ids(self) -> List[int]:
|
|
116
|
+
"""Get node IDs."""
|
|
117
|
+
self._init_run()
|
|
118
|
+
return list(self.state.get_nodes(cast(Run, self._run).run_id))
|
|
119
|
+
|
|
120
|
+
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
121
|
+
"""Push messages to specified node IDs.
|
|
122
|
+
|
|
123
|
+
This method takes an iterable of messages and sends each message
|
|
124
|
+
to the node specified in `dst_node_id`.
|
|
125
|
+
"""
|
|
126
|
+
task_ids: List[str] = []
|
|
127
|
+
for msg in messages:
|
|
128
|
+
# Check message
|
|
129
|
+
self._check_message(msg)
|
|
130
|
+
# Convert Message to TaskIns
|
|
131
|
+
taskins = message_to_taskins(msg)
|
|
132
|
+
# Store in state
|
|
133
|
+
taskins.task.pushed_at = time.time()
|
|
134
|
+
task_id = self.state.store_task_ins(taskins)
|
|
135
|
+
if task_id:
|
|
136
|
+
task_ids.append(str(task_id))
|
|
137
|
+
|
|
138
|
+
return task_ids
|
|
139
|
+
|
|
140
|
+
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
141
|
+
"""Pull messages based on message IDs.
|
|
142
|
+
|
|
143
|
+
This method is used to collect messages from the SuperLink that correspond to a
|
|
144
|
+
set of given message IDs.
|
|
145
|
+
"""
|
|
146
|
+
msg_ids = {UUID(msg_id) for msg_id in message_ids}
|
|
147
|
+
# Pull TaskRes
|
|
148
|
+
task_res_list = self.state.get_task_res(task_ids=msg_ids, limit=len(msg_ids))
|
|
149
|
+
# Delete tasks in state
|
|
150
|
+
self.state.delete_tasks(msg_ids)
|
|
151
|
+
# Convert TaskRes to Message
|
|
152
|
+
msgs = [message_from_taskres(taskres) for taskres in task_res_list]
|
|
153
|
+
return msgs
|
|
154
|
+
|
|
155
|
+
def send_and_receive(
|
|
156
|
+
self,
|
|
157
|
+
messages: Iterable[Message],
|
|
158
|
+
*,
|
|
159
|
+
timeout: Optional[float] = None,
|
|
160
|
+
) -> Iterable[Message]:
|
|
161
|
+
"""Push messages to specified node IDs and pull the reply messages.
|
|
162
|
+
|
|
163
|
+
This method sends a list of messages to their destination node IDs and then
|
|
164
|
+
waits for the replies. It continues to pull replies until either all replies are
|
|
165
|
+
received or the specified timeout duration is exceeded.
|
|
166
|
+
"""
|
|
167
|
+
# Push messages
|
|
168
|
+
msg_ids = set(self.push_messages(messages))
|
|
169
|
+
|
|
170
|
+
# Pull messages
|
|
171
|
+
end_time = time.time() + (timeout if timeout is not None else 0.0)
|
|
172
|
+
ret: List[Message] = []
|
|
173
|
+
while timeout is None or time.time() < end_time:
|
|
174
|
+
res_msgs = self.pull_messages(msg_ids)
|
|
175
|
+
ret.extend(res_msgs)
|
|
176
|
+
msg_ids.difference_update(
|
|
177
|
+
{msg.metadata.reply_to_message for msg in res_msgs}
|
|
178
|
+
)
|
|
179
|
+
if len(msg_ids) == 0:
|
|
180
|
+
break
|
|
181
|
+
# Sleep
|
|
182
|
+
time.sleep(3)
|
|
183
|
+
return ret
|
flwr/server/history.py
CHANGED
|
@@ -91,32 +91,32 @@ class History:
|
|
|
91
91
|
"""
|
|
92
92
|
rep = ""
|
|
93
93
|
if self.losses_distributed:
|
|
94
|
-
rep += "History (loss, distributed):\n" +
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
],
|
|
101
|
-
)
|
|
94
|
+
rep += "History (loss, distributed):\n" + reduce(
|
|
95
|
+
lambda a, b: a + b,
|
|
96
|
+
[
|
|
97
|
+
f"\tround {server_round}: {loss}\n"
|
|
98
|
+
for server_round, loss in self.losses_distributed
|
|
99
|
+
],
|
|
102
100
|
)
|
|
103
101
|
if self.losses_centralized:
|
|
104
|
-
rep += "History (loss, centralized):\n" +
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
],
|
|
111
|
-
)
|
|
102
|
+
rep += "History (loss, centralized):\n" + reduce(
|
|
103
|
+
lambda a, b: a + b,
|
|
104
|
+
[
|
|
105
|
+
f"\tround {server_round}: {loss}\n"
|
|
106
|
+
for server_round, loss in self.losses_centralized
|
|
107
|
+
],
|
|
112
108
|
)
|
|
113
109
|
if self.metrics_distributed_fit:
|
|
114
|
-
rep +=
|
|
115
|
-
|
|
110
|
+
rep += (
|
|
111
|
+
"History (metrics, distributed, fit):\n"
|
|
112
|
+
+ pprint.pformat(self.metrics_distributed_fit)
|
|
113
|
+
+ "\n"
|
|
116
114
|
)
|
|
117
115
|
if self.metrics_distributed:
|
|
118
|
-
rep +=
|
|
119
|
-
|
|
116
|
+
rep += (
|
|
117
|
+
"History (metrics, distributed, evaluate):\n"
|
|
118
|
+
+ pprint.pformat(self.metrics_distributed)
|
|
119
|
+
+ "\n"
|
|
120
120
|
)
|
|
121
121
|
if self.metrics_centralized:
|
|
122
122
|
rep += "History (metrics, centralized):\n" + pprint.pformat(
|