flwr-nightly 1.7.0.dev20240117__py3-none-any.whl → 1.7.0.dev20240118__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. flwr/client/app.py +1 -1
  2. flwr/client/grpc_client/connection.py +7 -4
  3. flwr/client/grpc_rere_client/connection.py +4 -4
  4. flwr/client/message_handler/message_handler.py +11 -2
  5. flwr/client/message_handler/task_handler.py +7 -4
  6. flwr/client/node_state_tests.py +1 -1
  7. flwr/client/rest_client/connection.py +3 -3
  8. flwr/client/typing.py +1 -1
  9. flwr/common/configsrecord.py +98 -0
  10. flwr/common/logger.py +14 -0
  11. flwr/common/metricsrecord.py +96 -0
  12. flwr/common/recordset.py +3 -10
  13. flwr/common/serde.py +2 -2
  14. flwr/common/typing.py +9 -0
  15. flwr/driver/app.py +5 -3
  16. flwr/driver/driver.py +3 -3
  17. flwr/driver/driver_client_proxy.py +24 -15
  18. flwr/driver/grpc_driver.py +2 -2
  19. flwr/proto/driver_pb2.py +23 -88
  20. flwr/proto/fleet_pb2.py +29 -111
  21. flwr/proto/node_pb2.py +7 -15
  22. flwr/proto/task_pb2.py +33 -127
  23. flwr/proto/transport_pb2.py +69 -278
  24. flwr/server/app.py +9 -3
  25. flwr/server/driver/driver_servicer.py +4 -4
  26. flwr/server/fleet/grpc_bidi/flower_service_servicer.py +5 -2
  27. flwr/server/fleet/grpc_bidi/grpc_bridge.py +4 -1
  28. flwr/server/fleet/grpc_bidi/grpc_client_proxy.py +4 -1
  29. flwr/server/fleet/grpc_bidi/grpc_server.py +3 -1
  30. flwr/server/fleet/grpc_bidi/ins_scheduler.py +6 -3
  31. flwr/server/fleet/grpc_rere/fleet_servicer.py +2 -2
  32. flwr/server/fleet/message_handler/message_handler.py +3 -3
  33. flwr/server/fleet/rest_rere/rest_api.py +1 -1
  34. flwr/server/state/in_memory_state.py +1 -1
  35. flwr/server/state/sqlite_state.py +6 -3
  36. flwr/server/state/state.py +1 -1
  37. flwr/server/strategy/fedxgb_nn_avg.py +9 -2
  38. flwr/server/utils/validator.py +1 -1
  39. {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/METADATA +3 -3
  40. {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/RECORD +43 -41
  41. {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/LICENSE +0 -0
  42. {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/WHEEL +0 -0
  43. {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/entry_points.txt +0 -0
flwr/client/app.py CHANGED
@@ -35,7 +35,7 @@ from flwr.common.constant import (
35
35
  TRANSPORT_TYPES,
36
36
  )
37
37
  from flwr.common.logger import log, warn_experimental_feature
38
- from flwr.proto.task_pb2 import TaskIns, TaskRes
38
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
39
39
 
40
40
  from .flower import load_flower_callable
41
41
  from .grpc_client.connection import grpc_connection
@@ -25,10 +25,13 @@ from typing import Callable, Iterator, Optional, Tuple, Union
25
25
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
26
26
  from flwr.common.grpc import create_channel
27
27
  from flwr.common.logger import log
28
- from flwr.proto.node_pb2 import Node
29
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
30
- from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
31
- from flwr.proto.transport_pb2_grpc import FlowerServiceStub
28
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
29
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
30
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
31
+ ClientMessage,
32
+ ServerMessage,
33
+ )
34
+ from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
32
35
 
33
36
  # The following flags can be uncommented for debugging. Other possible values:
34
37
  # https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
@@ -29,15 +29,15 @@ from flwr.client.message_handler.task_handler import (
29
29
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
30
30
  from flwr.common.grpc import create_channel
31
31
  from flwr.common.logger import log, warn_experimental_feature
32
- from flwr.proto.fleet_pb2 import (
32
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
33
  CreateNodeRequest,
34
34
  DeleteNodeRequest,
35
35
  PullTaskInsRequest,
36
36
  PushTaskResRequest,
37
37
  )
38
- from flwr.proto.fleet_pb2_grpc import FleetStub
39
- from flwr.proto.node_pb2 import Node
40
- from flwr.proto.task_pb2 import TaskIns, TaskRes
38
+ from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
39
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
40
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
41
41
 
42
42
  KEY_NODE = "node"
43
43
  KEY_TASK_INS = "current_task_ins"
@@ -32,8 +32,17 @@ from flwr.client.run_state import RunState
32
32
  from flwr.client.secure_aggregation import SecureAggregationHandler
33
33
  from flwr.client.typing import ClientFn
34
34
  from flwr.common import serde
35
- from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes
36
- from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage
35
+ from flwr.proto.task_pb2 import ( # pylint: disable=E0611
36
+ SecureAggregation,
37
+ Task,
38
+ TaskIns,
39
+ TaskRes,
40
+ )
41
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
42
+ ClientMessage,
43
+ Reason,
44
+ ServerMessage,
45
+ )
37
46
 
38
47
 
39
48
  class UnexpectedServerMessage(Exception):
@@ -17,10 +17,13 @@
17
17
 
18
18
  from typing import Optional
19
19
 
20
- from flwr.proto.fleet_pb2 import PullTaskInsResponse
21
- from flwr.proto.node_pb2 import Node
22
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
23
- from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
20
+ from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611
21
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
22
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
23
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
24
+ ClientMessage,
25
+ ServerMessage,
26
+ )
24
27
 
25
28
 
26
29
  def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool:
@@ -17,7 +17,7 @@
17
17
 
18
18
  from flwr.client.node_state import NodeState
19
19
  from flwr.client.run_state import RunState
20
- from flwr.proto.task_pb2 import TaskIns
20
+ from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
21
21
 
22
22
 
23
23
  def _run_dummy_task(state: RunState) -> RunState:
@@ -29,7 +29,7 @@ from flwr.client.message_handler.task_handler import (
29
29
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
30
30
  from flwr.common.constant import MISSING_EXTRA_REST
31
31
  from flwr.common.logger import log
32
- from flwr.proto.fleet_pb2 import (
32
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
33
  CreateNodeRequest,
34
34
  CreateNodeResponse,
35
35
  DeleteNodeRequest,
@@ -38,8 +38,8 @@ from flwr.proto.fleet_pb2 import (
38
38
  PushTaskResRequest,
39
39
  PushTaskResResponse,
40
40
  )
41
- from flwr.proto.node_pb2 import Node
42
- from flwr.proto.task_pb2 import TaskIns, TaskRes
41
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
42
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
43
43
 
44
44
  try:
45
45
  import requests
flwr/client/typing.py CHANGED
@@ -18,7 +18,7 @@ from dataclasses import dataclass
18
18
  from typing import Callable
19
19
 
20
20
  from flwr.client.run_state import RunState
21
- from flwr.proto.task_pb2 import TaskIns, TaskRes
21
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
22
22
 
23
23
  from .client import Client as Client
24
24
 
@@ -0,0 +1,98 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ConfigsRecord."""
16
+
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import Dict, Optional, get_args
20
+
21
+ from .typing import ConfigsRecordValues, ConfigsScalar
22
+
23
+
24
+ @dataclass
25
+ class ConfigsRecord:
26
+ """Configs record."""
27
+
28
+ keep_input: bool
29
+ data: Dict[str, ConfigsRecordValues] = field(default_factory=dict)
30
+
31
+ def __init__(
32
+ self,
33
+ configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None,
34
+ keep_input: bool = True,
35
+ ):
36
+ """Construct a ConfigsRecord object.
37
+
38
+ Parameters
39
+ ----------
40
+ configs_dict : Optional[Dict[str, ConfigsRecordValues]]
41
+ A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as
42
+ defined in `ConfigsScalar`) and lists of such types (see
43
+ `ConfigsScalarList`).
44
+ keep_input : bool (default: True)
45
+ A boolean indicating whether config passed should be deleted from the input
46
+ dictionary immediately after adding them to the record. When set
47
+ to True, the data is duplicated in memory. If memory is a concern, set
48
+ it to False.
49
+ """
50
+ self.keep_input = keep_input
51
+ self.data = {}
52
+ if configs_dict:
53
+ self.set_configs(configs_dict)
54
+
55
+ def set_configs(self, configs_dict: Dict[str, ConfigsRecordValues]) -> None:
56
+ """Add configs to the record.
57
+
58
+ Parameters
59
+ ----------
60
+ configs_dict : Dict[str, ConfigsRecordValues]
61
+ A dictionary that stores basic types (i.e. `str`,`int`, `float`, `bytes` as
62
+ defined in `ConfigsRecordValues`) and list of such types (see
63
+ `ConfigsScalarList`).
64
+ """
65
+ if any(not isinstance(k, str) for k in configs_dict.keys()):
66
+ raise TypeError(f"Not all keys are of valid type. Expected {str}")
67
+
68
+ def is_valid(value: ConfigsScalar) -> None:
69
+ """Check if value is of expected type."""
70
+ if not isinstance(value, get_args(ConfigsScalar)):
71
+ raise TypeError(
72
+ "Not all values are of valid type."
73
+ f" Expected {ConfigsRecordValues} but you passed {type(value)}."
74
+ )
75
+
76
+ # Check types of values
77
+ # Split between those values that are list and those that aren't
78
+ # then process in the same way
79
+ for value in configs_dict.values():
80
+ if isinstance(value, list):
81
+ # If your lists are large (e.g. 1M+ elements) this will be slow
82
+ # 1s to check 10M element list on a M2 Pro
83
+ # In such settings, you'd be better of treating such config as
84
+ # an array and pass it to a ParametersRecord.
85
+ for list_value in value:
86
+ is_valid(list_value)
87
+ else:
88
+ is_valid(value)
89
+
90
+ # Add configs to record
91
+ if self.keep_input:
92
+ # Copy
93
+ self.data = configs_dict.copy()
94
+ else:
95
+ # Add entries to dataclass without duplicating memory
96
+ for key in list(configs_dict.keys()):
97
+ self.data[key] = configs_dict[key]
98
+ del configs_dict[key]
flwr/common/logger.py CHANGED
@@ -111,3 +111,17 @@ def warn_experimental_feature(name: str) -> None:
111
111
  """,
112
112
  name,
113
113
  )
114
+
115
+
116
+ def warn_deprecated_feature(name: str) -> None:
117
+ """Warn the user when they use a deprecated feature."""
118
+ log(
119
+ WARN,
120
+ """
121
+ DEPRECATED FEATURE: %s
122
+
123
+ This is a deprecated feature. It will be removed
124
+ entirely in future versions of Flower.
125
+ """,
126
+ name,
127
+ )
@@ -0,0 +1,96 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """MetricsRecord."""
16
+
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import Dict, Optional, get_args
20
+
21
+ from .typing import MetricsRecordValues, MetricsScalar
22
+
23
+
24
+ @dataclass
25
+ class MetricsRecord:
26
+ """Metrics record."""
27
+
28
+ keep_input: bool
29
+ data: Dict[str, MetricsRecordValues] = field(default_factory=dict)
30
+
31
+ def __init__(
32
+ self,
33
+ metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None,
34
+ keep_input: bool = True,
35
+ ):
36
+ """Construct a MetricsRecord object.
37
+
38
+ Parameters
39
+ ----------
40
+ metrics_dict : Optional[Dict[str, MetricsRecordValues]]
41
+ A dictionary that stores basic types (i.e. `int`, `float` as defined
42
+ in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
43
+ keep_input : bool (default: True)
44
+ A boolean indicating whether metrics should be deleted from the input
45
+ dictionary immediately after adding them to the record. When set
46
+ to True, the data is duplicated in memory. If memory is a concern, set
47
+ it to False.
48
+ """
49
+ self.keep_input = keep_input
50
+ self.data = {}
51
+ if metrics_dict:
52
+ self.set_metrics(metrics_dict)
53
+
54
+ def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None:
55
+ """Add metrics to the record.
56
+
57
+ Parameters
58
+ ----------
59
+ metrics_dict : Dict[str, MetricsRecordValues]
60
+ A dictionary that stores basic types (i.e. `int`, `float` as defined
61
+ in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
62
+ """
63
+ if any(not isinstance(k, str) for k in metrics_dict.keys()):
64
+ raise TypeError(f"Not all keys are of valid type. Expected {str}.")
65
+
66
+ def is_valid(value: MetricsScalar) -> None:
67
+ """Check if value is of expected type."""
68
+ if not isinstance(value, get_args(MetricsScalar)):
69
+ raise TypeError(
70
+ "Not all values are of valid type."
71
+ f" Expected {MetricsRecordValues} but you passed {type(value)}."
72
+ )
73
+
74
+ # Check types of values
75
+ # Split between those values that are list and those that aren't
76
+ # then process in the same way
77
+ for value in metrics_dict.values():
78
+ if isinstance(value, list):
79
+ # If your lists are large (e.g. 1M+ elements) this will be slow
80
+ # 1s to check 10M element list on a M2 Pro
81
+ # In such settings, you'd be better of treating such metric as
82
+ # an array and pass it to a ParametersRecord.
83
+ for list_value in value:
84
+ is_valid(list_value)
85
+ else:
86
+ is_valid(value)
87
+
88
+ # Add metrics to record
89
+ if self.keep_input:
90
+ # Copy
91
+ self.data = metrics_dict.copy()
92
+ else:
93
+ # Add entries to dataclass without duplicating memory
94
+ for key in list(metrics_dict.keys()):
95
+ self.data[key] = metrics_dict[key]
96
+ del metrics_dict[key]
flwr/common/recordset.py CHANGED
@@ -14,22 +14,15 @@
14
14
  # ==============================================================================
15
15
  """RecordSet."""
16
16
 
17
+
17
18
  from dataclasses import dataclass, field
18
19
  from typing import Dict
19
20
 
21
+ from .configsrecord import ConfigsRecord
22
+ from .metricsrecord import MetricsRecord
20
23
  from .parametersrecord import ParametersRecord
21
24
 
22
25
 
23
- @dataclass
24
- class MetricsRecord:
25
- """Metrics record."""
26
-
27
-
28
- @dataclass
29
- class ConfigsRecord:
30
- """Configs record."""
31
-
32
-
33
26
  @dataclass
34
27
  class RecordSet:
35
28
  """Definition of RecordSet."""
flwr/common/serde.py CHANGED
@@ -17,8 +17,8 @@
17
17
 
18
18
  from typing import Any, Dict, List, MutableMapping, cast
19
19
 
20
- from flwr.proto.task_pb2 import Value
21
- from flwr.proto.transport_pb2 import (
20
+ from flwr.proto.task_pb2 import Value # pylint: disable=E0611
21
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
22
22
  ClientMessage,
23
23
  Code,
24
24
  Parameters,
flwr/common/typing.py CHANGED
@@ -45,6 +45,15 @@ Value = Union[
45
45
  List[str],
46
46
  ]
47
47
 
48
+ # Value types for common.MetricsRecord
49
+ MetricsScalar = Union[int, float]
50
+ MetricsScalarList = Union[List[int], List[float]]
51
+ MetricsRecordValues = Union[MetricsScalar, MetricsScalarList]
52
+ # Value types for common.ConfigsRecord
53
+ ConfigsScalar = Union[MetricsScalar, str, bytes]
54
+ ConfigsScalarList = Union[MetricsScalarList, List[str], List[bytes]]
55
+ ConfigsRecordValues = Union[ConfigsScalar, ConfigsScalarList]
56
+
48
57
  Metrics = Dict[str, Scalar]
49
58
  MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics]
50
59
 
flwr/driver/app.py CHANGED
@@ -25,7 +25,7 @@ from typing import Dict, Optional, Union
25
25
  from flwr.common import EventType, event
26
26
  from flwr.common.address import parse_address
27
27
  from flwr.common.logger import log
28
- from flwr.proto import driver_pb2
28
+ from flwr.proto import driver_pb2 # pylint: disable=E0611
29
29
  from flwr.server.app import ServerConfig, init_defaults, run_fl
30
30
  from flwr.server.client_manager import ClientManager
31
31
  from flwr.server.history import History
@@ -171,7 +171,9 @@ def update_client_manager(
171
171
  `client_manager.unregister()`.
172
172
  """
173
173
  # Request for run_id
174
- run_id = driver.create_run(driver_pb2.CreateRunRequest()).run_id
174
+ run_id = driver.create_run(
175
+ driver_pb2.CreateRunRequest() # pylint: disable=E1101
176
+ ).run_id
175
177
 
176
178
  # Loop until the driver is disconnected
177
179
  registered_nodes: Dict[int, DriverClientProxy] = {}
@@ -181,7 +183,7 @@ def update_client_manager(
181
183
  if driver.stub is None:
182
184
  break
183
185
  get_nodes_res = driver.get_nodes(
184
- req=driver_pb2.GetNodesRequest(run_id=run_id)
186
+ req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101
185
187
  )
186
188
  all_node_ids = {node.node_id for node in get_nodes_res.nodes}
187
189
  dead_nodes = set(registered_nodes).difference(all_node_ids)
flwr/driver/driver.py CHANGED
@@ -18,14 +18,14 @@
18
18
  from typing import Iterable, List, Optional, Tuple
19
19
 
20
20
  from flwr.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver
21
- from flwr.proto.driver_pb2 import (
21
+ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
22
22
  CreateRunRequest,
23
23
  GetNodesRequest,
24
24
  PullTaskResRequest,
25
25
  PushTaskInsRequest,
26
26
  )
27
- from flwr.proto.node_pb2 import Node
28
- from flwr.proto.task_pb2 import TaskIns, TaskRes
27
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
28
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
29
29
 
30
30
 
31
31
  class Driver:
@@ -20,7 +20,12 @@ from typing import List, Optional, cast
20
20
 
21
21
  from flwr import common
22
22
  from flwr.common import serde
23
- from flwr.proto import driver_pb2, node_pb2, task_pb2, transport_pb2
23
+ from flwr.proto import ( # pylint: disable=E0611
24
+ driver_pb2,
25
+ node_pb2,
26
+ task_pb2,
27
+ transport_pb2,
28
+ )
24
29
  from flwr.server.client_proxy import ClientProxy
25
30
 
26
31
  from .grpc_driver import GrpcDriver
@@ -42,7 +47,7 @@ class DriverClientProxy(ClientProxy):
42
47
  self, ins: common.GetPropertiesIns, timeout: Optional[float]
43
48
  ) -> common.GetPropertiesRes:
44
49
  """Return client's properties."""
45
- server_message_proto: transport_pb2.ServerMessage = (
50
+ server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101
46
51
  serde.server_message_to_proto(
47
52
  server_message=common.ServerMessage(get_properties_ins=ins)
48
53
  )
@@ -56,7 +61,7 @@ class DriverClientProxy(ClientProxy):
56
61
  self, ins: common.GetParametersIns, timeout: Optional[float]
57
62
  ) -> common.GetParametersRes:
58
63
  """Return the current local model parameters."""
59
- server_message_proto: transport_pb2.ServerMessage = (
64
+ server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101
60
65
  serde.server_message_to_proto(
61
66
  server_message=common.ServerMessage(get_parameters_ins=ins)
62
67
  )
@@ -68,7 +73,7 @@ class DriverClientProxy(ClientProxy):
68
73
 
69
74
  def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes:
70
75
  """Train model parameters on the locally held dataset."""
71
- server_message_proto: transport_pb2.ServerMessage = (
76
+ server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101
72
77
  serde.server_message_to_proto(
73
78
  server_message=common.ServerMessage(fit_ins=ins)
74
79
  )
@@ -82,7 +87,7 @@ class DriverClientProxy(ClientProxy):
82
87
  self, ins: common.EvaluateIns, timeout: Optional[float]
83
88
  ) -> common.EvaluateRes:
84
89
  """Evaluate model parameters on the locally held dataset."""
85
- server_message_proto: transport_pb2.ServerMessage = (
90
+ server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101
86
91
  serde.server_message_to_proto(
87
92
  server_message=common.ServerMessage(evaluate_ins=ins)
88
93
  )
@@ -99,25 +104,29 @@ class DriverClientProxy(ClientProxy):
99
104
  return common.DisconnectRes(reason="") # Nothing to do here (yet)
100
105
 
101
106
  def _send_receive_msg(
102
- self, server_message: transport_pb2.ServerMessage, timeout: Optional[float]
103
- ) -> transport_pb2.ClientMessage:
104
- task_ins = task_pb2.TaskIns(
107
+ self,
108
+ server_message: transport_pb2.ServerMessage, # pylint: disable=E1101
109
+ timeout: Optional[float],
110
+ ) -> transport_pb2.ClientMessage: # pylint: disable=E1101
111
+ task_ins = task_pb2.TaskIns( # pylint: disable=E1101
105
112
  task_id="",
106
113
  group_id="",
107
114
  run_id=self.run_id,
108
- task=task_pb2.Task(
109
- producer=node_pb2.Node(
115
+ task=task_pb2.Task( # pylint: disable=E1101
116
+ producer=node_pb2.Node( # pylint: disable=E1101
110
117
  node_id=0,
111
118
  anonymous=True,
112
119
  ),
113
- consumer=node_pb2.Node(
120
+ consumer=node_pb2.Node( # pylint: disable=E1101
114
121
  node_id=self.node_id,
115
122
  anonymous=self.anonymous,
116
123
  ),
117
124
  legacy_server_message=server_message,
118
125
  ),
119
126
  )
120
- push_task_ins_req = driver_pb2.PushTaskInsRequest(task_ins_list=[task_ins])
127
+ push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101
128
+ task_ins_list=[task_ins]
129
+ )
121
130
 
122
131
  # Send TaskIns to Driver API
123
132
  push_task_ins_res = self.driver.push_task_ins(req=push_task_ins_req)
@@ -133,15 +142,15 @@ class DriverClientProxy(ClientProxy):
133
142
  start_time = time.time()
134
143
 
135
144
  while True:
136
- pull_task_res_req = driver_pb2.PullTaskResRequest(
137
- node=node_pb2.Node(node_id=0, anonymous=True),
145
+ pull_task_res_req = driver_pb2.PullTaskResRequest( # pylint: disable=E1101
146
+ node=node_pb2.Node(node_id=0, anonymous=True), # pylint: disable=E1101
138
147
  task_ids=[task_id],
139
148
  )
140
149
 
141
150
  # Ask Driver API for TaskRes
142
151
  pull_task_res_res = self.driver.pull_task_res(req=pull_task_res_req)
143
152
 
144
- task_res_list: List[task_pb2.TaskRes] = list(
153
+ task_res_list: List[task_pb2.TaskRes] = list( # pylint: disable=E1101
145
154
  pull_task_res_res.task_res_list
146
155
  )
147
156
  if len(task_res_list) == 1:
@@ -23,7 +23,7 @@ import grpc
23
23
  from flwr.common import EventType, event
24
24
  from flwr.common.grpc import create_channel
25
25
  from flwr.common.logger import log
26
- from flwr.proto.driver_pb2 import (
26
+ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
27
27
  CreateRunRequest,
28
28
  CreateRunResponse,
29
29
  GetNodesRequest,
@@ -33,7 +33,7 @@ from flwr.proto.driver_pb2 import (
33
33
  PushTaskInsRequest,
34
34
  PushTaskInsResponse,
35
35
  )
36
- from flwr.proto.driver_pb2_grpc import DriverStub
36
+ from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
37
37
 
38
38
  DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
39
39