flwr-nightly 1.12.0.dev20240906__py3-none-any.whl → 1.12.0.dev20240913__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +1 -2
- flwr/cli/config_utils.py +10 -10
- flwr/cli/install.py +1 -2
- flwr/cli/new/new.py +26 -40
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +19 -29
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -3
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +18 -3
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -13
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +9 -1
- flwr/cli/run/run.py +6 -7
- flwr/cli/utils.py +2 -2
- flwr/client/app.py +14 -14
- flwr/client/client_app.py +5 -5
- flwr/client/clientapp/app.py +2 -2
- flwr/client/dpfedavg_numpy_client.py +6 -7
- flwr/client/grpc_adapter_client/connection.py +4 -3
- flwr/client/grpc_client/connection.py +4 -3
- flwr/client/grpc_rere_client/client_interceptor.py +5 -5
- flwr/client/grpc_rere_client/connection.py +5 -4
- flwr/client/grpc_rere_client/grpc_adapter.py +2 -2
- flwr/client/message_handler/message_handler.py +3 -3
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +25 -25
- flwr/client/mod/utils.py +1 -3
- flwr/client/node_state.py +2 -2
- flwr/client/numpy_client.py +8 -8
- flwr/client/rest_client/connection.py +5 -4
- flwr/client/supernode/app.py +7 -8
- flwr/common/address.py +2 -2
- flwr/common/config.py +8 -8
- flwr/common/constant.py +12 -1
- flwr/common/differential_privacy.py +2 -2
- flwr/common/dp.py +1 -3
- flwr/common/exit_handlers.py +3 -3
- flwr/common/grpc.py +2 -1
- flwr/common/logger.py +3 -3
- flwr/common/object_ref.py +3 -3
- flwr/common/record/configsrecord.py +3 -3
- flwr/common/record/metricsrecord.py +3 -3
- flwr/common/record/parametersrecord.py +3 -2
- flwr/common/record/recordset.py +1 -1
- flwr/common/record/typeddict.py +23 -10
- flwr/common/recordset_compat.py +7 -5
- flwr/common/retry_invoker.py +6 -17
- flwr/common/secure_aggregation/crypto/shamir.py +10 -10
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +2 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +16 -16
- flwr/common/secure_aggregation/quantization.py +7 -7
- flwr/common/secure_aggregation/secaggplus_utils.py +3 -5
- flwr/common/serde.py +11 -9
- flwr/common/telemetry.py +5 -5
- flwr/common/typing.py +19 -19
- flwr/common/version.py +2 -3
- flwr/server/app.py +18 -18
- flwr/server/client_manager.py +6 -6
- flwr/server/compat/app_utils.py +2 -3
- flwr/server/driver/driver.py +3 -2
- flwr/server/driver/grpc_driver.py +7 -7
- flwr/server/driver/inmemory_driver.py +5 -4
- flwr/server/history.py +8 -9
- flwr/server/run_serverapp.py +5 -6
- flwr/server/server.py +36 -36
- flwr/server/strategy/aggregate.py +13 -13
- flwr/server/strategy/bulyan.py +8 -8
- flwr/server/strategy/dp_adaptive_clipping.py +20 -20
- flwr/server/strategy/dp_fixed_clipping.py +19 -19
- flwr/server/strategy/dpfedavg_adaptive.py +6 -6
- flwr/server/strategy/dpfedavg_fixed.py +10 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +11 -11
- flwr/server/strategy/fedadagrad.py +8 -8
- flwr/server/strategy/fedadam.py +8 -8
- flwr/server/strategy/fedavg.py +16 -16
- flwr/server/strategy/fedavg_android.py +16 -16
- flwr/server/strategy/fedavgm.py +8 -8
- flwr/server/strategy/fedmedian.py +4 -4
- flwr/server/strategy/fedopt.py +5 -5
- flwr/server/strategy/fedprox.py +6 -6
- flwr/server/strategy/fedtrimmedavg.py +8 -8
- flwr/server/strategy/fedxgb_bagging.py +11 -11
- flwr/server/strategy/fedxgb_cyclic.py +9 -9
- flwr/server/strategy/fedxgb_nn_avg.py +5 -5
- flwr/server/strategy/fedyogi.py +8 -8
- flwr/server/strategy/krum.py +8 -8
- flwr/server/strategy/qfedavg.py +15 -15
- flwr/server/strategy/strategy.py +10 -10
- flwr/server/superlink/driver/driver_grpc.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +6 -6
- flwr/server/superlink/ffs/disk_ffs.py +4 -4
- flwr/server/superlink/ffs/ffs.py +4 -4
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -2
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +9 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +5 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +2 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +2 -3
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +26 -17
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/state/in_memory_state.py +18 -18
- flwr/server/superlink/state/sqlite_state.py +22 -21
- flwr/server/superlink/state/state.py +7 -7
- flwr/server/utils/tensorboard.py +4 -4
- flwr/server/utils/validator.py +2 -2
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +22 -22
- flwr/simulation/app.py +8 -8
- flwr/simulation/ray_transport/ray_actor.py +23 -23
- flwr/simulation/run_simulation.py +16 -4
- flwr/superexec/app.py +4 -4
- flwr/superexec/deployment.py +2 -2
- flwr/superexec/exec_grpc.py +2 -2
- flwr/superexec/exec_servicer.py +3 -2
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/METADATA +4 -6
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/RECORD +118 -118
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/entry_points.txt +0 -0
flwr/common/config.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
import os
|
|
18
18
|
import re
|
|
19
19
|
from pathlib import Path
|
|
20
|
-
from typing import Any,
|
|
20
|
+
from typing import Any, Optional, Union, cast, get_args
|
|
21
21
|
|
|
22
22
|
import tomli
|
|
23
23
|
|
|
@@ -53,7 +53,7 @@ def get_project_dir(
|
|
|
53
53
|
return Path(flwr_dir) / APP_DIR / publisher / project_name / fab_version
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
def get_project_config(project_dir: Union[str, Path]) ->
|
|
56
|
+
def get_project_config(project_dir: Union[str, Path]) -> dict[str, Any]:
|
|
57
57
|
"""Return pyproject.toml in the given project directory."""
|
|
58
58
|
# Load pyproject.toml file
|
|
59
59
|
toml_path = Path(project_dir) / FAB_CONFIG_FILE
|
|
@@ -137,13 +137,13 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
|
|
|
137
137
|
|
|
138
138
|
|
|
139
139
|
def flatten_dict(
|
|
140
|
-
raw_dict: Optional[
|
|
140
|
+
raw_dict: Optional[dict[str, Any]], parent_key: str = ""
|
|
141
141
|
) -> UserConfig:
|
|
142
142
|
"""Flatten dict by joining nested keys with a given separator."""
|
|
143
143
|
if raw_dict is None:
|
|
144
144
|
return {}
|
|
145
145
|
|
|
146
|
-
items:
|
|
146
|
+
items: list[tuple[str, UserConfigValue]] = []
|
|
147
147
|
separator: str = "."
|
|
148
148
|
for k, v in raw_dict.items():
|
|
149
149
|
new_key = f"{parent_key}{separator}{k}" if parent_key else k
|
|
@@ -159,9 +159,9 @@ def flatten_dict(
|
|
|
159
159
|
return dict(items)
|
|
160
160
|
|
|
161
161
|
|
|
162
|
-
def unflatten_dict(flat_dict:
|
|
162
|
+
def unflatten_dict(flat_dict: dict[str, Any]) -> dict[str, Any]:
|
|
163
163
|
"""Unflatten a dict with keys containing separators into a nested dict."""
|
|
164
|
-
unflattened_dict:
|
|
164
|
+
unflattened_dict: dict[str, Any] = {}
|
|
165
165
|
separator: str = "."
|
|
166
166
|
|
|
167
167
|
for key, value in flat_dict.items():
|
|
@@ -177,7 +177,7 @@ def unflatten_dict(flat_dict: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
177
177
|
|
|
178
178
|
|
|
179
179
|
def parse_config_args(
|
|
180
|
-
config: Optional[
|
|
180
|
+
config: Optional[list[str]],
|
|
181
181
|
) -> UserConfig:
|
|
182
182
|
"""Parse separator separated list of key-value pairs separated by '='."""
|
|
183
183
|
overrides: UserConfig = {}
|
|
@@ -209,7 +209,7 @@ def parse_config_args(
|
|
|
209
209
|
return overrides
|
|
210
210
|
|
|
211
211
|
|
|
212
|
-
def get_metadata_from_config(config:
|
|
212
|
+
def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
|
|
213
213
|
"""Extract `fab_version` and `fab_id` from a project config."""
|
|
214
214
|
return (
|
|
215
215
|
config["project"]["version"],
|
flwr/common/constant.py
CHANGED
|
@@ -37,7 +37,18 @@ TRANSPORT_TYPES = [
|
|
|
37
37
|
TRANSPORT_TYPE_VCE,
|
|
38
38
|
]
|
|
39
39
|
|
|
40
|
-
|
|
40
|
+
# Addresses
|
|
41
|
+
# SuperNode
|
|
42
|
+
CLIENTAPPIO_API_DEFAULT_ADDRESS = "0.0.0.0:9094"
|
|
43
|
+
# SuperExec
|
|
44
|
+
EXEC_API_DEFAULT_ADDRESS = "0.0.0.0:9093"
|
|
45
|
+
# SuperLink
|
|
46
|
+
DRIVER_API_DEFAULT_ADDRESS = "0.0.0.0:9091"
|
|
47
|
+
FLEET_API_GRPC_RERE_DEFAULT_ADDRESS = "0.0.0.0:9092"
|
|
48
|
+
FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS = (
|
|
49
|
+
"[::]:8080" # IPv6 to keep start_server compatible
|
|
50
|
+
)
|
|
51
|
+
FLEET_API_REST_DEFAULT_ADDRESS = "0.0.0.0:9093"
|
|
41
52
|
|
|
42
53
|
# Constants for ping
|
|
43
54
|
PING_DEFAULT_INTERVAL = 30
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from logging import WARNING
|
|
19
|
-
from typing import Optional
|
|
19
|
+
from typing import Optional
|
|
20
20
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
|
|
@@ -125,7 +125,7 @@ def compute_adaptive_noise_params(
|
|
|
125
125
|
noise_multiplier: float,
|
|
126
126
|
num_sampled_clients: float,
|
|
127
127
|
clipped_count_stddev: Optional[float],
|
|
128
|
-
) ->
|
|
128
|
+
) -> tuple[float, float]:
|
|
129
129
|
"""Compute noising parameters for the adaptive clipping.
|
|
130
130
|
|
|
131
131
|
Paper: https://arxiv.org/abs/1905.03871
|
flwr/common/dp.py
CHANGED
|
@@ -15,8 +15,6 @@
|
|
|
15
15
|
"""Building block functions for DP algorithms."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import Tuple
|
|
19
|
-
|
|
20
18
|
import numpy as np
|
|
21
19
|
|
|
22
20
|
from flwr.common.logger import warn_deprecated_feature
|
|
@@ -41,7 +39,7 @@ def add_gaussian_noise(update: NDArrays, std_dev: float) -> NDArrays:
|
|
|
41
39
|
return update_noised
|
|
42
40
|
|
|
43
41
|
|
|
44
|
-
def clip_by_l2(update: NDArrays, threshold: float) ->
|
|
42
|
+
def clip_by_l2(update: NDArrays, threshold: float) -> tuple[NDArrays, bool]:
|
|
45
43
|
"""Scales the update so thats its L2 norm is upper-bound to threshold."""
|
|
46
44
|
warn_deprecated_feature("`clip_by_l2` method")
|
|
47
45
|
update_norm = _get_update_norm(update)
|
flwr/common/exit_handlers.py
CHANGED
|
@@ -19,7 +19,7 @@ import sys
|
|
|
19
19
|
from signal import SIGINT, SIGTERM, signal
|
|
20
20
|
from threading import Thread
|
|
21
21
|
from types import FrameType
|
|
22
|
-
from typing import
|
|
22
|
+
from typing import Optional
|
|
23
23
|
|
|
24
24
|
from grpc import Server
|
|
25
25
|
|
|
@@ -28,8 +28,8 @@ from flwr.common.telemetry import EventType, event
|
|
|
28
28
|
|
|
29
29
|
def register_exit_handlers(
|
|
30
30
|
event_type: EventType,
|
|
31
|
-
grpc_servers: Optional[
|
|
32
|
-
bckg_threads: Optional[
|
|
31
|
+
grpc_servers: Optional[list[Server]] = None,
|
|
32
|
+
bckg_threads: Optional[list[Thread]] = None,
|
|
33
33
|
) -> None:
|
|
34
34
|
"""Register exit handlers for `SIGINT` and `SIGTERM` signals.
|
|
35
35
|
|
flwr/common/grpc.py
CHANGED
flwr/common/logger.py
CHANGED
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import logging
|
|
19
19
|
from logging import WARN, LogRecord
|
|
20
20
|
from logging.handlers import HTTPHandler
|
|
21
|
-
from typing import TYPE_CHECKING, Any,
|
|
21
|
+
from typing import TYPE_CHECKING, Any, Optional, TextIO
|
|
22
22
|
|
|
23
23
|
# Create logger
|
|
24
24
|
LOGGER_NAME = "flwr"
|
|
@@ -119,12 +119,12 @@ class CustomHTTPHandler(HTTPHandler):
|
|
|
119
119
|
url: str,
|
|
120
120
|
method: str = "GET",
|
|
121
121
|
secure: bool = False,
|
|
122
|
-
credentials: Optional[
|
|
122
|
+
credentials: Optional[tuple[str, str]] = None,
|
|
123
123
|
) -> None:
|
|
124
124
|
super().__init__(host, url, method, secure, credentials)
|
|
125
125
|
self.identifier = identifier
|
|
126
126
|
|
|
127
|
-
def mapLogRecord(self, record: LogRecord) ->
|
|
127
|
+
def mapLogRecord(self, record: LogRecord) -> dict[str, Any]:
|
|
128
128
|
"""Filter for the properties to be send to the logserver."""
|
|
129
129
|
record_dict = record.__dict__
|
|
130
130
|
return {
|
flwr/common/object_ref.py
CHANGED
|
@@ -21,7 +21,7 @@ import sys
|
|
|
21
21
|
from importlib.util import find_spec
|
|
22
22
|
from logging import WARN
|
|
23
23
|
from pathlib import Path
|
|
24
|
-
from typing import Any, Optional,
|
|
24
|
+
from typing import Any, Optional, Union
|
|
25
25
|
|
|
26
26
|
from .logger import log
|
|
27
27
|
|
|
@@ -40,7 +40,7 @@ def validate(
|
|
|
40
40
|
module_attribute_str: str,
|
|
41
41
|
check_module: bool = True,
|
|
42
42
|
project_dir: Optional[Union[str, Path]] = None,
|
|
43
|
-
) ->
|
|
43
|
+
) -> tuple[bool, Optional[str]]:
|
|
44
44
|
"""Validate object reference.
|
|
45
45
|
|
|
46
46
|
Parameters
|
|
@@ -106,7 +106,7 @@ def validate(
|
|
|
106
106
|
|
|
107
107
|
def load_app( # pylint: disable= too-many-branches
|
|
108
108
|
module_attribute_str: str,
|
|
109
|
-
error_type:
|
|
109
|
+
error_type: type[Exception],
|
|
110
110
|
project_dir: Optional[Union[str, Path]] = None,
|
|
111
111
|
) -> Any:
|
|
112
112
|
"""Return the object specified in a module attribute string.
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""ConfigsRecord."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import
|
|
18
|
+
from typing import Optional, get_args
|
|
19
19
|
|
|
20
20
|
from flwr.common.typing import ConfigsRecordValues, ConfigsScalar
|
|
21
21
|
|
|
@@ -109,7 +109,7 @@ class ConfigsRecord(TypedDict[str, ConfigsRecordValues]):
|
|
|
109
109
|
|
|
110
110
|
def __init__(
|
|
111
111
|
self,
|
|
112
|
-
configs_dict: Optional[
|
|
112
|
+
configs_dict: Optional[dict[str, ConfigsRecordValues]] = None,
|
|
113
113
|
keep_input: bool = True,
|
|
114
114
|
) -> None:
|
|
115
115
|
|
|
@@ -141,7 +141,7 @@ class ConfigsRecord(TypedDict[str, ConfigsRecordValues]):
|
|
|
141
141
|
num_bytes = 0
|
|
142
142
|
|
|
143
143
|
for k, v in self.items():
|
|
144
|
-
if isinstance(v,
|
|
144
|
+
if isinstance(v, list):
|
|
145
145
|
if isinstance(v[0], (bytes, str)):
|
|
146
146
|
# not all str are of equal length necessarily
|
|
147
147
|
# for both the footprint of each element is 1 Byte
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""MetricsRecord."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import
|
|
18
|
+
from typing import Optional, get_args
|
|
19
19
|
|
|
20
20
|
from flwr.common.typing import MetricsRecordValues, MetricsScalar
|
|
21
21
|
|
|
@@ -115,7 +115,7 @@ class MetricsRecord(TypedDict[str, MetricsRecordValues]):
|
|
|
115
115
|
|
|
116
116
|
def __init__(
|
|
117
117
|
self,
|
|
118
|
-
metrics_dict: Optional[
|
|
118
|
+
metrics_dict: Optional[dict[str, MetricsRecordValues]] = None,
|
|
119
119
|
keep_input: bool = True,
|
|
120
120
|
):
|
|
121
121
|
super().__init__(_check_key, _check_value)
|
|
@@ -130,7 +130,7 @@ class MetricsRecord(TypedDict[str, MetricsRecordValues]):
|
|
|
130
130
|
num_bytes = 0
|
|
131
131
|
|
|
132
132
|
for k, v in self.items():
|
|
133
|
-
if isinstance(v,
|
|
133
|
+
if isinstance(v, list):
|
|
134
134
|
# both int and float normally take 4 bytes
|
|
135
135
|
# But MetricRecords are mapped to 64bit int/float
|
|
136
136
|
# during protobuffing
|
|
@@ -14,9 +14,10 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""ParametersRecord and Array."""
|
|
16
16
|
|
|
17
|
+
from collections import OrderedDict
|
|
17
18
|
from dataclasses import dataclass
|
|
18
19
|
from io import BytesIO
|
|
19
|
-
from typing import
|
|
20
|
+
from typing import Optional, cast
|
|
20
21
|
|
|
21
22
|
import numpy as np
|
|
22
23
|
|
|
@@ -51,7 +52,7 @@ class Array:
|
|
|
51
52
|
"""
|
|
52
53
|
|
|
53
54
|
dtype: str
|
|
54
|
-
shape:
|
|
55
|
+
shape: list[int]
|
|
55
56
|
stype: str
|
|
56
57
|
data: bytes
|
|
57
58
|
|
flwr/common/record/recordset.py
CHANGED
|
@@ -119,7 +119,7 @@ class RecordSet:
|
|
|
119
119
|
Let's see an example.
|
|
120
120
|
|
|
121
121
|
>>> from flwr.common import RecordSet
|
|
122
|
-
>>> from flwr.common import
|
|
122
|
+
>>> from flwr.common import ConfigsRecord, MetricsRecord, ParametersRecord
|
|
123
123
|
>>>
|
|
124
124
|
>>> # Let's begin with an empty record
|
|
125
125
|
>>> my_recordset = RecordSet()
|
flwr/common/record/typeddict.py
CHANGED
|
@@ -15,7 +15,8 @@
|
|
|
15
15
|
"""Typed dict base class for *Records."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from
|
|
18
|
+
from collections.abc import ItemsView, Iterator, KeysView, MutableMapping, ValuesView
|
|
19
|
+
from typing import Callable, Generic, TypeVar, cast
|
|
19
20
|
|
|
20
21
|
K = TypeVar("K") # Key type
|
|
21
22
|
V = TypeVar("V") # Value type
|
|
@@ -38,38 +39,50 @@ class TypedDict(MutableMapping[K, V], Generic[K, V]):
|
|
|
38
39
|
cast(Callable[[V], None], self.__dict__["_check_value_fn"])(value)
|
|
39
40
|
|
|
40
41
|
# Set key-value pair
|
|
41
|
-
cast(
|
|
42
|
+
cast(dict[K, V], self.__dict__["_data"])[key] = value
|
|
42
43
|
|
|
43
44
|
def __delitem__(self, key: K) -> None:
|
|
44
45
|
"""Remove the item with the specified key."""
|
|
45
|
-
del cast(
|
|
46
|
+
del cast(dict[K, V], self.__dict__["_data"])[key]
|
|
46
47
|
|
|
47
48
|
def __getitem__(self, item: K) -> V:
|
|
48
49
|
"""Return the value for the specified key."""
|
|
49
|
-
return cast(
|
|
50
|
+
return cast(dict[K, V], self.__dict__["_data"])[item]
|
|
50
51
|
|
|
51
52
|
def __iter__(self) -> Iterator[K]:
|
|
52
53
|
"""Yield an iterator over the keys of the dictionary."""
|
|
53
|
-
return iter(cast(
|
|
54
|
+
return iter(cast(dict[K, V], self.__dict__["_data"]))
|
|
54
55
|
|
|
55
56
|
def __repr__(self) -> str:
|
|
56
57
|
"""Return a string representation of the dictionary."""
|
|
57
|
-
return cast(
|
|
58
|
+
return cast(dict[K, V], self.__dict__["_data"]).__repr__()
|
|
58
59
|
|
|
59
60
|
def __len__(self) -> int:
|
|
60
61
|
"""Return the number of items in the dictionary."""
|
|
61
|
-
return len(cast(
|
|
62
|
+
return len(cast(dict[K, V], self.__dict__["_data"]))
|
|
62
63
|
|
|
63
64
|
def __contains__(self, key: object) -> bool:
|
|
64
65
|
"""Check if the dictionary contains the specified key."""
|
|
65
|
-
return key in cast(
|
|
66
|
+
return key in cast(dict[K, V], self.__dict__["_data"])
|
|
66
67
|
|
|
67
68
|
def __eq__(self, other: object) -> bool:
|
|
68
69
|
"""Compare this instance to another dictionary or TypedDict."""
|
|
69
|
-
data = cast(
|
|
70
|
+
data = cast(dict[K, V], self.__dict__["_data"])
|
|
70
71
|
if isinstance(other, TypedDict):
|
|
71
|
-
other_data = cast(
|
|
72
|
+
other_data = cast(dict[K, V], other.__dict__["_data"])
|
|
72
73
|
return data == other_data
|
|
73
74
|
if isinstance(other, dict):
|
|
74
75
|
return data == other
|
|
75
76
|
return NotImplemented
|
|
77
|
+
|
|
78
|
+
def keys(self) -> KeysView[K]:
|
|
79
|
+
"""D.keys() -> a set-like object providing a view on D's keys."""
|
|
80
|
+
return cast(dict[K, V], self.__dict__["_data"]).keys()
|
|
81
|
+
|
|
82
|
+
def values(self) -> ValuesView[V]:
|
|
83
|
+
"""D.values() -> an object providing a view on D's values."""
|
|
84
|
+
return cast(dict[K, V], self.__dict__["_data"]).values()
|
|
85
|
+
|
|
86
|
+
def items(self) -> ItemsView[K, V]:
|
|
87
|
+
"""D.items() -> a set-like object providing a view on D's items."""
|
|
88
|
+
return cast(dict[K, V], self.__dict__["_data"]).items()
|
flwr/common/recordset_compat.py
CHANGED
|
@@ -15,7 +15,9 @@
|
|
|
15
15
|
"""RecordSet utilities."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from
|
|
18
|
+
from collections import OrderedDict
|
|
19
|
+
from collections.abc import Mapping
|
|
20
|
+
from typing import Union, cast, get_args
|
|
19
21
|
|
|
20
22
|
from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet
|
|
21
23
|
from .typing import (
|
|
@@ -115,7 +117,7 @@ def parameters_to_parametersrecord(
|
|
|
115
117
|
|
|
116
118
|
def _check_mapping_from_recordscalartype_to_scalar(
|
|
117
119
|
record_data: Mapping[str, Union[ConfigsRecordValues, MetricsRecordValues]]
|
|
118
|
-
) ->
|
|
120
|
+
) -> dict[str, Scalar]:
|
|
119
121
|
"""Check mapping `common.*RecordValues` into `common.Scalar` is possible."""
|
|
120
122
|
for value in record_data.values():
|
|
121
123
|
if not isinstance(value, get_args(Scalar)):
|
|
@@ -126,14 +128,14 @@ def _check_mapping_from_recordscalartype_to_scalar(
|
|
|
126
128
|
"supported by the `common.RecordSet` infrastructure. "
|
|
127
129
|
f"You used type: {type(value)}"
|
|
128
130
|
)
|
|
129
|
-
return cast(
|
|
131
|
+
return cast(dict[str, Scalar], record_data)
|
|
130
132
|
|
|
131
133
|
|
|
132
134
|
def _recordset_to_fit_or_evaluate_ins_components(
|
|
133
135
|
recordset: RecordSet,
|
|
134
136
|
ins_str: str,
|
|
135
137
|
keep_input: bool,
|
|
136
|
-
) ->
|
|
138
|
+
) -> tuple[Parameters, dict[str, Scalar]]:
|
|
137
139
|
"""Derive Fit/Evaluate Ins from a RecordSet."""
|
|
138
140
|
# get Array and construct Parameters
|
|
139
141
|
parameters_record = recordset.parameters_records[f"{ins_str}.parameters"]
|
|
@@ -169,7 +171,7 @@ def _fit_or_evaluate_ins_to_recordset(
|
|
|
169
171
|
def _embed_status_into_recordset(
|
|
170
172
|
res_str: str, status: Status, recordset: RecordSet
|
|
171
173
|
) -> RecordSet:
|
|
172
|
-
status_dict:
|
|
174
|
+
status_dict: dict[str, ConfigsRecordValues] = {
|
|
173
175
|
"code": int(status.code.value),
|
|
174
176
|
"message": status.message,
|
|
175
177
|
}
|
flwr/common/retry_invoker.py
CHANGED
|
@@ -18,20 +18,9 @@
|
|
|
18
18
|
import itertools
|
|
19
19
|
import random
|
|
20
20
|
import time
|
|
21
|
+
from collections.abc import Generator, Iterable
|
|
21
22
|
from dataclasses import dataclass
|
|
22
|
-
from typing import
|
|
23
|
-
Any,
|
|
24
|
-
Callable,
|
|
25
|
-
Dict,
|
|
26
|
-
Generator,
|
|
27
|
-
Iterable,
|
|
28
|
-
List,
|
|
29
|
-
Optional,
|
|
30
|
-
Tuple,
|
|
31
|
-
Type,
|
|
32
|
-
Union,
|
|
33
|
-
cast,
|
|
34
|
-
)
|
|
23
|
+
from typing import Any, Callable, Optional, Union, cast
|
|
35
24
|
|
|
36
25
|
|
|
37
26
|
def exponential(
|
|
@@ -93,8 +82,8 @@ class RetryState:
|
|
|
93
82
|
"""State for callbacks in RetryInvoker."""
|
|
94
83
|
|
|
95
84
|
target: Callable[..., Any]
|
|
96
|
-
args:
|
|
97
|
-
kwargs:
|
|
85
|
+
args: tuple[Any, ...]
|
|
86
|
+
kwargs: dict[str, Any]
|
|
98
87
|
tries: int
|
|
99
88
|
elapsed_time: float
|
|
100
89
|
exception: Optional[Exception] = None
|
|
@@ -167,7 +156,7 @@ class RetryInvoker:
|
|
|
167
156
|
def __init__(
|
|
168
157
|
self,
|
|
169
158
|
wait_gen_factory: Callable[[], Generator[float, None, None]],
|
|
170
|
-
recoverable_exceptions: Union[
|
|
159
|
+
recoverable_exceptions: Union[type[Exception], tuple[type[Exception], ...]],
|
|
171
160
|
max_tries: Optional[int],
|
|
172
161
|
max_time: Optional[float],
|
|
173
162
|
*,
|
|
@@ -244,7 +233,7 @@ class RetryInvoker:
|
|
|
244
233
|
try_cnt = 0
|
|
245
234
|
wait_generator = self.wait_gen_factory()
|
|
246
235
|
start = time.monotonic()
|
|
247
|
-
ref_state:
|
|
236
|
+
ref_state: list[Optional[RetryState]] = [None]
|
|
248
237
|
|
|
249
238
|
while True:
|
|
250
239
|
try_cnt += 1
|
|
@@ -17,20 +17,20 @@
|
|
|
17
17
|
|
|
18
18
|
import pickle
|
|
19
19
|
from concurrent.futures import ThreadPoolExecutor
|
|
20
|
-
from typing import
|
|
20
|
+
from typing import cast
|
|
21
21
|
|
|
22
22
|
from Crypto.Protocol.SecretSharing import Shamir
|
|
23
23
|
from Crypto.Util.Padding import pad, unpad
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def create_shares(secret: bytes, threshold: int, num: int) ->
|
|
26
|
+
def create_shares(secret: bytes, threshold: int, num: int) -> list[bytes]:
|
|
27
27
|
"""Return list of shares (bytes)."""
|
|
28
28
|
secret_padded = pad(secret, 16)
|
|
29
29
|
secret_padded_chunk = [
|
|
30
30
|
(threshold, num, secret_padded[i : i + 16])
|
|
31
31
|
for i in range(0, len(secret_padded), 16)
|
|
32
32
|
]
|
|
33
|
-
share_list:
|
|
33
|
+
share_list: list[list[tuple[int, bytes]]] = [[] for _ in range(num)]
|
|
34
34
|
|
|
35
35
|
with ThreadPoolExecutor(max_workers=10) as executor:
|
|
36
36
|
for chunk_shares in executor.map(
|
|
@@ -43,22 +43,22 @@ def create_shares(secret: bytes, threshold: int, num: int) -> List[bytes]:
|
|
|
43
43
|
return [pickle.dumps(shares) for shares in share_list]
|
|
44
44
|
|
|
45
45
|
|
|
46
|
-
def _shamir_split(threshold: int, num: int, chunk: bytes) ->
|
|
46
|
+
def _shamir_split(threshold: int, num: int, chunk: bytes) -> list[tuple[int, bytes]]:
|
|
47
47
|
return Shamir.split(threshold, num, chunk, ssss=False)
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
# Reconstructing secret with PyCryptodome
|
|
51
|
-
def combine_shares(share_list:
|
|
51
|
+
def combine_shares(share_list: list[bytes]) -> bytes:
|
|
52
52
|
"""Reconstruct secret from shares."""
|
|
53
|
-
unpickled_share_list:
|
|
54
|
-
cast(
|
|
53
|
+
unpickled_share_list: list[list[tuple[int, bytes]]] = [
|
|
54
|
+
cast(list[tuple[int, bytes]], pickle.loads(share)) for share in share_list
|
|
55
55
|
]
|
|
56
56
|
|
|
57
57
|
chunk_num = len(unpickled_share_list[0])
|
|
58
58
|
secret_padded = bytearray(0)
|
|
59
|
-
chunk_shares_list:
|
|
59
|
+
chunk_shares_list: list[list[tuple[int, bytes]]] = []
|
|
60
60
|
for i in range(chunk_num):
|
|
61
|
-
chunk_shares:
|
|
61
|
+
chunk_shares: list[tuple[int, bytes]] = []
|
|
62
62
|
for share in unpickled_share_list:
|
|
63
63
|
chunk_shares.append(share[i])
|
|
64
64
|
chunk_shares_list.append(chunk_shares)
|
|
@@ -71,5 +71,5 @@ def combine_shares(share_list: List[bytes]) -> bytes:
|
|
|
71
71
|
return bytes(secret)
|
|
72
72
|
|
|
73
73
|
|
|
74
|
-
def _shamir_combine(shares:
|
|
74
|
+
def _shamir_combine(shares: list[tuple[int, bytes]]) -> bytes:
|
|
75
75
|
return Shamir.combine(shares, ssss=False)
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import base64
|
|
19
|
-
from typing import
|
|
19
|
+
from typing import cast
|
|
20
20
|
|
|
21
21
|
from cryptography.exceptions import InvalidSignature
|
|
22
22
|
from cryptography.fernet import Fernet
|
|
@@ -26,7 +26,7 @@ from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def generate_key_pairs() -> (
|
|
29
|
-
|
|
29
|
+
tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
30
30
|
):
|
|
31
31
|
"""Generate private and public key pairs with Cryptography."""
|
|
32
32
|
private_key = ec.generate_private_key(ec.SECP384R1())
|
|
@@ -15,51 +15,51 @@
|
|
|
15
15
|
"""Utility functions for performing operations on Numpy NDArrays."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import Any,
|
|
18
|
+
from typing import Any, Union
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
from numpy.typing import DTypeLike, NDArray
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def factor_combine(factor: int, parameters:
|
|
24
|
+
def factor_combine(factor: int, parameters: list[NDArray[Any]]) -> list[NDArray[Any]]:
|
|
25
25
|
"""Combine factor with parameters."""
|
|
26
26
|
return [np.array([factor])] + parameters
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def factor_extract(
|
|
30
|
-
parameters:
|
|
31
|
-
) ->
|
|
30
|
+
parameters: list[NDArray[Any]],
|
|
31
|
+
) -> tuple[int, list[NDArray[Any]]]:
|
|
32
32
|
"""Extract factor from parameters."""
|
|
33
33
|
return parameters[0][0], parameters[1:]
|
|
34
34
|
|
|
35
35
|
|
|
36
|
-
def get_parameters_shape(parameters:
|
|
36
|
+
def get_parameters_shape(parameters: list[NDArray[Any]]) -> list[tuple[int, ...]]:
|
|
37
37
|
"""Get dimensions of each NDArray in parameters."""
|
|
38
38
|
return [arr.shape for arr in parameters]
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
def get_zero_parameters(
|
|
42
|
-
dimensions_list:
|
|
43
|
-
) ->
|
|
42
|
+
dimensions_list: list[tuple[int, ...]], dtype: DTypeLike = np.int64
|
|
43
|
+
) -> list[NDArray[Any]]:
|
|
44
44
|
"""Generate zero parameters based on the dimensions list."""
|
|
45
45
|
return [np.zeros(dimensions, dtype=dtype) for dimensions in dimensions_list]
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
def parameters_addition(
|
|
49
|
-
parameters1:
|
|
50
|
-
) ->
|
|
49
|
+
parameters1: list[NDArray[Any]], parameters2: list[NDArray[Any]]
|
|
50
|
+
) -> list[NDArray[Any]]:
|
|
51
51
|
"""Add two parameters."""
|
|
52
52
|
return [parameters1[idx] + parameters2[idx] for idx in range(len(parameters1))]
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
def parameters_subtraction(
|
|
56
|
-
parameters1:
|
|
57
|
-
) ->
|
|
56
|
+
parameters1: list[NDArray[Any]], parameters2: list[NDArray[Any]]
|
|
57
|
+
) -> list[NDArray[Any]]:
|
|
58
58
|
"""Subtract parameters from the other parameters."""
|
|
59
59
|
return [parameters1[idx] - parameters2[idx] for idx in range(len(parameters1))]
|
|
60
60
|
|
|
61
61
|
|
|
62
|
-
def parameters_mod(parameters:
|
|
62
|
+
def parameters_mod(parameters: list[NDArray[Any]], divisor: int) -> list[NDArray[Any]]:
|
|
63
63
|
"""Take mod of parameters with an integer divisor."""
|
|
64
64
|
if bin(divisor).count("1") == 1:
|
|
65
65
|
msk = divisor - 1
|
|
@@ -68,14 +68,14 @@ def parameters_mod(parameters: List[NDArray[Any]], divisor: int) -> List[NDArray
|
|
|
68
68
|
|
|
69
69
|
|
|
70
70
|
def parameters_multiply(
|
|
71
|
-
parameters:
|
|
72
|
-
) ->
|
|
71
|
+
parameters: list[NDArray[Any]], multiplier: Union[int, float]
|
|
72
|
+
) -> list[NDArray[Any]]:
|
|
73
73
|
"""Multiply parameters by an integer/float multiplier."""
|
|
74
74
|
return [parameters[idx] * multiplier for idx in range(len(parameters))]
|
|
75
75
|
|
|
76
76
|
|
|
77
77
|
def parameters_divide(
|
|
78
|
-
parameters:
|
|
79
|
-
) ->
|
|
78
|
+
parameters: list[NDArray[Any]], divisor: Union[int, float]
|
|
79
|
+
) -> list[NDArray[Any]]:
|
|
80
80
|
"""Divide weight by an integer/float divisor."""
|
|
81
81
|
return [parameters[idx] / divisor for idx in range(len(parameters))]
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Utility functions for model quantization."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import
|
|
18
|
+
from typing import cast
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
|
|
@@ -30,10 +30,10 @@ def _stochastic_round(arr: NDArrayFloat) -> NDArrayInt:
|
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
def quantize(
|
|
33
|
-
parameters:
|
|
34
|
-
) ->
|
|
33
|
+
parameters: list[NDArrayFloat], clipping_range: float, target_range: int
|
|
34
|
+
) -> list[NDArrayInt]:
|
|
35
35
|
"""Quantize float Numpy arrays to integer Numpy arrays."""
|
|
36
|
-
quantized_list:
|
|
36
|
+
quantized_list: list[NDArrayInt] = []
|
|
37
37
|
quantizer = target_range / (2 * clipping_range)
|
|
38
38
|
for arr in parameters:
|
|
39
39
|
# Stochastic quantization
|
|
@@ -49,12 +49,12 @@ def quantize(
|
|
|
49
49
|
|
|
50
50
|
# Dequantize parameters to range [-clipping_range, clipping_range]
|
|
51
51
|
def dequantize(
|
|
52
|
-
quantized_parameters:
|
|
52
|
+
quantized_parameters: list[NDArrayInt],
|
|
53
53
|
clipping_range: float,
|
|
54
54
|
target_range: int,
|
|
55
|
-
) ->
|
|
55
|
+
) -> list[NDArrayFloat]:
|
|
56
56
|
"""Dequantize integer Numpy arrays to float Numpy arrays."""
|
|
57
|
-
reverse_quantized_list:
|
|
57
|
+
reverse_quantized_list: list[NDArrayFloat] = []
|
|
58
58
|
quantizer = (2 * clipping_range) / target_range
|
|
59
59
|
shift = -clipping_range
|
|
60
60
|
for arr in quantized_parameters:
|