flwr-nightly 1.7.0.dev20240117__py3-none-any.whl → 1.7.0.dev20240119__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (46) hide show
  1. flwr/client/app.py +1 -1
  2. flwr/client/grpc_client/connection.py +7 -4
  3. flwr/client/grpc_rere_client/connection.py +4 -4
  4. flwr/client/message_handler/message_handler.py +11 -2
  5. flwr/client/message_handler/task_handler.py +7 -4
  6. flwr/client/node_state_tests.py +1 -1
  7. flwr/client/rest_client/connection.py +3 -3
  8. flwr/client/typing.py +1 -1
  9. flwr/common/configsrecord.py +107 -0
  10. flwr/common/flowercontext.py +77 -0
  11. flwr/common/logger.py +14 -0
  12. flwr/common/metricsrecord.py +107 -0
  13. flwr/common/parametersrecord.py +12 -5
  14. flwr/common/recordset.py +3 -10
  15. flwr/common/serde.py +2 -2
  16. flwr/common/typing.py +9 -0
  17. flwr/driver/app.py +5 -3
  18. flwr/driver/driver.py +3 -3
  19. flwr/driver/driver_client_proxy.py +24 -15
  20. flwr/driver/grpc_driver.py +2 -2
  21. flwr/proto/driver_pb2.py +23 -88
  22. flwr/proto/fleet_pb2.py +29 -111
  23. flwr/proto/node_pb2.py +7 -15
  24. flwr/proto/task_pb2.py +34 -128
  25. flwr/proto/task_pb2.pyi +4 -1
  26. flwr/proto/transport_pb2.py +69 -278
  27. flwr/server/app.py +9 -3
  28. flwr/server/driver/driver_servicer.py +4 -4
  29. flwr/server/fleet/grpc_bidi/flower_service_servicer.py +5 -2
  30. flwr/server/fleet/grpc_bidi/grpc_bridge.py +4 -1
  31. flwr/server/fleet/grpc_bidi/grpc_client_proxy.py +4 -1
  32. flwr/server/fleet/grpc_bidi/grpc_server.py +3 -1
  33. flwr/server/fleet/grpc_bidi/ins_scheduler.py +6 -3
  34. flwr/server/fleet/grpc_rere/fleet_servicer.py +2 -2
  35. flwr/server/fleet/message_handler/message_handler.py +3 -3
  36. flwr/server/fleet/rest_rere/rest_api.py +1 -1
  37. flwr/server/state/in_memory_state.py +1 -1
  38. flwr/server/state/sqlite_state.py +6 -3
  39. flwr/server/state/state.py +1 -1
  40. flwr/server/strategy/fedxgb_nn_avg.py +9 -2
  41. flwr/server/utils/validator.py +1 -1
  42. {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240119.dist-info}/METADATA +3 -3
  43. {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240119.dist-info}/RECORD +46 -43
  44. {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240119.dist-info}/LICENSE +0 -0
  45. {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240119.dist-info}/WHEEL +0 -0
  46. {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240119.dist-info}/entry_points.txt +0 -0
flwr/client/app.py CHANGED
@@ -35,7 +35,7 @@ from flwr.common.constant import (
35
35
  TRANSPORT_TYPES,
36
36
  )
37
37
  from flwr.common.logger import log, warn_experimental_feature
38
- from flwr.proto.task_pb2 import TaskIns, TaskRes
38
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
39
39
 
40
40
  from .flower import load_flower_callable
41
41
  from .grpc_client.connection import grpc_connection
@@ -25,10 +25,13 @@ from typing import Callable, Iterator, Optional, Tuple, Union
25
25
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
26
26
  from flwr.common.grpc import create_channel
27
27
  from flwr.common.logger import log
28
- from flwr.proto.node_pb2 import Node
29
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
30
- from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
31
- from flwr.proto.transport_pb2_grpc import FlowerServiceStub
28
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
29
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
30
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
31
+ ClientMessage,
32
+ ServerMessage,
33
+ )
34
+ from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
32
35
 
33
36
  # The following flags can be uncommented for debugging. Other possible values:
34
37
  # https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
@@ -29,15 +29,15 @@ from flwr.client.message_handler.task_handler import (
29
29
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
30
30
  from flwr.common.grpc import create_channel
31
31
  from flwr.common.logger import log, warn_experimental_feature
32
- from flwr.proto.fleet_pb2 import (
32
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
33
  CreateNodeRequest,
34
34
  DeleteNodeRequest,
35
35
  PullTaskInsRequest,
36
36
  PushTaskResRequest,
37
37
  )
38
- from flwr.proto.fleet_pb2_grpc import FleetStub
39
- from flwr.proto.node_pb2 import Node
40
- from flwr.proto.task_pb2 import TaskIns, TaskRes
38
+ from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
39
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
40
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
41
41
 
42
42
  KEY_NODE = "node"
43
43
  KEY_TASK_INS = "current_task_ins"
@@ -32,8 +32,17 @@ from flwr.client.run_state import RunState
32
32
  from flwr.client.secure_aggregation import SecureAggregationHandler
33
33
  from flwr.client.typing import ClientFn
34
34
  from flwr.common import serde
35
- from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes
36
- from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage
35
+ from flwr.proto.task_pb2 import ( # pylint: disable=E0611
36
+ SecureAggregation,
37
+ Task,
38
+ TaskIns,
39
+ TaskRes,
40
+ )
41
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
42
+ ClientMessage,
43
+ Reason,
44
+ ServerMessage,
45
+ )
37
46
 
38
47
 
39
48
  class UnexpectedServerMessage(Exception):
@@ -17,10 +17,13 @@
17
17
 
18
18
  from typing import Optional
19
19
 
20
- from flwr.proto.fleet_pb2 import PullTaskInsResponse
21
- from flwr.proto.node_pb2 import Node
22
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
23
- from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
20
+ from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611
21
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
22
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
23
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
24
+ ClientMessage,
25
+ ServerMessage,
26
+ )
24
27
 
25
28
 
26
29
  def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool:
@@ -17,7 +17,7 @@
17
17
 
18
18
  from flwr.client.node_state import NodeState
19
19
  from flwr.client.run_state import RunState
20
- from flwr.proto.task_pb2 import TaskIns
20
+ from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
21
21
 
22
22
 
23
23
  def _run_dummy_task(state: RunState) -> RunState:
@@ -29,7 +29,7 @@ from flwr.client.message_handler.task_handler import (
29
29
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
30
30
  from flwr.common.constant import MISSING_EXTRA_REST
31
31
  from flwr.common.logger import log
32
- from flwr.proto.fleet_pb2 import (
32
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
33
  CreateNodeRequest,
34
34
  CreateNodeResponse,
35
35
  DeleteNodeRequest,
@@ -38,8 +38,8 @@ from flwr.proto.fleet_pb2 import (
38
38
  PushTaskResRequest,
39
39
  PushTaskResResponse,
40
40
  )
41
- from flwr.proto.node_pb2 import Node
42
- from flwr.proto.task_pb2 import TaskIns, TaskRes
41
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
42
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
43
43
 
44
44
  try:
45
45
  import requests
flwr/client/typing.py CHANGED
@@ -18,7 +18,7 @@ from dataclasses import dataclass
18
18
  from typing import Callable
19
19
 
20
20
  from flwr.client.run_state import RunState
21
- from flwr.proto.task_pb2 import TaskIns, TaskRes
21
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
22
22
 
23
23
  from .client import Client as Client
24
24
 
@@ -0,0 +1,107 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ConfigsRecord."""
16
+
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import Dict, Optional, get_args
20
+
21
+ from .typing import ConfigsRecordValues, ConfigsScalar
22
+
23
+
24
+ @dataclass
25
+ class ConfigsRecord:
26
+ """Configs record."""
27
+
28
+ data: Dict[str, ConfigsRecordValues] = field(default_factory=dict)
29
+
30
+ def __init__(
31
+ self,
32
+ configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None,
33
+ keep_input: bool = True,
34
+ ):
35
+ """Construct a ConfigsRecord object.
36
+
37
+ Parameters
38
+ ----------
39
+ configs_dict : Optional[Dict[str, ConfigsRecordValues]]
40
+ A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as
41
+ defined in `ConfigsScalar`) and lists of such types (see
42
+ `ConfigsScalarList`).
43
+ keep_input : bool (default: True)
44
+ A boolean indicating whether config passed should be deleted from the input
45
+ dictionary immediately after adding them to the record. When set
46
+ to True, the data is duplicated in memory. If memory is a concern, set
47
+ it to False.
48
+ """
49
+ self.data = {}
50
+ if configs_dict:
51
+ self.set_configs(configs_dict, keep_input=keep_input)
52
+
53
+ def set_configs(
54
+ self, configs_dict: Dict[str, ConfigsRecordValues], keep_input: bool = True
55
+ ) -> None:
56
+ """Add configs to the record.
57
+
58
+ Parameters
59
+ ----------
60
+ configs_dict : Dict[str, ConfigsRecordValues]
61
+ A dictionary that stores basic types (i.e. `str`,`int`, `float`, `bytes` as
62
+ defined in `ConfigsRecordValues`) and list of such types (see
63
+ `ConfigsScalarList`).
64
+ keep_input : bool (default: True)
65
+ A boolean indicating whether config passed should be deleted from the input
66
+ dictionary immediately after adding them to the record. When set
67
+ to True, the data is duplicated in memory. If memory is a concern, set
68
+ it to False.
69
+ """
70
+ if any(not isinstance(k, str) for k in configs_dict.keys()):
71
+ raise TypeError(f"Not all keys are of valid type. Expected {str}")
72
+
73
+ def is_valid(value: ConfigsScalar) -> None:
74
+ """Check if value is of expected type."""
75
+ if not isinstance(value, get_args(ConfigsScalar)):
76
+ raise TypeError(
77
+ "Not all values are of valid type."
78
+ f" Expected {ConfigsRecordValues} but you passed {type(value)}."
79
+ )
80
+
81
+ # Check types of values
82
+ # Split between those values that are list and those that aren't
83
+ # then process in the same way
84
+ for value in configs_dict.values():
85
+ if isinstance(value, list):
86
+ # If your lists are large (e.g. 1M+ elements) this will be slow
87
+ # 1s to check 10M element list on a M2 Pro
88
+ # In such settings, you'd be better of treating such config as
89
+ # an array and pass it to a ParametersRecord.
90
+ for list_value in value:
91
+ is_valid(list_value)
92
+ else:
93
+ is_valid(value)
94
+
95
+ # Add configs to record
96
+ if keep_input:
97
+ # Copy
98
+ self.data = configs_dict.copy()
99
+ else:
100
+ # Add entries to dataclass without duplicating memory
101
+ for key in list(configs_dict.keys()):
102
+ self.data[key] = configs_dict[key]
103
+ del configs_dict[key]
104
+
105
+ def __getitem__(self, key: str) -> ConfigsRecordValues:
106
+ """Retrieve an element stored in record."""
107
+ return self.data[key]
@@ -0,0 +1,77 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """FlowerContext and Metadata."""
16
+
17
+
18
+ from dataclasses import dataclass
19
+
20
+ from .recordset import RecordSet
21
+
22
+
23
+ @dataclass
24
+ class Metadata:
25
+ """A dataclass holding metadata associated with the current task.
26
+
27
+ Parameters
28
+ ----------
29
+ run_id : int
30
+ An identifier for the current run.
31
+ task_id : str
32
+ An identifier for the current task.
33
+ group_id : str
34
+ An identifier for grouping tasks. In some settings
35
+ this is used as the FL round.
36
+ ttl : str
37
+ Time-to-live for this task.
38
+ task_type : str
39
+ A string that encodes the action to be executed on
40
+ the receiving end.
41
+ """
42
+
43
+ run_id: int
44
+ task_id: str
45
+ group_id: str
46
+ ttl: str
47
+ task_type: str
48
+
49
+
50
+ @dataclass
51
+ class FlowerContext:
52
+ """State of your application from the viewpoint of the entity using it.
53
+
54
+ Parameters
55
+ ----------
56
+ in_message : RecordSet
57
+ Holds records sent by another entity (e.g. sent by the server-side
58
+ logic to a client, or vice-versa)
59
+ out_message : RecordSet
60
+ Holds records added by the current entity. This `RecordSet` will
61
+ be sent out (e.g. back to the server-side for aggregation of
62
+ parameter, or to the client to perform a certain task)
63
+ local : RecordSet
64
+ Holds record added by the current entity and that will stay local.
65
+ This means that the data it holds will never leave the system it's running from.
66
+ This can be used as an intermediate storage or scratchpad when
67
+ executing middleware layers. It can also be used as a memory to access
68
+ at different points during the lifecycle of this entity (e.g. across
69
+ multiple rounds)
70
+ metadata : Metadata
71
+ A dataclass including information about the task to be executed.
72
+ """
73
+
74
+ in_message: RecordSet
75
+ out_message: RecordSet
76
+ local: RecordSet
77
+ metadata: Metadata
flwr/common/logger.py CHANGED
@@ -111,3 +111,17 @@ def warn_experimental_feature(name: str) -> None:
111
111
  """,
112
112
  name,
113
113
  )
114
+
115
+
116
+ def warn_deprecated_feature(name: str) -> None:
117
+ """Warn the user when they use a deprecated feature."""
118
+ log(
119
+ WARN,
120
+ """
121
+ DEPRECATED FEATURE: %s
122
+
123
+ This is a deprecated feature. It will be removed
124
+ entirely in future versions of Flower.
125
+ """,
126
+ name,
127
+ )
@@ -0,0 +1,107 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """MetricsRecord."""
16
+
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import Dict, Optional, get_args
20
+
21
+ from .typing import MetricsRecordValues, MetricsScalar
22
+
23
+
24
+ @dataclass
25
+ class MetricsRecord:
26
+ """Metrics record."""
27
+
28
+ data: Dict[str, MetricsRecordValues] = field(default_factory=dict)
29
+
30
+ def __init__(
31
+ self,
32
+ metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None,
33
+ keep_input: bool = True,
34
+ ):
35
+ """Construct a MetricsRecord object.
36
+
37
+ Parameters
38
+ ----------
39
+ metrics_dict : Optional[Dict[str, MetricsRecordValues]]
40
+ A dictionary that stores basic types (i.e. `int`, `float` as defined
41
+ in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
42
+ keep_input : bool (default: True)
43
+ A boolean indicating whether metrics should be deleted from the input
44
+ dictionary immediately after adding them to the record. When set
45
+ to True, the data is duplicated in memory. If memory is a concern, set
46
+ it to False.
47
+ """
48
+ self.data = {}
49
+ if metrics_dict:
50
+ self.set_metrics(metrics_dict, keep_input=keep_input)
51
+
52
+ def set_metrics(
53
+ self, metrics_dict: Dict[str, MetricsRecordValues], keep_input: bool = True
54
+ ) -> None:
55
+ """Add metrics to the record.
56
+
57
+ Parameters
58
+ ----------
59
+ metrics_dict : Dict[str, MetricsRecordValues]
60
+ A dictionary that stores basic types (i.e. `int`, `float` as defined
61
+ in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
62
+ keep_input : bool (default: True)
63
+ A boolean indicating whether metrics should be deleted from the input
64
+ dictionary immediately after adding them to the record. When set
65
+ to True, the data is duplicated in memory. If memory is a concern, set
66
+ it to False.
67
+ """
68
+ if any(not isinstance(k, str) for k in metrics_dict.keys()):
69
+ raise TypeError(f"Not all keys are of valid type. Expected {str}.")
70
+
71
+ def is_valid(value: MetricsScalar) -> None:
72
+ """Check if value is of expected type."""
73
+ if not isinstance(value, get_args(MetricsScalar)) or isinstance(
74
+ value, bool
75
+ ):
76
+ raise TypeError(
77
+ "Not all values are of valid type."
78
+ f" Expected {MetricsRecordValues} but you passed {type(value)}."
79
+ )
80
+
81
+ # Check types of values
82
+ # Split between those values that are list and those that aren't
83
+ # then process in the same way
84
+ for value in metrics_dict.values():
85
+ if isinstance(value, list):
86
+ # If your lists are large (e.g. 1M+ elements) this will be slow
87
+ # 1s to check 10M element list on a M2 Pro
88
+ # In such settings, you'd be better of treating such metric as
89
+ # an array and pass it to a ParametersRecord.
90
+ for list_value in value:
91
+ is_valid(list_value)
92
+ else:
93
+ is_valid(value)
94
+
95
+ # Add metrics to record
96
+ if keep_input:
97
+ # Copy
98
+ self.data = metrics_dict.copy()
99
+ else:
100
+ # Add entries to dataclass without duplicating memory
101
+ for key in list(metrics_dict.keys()):
102
+ self.data[key] = metrics_dict[key]
103
+ del metrics_dict[key]
104
+
105
+ def __getitem__(self, key: str) -> MetricsRecordValues:
106
+ """Retrieve an element stored in record."""
107
+ return self.data[key]
@@ -59,7 +59,6 @@ class ParametersRecord:
59
59
  PyTorch's state_dict, but holding serialised tensors instead.
60
60
  """
61
61
 
62
- keep_input: bool
63
62
  data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array])
64
63
 
65
64
  def __init__(
@@ -82,25 +81,29 @@ class ParametersRecord:
82
81
  parameters after adding it to the record, set this flag to True. When set
83
82
  to True, the data is duplicated in memory.
84
83
  """
85
- self.keep_input = keep_input
86
84
  self.data = OrderedDict()
87
85
  if array_dict:
88
- self.set_parameters(array_dict)
86
+ self.set_parameters(array_dict, keep_input=keep_input)
89
87
 
90
- def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None:
88
+ def set_parameters(
89
+ self, array_dict: OrderedDict[str, Array], keep_input: bool = False
90
+ ) -> None:
91
91
  """Add parameters to record.
92
92
 
93
93
  Parameters
94
94
  ----------
95
95
  array_dict : OrderedDict[str, Array]
96
96
  A dictionary that stores serialized array-like or tensor-like objects.
97
+ keep_input : bool (default: False)
98
+ A boolean indicating whether parameters should be deleted from the input
99
+ dictionary immediately after adding them to the record.
97
100
  """
98
101
  if any(not isinstance(k, str) for k in array_dict.keys()):
99
102
  raise TypeError(f"Not all keys are of valid type. Expected {str}")
100
103
  if any(not isinstance(v, Array) for v in array_dict.values()):
101
104
  raise TypeError(f"Not all values are of valid type. Expected {Array}")
102
105
 
103
- if self.keep_input:
106
+ if keep_input:
104
107
  # Copy
105
108
  self.data = OrderedDict(array_dict)
106
109
  else:
@@ -108,3 +111,7 @@ class ParametersRecord:
108
111
  for key in list(array_dict.keys()):
109
112
  self.data[key] = array_dict[key]
110
113
  del array_dict[key]
114
+
115
+ def __getitem__(self, key: str) -> Array:
116
+ """Retrieve an element stored in record."""
117
+ return self.data[key]
flwr/common/recordset.py CHANGED
@@ -14,22 +14,15 @@
14
14
  # ==============================================================================
15
15
  """RecordSet."""
16
16
 
17
+
17
18
  from dataclasses import dataclass, field
18
19
  from typing import Dict
19
20
 
21
+ from .configsrecord import ConfigsRecord
22
+ from .metricsrecord import MetricsRecord
20
23
  from .parametersrecord import ParametersRecord
21
24
 
22
25
 
23
- @dataclass
24
- class MetricsRecord:
25
- """Metrics record."""
26
-
27
-
28
- @dataclass
29
- class ConfigsRecord:
30
- """Configs record."""
31
-
32
-
33
26
  @dataclass
34
27
  class RecordSet:
35
28
  """Definition of RecordSet."""
flwr/common/serde.py CHANGED
@@ -17,8 +17,8 @@
17
17
 
18
18
  from typing import Any, Dict, List, MutableMapping, cast
19
19
 
20
- from flwr.proto.task_pb2 import Value
21
- from flwr.proto.transport_pb2 import (
20
+ from flwr.proto.task_pb2 import Value # pylint: disable=E0611
21
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
22
22
  ClientMessage,
23
23
  Code,
24
24
  Parameters,
flwr/common/typing.py CHANGED
@@ -45,6 +45,15 @@ Value = Union[
45
45
  List[str],
46
46
  ]
47
47
 
48
+ # Value types for common.MetricsRecord
49
+ MetricsScalar = Union[int, float]
50
+ MetricsScalarList = Union[List[int], List[float]]
51
+ MetricsRecordValues = Union[MetricsScalar, MetricsScalarList]
52
+ # Value types for common.ConfigsRecord
53
+ ConfigsScalar = Union[MetricsScalar, str, bytes, bool]
54
+ ConfigsScalarList = Union[MetricsScalarList, List[str], List[bytes], List[bool]]
55
+ ConfigsRecordValues = Union[ConfigsScalar, ConfigsScalarList]
56
+
48
57
  Metrics = Dict[str, Scalar]
49
58
  MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics]
50
59
 
flwr/driver/app.py CHANGED
@@ -25,7 +25,7 @@ from typing import Dict, Optional, Union
25
25
  from flwr.common import EventType, event
26
26
  from flwr.common.address import parse_address
27
27
  from flwr.common.logger import log
28
- from flwr.proto import driver_pb2
28
+ from flwr.proto import driver_pb2 # pylint: disable=E0611
29
29
  from flwr.server.app import ServerConfig, init_defaults, run_fl
30
30
  from flwr.server.client_manager import ClientManager
31
31
  from flwr.server.history import History
@@ -171,7 +171,9 @@ def update_client_manager(
171
171
  `client_manager.unregister()`.
172
172
  """
173
173
  # Request for run_id
174
- run_id = driver.create_run(driver_pb2.CreateRunRequest()).run_id
174
+ run_id = driver.create_run(
175
+ driver_pb2.CreateRunRequest() # pylint: disable=E1101
176
+ ).run_id
175
177
 
176
178
  # Loop until the driver is disconnected
177
179
  registered_nodes: Dict[int, DriverClientProxy] = {}
@@ -181,7 +183,7 @@ def update_client_manager(
181
183
  if driver.stub is None:
182
184
  break
183
185
  get_nodes_res = driver.get_nodes(
184
- req=driver_pb2.GetNodesRequest(run_id=run_id)
186
+ req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101
185
187
  )
186
188
  all_node_ids = {node.node_id for node in get_nodes_res.nodes}
187
189
  dead_nodes = set(registered_nodes).difference(all_node_ids)
flwr/driver/driver.py CHANGED
@@ -18,14 +18,14 @@
18
18
  from typing import Iterable, List, Optional, Tuple
19
19
 
20
20
  from flwr.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver
21
- from flwr.proto.driver_pb2 import (
21
+ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
22
22
  CreateRunRequest,
23
23
  GetNodesRequest,
24
24
  PullTaskResRequest,
25
25
  PushTaskInsRequest,
26
26
  )
27
- from flwr.proto.node_pb2 import Node
28
- from flwr.proto.task_pb2 import TaskIns, TaskRes
27
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
28
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
29
29
 
30
30
 
31
31
  class Driver: