flwr-nightly 1.9.0.dev20240520__py3-none-any.whl → 1.10.0.dev20240612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +4 -19
- flwr/cli/config_utils.py +12 -27
- flwr/cli/install.py +196 -0
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +7 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +7 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +7 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -1
- flwr/cli/run/run.py +20 -4
- flwr/cli/utils.py +14 -0
- flwr/client/__init__.py +1 -0
- flwr/client/app.py +135 -97
- flwr/client/client_app.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +1 -1
- flwr/client/grpc_rere_client/connection.py +6 -6
- flwr/client/mod/__init__.py +1 -1
- flwr/client/rest_client/connection.py +1 -2
- flwr/client/supernode/app.py +70 -28
- flwr/common/object_ref.py +13 -9
- flwr/common/recordset_compat.py +8 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +0 -15
- flwr/proto/driver_pb2.py +20 -19
- flwr/proto/driver_pb2_grpc.py +35 -0
- flwr/proto/driver_pb2_grpc.pyi +14 -0
- flwr/proto/fleet_pb2.py +28 -33
- flwr/proto/fleet_pb2.pyi +0 -42
- flwr/proto/fleet_pb2_grpc.py +7 -6
- flwr/proto/fleet_pb2_grpc.pyi +5 -4
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/run_pb2.py +30 -0
- flwr/proto/run_pb2.pyi +52 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/server/__init__.py +0 -4
- flwr/server/app.py +190 -395
- flwr/server/run_serverapp.py +29 -5
- flwr/server/server_app.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +7 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +5 -3
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/METADATA +4 -3
- {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/RECORD +53 -44
- {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/entry_points.txt +0 -2
- {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/WHEEL +0 -0
flwr/client/app.py
CHANGED
|
@@ -14,8 +14,10 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower client app."""
|
|
16
16
|
|
|
17
|
+
import signal
|
|
17
18
|
import sys
|
|
18
19
|
import time
|
|
20
|
+
from dataclasses import dataclass
|
|
19
21
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
20
22
|
from typing import Callable, ContextManager, Optional, Tuple, Type, Union
|
|
21
23
|
|
|
@@ -37,7 +39,7 @@ from flwr.common.constant import (
|
|
|
37
39
|
)
|
|
38
40
|
from flwr.common.logger import log, warn_deprecated_feature
|
|
39
41
|
from flwr.common.message import Error
|
|
40
|
-
from flwr.common.retry_invoker import RetryInvoker, exponential
|
|
42
|
+
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
|
|
41
43
|
|
|
42
44
|
from .grpc_client.connection import grpc_connection
|
|
43
45
|
from .grpc_rere_client.connection import grpc_request_response
|
|
@@ -263,6 +265,29 @@ def _start_client_internal(
|
|
|
263
265
|
transport, server_address
|
|
264
266
|
)
|
|
265
267
|
|
|
268
|
+
app_state_tracker = _AppStateTracker()
|
|
269
|
+
|
|
270
|
+
def _on_sucess(retry_state: RetryState) -> None:
|
|
271
|
+
app_state_tracker.is_connected = True
|
|
272
|
+
if retry_state.tries > 1:
|
|
273
|
+
log(
|
|
274
|
+
INFO,
|
|
275
|
+
"Connection successful after %.2f seconds and %s tries.",
|
|
276
|
+
retry_state.elapsed_time,
|
|
277
|
+
retry_state.tries,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
def _on_backoff(retry_state: RetryState) -> None:
|
|
281
|
+
app_state_tracker.is_connected = False
|
|
282
|
+
if retry_state.tries == 1:
|
|
283
|
+
log(WARN, "Connection attempt failed, retrying...")
|
|
284
|
+
else:
|
|
285
|
+
log(
|
|
286
|
+
DEBUG,
|
|
287
|
+
"Connection attempt failed, retrying in %.2f seconds",
|
|
288
|
+
retry_state.actual_wait,
|
|
289
|
+
)
|
|
290
|
+
|
|
266
291
|
retry_invoker = RetryInvoker(
|
|
267
292
|
wait_gen_factory=exponential,
|
|
268
293
|
recoverable_exceptions=connection_error_type,
|
|
@@ -278,30 +303,13 @@ def _start_client_internal(
|
|
|
278
303
|
if retry_state.tries > 1
|
|
279
304
|
else None
|
|
280
305
|
),
|
|
281
|
-
on_success=
|
|
282
|
-
|
|
283
|
-
INFO,
|
|
284
|
-
"Connection successful after %.2f seconds and %s tries.",
|
|
285
|
-
retry_state.elapsed_time,
|
|
286
|
-
retry_state.tries,
|
|
287
|
-
)
|
|
288
|
-
if retry_state.tries > 1
|
|
289
|
-
else None
|
|
290
|
-
),
|
|
291
|
-
on_backoff=lambda retry_state: (
|
|
292
|
-
log(WARN, "Connection attempt failed, retrying...")
|
|
293
|
-
if retry_state.tries == 1
|
|
294
|
-
else log(
|
|
295
|
-
DEBUG,
|
|
296
|
-
"Connection attempt failed, retrying in %.2f seconds",
|
|
297
|
-
retry_state.actual_wait,
|
|
298
|
-
)
|
|
299
|
-
),
|
|
306
|
+
on_success=_on_sucess,
|
|
307
|
+
on_backoff=_on_backoff,
|
|
300
308
|
)
|
|
301
309
|
|
|
302
310
|
node_state = NodeState()
|
|
303
311
|
|
|
304
|
-
while
|
|
312
|
+
while not app_state_tracker.interrupt:
|
|
305
313
|
sleep_duration: int = 0
|
|
306
314
|
with connection(
|
|
307
315
|
address,
|
|
@@ -318,99 +326,112 @@ def _start_client_internal(
|
|
|
318
326
|
if create_node is not None:
|
|
319
327
|
create_node() # pylint: disable=not-callable
|
|
320
328
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
329
|
+
app_state_tracker.register_signal_handler()
|
|
330
|
+
while not app_state_tracker.interrupt:
|
|
331
|
+
try:
|
|
332
|
+
# Receive
|
|
333
|
+
message = receive()
|
|
334
|
+
if message is None:
|
|
335
|
+
time.sleep(3) # Wait for 3s before asking again
|
|
336
|
+
continue
|
|
337
|
+
|
|
338
|
+
log(INFO, "")
|
|
339
|
+
if len(message.metadata.group_id) > 0:
|
|
340
|
+
log(
|
|
341
|
+
INFO,
|
|
342
|
+
"[RUN %s, ROUND %s]",
|
|
343
|
+
message.metadata.run_id,
|
|
344
|
+
message.metadata.group_id,
|
|
345
|
+
)
|
|
330
346
|
log(
|
|
331
347
|
INFO,
|
|
332
|
-
"
|
|
333
|
-
message.metadata.
|
|
334
|
-
message.metadata.
|
|
348
|
+
"Received: %s message %s",
|
|
349
|
+
message.metadata.message_type,
|
|
350
|
+
message.metadata.message_id,
|
|
335
351
|
)
|
|
336
|
-
log(
|
|
337
|
-
INFO,
|
|
338
|
-
"Received: %s message %s",
|
|
339
|
-
message.metadata.message_type,
|
|
340
|
-
message.metadata.message_id,
|
|
341
|
-
)
|
|
342
|
-
|
|
343
|
-
# Handle control message
|
|
344
|
-
out_message, sleep_duration = handle_control_message(message)
|
|
345
|
-
if out_message:
|
|
346
|
-
send(out_message)
|
|
347
|
-
break
|
|
348
|
-
|
|
349
|
-
# Register context for this run
|
|
350
|
-
node_state.register_context(run_id=message.metadata.run_id)
|
|
351
|
-
|
|
352
|
-
# Retrieve context for this run
|
|
353
|
-
context = node_state.retrieve_context(run_id=message.metadata.run_id)
|
|
354
352
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
353
|
+
# Handle control message
|
|
354
|
+
out_message, sleep_duration = handle_control_message(message)
|
|
355
|
+
if out_message:
|
|
356
|
+
send(out_message)
|
|
357
|
+
break
|
|
360
358
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
# Load ClientApp instance
|
|
364
|
-
client_app: ClientApp = load_client_app_fn()
|
|
365
|
-
|
|
366
|
-
# Execute ClientApp
|
|
367
|
-
reply_message = client_app(message=message, context=context)
|
|
368
|
-
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
369
|
-
|
|
370
|
-
# Legacy grpc-bidi
|
|
371
|
-
if transport in ["grpc-bidi", None]:
|
|
372
|
-
log(ERROR, "Client raised an exception.", exc_info=ex)
|
|
373
|
-
# Raise exception, crash process
|
|
374
|
-
raise ex
|
|
375
|
-
|
|
376
|
-
# Don't update/change NodeState
|
|
377
|
-
|
|
378
|
-
e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
|
|
379
|
-
# Reason example: "<class 'ZeroDivisionError'>:<'division by zero'>"
|
|
380
|
-
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
|
|
381
|
-
exc_entity = "ClientApp"
|
|
382
|
-
if isinstance(ex, LoadClientAppError):
|
|
383
|
-
reason = (
|
|
384
|
-
"An exception was raised when attempting to load "
|
|
385
|
-
"`ClientApp`"
|
|
386
|
-
)
|
|
387
|
-
e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
|
|
388
|
-
exc_entity = "SuperNode"
|
|
359
|
+
# Register context for this run
|
|
360
|
+
node_state.register_context(run_id=message.metadata.run_id)
|
|
389
361
|
|
|
390
|
-
|
|
362
|
+
# Retrieve context for this run
|
|
363
|
+
context = node_state.retrieve_context(
|
|
364
|
+
run_id=message.metadata.run_id
|
|
365
|
+
)
|
|
391
366
|
|
|
392
|
-
# Create error message
|
|
367
|
+
# Create an error reply message that will never be used to prevent
|
|
368
|
+
# the used-before-assignment linting error
|
|
393
369
|
reply_message = message.create_error_reply(
|
|
394
|
-
error=Error(code=
|
|
395
|
-
)
|
|
396
|
-
else:
|
|
397
|
-
# No exception, update node state
|
|
398
|
-
node_state.update_context(
|
|
399
|
-
run_id=message.metadata.run_id,
|
|
400
|
-
context=context,
|
|
370
|
+
error=Error(code=ErrorCode.UNKNOWN, reason="Unknown")
|
|
401
371
|
)
|
|
402
372
|
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
373
|
+
# Handle app loading and task message
|
|
374
|
+
try:
|
|
375
|
+
# Load ClientApp instance
|
|
376
|
+
client_app: ClientApp = load_client_app_fn()
|
|
377
|
+
|
|
378
|
+
# Execute ClientApp
|
|
379
|
+
reply_message = client_app(message=message, context=context)
|
|
380
|
+
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
381
|
+
|
|
382
|
+
# Legacy grpc-bidi
|
|
383
|
+
if transport in ["grpc-bidi", None]:
|
|
384
|
+
log(ERROR, "Client raised an exception.", exc_info=ex)
|
|
385
|
+
# Raise exception, crash process
|
|
386
|
+
raise ex
|
|
387
|
+
|
|
388
|
+
# Don't update/change NodeState
|
|
389
|
+
|
|
390
|
+
e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
|
|
391
|
+
# Ex fmt: "<class 'ZeroDivisionError'>:<'division by zero'>"
|
|
392
|
+
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
|
|
393
|
+
exc_entity = "ClientApp"
|
|
394
|
+
if isinstance(ex, LoadClientAppError):
|
|
395
|
+
reason = (
|
|
396
|
+
"An exception was raised when attempting to load "
|
|
397
|
+
"`ClientApp`"
|
|
398
|
+
)
|
|
399
|
+
e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
|
|
400
|
+
exc_entity = "SuperNode"
|
|
401
|
+
|
|
402
|
+
if not app_state_tracker.interrupt:
|
|
403
|
+
log(
|
|
404
|
+
ERROR, "%s raised an exception", exc_entity, exc_info=ex
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
# Create error message
|
|
408
|
+
reply_message = message.create_error_reply(
|
|
409
|
+
error=Error(code=e_code, reason=reason)
|
|
410
|
+
)
|
|
411
|
+
else:
|
|
412
|
+
# No exception, update node state
|
|
413
|
+
node_state.update_context(
|
|
414
|
+
run_id=message.metadata.run_id,
|
|
415
|
+
context=context,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
# Send
|
|
419
|
+
send(reply_message)
|
|
420
|
+
log(INFO, "Sent reply")
|
|
421
|
+
|
|
422
|
+
except StopIteration:
|
|
423
|
+
sleep_duration = 0
|
|
424
|
+
break
|
|
406
425
|
|
|
407
426
|
# Unregister node
|
|
408
|
-
if delete_node is not None:
|
|
427
|
+
if delete_node is not None and app_state_tracker.is_connected:
|
|
409
428
|
delete_node() # pylint: disable=not-callable
|
|
410
429
|
|
|
411
430
|
if sleep_duration == 0:
|
|
412
431
|
log(INFO, "Disconnect and shut down")
|
|
432
|
+
del app_state_tracker
|
|
413
433
|
break
|
|
434
|
+
|
|
414
435
|
# Sleep and reconnect afterwards
|
|
415
436
|
log(
|
|
416
437
|
INFO,
|
|
@@ -579,3 +600,20 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
|
579
600
|
)
|
|
580
601
|
|
|
581
602
|
return connection, address, error_type
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
@dataclass
|
|
606
|
+
class _AppStateTracker:
|
|
607
|
+
interrupt: bool = False
|
|
608
|
+
is_connected: bool = False
|
|
609
|
+
|
|
610
|
+
def register_signal_handler(self) -> None:
|
|
611
|
+
"""Register handlers for exit signals."""
|
|
612
|
+
|
|
613
|
+
def signal_handler(sig, frame): # type: ignore
|
|
614
|
+
# pylint: disable=unused-argument
|
|
615
|
+
self.interrupt = True
|
|
616
|
+
raise StopIteration from None
|
|
617
|
+
|
|
618
|
+
signal.signal(signal.SIGINT, signal_handler)
|
|
619
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
flwr/client/client_app.py
CHANGED
|
@@ -31,11 +31,11 @@ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
|
31
31
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
32
32
|
CreateNodeRequest,
|
|
33
33
|
DeleteNodeRequest,
|
|
34
|
-
GetRunRequest,
|
|
35
34
|
PingRequest,
|
|
36
35
|
PullTaskInsRequest,
|
|
37
36
|
PushTaskResRequest,
|
|
38
37
|
)
|
|
38
|
+
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
|
|
39
39
|
|
|
40
40
|
_PUBLIC_KEY_HEADER = "public-key"
|
|
41
41
|
_AUTH_TOKEN_HEADER = "auth-token"
|
|
@@ -21,7 +21,7 @@ from contextlib import contextmanager
|
|
|
21
21
|
from copy import copy
|
|
22
22
|
from logging import DEBUG, ERROR
|
|
23
23
|
from pathlib import Path
|
|
24
|
-
from typing import Callable, Iterator, Optional, Sequence, Tuple, Union, cast
|
|
24
|
+
from typing import Callable, Iterator, Optional, Sequence, Tuple, Type, Union, cast
|
|
25
25
|
|
|
26
26
|
import grpc
|
|
27
27
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
@@ -44,8 +44,6 @@ from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
|
44
44
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
45
45
|
CreateNodeRequest,
|
|
46
46
|
DeleteNodeRequest,
|
|
47
|
-
GetRunRequest,
|
|
48
|
-
GetRunResponse,
|
|
49
47
|
PingRequest,
|
|
50
48
|
PingResponse,
|
|
51
49
|
PullTaskInsRequest,
|
|
@@ -53,6 +51,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
53
51
|
)
|
|
54
52
|
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
|
55
53
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
54
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
56
55
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
57
56
|
|
|
58
57
|
from .client_interceptor import AuthenticateClientInterceptor
|
|
@@ -73,6 +72,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
73
72
|
authentication_keys: Optional[
|
|
74
73
|
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
75
74
|
] = None,
|
|
75
|
+
adapter_cls: Optional[Type[FleetStub]] = None,
|
|
76
76
|
) -> Iterator[
|
|
77
77
|
Tuple[
|
|
78
78
|
Callable[[], Optional[Message]],
|
|
@@ -133,7 +133,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
133
133
|
channel.subscribe(on_channel_state_change)
|
|
134
134
|
|
|
135
135
|
# Shared variables for inner functions
|
|
136
|
-
|
|
136
|
+
if adapter_cls is None:
|
|
137
|
+
adapter_cls = FleetStub
|
|
138
|
+
stub = adapter_cls(channel)
|
|
137
139
|
metadata: Optional[Metadata] = None
|
|
138
140
|
node: Optional[Node] = None
|
|
139
141
|
ping_thread: Optional[threading.Thread] = None
|
|
@@ -190,8 +192,6 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
190
192
|
|
|
191
193
|
# Stop the ping-loop thread
|
|
192
194
|
ping_stop_event.set()
|
|
193
|
-
if ping_thread is not None:
|
|
194
|
-
ping_thread.join()
|
|
195
195
|
|
|
196
196
|
# Call FleetAPI
|
|
197
197
|
delete_node_request = DeleteNodeRequest(node=node)
|
flwr/client/mod/__init__.py
CHANGED
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Mods."""
|
|
15
|
+
"""Flower Built-in Mods."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
|
|
@@ -46,8 +46,6 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
46
46
|
CreateNodeResponse,
|
|
47
47
|
DeleteNodeRequest,
|
|
48
48
|
DeleteNodeResponse,
|
|
49
|
-
GetRunRequest,
|
|
50
|
-
GetRunResponse,
|
|
51
49
|
PingRequest,
|
|
52
50
|
PingResponse,
|
|
53
51
|
PullTaskInsRequest,
|
|
@@ -56,6 +54,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
56
54
|
PushTaskResResponse,
|
|
57
55
|
)
|
|
58
56
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
57
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
59
58
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
60
59
|
|
|
61
60
|
try:
|
flwr/client/supernode/app.py
CHANGED
|
@@ -20,6 +20,7 @@ from logging import DEBUG, INFO, WARN
|
|
|
20
20
|
from pathlib import Path
|
|
21
21
|
from typing import Callable, Optional, Tuple
|
|
22
22
|
|
|
23
|
+
from cryptography.exceptions import UnsupportedAlgorithm
|
|
23
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
25
|
from cryptography.hazmat.primitives.serialization import (
|
|
25
26
|
load_ssh_private_key,
|
|
@@ -29,14 +30,13 @@ from cryptography.hazmat.primitives.serialization import (
|
|
|
29
30
|
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
30
31
|
from flwr.common import EventType, event
|
|
31
32
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
32
|
-
from flwr.common.logger import log
|
|
33
|
+
from flwr.common.logger import log, warn_deprecated_feature
|
|
33
34
|
from flwr.common.object_ref import load_app, validate
|
|
34
|
-
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
35
|
-
ssh_types_to_elliptic_curve,
|
|
36
|
-
)
|
|
37
35
|
|
|
38
36
|
from ..app import _start_client_internal
|
|
39
37
|
|
|
38
|
+
ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092"
|
|
39
|
+
|
|
40
40
|
|
|
41
41
|
def run_supernode() -> None:
|
|
42
42
|
"""Run Flower SuperNode."""
|
|
@@ -65,6 +65,23 @@ def run_client_app() -> None:
|
|
|
65
65
|
|
|
66
66
|
args = _parse_args_run_client_app().parse_args()
|
|
67
67
|
|
|
68
|
+
if args.server != ADDRESS_FLEET_API_GRPC_RERE:
|
|
69
|
+
warn = "Passing flag --server is deprecated. Use --superlink instead."
|
|
70
|
+
warn_deprecated_feature(warn)
|
|
71
|
+
|
|
72
|
+
if args.superlink != ADDRESS_FLEET_API_GRPC_RERE:
|
|
73
|
+
# if `--superlink` also passed, then
|
|
74
|
+
# warn user that this argument overrides what was passed with `--server`
|
|
75
|
+
log(
|
|
76
|
+
WARN,
|
|
77
|
+
"Both `--server` and `--superlink` were passed. "
|
|
78
|
+
"`--server` will be ignored. Connecting to the Superlink Fleet API "
|
|
79
|
+
"at %s.",
|
|
80
|
+
args.superlink,
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
args.superlink = args.server
|
|
84
|
+
|
|
68
85
|
root_certificates = _get_certificates(args)
|
|
69
86
|
log(
|
|
70
87
|
DEBUG,
|
|
@@ -75,7 +92,7 @@ def run_client_app() -> None:
|
|
|
75
92
|
authentication_keys = _try_setup_client_authentication(args)
|
|
76
93
|
|
|
77
94
|
_start_client_internal(
|
|
78
|
-
server_address=args.
|
|
95
|
+
server_address=args.superlink,
|
|
79
96
|
load_client_app_fn=load_fn,
|
|
80
97
|
transport="rest" if args.rest else "grpc-rere",
|
|
81
98
|
root_certificates=root_certificates,
|
|
@@ -102,7 +119,7 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
|
|
|
102
119
|
WARN,
|
|
103
120
|
"Option `--insecure` was set. "
|
|
104
121
|
"Starting insecure HTTP client connected to %s.",
|
|
105
|
-
args.
|
|
122
|
+
args.superlink,
|
|
106
123
|
)
|
|
107
124
|
root_certificates = None
|
|
108
125
|
else:
|
|
@@ -116,7 +133,7 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
|
|
|
116
133
|
DEBUG,
|
|
117
134
|
"Starting secure HTTPS client connected to %s "
|
|
118
135
|
"with the following certificates: %s.",
|
|
119
|
-
args.
|
|
136
|
+
args.superlink,
|
|
120
137
|
cert_path,
|
|
121
138
|
)
|
|
122
139
|
return root_certificates
|
|
@@ -215,9 +232,14 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
215
232
|
)
|
|
216
233
|
parser.add_argument(
|
|
217
234
|
"--server",
|
|
218
|
-
default=
|
|
235
|
+
default=ADDRESS_FLEET_API_GRPC_RERE,
|
|
219
236
|
help="Server address",
|
|
220
237
|
)
|
|
238
|
+
parser.add_argument(
|
|
239
|
+
"--superlink",
|
|
240
|
+
default=ADDRESS_FLEET_API_GRPC_RERE,
|
|
241
|
+
help="SuperLink Fleet API (gRPC-rere) address (IPv4, IPv6, or a domain name)",
|
|
242
|
+
)
|
|
221
243
|
parser.add_argument(
|
|
222
244
|
"--max-retries",
|
|
223
245
|
type=int,
|
|
@@ -242,40 +264,60 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
242
264
|
" Default: current working directory.",
|
|
243
265
|
)
|
|
244
266
|
parser.add_argument(
|
|
245
|
-
"--
|
|
246
|
-
|
|
247
|
-
|
|
267
|
+
"--auth-supernode-private-key",
|
|
268
|
+
type=str,
|
|
269
|
+
help="The SuperNode's private key (as a path str) to enable authentication.",
|
|
270
|
+
)
|
|
271
|
+
parser.add_argument(
|
|
272
|
+
"--auth-supernode-public-key",
|
|
248
273
|
type=str,
|
|
249
|
-
help="
|
|
250
|
-
"key file, and (2) the client's public key file.",
|
|
274
|
+
help="The SuperNode's public key (as a path str) to enable authentication.",
|
|
251
275
|
)
|
|
252
276
|
|
|
253
277
|
|
|
254
278
|
def _try_setup_client_authentication(
|
|
255
279
|
args: argparse.Namespace,
|
|
256
280
|
) -> Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
|
|
257
|
-
if not args.
|
|
281
|
+
if not args.auth_supernode_private_key and not args.auth_supernode_public_key:
|
|
258
282
|
return None
|
|
259
283
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
284
|
+
if not args.auth_supernode_private_key or not args.auth_supernode_public_key:
|
|
285
|
+
sys.exit(
|
|
286
|
+
"Authentication requires file paths to both "
|
|
287
|
+
"'--auth-supernode-private-key' and '--auth-supernode-public-key'"
|
|
288
|
+
"to be provided (providing only one of them is not sufficient)."
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
ssh_private_key = load_ssh_private_key(
|
|
293
|
+
Path(args.auth_supernode_private_key).read_bytes(),
|
|
294
|
+
None,
|
|
295
|
+
)
|
|
296
|
+
if not isinstance(ssh_private_key, ec.EllipticCurvePrivateKey):
|
|
297
|
+
raise ValueError()
|
|
298
|
+
except (ValueError, UnsupportedAlgorithm):
|
|
299
|
+
sys.exit(
|
|
300
|
+
"Error: Unable to parse the private key file in "
|
|
301
|
+
"'--auth-supernode-private-key'. Authentication requires elliptic "
|
|
302
|
+
"curve private and public key pair. Please ensure that the file "
|
|
303
|
+
"path points to a valid private key file and try again."
|
|
304
|
+
)
|
|
265
305
|
|
|
266
306
|
try:
|
|
267
|
-
|
|
268
|
-
|
|
307
|
+
ssh_public_key = load_ssh_public_key(
|
|
308
|
+
Path(args.auth_supernode_public_key).read_bytes()
|
|
269
309
|
)
|
|
270
|
-
|
|
310
|
+
if not isinstance(ssh_public_key, ec.EllipticCurvePublicKey):
|
|
311
|
+
raise ValueError()
|
|
312
|
+
except (ValueError, UnsupportedAlgorithm):
|
|
271
313
|
sys.exit(
|
|
272
|
-
"
|
|
273
|
-
"key
|
|
274
|
-
"private key pair. Please
|
|
275
|
-
"
|
|
314
|
+
"Error: Unable to parse the public key file in "
|
|
315
|
+
"'--auth-supernode-public-key'. Authentication requires elliptic "
|
|
316
|
+
"curve private and public key pair. Please ensure that the file "
|
|
317
|
+
"path points to a valid public key file and try again."
|
|
276
318
|
)
|
|
277
319
|
|
|
278
320
|
return (
|
|
279
|
-
|
|
280
|
-
|
|
321
|
+
ssh_private_key,
|
|
322
|
+
ssh_public_key,
|
|
281
323
|
)
|
flwr/common/object_ref.py
CHANGED
|
@@ -30,6 +30,7 @@ attribute.
|
|
|
30
30
|
|
|
31
31
|
def validate(
|
|
32
32
|
module_attribute_str: str,
|
|
33
|
+
check_module: bool = True,
|
|
33
34
|
) -> Tuple[bool, Optional[str]]:
|
|
34
35
|
"""Validate object reference.
|
|
35
36
|
|
|
@@ -56,15 +57,18 @@ def validate(
|
|
|
56
57
|
f"Missing attribute in {module_attribute_str}{OBJECT_REF_HELP_STR}",
|
|
57
58
|
)
|
|
58
59
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
if
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
60
|
+
if check_module:
|
|
61
|
+
# Load module
|
|
62
|
+
module = find_spec(module_str)
|
|
63
|
+
if module and module.origin:
|
|
64
|
+
if not _find_attribute_in_module(module.origin, attributes_str):
|
|
65
|
+
return (
|
|
66
|
+
False,
|
|
67
|
+
f"Unable to find attribute {attributes_str} in module {module_str}"
|
|
68
|
+
f"{OBJECT_REF_HELP_STR}",
|
|
69
|
+
)
|
|
70
|
+
return (True, None)
|
|
71
|
+
else:
|
|
68
72
|
return (True, None)
|
|
69
73
|
|
|
70
74
|
return (
|
flwr/common/recordset_compat.py
CHANGED
|
@@ -35,6 +35,8 @@ from .typing import (
|
|
|
35
35
|
Status,
|
|
36
36
|
)
|
|
37
37
|
|
|
38
|
+
EMPTY_TENSOR_KEY = "_empty"
|
|
39
|
+
|
|
38
40
|
|
|
39
41
|
def parametersrecord_to_parameters(
|
|
40
42
|
record: ParametersRecord, keep_input: bool
|
|
@@ -59,7 +61,8 @@ def parametersrecord_to_parameters(
|
|
|
59
61
|
parameters = Parameters(tensors=[], tensor_type="")
|
|
60
62
|
|
|
61
63
|
for key in list(record.keys()):
|
|
62
|
-
|
|
64
|
+
if key != EMPTY_TENSOR_KEY:
|
|
65
|
+
parameters.tensors.append(record[key].data)
|
|
63
66
|
|
|
64
67
|
if not parameters.tensor_type:
|
|
65
68
|
# Setting from first array in record. Recall the warning in the docstrings
|
|
@@ -103,6 +106,10 @@ def parameters_to_parametersrecord(
|
|
|
103
106
|
data=tensor, dtype="", stype=tensor_type, shape=[]
|
|
104
107
|
)
|
|
105
108
|
|
|
109
|
+
if num_arrays == 0:
|
|
110
|
+
ordered_dict[EMPTY_TENSOR_KEY] = Array(
|
|
111
|
+
data=b"", dtype="", stype=tensor_type, shape=[]
|
|
112
|
+
)
|
|
106
113
|
return ParametersRecord(ordered_dict, keep_input=keep_input)
|
|
107
114
|
|
|
108
115
|
|
|
@@ -117,18 +117,3 @@ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
|
|
|
117
117
|
return True
|
|
118
118
|
except InvalidSignature:
|
|
119
119
|
return False
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
def ssh_types_to_elliptic_curve(
|
|
123
|
-
private_key: serialization.SSHPrivateKeyTypes,
|
|
124
|
-
public_key: serialization.SSHPublicKeyTypes,
|
|
125
|
-
) -> Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]:
|
|
126
|
-
"""Cast SSH key types to elliptic curve."""
|
|
127
|
-
if isinstance(private_key, ec.EllipticCurvePrivateKey) and isinstance(
|
|
128
|
-
public_key, ec.EllipticCurvePublicKey
|
|
129
|
-
):
|
|
130
|
-
return (private_key, public_key)
|
|
131
|
-
|
|
132
|
-
raise TypeError(
|
|
133
|
-
"The provided key is not an EllipticCurvePrivateKey or EllipticCurvePublicKey"
|
|
134
|
-
)
|