flwr-nightly 1.9.0.dev20240417__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +151 -0
- flwr/cli/config_utils.py +19 -14
- flwr/cli/new/new.py +51 -22
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +42 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +26 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
- flwr/cli/run/run.py +1 -1
- flwr/cli/utils.py +18 -17
- flwr/client/__init__.py +3 -1
- flwr/client/app.py +20 -142
- flwr/client/grpc_client/connection.py +8 -2
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +33 -4
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/localdp_mod.py +9 -3
- flwr/client/rest_client/connection.py +92 -169
- flwr/client/supernode/__init__.py +24 -0
- flwr/client/supernode/app.py +281 -0
- flwr/common/grpc.py +5 -1
- flwr/common/logger.py +37 -4
- flwr/common/message.py +105 -86
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
- flwr/common/telemetry.py +4 -0
- flwr/server/app.py +116 -6
- flwr/server/compat/app.py +2 -2
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +27 -70
- flwr/server/driver/__init__.py +2 -1
- flwr/server/driver/driver.py +12 -139
- flwr/server/driver/grpc_driver.py +199 -13
- flwr/server/run_serverapp.py +18 -4
- flwr/server/strategy/dp_adaptive_clipping.py +5 -3
- flwr/server/strategy/dp_fixed_clipping.py +6 -3
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +4 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +89 -12
- flwr/server/superlink/state/sqlite_state.py +133 -16
- flwr/server/superlink/state/state.py +56 -6
- flwr/simulation/__init__.py +2 -2
- flwr/simulation/app.py +16 -1
- flwr/simulation/run_simulation.py +10 -7
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +66 -52
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +2 -1
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
|
@@ -21,7 +21,10 @@ import threading
|
|
|
21
21
|
from contextlib import contextmanager
|
|
22
22
|
from copy import copy
|
|
23
23
|
from logging import ERROR, INFO, WARN
|
|
24
|
-
from typing import Callable, Iterator, Optional, Tuple, Union
|
|
24
|
+
from typing import Callable, Iterator, Optional, Tuple, Type, TypeVar, Union
|
|
25
|
+
|
|
26
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
27
|
+
from google.protobuf.message import Message as GrpcMessage
|
|
25
28
|
|
|
26
29
|
from flwr.client.heartbeat import start_ping_loop
|
|
27
30
|
from flwr.client.message_handler.message_handler import validate_out_message
|
|
@@ -42,6 +45,9 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
42
45
|
CreateNodeRequest,
|
|
43
46
|
CreateNodeResponse,
|
|
44
47
|
DeleteNodeRequest,
|
|
48
|
+
DeleteNodeResponse,
|
|
49
|
+
GetRunRequest,
|
|
50
|
+
GetRunResponse,
|
|
45
51
|
PingRequest,
|
|
46
52
|
PingResponse,
|
|
47
53
|
PullTaskInsRequest,
|
|
@@ -63,10 +69,13 @@ PATH_DELETE_NODE: str = "api/v0/fleet/delete-node"
|
|
|
63
69
|
PATH_PULL_TASK_INS: str = "api/v0/fleet/pull-task-ins"
|
|
64
70
|
PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
|
|
65
71
|
PATH_PING: str = "api/v0/fleet/ping"
|
|
72
|
+
PATH_GET_RUN: str = "/api/v0/fleet/get-run"
|
|
73
|
+
|
|
74
|
+
T = TypeVar("T", bound=GrpcMessage)
|
|
66
75
|
|
|
67
76
|
|
|
68
77
|
@contextmanager
|
|
69
|
-
def http_request_response( # pylint: disable
|
|
78
|
+
def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
70
79
|
server_address: str,
|
|
71
80
|
insecure: bool, # pylint: disable=unused-argument
|
|
72
81
|
retry_invoker: RetryInvoker,
|
|
@@ -74,12 +83,16 @@ def http_request_response( # pylint: disable=R0914, R0915
|
|
|
74
83
|
root_certificates: Optional[
|
|
75
84
|
Union[bytes, str]
|
|
76
85
|
] = None, # pylint: disable=unused-argument
|
|
86
|
+
authentication_keys: Optional[ # pylint: disable=unused-argument
|
|
87
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
88
|
+
] = None,
|
|
77
89
|
) -> Iterator[
|
|
78
90
|
Tuple[
|
|
79
91
|
Callable[[], Optional[Message]],
|
|
80
92
|
Callable[[Message], None],
|
|
81
93
|
Optional[Callable[[], None]],
|
|
82
94
|
Optional[Callable[[], None]],
|
|
95
|
+
Optional[Callable[[int], Tuple[str, str]]],
|
|
83
96
|
]
|
|
84
97
|
]:
|
|
85
98
|
"""Primitives for request/response-based interaction with a server.
|
|
@@ -141,55 +154,72 @@ def http_request_response( # pylint: disable=R0914, R0915
|
|
|
141
154
|
ping_stop_event = threading.Event()
|
|
142
155
|
|
|
143
156
|
###########################################################################
|
|
144
|
-
# ping/create_node/delete_node/receive/send functions
|
|
157
|
+
# ping/create_node/delete_node/receive/send/get_run functions
|
|
145
158
|
###########################################################################
|
|
146
159
|
|
|
147
|
-
def
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
# Construct the ping request
|
|
154
|
-
req = PingRequest(node=node, ping_interval=PING_DEFAULT_INTERVAL)
|
|
155
|
-
req_bytes: bytes = req.SerializeToString()
|
|
160
|
+
def _request(
|
|
161
|
+
req: GrpcMessage, res_type: Type[T], api_path: str, retry: bool = True
|
|
162
|
+
) -> Optional[T]:
|
|
163
|
+
# Serialize the request
|
|
164
|
+
req_bytes = req.SerializeToString()
|
|
156
165
|
|
|
157
166
|
# Send the request
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
167
|
+
def post() -> requests.Response:
|
|
168
|
+
return requests.post(
|
|
169
|
+
f"{base_url}/{api_path}",
|
|
170
|
+
data=req_bytes,
|
|
171
|
+
headers={
|
|
172
|
+
"Accept": "application/protobuf",
|
|
173
|
+
"Content-Type": "application/protobuf",
|
|
174
|
+
},
|
|
175
|
+
verify=verify,
|
|
176
|
+
timeout=None,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
if retry:
|
|
180
|
+
res: requests.Response = retry_invoker.invoke(post)
|
|
181
|
+
else:
|
|
182
|
+
res = post()
|
|
168
183
|
|
|
169
184
|
# Check status code and headers
|
|
170
185
|
if res.status_code != 200:
|
|
171
|
-
return
|
|
186
|
+
return None
|
|
172
187
|
if "content-type" not in res.headers:
|
|
173
188
|
log(
|
|
174
189
|
WARN,
|
|
175
190
|
"[Node] POST /%s: missing header `Content-Type`",
|
|
176
|
-
|
|
191
|
+
api_path,
|
|
177
192
|
)
|
|
178
|
-
return
|
|
193
|
+
return None
|
|
179
194
|
if res.headers["content-type"] != "application/protobuf":
|
|
180
195
|
log(
|
|
181
196
|
WARN,
|
|
182
197
|
"[Node] POST /%s: header `Content-Type` has wrong value",
|
|
183
|
-
|
|
198
|
+
api_path,
|
|
184
199
|
)
|
|
185
|
-
return
|
|
200
|
+
return None
|
|
186
201
|
|
|
187
202
|
# Deserialize ProtoBuf from bytes
|
|
188
|
-
|
|
189
|
-
|
|
203
|
+
grpc_res = res_type()
|
|
204
|
+
grpc_res.ParseFromString(res.content)
|
|
205
|
+
return grpc_res
|
|
206
|
+
|
|
207
|
+
def ping() -> None:
|
|
208
|
+
# Get Node
|
|
209
|
+
if node is None:
|
|
210
|
+
log(ERROR, "Node instance missing")
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
# Construct the ping request
|
|
214
|
+
req = PingRequest(node=node, ping_interval=PING_DEFAULT_INTERVAL)
|
|
215
|
+
|
|
216
|
+
# Send the request
|
|
217
|
+
res = _request(req, PingResponse, PATH_PING, retry=False)
|
|
218
|
+
if res is None:
|
|
219
|
+
return
|
|
190
220
|
|
|
191
221
|
# Check if success
|
|
192
|
-
if not
|
|
222
|
+
if not res.success:
|
|
193
223
|
raise RuntimeError("Ping failed unexpectedly.")
|
|
194
224
|
|
|
195
225
|
# Wait
|
|
@@ -201,46 +231,16 @@ def http_request_response( # pylint: disable=R0914, R0915
|
|
|
201
231
|
|
|
202
232
|
def create_node() -> None:
|
|
203
233
|
"""Set create_node."""
|
|
204
|
-
|
|
205
|
-
create_node_req_bytes: bytes = create_node_req_proto.SerializeToString()
|
|
206
|
-
|
|
207
|
-
res = retry_invoker.invoke(
|
|
208
|
-
requests.post,
|
|
209
|
-
url=f"{base_url}/{PATH_CREATE_NODE}",
|
|
210
|
-
headers={
|
|
211
|
-
"Accept": "application/protobuf",
|
|
212
|
-
"Content-Type": "application/protobuf",
|
|
213
|
-
},
|
|
214
|
-
data=create_node_req_bytes,
|
|
215
|
-
verify=verify,
|
|
216
|
-
timeout=None,
|
|
217
|
-
)
|
|
234
|
+
req = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
|
|
218
235
|
|
|
219
|
-
#
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
if "content-type" not in res.headers:
|
|
223
|
-
log(
|
|
224
|
-
WARN,
|
|
225
|
-
"[Node] POST /%s: missing header `Content-Type`",
|
|
226
|
-
PATH_PULL_TASK_INS,
|
|
227
|
-
)
|
|
228
|
-
return
|
|
229
|
-
if res.headers["content-type"] != "application/protobuf":
|
|
230
|
-
log(
|
|
231
|
-
WARN,
|
|
232
|
-
"[Node] POST /%s: header `Content-Type` has wrong value",
|
|
233
|
-
PATH_PULL_TASK_INS,
|
|
234
|
-
)
|
|
236
|
+
# Send the request
|
|
237
|
+
res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
|
|
238
|
+
if res is None:
|
|
235
239
|
return
|
|
236
240
|
|
|
237
|
-
# Deserialize ProtoBuf from bytes
|
|
238
|
-
create_node_response_proto = CreateNodeResponse()
|
|
239
|
-
create_node_response_proto.ParseFromString(res.content)
|
|
240
|
-
|
|
241
241
|
# Remember the node and the ping-loop thread
|
|
242
242
|
nonlocal node, ping_thread
|
|
243
|
-
node =
|
|
243
|
+
node = res.node
|
|
244
244
|
ping_thread = start_ping_loop(ping, ping_stop_event)
|
|
245
245
|
|
|
246
246
|
def delete_node() -> None:
|
|
@@ -256,36 +256,12 @@ def http_request_response( # pylint: disable=R0914, R0915
|
|
|
256
256
|
ping_thread.join()
|
|
257
257
|
|
|
258
258
|
# Send DeleteNode request
|
|
259
|
-
|
|
260
|
-
delete_node_req_req_bytes: bytes = delete_node_req_proto.SerializeToString()
|
|
261
|
-
res = retry_invoker.invoke(
|
|
262
|
-
requests.post,
|
|
263
|
-
url=f"{base_url}/{PATH_DELETE_NODE}",
|
|
264
|
-
headers={
|
|
265
|
-
"Accept": "application/protobuf",
|
|
266
|
-
"Content-Type": "application/protobuf",
|
|
267
|
-
},
|
|
268
|
-
data=delete_node_req_req_bytes,
|
|
269
|
-
verify=verify,
|
|
270
|
-
timeout=None,
|
|
271
|
-
)
|
|
259
|
+
req = DeleteNodeRequest(node=node)
|
|
272
260
|
|
|
273
|
-
#
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
if "content-type" not in res.headers:
|
|
277
|
-
log(
|
|
278
|
-
WARN,
|
|
279
|
-
"[Node] POST /%s: missing header `Content-Type`",
|
|
280
|
-
PATH_PULL_TASK_INS,
|
|
281
|
-
)
|
|
261
|
+
# Send the request
|
|
262
|
+
res = _request(req, DeleteNodeResponse, PATH_CREATE_NODE)
|
|
263
|
+
if res is None:
|
|
282
264
|
return
|
|
283
|
-
if res.headers["content-type"] != "application/protobuf":
|
|
284
|
-
log(
|
|
285
|
-
WARN,
|
|
286
|
-
"[Node] POST /%s: header `Content-Type` has wrong value",
|
|
287
|
-
PATH_PULL_TASK_INS,
|
|
288
|
-
)
|
|
289
265
|
|
|
290
266
|
# Cleanup
|
|
291
267
|
node = None
|
|
@@ -298,46 +274,15 @@ def http_request_response( # pylint: disable=R0914, R0915
|
|
|
298
274
|
return None
|
|
299
275
|
|
|
300
276
|
# Request instructions (task) from server
|
|
301
|
-
|
|
302
|
-
pull_task_ins_req_bytes: bytes = pull_task_ins_req_proto.SerializeToString()
|
|
277
|
+
req = PullTaskInsRequest(node=node)
|
|
303
278
|
|
|
304
|
-
#
|
|
305
|
-
res =
|
|
306
|
-
|
|
307
|
-
url=f"{base_url}/{PATH_PULL_TASK_INS}",
|
|
308
|
-
headers={
|
|
309
|
-
"Accept": "application/protobuf",
|
|
310
|
-
"Content-Type": "application/protobuf",
|
|
311
|
-
},
|
|
312
|
-
data=pull_task_ins_req_bytes,
|
|
313
|
-
verify=verify,
|
|
314
|
-
timeout=None,
|
|
315
|
-
)
|
|
316
|
-
|
|
317
|
-
# Check status code and headers
|
|
318
|
-
if res.status_code != 200:
|
|
319
|
-
return None
|
|
320
|
-
if "content-type" not in res.headers:
|
|
321
|
-
log(
|
|
322
|
-
WARN,
|
|
323
|
-
"[Node] POST /%s: missing header `Content-Type`",
|
|
324
|
-
PATH_PULL_TASK_INS,
|
|
325
|
-
)
|
|
326
|
-
return None
|
|
327
|
-
if res.headers["content-type"] != "application/protobuf":
|
|
328
|
-
log(
|
|
329
|
-
WARN,
|
|
330
|
-
"[Node] POST /%s: header `Content-Type` has wrong value",
|
|
331
|
-
PATH_PULL_TASK_INS,
|
|
332
|
-
)
|
|
279
|
+
# Send the request
|
|
280
|
+
res = _request(req, PullTaskInsResponse, PATH_PULL_TASK_INS)
|
|
281
|
+
if res is None:
|
|
333
282
|
return None
|
|
334
283
|
|
|
335
|
-
# Deserialize ProtoBuf from bytes
|
|
336
|
-
pull_task_ins_response_proto = PullTaskInsResponse()
|
|
337
|
-
pull_task_ins_response_proto.ParseFromString(res.content)
|
|
338
|
-
|
|
339
284
|
# Get the current TaskIns
|
|
340
|
-
task_ins: Optional[TaskIns] = get_task_ins(
|
|
285
|
+
task_ins: Optional[TaskIns] = get_task_ins(res)
|
|
341
286
|
|
|
342
287
|
# Discard the current TaskIns if not valid
|
|
343
288
|
if task_ins is not None and not (
|
|
@@ -372,61 +317,39 @@ def http_request_response( # pylint: disable=R0914, R0915
|
|
|
372
317
|
if not validate_out_message(message, metadata):
|
|
373
318
|
log(ERROR, "Invalid out message")
|
|
374
319
|
return
|
|
320
|
+
metadata = None
|
|
375
321
|
|
|
376
322
|
# Construct TaskRes
|
|
377
323
|
task_res = message_to_taskres(message)
|
|
378
324
|
|
|
379
325
|
# Serialize ProtoBuf to bytes
|
|
380
|
-
|
|
381
|
-
push_task_res_request_bytes: bytes = (
|
|
382
|
-
push_task_res_request_proto.SerializeToString()
|
|
383
|
-
)
|
|
384
|
-
|
|
385
|
-
# Send ClientMessage to server
|
|
386
|
-
res = retry_invoker.invoke(
|
|
387
|
-
requests.post,
|
|
388
|
-
url=f"{base_url}/{PATH_PUSH_TASK_RES}",
|
|
389
|
-
headers={
|
|
390
|
-
"Accept": "application/protobuf",
|
|
391
|
-
"Content-Type": "application/protobuf",
|
|
392
|
-
},
|
|
393
|
-
data=push_task_res_request_bytes,
|
|
394
|
-
verify=verify,
|
|
395
|
-
timeout=None,
|
|
396
|
-
)
|
|
397
|
-
|
|
398
|
-
metadata = None
|
|
326
|
+
req = PushTaskResRequest(task_res_list=[task_res])
|
|
399
327
|
|
|
400
|
-
#
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
if "content-type" not in res.headers:
|
|
404
|
-
log(
|
|
405
|
-
WARN,
|
|
406
|
-
"[Node] POST /%s: missing header `Content-Type`",
|
|
407
|
-
PATH_PUSH_TASK_RES,
|
|
408
|
-
)
|
|
409
|
-
return
|
|
410
|
-
if res.headers["content-type"] != "application/protobuf":
|
|
411
|
-
log(
|
|
412
|
-
WARN,
|
|
413
|
-
"[Node] POST /%s: header `Content-Type` has wrong value",
|
|
414
|
-
PATH_PUSH_TASK_RES,
|
|
415
|
-
)
|
|
328
|
+
# Send the request
|
|
329
|
+
res = _request(req, PushTaskResResponse, PATH_PUSH_TASK_RES)
|
|
330
|
+
if res is None:
|
|
416
331
|
return
|
|
417
332
|
|
|
418
|
-
# Deserialize ProtoBuf from bytes
|
|
419
|
-
push_task_res_response_proto = PushTaskResResponse()
|
|
420
|
-
push_task_res_response_proto.ParseFromString(res.content)
|
|
421
333
|
log(
|
|
422
334
|
INFO,
|
|
423
335
|
"[Node] POST /%s: success, created result %s",
|
|
424
336
|
PATH_PUSH_TASK_RES,
|
|
425
|
-
|
|
337
|
+
res.results, # pylint: disable=no-member
|
|
426
338
|
)
|
|
427
339
|
|
|
340
|
+
def get_run(run_id: int) -> Tuple[str, str]:
|
|
341
|
+
# Construct the request
|
|
342
|
+
req = GetRunRequest(run_id=run_id)
|
|
343
|
+
|
|
344
|
+
# Send the request
|
|
345
|
+
res = _request(req, GetRunResponse, PATH_GET_RUN)
|
|
346
|
+
if res is None:
|
|
347
|
+
return "", ""
|
|
348
|
+
|
|
349
|
+
return res.run.fab_id, res.run.fab_version
|
|
350
|
+
|
|
428
351
|
try:
|
|
429
352
|
# Yield methods
|
|
430
|
-
yield (receive, send, create_node, delete_node)
|
|
353
|
+
yield (receive, send, create_node, delete_node, get_run)
|
|
431
354
|
except Exception as exc: # pylint: disable=broad-except
|
|
432
355
|
log(ERROR, exc)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower SuperNode."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from .app import run_client_app as run_client_app
|
|
19
|
+
from .app import run_supernode as run_supernode
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"run_client_app",
|
|
23
|
+
"run_supernode",
|
|
24
|
+
]
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower SuperNode."""
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import sys
|
|
19
|
+
from logging import DEBUG, INFO, WARN
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Callable, Optional, Tuple
|
|
22
|
+
|
|
23
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
|
+
from cryptography.hazmat.primitives.serialization import (
|
|
25
|
+
load_ssh_private_key,
|
|
26
|
+
load_ssh_public_key,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
30
|
+
from flwr.common import EventType, event
|
|
31
|
+
from flwr.common.exit_handlers import register_exit_handlers
|
|
32
|
+
from flwr.common.logger import log
|
|
33
|
+
from flwr.common.object_ref import load_app, validate
|
|
34
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
35
|
+
ssh_types_to_elliptic_curve,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
from ..app import _start_client_internal
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def run_supernode() -> None:
|
|
42
|
+
"""Run Flower SuperNode."""
|
|
43
|
+
log(INFO, "Starting Flower SuperNode")
|
|
44
|
+
|
|
45
|
+
event(EventType.RUN_SUPERNODE_ENTER)
|
|
46
|
+
|
|
47
|
+
_ = _parse_args_run_supernode().parse_args()
|
|
48
|
+
|
|
49
|
+
log(
|
|
50
|
+
DEBUG,
|
|
51
|
+
"Flower SuperNode starting...",
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Graceful shutdown
|
|
55
|
+
register_exit_handlers(
|
|
56
|
+
event_type=EventType.RUN_SUPERNODE_LEAVE,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def run_client_app() -> None:
|
|
61
|
+
"""Run Flower client app."""
|
|
62
|
+
log(INFO, "Long-running Flower client starting")
|
|
63
|
+
|
|
64
|
+
event(EventType.RUN_CLIENT_APP_ENTER)
|
|
65
|
+
|
|
66
|
+
args = _parse_args_run_client_app().parse_args()
|
|
67
|
+
|
|
68
|
+
root_certificates = _get_certificates(args)
|
|
69
|
+
log(
|
|
70
|
+
DEBUG,
|
|
71
|
+
"Flower will load ClientApp `%s`",
|
|
72
|
+
getattr(args, "client-app"),
|
|
73
|
+
)
|
|
74
|
+
load_fn = _get_load_client_app_fn(args)
|
|
75
|
+
authentication_keys = _try_setup_client_authentication(args)
|
|
76
|
+
|
|
77
|
+
_start_client_internal(
|
|
78
|
+
server_address=args.server,
|
|
79
|
+
load_client_app_fn=load_fn,
|
|
80
|
+
transport="rest" if args.rest else "grpc-rere",
|
|
81
|
+
root_certificates=root_certificates,
|
|
82
|
+
insecure=args.insecure,
|
|
83
|
+
authentication_keys=authentication_keys,
|
|
84
|
+
max_retries=args.max_retries,
|
|
85
|
+
max_wait_time=args.max_wait_time,
|
|
86
|
+
)
|
|
87
|
+
register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
|
|
91
|
+
"""Load certificates if specified in args."""
|
|
92
|
+
# Obtain certificates
|
|
93
|
+
if args.insecure:
|
|
94
|
+
if args.root_certificates is not None:
|
|
95
|
+
sys.exit(
|
|
96
|
+
"Conflicting options: The '--insecure' flag disables HTTPS, "
|
|
97
|
+
"but '--root-certificates' was also specified. Please remove "
|
|
98
|
+
"the '--root-certificates' option when running in insecure mode, "
|
|
99
|
+
"or omit '--insecure' to use HTTPS."
|
|
100
|
+
)
|
|
101
|
+
log(
|
|
102
|
+
WARN,
|
|
103
|
+
"Option `--insecure` was set. "
|
|
104
|
+
"Starting insecure HTTP client connected to %s.",
|
|
105
|
+
args.server,
|
|
106
|
+
)
|
|
107
|
+
root_certificates = None
|
|
108
|
+
else:
|
|
109
|
+
# Load the certificates if provided, or load the system certificates
|
|
110
|
+
cert_path = args.root_certificates
|
|
111
|
+
if cert_path is None:
|
|
112
|
+
root_certificates = None
|
|
113
|
+
else:
|
|
114
|
+
root_certificates = Path(cert_path).read_bytes()
|
|
115
|
+
log(
|
|
116
|
+
DEBUG,
|
|
117
|
+
"Starting secure HTTPS client connected to %s "
|
|
118
|
+
"with the following certificates: %s.",
|
|
119
|
+
args.server,
|
|
120
|
+
cert_path,
|
|
121
|
+
)
|
|
122
|
+
return root_certificates
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _get_load_client_app_fn(
|
|
126
|
+
args: argparse.Namespace,
|
|
127
|
+
) -> Callable[[], ClientApp]:
|
|
128
|
+
"""Get the load_client_app_fn function."""
|
|
129
|
+
client_app_dir = args.dir
|
|
130
|
+
if client_app_dir is not None:
|
|
131
|
+
sys.path.insert(0, client_app_dir)
|
|
132
|
+
|
|
133
|
+
app_ref: str = getattr(args, "client-app")
|
|
134
|
+
valid, error_msg = validate(app_ref)
|
|
135
|
+
if not valid and error_msg:
|
|
136
|
+
raise LoadClientAppError(error_msg) from None
|
|
137
|
+
|
|
138
|
+
def _load() -> ClientApp:
|
|
139
|
+
client_app = load_app(app_ref, LoadClientAppError)
|
|
140
|
+
|
|
141
|
+
if not isinstance(client_app, ClientApp):
|
|
142
|
+
raise LoadClientAppError(
|
|
143
|
+
f"Attribute {app_ref} is not of type {ClientApp}",
|
|
144
|
+
) from None
|
|
145
|
+
|
|
146
|
+
return client_app
|
|
147
|
+
|
|
148
|
+
return _load
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _parse_args_run_supernode() -> argparse.ArgumentParser:
|
|
152
|
+
"""Parse flower-supernode command line arguments."""
|
|
153
|
+
parser = argparse.ArgumentParser(
|
|
154
|
+
description="Start a Flower SuperNode",
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
parser.add_argument(
|
|
158
|
+
"client-app",
|
|
159
|
+
nargs="?",
|
|
160
|
+
default="",
|
|
161
|
+
help="For example: `client:app` or `project.package.module:wrapper.app`. "
|
|
162
|
+
"This is optional and serves as the default ClientApp to be loaded when "
|
|
163
|
+
"the ServerApp does not specify `fab_id` and `fab_version`. "
|
|
164
|
+
"If not provided, defaults to an empty string.",
|
|
165
|
+
)
|
|
166
|
+
_parse_args_common(parser)
|
|
167
|
+
parser.add_argument(
|
|
168
|
+
"--flwr-dir",
|
|
169
|
+
default=None,
|
|
170
|
+
help="""The path containing installed Flower Apps.
|
|
171
|
+
By default, this value isequal to:
|
|
172
|
+
|
|
173
|
+
- `$FLWR_HOME/` if `$FLWR_HOME` is defined
|
|
174
|
+
- `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
|
|
175
|
+
- `$HOME/.flwr/` in all other cases
|
|
176
|
+
""",
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
return parser
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _parse_args_run_client_app() -> argparse.ArgumentParser:
|
|
183
|
+
"""Parse flower-client-app command line arguments."""
|
|
184
|
+
parser = argparse.ArgumentParser(
|
|
185
|
+
description="Start a Flower client app",
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
parser.add_argument(
|
|
189
|
+
"client-app",
|
|
190
|
+
help="For example: `client:app` or `project.package.module:wrapper.app`",
|
|
191
|
+
)
|
|
192
|
+
_parse_args_common(parser=parser)
|
|
193
|
+
|
|
194
|
+
return parser
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
198
|
+
parser.add_argument(
|
|
199
|
+
"--insecure",
|
|
200
|
+
action="store_true",
|
|
201
|
+
help="Run the client without HTTPS. By default, the client runs with "
|
|
202
|
+
"HTTPS enabled. Use this flag only if you understand the risks.",
|
|
203
|
+
)
|
|
204
|
+
parser.add_argument(
|
|
205
|
+
"--rest",
|
|
206
|
+
action="store_true",
|
|
207
|
+
help="Use REST as a transport layer for the client.",
|
|
208
|
+
)
|
|
209
|
+
parser.add_argument(
|
|
210
|
+
"--root-certificates",
|
|
211
|
+
metavar="ROOT_CERT",
|
|
212
|
+
type=str,
|
|
213
|
+
help="Specifies the path to the PEM-encoded root certificate file for "
|
|
214
|
+
"establishing secure HTTPS connections.",
|
|
215
|
+
)
|
|
216
|
+
parser.add_argument(
|
|
217
|
+
"--server",
|
|
218
|
+
default="0.0.0.0:9092",
|
|
219
|
+
help="Server address",
|
|
220
|
+
)
|
|
221
|
+
parser.add_argument(
|
|
222
|
+
"--max-retries",
|
|
223
|
+
type=int,
|
|
224
|
+
default=None,
|
|
225
|
+
help="The maximum number of times the client will try to connect to the"
|
|
226
|
+
"server before giving up in case of a connection error. By default,"
|
|
227
|
+
"it is set to None, meaning there is no limit to the number of tries.",
|
|
228
|
+
)
|
|
229
|
+
parser.add_argument(
|
|
230
|
+
"--max-wait-time",
|
|
231
|
+
type=float,
|
|
232
|
+
default=None,
|
|
233
|
+
help="The maximum duration before the client stops trying to"
|
|
234
|
+
"connect to the server in case of connection error. By default, it"
|
|
235
|
+
"is set to None, meaning there is no limit to the total time.",
|
|
236
|
+
)
|
|
237
|
+
parser.add_argument(
|
|
238
|
+
"--dir",
|
|
239
|
+
default="",
|
|
240
|
+
help="Add specified directory to the PYTHONPATH and load Flower "
|
|
241
|
+
"app from there."
|
|
242
|
+
" Default: current working directory.",
|
|
243
|
+
)
|
|
244
|
+
parser.add_argument(
|
|
245
|
+
"--authentication-keys",
|
|
246
|
+
nargs=2,
|
|
247
|
+
metavar=("CLIENT_PRIVATE_KEY", "CLIENT_PUBLIC_KEY"),
|
|
248
|
+
type=str,
|
|
249
|
+
help="Provide two file paths: (1) the client's private "
|
|
250
|
+
"key file, and (2) the client's public key file.",
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def _try_setup_client_authentication(
|
|
255
|
+
args: argparse.Namespace,
|
|
256
|
+
) -> Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
|
|
257
|
+
if not args.authentication_keys:
|
|
258
|
+
return None
|
|
259
|
+
|
|
260
|
+
ssh_private_key = load_ssh_private_key(
|
|
261
|
+
Path(args.authentication_keys[0]).read_bytes(),
|
|
262
|
+
None,
|
|
263
|
+
)
|
|
264
|
+
ssh_public_key = load_ssh_public_key(Path(args.authentication_keys[1]).read_bytes())
|
|
265
|
+
|
|
266
|
+
try:
|
|
267
|
+
client_private_key, client_public_key = ssh_types_to_elliptic_curve(
|
|
268
|
+
ssh_private_key, ssh_public_key
|
|
269
|
+
)
|
|
270
|
+
except TypeError:
|
|
271
|
+
sys.exit(
|
|
272
|
+
"The file paths provided could not be read as a private and public "
|
|
273
|
+
"key pair. Client authentication requires an elliptic curve public and "
|
|
274
|
+
"private key pair. Please provide the file paths containing elliptic "
|
|
275
|
+
"curve private and public keys to '--authentication-keys'."
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
return (
|
|
279
|
+
client_private_key,
|
|
280
|
+
client_public_key,
|
|
281
|
+
)
|