flwr-nightly 1.9.0.dev20240417__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +151 -0
- flwr/cli/config_utils.py +19 -14
- flwr/cli/new/new.py +51 -22
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +42 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +26 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
- flwr/cli/run/run.py +1 -1
- flwr/cli/utils.py +18 -17
- flwr/client/__init__.py +3 -1
- flwr/client/app.py +20 -142
- flwr/client/grpc_client/connection.py +8 -2
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +33 -4
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/localdp_mod.py +9 -3
- flwr/client/rest_client/connection.py +92 -169
- flwr/client/supernode/__init__.py +24 -0
- flwr/client/supernode/app.py +281 -0
- flwr/common/grpc.py +5 -1
- flwr/common/logger.py +37 -4
- flwr/common/message.py +105 -86
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
- flwr/common/telemetry.py +4 -0
- flwr/server/app.py +116 -6
- flwr/server/compat/app.py +2 -2
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +27 -70
- flwr/server/driver/__init__.py +2 -1
- flwr/server/driver/driver.py +12 -139
- flwr/server/driver/grpc_driver.py +199 -13
- flwr/server/run_serverapp.py +18 -4
- flwr/server/strategy/dp_adaptive_clipping.py +5 -3
- flwr/server/strategy/dp_fixed_clipping.py +6 -3
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +4 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +89 -12
- flwr/server/superlink/state/sqlite_state.py +133 -16
- flwr/server/superlink/state/state.py +56 -6
- flwr/simulation/__init__.py +2 -2
- flwr/simulation/app.py +16 -1
- flwr/simulation/run_simulation.py +10 -7
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +66 -52
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +2 -1
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2022 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,17 +12,19 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower
|
|
16
|
-
|
|
15
|
+
"""Flower gRPC Driver."""
|
|
17
16
|
|
|
17
|
+
import time
|
|
18
|
+
import warnings
|
|
18
19
|
from logging import DEBUG, ERROR, WARNING
|
|
19
|
-
from typing import Optional
|
|
20
|
+
from typing import Iterable, List, Optional, Tuple
|
|
20
21
|
|
|
21
22
|
import grpc
|
|
22
23
|
|
|
23
|
-
from flwr.common import EventType, event
|
|
24
|
+
from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
|
|
24
25
|
from flwr.common.grpc import create_channel
|
|
25
26
|
from flwr.common.logger import log
|
|
27
|
+
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
26
28
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
27
29
|
CreateRunRequest,
|
|
28
30
|
CreateRunResponse,
|
|
@@ -34,19 +36,23 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
|
34
36
|
PushTaskInsResponse,
|
|
35
37
|
)
|
|
36
38
|
from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
|
|
39
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
40
|
+
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
41
|
+
|
|
42
|
+
from .driver import Driver
|
|
37
43
|
|
|
38
44
|
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
39
45
|
|
|
40
46
|
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
|
|
41
47
|
[Driver] Error: Not connected.
|
|
42
48
|
|
|
43
|
-
Call `connect()` on the `
|
|
44
|
-
`
|
|
49
|
+
Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other
|
|
50
|
+
`GrpcDriverHelper` methods.
|
|
45
51
|
"""
|
|
46
52
|
|
|
47
53
|
|
|
48
|
-
class
|
|
49
|
-
"""`
|
|
54
|
+
class GrpcDriverHelper:
|
|
55
|
+
"""`GrpcDriverHelper` provides access to the gRPC Driver API/service."""
|
|
50
56
|
|
|
51
57
|
def __init__(
|
|
52
58
|
self,
|
|
@@ -89,7 +95,7 @@ class GrpcDriver:
|
|
|
89
95
|
# Check if channel is open
|
|
90
96
|
if self.stub is None:
|
|
91
97
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
92
|
-
raise ConnectionError("`
|
|
98
|
+
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
93
99
|
|
|
94
100
|
# Call Driver API
|
|
95
101
|
res: CreateRunResponse = self.stub.CreateRun(request=req)
|
|
@@ -100,7 +106,7 @@ class GrpcDriver:
|
|
|
100
106
|
# Check if channel is open
|
|
101
107
|
if self.stub is None:
|
|
102
108
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
103
|
-
raise ConnectionError("`
|
|
109
|
+
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
104
110
|
|
|
105
111
|
# Call gRPC Driver API
|
|
106
112
|
res: GetNodesResponse = self.stub.GetNodes(request=req)
|
|
@@ -111,7 +117,7 @@ class GrpcDriver:
|
|
|
111
117
|
# Check if channel is open
|
|
112
118
|
if self.stub is None:
|
|
113
119
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
114
|
-
raise ConnectionError("`
|
|
120
|
+
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
115
121
|
|
|
116
122
|
# Call gRPC Driver API
|
|
117
123
|
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
|
|
@@ -122,8 +128,188 @@ class GrpcDriver:
|
|
|
122
128
|
# Check if channel is open
|
|
123
129
|
if self.stub is None:
|
|
124
130
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
125
|
-
raise ConnectionError("`
|
|
131
|
+
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
126
132
|
|
|
127
133
|
# Call Driver API
|
|
128
134
|
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
|
|
129
135
|
return res
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class GrpcDriver(Driver):
|
|
139
|
+
"""`Driver` class provides an interface to the Driver API.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
driver_service_address : Optional[str]
|
|
144
|
+
The IPv4 or IPv6 address of the Driver API server.
|
|
145
|
+
Defaults to `"[::]:9091"`.
|
|
146
|
+
certificates : bytes (default: None)
|
|
147
|
+
Tuple containing root certificate, server certificate, and private key
|
|
148
|
+
to start a secure SSL-enabled server. The tuple is expected to have
|
|
149
|
+
three bytes elements in the following order:
|
|
150
|
+
|
|
151
|
+
* CA certificate.
|
|
152
|
+
* server certificate.
|
|
153
|
+
* server private key.
|
|
154
|
+
fab_id : str (default: None)
|
|
155
|
+
The identifier of the FAB used in the run.
|
|
156
|
+
fab_version : str (default: None)
|
|
157
|
+
The version of the FAB used in the run.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def __init__(
|
|
161
|
+
self,
|
|
162
|
+
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
|
|
163
|
+
root_certificates: Optional[bytes] = None,
|
|
164
|
+
fab_id: Optional[str] = None,
|
|
165
|
+
fab_version: Optional[str] = None,
|
|
166
|
+
) -> None:
|
|
167
|
+
self.addr = driver_service_address
|
|
168
|
+
self.root_certificates = root_certificates
|
|
169
|
+
self.driver_helper: Optional[GrpcDriverHelper] = None
|
|
170
|
+
self.run_id: Optional[int] = None
|
|
171
|
+
self.fab_id = fab_id if fab_id is not None else ""
|
|
172
|
+
self.fab_version = fab_version if fab_version is not None else ""
|
|
173
|
+
self.node = Node(node_id=0, anonymous=True)
|
|
174
|
+
|
|
175
|
+
def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
|
|
176
|
+
# Check if the GrpcDriverHelper is initialized
|
|
177
|
+
if self.driver_helper is None or self.run_id is None:
|
|
178
|
+
# Connect and create run
|
|
179
|
+
self.driver_helper = GrpcDriverHelper(
|
|
180
|
+
driver_service_address=self.addr,
|
|
181
|
+
root_certificates=self.root_certificates,
|
|
182
|
+
)
|
|
183
|
+
self.driver_helper.connect()
|
|
184
|
+
req = CreateRunRequest(fab_id=self.fab_id, fab_version=self.fab_version)
|
|
185
|
+
res = self.driver_helper.create_run(req)
|
|
186
|
+
self.run_id = res.run_id
|
|
187
|
+
return self.driver_helper, self.run_id
|
|
188
|
+
|
|
189
|
+
def _check_message(self, message: Message) -> None:
|
|
190
|
+
# Check if the message is valid
|
|
191
|
+
if not (
|
|
192
|
+
message.metadata.run_id == self.run_id
|
|
193
|
+
and message.metadata.src_node_id == self.node.node_id
|
|
194
|
+
and message.metadata.message_id == ""
|
|
195
|
+
and message.metadata.reply_to_message == ""
|
|
196
|
+
and message.metadata.ttl > 0
|
|
197
|
+
):
|
|
198
|
+
raise ValueError(f"Invalid message: {message}")
|
|
199
|
+
|
|
200
|
+
def create_message( # pylint: disable=too-many-arguments
|
|
201
|
+
self,
|
|
202
|
+
content: RecordSet,
|
|
203
|
+
message_type: str,
|
|
204
|
+
dst_node_id: int,
|
|
205
|
+
group_id: str,
|
|
206
|
+
ttl: Optional[float] = None,
|
|
207
|
+
) -> Message:
|
|
208
|
+
"""Create a new message with specified parameters.
|
|
209
|
+
|
|
210
|
+
This method constructs a new `Message` with given content and metadata.
|
|
211
|
+
The `run_id` and `src_node_id` will be set automatically.
|
|
212
|
+
"""
|
|
213
|
+
_, run_id = self._get_grpc_driver_helper_and_run_id()
|
|
214
|
+
if ttl:
|
|
215
|
+
warnings.warn(
|
|
216
|
+
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
217
|
+
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
|
|
218
|
+
"version of Flower.",
|
|
219
|
+
stacklevel=2,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
223
|
+
metadata = Metadata(
|
|
224
|
+
run_id=run_id,
|
|
225
|
+
message_id="", # Will be set by the server
|
|
226
|
+
src_node_id=self.node.node_id,
|
|
227
|
+
dst_node_id=dst_node_id,
|
|
228
|
+
reply_to_message="",
|
|
229
|
+
group_id=group_id,
|
|
230
|
+
ttl=ttl_,
|
|
231
|
+
message_type=message_type,
|
|
232
|
+
)
|
|
233
|
+
return Message(metadata=metadata, content=content)
|
|
234
|
+
|
|
235
|
+
def get_node_ids(self) -> List[int]:
|
|
236
|
+
"""Get node IDs."""
|
|
237
|
+
grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id()
|
|
238
|
+
# Call GrpcDriverHelper method
|
|
239
|
+
res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id))
|
|
240
|
+
return [node.node_id for node in res.nodes]
|
|
241
|
+
|
|
242
|
+
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
243
|
+
"""Push messages to specified node IDs.
|
|
244
|
+
|
|
245
|
+
This method takes an iterable of messages and sends each message
|
|
246
|
+
to the node specified in `dst_node_id`.
|
|
247
|
+
"""
|
|
248
|
+
grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id()
|
|
249
|
+
# Construct TaskIns
|
|
250
|
+
task_ins_list: List[TaskIns] = []
|
|
251
|
+
for msg in messages:
|
|
252
|
+
# Check message
|
|
253
|
+
self._check_message(msg)
|
|
254
|
+
# Convert Message to TaskIns
|
|
255
|
+
taskins = message_to_taskins(msg)
|
|
256
|
+
# Add to list
|
|
257
|
+
task_ins_list.append(taskins)
|
|
258
|
+
# Call GrpcDriverHelper method
|
|
259
|
+
res = grpc_driver_helper.push_task_ins(
|
|
260
|
+
PushTaskInsRequest(task_ins_list=task_ins_list)
|
|
261
|
+
)
|
|
262
|
+
return list(res.task_ids)
|
|
263
|
+
|
|
264
|
+
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
265
|
+
"""Pull messages based on message IDs.
|
|
266
|
+
|
|
267
|
+
This method is used to collect messages from the SuperLink that correspond to a
|
|
268
|
+
set of given message IDs.
|
|
269
|
+
"""
|
|
270
|
+
grpc_driver, _ = self._get_grpc_driver_helper_and_run_id()
|
|
271
|
+
# Pull TaskRes
|
|
272
|
+
res = grpc_driver.pull_task_res(
|
|
273
|
+
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
274
|
+
)
|
|
275
|
+
# Convert TaskRes to Message
|
|
276
|
+
msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
|
|
277
|
+
return msgs
|
|
278
|
+
|
|
279
|
+
def send_and_receive(
|
|
280
|
+
self,
|
|
281
|
+
messages: Iterable[Message],
|
|
282
|
+
*,
|
|
283
|
+
timeout: Optional[float] = None,
|
|
284
|
+
) -> Iterable[Message]:
|
|
285
|
+
"""Push messages to specified node IDs and pull the reply messages.
|
|
286
|
+
|
|
287
|
+
This method sends a list of messages to their destination node IDs and then
|
|
288
|
+
waits for the replies. It continues to pull replies until either all replies are
|
|
289
|
+
received or the specified timeout duration is exceeded.
|
|
290
|
+
"""
|
|
291
|
+
# Push messages
|
|
292
|
+
msg_ids = set(self.push_messages(messages))
|
|
293
|
+
|
|
294
|
+
# Pull messages
|
|
295
|
+
end_time = time.time() + (timeout if timeout is not None else 0.0)
|
|
296
|
+
ret: List[Message] = []
|
|
297
|
+
while timeout is None or time.time() < end_time:
|
|
298
|
+
res_msgs = self.pull_messages(msg_ids)
|
|
299
|
+
ret.extend(res_msgs)
|
|
300
|
+
msg_ids.difference_update(
|
|
301
|
+
{msg.metadata.reply_to_message for msg in res_msgs}
|
|
302
|
+
)
|
|
303
|
+
if len(msg_ids) == 0:
|
|
304
|
+
break
|
|
305
|
+
# Sleep
|
|
306
|
+
time.sleep(3)
|
|
307
|
+
return ret
|
|
308
|
+
|
|
309
|
+
def close(self) -> None:
|
|
310
|
+
"""Disconnect from the SuperLink if connected."""
|
|
311
|
+
# Check if GrpcDriverHelper is initialized
|
|
312
|
+
if self.driver_helper is None:
|
|
313
|
+
return
|
|
314
|
+
# Disconnect
|
|
315
|
+
self.driver_helper.disconnect()
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -25,7 +25,7 @@ from flwr.common import Context, EventType, RecordSet, event
|
|
|
25
25
|
from flwr.common.logger import log, update_console_handler
|
|
26
26
|
from flwr.common.object_ref import load_app
|
|
27
27
|
|
|
28
|
-
from .driver
|
|
28
|
+
from .driver import Driver, GrpcDriver
|
|
29
29
|
from .server_app import LoadServerAppError, ServerApp
|
|
30
30
|
|
|
31
31
|
|
|
@@ -128,13 +128,15 @@ def run_server_app() -> None:
|
|
|
128
128
|
server_app_dir = args.dir
|
|
129
129
|
server_app_attr = getattr(args, "server-app")
|
|
130
130
|
|
|
131
|
-
# Initialize
|
|
132
|
-
driver =
|
|
131
|
+
# Initialize GrpcDriver
|
|
132
|
+
driver = GrpcDriver(
|
|
133
133
|
driver_service_address=args.server,
|
|
134
134
|
root_certificates=root_certificates,
|
|
135
|
+
fab_id=args.fab_id,
|
|
136
|
+
fab_version=args.fab_version,
|
|
135
137
|
)
|
|
136
138
|
|
|
137
|
-
# Run the
|
|
139
|
+
# Run the ServerApp with the Driver
|
|
138
140
|
run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
|
|
139
141
|
|
|
140
142
|
# Clean up
|
|
@@ -183,5 +185,17 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
|
183
185
|
"app from there."
|
|
184
186
|
" Default: current working directory.",
|
|
185
187
|
)
|
|
188
|
+
parser.add_argument(
|
|
189
|
+
"--fab-id",
|
|
190
|
+
default=None,
|
|
191
|
+
type=str,
|
|
192
|
+
help="The identifier of the FAB used in the run.",
|
|
193
|
+
)
|
|
194
|
+
parser.add_argument(
|
|
195
|
+
"--fab-version",
|
|
196
|
+
default=None,
|
|
197
|
+
type=str,
|
|
198
|
+
help="The version of the FAB used in the run.",
|
|
199
|
+
)
|
|
186
200
|
|
|
187
201
|
return parser
|
|
@@ -200,7 +200,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
200
200
|
|
|
201
201
|
log(
|
|
202
202
|
INFO,
|
|
203
|
-
"aggregate_fit: parameters are clipped by value:
|
|
203
|
+
"aggregate_fit: parameters are clipped by value: %.4f.",
|
|
204
204
|
self.clipping_norm,
|
|
205
205
|
)
|
|
206
206
|
|
|
@@ -234,7 +234,8 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
234
234
|
)
|
|
235
235
|
log(
|
|
236
236
|
INFO,
|
|
237
|
-
"aggregate_fit: central DP noise with
|
|
237
|
+
"aggregate_fit: central DP noise with "
|
|
238
|
+
"standard deviation: %.4f added to parameters.",
|
|
238
239
|
compute_stdv(
|
|
239
240
|
self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
|
|
240
241
|
),
|
|
@@ -424,7 +425,8 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
|
|
424
425
|
)
|
|
425
426
|
log(
|
|
426
427
|
INFO,
|
|
427
|
-
"aggregate_fit: central DP noise with
|
|
428
|
+
"aggregate_fit: central DP noise with "
|
|
429
|
+
"standard deviation: %.4f added to parameters.",
|
|
428
430
|
compute_stdv(
|
|
429
431
|
self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
|
|
430
432
|
),
|
|
@@ -158,7 +158,7 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
|
|
|
158
158
|
)
|
|
159
159
|
log(
|
|
160
160
|
INFO,
|
|
161
|
-
"aggregate_fit: parameters are clipped by value:
|
|
161
|
+
"aggregate_fit: parameters are clipped by value: %.4f.",
|
|
162
162
|
self.clipping_norm,
|
|
163
163
|
)
|
|
164
164
|
# Convert back to parameters
|
|
@@ -180,7 +180,8 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
|
|
|
180
180
|
|
|
181
181
|
log(
|
|
182
182
|
INFO,
|
|
183
|
-
"aggregate_fit: central DP noise with
|
|
183
|
+
"aggregate_fit: central DP noise with "
|
|
184
|
+
"standard deviation: %.4f added to parameters.",
|
|
184
185
|
compute_stdv(
|
|
185
186
|
self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
|
|
186
187
|
),
|
|
@@ -337,11 +338,13 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
|
|
|
337
338
|
)
|
|
338
339
|
log(
|
|
339
340
|
INFO,
|
|
340
|
-
"aggregate_fit: central DP noise with
|
|
341
|
+
"aggregate_fit: central DP noise with "
|
|
342
|
+
"standard deviation: %.4f added to parameters.",
|
|
341
343
|
compute_stdv(
|
|
342
344
|
self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
|
|
343
345
|
),
|
|
344
346
|
)
|
|
347
|
+
|
|
345
348
|
return aggregated_params, metrics
|
|
346
349
|
|
|
347
350
|
def aggregate_evaluate(
|
|
@@ -64,7 +64,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
64
64
|
"""Create run ID."""
|
|
65
65
|
log(INFO, "DriverServicer.CreateRun")
|
|
66
66
|
state: State = self.state_factory.state()
|
|
67
|
-
run_id = state.create_run()
|
|
67
|
+
run_id = state.create_run(request.fab_id, request.fab_version)
|
|
68
68
|
return CreateRunResponse(run_id=run_id)
|
|
69
69
|
|
|
70
70
|
def PushTaskIns(
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import concurrent.futures
|
|
19
19
|
import sys
|
|
20
20
|
from logging import ERROR
|
|
21
|
-
from typing import Any, Callable, Optional, Tuple, Union
|
|
21
|
+
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
22
22
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
@@ -162,6 +162,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
|
|
|
162
162
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
163
163
|
keepalive_time_ms: int = 210000,
|
|
164
164
|
certificates: Optional[Tuple[bytes, bytes, bytes]] = None,
|
|
165
|
+
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
165
166
|
) -> grpc.Server:
|
|
166
167
|
"""Create a gRPC server with a single servicer.
|
|
167
168
|
|
|
@@ -249,6 +250,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
|
|
|
249
250
|
# returning RESOURCE_EXHAUSTED status, or None to indicate no limit.
|
|
250
251
|
maximum_concurrent_rpcs=max_concurrent_workers,
|
|
251
252
|
options=options,
|
|
253
|
+
interceptors=interceptors,
|
|
252
254
|
)
|
|
253
255
|
add_servicer_to_server_fn(servicer, server)
|
|
254
256
|
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower server interceptor."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import base64
|
|
19
|
+
from logging import WARNING
|
|
20
|
+
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
21
|
+
|
|
22
|
+
import grpc
|
|
23
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
|
+
|
|
25
|
+
from flwr.common.logger import log
|
|
26
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
27
|
+
bytes_to_private_key,
|
|
28
|
+
bytes_to_public_key,
|
|
29
|
+
generate_shared_key,
|
|
30
|
+
verify_hmac,
|
|
31
|
+
)
|
|
32
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
33
|
+
CreateNodeRequest,
|
|
34
|
+
CreateNodeResponse,
|
|
35
|
+
DeleteNodeRequest,
|
|
36
|
+
DeleteNodeResponse,
|
|
37
|
+
GetRunRequest,
|
|
38
|
+
GetRunResponse,
|
|
39
|
+
PingRequest,
|
|
40
|
+
PingResponse,
|
|
41
|
+
PullTaskInsRequest,
|
|
42
|
+
PullTaskInsResponse,
|
|
43
|
+
PushTaskResRequest,
|
|
44
|
+
PushTaskResResponse,
|
|
45
|
+
)
|
|
46
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
47
|
+
from flwr.server.superlink.state import State
|
|
48
|
+
|
|
49
|
+
_PUBLIC_KEY_HEADER = "public-key"
|
|
50
|
+
_AUTH_TOKEN_HEADER = "auth-token"
|
|
51
|
+
|
|
52
|
+
Request = Union[
|
|
53
|
+
CreateNodeRequest,
|
|
54
|
+
DeleteNodeRequest,
|
|
55
|
+
PullTaskInsRequest,
|
|
56
|
+
PushTaskResRequest,
|
|
57
|
+
GetRunRequest,
|
|
58
|
+
PingRequest,
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
Response = Union[
|
|
62
|
+
CreateNodeResponse,
|
|
63
|
+
DeleteNodeResponse,
|
|
64
|
+
PullTaskInsResponse,
|
|
65
|
+
PushTaskResResponse,
|
|
66
|
+
GetRunResponse,
|
|
67
|
+
PingResponse,
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _get_value_from_tuples(
|
|
72
|
+
key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]
|
|
73
|
+
) -> bytes:
|
|
74
|
+
value = next((value for key, value in tuples if key == key_string), "")
|
|
75
|
+
if isinstance(value, str):
|
|
76
|
+
return value.encode()
|
|
77
|
+
|
|
78
|
+
return value
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
82
|
+
"""Server interceptor for client authentication."""
|
|
83
|
+
|
|
84
|
+
def __init__(self, state: State):
|
|
85
|
+
self.state = state
|
|
86
|
+
|
|
87
|
+
self.client_public_keys = state.get_client_public_keys()
|
|
88
|
+
if len(self.client_public_keys) == 0:
|
|
89
|
+
log(WARNING, "Authentication enabled, but no known public keys configured")
|
|
90
|
+
|
|
91
|
+
private_key = self.state.get_server_private_key()
|
|
92
|
+
public_key = self.state.get_server_public_key()
|
|
93
|
+
|
|
94
|
+
if private_key is None or public_key is None:
|
|
95
|
+
raise ValueError("Error loading authentication keys")
|
|
96
|
+
|
|
97
|
+
self.server_private_key = bytes_to_private_key(private_key)
|
|
98
|
+
self.encoded_server_public_key = base64.urlsafe_b64encode(public_key)
|
|
99
|
+
|
|
100
|
+
def intercept_service(
|
|
101
|
+
self,
|
|
102
|
+
continuation: Callable[[Any], Any],
|
|
103
|
+
handler_call_details: grpc.HandlerCallDetails,
|
|
104
|
+
) -> grpc.RpcMethodHandler:
|
|
105
|
+
"""Flower server interceptor authentication logic.
|
|
106
|
+
|
|
107
|
+
Intercept all unary calls from clients and authenticate clients by validating
|
|
108
|
+
auth metadata sent by the client. Continue RPC call if client is authenticated,
|
|
109
|
+
else, terminate RPC call by setting context to abort.
|
|
110
|
+
"""
|
|
111
|
+
# One of the method handlers in
|
|
112
|
+
# `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
|
|
113
|
+
method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
|
|
114
|
+
return self._generic_auth_unary_method_handler(method_handler)
|
|
115
|
+
|
|
116
|
+
def _generic_auth_unary_method_handler(
|
|
117
|
+
self, method_handler: grpc.RpcMethodHandler
|
|
118
|
+
) -> grpc.RpcMethodHandler:
|
|
119
|
+
def _generic_method_handler(
|
|
120
|
+
request: Request,
|
|
121
|
+
context: grpc.ServicerContext,
|
|
122
|
+
) -> Response:
|
|
123
|
+
client_public_key_bytes = base64.urlsafe_b64decode(
|
|
124
|
+
_get_value_from_tuples(
|
|
125
|
+
_PUBLIC_KEY_HEADER, context.invocation_metadata()
|
|
126
|
+
)
|
|
127
|
+
)
|
|
128
|
+
if client_public_key_bytes not in self.client_public_keys:
|
|
129
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
130
|
+
|
|
131
|
+
if isinstance(request, CreateNodeRequest):
|
|
132
|
+
return self._create_authenticated_node(
|
|
133
|
+
client_public_key_bytes, request, context
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Verify hmac value
|
|
137
|
+
hmac_value = base64.urlsafe_b64decode(
|
|
138
|
+
_get_value_from_tuples(
|
|
139
|
+
_AUTH_TOKEN_HEADER, context.invocation_metadata()
|
|
140
|
+
)
|
|
141
|
+
)
|
|
142
|
+
public_key = bytes_to_public_key(client_public_key_bytes)
|
|
143
|
+
|
|
144
|
+
if not self._verify_hmac(public_key, request, hmac_value):
|
|
145
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
146
|
+
|
|
147
|
+
# Verify node_id
|
|
148
|
+
node_id = self.state.get_node_id(client_public_key_bytes)
|
|
149
|
+
|
|
150
|
+
if not self._verify_node_id(node_id, request):
|
|
151
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
152
|
+
|
|
153
|
+
return method_handler.unary_unary(request, context) # type: ignore
|
|
154
|
+
|
|
155
|
+
return grpc.unary_unary_rpc_method_handler(
|
|
156
|
+
_generic_method_handler,
|
|
157
|
+
request_deserializer=method_handler.request_deserializer,
|
|
158
|
+
response_serializer=method_handler.response_serializer,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def _verify_node_id(
|
|
162
|
+
self,
|
|
163
|
+
node_id: Optional[int],
|
|
164
|
+
request: Union[
|
|
165
|
+
DeleteNodeRequest,
|
|
166
|
+
PullTaskInsRequest,
|
|
167
|
+
PushTaskResRequest,
|
|
168
|
+
GetRunRequest,
|
|
169
|
+
PingRequest,
|
|
170
|
+
],
|
|
171
|
+
) -> bool:
|
|
172
|
+
if node_id is None:
|
|
173
|
+
return False
|
|
174
|
+
if isinstance(request, PushTaskResRequest):
|
|
175
|
+
if len(request.task_res_list) == 0:
|
|
176
|
+
return False
|
|
177
|
+
return request.task_res_list[0].task.producer.node_id == node_id
|
|
178
|
+
if isinstance(request, GetRunRequest):
|
|
179
|
+
return node_id in self.state.get_nodes(request.run_id)
|
|
180
|
+
return request.node.node_id == node_id
|
|
181
|
+
|
|
182
|
+
def _verify_hmac(
|
|
183
|
+
self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
|
|
184
|
+
) -> bool:
|
|
185
|
+
shared_secret = generate_shared_key(self.server_private_key, public_key)
|
|
186
|
+
return verify_hmac(shared_secret, request.SerializeToString(True), hmac_value)
|
|
187
|
+
|
|
188
|
+
def _create_authenticated_node(
|
|
189
|
+
self,
|
|
190
|
+
public_key_bytes: bytes,
|
|
191
|
+
request: CreateNodeRequest,
|
|
192
|
+
context: grpc.ServicerContext,
|
|
193
|
+
) -> CreateNodeResponse:
|
|
194
|
+
context.send_initial_metadata(
|
|
195
|
+
(
|
|
196
|
+
(
|
|
197
|
+
_PUBLIC_KEY_HEADER,
|
|
198
|
+
self.encoded_server_public_key,
|
|
199
|
+
),
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
node_id = self.state.get_node_id(public_key_bytes)
|
|
204
|
+
|
|
205
|
+
# Handle `CreateNode` here instead of calling the default method handler
|
|
206
|
+
# Return previously assigned `node_id` for the provided `public_key`
|
|
207
|
+
if node_id is not None:
|
|
208
|
+
self.state.acknowledge_ping(node_id, request.ping_interval)
|
|
209
|
+
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
210
|
+
|
|
211
|
+
# No `node_id` exists for the provided `public_key`
|
|
212
|
+
# Handle `CreateNode` here instead of calling the default method handler
|
|
213
|
+
# Note: the innermost `CreateNode` method will never be called
|
|
214
|
+
node_id = self.state.create_node(request.ping_interval, public_key_bytes)
|
|
215
|
+
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
@@ -33,6 +33,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
33
33
|
PushTaskResRequest,
|
|
34
34
|
PushTaskResResponse,
|
|
35
35
|
Reconnect,
|
|
36
|
+
Run,
|
|
36
37
|
)
|
|
37
38
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
38
39
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
@@ -109,4 +110,6 @@ def get_run(
|
|
|
109
110
|
request: GetRunRequest, state: State # pylint: disable=W0613
|
|
110
111
|
) -> GetRunResponse:
|
|
111
112
|
"""Get run information."""
|
|
112
|
-
|
|
113
|
+
run_id, fab_id, fab_version = state.get_run(request.run_id)
|
|
114
|
+
run = Run(run_id=run_id, fab_id=fab_id, fab_version=fab_version)
|
|
115
|
+
return GetRunResponse(run=run)
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Ray backend for the Fleet API using the Simulation Engine."""
|
|
16
16
|
|
|
17
17
|
import pathlib
|
|
18
|
-
from logging import ERROR, INFO
|
|
18
|
+
from logging import DEBUG, ERROR, INFO
|
|
19
19
|
from typing import Callable, Dict, List, Tuple, Union
|
|
20
20
|
|
|
21
21
|
import ray
|
|
@@ -46,7 +46,7 @@ class RayBackend(Backend):
|
|
|
46
46
|
) -> None:
|
|
47
47
|
"""Prepare RayBackend by initialising Ray and creating the ActorPool."""
|
|
48
48
|
log(INFO, "Initialising: %s", self.__class__.__name__)
|
|
49
|
-
log(
|
|
49
|
+
log(DEBUG, "Backend config: %s", backend_config)
|
|
50
50
|
|
|
51
51
|
if not pathlib.Path(work_dir).exists():
|
|
52
52
|
raise ValueError(f"Specified work_dir {work_dir} does not exist.")
|
|
@@ -109,7 +109,7 @@ class RayBackend(Backend):
|
|
|
109
109
|
else:
|
|
110
110
|
client_resources = {"num_cpus": 2, "num_gpus": 0.0}
|
|
111
111
|
log(
|
|
112
|
-
|
|
112
|
+
DEBUG,
|
|
113
113
|
"`%s` not specified in backend config. Applying default setting: %s",
|
|
114
114
|
self.client_resources_key,
|
|
115
115
|
client_resources,
|
|
@@ -129,7 +129,7 @@ class RayBackend(Backend):
|
|
|
129
129
|
async def build(self) -> None:
|
|
130
130
|
"""Build pool of Ray actors that this backend will submit jobs to."""
|
|
131
131
|
await self.pool.add_actors_to_pool(self.pool.actors_capacity)
|
|
132
|
-
log(
|
|
132
|
+
log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
|
|
133
133
|
|
|
134
134
|
async def process_message(
|
|
135
135
|
self,
|
|
@@ -173,4 +173,4 @@ class RayBackend(Backend):
|
|
|
173
173
|
"""Terminate all actors in actor pool."""
|
|
174
174
|
await self.pool.terminate_all_actors()
|
|
175
175
|
ray.shutdown()
|
|
176
|
-
log(
|
|
176
|
+
log(DEBUG, "Terminated %s", self.__class__.__name__)
|