flwr-nightly 1.7.0.dev20240117__py3-none-any.whl → 1.7.0.dev20240118__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/app.py +1 -1
- flwr/client/grpc_client/connection.py +7 -4
- flwr/client/grpc_rere_client/connection.py +4 -4
- flwr/client/message_handler/message_handler.py +11 -2
- flwr/client/message_handler/task_handler.py +7 -4
- flwr/client/node_state_tests.py +1 -1
- flwr/client/rest_client/connection.py +3 -3
- flwr/client/typing.py +1 -1
- flwr/common/configsrecord.py +98 -0
- flwr/common/logger.py +14 -0
- flwr/common/metricsrecord.py +96 -0
- flwr/common/recordset.py +3 -10
- flwr/common/serde.py +2 -2
- flwr/common/typing.py +9 -0
- flwr/driver/app.py +5 -3
- flwr/driver/driver.py +3 -3
- flwr/driver/driver_client_proxy.py +24 -15
- flwr/driver/grpc_driver.py +2 -2
- flwr/proto/driver_pb2.py +23 -88
- flwr/proto/fleet_pb2.py +29 -111
- flwr/proto/node_pb2.py +7 -15
- flwr/proto/task_pb2.py +33 -127
- flwr/proto/transport_pb2.py +69 -278
- flwr/server/app.py +9 -3
- flwr/server/driver/driver_servicer.py +4 -4
- flwr/server/fleet/grpc_bidi/flower_service_servicer.py +5 -2
- flwr/server/fleet/grpc_bidi/grpc_bridge.py +4 -1
- flwr/server/fleet/grpc_bidi/grpc_client_proxy.py +4 -1
- flwr/server/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/fleet/grpc_bidi/ins_scheduler.py +6 -3
- flwr/server/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/fleet/message_handler/message_handler.py +3 -3
- flwr/server/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/state/in_memory_state.py +1 -1
- flwr/server/state/sqlite_state.py +6 -3
- flwr/server/state/state.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +9 -2
- flwr/server/utils/validator.py +1 -1
- {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/METADATA +3 -3
- {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/RECORD +43 -41
- {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/entry_points.txt +0 -0
flwr/client/app.py
CHANGED
@@ -35,7 +35,7 @@ from flwr.common.constant import (
|
|
35
35
|
TRANSPORT_TYPES,
|
36
36
|
)
|
37
37
|
from flwr.common.logger import log, warn_experimental_feature
|
38
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
38
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
39
39
|
|
40
40
|
from .flower import load_flower_callable
|
41
41
|
from .grpc_client.connection import grpc_connection
|
@@ -25,10 +25,13 @@ from typing import Callable, Iterator, Optional, Tuple, Union
|
|
25
25
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
26
26
|
from flwr.common.grpc import create_channel
|
27
27
|
from flwr.common.logger import log
|
28
|
-
from flwr.proto.node_pb2 import Node
|
29
|
-
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
30
|
-
from flwr.proto.transport_pb2 import
|
31
|
-
|
28
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
29
|
+
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
30
|
+
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
31
|
+
ClientMessage,
|
32
|
+
ServerMessage,
|
33
|
+
)
|
34
|
+
from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
|
32
35
|
|
33
36
|
# The following flags can be uncommented for debugging. Other possible values:
|
34
37
|
# https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
|
@@ -29,15 +29,15 @@ from flwr.client.message_handler.task_handler import (
|
|
29
29
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
30
30
|
from flwr.common.grpc import create_channel
|
31
31
|
from flwr.common.logger import log, warn_experimental_feature
|
32
|
-
from flwr.proto.fleet_pb2 import (
|
32
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
33
33
|
CreateNodeRequest,
|
34
34
|
DeleteNodeRequest,
|
35
35
|
PullTaskInsRequest,
|
36
36
|
PushTaskResRequest,
|
37
37
|
)
|
38
|
-
from flwr.proto.fleet_pb2_grpc import FleetStub
|
39
|
-
from flwr.proto.node_pb2 import Node
|
40
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
38
|
+
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
39
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
40
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
41
41
|
|
42
42
|
KEY_NODE = "node"
|
43
43
|
KEY_TASK_INS = "current_task_ins"
|
@@ -32,8 +32,17 @@ from flwr.client.run_state import RunState
|
|
32
32
|
from flwr.client.secure_aggregation import SecureAggregationHandler
|
33
33
|
from flwr.client.typing import ClientFn
|
34
34
|
from flwr.common import serde
|
35
|
-
from flwr.proto.task_pb2 import
|
36
|
-
|
35
|
+
from flwr.proto.task_pb2 import ( # pylint: disable=E0611
|
36
|
+
SecureAggregation,
|
37
|
+
Task,
|
38
|
+
TaskIns,
|
39
|
+
TaskRes,
|
40
|
+
)
|
41
|
+
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
42
|
+
ClientMessage,
|
43
|
+
Reason,
|
44
|
+
ServerMessage,
|
45
|
+
)
|
37
46
|
|
38
47
|
|
39
48
|
class UnexpectedServerMessage(Exception):
|
@@ -17,10 +17,13 @@
|
|
17
17
|
|
18
18
|
from typing import Optional
|
19
19
|
|
20
|
-
from flwr.proto.fleet_pb2 import PullTaskInsResponse
|
21
|
-
from flwr.proto.node_pb2 import Node
|
22
|
-
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
23
|
-
from flwr.proto.transport_pb2 import
|
20
|
+
from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611
|
21
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
22
|
+
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
23
|
+
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
24
|
+
ClientMessage,
|
25
|
+
ServerMessage,
|
26
|
+
)
|
24
27
|
|
25
28
|
|
26
29
|
def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool:
|
flwr/client/node_state_tests.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from flwr.client.node_state import NodeState
|
19
19
|
from flwr.client.run_state import RunState
|
20
|
-
from flwr.proto.task_pb2 import TaskIns
|
20
|
+
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
21
21
|
|
22
22
|
|
23
23
|
def _run_dummy_task(state: RunState) -> RunState:
|
@@ -29,7 +29,7 @@ from flwr.client.message_handler.task_handler import (
|
|
29
29
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
30
30
|
from flwr.common.constant import MISSING_EXTRA_REST
|
31
31
|
from flwr.common.logger import log
|
32
|
-
from flwr.proto.fleet_pb2 import (
|
32
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
33
33
|
CreateNodeRequest,
|
34
34
|
CreateNodeResponse,
|
35
35
|
DeleteNodeRequest,
|
@@ -38,8 +38,8 @@ from flwr.proto.fleet_pb2 import (
|
|
38
38
|
PushTaskResRequest,
|
39
39
|
PushTaskResResponse,
|
40
40
|
)
|
41
|
-
from flwr.proto.node_pb2 import Node
|
42
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
41
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
42
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
43
43
|
|
44
44
|
try:
|
45
45
|
import requests
|
flwr/client/typing.py
CHANGED
@@ -18,7 +18,7 @@ from dataclasses import dataclass
|
|
18
18
|
from typing import Callable
|
19
19
|
|
20
20
|
from flwr.client.run_state import RunState
|
21
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
21
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
22
22
|
|
23
23
|
from .client import Client as Client
|
24
24
|
|
@@ -0,0 +1,98 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""ConfigsRecord."""
|
16
|
+
|
17
|
+
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
from typing import Dict, Optional, get_args
|
20
|
+
|
21
|
+
from .typing import ConfigsRecordValues, ConfigsScalar
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class ConfigsRecord:
|
26
|
+
"""Configs record."""
|
27
|
+
|
28
|
+
keep_input: bool
|
29
|
+
data: Dict[str, ConfigsRecordValues] = field(default_factory=dict)
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None,
|
34
|
+
keep_input: bool = True,
|
35
|
+
):
|
36
|
+
"""Construct a ConfigsRecord object.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
configs_dict : Optional[Dict[str, ConfigsRecordValues]]
|
41
|
+
A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as
|
42
|
+
defined in `ConfigsScalar`) and lists of such types (see
|
43
|
+
`ConfigsScalarList`).
|
44
|
+
keep_input : bool (default: True)
|
45
|
+
A boolean indicating whether config passed should be deleted from the input
|
46
|
+
dictionary immediately after adding them to the record. When set
|
47
|
+
to True, the data is duplicated in memory. If memory is a concern, set
|
48
|
+
it to False.
|
49
|
+
"""
|
50
|
+
self.keep_input = keep_input
|
51
|
+
self.data = {}
|
52
|
+
if configs_dict:
|
53
|
+
self.set_configs(configs_dict)
|
54
|
+
|
55
|
+
def set_configs(self, configs_dict: Dict[str, ConfigsRecordValues]) -> None:
|
56
|
+
"""Add configs to the record.
|
57
|
+
|
58
|
+
Parameters
|
59
|
+
----------
|
60
|
+
configs_dict : Dict[str, ConfigsRecordValues]
|
61
|
+
A dictionary that stores basic types (i.e. `str`,`int`, `float`, `bytes` as
|
62
|
+
defined in `ConfigsRecordValues`) and list of such types (see
|
63
|
+
`ConfigsScalarList`).
|
64
|
+
"""
|
65
|
+
if any(not isinstance(k, str) for k in configs_dict.keys()):
|
66
|
+
raise TypeError(f"Not all keys are of valid type. Expected {str}")
|
67
|
+
|
68
|
+
def is_valid(value: ConfigsScalar) -> None:
|
69
|
+
"""Check if value is of expected type."""
|
70
|
+
if not isinstance(value, get_args(ConfigsScalar)):
|
71
|
+
raise TypeError(
|
72
|
+
"Not all values are of valid type."
|
73
|
+
f" Expected {ConfigsRecordValues} but you passed {type(value)}."
|
74
|
+
)
|
75
|
+
|
76
|
+
# Check types of values
|
77
|
+
# Split between those values that are list and those that aren't
|
78
|
+
# then process in the same way
|
79
|
+
for value in configs_dict.values():
|
80
|
+
if isinstance(value, list):
|
81
|
+
# If your lists are large (e.g. 1M+ elements) this will be slow
|
82
|
+
# 1s to check 10M element list on a M2 Pro
|
83
|
+
# In such settings, you'd be better of treating such config as
|
84
|
+
# an array and pass it to a ParametersRecord.
|
85
|
+
for list_value in value:
|
86
|
+
is_valid(list_value)
|
87
|
+
else:
|
88
|
+
is_valid(value)
|
89
|
+
|
90
|
+
# Add configs to record
|
91
|
+
if self.keep_input:
|
92
|
+
# Copy
|
93
|
+
self.data = configs_dict.copy()
|
94
|
+
else:
|
95
|
+
# Add entries to dataclass without duplicating memory
|
96
|
+
for key in list(configs_dict.keys()):
|
97
|
+
self.data[key] = configs_dict[key]
|
98
|
+
del configs_dict[key]
|
flwr/common/logger.py
CHANGED
@@ -111,3 +111,17 @@ def warn_experimental_feature(name: str) -> None:
|
|
111
111
|
""",
|
112
112
|
name,
|
113
113
|
)
|
114
|
+
|
115
|
+
|
116
|
+
def warn_deprecated_feature(name: str) -> None:
|
117
|
+
"""Warn the user when they use a deprecated feature."""
|
118
|
+
log(
|
119
|
+
WARN,
|
120
|
+
"""
|
121
|
+
DEPRECATED FEATURE: %s
|
122
|
+
|
123
|
+
This is a deprecated feature. It will be removed
|
124
|
+
entirely in future versions of Flower.
|
125
|
+
""",
|
126
|
+
name,
|
127
|
+
)
|
@@ -0,0 +1,96 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""MetricsRecord."""
|
16
|
+
|
17
|
+
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
from typing import Dict, Optional, get_args
|
20
|
+
|
21
|
+
from .typing import MetricsRecordValues, MetricsScalar
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class MetricsRecord:
|
26
|
+
"""Metrics record."""
|
27
|
+
|
28
|
+
keep_input: bool
|
29
|
+
data: Dict[str, MetricsRecordValues] = field(default_factory=dict)
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None,
|
34
|
+
keep_input: bool = True,
|
35
|
+
):
|
36
|
+
"""Construct a MetricsRecord object.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
metrics_dict : Optional[Dict[str, MetricsRecordValues]]
|
41
|
+
A dictionary that stores basic types (i.e. `int`, `float` as defined
|
42
|
+
in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
|
43
|
+
keep_input : bool (default: True)
|
44
|
+
A boolean indicating whether metrics should be deleted from the input
|
45
|
+
dictionary immediately after adding them to the record. When set
|
46
|
+
to True, the data is duplicated in memory. If memory is a concern, set
|
47
|
+
it to False.
|
48
|
+
"""
|
49
|
+
self.keep_input = keep_input
|
50
|
+
self.data = {}
|
51
|
+
if metrics_dict:
|
52
|
+
self.set_metrics(metrics_dict)
|
53
|
+
|
54
|
+
def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None:
|
55
|
+
"""Add metrics to the record.
|
56
|
+
|
57
|
+
Parameters
|
58
|
+
----------
|
59
|
+
metrics_dict : Dict[str, MetricsRecordValues]
|
60
|
+
A dictionary that stores basic types (i.e. `int`, `float` as defined
|
61
|
+
in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
|
62
|
+
"""
|
63
|
+
if any(not isinstance(k, str) for k in metrics_dict.keys()):
|
64
|
+
raise TypeError(f"Not all keys are of valid type. Expected {str}.")
|
65
|
+
|
66
|
+
def is_valid(value: MetricsScalar) -> None:
|
67
|
+
"""Check if value is of expected type."""
|
68
|
+
if not isinstance(value, get_args(MetricsScalar)):
|
69
|
+
raise TypeError(
|
70
|
+
"Not all values are of valid type."
|
71
|
+
f" Expected {MetricsRecordValues} but you passed {type(value)}."
|
72
|
+
)
|
73
|
+
|
74
|
+
# Check types of values
|
75
|
+
# Split between those values that are list and those that aren't
|
76
|
+
# then process in the same way
|
77
|
+
for value in metrics_dict.values():
|
78
|
+
if isinstance(value, list):
|
79
|
+
# If your lists are large (e.g. 1M+ elements) this will be slow
|
80
|
+
# 1s to check 10M element list on a M2 Pro
|
81
|
+
# In such settings, you'd be better of treating such metric as
|
82
|
+
# an array and pass it to a ParametersRecord.
|
83
|
+
for list_value in value:
|
84
|
+
is_valid(list_value)
|
85
|
+
else:
|
86
|
+
is_valid(value)
|
87
|
+
|
88
|
+
# Add metrics to record
|
89
|
+
if self.keep_input:
|
90
|
+
# Copy
|
91
|
+
self.data = metrics_dict.copy()
|
92
|
+
else:
|
93
|
+
# Add entries to dataclass without duplicating memory
|
94
|
+
for key in list(metrics_dict.keys()):
|
95
|
+
self.data[key] = metrics_dict[key]
|
96
|
+
del metrics_dict[key]
|
flwr/common/recordset.py
CHANGED
@@ -14,22 +14,15 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""RecordSet."""
|
16
16
|
|
17
|
+
|
17
18
|
from dataclasses import dataclass, field
|
18
19
|
from typing import Dict
|
19
20
|
|
21
|
+
from .configsrecord import ConfigsRecord
|
22
|
+
from .metricsrecord import MetricsRecord
|
20
23
|
from .parametersrecord import ParametersRecord
|
21
24
|
|
22
25
|
|
23
|
-
@dataclass
|
24
|
-
class MetricsRecord:
|
25
|
-
"""Metrics record."""
|
26
|
-
|
27
|
-
|
28
|
-
@dataclass
|
29
|
-
class ConfigsRecord:
|
30
|
-
"""Configs record."""
|
31
|
-
|
32
|
-
|
33
26
|
@dataclass
|
34
27
|
class RecordSet:
|
35
28
|
"""Definition of RecordSet."""
|
flwr/common/serde.py
CHANGED
@@ -17,8 +17,8 @@
|
|
17
17
|
|
18
18
|
from typing import Any, Dict, List, MutableMapping, cast
|
19
19
|
|
20
|
-
from flwr.proto.task_pb2 import Value
|
21
|
-
from flwr.proto.transport_pb2 import (
|
20
|
+
from flwr.proto.task_pb2 import Value # pylint: disable=E0611
|
21
|
+
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
22
22
|
ClientMessage,
|
23
23
|
Code,
|
24
24
|
Parameters,
|
flwr/common/typing.py
CHANGED
@@ -45,6 +45,15 @@ Value = Union[
|
|
45
45
|
List[str],
|
46
46
|
]
|
47
47
|
|
48
|
+
# Value types for common.MetricsRecord
|
49
|
+
MetricsScalar = Union[int, float]
|
50
|
+
MetricsScalarList = Union[List[int], List[float]]
|
51
|
+
MetricsRecordValues = Union[MetricsScalar, MetricsScalarList]
|
52
|
+
# Value types for common.ConfigsRecord
|
53
|
+
ConfigsScalar = Union[MetricsScalar, str, bytes]
|
54
|
+
ConfigsScalarList = Union[MetricsScalarList, List[str], List[bytes]]
|
55
|
+
ConfigsRecordValues = Union[ConfigsScalar, ConfigsScalarList]
|
56
|
+
|
48
57
|
Metrics = Dict[str, Scalar]
|
49
58
|
MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics]
|
50
59
|
|
flwr/driver/app.py
CHANGED
@@ -25,7 +25,7 @@ from typing import Dict, Optional, Union
|
|
25
25
|
from flwr.common import EventType, event
|
26
26
|
from flwr.common.address import parse_address
|
27
27
|
from flwr.common.logger import log
|
28
|
-
from flwr.proto import driver_pb2
|
28
|
+
from flwr.proto import driver_pb2 # pylint: disable=E0611
|
29
29
|
from flwr.server.app import ServerConfig, init_defaults, run_fl
|
30
30
|
from flwr.server.client_manager import ClientManager
|
31
31
|
from flwr.server.history import History
|
@@ -171,7 +171,9 @@ def update_client_manager(
|
|
171
171
|
`client_manager.unregister()`.
|
172
172
|
"""
|
173
173
|
# Request for run_id
|
174
|
-
run_id = driver.create_run(
|
174
|
+
run_id = driver.create_run(
|
175
|
+
driver_pb2.CreateRunRequest() # pylint: disable=E1101
|
176
|
+
).run_id
|
175
177
|
|
176
178
|
# Loop until the driver is disconnected
|
177
179
|
registered_nodes: Dict[int, DriverClientProxy] = {}
|
@@ -181,7 +183,7 @@ def update_client_manager(
|
|
181
183
|
if driver.stub is None:
|
182
184
|
break
|
183
185
|
get_nodes_res = driver.get_nodes(
|
184
|
-
req=driver_pb2.GetNodesRequest(run_id=run_id)
|
186
|
+
req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101
|
185
187
|
)
|
186
188
|
all_node_ids = {node.node_id for node in get_nodes_res.nodes}
|
187
189
|
dead_nodes = set(registered_nodes).difference(all_node_ids)
|
flwr/driver/driver.py
CHANGED
@@ -18,14 +18,14 @@
|
|
18
18
|
from typing import Iterable, List, Optional, Tuple
|
19
19
|
|
20
20
|
from flwr.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver
|
21
|
-
from flwr.proto.driver_pb2 import (
|
21
|
+
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
22
22
|
CreateRunRequest,
|
23
23
|
GetNodesRequest,
|
24
24
|
PullTaskResRequest,
|
25
25
|
PushTaskInsRequest,
|
26
26
|
)
|
27
|
-
from flwr.proto.node_pb2 import Node
|
28
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
27
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
28
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
29
29
|
|
30
30
|
|
31
31
|
class Driver:
|
@@ -20,7 +20,12 @@ from typing import List, Optional, cast
|
|
20
20
|
|
21
21
|
from flwr import common
|
22
22
|
from flwr.common import serde
|
23
|
-
from flwr.proto import
|
23
|
+
from flwr.proto import ( # pylint: disable=E0611
|
24
|
+
driver_pb2,
|
25
|
+
node_pb2,
|
26
|
+
task_pb2,
|
27
|
+
transport_pb2,
|
28
|
+
)
|
24
29
|
from flwr.server.client_proxy import ClientProxy
|
25
30
|
|
26
31
|
from .grpc_driver import GrpcDriver
|
@@ -42,7 +47,7 @@ class DriverClientProxy(ClientProxy):
|
|
42
47
|
self, ins: common.GetPropertiesIns, timeout: Optional[float]
|
43
48
|
) -> common.GetPropertiesRes:
|
44
49
|
"""Return client's properties."""
|
45
|
-
server_message_proto: transport_pb2.ServerMessage = (
|
50
|
+
server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101
|
46
51
|
serde.server_message_to_proto(
|
47
52
|
server_message=common.ServerMessage(get_properties_ins=ins)
|
48
53
|
)
|
@@ -56,7 +61,7 @@ class DriverClientProxy(ClientProxy):
|
|
56
61
|
self, ins: common.GetParametersIns, timeout: Optional[float]
|
57
62
|
) -> common.GetParametersRes:
|
58
63
|
"""Return the current local model parameters."""
|
59
|
-
server_message_proto: transport_pb2.ServerMessage = (
|
64
|
+
server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101
|
60
65
|
serde.server_message_to_proto(
|
61
66
|
server_message=common.ServerMessage(get_parameters_ins=ins)
|
62
67
|
)
|
@@ -68,7 +73,7 @@ class DriverClientProxy(ClientProxy):
|
|
68
73
|
|
69
74
|
def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes:
|
70
75
|
"""Train model parameters on the locally held dataset."""
|
71
|
-
server_message_proto: transport_pb2.ServerMessage = (
|
76
|
+
server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101
|
72
77
|
serde.server_message_to_proto(
|
73
78
|
server_message=common.ServerMessage(fit_ins=ins)
|
74
79
|
)
|
@@ -82,7 +87,7 @@ class DriverClientProxy(ClientProxy):
|
|
82
87
|
self, ins: common.EvaluateIns, timeout: Optional[float]
|
83
88
|
) -> common.EvaluateRes:
|
84
89
|
"""Evaluate model parameters on the locally held dataset."""
|
85
|
-
server_message_proto: transport_pb2.ServerMessage = (
|
90
|
+
server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101
|
86
91
|
serde.server_message_to_proto(
|
87
92
|
server_message=common.ServerMessage(evaluate_ins=ins)
|
88
93
|
)
|
@@ -99,25 +104,29 @@ class DriverClientProxy(ClientProxy):
|
|
99
104
|
return common.DisconnectRes(reason="") # Nothing to do here (yet)
|
100
105
|
|
101
106
|
def _send_receive_msg(
|
102
|
-
self,
|
103
|
-
|
104
|
-
|
107
|
+
self,
|
108
|
+
server_message: transport_pb2.ServerMessage, # pylint: disable=E1101
|
109
|
+
timeout: Optional[float],
|
110
|
+
) -> transport_pb2.ClientMessage: # pylint: disable=E1101
|
111
|
+
task_ins = task_pb2.TaskIns( # pylint: disable=E1101
|
105
112
|
task_id="",
|
106
113
|
group_id="",
|
107
114
|
run_id=self.run_id,
|
108
|
-
task=task_pb2.Task(
|
109
|
-
producer=node_pb2.Node(
|
115
|
+
task=task_pb2.Task( # pylint: disable=E1101
|
116
|
+
producer=node_pb2.Node( # pylint: disable=E1101
|
110
117
|
node_id=0,
|
111
118
|
anonymous=True,
|
112
119
|
),
|
113
|
-
consumer=node_pb2.Node(
|
120
|
+
consumer=node_pb2.Node( # pylint: disable=E1101
|
114
121
|
node_id=self.node_id,
|
115
122
|
anonymous=self.anonymous,
|
116
123
|
),
|
117
124
|
legacy_server_message=server_message,
|
118
125
|
),
|
119
126
|
)
|
120
|
-
push_task_ins_req = driver_pb2.PushTaskInsRequest(
|
127
|
+
push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101
|
128
|
+
task_ins_list=[task_ins]
|
129
|
+
)
|
121
130
|
|
122
131
|
# Send TaskIns to Driver API
|
123
132
|
push_task_ins_res = self.driver.push_task_ins(req=push_task_ins_req)
|
@@ -133,15 +142,15 @@ class DriverClientProxy(ClientProxy):
|
|
133
142
|
start_time = time.time()
|
134
143
|
|
135
144
|
while True:
|
136
|
-
pull_task_res_req = driver_pb2.PullTaskResRequest(
|
137
|
-
node=node_pb2.Node(node_id=0, anonymous=True),
|
145
|
+
pull_task_res_req = driver_pb2.PullTaskResRequest( # pylint: disable=E1101
|
146
|
+
node=node_pb2.Node(node_id=0, anonymous=True), # pylint: disable=E1101
|
138
147
|
task_ids=[task_id],
|
139
148
|
)
|
140
149
|
|
141
150
|
# Ask Driver API for TaskRes
|
142
151
|
pull_task_res_res = self.driver.pull_task_res(req=pull_task_res_req)
|
143
152
|
|
144
|
-
task_res_list: List[task_pb2.TaskRes] = list(
|
153
|
+
task_res_list: List[task_pb2.TaskRes] = list( # pylint: disable=E1101
|
145
154
|
pull_task_res_res.task_res_list
|
146
155
|
)
|
147
156
|
if len(task_res_list) == 1:
|
flwr/driver/grpc_driver.py
CHANGED
@@ -23,7 +23,7 @@ import grpc
|
|
23
23
|
from flwr.common import EventType, event
|
24
24
|
from flwr.common.grpc import create_channel
|
25
25
|
from flwr.common.logger import log
|
26
|
-
from flwr.proto.driver_pb2 import (
|
26
|
+
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
27
27
|
CreateRunRequest,
|
28
28
|
CreateRunResponse,
|
29
29
|
GetNodesRequest,
|
@@ -33,7 +33,7 @@ from flwr.proto.driver_pb2 import (
|
|
33
33
|
PushTaskInsRequest,
|
34
34
|
PushTaskInsResponse,
|
35
35
|
)
|
36
|
-
from flwr.proto.driver_pb2_grpc import DriverStub
|
36
|
+
from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
|
37
37
|
|
38
38
|
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
39
39
|
|