flwr-nightly 1.13.0.dev20241021__py3-none-any.whl → 1.13.0.dev20241111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +2 -2
- flwr/cli/config_utils.py +97 -0
- flwr/cli/log.py +63 -97
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +34 -88
- flwr/client/app.py +23 -20
- flwr/client/clientapp/app.py +22 -18
- flwr/client/nodestate/__init__.py +25 -0
- flwr/client/nodestate/in_memory_nodestate.py +38 -0
- flwr/client/nodestate/nodestate.py +30 -0
- flwr/client/nodestate/nodestate_factory.py +37 -0
- flwr/client/{node_state.py → run_info_store.py} +4 -3
- flwr/client/supernode/app.py +6 -8
- flwr/common/args.py +83 -0
- flwr/common/config.py +10 -0
- flwr/common/constant.py +39 -5
- flwr/common/context.py +9 -4
- flwr/common/date.py +3 -3
- flwr/common/logger.py +108 -1
- flwr/common/object_ref.py +47 -16
- flwr/common/serde.py +24 -0
- flwr/common/telemetry.py +0 -6
- flwr/common/typing.py +10 -1
- flwr/proto/exec_pb2.py +14 -17
- flwr/proto/exec_pb2.pyi +14 -22
- flwr/proto/log_pb2.py +29 -0
- flwr/proto/log_pb2.pyi +39 -0
- flwr/proto/log_pb2_grpc.py +4 -0
- flwr/proto/log_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +8 -8
- flwr/proto/message_pb2.pyi +4 -1
- flwr/proto/run_pb2.py +32 -27
- flwr/proto/run_pb2.pyi +26 -0
- flwr/proto/serverappio_pb2.py +52 -0
- flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
- flwr/proto/serverappio_pb2_grpc.py +376 -0
- flwr/proto/serverappio_pb2_grpc.pyi +147 -0
- flwr/proto/simulationio_pb2.py +38 -0
- flwr/proto/simulationio_pb2.pyi +65 -0
- flwr/proto/simulationio_pb2_grpc.py +205 -0
- flwr/proto/simulationio_pb2_grpc.pyi +81 -0
- flwr/server/app.py +272 -105
- flwr/server/driver/driver.py +15 -1
- flwr/server/driver/grpc_driver.py +25 -36
- flwr/server/driver/inmemory_driver.py +6 -16
- flwr/server/run_serverapp.py +29 -23
- flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
- flwr/server/serverapp/app.py +214 -0
- flwr/server/strategy/aggregate.py +4 -4
- flwr/server/strategy/fedadam.py +11 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
- flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
- flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
- flwr/server/superlink/fleet/vce/vce_api.py +23 -23
- flwr/server/superlink/linkstate/__init__.py +28 -0
- flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +184 -36
- flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +149 -19
- flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
- flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +306 -65
- flwr/server/superlink/{state → linkstate}/utils.py +81 -30
- flwr/server/superlink/simulation/__init__.py +15 -0
- flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
- flwr/server/superlink/simulation/simulationio_servicer.py +153 -0
- flwr/simulation/__init__.py +5 -1
- flwr/simulation/app.py +273 -345
- flwr/simulation/legacy_app.py +382 -0
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- flwr/simulation/run_simulation.py +57 -131
- flwr/simulation/simulationio_connection.py +86 -0
- flwr/superexec/app.py +6 -134
- flwr/superexec/deployment.py +61 -66
- flwr/superexec/exec_grpc.py +15 -8
- flwr/superexec/exec_servicer.py +36 -65
- flwr/superexec/executor.py +26 -7
- flwr/superexec/simulation.py +54 -107
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/METADATA +5 -4
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/RECORD +88 -69
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/entry_points.txt +2 -0
- flwr/client/node_state_tests.py +0 -66
- flwr/proto/driver_pb2.py +0 -42
- flwr/proto/driver_pb2_grpc.py +0 -239
- flwr/proto/driver_pb2_grpc.pyi +0 -94
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/WHEEL +0 -0
|
@@ -23,7 +23,7 @@ from typing import Optional, cast
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
25
25
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
26
|
-
from flwr.common.constant import
|
|
26
|
+
from flwr.common.constant import SERVERAPPIO_API_DEFAULT_ADDRESS
|
|
27
27
|
from flwr.common.grpc import create_channel
|
|
28
28
|
from flwr.common.logger import log
|
|
29
29
|
from flwr.common.serde import (
|
|
@@ -32,7 +32,9 @@ from flwr.common.serde import (
|
|
|
32
32
|
user_config_from_proto,
|
|
33
33
|
)
|
|
34
34
|
from flwr.common.typing import Run
|
|
35
|
-
from flwr.proto.
|
|
35
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
36
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
37
|
+
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
36
38
|
GetNodesRequest,
|
|
37
39
|
GetNodesResponse,
|
|
38
40
|
PullTaskResRequest,
|
|
@@ -40,9 +42,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
|
40
42
|
PushTaskInsRequest,
|
|
41
43
|
PushTaskInsResponse,
|
|
42
44
|
)
|
|
43
|
-
from flwr.proto.
|
|
44
|
-
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
45
|
-
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
45
|
+
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
|
|
46
46
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
47
47
|
|
|
48
48
|
from .driver import Driver
|
|
@@ -56,14 +56,12 @@ Call `connect()` on the `GrpcDriverStub` instance before calling any of the othe
|
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
class GrpcDriver(Driver):
|
|
59
|
-
"""`GrpcDriver` provides an interface to the
|
|
59
|
+
"""`GrpcDriver` provides an interface to the ServerAppIo API.
|
|
60
60
|
|
|
61
61
|
Parameters
|
|
62
62
|
----------
|
|
63
|
-
|
|
64
|
-
The
|
|
65
|
-
driver_service_address : str (default: "[::]:9091")
|
|
66
|
-
The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
|
|
63
|
+
serverappio_service_address : str (default: "[::]:9091")
|
|
64
|
+
The address (URL, IPv6, IPv4) of the SuperLink ServerAppIo API service.
|
|
67
65
|
root_certificates : Optional[bytes] (default: None)
|
|
68
66
|
The PEM-encoded root certificates as a byte string.
|
|
69
67
|
If provided, a secure connection using the certificates will be
|
|
@@ -72,25 +70,23 @@ class GrpcDriver(Driver):
|
|
|
72
70
|
|
|
73
71
|
def __init__( # pylint: disable=too-many-arguments
|
|
74
72
|
self,
|
|
75
|
-
|
|
76
|
-
driver_service_address: str = DRIVER_API_DEFAULT_ADDRESS,
|
|
73
|
+
serverappio_service_address: str = SERVERAPPIO_API_DEFAULT_ADDRESS,
|
|
77
74
|
root_certificates: Optional[bytes] = None,
|
|
78
75
|
) -> None:
|
|
79
|
-
self.
|
|
80
|
-
self._addr = driver_service_address
|
|
76
|
+
self._addr = serverappio_service_address
|
|
81
77
|
self._cert = root_certificates
|
|
82
78
|
self._run: Optional[Run] = None
|
|
83
|
-
self._grpc_stub: Optional[
|
|
79
|
+
self._grpc_stub: Optional[ServerAppIoStub] = None
|
|
84
80
|
self._channel: Optional[grpc.Channel] = None
|
|
85
81
|
self.node = Node(node_id=0, anonymous=True)
|
|
86
82
|
|
|
87
83
|
@property
|
|
88
84
|
def _is_connected(self) -> bool:
|
|
89
|
-
"""Check if connected to the
|
|
85
|
+
"""Check if connected to the ServerAppIo API server."""
|
|
90
86
|
return self._channel is not None
|
|
91
87
|
|
|
92
88
|
def _connect(self) -> None:
|
|
93
|
-
"""Connect to the
|
|
89
|
+
"""Connect to the ServerAppIo API.
|
|
94
90
|
|
|
95
91
|
This will not call GetRun.
|
|
96
92
|
"""
|
|
@@ -102,11 +98,11 @@ class GrpcDriver(Driver):
|
|
|
102
98
|
insecure=(self._cert is None),
|
|
103
99
|
root_certificates=self._cert,
|
|
104
100
|
)
|
|
105
|
-
self._grpc_stub =
|
|
101
|
+
self._grpc_stub = ServerAppIoStub(self._channel)
|
|
106
102
|
log(DEBUG, "[Driver] Connected to %s", self._addr)
|
|
107
103
|
|
|
108
104
|
def _disconnect(self) -> None:
|
|
109
|
-
"""Disconnect from the
|
|
105
|
+
"""Disconnect from the ServerAppIo API."""
|
|
110
106
|
if not self._is_connected:
|
|
111
107
|
log(DEBUG, "Already disconnected")
|
|
112
108
|
return
|
|
@@ -116,15 +112,13 @@ class GrpcDriver(Driver):
|
|
|
116
112
|
channel.close()
|
|
117
113
|
log(DEBUG, "[Driver] Disconnected")
|
|
118
114
|
|
|
119
|
-
def
|
|
120
|
-
|
|
121
|
-
if self._run is not None:
|
|
122
|
-
return
|
|
115
|
+
def set_run(self, run_id: int) -> None:
|
|
116
|
+
"""Set the run."""
|
|
123
117
|
# Get the run info
|
|
124
|
-
req = GetRunRequest(run_id=
|
|
118
|
+
req = GetRunRequest(run_id=run_id)
|
|
125
119
|
res: GetRunResponse = self._stub.GetRun(req)
|
|
126
120
|
if not res.HasField("run"):
|
|
127
|
-
raise RuntimeError(f"Cannot find the run with ID: {
|
|
121
|
+
raise RuntimeError(f"Cannot find the run with ID: {run_id}")
|
|
128
122
|
self._run = Run(
|
|
129
123
|
run_id=res.run.run_id,
|
|
130
124
|
fab_id=res.run.fab_id,
|
|
@@ -136,21 +130,20 @@ class GrpcDriver(Driver):
|
|
|
136
130
|
@property
|
|
137
131
|
def run(self) -> Run:
|
|
138
132
|
"""Run information."""
|
|
139
|
-
self._init_run()
|
|
140
133
|
return Run(**vars(self._run))
|
|
141
134
|
|
|
142
135
|
@property
|
|
143
|
-
def _stub(self) ->
|
|
144
|
-
"""
|
|
136
|
+
def _stub(self) -> ServerAppIoStub:
|
|
137
|
+
"""ServerAppIo stub."""
|
|
145
138
|
if not self._is_connected:
|
|
146
139
|
self._connect()
|
|
147
|
-
return cast(
|
|
140
|
+
return cast(ServerAppIoStub, self._grpc_stub)
|
|
148
141
|
|
|
149
142
|
def _check_message(self, message: Message) -> None:
|
|
150
143
|
# Check if the message is valid
|
|
151
144
|
if not (
|
|
152
145
|
# Assume self._run being initialized
|
|
153
|
-
message.metadata.run_id == self.
|
|
146
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
154
147
|
and message.metadata.src_node_id == self.node.node_id
|
|
155
148
|
and message.metadata.message_id == ""
|
|
156
149
|
and message.metadata.reply_to_message == ""
|
|
@@ -171,7 +164,6 @@ class GrpcDriver(Driver):
|
|
|
171
164
|
This method constructs a new `Message` with given content and metadata.
|
|
172
165
|
The `run_id` and `src_node_id` will be set automatically.
|
|
173
166
|
"""
|
|
174
|
-
self._init_run()
|
|
175
167
|
if ttl:
|
|
176
168
|
warnings.warn(
|
|
177
169
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -182,7 +174,7 @@ class GrpcDriver(Driver):
|
|
|
182
174
|
|
|
183
175
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
184
176
|
metadata = Metadata(
|
|
185
|
-
run_id=self.
|
|
177
|
+
run_id=cast(Run, self._run).run_id,
|
|
186
178
|
message_id="", # Will be set by the server
|
|
187
179
|
src_node_id=self.node.node_id,
|
|
188
180
|
dst_node_id=dst_node_id,
|
|
@@ -195,10 +187,9 @@ class GrpcDriver(Driver):
|
|
|
195
187
|
|
|
196
188
|
def get_node_ids(self) -> list[int]:
|
|
197
189
|
"""Get node IDs."""
|
|
198
|
-
self._init_run()
|
|
199
190
|
# Call GrpcDriverStub method
|
|
200
191
|
res: GetNodesResponse = self._stub.GetNodes(
|
|
201
|
-
GetNodesRequest(run_id=self.
|
|
192
|
+
GetNodesRequest(run_id=cast(Run, self._run).run_id)
|
|
202
193
|
)
|
|
203
194
|
return [node.node_id for node in res.nodes]
|
|
204
195
|
|
|
@@ -208,7 +199,6 @@ class GrpcDriver(Driver):
|
|
|
208
199
|
This method takes an iterable of messages and sends each message
|
|
209
200
|
to the node specified in `dst_node_id`.
|
|
210
201
|
"""
|
|
211
|
-
self._init_run()
|
|
212
202
|
# Construct TaskIns
|
|
213
203
|
task_ins_list: list[TaskIns] = []
|
|
214
204
|
for msg in messages:
|
|
@@ -230,7 +220,6 @@ class GrpcDriver(Driver):
|
|
|
230
220
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
231
221
|
set of given message IDs.
|
|
232
222
|
"""
|
|
233
|
-
self._init_run()
|
|
234
223
|
# Pull TaskRes
|
|
235
224
|
res: PullTaskResResponse = self._stub.PullTaskRes(
|
|
236
225
|
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
@@ -25,18 +25,16 @@ from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
|
25
25
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
26
26
|
from flwr.common.typing import Run
|
|
27
27
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
28
|
-
from flwr.server.superlink.
|
|
28
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
29
29
|
|
|
30
30
|
from .driver import Driver
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
class InMemoryDriver(Driver):
|
|
34
|
-
"""`InMemoryDriver` class provides an interface to the
|
|
34
|
+
"""`InMemoryDriver` class provides an interface to the ServerAppIo API.
|
|
35
35
|
|
|
36
36
|
Parameters
|
|
37
37
|
----------
|
|
38
|
-
run_id : int
|
|
39
|
-
The identifier of the run.
|
|
40
38
|
state_factory : StateFactory
|
|
41
39
|
A StateFactory embedding a state that this driver can interface with.
|
|
42
40
|
pull_interval : float (default=0.1)
|
|
@@ -45,18 +43,15 @@ class InMemoryDriver(Driver):
|
|
|
45
43
|
|
|
46
44
|
def __init__(
|
|
47
45
|
self,
|
|
48
|
-
|
|
49
|
-
state_factory: StateFactory,
|
|
46
|
+
state_factory: LinkStateFactory,
|
|
50
47
|
pull_interval: float = 0.1,
|
|
51
48
|
) -> None:
|
|
52
|
-
self._run_id = run_id
|
|
53
49
|
self._run: Optional[Run] = None
|
|
54
50
|
self.state = state_factory.state()
|
|
55
51
|
self.pull_interval = pull_interval
|
|
56
52
|
self.node = Node(node_id=0, anonymous=True)
|
|
57
53
|
|
|
58
54
|
def _check_message(self, message: Message) -> None:
|
|
59
|
-
self._init_run()
|
|
60
55
|
# Check if the message is valid
|
|
61
56
|
if not (
|
|
62
57
|
message.metadata.run_id == cast(Run, self._run).run_id
|
|
@@ -67,19 +62,16 @@ class InMemoryDriver(Driver):
|
|
|
67
62
|
):
|
|
68
63
|
raise ValueError(f"Invalid message: {message}")
|
|
69
64
|
|
|
70
|
-
def
|
|
65
|
+
def set_run(self, run_id: int) -> None:
|
|
71
66
|
"""Initialize the run."""
|
|
72
|
-
|
|
73
|
-
return
|
|
74
|
-
run = self.state.get_run(self._run_id)
|
|
67
|
+
run = self.state.get_run(run_id)
|
|
75
68
|
if run is None:
|
|
76
|
-
raise RuntimeError(f"Cannot find the run with ID: {
|
|
69
|
+
raise RuntimeError(f"Cannot find the run with ID: {run_id}")
|
|
77
70
|
self._run = run
|
|
78
71
|
|
|
79
72
|
@property
|
|
80
73
|
def run(self) -> Run:
|
|
81
74
|
"""Run ID."""
|
|
82
|
-
self._init_run()
|
|
83
75
|
return Run(**vars(cast(Run, self._run)))
|
|
84
76
|
|
|
85
77
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
|
@@ -95,7 +87,6 @@ class InMemoryDriver(Driver):
|
|
|
95
87
|
This method constructs a new `Message` with given content and metadata.
|
|
96
88
|
The `run_id` and `src_node_id` will be set automatically.
|
|
97
89
|
"""
|
|
98
|
-
self._init_run()
|
|
99
90
|
if ttl:
|
|
100
91
|
warnings.warn(
|
|
101
92
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -119,7 +110,6 @@ class InMemoryDriver(Driver):
|
|
|
119
110
|
|
|
120
111
|
def get_node_ids(self) -> list[int]:
|
|
121
112
|
"""Get node IDs."""
|
|
122
|
-
self._init_run()
|
|
123
113
|
return list(self.state.get_nodes(cast(Run, self._run).run_id))
|
|
124
114
|
|
|
125
115
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -31,10 +31,9 @@ from flwr.common.config import (
|
|
|
31
31
|
get_project_config,
|
|
32
32
|
get_project_dir,
|
|
33
33
|
)
|
|
34
|
-
from flwr.common.constant import
|
|
34
|
+
from flwr.common.constant import SERVERAPPIO_API_DEFAULT_ADDRESS
|
|
35
35
|
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
|
|
36
36
|
from flwr.common.object_ref import load_app
|
|
37
|
-
from flwr.common.typing import UserConfig
|
|
38
37
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
39
38
|
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
40
39
|
CreateRunRequest,
|
|
@@ -48,11 +47,11 @@ from .server_app import LoadServerAppError, ServerApp
|
|
|
48
47
|
|
|
49
48
|
def run(
|
|
50
49
|
driver: Driver,
|
|
50
|
+
context: Context,
|
|
51
51
|
server_app_dir: str,
|
|
52
|
-
server_app_run_config: UserConfig,
|
|
53
52
|
server_app_attr: Optional[str] = None,
|
|
54
53
|
loaded_server_app: Optional[ServerApp] = None,
|
|
55
|
-
) ->
|
|
54
|
+
) -> Context:
|
|
56
55
|
"""Run ServerApp with a given Driver."""
|
|
57
56
|
if not (server_app_attr is None) ^ (loaded_server_app is None):
|
|
58
57
|
raise ValueError(
|
|
@@ -78,15 +77,11 @@ def run(
|
|
|
78
77
|
|
|
79
78
|
server_app = _load()
|
|
80
79
|
|
|
81
|
-
# Initialize Context
|
|
82
|
-
context = Context(
|
|
83
|
-
node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config
|
|
84
|
-
)
|
|
85
|
-
|
|
86
80
|
# Call ServerApp
|
|
87
81
|
server_app(driver=driver, context=context)
|
|
88
82
|
|
|
89
83
|
log(DEBUG, "ServerApp finished running.")
|
|
84
|
+
return context
|
|
90
85
|
|
|
91
86
|
|
|
92
87
|
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
|
|
@@ -111,18 +106,18 @@ def run_server_app() -> None:
|
|
|
111
106
|
"app by executing `flwr new` and following the prompt."
|
|
112
107
|
)
|
|
113
108
|
|
|
114
|
-
if args.server !=
|
|
109
|
+
if args.server != SERVERAPPIO_API_DEFAULT_ADDRESS:
|
|
115
110
|
warn = "Passing flag --server is deprecated. Use --superlink instead."
|
|
116
111
|
warn_deprecated_feature(warn)
|
|
117
112
|
|
|
118
|
-
if args.superlink !=
|
|
113
|
+
if args.superlink != SERVERAPPIO_API_DEFAULT_ADDRESS:
|
|
119
114
|
# if `--superlink` also passed, then
|
|
120
115
|
# warn user that this argument overrides what was passed with `--server`
|
|
121
116
|
log(
|
|
122
117
|
WARN,
|
|
123
118
|
"Both `--server` and `--superlink` were passed. "
|
|
124
|
-
"`--server` will be ignored. Connecting to the
|
|
125
|
-
"at %s.",
|
|
119
|
+
"`--server` will be ignored. Connecting to the "
|
|
120
|
+
"SuperLink ServerAppIo API at %s.",
|
|
126
121
|
args.superlink,
|
|
127
122
|
)
|
|
128
123
|
else:
|
|
@@ -175,11 +170,11 @@ def run_server_app() -> None:
|
|
|
175
170
|
if app_path is None:
|
|
176
171
|
# User provided `--run-id`, but not `app_dir`
|
|
177
172
|
driver = GrpcDriver(
|
|
178
|
-
|
|
179
|
-
driver_service_address=args.superlink,
|
|
173
|
+
serverappio_service_address=args.superlink,
|
|
180
174
|
root_certificates=root_certificates,
|
|
181
175
|
)
|
|
182
176
|
flwr_dir = get_flwr_dir(args.flwr_dir)
|
|
177
|
+
driver.set_run(args.run_id)
|
|
183
178
|
run_ = driver.run
|
|
184
179
|
if not run_.fab_hash:
|
|
185
180
|
raise ValueError("FAB hash not provided.")
|
|
@@ -193,12 +188,12 @@ def run_server_app() -> None:
|
|
|
193
188
|
|
|
194
189
|
app_path = str(get_project_dir(fab_id, fab_version, run_.fab_hash, flwr_dir))
|
|
195
190
|
config = get_project_config(app_path)
|
|
191
|
+
run_id = run_.run_id
|
|
196
192
|
else:
|
|
197
193
|
# User provided `app_dir`, but not `--run-id`
|
|
198
194
|
# Create run if run_id is not provided
|
|
199
195
|
driver = GrpcDriver(
|
|
200
|
-
|
|
201
|
-
driver_service_address=args.superlink,
|
|
196
|
+
serverappio_service_address=args.superlink,
|
|
202
197
|
root_certificates=root_certificates,
|
|
203
198
|
)
|
|
204
199
|
# Load config from the project directory
|
|
@@ -208,8 +203,9 @@ def run_server_app() -> None:
|
|
|
208
203
|
# Create run
|
|
209
204
|
req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version)
|
|
210
205
|
res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
|
|
211
|
-
#
|
|
212
|
-
driver.
|
|
206
|
+
# Fetch full `Run` using `run_id`
|
|
207
|
+
driver.set_run(res.run_id) # pylint: disable=W0212
|
|
208
|
+
run_id = res.run_id
|
|
213
209
|
|
|
214
210
|
# Obtain server app reference and the run config
|
|
215
211
|
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
|
|
@@ -225,11 +221,20 @@ def run_server_app() -> None:
|
|
|
225
221
|
root_certificates,
|
|
226
222
|
)
|
|
227
223
|
|
|
224
|
+
# Initialize Context
|
|
225
|
+
context = Context(
|
|
226
|
+
run_id=run_id,
|
|
227
|
+
node_id=0,
|
|
228
|
+
node_config={},
|
|
229
|
+
state=RecordSet(),
|
|
230
|
+
run_config=server_app_run_config,
|
|
231
|
+
)
|
|
232
|
+
|
|
228
233
|
# Run the ServerApp with the Driver
|
|
229
234
|
run(
|
|
230
235
|
driver=driver,
|
|
236
|
+
context=context,
|
|
231
237
|
server_app_dir=app_path,
|
|
232
|
-
server_app_run_config=server_app_run_config,
|
|
233
238
|
server_app_attr=server_app_attr,
|
|
234
239
|
)
|
|
235
240
|
|
|
@@ -272,13 +277,14 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
|
272
277
|
)
|
|
273
278
|
parser.add_argument(
|
|
274
279
|
"--server",
|
|
275
|
-
default=
|
|
280
|
+
default=SERVERAPPIO_API_DEFAULT_ADDRESS,
|
|
276
281
|
help="Server address",
|
|
277
282
|
)
|
|
278
283
|
parser.add_argument(
|
|
279
284
|
"--superlink",
|
|
280
|
-
default=
|
|
281
|
-
help="SuperLink
|
|
285
|
+
default=SERVERAPPIO_API_DEFAULT_ADDRESS,
|
|
286
|
+
help="SuperLink ServerAppIo API (gRPC-rere) address "
|
|
287
|
+
"(IPv4, IPv6, or a domain name)",
|
|
282
288
|
)
|
|
283
289
|
parser.add_argument(
|
|
284
290
|
"--run-id",
|
|
@@ -12,17 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower
|
|
15
|
+
"""Flower AppIO service."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from .
|
|
19
|
-
from .sqlite_state import SqliteState as SqliteState
|
|
20
|
-
from .state import State as State
|
|
21
|
-
from .state_factory import StateFactory as StateFactory
|
|
18
|
+
from .app import flwr_serverapp as flwr_serverapp
|
|
22
19
|
|
|
23
20
|
__all__ = [
|
|
24
|
-
"
|
|
25
|
-
"SqliteState",
|
|
26
|
-
"State",
|
|
27
|
-
"StateFactory",
|
|
21
|
+
"flwr_serverapp",
|
|
28
22
|
]
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower ServerApp process."""
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
from logging import DEBUG, ERROR, INFO
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from queue import Queue
|
|
21
|
+
from time import sleep
|
|
22
|
+
from typing import Optional
|
|
23
|
+
|
|
24
|
+
from flwr.cli.config_utils import get_fab_metadata
|
|
25
|
+
from flwr.cli.install import install_from_fab
|
|
26
|
+
from flwr.common.args import add_args_flwr_app_common, try_obtain_certificates
|
|
27
|
+
from flwr.common.config import (
|
|
28
|
+
get_flwr_dir,
|
|
29
|
+
get_fused_config_from_dir,
|
|
30
|
+
get_project_config,
|
|
31
|
+
get_project_dir,
|
|
32
|
+
)
|
|
33
|
+
from flwr.common.constant import Status, SubStatus
|
|
34
|
+
from flwr.common.logger import (
|
|
35
|
+
log,
|
|
36
|
+
mirror_output_to_queue,
|
|
37
|
+
restore_output,
|
|
38
|
+
start_log_uploader,
|
|
39
|
+
stop_log_uploader,
|
|
40
|
+
)
|
|
41
|
+
from flwr.common.serde import (
|
|
42
|
+
context_from_proto,
|
|
43
|
+
context_to_proto,
|
|
44
|
+
fab_from_proto,
|
|
45
|
+
run_from_proto,
|
|
46
|
+
run_status_to_proto,
|
|
47
|
+
)
|
|
48
|
+
from flwr.common.typing import RunStatus
|
|
49
|
+
from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
|
|
50
|
+
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
51
|
+
PullServerAppInputsRequest,
|
|
52
|
+
PullServerAppInputsResponse,
|
|
53
|
+
PushServerAppOutputsRequest,
|
|
54
|
+
)
|
|
55
|
+
from flwr.server.driver.grpc_driver import GrpcDriver
|
|
56
|
+
from flwr.server.run_serverapp import run as run_
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def flwr_serverapp() -> None:
|
|
60
|
+
"""Run process-isolated Flower ServerApp."""
|
|
61
|
+
# Capture stdout/stderr
|
|
62
|
+
log_queue: Queue[Optional[str]] = Queue()
|
|
63
|
+
mirror_output_to_queue(log_queue)
|
|
64
|
+
|
|
65
|
+
parser = argparse.ArgumentParser(
|
|
66
|
+
description="Run a Flower ServerApp",
|
|
67
|
+
)
|
|
68
|
+
parser.add_argument(
|
|
69
|
+
"--superlink",
|
|
70
|
+
type=str,
|
|
71
|
+
help="Address of SuperLink's ServerAppIo API",
|
|
72
|
+
)
|
|
73
|
+
parser.add_argument(
|
|
74
|
+
"--run-once",
|
|
75
|
+
action="store_true",
|
|
76
|
+
help="When set, this process will start a single ServerApp for a pending Run. "
|
|
77
|
+
"If there is no pending Run, the process will exit.",
|
|
78
|
+
)
|
|
79
|
+
add_args_flwr_app_common(parser=parser)
|
|
80
|
+
args = parser.parse_args()
|
|
81
|
+
|
|
82
|
+
log(INFO, "Starting Flower ServerApp")
|
|
83
|
+
certificates = try_obtain_certificates(args)
|
|
84
|
+
|
|
85
|
+
log(
|
|
86
|
+
DEBUG,
|
|
87
|
+
"Starting isolated `ServerApp` connected to SuperLink's ServerAppIo API at %s",
|
|
88
|
+
args.superlink,
|
|
89
|
+
)
|
|
90
|
+
run_serverapp(
|
|
91
|
+
superlink=args.superlink,
|
|
92
|
+
log_queue=log_queue,
|
|
93
|
+
run_once=args.run_once,
|
|
94
|
+
flwr_dir=args.flwr_dir,
|
|
95
|
+
certificates=certificates,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Restore stdout/stderr
|
|
99
|
+
restore_output()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
103
|
+
superlink: str,
|
|
104
|
+
log_queue: Queue[Optional[str]],
|
|
105
|
+
run_once: bool,
|
|
106
|
+
flwr_dir: Optional[str] = None,
|
|
107
|
+
certificates: Optional[bytes] = None,
|
|
108
|
+
) -> None:
|
|
109
|
+
"""Run Flower ServerApp process."""
|
|
110
|
+
driver = GrpcDriver(
|
|
111
|
+
serverappio_service_address=superlink,
|
|
112
|
+
root_certificates=certificates,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Resolve directory where FABs are installed
|
|
116
|
+
flwr_dir_ = get_flwr_dir(flwr_dir)
|
|
117
|
+
log_uploader = None
|
|
118
|
+
|
|
119
|
+
while True:
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
# Pull ServerAppInputs from LinkState
|
|
123
|
+
req = PullServerAppInputsRequest()
|
|
124
|
+
res: PullServerAppInputsResponse = driver._stub.PullServerAppInputs(req)
|
|
125
|
+
if not res.HasField("run"):
|
|
126
|
+
sleep(3)
|
|
127
|
+
run_status = None
|
|
128
|
+
continue
|
|
129
|
+
|
|
130
|
+
context = context_from_proto(res.context)
|
|
131
|
+
run = run_from_proto(res.run)
|
|
132
|
+
fab = fab_from_proto(res.fab)
|
|
133
|
+
|
|
134
|
+
driver.set_run(run.run_id)
|
|
135
|
+
|
|
136
|
+
# Start log uploader for this run
|
|
137
|
+
log_uploader = start_log_uploader(
|
|
138
|
+
log_queue=log_queue,
|
|
139
|
+
node_id=0,
|
|
140
|
+
run_id=run.run_id,
|
|
141
|
+
stub=driver._stub,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
log(DEBUG, "ServerApp process starts FAB installation.")
|
|
145
|
+
install_from_fab(fab.content, flwr_dir=flwr_dir_, skip_prompt=True)
|
|
146
|
+
|
|
147
|
+
fab_id, fab_version = get_fab_metadata(fab.content)
|
|
148
|
+
|
|
149
|
+
app_path = str(
|
|
150
|
+
get_project_dir(fab_id, fab_version, fab.hash_str, flwr_dir_)
|
|
151
|
+
)
|
|
152
|
+
config = get_project_config(app_path)
|
|
153
|
+
|
|
154
|
+
# Obtain server app reference and the run config
|
|
155
|
+
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
|
|
156
|
+
server_app_run_config = get_fused_config_from_dir(
|
|
157
|
+
Path(app_path), run.override_config
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Update run_config in context
|
|
161
|
+
context.run_config = server_app_run_config
|
|
162
|
+
|
|
163
|
+
log(
|
|
164
|
+
DEBUG,
|
|
165
|
+
"Flower will load ServerApp `%s` in %s",
|
|
166
|
+
server_app_attr,
|
|
167
|
+
app_path,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Change status to Running
|
|
171
|
+
run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
|
|
172
|
+
driver._stub.UpdateRunStatus(
|
|
173
|
+
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Load and run the ServerApp with the Driver
|
|
177
|
+
updated_context = run_(
|
|
178
|
+
driver=driver,
|
|
179
|
+
server_app_dir=app_path,
|
|
180
|
+
server_app_attr=server_app_attr,
|
|
181
|
+
context=context,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Send resulting context
|
|
185
|
+
context_proto = context_to_proto(updated_context)
|
|
186
|
+
out_req = PushServerAppOutputsRequest(
|
|
187
|
+
run_id=run.run_id, context=context_proto
|
|
188
|
+
)
|
|
189
|
+
_ = driver._stub.PushServerAppOutputs(out_req)
|
|
190
|
+
|
|
191
|
+
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
|
192
|
+
|
|
193
|
+
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
194
|
+
exc_entity = "ServerApp"
|
|
195
|
+
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
|
196
|
+
run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))
|
|
197
|
+
|
|
198
|
+
finally:
|
|
199
|
+
if run_status:
|
|
200
|
+
run_status_proto = run_status_to_proto(run_status)
|
|
201
|
+
driver._stub.UpdateRunStatus(
|
|
202
|
+
UpdateRunStatusRequest(
|
|
203
|
+
run_id=run.run_id, run_status=run_status_proto
|
|
204
|
+
)
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Stop log uploader for this run
|
|
208
|
+
if log_uploader:
|
|
209
|
+
stop_log_uploader(log_queue, log_uploader)
|
|
210
|
+
log_uploader = None
|
|
211
|
+
|
|
212
|
+
# Stop the loop if `flwr-serverapp` is expected to process a single run
|
|
213
|
+
if run_once:
|
|
214
|
+
break
|
|
@@ -48,12 +48,12 @@ def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
|
|
|
48
48
|
num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results)
|
|
49
49
|
|
|
50
50
|
# Compute scaling factors for each result
|
|
51
|
-
scaling_factors =
|
|
52
|
-
fit_res.num_examples / num_examples_total for _, fit_res in results
|
|
53
|
-
|
|
51
|
+
scaling_factors = np.asarray(
|
|
52
|
+
[fit_res.num_examples / num_examples_total for _, fit_res in results]
|
|
53
|
+
)
|
|
54
54
|
|
|
55
55
|
def _try_inplace(
|
|
56
|
-
x: NDArray, y: Union[NDArray,
|
|
56
|
+
x: NDArray, y: Union[NDArray, np.float64], np_binary_op: np.ufunc
|
|
57
57
|
) -> NDArray:
|
|
58
58
|
return ( # type: ignore[no-any-return]
|
|
59
59
|
np_binary_op(x, y, out=x)
|
flwr/server/strategy/fedadam.py
CHANGED
|
@@ -170,8 +170,18 @@ class FedAdam(FedOpt):
|
|
|
170
170
|
for x, y in zip(self.v_t, delta_t)
|
|
171
171
|
]
|
|
172
172
|
|
|
173
|
+
# Compute the bias-corrected learning rate, `eta_norm` for improving convergence
|
|
174
|
+
# in the early rounds of FL training. This `eta_norm` is `\alpha_t` in Kingma &
|
|
175
|
+
# Ba, 2014 (http://arxiv.org/abs/1412.6980) "Adam: A Method for Stochastic
|
|
176
|
+
# Optimization" in the formula line right before Section 2.1.
|
|
177
|
+
eta_norm = (
|
|
178
|
+
self.eta
|
|
179
|
+
* np.sqrt(1 - np.power(self.beta_2, server_round + 1.0))
|
|
180
|
+
/ (1 - np.power(self.beta_1, server_round + 1.0))
|
|
181
|
+
)
|
|
182
|
+
|
|
173
183
|
new_weights = [
|
|
174
|
-
x +
|
|
184
|
+
x + eta_norm * y / (np.sqrt(z) + self.tau)
|
|
175
185
|
for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
|
|
176
186
|
]
|
|
177
187
|
|