flwr-nightly 1.12.0.dev20240906__py3-none-any.whl → 1.12.0.dev20240913__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +1 -2
- flwr/cli/config_utils.py +10 -10
- flwr/cli/install.py +1 -2
- flwr/cli/new/new.py +26 -40
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +19 -29
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -3
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +18 -3
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -13
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +9 -1
- flwr/cli/run/run.py +6 -7
- flwr/cli/utils.py +2 -2
- flwr/client/app.py +14 -14
- flwr/client/client_app.py +5 -5
- flwr/client/clientapp/app.py +2 -2
- flwr/client/dpfedavg_numpy_client.py +6 -7
- flwr/client/grpc_adapter_client/connection.py +4 -3
- flwr/client/grpc_client/connection.py +4 -3
- flwr/client/grpc_rere_client/client_interceptor.py +5 -5
- flwr/client/grpc_rere_client/connection.py +5 -4
- flwr/client/grpc_rere_client/grpc_adapter.py +2 -2
- flwr/client/message_handler/message_handler.py +3 -3
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +25 -25
- flwr/client/mod/utils.py +1 -3
- flwr/client/node_state.py +2 -2
- flwr/client/numpy_client.py +8 -8
- flwr/client/rest_client/connection.py +5 -4
- flwr/client/supernode/app.py +7 -8
- flwr/common/address.py +2 -2
- flwr/common/config.py +8 -8
- flwr/common/constant.py +12 -1
- flwr/common/differential_privacy.py +2 -2
- flwr/common/dp.py +1 -3
- flwr/common/exit_handlers.py +3 -3
- flwr/common/grpc.py +2 -1
- flwr/common/logger.py +3 -3
- flwr/common/object_ref.py +3 -3
- flwr/common/record/configsrecord.py +3 -3
- flwr/common/record/metricsrecord.py +3 -3
- flwr/common/record/parametersrecord.py +3 -2
- flwr/common/record/recordset.py +1 -1
- flwr/common/record/typeddict.py +23 -10
- flwr/common/recordset_compat.py +7 -5
- flwr/common/retry_invoker.py +6 -17
- flwr/common/secure_aggregation/crypto/shamir.py +10 -10
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +2 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +16 -16
- flwr/common/secure_aggregation/quantization.py +7 -7
- flwr/common/secure_aggregation/secaggplus_utils.py +3 -5
- flwr/common/serde.py +11 -9
- flwr/common/telemetry.py +5 -5
- flwr/common/typing.py +19 -19
- flwr/common/version.py +2 -3
- flwr/server/app.py +18 -18
- flwr/server/client_manager.py +6 -6
- flwr/server/compat/app_utils.py +2 -3
- flwr/server/driver/driver.py +3 -2
- flwr/server/driver/grpc_driver.py +7 -7
- flwr/server/driver/inmemory_driver.py +5 -4
- flwr/server/history.py +8 -9
- flwr/server/run_serverapp.py +5 -6
- flwr/server/server.py +36 -36
- flwr/server/strategy/aggregate.py +13 -13
- flwr/server/strategy/bulyan.py +8 -8
- flwr/server/strategy/dp_adaptive_clipping.py +20 -20
- flwr/server/strategy/dp_fixed_clipping.py +19 -19
- flwr/server/strategy/dpfedavg_adaptive.py +6 -6
- flwr/server/strategy/dpfedavg_fixed.py +10 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +11 -11
- flwr/server/strategy/fedadagrad.py +8 -8
- flwr/server/strategy/fedadam.py +8 -8
- flwr/server/strategy/fedavg.py +16 -16
- flwr/server/strategy/fedavg_android.py +16 -16
- flwr/server/strategy/fedavgm.py +8 -8
- flwr/server/strategy/fedmedian.py +4 -4
- flwr/server/strategy/fedopt.py +5 -5
- flwr/server/strategy/fedprox.py +6 -6
- flwr/server/strategy/fedtrimmedavg.py +8 -8
- flwr/server/strategy/fedxgb_bagging.py +11 -11
- flwr/server/strategy/fedxgb_cyclic.py +9 -9
- flwr/server/strategy/fedxgb_nn_avg.py +5 -5
- flwr/server/strategy/fedyogi.py +8 -8
- flwr/server/strategy/krum.py +8 -8
- flwr/server/strategy/qfedavg.py +15 -15
- flwr/server/strategy/strategy.py +10 -10
- flwr/server/superlink/driver/driver_grpc.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +6 -6
- flwr/server/superlink/ffs/disk_ffs.py +4 -4
- flwr/server/superlink/ffs/ffs.py +4 -4
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -2
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +9 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +5 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +2 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +2 -3
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +26 -17
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/state/in_memory_state.py +18 -18
- flwr/server/superlink/state/sqlite_state.py +22 -21
- flwr/server/superlink/state/state.py +7 -7
- flwr/server/utils/tensorboard.py +4 -4
- flwr/server/utils/validator.py +2 -2
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +22 -22
- flwr/simulation/app.py +8 -8
- flwr/simulation/ray_transport/ray_actor.py +23 -23
- flwr/simulation/run_simulation.py +16 -4
- flwr/superexec/app.py +4 -4
- flwr/superexec/deployment.py +2 -2
- flwr/superexec/exec_grpc.py +2 -2
- flwr/superexec/exec_servicer.py +3 -2
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/METADATA +4 -6
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/RECORD +118 -118
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/entry_points.txt +0 -0
|
@@ -15,8 +15,6 @@
|
|
|
15
15
|
"""Utility functions for the SecAgg/SecAgg+ protocol."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import List, Tuple
|
|
19
|
-
|
|
20
18
|
import numpy as np
|
|
21
19
|
|
|
22
20
|
from flwr.common.typing import NDArrayInt
|
|
@@ -54,7 +52,7 @@ def share_keys_plaintext_concat(
|
|
|
54
52
|
)
|
|
55
53
|
|
|
56
54
|
|
|
57
|
-
def share_keys_plaintext_separate(plaintext: bytes) ->
|
|
55
|
+
def share_keys_plaintext_separate(plaintext: bytes) -> tuple[int, int, bytes, bytes]:
|
|
58
56
|
"""Retrieve arguments from bytes.
|
|
59
57
|
|
|
60
58
|
Parameters
|
|
@@ -83,8 +81,8 @@ def share_keys_plaintext_separate(plaintext: bytes) -> Tuple[int, int, bytes, by
|
|
|
83
81
|
|
|
84
82
|
|
|
85
83
|
def pseudo_rand_gen(
|
|
86
|
-
seed: bytes, num_range: int, dimensions_list:
|
|
87
|
-
) ->
|
|
84
|
+
seed: bytes, num_range: int, dimensions_list: list[tuple[int, ...]]
|
|
85
|
+
) -> list[NDArrayInt]:
|
|
88
86
|
"""Seeded pseudo-random number generator for noise generation with Numpy."""
|
|
89
87
|
assert len(seed) & 0x3 == 0
|
|
90
88
|
seed32 = 0
|
flwr/common/serde.py
CHANGED
|
@@ -15,7 +15,9 @@
|
|
|
15
15
|
"""ProtoBuf serialization and deserialization."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from
|
|
18
|
+
from collections import OrderedDict
|
|
19
|
+
from collections.abc import MutableMapping
|
|
20
|
+
from typing import Any, TypeVar, cast
|
|
19
21
|
|
|
20
22
|
from google.protobuf.message import Message as GrpcMessage
|
|
21
23
|
|
|
@@ -72,7 +74,7 @@ def parameters_to_proto(parameters: typing.Parameters) -> Parameters:
|
|
|
72
74
|
|
|
73
75
|
def parameters_from_proto(msg: Parameters) -> typing.Parameters:
|
|
74
76
|
"""Deserialize `Parameters` from ProtoBuf."""
|
|
75
|
-
tensors:
|
|
77
|
+
tensors: list[bytes] = list(msg.tensors)
|
|
76
78
|
return typing.Parameters(tensors=tensors, tensor_type=msg.tensor_type)
|
|
77
79
|
|
|
78
80
|
|
|
@@ -390,7 +392,7 @@ T = TypeVar("T")
|
|
|
390
392
|
|
|
391
393
|
|
|
392
394
|
def _record_value_to_proto(
|
|
393
|
-
value: Any, allowed_types:
|
|
395
|
+
value: Any, allowed_types: list[type], proto_class: type[T]
|
|
394
396
|
) -> T:
|
|
395
397
|
"""Serialize `*RecordValue` to ProtoBuf.
|
|
396
398
|
|
|
@@ -427,9 +429,9 @@ def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
|
|
|
427
429
|
|
|
428
430
|
def _record_value_dict_to_proto(
|
|
429
431
|
value_dict: TypedDict[str, Any],
|
|
430
|
-
allowed_types:
|
|
431
|
-
value_proto_class:
|
|
432
|
-
) ->
|
|
432
|
+
allowed_types: list[type],
|
|
433
|
+
value_proto_class: type[T],
|
|
434
|
+
) -> dict[str, T]:
|
|
433
435
|
"""Serialize the record value dict to ProtoBuf.
|
|
434
436
|
|
|
435
437
|
Note: `bool` MUST be put in the front of allowd_types if it exists.
|
|
@@ -447,7 +449,7 @@ def _record_value_dict_to_proto(
|
|
|
447
449
|
|
|
448
450
|
def _record_value_dict_from_proto(
|
|
449
451
|
value_dict_proto: MutableMapping[str, Any]
|
|
450
|
-
) ->
|
|
452
|
+
) -> dict[str, Any]:
|
|
451
453
|
"""Deserialize the record value dict from ProtoBuf."""
|
|
452
454
|
return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}
|
|
453
455
|
|
|
@@ -498,7 +500,7 @@ def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord
|
|
|
498
500
|
"""Deserialize MetricsRecord from ProtoBuf."""
|
|
499
501
|
return MetricsRecord(
|
|
500
502
|
metrics_dict=cast(
|
|
501
|
-
|
|
503
|
+
dict[str, typing.MetricsRecordValues],
|
|
502
504
|
_record_value_dict_from_proto(record_proto.data),
|
|
503
505
|
),
|
|
504
506
|
keep_input=False,
|
|
@@ -520,7 +522,7 @@ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord
|
|
|
520
522
|
"""Deserialize ConfigsRecord from ProtoBuf."""
|
|
521
523
|
return ConfigsRecord(
|
|
522
524
|
configs_dict=cast(
|
|
523
|
-
|
|
525
|
+
dict[str, typing.ConfigsRecordValues],
|
|
524
526
|
_record_value_dict_from_proto(record_proto.data),
|
|
525
527
|
),
|
|
526
528
|
keep_input=False,
|
flwr/common/telemetry.py
CHANGED
|
@@ -25,7 +25,7 @@ import uuid
|
|
|
25
25
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
26
26
|
from enum import Enum, auto
|
|
27
27
|
from pathlib import Path
|
|
28
|
-
from typing import Any,
|
|
28
|
+
from typing import Any, Optional, Union, cast
|
|
29
29
|
|
|
30
30
|
from flwr.common.version import package_name, package_version
|
|
31
31
|
|
|
@@ -126,7 +126,7 @@ class EventType(str, Enum):
|
|
|
126
126
|
# The type signature is not compatible with mypy, pylint and flake8
|
|
127
127
|
# so each of those needs to be disabled for this line.
|
|
128
128
|
# pylint: disable-next=no-self-argument,arguments-differ,line-too-long
|
|
129
|
-
def _generate_next_value_(name: str, start: int, count: int, last_values:
|
|
129
|
+
def _generate_next_value_(name: str, start: int, count: int, last_values: list[Any]) -> Any: # type: ignore # noqa: E501
|
|
130
130
|
return name
|
|
131
131
|
|
|
132
132
|
# Ping
|
|
@@ -189,7 +189,7 @@ class EventType(str, Enum):
|
|
|
189
189
|
|
|
190
190
|
# Use the ThreadPoolExecutor with max_workers=1 to have a queue
|
|
191
191
|
# and also ensure that telemetry calls are not blocking.
|
|
192
|
-
state:
|
|
192
|
+
state: dict[str, Union[Optional[str], Optional[ThreadPoolExecutor]]] = {
|
|
193
193
|
# Will be assigned ThreadPoolExecutor(max_workers=1)
|
|
194
194
|
# in event() the first time it's required
|
|
195
195
|
"executor": None,
|
|
@@ -201,7 +201,7 @@ state: Dict[str, Union[Optional[str], Optional[ThreadPoolExecutor]]] = {
|
|
|
201
201
|
|
|
202
202
|
def event(
|
|
203
203
|
event_type: EventType,
|
|
204
|
-
event_details: Optional[
|
|
204
|
+
event_details: Optional[dict[str, Any]] = None,
|
|
205
205
|
) -> Future: # type: ignore
|
|
206
206
|
"""Submit create_event to ThreadPoolExecutor to avoid blocking."""
|
|
207
207
|
if state["executor"] is None:
|
|
@@ -213,7 +213,7 @@ def event(
|
|
|
213
213
|
return result
|
|
214
214
|
|
|
215
215
|
|
|
216
|
-
def create_event(event_type: EventType, event_details: Optional[
|
|
216
|
+
def create_event(event_type: EventType, event_details: Optional[dict[str, Any]]) -> str:
|
|
217
217
|
"""Create telemetry event."""
|
|
218
218
|
if state["source"] is None:
|
|
219
219
|
state["source"] = _get_source_id()
|
flwr/common/typing.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
from dataclasses import dataclass
|
|
19
19
|
from enum import Enum
|
|
20
|
-
from typing import Any, Callable,
|
|
20
|
+
from typing import Any, Callable, Optional, Union
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
23
|
import numpy.typing as npt
|
|
@@ -25,7 +25,7 @@ import numpy.typing as npt
|
|
|
25
25
|
NDArray = npt.NDArray[Any]
|
|
26
26
|
NDArrayInt = npt.NDArray[np.int_]
|
|
27
27
|
NDArrayFloat = npt.NDArray[np.float_]
|
|
28
|
-
NDArrays =
|
|
28
|
+
NDArrays = list[NDArray]
|
|
29
29
|
|
|
30
30
|
# The following union type contains Python types corresponding to ProtoBuf types that
|
|
31
31
|
# ProtoBuf considers to be "Scalar Value Types", even though some of them arguably do
|
|
@@ -38,31 +38,31 @@ Value = Union[
|
|
|
38
38
|
float,
|
|
39
39
|
int,
|
|
40
40
|
str,
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
41
|
+
list[bool],
|
|
42
|
+
list[bytes],
|
|
43
|
+
list[float],
|
|
44
|
+
list[int],
|
|
45
|
+
list[str],
|
|
46
46
|
]
|
|
47
47
|
|
|
48
48
|
# Value types for common.MetricsRecord
|
|
49
49
|
MetricsScalar = Union[int, float]
|
|
50
|
-
MetricsScalarList = Union[
|
|
50
|
+
MetricsScalarList = Union[list[int], list[float]]
|
|
51
51
|
MetricsRecordValues = Union[MetricsScalar, MetricsScalarList]
|
|
52
52
|
# Value types for common.ConfigsRecord
|
|
53
53
|
ConfigsScalar = Union[MetricsScalar, str, bytes, bool]
|
|
54
|
-
ConfigsScalarList = Union[MetricsScalarList,
|
|
54
|
+
ConfigsScalarList = Union[MetricsScalarList, list[str], list[bytes], list[bool]]
|
|
55
55
|
ConfigsRecordValues = Union[ConfigsScalar, ConfigsScalarList]
|
|
56
56
|
|
|
57
|
-
Metrics =
|
|
58
|
-
MetricsAggregationFn = Callable[[
|
|
57
|
+
Metrics = dict[str, Scalar]
|
|
58
|
+
MetricsAggregationFn = Callable[[list[tuple[int, Metrics]]], Metrics]
|
|
59
59
|
|
|
60
|
-
Config =
|
|
61
|
-
Properties =
|
|
60
|
+
Config = dict[str, Scalar]
|
|
61
|
+
Properties = dict[str, Scalar]
|
|
62
62
|
|
|
63
63
|
# Value type for user configs
|
|
64
64
|
UserConfigValue = Union[bool, float, int, str]
|
|
65
|
-
UserConfig =
|
|
65
|
+
UserConfig = dict[str, UserConfigValue]
|
|
66
66
|
|
|
67
67
|
|
|
68
68
|
class Code(Enum):
|
|
@@ -103,7 +103,7 @@ class ClientAppOutputStatus:
|
|
|
103
103
|
class Parameters:
|
|
104
104
|
"""Model parameters."""
|
|
105
105
|
|
|
106
|
-
tensors:
|
|
106
|
+
tensors: list[bytes]
|
|
107
107
|
tensor_type: str
|
|
108
108
|
|
|
109
109
|
|
|
@@ -127,7 +127,7 @@ class FitIns:
|
|
|
127
127
|
"""Fit instructions for a client."""
|
|
128
128
|
|
|
129
129
|
parameters: Parameters
|
|
130
|
-
config:
|
|
130
|
+
config: dict[str, Scalar]
|
|
131
131
|
|
|
132
132
|
|
|
133
133
|
@dataclass
|
|
@@ -137,7 +137,7 @@ class FitRes:
|
|
|
137
137
|
status: Status
|
|
138
138
|
parameters: Parameters
|
|
139
139
|
num_examples: int
|
|
140
|
-
metrics:
|
|
140
|
+
metrics: dict[str, Scalar]
|
|
141
141
|
|
|
142
142
|
|
|
143
143
|
@dataclass
|
|
@@ -145,7 +145,7 @@ class EvaluateIns:
|
|
|
145
145
|
"""Evaluate instructions for a client."""
|
|
146
146
|
|
|
147
147
|
parameters: Parameters
|
|
148
|
-
config:
|
|
148
|
+
config: dict[str, Scalar]
|
|
149
149
|
|
|
150
150
|
|
|
151
151
|
@dataclass
|
|
@@ -155,7 +155,7 @@ class EvaluateRes:
|
|
|
155
155
|
status: Status
|
|
156
156
|
loss: float
|
|
157
157
|
num_examples: int
|
|
158
|
-
metrics:
|
|
158
|
+
metrics: dict[str, Scalar]
|
|
159
159
|
|
|
160
160
|
|
|
161
161
|
@dataclass
|
flwr/common/version.py
CHANGED
|
@@ -15,15 +15,14 @@
|
|
|
15
15
|
"""Flower package version helper."""
|
|
16
16
|
|
|
17
17
|
import importlib.metadata as importlib_metadata
|
|
18
|
-
from typing import Tuple
|
|
19
18
|
|
|
20
19
|
|
|
21
|
-
def _check_package(name: str) ->
|
|
20
|
+
def _check_package(name: str) -> tuple[str, str]:
|
|
22
21
|
version: str = importlib_metadata.version(name)
|
|
23
22
|
return name, version
|
|
24
23
|
|
|
25
24
|
|
|
26
|
-
def _version() ->
|
|
25
|
+
def _version() -> tuple[str, str]:
|
|
27
26
|
"""Read and return Flower package name and version.
|
|
28
27
|
|
|
29
28
|
Returns
|
flwr/server/app.py
CHANGED
|
@@ -19,10 +19,11 @@ import csv
|
|
|
19
19
|
import importlib.util
|
|
20
20
|
import sys
|
|
21
21
|
import threading
|
|
22
|
+
from collections.abc import Sequence
|
|
22
23
|
from logging import INFO, WARN
|
|
23
24
|
from os.path import isfile
|
|
24
25
|
from pathlib import Path
|
|
25
|
-
from typing import Optional
|
|
26
|
+
from typing import Optional
|
|
26
27
|
|
|
27
28
|
import grpc
|
|
28
29
|
from cryptography.exceptions import UnsupportedAlgorithm
|
|
@@ -36,6 +37,10 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
|
36
37
|
from flwr.common.address import parse_address
|
|
37
38
|
from flwr.common.config import get_flwr_dir
|
|
38
39
|
from flwr.common.constant import (
|
|
40
|
+
DRIVER_API_DEFAULT_ADDRESS,
|
|
41
|
+
FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
42
|
+
FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
|
|
43
|
+
FLEET_API_REST_DEFAULT_ADDRESS,
|
|
39
44
|
MISSING_EXTRA_REST,
|
|
40
45
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
41
46
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
@@ -68,24 +73,19 @@ from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
|
|
68
73
|
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
|
|
69
74
|
from .superlink.state import StateFactory
|
|
70
75
|
|
|
71
|
-
ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
72
|
-
ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092"
|
|
73
|
-
ADDRESS_FLEET_API_GRPC_BIDI = "[::]:8080" # IPv6 to keep start_server compatible
|
|
74
|
-
ADDRESS_FLEET_API_REST = "0.0.0.0:9093"
|
|
75
|
-
|
|
76
76
|
DATABASE = ":flwr-in-memory-state:"
|
|
77
77
|
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
|
|
78
78
|
|
|
79
79
|
|
|
80
80
|
def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
81
81
|
*,
|
|
82
|
-
server_address: str =
|
|
82
|
+
server_address: str = FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
83
83
|
server: Optional[Server] = None,
|
|
84
84
|
config: Optional[ServerConfig] = None,
|
|
85
85
|
strategy: Optional[Strategy] = None,
|
|
86
86
|
client_manager: Optional[ClientManager] = None,
|
|
87
87
|
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
88
|
-
certificates: Optional[
|
|
88
|
+
certificates: Optional[tuple[bytes, bytes, bytes]] = None,
|
|
89
89
|
) -> History:
|
|
90
90
|
"""Start a Flower server using the gRPC transport layer.
|
|
91
91
|
|
|
@@ -232,9 +232,9 @@ def run_superlink() -> None:
|
|
|
232
232
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
233
233
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
234
234
|
]:
|
|
235
|
-
args.fleet_api_address =
|
|
235
|
+
args.fleet_api_address = FLEET_API_GRPC_RERE_DEFAULT_ADDRESS
|
|
236
236
|
elif args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
237
|
-
args.fleet_api_address =
|
|
237
|
+
args.fleet_api_address = FLEET_API_REST_DEFAULT_ADDRESS
|
|
238
238
|
|
|
239
239
|
fleet_address, host, port = _format_address(args.fleet_api_address)
|
|
240
240
|
|
|
@@ -334,7 +334,7 @@ def run_superlink() -> None:
|
|
|
334
334
|
driver_server.wait_for_termination(timeout=1)
|
|
335
335
|
|
|
336
336
|
|
|
337
|
-
def _format_address(address: str) ->
|
|
337
|
+
def _format_address(address: str) -> tuple[str, str, int]:
|
|
338
338
|
parsed_address = parse_address(address)
|
|
339
339
|
if not parsed_address:
|
|
340
340
|
sys.exit(
|
|
@@ -346,8 +346,8 @@ def _format_address(address: str) -> Tuple[str, str, int]:
|
|
|
346
346
|
|
|
347
347
|
def _try_setup_node_authentication(
|
|
348
348
|
args: argparse.Namespace,
|
|
349
|
-
certificates: Optional[
|
|
350
|
-
) -> Optional[
|
|
349
|
+
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
350
|
+
) -> Optional[tuple[set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
|
|
351
351
|
if (
|
|
352
352
|
not args.auth_list_public_keys
|
|
353
353
|
and not args.auth_superlink_private_key
|
|
@@ -382,7 +382,7 @@ def _try_setup_node_authentication(
|
|
|
382
382
|
"to '--auth-list-public-keys'."
|
|
383
383
|
)
|
|
384
384
|
|
|
385
|
-
node_public_keys:
|
|
385
|
+
node_public_keys: set[bytes] = set()
|
|
386
386
|
|
|
387
387
|
try:
|
|
388
388
|
ssh_private_key = load_ssh_private_key(
|
|
@@ -435,7 +435,7 @@ def _try_setup_node_authentication(
|
|
|
435
435
|
|
|
436
436
|
def _try_obtain_certificates(
|
|
437
437
|
args: argparse.Namespace,
|
|
438
|
-
) -> Optional[
|
|
438
|
+
) -> Optional[tuple[bytes, bytes, bytes]]:
|
|
439
439
|
# Obtain certificates
|
|
440
440
|
if args.insecure:
|
|
441
441
|
log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.")
|
|
@@ -491,7 +491,7 @@ def _run_fleet_api_grpc_rere(
|
|
|
491
491
|
address: str,
|
|
492
492
|
state_factory: StateFactory,
|
|
493
493
|
ffs_factory: FfsFactory,
|
|
494
|
-
certificates: Optional[
|
|
494
|
+
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
495
495
|
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
496
496
|
) -> grpc.Server:
|
|
497
497
|
"""Run Fleet API (gRPC, request-response)."""
|
|
@@ -519,7 +519,7 @@ def _run_fleet_api_grpc_adapter(
|
|
|
519
519
|
address: str,
|
|
520
520
|
state_factory: StateFactory,
|
|
521
521
|
ffs_factory: FfsFactory,
|
|
522
|
-
certificates: Optional[
|
|
522
|
+
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
523
523
|
) -> grpc.Server:
|
|
524
524
|
"""Run Fleet API (GrpcAdapter)."""
|
|
525
525
|
# Create Fleet API gRPC server
|
|
@@ -653,7 +653,7 @@ def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
|
|
|
653
653
|
parser.add_argument(
|
|
654
654
|
"--driver-api-address",
|
|
655
655
|
help="Driver API (gRPC) server address (IPv4, IPv6, or a domain name).",
|
|
656
|
-
default=
|
|
656
|
+
default=DRIVER_API_DEFAULT_ADDRESS,
|
|
657
657
|
)
|
|
658
658
|
|
|
659
659
|
|
flwr/server/client_manager.py
CHANGED
|
@@ -19,7 +19,7 @@ import random
|
|
|
19
19
|
import threading
|
|
20
20
|
from abc import ABC, abstractmethod
|
|
21
21
|
from logging import INFO
|
|
22
|
-
from typing import
|
|
22
|
+
from typing import Optional
|
|
23
23
|
|
|
24
24
|
from flwr.common.logger import log
|
|
25
25
|
|
|
@@ -67,7 +67,7 @@ class ClientManager(ABC):
|
|
|
67
67
|
"""
|
|
68
68
|
|
|
69
69
|
@abstractmethod
|
|
70
|
-
def all(self) ->
|
|
70
|
+
def all(self) -> dict[str, ClientProxy]:
|
|
71
71
|
"""Return all available clients."""
|
|
72
72
|
|
|
73
73
|
@abstractmethod
|
|
@@ -80,7 +80,7 @@ class ClientManager(ABC):
|
|
|
80
80
|
num_clients: int,
|
|
81
81
|
min_num_clients: Optional[int] = None,
|
|
82
82
|
criterion: Optional[Criterion] = None,
|
|
83
|
-
) ->
|
|
83
|
+
) -> list[ClientProxy]:
|
|
84
84
|
"""Sample a number of Flower ClientProxy instances."""
|
|
85
85
|
|
|
86
86
|
|
|
@@ -88,7 +88,7 @@ class SimpleClientManager(ClientManager):
|
|
|
88
88
|
"""Provides a pool of available clients."""
|
|
89
89
|
|
|
90
90
|
def __init__(self) -> None:
|
|
91
|
-
self.clients:
|
|
91
|
+
self.clients: dict[str, ClientProxy] = {}
|
|
92
92
|
self._cv = threading.Condition()
|
|
93
93
|
|
|
94
94
|
def __len__(self) -> int:
|
|
@@ -170,7 +170,7 @@ class SimpleClientManager(ClientManager):
|
|
|
170
170
|
with self._cv:
|
|
171
171
|
self._cv.notify_all()
|
|
172
172
|
|
|
173
|
-
def all(self) ->
|
|
173
|
+
def all(self) -> dict[str, ClientProxy]:
|
|
174
174
|
"""Return all available clients."""
|
|
175
175
|
return self.clients
|
|
176
176
|
|
|
@@ -179,7 +179,7 @@ class SimpleClientManager(ClientManager):
|
|
|
179
179
|
num_clients: int,
|
|
180
180
|
min_num_clients: Optional[int] = None,
|
|
181
181
|
criterion: Optional[Criterion] = None,
|
|
182
|
-
) ->
|
|
182
|
+
) -> list[ClientProxy]:
|
|
183
183
|
"""Sample a number of Flower ClientProxy instances."""
|
|
184
184
|
# Block until at least num_clients are connected.
|
|
185
185
|
if min_num_clients is None:
|
flwr/server/compat/app_utils.py
CHANGED
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
|
-
from typing import Dict, Tuple
|
|
20
19
|
|
|
21
20
|
from ..client_manager import ClientManager
|
|
22
21
|
from ..compat.driver_client_proxy import DriverClientProxy
|
|
@@ -26,7 +25,7 @@ from ..driver import Driver
|
|
|
26
25
|
def start_update_client_manager_thread(
|
|
27
26
|
driver: Driver,
|
|
28
27
|
client_manager: ClientManager,
|
|
29
|
-
) ->
|
|
28
|
+
) -> tuple[threading.Thread, threading.Event]:
|
|
30
29
|
"""Periodically update the nodes list in the client manager in a thread.
|
|
31
30
|
|
|
32
31
|
This function starts a thread that periodically uses the associated driver to
|
|
@@ -73,7 +72,7 @@ def _update_client_manager(
|
|
|
73
72
|
) -> None:
|
|
74
73
|
"""Update the nodes list in the client manager."""
|
|
75
74
|
# Loop until the driver is disconnected
|
|
76
|
-
registered_nodes:
|
|
75
|
+
registered_nodes: dict[int, DriverClientProxy] = {}
|
|
77
76
|
while not f_stop.is_set():
|
|
78
77
|
all_node_ids = set(driver.get_node_ids())
|
|
79
78
|
dead_nodes = set(registered_nodes).difference(all_node_ids)
|
flwr/server/driver/driver.py
CHANGED
|
@@ -16,7 +16,8 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
|
-
from
|
|
19
|
+
from collections.abc import Iterable
|
|
20
|
+
from typing import Optional
|
|
20
21
|
|
|
21
22
|
from flwr.common import Message, RecordSet
|
|
22
23
|
from flwr.common.typing import Run
|
|
@@ -70,7 +71,7 @@ class Driver(ABC):
|
|
|
70
71
|
"""
|
|
71
72
|
|
|
72
73
|
@abstractmethod
|
|
73
|
-
def get_node_ids(self) ->
|
|
74
|
+
def get_node_ids(self) -> list[int]:
|
|
74
75
|
"""Get node IDs."""
|
|
75
76
|
|
|
76
77
|
@abstractmethod
|
|
@@ -16,12 +16,14 @@
|
|
|
16
16
|
|
|
17
17
|
import time
|
|
18
18
|
import warnings
|
|
19
|
+
from collections.abc import Iterable
|
|
19
20
|
from logging import DEBUG, WARNING
|
|
20
|
-
from typing import
|
|
21
|
+
from typing import Optional, cast
|
|
21
22
|
|
|
22
23
|
import grpc
|
|
23
24
|
|
|
24
25
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
26
|
+
from flwr.common.constant import DRIVER_API_DEFAULT_ADDRESS
|
|
25
27
|
from flwr.common.grpc import create_channel
|
|
26
28
|
from flwr.common.logger import log
|
|
27
29
|
from flwr.common.serde import (
|
|
@@ -45,8 +47,6 @@ from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
|
45
47
|
|
|
46
48
|
from .driver import Driver
|
|
47
49
|
|
|
48
|
-
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
49
|
-
|
|
50
50
|
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
|
|
51
51
|
[Driver] Error: Not connected.
|
|
52
52
|
|
|
@@ -73,7 +73,7 @@ class GrpcDriver(Driver):
|
|
|
73
73
|
def __init__( # pylint: disable=too-many-arguments
|
|
74
74
|
self,
|
|
75
75
|
run_id: int,
|
|
76
|
-
driver_service_address: str =
|
|
76
|
+
driver_service_address: str = DRIVER_API_DEFAULT_ADDRESS,
|
|
77
77
|
root_certificates: Optional[bytes] = None,
|
|
78
78
|
) -> None:
|
|
79
79
|
self._run_id = run_id
|
|
@@ -193,7 +193,7 @@ class GrpcDriver(Driver):
|
|
|
193
193
|
)
|
|
194
194
|
return Message(metadata=metadata, content=content)
|
|
195
195
|
|
|
196
|
-
def get_node_ids(self) ->
|
|
196
|
+
def get_node_ids(self) -> list[int]:
|
|
197
197
|
"""Get node IDs."""
|
|
198
198
|
self._init_run()
|
|
199
199
|
# Call GrpcDriverStub method
|
|
@@ -210,7 +210,7 @@ class GrpcDriver(Driver):
|
|
|
210
210
|
"""
|
|
211
211
|
self._init_run()
|
|
212
212
|
# Construct TaskIns
|
|
213
|
-
task_ins_list:
|
|
213
|
+
task_ins_list: list[TaskIns] = []
|
|
214
214
|
for msg in messages:
|
|
215
215
|
# Check message
|
|
216
216
|
self._check_message(msg)
|
|
@@ -256,7 +256,7 @@ class GrpcDriver(Driver):
|
|
|
256
256
|
|
|
257
257
|
# Pull messages
|
|
258
258
|
end_time = time.time() + (timeout if timeout is not None else 0.0)
|
|
259
|
-
ret:
|
|
259
|
+
ret: list[Message] = []
|
|
260
260
|
while timeout is None or time.time() < end_time:
|
|
261
261
|
res_msgs = self.pull_messages(msg_ids)
|
|
262
262
|
ret.extend(res_msgs)
|
|
@@ -17,7 +17,8 @@
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
19
|
import warnings
|
|
20
|
-
from
|
|
20
|
+
from collections.abc import Iterable
|
|
21
|
+
from typing import Optional, cast
|
|
21
22
|
from uuid import UUID
|
|
22
23
|
|
|
23
24
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
@@ -112,7 +113,7 @@ class InMemoryDriver(Driver):
|
|
|
112
113
|
)
|
|
113
114
|
return Message(metadata=metadata, content=content)
|
|
114
115
|
|
|
115
|
-
def get_node_ids(self) ->
|
|
116
|
+
def get_node_ids(self) -> list[int]:
|
|
116
117
|
"""Get node IDs."""
|
|
117
118
|
self._init_run()
|
|
118
119
|
return list(self.state.get_nodes(cast(Run, self._run).run_id))
|
|
@@ -123,7 +124,7 @@ class InMemoryDriver(Driver):
|
|
|
123
124
|
This method takes an iterable of messages and sends each message
|
|
124
125
|
to the node specified in `dst_node_id`.
|
|
125
126
|
"""
|
|
126
|
-
task_ids:
|
|
127
|
+
task_ids: list[str] = []
|
|
127
128
|
for msg in messages:
|
|
128
129
|
# Check message
|
|
129
130
|
self._check_message(msg)
|
|
@@ -169,7 +170,7 @@ class InMemoryDriver(Driver):
|
|
|
169
170
|
|
|
170
171
|
# Pull messages
|
|
171
172
|
end_time = time.time() + (timeout if timeout is not None else 0.0)
|
|
172
|
-
ret:
|
|
173
|
+
ret: list[Message] = []
|
|
173
174
|
while timeout is None or time.time() < end_time:
|
|
174
175
|
res_msgs = self.pull_messages(msg_ids)
|
|
175
176
|
ret.extend(res_msgs)
|
flwr/server/history.py
CHANGED
|
@@ -17,7 +17,6 @@
|
|
|
17
17
|
|
|
18
18
|
import pprint
|
|
19
19
|
from functools import reduce
|
|
20
|
-
from typing import Dict, List, Tuple
|
|
21
20
|
|
|
22
21
|
from flwr.common.typing import Scalar
|
|
23
22
|
|
|
@@ -26,11 +25,11 @@ class History:
|
|
|
26
25
|
"""History class for training and/or evaluation metrics collection."""
|
|
27
26
|
|
|
28
27
|
def __init__(self) -> None:
|
|
29
|
-
self.losses_distributed:
|
|
30
|
-
self.losses_centralized:
|
|
31
|
-
self.metrics_distributed_fit:
|
|
32
|
-
self.metrics_distributed:
|
|
33
|
-
self.metrics_centralized:
|
|
28
|
+
self.losses_distributed: list[tuple[int, float]] = []
|
|
29
|
+
self.losses_centralized: list[tuple[int, float]] = []
|
|
30
|
+
self.metrics_distributed_fit: dict[str, list[tuple[int, Scalar]]] = {}
|
|
31
|
+
self.metrics_distributed: dict[str, list[tuple[int, Scalar]]] = {}
|
|
32
|
+
self.metrics_centralized: dict[str, list[tuple[int, Scalar]]] = {}
|
|
34
33
|
|
|
35
34
|
def add_loss_distributed(self, server_round: int, loss: float) -> None:
|
|
36
35
|
"""Add one loss entry (from distributed evaluation)."""
|
|
@@ -41,7 +40,7 @@ class History:
|
|
|
41
40
|
self.losses_centralized.append((server_round, loss))
|
|
42
41
|
|
|
43
42
|
def add_metrics_distributed_fit(
|
|
44
|
-
self, server_round: int, metrics:
|
|
43
|
+
self, server_round: int, metrics: dict[str, Scalar]
|
|
45
44
|
) -> None:
|
|
46
45
|
"""Add metrics entries (from distributed fit)."""
|
|
47
46
|
for key in metrics:
|
|
@@ -52,7 +51,7 @@ class History:
|
|
|
52
51
|
self.metrics_distributed_fit[key].append((server_round, metrics[key]))
|
|
53
52
|
|
|
54
53
|
def add_metrics_distributed(
|
|
55
|
-
self, server_round: int, metrics:
|
|
54
|
+
self, server_round: int, metrics: dict[str, Scalar]
|
|
56
55
|
) -> None:
|
|
57
56
|
"""Add metrics entries (from distributed evaluation)."""
|
|
58
57
|
for key in metrics:
|
|
@@ -63,7 +62,7 @@ class History:
|
|
|
63
62
|
self.metrics_distributed[key].append((server_round, metrics[key]))
|
|
64
63
|
|
|
65
64
|
def add_metrics_centralized(
|
|
66
|
-
self, server_round: int, metrics:
|
|
65
|
+
self, server_round: int, metrics: dict[str, Scalar]
|
|
67
66
|
) -> None:
|
|
68
67
|
"""Add metrics entries (from centralized evaluation)."""
|
|
69
68
|
for key in metrics:
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -31,6 +31,7 @@ from flwr.common.config import (
|
|
|
31
31
|
get_project_config,
|
|
32
32
|
get_project_dir,
|
|
33
33
|
)
|
|
34
|
+
from flwr.common.constant import DRIVER_API_DEFAULT_ADDRESS
|
|
34
35
|
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
|
|
35
36
|
from flwr.common.object_ref import load_app
|
|
36
37
|
from flwr.common.typing import UserConfig
|
|
@@ -44,8 +45,6 @@ from .driver import Driver
|
|
|
44
45
|
from .driver.grpc_driver import GrpcDriver
|
|
45
46
|
from .server_app import LoadServerAppError, ServerApp
|
|
46
47
|
|
|
47
|
-
ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
48
|
-
|
|
49
48
|
|
|
50
49
|
def run(
|
|
51
50
|
driver: Driver,
|
|
@@ -112,11 +111,11 @@ def run_server_app() -> None:
|
|
|
112
111
|
"app by executing `flwr new` and following the prompt."
|
|
113
112
|
)
|
|
114
113
|
|
|
115
|
-
if args.server !=
|
|
114
|
+
if args.server != DRIVER_API_DEFAULT_ADDRESS:
|
|
116
115
|
warn = "Passing flag --server is deprecated. Use --superlink instead."
|
|
117
116
|
warn_deprecated_feature(warn)
|
|
118
117
|
|
|
119
|
-
if args.superlink !=
|
|
118
|
+
if args.superlink != DRIVER_API_DEFAULT_ADDRESS:
|
|
120
119
|
# if `--superlink` also passed, then
|
|
121
120
|
# warn user that this argument overrides what was passed with `--server`
|
|
122
121
|
log(
|
|
@@ -275,12 +274,12 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
|
275
274
|
)
|
|
276
275
|
parser.add_argument(
|
|
277
276
|
"--server",
|
|
278
|
-
default=
|
|
277
|
+
default=DRIVER_API_DEFAULT_ADDRESS,
|
|
279
278
|
help="Server address",
|
|
280
279
|
)
|
|
281
280
|
parser.add_argument(
|
|
282
281
|
"--superlink",
|
|
283
|
-
default=
|
|
282
|
+
default=DRIVER_API_DEFAULT_ADDRESS,
|
|
284
283
|
help="SuperLink Driver API (gRPC-rere) address (IPv4, IPv6, or a domain name)",
|
|
285
284
|
)
|
|
286
285
|
parser.add_argument(
|