flwr-nightly 1.12.0.dev20240906__py3-none-any.whl → 1.12.0.dev20240913__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +1 -2
- flwr/cli/config_utils.py +10 -10
- flwr/cli/install.py +1 -2
- flwr/cli/new/new.py +26 -40
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +19 -29
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -3
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +18 -3
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -13
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +9 -1
- flwr/cli/run/run.py +6 -7
- flwr/cli/utils.py +2 -2
- flwr/client/app.py +14 -14
- flwr/client/client_app.py +5 -5
- flwr/client/clientapp/app.py +2 -2
- flwr/client/dpfedavg_numpy_client.py +6 -7
- flwr/client/grpc_adapter_client/connection.py +4 -3
- flwr/client/grpc_client/connection.py +4 -3
- flwr/client/grpc_rere_client/client_interceptor.py +5 -5
- flwr/client/grpc_rere_client/connection.py +5 -4
- flwr/client/grpc_rere_client/grpc_adapter.py +2 -2
- flwr/client/message_handler/message_handler.py +3 -3
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +25 -25
- flwr/client/mod/utils.py +1 -3
- flwr/client/node_state.py +2 -2
- flwr/client/numpy_client.py +8 -8
- flwr/client/rest_client/connection.py +5 -4
- flwr/client/supernode/app.py +7 -8
- flwr/common/address.py +2 -2
- flwr/common/config.py +8 -8
- flwr/common/constant.py +12 -1
- flwr/common/differential_privacy.py +2 -2
- flwr/common/dp.py +1 -3
- flwr/common/exit_handlers.py +3 -3
- flwr/common/grpc.py +2 -1
- flwr/common/logger.py +3 -3
- flwr/common/object_ref.py +3 -3
- flwr/common/record/configsrecord.py +3 -3
- flwr/common/record/metricsrecord.py +3 -3
- flwr/common/record/parametersrecord.py +3 -2
- flwr/common/record/recordset.py +1 -1
- flwr/common/record/typeddict.py +23 -10
- flwr/common/recordset_compat.py +7 -5
- flwr/common/retry_invoker.py +6 -17
- flwr/common/secure_aggregation/crypto/shamir.py +10 -10
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +2 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +16 -16
- flwr/common/secure_aggregation/quantization.py +7 -7
- flwr/common/secure_aggregation/secaggplus_utils.py +3 -5
- flwr/common/serde.py +11 -9
- flwr/common/telemetry.py +5 -5
- flwr/common/typing.py +19 -19
- flwr/common/version.py +2 -3
- flwr/server/app.py +18 -18
- flwr/server/client_manager.py +6 -6
- flwr/server/compat/app_utils.py +2 -3
- flwr/server/driver/driver.py +3 -2
- flwr/server/driver/grpc_driver.py +7 -7
- flwr/server/driver/inmemory_driver.py +5 -4
- flwr/server/history.py +8 -9
- flwr/server/run_serverapp.py +5 -6
- flwr/server/server.py +36 -36
- flwr/server/strategy/aggregate.py +13 -13
- flwr/server/strategy/bulyan.py +8 -8
- flwr/server/strategy/dp_adaptive_clipping.py +20 -20
- flwr/server/strategy/dp_fixed_clipping.py +19 -19
- flwr/server/strategy/dpfedavg_adaptive.py +6 -6
- flwr/server/strategy/dpfedavg_fixed.py +10 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +11 -11
- flwr/server/strategy/fedadagrad.py +8 -8
- flwr/server/strategy/fedadam.py +8 -8
- flwr/server/strategy/fedavg.py +16 -16
- flwr/server/strategy/fedavg_android.py +16 -16
- flwr/server/strategy/fedavgm.py +8 -8
- flwr/server/strategy/fedmedian.py +4 -4
- flwr/server/strategy/fedopt.py +5 -5
- flwr/server/strategy/fedprox.py +6 -6
- flwr/server/strategy/fedtrimmedavg.py +8 -8
- flwr/server/strategy/fedxgb_bagging.py +11 -11
- flwr/server/strategy/fedxgb_cyclic.py +9 -9
- flwr/server/strategy/fedxgb_nn_avg.py +5 -5
- flwr/server/strategy/fedyogi.py +8 -8
- flwr/server/strategy/krum.py +8 -8
- flwr/server/strategy/qfedavg.py +15 -15
- flwr/server/strategy/strategy.py +10 -10
- flwr/server/superlink/driver/driver_grpc.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +6 -6
- flwr/server/superlink/ffs/disk_ffs.py +4 -4
- flwr/server/superlink/ffs/ffs.py +4 -4
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -2
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +9 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +5 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +2 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +2 -3
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +26 -17
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/state/in_memory_state.py +18 -18
- flwr/server/superlink/state/sqlite_state.py +22 -21
- flwr/server/superlink/state/state.py +7 -7
- flwr/server/utils/tensorboard.py +4 -4
- flwr/server/utils/validator.py +2 -2
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +22 -22
- flwr/simulation/app.py +8 -8
- flwr/simulation/ray_transport/ray_actor.py +23 -23
- flwr/simulation/run_simulation.py +16 -4
- flwr/superexec/app.py +4 -4
- flwr/superexec/deployment.py +2 -2
- flwr/superexec/exec_grpc.py +2 -2
- flwr/superexec/exec_servicer.py +3 -2
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/METADATA +4 -6
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/RECORD +118 -118
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/entry_points.txt +0 -0
flwr/client/client_app.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import inspect
|
|
19
|
-
from typing import Callable,
|
|
19
|
+
from typing import Callable, Optional
|
|
20
20
|
|
|
21
21
|
from flwr.client.client import Client
|
|
22
22
|
from flwr.client.message_handler.message_handler import (
|
|
@@ -41,11 +41,11 @@ def _alert_erroneous_client_fn() -> None:
|
|
|
41
41
|
|
|
42
42
|
def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
|
|
43
43
|
client_fn_args = inspect.signature(client_fn).parameters
|
|
44
|
-
first_arg = list(client_fn_args.keys())[0]
|
|
45
44
|
|
|
46
45
|
if len(client_fn_args) != 1:
|
|
47
46
|
_alert_erroneous_client_fn()
|
|
48
47
|
|
|
48
|
+
first_arg = list(client_fn_args.keys())[0]
|
|
49
49
|
first_arg_type = client_fn_args[first_arg].annotation
|
|
50
50
|
|
|
51
51
|
if first_arg_type is str or first_arg == "cid":
|
|
@@ -109,9 +109,9 @@ class ClientApp:
|
|
|
109
109
|
def __init__(
|
|
110
110
|
self,
|
|
111
111
|
client_fn: Optional[ClientFnExt] = None, # Only for backward compatibility
|
|
112
|
-
mods: Optional[
|
|
112
|
+
mods: Optional[list[Mod]] = None,
|
|
113
113
|
) -> None:
|
|
114
|
-
self._mods:
|
|
114
|
+
self._mods: list[Mod] = mods if mods is not None else []
|
|
115
115
|
|
|
116
116
|
# Create wrapper function for `handle`
|
|
117
117
|
self._call: Optional[ClientAppCallable] = None
|
|
@@ -263,7 +263,7 @@ def _registration_error(fn_name: str) -> ValueError:
|
|
|
263
263
|
>>> class FlowerClient(NumPyClient):
|
|
264
264
|
>>> # ...
|
|
265
265
|
>>>
|
|
266
|
-
>>> def client_fn(
|
|
266
|
+
>>> def client_fn(context: Context):
|
|
267
267
|
>>> return FlowerClient().to_client()
|
|
268
268
|
>>>
|
|
269
269
|
>>> app = ClientApp(
|
flwr/client/clientapp/app.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
import argparse
|
|
18
18
|
import time
|
|
19
19
|
from logging import DEBUG, ERROR, INFO
|
|
20
|
-
from typing import Optional
|
|
20
|
+
from typing import Optional
|
|
21
21
|
|
|
22
22
|
import grpc
|
|
23
23
|
|
|
@@ -196,7 +196,7 @@ def get_token(stub: grpc.Channel) -> Optional[int]:
|
|
|
196
196
|
|
|
197
197
|
def pull_message(
|
|
198
198
|
stub: grpc.Channel, token: int
|
|
199
|
-
) ->
|
|
199
|
+
) -> tuple[Message, Context, Run, Optional[Fab]]:
|
|
200
200
|
"""Pull message from SuperNode to ClientApp."""
|
|
201
201
|
log(INFO, "Pulling ClientAppInputs for token %s", token)
|
|
202
202
|
try:
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import copy
|
|
19
|
-
from typing import Dict, Tuple
|
|
20
19
|
|
|
21
20
|
import numpy as np
|
|
22
21
|
|
|
@@ -39,7 +38,7 @@ class DPFedAvgNumPyClient(NumPyClient):
|
|
|
39
38
|
super().__init__()
|
|
40
39
|
self.client = client
|
|
41
40
|
|
|
42
|
-
def get_properties(self, config: Config) ->
|
|
41
|
+
def get_properties(self, config: Config) -> dict[str, Scalar]:
|
|
43
42
|
"""Get client properties using the given Numpy client.
|
|
44
43
|
|
|
45
44
|
Parameters
|
|
@@ -58,7 +57,7 @@ class DPFedAvgNumPyClient(NumPyClient):
|
|
|
58
57
|
"""
|
|
59
58
|
return self.client.get_properties(config)
|
|
60
59
|
|
|
61
|
-
def get_parameters(self, config:
|
|
60
|
+
def get_parameters(self, config: dict[str, Scalar]) -> NDArrays:
|
|
62
61
|
"""Return the current local model parameters.
|
|
63
62
|
|
|
64
63
|
Parameters
|
|
@@ -76,8 +75,8 @@ class DPFedAvgNumPyClient(NumPyClient):
|
|
|
76
75
|
return self.client.get_parameters(config)
|
|
77
76
|
|
|
78
77
|
def fit(
|
|
79
|
-
self, parameters: NDArrays, config:
|
|
80
|
-
) ->
|
|
78
|
+
self, parameters: NDArrays, config: dict[str, Scalar]
|
|
79
|
+
) -> tuple[NDArrays, int, dict[str, Scalar]]:
|
|
81
80
|
"""Train the provided parameters using the locally held dataset.
|
|
82
81
|
|
|
83
82
|
This method first updates the local model using the original parameters
|
|
@@ -153,8 +152,8 @@ class DPFedAvgNumPyClient(NumPyClient):
|
|
|
153
152
|
return updated_params, num_examples, metrics
|
|
154
153
|
|
|
155
154
|
def evaluate(
|
|
156
|
-
self, parameters: NDArrays, config:
|
|
157
|
-
) ->
|
|
155
|
+
self, parameters: NDArrays, config: dict[str, Scalar]
|
|
156
|
+
) -> tuple[float, int, dict[str, Scalar]]:
|
|
158
157
|
"""Evaluate the provided parameters using the locally held dataset.
|
|
159
158
|
|
|
160
159
|
Parameters
|
|
@@ -15,9 +15,10 @@
|
|
|
15
15
|
"""Contextmanager for a GrpcAdapter channel to the Flower server."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from collections.abc import Iterator
|
|
18
19
|
from contextlib import contextmanager
|
|
19
20
|
from logging import ERROR
|
|
20
|
-
from typing import Callable,
|
|
21
|
+
from typing import Callable, Optional, Union
|
|
21
22
|
|
|
22
23
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
23
24
|
|
|
@@ -38,10 +39,10 @@ def grpc_adapter( # pylint: disable=R0913
|
|
|
38
39
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
|
39
40
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
40
41
|
authentication_keys: Optional[ # pylint: disable=unused-argument
|
|
41
|
-
|
|
42
|
+
tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
42
43
|
] = None,
|
|
43
44
|
) -> Iterator[
|
|
44
|
-
|
|
45
|
+
tuple[
|
|
45
46
|
Callable[[], Optional[Message]],
|
|
46
47
|
Callable[[Message], None],
|
|
47
48
|
Optional[Callable[[], Optional[int]]],
|
|
@@ -16,11 +16,12 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import uuid
|
|
19
|
+
from collections.abc import Iterator
|
|
19
20
|
from contextlib import contextmanager
|
|
20
21
|
from logging import DEBUG, ERROR
|
|
21
22
|
from pathlib import Path
|
|
22
23
|
from queue import Queue
|
|
23
|
-
from typing import Callable,
|
|
24
|
+
from typing import Callable, Optional, Union, cast
|
|
24
25
|
|
|
25
26
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
26
27
|
|
|
@@ -66,10 +67,10 @@ def grpc_connection( # pylint: disable=R0913, R0915
|
|
|
66
67
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
67
68
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
68
69
|
authentication_keys: Optional[ # pylint: disable=unused-argument
|
|
69
|
-
|
|
70
|
+
tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
70
71
|
] = None,
|
|
71
72
|
) -> Iterator[
|
|
72
|
-
|
|
73
|
+
tuple[
|
|
73
74
|
Callable[[], Optional[Message]],
|
|
74
75
|
Callable[[Message], None],
|
|
75
76
|
Optional[Callable[[], Optional[int]]],
|
|
@@ -17,8 +17,9 @@
|
|
|
17
17
|
|
|
18
18
|
import base64
|
|
19
19
|
import collections
|
|
20
|
+
from collections.abc import Sequence
|
|
20
21
|
from logging import WARNING
|
|
21
|
-
from typing import Any, Callable, Optional,
|
|
22
|
+
from typing import Any, Callable, Optional, Union
|
|
22
23
|
|
|
23
24
|
import grpc
|
|
24
25
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
@@ -53,7 +54,7 @@ Request = Union[
|
|
|
53
54
|
|
|
54
55
|
|
|
55
56
|
def _get_value_from_tuples(
|
|
56
|
-
key_string: str, tuples: Sequence[
|
|
57
|
+
key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]]
|
|
57
58
|
) -> bytes:
|
|
58
59
|
value = next((value for key, value in tuples if key == key_string), "")
|
|
59
60
|
if isinstance(value, str):
|
|
@@ -130,13 +131,12 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
|
|
|
130
131
|
if self.shared_secret is None:
|
|
131
132
|
raise RuntimeError("Failure to compute hmac")
|
|
132
133
|
|
|
134
|
+
message_bytes = request.SerializeToString(deterministic=True)
|
|
133
135
|
metadata.append(
|
|
134
136
|
(
|
|
135
137
|
_AUTH_TOKEN_HEADER,
|
|
136
138
|
base64.urlsafe_b64encode(
|
|
137
|
-
compute_hmac(
|
|
138
|
-
self.shared_secret, request.SerializeToString(True)
|
|
139
|
-
)
|
|
139
|
+
compute_hmac(self.shared_secret, message_bytes)
|
|
140
140
|
),
|
|
141
141
|
)
|
|
142
142
|
)
|
|
@@ -17,11 +17,12 @@
|
|
|
17
17
|
|
|
18
18
|
import random
|
|
19
19
|
import threading
|
|
20
|
+
from collections.abc import Iterator, Sequence
|
|
20
21
|
from contextlib import contextmanager
|
|
21
22
|
from copy import copy
|
|
22
23
|
from logging import DEBUG, ERROR
|
|
23
24
|
from pathlib import Path
|
|
24
|
-
from typing import Callable,
|
|
25
|
+
from typing import Callable, Optional, Union, cast
|
|
25
26
|
|
|
26
27
|
import grpc
|
|
27
28
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
@@ -77,11 +78,11 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
77
78
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
|
78
79
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
79
80
|
authentication_keys: Optional[
|
|
80
|
-
|
|
81
|
+
tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
81
82
|
] = None,
|
|
82
|
-
adapter_cls: Optional[Union[
|
|
83
|
+
adapter_cls: Optional[Union[type[FleetStub], type[GrpcAdapter]]] = None,
|
|
83
84
|
) -> Iterator[
|
|
84
|
-
|
|
85
|
+
tuple[
|
|
85
86
|
Callable[[], Optional[Message]],
|
|
86
87
|
Callable[[Message], None],
|
|
87
88
|
Optional[Callable[[], Optional[int]]],
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
import sys
|
|
19
19
|
from logging import DEBUG
|
|
20
|
-
from typing import Any,
|
|
20
|
+
from typing import Any, TypeVar, cast
|
|
21
21
|
|
|
22
22
|
import grpc
|
|
23
23
|
from google.protobuf.message import Message as GrpcMessage
|
|
@@ -59,7 +59,7 @@ class GrpcAdapter:
|
|
|
59
59
|
self.stub = GrpcAdapterStub(channel)
|
|
60
60
|
|
|
61
61
|
def _send_and_receive(
|
|
62
|
-
self, request: GrpcMessage, response_type:
|
|
62
|
+
self, request: GrpcMessage, response_type: type[T], **kwargs: Any
|
|
63
63
|
) -> T:
|
|
64
64
|
# Serialize request
|
|
65
65
|
container_req = MessageContainer(
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Client-side message handler."""
|
|
16
16
|
|
|
17
17
|
from logging import WARN
|
|
18
|
-
from typing import Optional,
|
|
18
|
+
from typing import Optional, cast
|
|
19
19
|
|
|
20
20
|
from flwr.client.client import (
|
|
21
21
|
maybe_call_evaluate,
|
|
@@ -52,7 +52,7 @@ class UnknownServerMessage(Exception):
|
|
|
52
52
|
"""Exception indicating that the received message is unknown."""
|
|
53
53
|
|
|
54
54
|
|
|
55
|
-
def handle_control_message(message: Message) ->
|
|
55
|
+
def handle_control_message(message: Message) -> tuple[Optional[Message], int]:
|
|
56
56
|
"""Handle control part of the incoming message.
|
|
57
57
|
|
|
58
58
|
Parameters
|
|
@@ -147,7 +147,7 @@ def handle_legacy_message_from_msgtype(
|
|
|
147
147
|
|
|
148
148
|
def _reconnect(
|
|
149
149
|
reconnect_msg: ServerMessage.ReconnectIns,
|
|
150
|
-
) ->
|
|
150
|
+
) -> tuple[ClientMessage, int]:
|
|
151
151
|
# Determine the reason for sending DisconnectRes message
|
|
152
152
|
reason = Reason.ACK
|
|
153
153
|
sleep_duration = None
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import os
|
|
19
19
|
from dataclasses import dataclass, field
|
|
20
20
|
from logging import DEBUG, WARNING
|
|
21
|
-
from typing import Any,
|
|
21
|
+
from typing import Any, cast
|
|
22
22
|
|
|
23
23
|
from flwr.client.typing import ClientAppCallable
|
|
24
24
|
from flwr.common import (
|
|
@@ -91,11 +91,11 @@ class SecAggPlusState:
|
|
|
91
91
|
# Random seed for generating the private mask
|
|
92
92
|
rd_seed: bytes = b""
|
|
93
93
|
|
|
94
|
-
rd_seed_share_dict:
|
|
95
|
-
sk1_share_dict:
|
|
94
|
+
rd_seed_share_dict: dict[int, bytes] = field(default_factory=dict)
|
|
95
|
+
sk1_share_dict: dict[int, bytes] = field(default_factory=dict)
|
|
96
96
|
# The dict of the shared secrets from sk2
|
|
97
|
-
ss2_dict:
|
|
98
|
-
public_keys_dict:
|
|
97
|
+
ss2_dict: dict[int, bytes] = field(default_factory=dict)
|
|
98
|
+
public_keys_dict: dict[int, tuple[bytes, bytes]] = field(default_factory=dict)
|
|
99
99
|
|
|
100
100
|
def __init__(self, **kwargs: ConfigsRecordValues) -> None:
|
|
101
101
|
for k, v in kwargs.items():
|
|
@@ -104,8 +104,8 @@ class SecAggPlusState:
|
|
|
104
104
|
new_v: Any = v
|
|
105
105
|
if k.endswith(":K"):
|
|
106
106
|
k = k[:-2]
|
|
107
|
-
keys = cast(
|
|
108
|
-
values = cast(
|
|
107
|
+
keys = cast(list[int], v)
|
|
108
|
+
values = cast(list[bytes], kwargs[f"{k}:V"])
|
|
109
109
|
if len(values) > len(keys):
|
|
110
110
|
updated_values = [
|
|
111
111
|
tuple(values[i : i + 2]) for i in range(0, len(values), 2)
|
|
@@ -115,17 +115,17 @@ class SecAggPlusState:
|
|
|
115
115
|
new_v = dict(zip(keys, values))
|
|
116
116
|
self.__setattr__(k, new_v)
|
|
117
117
|
|
|
118
|
-
def to_dict(self) ->
|
|
118
|
+
def to_dict(self) -> dict[str, ConfigsRecordValues]:
|
|
119
119
|
"""Convert the state to a dictionary."""
|
|
120
120
|
ret = vars(self)
|
|
121
121
|
for k in list(ret.keys()):
|
|
122
122
|
if isinstance(ret[k], dict):
|
|
123
123
|
# Replace dict with two lists
|
|
124
|
-
v = cast(
|
|
124
|
+
v = cast(dict[str, Any], ret.pop(k))
|
|
125
125
|
ret[f"{k}:K"] = list(v.keys())
|
|
126
126
|
if k == "public_keys_dict":
|
|
127
|
-
v_list:
|
|
128
|
-
for b1_b2 in cast(
|
|
127
|
+
v_list: list[bytes] = []
|
|
128
|
+
for b1_b2 in cast(list[tuple[bytes, bytes]], v.values()):
|
|
129
129
|
v_list.extend(b1_b2)
|
|
130
130
|
ret[f"{k}:V"] = v_list
|
|
131
131
|
else:
|
|
@@ -276,7 +276,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
|
276
276
|
)
|
|
277
277
|
if not isinstance(configs[key], list) or any(
|
|
278
278
|
elm
|
|
279
|
-
for elm in cast(
|
|
279
|
+
for elm in cast(list[Any], configs[key])
|
|
280
280
|
# pylint: disable-next=unidiomatic-typecheck
|
|
281
281
|
if type(elm) is not expected_type
|
|
282
282
|
):
|
|
@@ -299,7 +299,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
|
299
299
|
)
|
|
300
300
|
if not isinstance(configs[key], list) or any(
|
|
301
301
|
elm
|
|
302
|
-
for elm in cast(
|
|
302
|
+
for elm in cast(list[Any], configs[key])
|
|
303
303
|
# pylint: disable-next=unidiomatic-typecheck
|
|
304
304
|
if type(elm) is not expected_type
|
|
305
305
|
):
|
|
@@ -314,7 +314,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
|
314
314
|
|
|
315
315
|
def _setup(
|
|
316
316
|
state: SecAggPlusState, configs: ConfigsRecord
|
|
317
|
-
) ->
|
|
317
|
+
) -> dict[str, ConfigsRecordValues]:
|
|
318
318
|
# Assigning parameter values to object fields
|
|
319
319
|
sec_agg_param_dict = configs
|
|
320
320
|
state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
|
|
@@ -350,8 +350,8 @@ def _setup(
|
|
|
350
350
|
# pylint: disable-next=too-many-locals
|
|
351
351
|
def _share_keys(
|
|
352
352
|
state: SecAggPlusState, configs: ConfigsRecord
|
|
353
|
-
) ->
|
|
354
|
-
named_bytes_tuples = cast(
|
|
353
|
+
) -> dict[str, ConfigsRecordValues]:
|
|
354
|
+
named_bytes_tuples = cast(dict[str, tuple[bytes, bytes]], configs)
|
|
355
355
|
key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()}
|
|
356
356
|
log(DEBUG, "Node %d: starting stage 1...", state.nid)
|
|
357
357
|
state.public_keys_dict = key_dict
|
|
@@ -361,7 +361,7 @@ def _share_keys(
|
|
|
361
361
|
raise ValueError("Available neighbours number smaller than threshold")
|
|
362
362
|
|
|
363
363
|
# Check if all public keys are unique
|
|
364
|
-
pk_list:
|
|
364
|
+
pk_list: list[bytes] = []
|
|
365
365
|
for pk1, pk2 in state.public_keys_dict.values():
|
|
366
366
|
pk_list.append(pk1)
|
|
367
367
|
pk_list.append(pk2)
|
|
@@ -415,11 +415,11 @@ def _collect_masked_vectors(
|
|
|
415
415
|
configs: ConfigsRecord,
|
|
416
416
|
num_examples: int,
|
|
417
417
|
updated_parameters: Parameters,
|
|
418
|
-
) ->
|
|
418
|
+
) -> dict[str, ConfigsRecordValues]:
|
|
419
419
|
log(DEBUG, "Node %d: starting stage 2...", state.nid)
|
|
420
|
-
available_clients:
|
|
421
|
-
ciphertexts = cast(
|
|
422
|
-
srcs = cast(
|
|
420
|
+
available_clients: list[int] = []
|
|
421
|
+
ciphertexts = cast(list[bytes], configs[Key.CIPHERTEXT_LIST])
|
|
422
|
+
srcs = cast(list[int], configs[Key.SOURCE_LIST])
|
|
423
423
|
if len(ciphertexts) + 1 < state.threshold:
|
|
424
424
|
raise ValueError("Not enough available neighbour clients.")
|
|
425
425
|
|
|
@@ -467,7 +467,7 @@ def _collect_masked_vectors(
|
|
|
467
467
|
|
|
468
468
|
quantized_parameters = factor_combine(q_ratio, quantized_parameters)
|
|
469
469
|
|
|
470
|
-
dimensions_list:
|
|
470
|
+
dimensions_list: list[tuple[int, ...]] = [a.shape for a in quantized_parameters]
|
|
471
471
|
|
|
472
472
|
# Add private mask
|
|
473
473
|
private_mask = pseudo_rand_gen(state.rd_seed, state.mod_range, dimensions_list)
|
|
@@ -499,11 +499,11 @@ def _collect_masked_vectors(
|
|
|
499
499
|
|
|
500
500
|
def _unmask(
|
|
501
501
|
state: SecAggPlusState, configs: ConfigsRecord
|
|
502
|
-
) ->
|
|
502
|
+
) -> dict[str, ConfigsRecordValues]:
|
|
503
503
|
log(DEBUG, "Node %d: starting stage 3...", state.nid)
|
|
504
504
|
|
|
505
|
-
active_nids = cast(
|
|
506
|
-
dead_nids = cast(
|
|
505
|
+
active_nids = cast(list[int], configs[Key.ACTIVE_NODE_ID_LIST])
|
|
506
|
+
dead_nids = cast(list[int], configs[Key.DEAD_NODE_ID_LIST])
|
|
507
507
|
# Send private mask seed share for every avaliable client (including itself)
|
|
508
508
|
# Send first private key share for building pairwise mask for every dropped client
|
|
509
509
|
if len(active_nids) < state.threshold:
|
flwr/client/mod/utils.py
CHANGED
|
@@ -15,13 +15,11 @@
|
|
|
15
15
|
"""Utility functions for mods."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import List
|
|
19
|
-
|
|
20
18
|
from flwr.client.typing import ClientAppCallable, Mod
|
|
21
19
|
from flwr.common import Context, Message
|
|
22
20
|
|
|
23
21
|
|
|
24
|
-
def make_ffn(ffn: ClientAppCallable, mods:
|
|
22
|
+
def make_ffn(ffn: ClientAppCallable, mods: list[Mod]) -> ClientAppCallable:
|
|
25
23
|
"""."""
|
|
26
24
|
|
|
27
25
|
def wrap_ffn(_ffn: ClientAppCallable, _mod: Mod) -> ClientAppCallable:
|
flwr/client/node_state.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
from dataclasses import dataclass
|
|
19
19
|
from pathlib import Path
|
|
20
|
-
from typing import
|
|
20
|
+
from typing import Optional
|
|
21
21
|
|
|
22
22
|
from flwr.common import Context, RecordSet
|
|
23
23
|
from flwr.common.config import (
|
|
@@ -46,7 +46,7 @@ class NodeState:
|
|
|
46
46
|
) -> None:
|
|
47
47
|
self.node_id = node_id
|
|
48
48
|
self.node_config = node_config
|
|
49
|
-
self.run_infos:
|
|
49
|
+
self.run_infos: dict[int, RunInfo] = {}
|
|
50
50
|
|
|
51
51
|
# pylint: disable=too-many-arguments
|
|
52
52
|
def register_context(
|
flwr/client/numpy_client.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from abc import ABC
|
|
19
|
-
from typing import Callable
|
|
19
|
+
from typing import Callable
|
|
20
20
|
|
|
21
21
|
from flwr.client.client import Client
|
|
22
22
|
from flwr.common import (
|
|
@@ -73,7 +73,7 @@ class NumPyClient(ABC):
|
|
|
73
73
|
|
|
74
74
|
_context: Context
|
|
75
75
|
|
|
76
|
-
def get_properties(self, config: Config) ->
|
|
76
|
+
def get_properties(self, config: Config) -> dict[str, Scalar]:
|
|
77
77
|
"""Return a client's set of properties.
|
|
78
78
|
|
|
79
79
|
Parameters
|
|
@@ -93,7 +93,7 @@ class NumPyClient(ABC):
|
|
|
93
93
|
_ = (self, config)
|
|
94
94
|
return {}
|
|
95
95
|
|
|
96
|
-
def get_parameters(self, config:
|
|
96
|
+
def get_parameters(self, config: dict[str, Scalar]) -> NDArrays:
|
|
97
97
|
"""Return the current local model parameters.
|
|
98
98
|
|
|
99
99
|
Parameters
|
|
@@ -112,8 +112,8 @@ class NumPyClient(ABC):
|
|
|
112
112
|
return []
|
|
113
113
|
|
|
114
114
|
def fit(
|
|
115
|
-
self, parameters: NDArrays, config:
|
|
116
|
-
) ->
|
|
115
|
+
self, parameters: NDArrays, config: dict[str, Scalar]
|
|
116
|
+
) -> tuple[NDArrays, int, dict[str, Scalar]]:
|
|
117
117
|
"""Train the provided parameters using the locally held dataset.
|
|
118
118
|
|
|
119
119
|
Parameters
|
|
@@ -141,8 +141,8 @@ class NumPyClient(ABC):
|
|
|
141
141
|
return [], 0, {}
|
|
142
142
|
|
|
143
143
|
def evaluate(
|
|
144
|
-
self, parameters: NDArrays, config:
|
|
145
|
-
) ->
|
|
144
|
+
self, parameters: NDArrays, config: dict[str, Scalar]
|
|
145
|
+
) -> tuple[float, int, dict[str, Scalar]]:
|
|
146
146
|
"""Evaluate the provided parameters using the locally held dataset.
|
|
147
147
|
|
|
148
148
|
Parameters
|
|
@@ -310,7 +310,7 @@ def _set_context(self: Client, context: Context) -> None:
|
|
|
310
310
|
|
|
311
311
|
|
|
312
312
|
def _wrap_numpy_client(client: NumPyClient) -> Client:
|
|
313
|
-
member_dict:
|
|
313
|
+
member_dict: dict[str, Callable] = { # type: ignore
|
|
314
314
|
"__init__": _constructor,
|
|
315
315
|
"get_context": _get_context,
|
|
316
316
|
"set_context": _set_context,
|
|
@@ -18,10 +18,11 @@
|
|
|
18
18
|
import random
|
|
19
19
|
import sys
|
|
20
20
|
import threading
|
|
21
|
+
from collections.abc import Iterator
|
|
21
22
|
from contextlib import contextmanager
|
|
22
23
|
from copy import copy
|
|
23
24
|
from logging import ERROR, INFO, WARN
|
|
24
|
-
from typing import Callable,
|
|
25
|
+
from typing import Callable, Optional, TypeVar, Union
|
|
25
26
|
|
|
26
27
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
27
28
|
from google.protobuf.message import Message as GrpcMessage
|
|
@@ -90,10 +91,10 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
|
90
91
|
Union[bytes, str]
|
|
91
92
|
] = None, # pylint: disable=unused-argument
|
|
92
93
|
authentication_keys: Optional[ # pylint: disable=unused-argument
|
|
93
|
-
|
|
94
|
+
tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
94
95
|
] = None,
|
|
95
96
|
) -> Iterator[
|
|
96
|
-
|
|
97
|
+
tuple[
|
|
97
98
|
Callable[[], Optional[Message]],
|
|
98
99
|
Callable[[Message], None],
|
|
99
100
|
Optional[Callable[[], Optional[int]]],
|
|
@@ -173,7 +174,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
|
173
174
|
###########################################################################
|
|
174
175
|
|
|
175
176
|
def _request(
|
|
176
|
-
req: GrpcMessage, res_type:
|
|
177
|
+
req: GrpcMessage, res_type: type[T], api_path: str, retry: bool = True
|
|
177
178
|
) -> Optional[T]:
|
|
178
179
|
# Serialize the request
|
|
179
180
|
req_bytes = req.SerializeToString()
|
flwr/client/supernode/app.py
CHANGED
|
@@ -18,7 +18,7 @@ import argparse
|
|
|
18
18
|
import sys
|
|
19
19
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
20
20
|
from pathlib import Path
|
|
21
|
-
from typing import Optional
|
|
21
|
+
from typing import Optional
|
|
22
22
|
|
|
23
23
|
from cryptography.exceptions import UnsupportedAlgorithm
|
|
24
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
@@ -30,6 +30,7 @@ from cryptography.hazmat.primitives.serialization import (
|
|
|
30
30
|
from flwr.common import EventType, event
|
|
31
31
|
from flwr.common.config import parse_config_args
|
|
32
32
|
from flwr.common.constant import (
|
|
33
|
+
FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
|
|
33
34
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
34
35
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
35
36
|
TRANSPORT_TYPE_REST,
|
|
@@ -44,8 +45,6 @@ from ..app import (
|
|
|
44
45
|
)
|
|
45
46
|
from ..clientapp.utils import get_load_client_app_fn
|
|
46
47
|
|
|
47
|
-
ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092"
|
|
48
|
-
|
|
49
48
|
|
|
50
49
|
def run_supernode() -> None:
|
|
51
50
|
"""Run Flower SuperNode."""
|
|
@@ -103,11 +102,11 @@ def run_client_app() -> None:
|
|
|
103
102
|
|
|
104
103
|
def _warn_deprecated_server_arg(args: argparse.Namespace) -> None:
|
|
105
104
|
"""Warn about the deprecated argument `--server`."""
|
|
106
|
-
if args.server !=
|
|
105
|
+
if args.server != FLEET_API_GRPC_RERE_DEFAULT_ADDRESS:
|
|
107
106
|
warn = "Passing flag --server is deprecated. Use --superlink instead."
|
|
108
107
|
warn_deprecated_feature(warn)
|
|
109
108
|
|
|
110
|
-
if args.superlink !=
|
|
109
|
+
if args.superlink != FLEET_API_GRPC_RERE_DEFAULT_ADDRESS:
|
|
111
110
|
# if `--superlink` also passed, then
|
|
112
111
|
# warn user that this argument overrides what was passed with `--server`
|
|
113
112
|
log(
|
|
@@ -247,12 +246,12 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
247
246
|
)
|
|
248
247
|
parser.add_argument(
|
|
249
248
|
"--server",
|
|
250
|
-
default=
|
|
249
|
+
default=FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
|
|
251
250
|
help="Server address",
|
|
252
251
|
)
|
|
253
252
|
parser.add_argument(
|
|
254
253
|
"--superlink",
|
|
255
|
-
default=
|
|
254
|
+
default=FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
|
|
256
255
|
help="SuperLink Fleet API (gRPC-rere) address (IPv4, IPv6, or a domain name)",
|
|
257
256
|
)
|
|
258
257
|
parser.add_argument(
|
|
@@ -292,7 +291,7 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
292
291
|
|
|
293
292
|
def _try_setup_client_authentication(
|
|
294
293
|
args: argparse.Namespace,
|
|
295
|
-
) -> Optional[
|
|
294
|
+
) -> Optional[tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
|
|
296
295
|
if not args.auth_supernode_private_key and not args.auth_supernode_public_key:
|
|
297
296
|
return None
|
|
298
297
|
|
flwr/common/address.py
CHANGED
|
@@ -16,12 +16,12 @@
|
|
|
16
16
|
|
|
17
17
|
import socket
|
|
18
18
|
from ipaddress import ip_address
|
|
19
|
-
from typing import Optional
|
|
19
|
+
from typing import Optional
|
|
20
20
|
|
|
21
21
|
IPV6: int = 6
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def parse_address(address: str) -> Optional[
|
|
24
|
+
def parse_address(address: str) -> Optional[tuple[str, int, Optional[bool]]]:
|
|
25
25
|
"""Parse an IP address into host, port, and version.
|
|
26
26
|
|
|
27
27
|
Parameters
|