flwr-nightly 1.12.0.dev20240907__py3-none-any.whl → 1.12.0.dev20240913__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +1 -2
- flwr/cli/config_utils.py +10 -10
- flwr/cli/install.py +1 -2
- flwr/cli/new/new.py +26 -40
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +19 -29
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -3
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +18 -3
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -13
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +9 -1
- flwr/cli/run/run.py +6 -7
- flwr/cli/utils.py +2 -2
- flwr/client/app.py +14 -14
- flwr/client/client_app.py +5 -5
- flwr/client/clientapp/app.py +2 -2
- flwr/client/dpfedavg_numpy_client.py +6 -7
- flwr/client/grpc_adapter_client/connection.py +4 -3
- flwr/client/grpc_client/connection.py +4 -3
- flwr/client/grpc_rere_client/client_interceptor.py +5 -5
- flwr/client/grpc_rere_client/connection.py +5 -4
- flwr/client/grpc_rere_client/grpc_adapter.py +2 -2
- flwr/client/message_handler/message_handler.py +3 -3
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +25 -25
- flwr/client/mod/utils.py +1 -3
- flwr/client/node_state.py +2 -2
- flwr/client/numpy_client.py +8 -8
- flwr/client/rest_client/connection.py +5 -4
- flwr/client/supernode/app.py +7 -8
- flwr/common/address.py +2 -2
- flwr/common/config.py +8 -8
- flwr/common/constant.py +12 -1
- flwr/common/differential_privacy.py +2 -2
- flwr/common/dp.py +1 -3
- flwr/common/exit_handlers.py +3 -3
- flwr/common/grpc.py +2 -1
- flwr/common/logger.py +3 -3
- flwr/common/object_ref.py +3 -3
- flwr/common/record/configsrecord.py +3 -3
- flwr/common/record/metricsrecord.py +3 -3
- flwr/common/record/parametersrecord.py +3 -2
- flwr/common/record/recordset.py +1 -1
- flwr/common/record/typeddict.py +23 -10
- flwr/common/recordset_compat.py +7 -5
- flwr/common/retry_invoker.py +6 -17
- flwr/common/secure_aggregation/crypto/shamir.py +10 -10
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +2 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +16 -16
- flwr/common/secure_aggregation/quantization.py +7 -7
- flwr/common/secure_aggregation/secaggplus_utils.py +3 -5
- flwr/common/serde.py +11 -9
- flwr/common/telemetry.py +5 -5
- flwr/common/typing.py +19 -19
- flwr/common/version.py +2 -3
- flwr/server/app.py +18 -18
- flwr/server/client_manager.py +6 -6
- flwr/server/compat/app_utils.py +2 -3
- flwr/server/driver/driver.py +3 -2
- flwr/server/driver/grpc_driver.py +7 -7
- flwr/server/driver/inmemory_driver.py +5 -4
- flwr/server/history.py +8 -9
- flwr/server/run_serverapp.py +5 -6
- flwr/server/server.py +36 -36
- flwr/server/strategy/aggregate.py +13 -13
- flwr/server/strategy/bulyan.py +8 -8
- flwr/server/strategy/dp_adaptive_clipping.py +20 -20
- flwr/server/strategy/dp_fixed_clipping.py +19 -19
- flwr/server/strategy/dpfedavg_adaptive.py +6 -6
- flwr/server/strategy/dpfedavg_fixed.py +10 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +11 -11
- flwr/server/strategy/fedadagrad.py +8 -8
- flwr/server/strategy/fedadam.py +8 -8
- flwr/server/strategy/fedavg.py +16 -16
- flwr/server/strategy/fedavg_android.py +16 -16
- flwr/server/strategy/fedavgm.py +8 -8
- flwr/server/strategy/fedmedian.py +4 -4
- flwr/server/strategy/fedopt.py +5 -5
- flwr/server/strategy/fedprox.py +6 -6
- flwr/server/strategy/fedtrimmedavg.py +8 -8
- flwr/server/strategy/fedxgb_bagging.py +11 -11
- flwr/server/strategy/fedxgb_cyclic.py +9 -9
- flwr/server/strategy/fedxgb_nn_avg.py +5 -5
- flwr/server/strategy/fedyogi.py +8 -8
- flwr/server/strategy/krum.py +8 -8
- flwr/server/strategy/qfedavg.py +15 -15
- flwr/server/strategy/strategy.py +10 -10
- flwr/server/superlink/driver/driver_grpc.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +6 -6
- flwr/server/superlink/ffs/disk_ffs.py +4 -4
- flwr/server/superlink/ffs/ffs.py +4 -4
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -2
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +9 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +5 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +2 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +2 -3
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +26 -17
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/state/in_memory_state.py +18 -18
- flwr/server/superlink/state/sqlite_state.py +22 -21
- flwr/server/superlink/state/state.py +7 -7
- flwr/server/utils/tensorboard.py +4 -4
- flwr/server/utils/validator.py +2 -2
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +22 -22
- flwr/simulation/app.py +8 -8
- flwr/simulation/ray_transport/ray_actor.py +23 -23
- flwr/simulation/run_simulation.py +16 -4
- flwr/superexec/app.py +4 -4
- flwr/superexec/deployment.py +2 -2
- flwr/superexec/exec_grpc.py +2 -2
- flwr/superexec/exec_servicer.py +3 -2
- {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/METADATA +4 -6
- {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/RECORD +118 -118
- {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/entry_points.txt +0 -0
|
@@ -22,7 +22,7 @@ Paper: arxiv.org/abs/2304.07537
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
from logging import WARNING
|
|
25
|
-
from typing import Any,
|
|
25
|
+
from typing import Any, Optional, Union
|
|
26
26
|
|
|
27
27
|
from flwr.common import FitRes, Scalar, ndarrays_to_parameters, parameters_to_ndarrays
|
|
28
28
|
from flwr.common.logger import log, warn_deprecated_feature
|
|
@@ -56,7 +56,7 @@ class FedXgbNnAvg(FedAvg):
|
|
|
56
56
|
|
|
57
57
|
def evaluate(
|
|
58
58
|
self, server_round: int, parameters: Any
|
|
59
|
-
) -> Optional[
|
|
59
|
+
) -> Optional[tuple[float, dict[str, Scalar]]]:
|
|
60
60
|
"""Evaluate model parameters using an evaluation function."""
|
|
61
61
|
if self.evaluate_fn is None:
|
|
62
62
|
# No evaluation function provided
|
|
@@ -70,9 +70,9 @@ class FedXgbNnAvg(FedAvg):
|
|
|
70
70
|
def aggregate_fit(
|
|
71
71
|
self,
|
|
72
72
|
server_round: int,
|
|
73
|
-
results:
|
|
74
|
-
failures:
|
|
75
|
-
) ->
|
|
73
|
+
results: list[tuple[ClientProxy, FitRes]],
|
|
74
|
+
failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
|
|
75
|
+
) -> tuple[Optional[Any], dict[str, Scalar]]:
|
|
76
76
|
"""Aggregate fit results using weighted average."""
|
|
77
77
|
if not results:
|
|
78
78
|
return None, {}
|
flwr/server/strategy/fedyogi.py
CHANGED
|
@@ -18,7 +18,7 @@ Paper: arxiv.org/abs/2003.00295
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from typing import Callable,
|
|
21
|
+
from typing import Callable, Optional, Union
|
|
22
22
|
|
|
23
23
|
import numpy as np
|
|
24
24
|
|
|
@@ -93,12 +93,12 @@ class FedYogi(FedOpt):
|
|
|
93
93
|
min_available_clients: int = 2,
|
|
94
94
|
evaluate_fn: Optional[
|
|
95
95
|
Callable[
|
|
96
|
-
[int, NDArrays,
|
|
97
|
-
Optional[
|
|
96
|
+
[int, NDArrays, dict[str, Scalar]],
|
|
97
|
+
Optional[tuple[float, dict[str, Scalar]]],
|
|
98
98
|
]
|
|
99
99
|
] = None,
|
|
100
|
-
on_fit_config_fn: Optional[Callable[[int],
|
|
101
|
-
on_evaluate_config_fn: Optional[Callable[[int],
|
|
100
|
+
on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
|
|
101
|
+
on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
|
|
102
102
|
accept_failures: bool = True,
|
|
103
103
|
initial_parameters: Parameters,
|
|
104
104
|
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
|
|
@@ -137,9 +137,9 @@ class FedYogi(FedOpt):
|
|
|
137
137
|
def aggregate_fit(
|
|
138
138
|
self,
|
|
139
139
|
server_round: int,
|
|
140
|
-
results:
|
|
141
|
-
failures:
|
|
142
|
-
) ->
|
|
140
|
+
results: list[tuple[ClientProxy, FitRes]],
|
|
141
|
+
failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
|
|
142
|
+
) -> tuple[Optional[Parameters], dict[str, Scalar]]:
|
|
143
143
|
"""Aggregate fit results using weighted average."""
|
|
144
144
|
fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit(
|
|
145
145
|
server_round=server_round, results=results, failures=failures
|
flwr/server/strategy/krum.py
CHANGED
|
@@ -21,7 +21,7 @@ Paper: proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-P
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
from logging import WARNING
|
|
24
|
-
from typing import Callable,
|
|
24
|
+
from typing import Callable, Optional, Union
|
|
25
25
|
|
|
26
26
|
from flwr.common import (
|
|
27
27
|
FitRes,
|
|
@@ -87,12 +87,12 @@ class Krum(FedAvg):
|
|
|
87
87
|
num_clients_to_keep: int = 0,
|
|
88
88
|
evaluate_fn: Optional[
|
|
89
89
|
Callable[
|
|
90
|
-
[int, NDArrays,
|
|
91
|
-
Optional[
|
|
90
|
+
[int, NDArrays, dict[str, Scalar]],
|
|
91
|
+
Optional[tuple[float, dict[str, Scalar]]],
|
|
92
92
|
]
|
|
93
93
|
] = None,
|
|
94
|
-
on_fit_config_fn: Optional[Callable[[int],
|
|
95
|
-
on_evaluate_config_fn: Optional[Callable[[int],
|
|
94
|
+
on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
|
|
95
|
+
on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
|
|
96
96
|
accept_failures: bool = True,
|
|
97
97
|
initial_parameters: Optional[Parameters] = None,
|
|
98
98
|
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
|
|
@@ -123,9 +123,9 @@ class Krum(FedAvg):
|
|
|
123
123
|
def aggregate_fit(
|
|
124
124
|
self,
|
|
125
125
|
server_round: int,
|
|
126
|
-
results:
|
|
127
|
-
failures:
|
|
128
|
-
) ->
|
|
126
|
+
results: list[tuple[ClientProxy, FitRes]],
|
|
127
|
+
failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
|
|
128
|
+
) -> tuple[Optional[Parameters], dict[str, Scalar]]:
|
|
129
129
|
"""Aggregate fit results using Krum."""
|
|
130
130
|
if not results:
|
|
131
131
|
return None, {}
|
flwr/server/strategy/qfedavg.py
CHANGED
|
@@ -19,7 +19,7 @@ Paper: openreview.net/pdf?id=ByexElSYDr
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
from logging import WARNING
|
|
22
|
-
from typing import Callable,
|
|
22
|
+
from typing import Callable, Optional, Union
|
|
23
23
|
|
|
24
24
|
import numpy as np
|
|
25
25
|
|
|
@@ -60,12 +60,12 @@ class QFedAvg(FedAvg):
|
|
|
60
60
|
min_available_clients: int = 1,
|
|
61
61
|
evaluate_fn: Optional[
|
|
62
62
|
Callable[
|
|
63
|
-
[int, NDArrays,
|
|
64
|
-
Optional[
|
|
63
|
+
[int, NDArrays, dict[str, Scalar]],
|
|
64
|
+
Optional[tuple[float, dict[str, Scalar]]],
|
|
65
65
|
]
|
|
66
66
|
] = None,
|
|
67
|
-
on_fit_config_fn: Optional[Callable[[int],
|
|
68
|
-
on_evaluate_config_fn: Optional[Callable[[int],
|
|
67
|
+
on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
|
|
68
|
+
on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
|
|
69
69
|
accept_failures: bool = True,
|
|
70
70
|
initial_parameters: Optional[Parameters] = None,
|
|
71
71
|
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
|
|
@@ -95,19 +95,19 @@ class QFedAvg(FedAvg):
|
|
|
95
95
|
rep += f"q_param={self.q_param}, pre_weights={self.pre_weights})"
|
|
96
96
|
return rep
|
|
97
97
|
|
|
98
|
-
def num_fit_clients(self, num_available_clients: int) ->
|
|
98
|
+
def num_fit_clients(self, num_available_clients: int) -> tuple[int, int]:
|
|
99
99
|
"""Return the sample size and the required number of available clients."""
|
|
100
100
|
num_clients = int(num_available_clients * self.fraction_fit)
|
|
101
101
|
return max(num_clients, self.min_fit_clients), self.min_available_clients
|
|
102
102
|
|
|
103
|
-
def num_evaluation_clients(self, num_available_clients: int) ->
|
|
103
|
+
def num_evaluation_clients(self, num_available_clients: int) -> tuple[int, int]:
|
|
104
104
|
"""Use a fraction of available clients for evaluation."""
|
|
105
105
|
num_clients = int(num_available_clients * self.fraction_evaluate)
|
|
106
106
|
return max(num_clients, self.min_evaluate_clients), self.min_available_clients
|
|
107
107
|
|
|
108
108
|
def configure_fit(
|
|
109
109
|
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
|
110
|
-
) ->
|
|
110
|
+
) -> list[tuple[ClientProxy, FitIns]]:
|
|
111
111
|
"""Configure the next round of training."""
|
|
112
112
|
weights = parameters_to_ndarrays(parameters)
|
|
113
113
|
self.pre_weights = weights
|
|
@@ -131,7 +131,7 @@ class QFedAvg(FedAvg):
|
|
|
131
131
|
|
|
132
132
|
def configure_evaluate(
|
|
133
133
|
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
|
134
|
-
) ->
|
|
134
|
+
) -> list[tuple[ClientProxy, EvaluateIns]]:
|
|
135
135
|
"""Configure the next round of evaluation."""
|
|
136
136
|
# Do not configure federated evaluation if fraction_evaluate is 0
|
|
137
137
|
if self.fraction_evaluate == 0.0:
|
|
@@ -158,9 +158,9 @@ class QFedAvg(FedAvg):
|
|
|
158
158
|
def aggregate_fit(
|
|
159
159
|
self,
|
|
160
160
|
server_round: int,
|
|
161
|
-
results:
|
|
162
|
-
failures:
|
|
163
|
-
) ->
|
|
161
|
+
results: list[tuple[ClientProxy, FitRes]],
|
|
162
|
+
failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
|
|
163
|
+
) -> tuple[Optional[Parameters], dict[str, Scalar]]:
|
|
164
164
|
"""Aggregate fit results using weighted average."""
|
|
165
165
|
if not results:
|
|
166
166
|
return None, {}
|
|
@@ -229,9 +229,9 @@ class QFedAvg(FedAvg):
|
|
|
229
229
|
def aggregate_evaluate(
|
|
230
230
|
self,
|
|
231
231
|
server_round: int,
|
|
232
|
-
results:
|
|
233
|
-
failures:
|
|
234
|
-
) ->
|
|
232
|
+
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
233
|
+
failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
|
|
234
|
+
) -> tuple[Optional[float], dict[str, Scalar]]:
|
|
235
235
|
"""Aggregate evaluation losses using weighted average."""
|
|
236
236
|
if not results:
|
|
237
237
|
return None, {}
|
flwr/server/strategy/strategy.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
|
-
from typing import
|
|
19
|
+
from typing import Optional, Union
|
|
20
20
|
|
|
21
21
|
from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
|
|
22
22
|
from flwr.server.client_manager import ClientManager
|
|
@@ -47,7 +47,7 @@ class Strategy(ABC):
|
|
|
47
47
|
@abstractmethod
|
|
48
48
|
def configure_fit(
|
|
49
49
|
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
|
50
|
-
) ->
|
|
50
|
+
) -> list[tuple[ClientProxy, FitIns]]:
|
|
51
51
|
"""Configure the next round of training.
|
|
52
52
|
|
|
53
53
|
Parameters
|
|
@@ -72,9 +72,9 @@ class Strategy(ABC):
|
|
|
72
72
|
def aggregate_fit(
|
|
73
73
|
self,
|
|
74
74
|
server_round: int,
|
|
75
|
-
results:
|
|
76
|
-
failures:
|
|
77
|
-
) ->
|
|
75
|
+
results: list[tuple[ClientProxy, FitRes]],
|
|
76
|
+
failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
|
|
77
|
+
) -> tuple[Optional[Parameters], dict[str, Scalar]]:
|
|
78
78
|
"""Aggregate training results.
|
|
79
79
|
|
|
80
80
|
Parameters
|
|
@@ -108,7 +108,7 @@ class Strategy(ABC):
|
|
|
108
108
|
@abstractmethod
|
|
109
109
|
def configure_evaluate(
|
|
110
110
|
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
|
111
|
-
) ->
|
|
111
|
+
) -> list[tuple[ClientProxy, EvaluateIns]]:
|
|
112
112
|
"""Configure the next round of evaluation.
|
|
113
113
|
|
|
114
114
|
Parameters
|
|
@@ -134,9 +134,9 @@ class Strategy(ABC):
|
|
|
134
134
|
def aggregate_evaluate(
|
|
135
135
|
self,
|
|
136
136
|
server_round: int,
|
|
137
|
-
results:
|
|
138
|
-
failures:
|
|
139
|
-
) ->
|
|
137
|
+
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
138
|
+
failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
|
|
139
|
+
) -> tuple[Optional[float], dict[str, Scalar]]:
|
|
140
140
|
"""Aggregate evaluation results.
|
|
141
141
|
|
|
142
142
|
Parameters
|
|
@@ -164,7 +164,7 @@ class Strategy(ABC):
|
|
|
164
164
|
@abstractmethod
|
|
165
165
|
def evaluate(
|
|
166
166
|
self, server_round: int, parameters: Parameters
|
|
167
|
-
) -> Optional[
|
|
167
|
+
) -> Optional[tuple[float, dict[str, Scalar]]]:
|
|
168
168
|
"""Evaluate the current model parameters.
|
|
169
169
|
|
|
170
170
|
This function can be used to perform centralized (i.e., server-side) evaluation
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Driver gRPC API."""
|
|
16
16
|
|
|
17
17
|
from logging import INFO
|
|
18
|
-
from typing import Optional
|
|
18
|
+
from typing import Optional
|
|
19
19
|
|
|
20
20
|
import grpc
|
|
21
21
|
|
|
@@ -35,7 +35,7 @@ def run_driver_api_grpc(
|
|
|
35
35
|
address: str,
|
|
36
36
|
state_factory: StateFactory,
|
|
37
37
|
ffs_factory: FfsFactory,
|
|
38
|
-
certificates: Optional[
|
|
38
|
+
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
39
39
|
) -> grpc.Server:
|
|
40
40
|
"""Run Driver API (gRPC, request-response)."""
|
|
41
41
|
# Create Driver API gRPC server
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
19
|
from logging import DEBUG
|
|
20
|
-
from typing import
|
|
20
|
+
from typing import Optional
|
|
21
21
|
from uuid import UUID
|
|
22
22
|
|
|
23
23
|
import grpc
|
|
@@ -68,8 +68,8 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
68
68
|
"""Get available nodes."""
|
|
69
69
|
log(DEBUG, "DriverServicer.GetNodes")
|
|
70
70
|
state: State = self.state_factory.state()
|
|
71
|
-
all_ids:
|
|
72
|
-
nodes:
|
|
71
|
+
all_ids: set[int] = state.get_nodes(request.run_id)
|
|
72
|
+
nodes: list[Node] = [
|
|
73
73
|
Node(node_id=node_id, anonymous=False) for node_id in all_ids
|
|
74
74
|
]
|
|
75
75
|
return GetNodesResponse(nodes=nodes)
|
|
@@ -119,7 +119,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
119
119
|
state: State = self.state_factory.state()
|
|
120
120
|
|
|
121
121
|
# Store each TaskIns
|
|
122
|
-
task_ids:
|
|
122
|
+
task_ids: list[Optional[UUID]] = []
|
|
123
123
|
for task_ins in request.task_ins_list:
|
|
124
124
|
task_id: Optional[UUID] = state.store_task_ins(task_ins=task_ins)
|
|
125
125
|
task_ids.append(task_id)
|
|
@@ -135,7 +135,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
135
135
|
log(DEBUG, "DriverServicer.PullTaskRes")
|
|
136
136
|
|
|
137
137
|
# Convert each task_id str to UUID
|
|
138
|
-
task_ids:
|
|
138
|
+
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
|
139
139
|
|
|
140
140
|
# Init state
|
|
141
141
|
state: State = self.state_factory.state()
|
|
@@ -155,7 +155,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
155
155
|
context.add_callback(on_rpc_done)
|
|
156
156
|
|
|
157
157
|
# Read from state
|
|
158
|
-
task_res_list:
|
|
158
|
+
task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids, limit=None)
|
|
159
159
|
|
|
160
160
|
context.set_code(grpc.StatusCode.OK)
|
|
161
161
|
return PullTaskResResponse(task_res_list=task_res_list)
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
import hashlib
|
|
18
18
|
import json
|
|
19
19
|
from pathlib import Path
|
|
20
|
-
from typing import
|
|
20
|
+
from typing import Optional
|
|
21
21
|
|
|
22
22
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
23
23
|
|
|
@@ -35,7 +35,7 @@ class DiskFfs(Ffs): # pylint: disable=R0904
|
|
|
35
35
|
"""
|
|
36
36
|
self.base_dir = Path(base_dir)
|
|
37
37
|
|
|
38
|
-
def put(self, content: bytes, meta:
|
|
38
|
+
def put(self, content: bytes, meta: dict[str, str]) -> str:
|
|
39
39
|
"""Store bytes and metadata and return key (hash of content).
|
|
40
40
|
|
|
41
41
|
Parameters
|
|
@@ -58,7 +58,7 @@ class DiskFfs(Ffs): # pylint: disable=R0904
|
|
|
58
58
|
|
|
59
59
|
return content_hash
|
|
60
60
|
|
|
61
|
-
def get(self, key: str) -> Optional[
|
|
61
|
+
def get(self, key: str) -> Optional[tuple[bytes, dict[str, str]]]:
|
|
62
62
|
"""Return tuple containing the object content and metadata.
|
|
63
63
|
|
|
64
64
|
Parameters
|
|
@@ -90,7 +90,7 @@ class DiskFfs(Ffs): # pylint: disable=R0904
|
|
|
90
90
|
(self.base_dir / key).unlink()
|
|
91
91
|
(self.base_dir / f"{key}.META").unlink()
|
|
92
92
|
|
|
93
|
-
def list(self) ->
|
|
93
|
+
def list(self) -> list[str]:
|
|
94
94
|
"""List all keys.
|
|
95
95
|
|
|
96
96
|
Return all available keys in this `Ffs` instance.
|
flwr/server/superlink/ffs/ffs.py
CHANGED
|
@@ -16,14 +16,14 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import abc
|
|
19
|
-
from typing import
|
|
19
|
+
from typing import Optional
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class Ffs(abc.ABC): # pylint: disable=R0904
|
|
23
23
|
"""Abstract Flower File Storage interface for large objects."""
|
|
24
24
|
|
|
25
25
|
@abc.abstractmethod
|
|
26
|
-
def put(self, content: bytes, meta:
|
|
26
|
+
def put(self, content: bytes, meta: dict[str, str]) -> str:
|
|
27
27
|
"""Store bytes and metadata and return sha256hex hash of data as str.
|
|
28
28
|
|
|
29
29
|
Parameters
|
|
@@ -40,7 +40,7 @@ class Ffs(abc.ABC): # pylint: disable=R0904
|
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
42
|
@abc.abstractmethod
|
|
43
|
-
def get(self, key: str) -> Optional[
|
|
43
|
+
def get(self, key: str) -> Optional[tuple[bytes, dict[str, str]]]:
|
|
44
44
|
"""Return tuple containing the object content and metadata.
|
|
45
45
|
|
|
46
46
|
Parameters
|
|
@@ -65,7 +65,7 @@ class Ffs(abc.ABC): # pylint: disable=R0904
|
|
|
65
65
|
"""
|
|
66
66
|
|
|
67
67
|
@abc.abstractmethod
|
|
68
|
-
def list(self) ->
|
|
68
|
+
def list(self) -> list[str]:
|
|
69
69
|
"""List keys of all stored objects.
|
|
70
70
|
|
|
71
71
|
Return all available keys in this `Ffs` instance.
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from logging import DEBUG, INFO
|
|
19
|
-
from typing import Callable,
|
|
19
|
+
from typing import Callable, TypeVar
|
|
20
20
|
|
|
21
21
|
import grpc
|
|
22
22
|
from google.protobuf.message import Message as GrpcMessage
|
|
@@ -47,7 +47,7 @@ T = TypeVar("T", bound=GrpcMessage)
|
|
|
47
47
|
|
|
48
48
|
def _handle(
|
|
49
49
|
msg_container: MessageContainer,
|
|
50
|
-
request_type:
|
|
50
|
+
request_type: type[T],
|
|
51
51
|
handler: Callable[[T], GrpcMessage],
|
|
52
52
|
) -> MessageContainer:
|
|
53
53
|
req = request_type.FromString(msg_container.grpc_message_content)
|
|
@@ -15,10 +15,11 @@
|
|
|
15
15
|
"""Provides class GrpcBridge."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from collections.abc import Iterator
|
|
18
19
|
from dataclasses import dataclass
|
|
19
20
|
from enum import Enum
|
|
20
21
|
from threading import Condition
|
|
21
|
-
from typing import
|
|
22
|
+
from typing import Optional
|
|
22
23
|
|
|
23
24
|
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
|
24
25
|
ClientMessage,
|
|
@@ -17,8 +17,9 @@
|
|
|
17
17
|
|
|
18
18
|
import concurrent.futures
|
|
19
19
|
import sys
|
|
20
|
+
from collections.abc import Sequence
|
|
20
21
|
from logging import ERROR
|
|
21
|
-
from typing import Any, Callable, Optional,
|
|
22
|
+
from typing import Any, Callable, Optional, Union
|
|
22
23
|
|
|
23
24
|
import grpc
|
|
24
25
|
|
|
@@ -46,7 +47,7 @@ INVALID_CERTIFICATES_ERR_MSG = """
|
|
|
46
47
|
AddServicerToServerFn = Callable[..., Any]
|
|
47
48
|
|
|
48
49
|
|
|
49
|
-
def valid_certificates(certificates:
|
|
50
|
+
def valid_certificates(certificates: tuple[bytes, bytes, bytes]) -> bool:
|
|
50
51
|
"""Validate certificates tuple."""
|
|
51
52
|
is_valid = (
|
|
52
53
|
all(isinstance(certificate, bytes) for certificate in certificates)
|
|
@@ -65,7 +66,7 @@ def start_grpc_server( # pylint: disable=too-many-arguments
|
|
|
65
66
|
max_concurrent_workers: int = 1000,
|
|
66
67
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
67
68
|
keepalive_time_ms: int = 210000,
|
|
68
|
-
certificates: Optional[
|
|
69
|
+
certificates: Optional[tuple[bytes, bytes, bytes]] = None,
|
|
69
70
|
) -> grpc.Server:
|
|
70
71
|
"""Create and start a gRPC server running FlowerServiceServicer.
|
|
71
72
|
|
|
@@ -157,16 +158,16 @@ def start_grpc_server( # pylint: disable=too-many-arguments
|
|
|
157
158
|
|
|
158
159
|
def generic_create_grpc_server( # pylint: disable=too-many-arguments
|
|
159
160
|
servicer_and_add_fn: Union[
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
161
|
+
tuple[FleetServicer, AddServicerToServerFn],
|
|
162
|
+
tuple[GrpcAdapterServicer, AddServicerToServerFn],
|
|
163
|
+
tuple[FlowerServiceServicer, AddServicerToServerFn],
|
|
164
|
+
tuple[DriverServicer, AddServicerToServerFn],
|
|
164
165
|
],
|
|
165
166
|
server_address: str,
|
|
166
167
|
max_concurrent_workers: int = 1000,
|
|
167
168
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
168
169
|
keepalive_time_ms: int = 210000,
|
|
169
|
-
certificates: Optional[
|
|
170
|
+
certificates: Optional[tuple[bytes, bytes, bytes]] = None,
|
|
170
171
|
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
171
172
|
) -> grpc.Server:
|
|
172
173
|
"""Create a gRPC server with a single servicer.
|
|
@@ -16,8 +16,9 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import base64
|
|
19
|
+
from collections.abc import Sequence
|
|
19
20
|
from logging import INFO, WARNING
|
|
20
|
-
from typing import Any, Callable, Optional,
|
|
21
|
+
from typing import Any, Callable, Optional, Union
|
|
21
22
|
|
|
22
23
|
import grpc
|
|
23
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
@@ -68,7 +69,7 @@ Response = Union[
|
|
|
68
69
|
|
|
69
70
|
|
|
70
71
|
def _get_value_from_tuples(
|
|
71
|
-
key_string: str, tuples: Sequence[
|
|
72
|
+
key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]]
|
|
72
73
|
) -> bytes:
|
|
73
74
|
value = next((value for key, value in tuples if key == key_string), "")
|
|
74
75
|
if isinstance(value, str):
|
|
@@ -188,7 +189,8 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
188
189
|
self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
|
|
189
190
|
) -> bool:
|
|
190
191
|
shared_secret = generate_shared_key(self.server_private_key, public_key)
|
|
191
|
-
|
|
192
|
+
message_bytes = request.SerializeToString(deterministic=True)
|
|
193
|
+
return verify_hmac(shared_secret, message_bytes, hmac_value)
|
|
192
194
|
|
|
193
195
|
def _create_authenticated_node(
|
|
194
196
|
self,
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
|
-
from typing import
|
|
19
|
+
from typing import Optional
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
22
|
from flwr.common.serde import fab_to_proto, user_config_to_proto
|
|
@@ -83,7 +83,7 @@ def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsRespo
|
|
|
83
83
|
node_id: Optional[int] = None if node.anonymous else node.node_id
|
|
84
84
|
|
|
85
85
|
# Retrieve TaskIns from State
|
|
86
|
-
task_ins_list:
|
|
86
|
+
task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1)
|
|
87
87
|
|
|
88
88
|
# Build response
|
|
89
89
|
response = PullTaskInsResponse(
|
|
@@ -15,17 +15,16 @@
|
|
|
15
15
|
"""Simulation Engine Backends."""
|
|
16
16
|
|
|
17
17
|
import importlib
|
|
18
|
-
from typing import Dict, Type
|
|
19
18
|
|
|
20
19
|
from .backend import Backend, BackendConfig
|
|
21
20
|
|
|
22
21
|
is_ray_installed = importlib.util.find_spec("ray") is not None
|
|
23
22
|
|
|
24
23
|
# Mapping of supported backends
|
|
25
|
-
supported_backends:
|
|
24
|
+
supported_backends: dict[str, type[Backend]] = {}
|
|
26
25
|
|
|
27
26
|
# To log backend-specific error message when chosen backend isn't available
|
|
28
|
-
error_messages_backends:
|
|
27
|
+
error_messages_backends: dict[str, str] = {}
|
|
29
28
|
|
|
30
29
|
if is_ray_installed:
|
|
31
30
|
from .raybackend import RayBackend
|
|
@@ -16,14 +16,14 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
|
-
from typing import Callable
|
|
19
|
+
from typing import Callable
|
|
20
20
|
|
|
21
21
|
from flwr.client.client_app import ClientApp
|
|
22
22
|
from flwr.common.context import Context
|
|
23
23
|
from flwr.common.message import Message
|
|
24
24
|
from flwr.common.typing import ConfigsRecordValues
|
|
25
25
|
|
|
26
|
-
BackendConfig =
|
|
26
|
+
BackendConfig = dict[str, dict[str, ConfigsRecordValues]]
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class Backend(ABC):
|
|
@@ -62,5 +62,5 @@ class Backend(ABC):
|
|
|
62
62
|
self,
|
|
63
63
|
message: Message,
|
|
64
64
|
context: Context,
|
|
65
|
-
) ->
|
|
65
|
+
) -> tuple[Message, Context]:
|
|
66
66
|
"""Submit a job to the backend."""
|