flwr-nightly 1.7.0.dev20240117__py3-none-any.whl → 1.7.0.dev20240119__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. flwr/client/app.py +1 -1
  2. flwr/client/grpc_client/connection.py +7 -4
  3. flwr/client/grpc_rere_client/connection.py +4 -4
  4. flwr/client/message_handler/message_handler.py +11 -2
  5. flwr/client/message_handler/task_handler.py +7 -4
  6. flwr/client/node_state_tests.py +1 -1
  7. flwr/client/rest_client/connection.py +3 -3
  8. flwr/client/typing.py +1 -1
  9. flwr/common/configsrecord.py +107 -0
  10. flwr/common/flowercontext.py +77 -0
  11. flwr/common/logger.py +14 -0
  12. flwr/common/metricsrecord.py +107 -0
  13. flwr/common/parametersrecord.py +12 -5
  14. flwr/common/recordset.py +3 -10
  15. flwr/common/serde.py +2 -2
  16. flwr/common/typing.py +9 -0
  17. flwr/driver/app.py +5 -3
  18. flwr/driver/driver.py +3 -3
  19. flwr/driver/driver_client_proxy.py +24 -15
  20. flwr/driver/grpc_driver.py +2 -2
  21. flwr/proto/driver_pb2.py +23 -88
  22. flwr/proto/fleet_pb2.py +29 -111
  23. flwr/proto/node_pb2.py +7 -15
  24. flwr/proto/task_pb2.py +34 -128
  25. flwr/proto/task_pb2.pyi +4 -1
  26. flwr/proto/transport_pb2.py +69 -278
  27. flwr/server/app.py +9 -3
  28. flwr/server/driver/driver_servicer.py +4 -4
  29. flwr/server/fleet/grpc_bidi/flower_service_servicer.py +5 -2
  30. flwr/server/fleet/grpc_bidi/grpc_bridge.py +4 -1
  31. flwr/server/fleet/grpc_bidi/grpc_client_proxy.py +4 -1
  32. flwr/server/fleet/grpc_bidi/grpc_server.py +3 -1
  33. flwr/server/fleet/grpc_bidi/ins_scheduler.py +6 -3
  34. flwr/server/fleet/grpc_rere/fleet_servicer.py +2 -2
  35. flwr/server/fleet/message_handler/message_handler.py +3 -3
  36. flwr/server/fleet/rest_rere/rest_api.py +1 -1
  37. flwr/server/state/in_memory_state.py +1 -1
  38. flwr/server/state/sqlite_state.py +6 -3
  39. flwr/server/state/state.py +1 -1
  40. flwr/server/strategy/fedxgb_nn_avg.py +9 -2
  41. flwr/server/utils/validator.py +1 -1
  42. {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240119.dist-info}/METADATA +3 -3
  43. {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240119.dist-info}/RECORD +46 -43
  44. {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240119.dist-info}/LICENSE +0 -0
  45. {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240119.dist-info}/WHEEL +0 -0
  46. {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240119.dist-info}/entry_points.txt +0 -0
flwr/client/app.py CHANGED
@@ -35,7 +35,7 @@ from flwr.common.constant import (
35
35
  TRANSPORT_TYPES,
36
36
  )
37
37
  from flwr.common.logger import log, warn_experimental_feature
38
- from flwr.proto.task_pb2 import TaskIns, TaskRes
38
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
39
39
 
40
40
  from .flower import load_flower_callable
41
41
  from .grpc_client.connection import grpc_connection
@@ -25,10 +25,13 @@ from typing import Callable, Iterator, Optional, Tuple, Union
25
25
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
26
26
  from flwr.common.grpc import create_channel
27
27
  from flwr.common.logger import log
28
- from flwr.proto.node_pb2 import Node
29
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
30
- from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
31
- from flwr.proto.transport_pb2_grpc import FlowerServiceStub
28
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
29
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
30
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
31
+ ClientMessage,
32
+ ServerMessage,
33
+ )
34
+ from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
32
35
 
33
36
  # The following flags can be uncommented for debugging. Other possible values:
34
37
  # https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
@@ -29,15 +29,15 @@ from flwr.client.message_handler.task_handler import (
29
29
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
30
30
  from flwr.common.grpc import create_channel
31
31
  from flwr.common.logger import log, warn_experimental_feature
32
- from flwr.proto.fleet_pb2 import (
32
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
33
  CreateNodeRequest,
34
34
  DeleteNodeRequest,
35
35
  PullTaskInsRequest,
36
36
  PushTaskResRequest,
37
37
  )
38
- from flwr.proto.fleet_pb2_grpc import FleetStub
39
- from flwr.proto.node_pb2 import Node
40
- from flwr.proto.task_pb2 import TaskIns, TaskRes
38
+ from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
39
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
40
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
41
41
 
42
42
  KEY_NODE = "node"
43
43
  KEY_TASK_INS = "current_task_ins"
@@ -32,8 +32,17 @@ from flwr.client.run_state import RunState
32
32
  from flwr.client.secure_aggregation import SecureAggregationHandler
33
33
  from flwr.client.typing import ClientFn
34
34
  from flwr.common import serde
35
- from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes
36
- from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage
35
+ from flwr.proto.task_pb2 import ( # pylint: disable=E0611
36
+ SecureAggregation,
37
+ Task,
38
+ TaskIns,
39
+ TaskRes,
40
+ )
41
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
42
+ ClientMessage,
43
+ Reason,
44
+ ServerMessage,
45
+ )
37
46
 
38
47
 
39
48
  class UnexpectedServerMessage(Exception):
@@ -17,10 +17,13 @@
17
17
 
18
18
  from typing import Optional
19
19
 
20
- from flwr.proto.fleet_pb2 import PullTaskInsResponse
21
- from flwr.proto.node_pb2 import Node
22
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
23
- from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
20
+ from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611
21
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
22
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
23
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
24
+ ClientMessage,
25
+ ServerMessage,
26
+ )
24
27
 
25
28
 
26
29
  def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool:
@@ -17,7 +17,7 @@
17
17
 
18
18
  from flwr.client.node_state import NodeState
19
19
  from flwr.client.run_state import RunState
20
- from flwr.proto.task_pb2 import TaskIns
20
+ from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
21
21
 
22
22
 
23
23
  def _run_dummy_task(state: RunState) -> RunState:
@@ -29,7 +29,7 @@ from flwr.client.message_handler.task_handler import (
29
29
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
30
30
  from flwr.common.constant import MISSING_EXTRA_REST
31
31
  from flwr.common.logger import log
32
- from flwr.proto.fleet_pb2 import (
32
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
33
  CreateNodeRequest,
34
34
  CreateNodeResponse,
35
35
  DeleteNodeRequest,
@@ -38,8 +38,8 @@ from flwr.proto.fleet_pb2 import (
38
38
  PushTaskResRequest,
39
39
  PushTaskResResponse,
40
40
  )
41
- from flwr.proto.node_pb2 import Node
42
- from flwr.proto.task_pb2 import TaskIns, TaskRes
41
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
42
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
43
43
 
44
44
  try:
45
45
  import requests
flwr/client/typing.py CHANGED
@@ -18,7 +18,7 @@ from dataclasses import dataclass
18
18
  from typing import Callable
19
19
 
20
20
  from flwr.client.run_state import RunState
21
- from flwr.proto.task_pb2 import TaskIns, TaskRes
21
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
22
22
 
23
23
  from .client import Client as Client
24
24
 
@@ -0,0 +1,107 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ConfigsRecord."""
16
+
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import Dict, Optional, get_args
20
+
21
+ from .typing import ConfigsRecordValues, ConfigsScalar
22
+
23
+
24
+ @dataclass
25
+ class ConfigsRecord:
26
+ """Configs record."""
27
+
28
+ data: Dict[str, ConfigsRecordValues] = field(default_factory=dict)
29
+
30
+ def __init__(
31
+ self,
32
+ configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None,
33
+ keep_input: bool = True,
34
+ ):
35
+ """Construct a ConfigsRecord object.
36
+
37
+ Parameters
38
+ ----------
39
+ configs_dict : Optional[Dict[str, ConfigsRecordValues]]
40
+ A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as
41
+ defined in `ConfigsScalar`) and lists of such types (see
42
+ `ConfigsScalarList`).
43
+ keep_input : bool (default: True)
44
+ A boolean indicating whether config passed should be deleted from the input
45
+ dictionary immediately after adding them to the record. When set
46
+ to True, the data is duplicated in memory. If memory is a concern, set
47
+ it to False.
48
+ """
49
+ self.data = {}
50
+ if configs_dict:
51
+ self.set_configs(configs_dict, keep_input=keep_input)
52
+
53
+ def set_configs(
54
+ self, configs_dict: Dict[str, ConfigsRecordValues], keep_input: bool = True
55
+ ) -> None:
56
+ """Add configs to the record.
57
+
58
+ Parameters
59
+ ----------
60
+ configs_dict : Dict[str, ConfigsRecordValues]
61
+ A dictionary that stores basic types (i.e. `str`,`int`, `float`, `bytes` as
62
+ defined in `ConfigsRecordValues`) and list of such types (see
63
+ `ConfigsScalarList`).
64
+ keep_input : bool (default: True)
65
+ A boolean indicating whether config passed should be deleted from the input
66
+ dictionary immediately after adding them to the record. When set
67
+ to True, the data is duplicated in memory. If memory is a concern, set
68
+ it to False.
69
+ """
70
+ if any(not isinstance(k, str) for k in configs_dict.keys()):
71
+ raise TypeError(f"Not all keys are of valid type. Expected {str}")
72
+
73
+ def is_valid(value: ConfigsScalar) -> None:
74
+ """Check if value is of expected type."""
75
+ if not isinstance(value, get_args(ConfigsScalar)):
76
+ raise TypeError(
77
+ "Not all values are of valid type."
78
+ f" Expected {ConfigsRecordValues} but you passed {type(value)}."
79
+ )
80
+
81
+ # Check types of values
82
+ # Split between those values that are list and those that aren't
83
+ # then process in the same way
84
+ for value in configs_dict.values():
85
+ if isinstance(value, list):
86
+ # If your lists are large (e.g. 1M+ elements) this will be slow
87
+ # 1s to check 10M element list on a M2 Pro
88
+ # In such settings, you'd be better of treating such config as
89
+ # an array and pass it to a ParametersRecord.
90
+ for list_value in value:
91
+ is_valid(list_value)
92
+ else:
93
+ is_valid(value)
94
+
95
+ # Add configs to record
96
+ if keep_input:
97
+ # Copy
98
+ self.data = configs_dict.copy()
99
+ else:
100
+ # Add entries to dataclass without duplicating memory
101
+ for key in list(configs_dict.keys()):
102
+ self.data[key] = configs_dict[key]
103
+ del configs_dict[key]
104
+
105
+ def __getitem__(self, key: str) -> ConfigsRecordValues:
106
+ """Retrieve an element stored in record."""
107
+ return self.data[key]
@@ -0,0 +1,77 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """FlowerContext and Metadata."""
16
+
17
+
18
+ from dataclasses import dataclass
19
+
20
+ from .recordset import RecordSet
21
+
22
+
23
+ @dataclass
24
+ class Metadata:
25
+ """A dataclass holding metadata associated with the current task.
26
+
27
+ Parameters
28
+ ----------
29
+ run_id : int
30
+ An identifier for the current run.
31
+ task_id : str
32
+ An identifier for the current task.
33
+ group_id : str
34
+ An identifier for grouping tasks. In some settings
35
+ this is used as the FL round.
36
+ ttl : str
37
+ Time-to-live for this task.
38
+ task_type : str
39
+ A string that encodes the action to be executed on
40
+ the receiving end.
41
+ """
42
+
43
+ run_id: int
44
+ task_id: str
45
+ group_id: str
46
+ ttl: str
47
+ task_type: str
48
+
49
+
50
+ @dataclass
51
+ class FlowerContext:
52
+ """State of your application from the viewpoint of the entity using it.
53
+
54
+ Parameters
55
+ ----------
56
+ in_message : RecordSet
57
+ Holds records sent by another entity (e.g. sent by the server-side
58
+ logic to a client, or vice-versa)
59
+ out_message : RecordSet
60
+ Holds records added by the current entity. This `RecordSet` will
61
+ be sent out (e.g. back to the server-side for aggregation of
62
+ parameter, or to the client to perform a certain task)
63
+ local : RecordSet
64
+ Holds record added by the current entity and that will stay local.
65
+ This means that the data it holds will never leave the system it's running from.
66
+ This can be used as an intermediate storage or scratchpad when
67
+ executing middleware layers. It can also be used as a memory to access
68
+ at different points during the lifecycle of this entity (e.g. across
69
+ multiple rounds)
70
+ metadata : Metadata
71
+ A dataclass including information about the task to be executed.
72
+ """
73
+
74
+ in_message: RecordSet
75
+ out_message: RecordSet
76
+ local: RecordSet
77
+ metadata: Metadata
flwr/common/logger.py CHANGED
@@ -111,3 +111,17 @@ def warn_experimental_feature(name: str) -> None:
111
111
  """,
112
112
  name,
113
113
  )
114
+
115
+
116
+ def warn_deprecated_feature(name: str) -> None:
117
+ """Warn the user when they use a deprecated feature."""
118
+ log(
119
+ WARN,
120
+ """
121
+ DEPRECATED FEATURE: %s
122
+
123
+ This is a deprecated feature. It will be removed
124
+ entirely in future versions of Flower.
125
+ """,
126
+ name,
127
+ )
@@ -0,0 +1,107 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """MetricsRecord."""
16
+
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import Dict, Optional, get_args
20
+
21
+ from .typing import MetricsRecordValues, MetricsScalar
22
+
23
+
24
+ @dataclass
25
+ class MetricsRecord:
26
+ """Metrics record."""
27
+
28
+ data: Dict[str, MetricsRecordValues] = field(default_factory=dict)
29
+
30
+ def __init__(
31
+ self,
32
+ metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None,
33
+ keep_input: bool = True,
34
+ ):
35
+ """Construct a MetricsRecord object.
36
+
37
+ Parameters
38
+ ----------
39
+ metrics_dict : Optional[Dict[str, MetricsRecordValues]]
40
+ A dictionary that stores basic types (i.e. `int`, `float` as defined
41
+ in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
42
+ keep_input : bool (default: True)
43
+ A boolean indicating whether metrics should be deleted from the input
44
+ dictionary immediately after adding them to the record. When set
45
+ to True, the data is duplicated in memory. If memory is a concern, set
46
+ it to False.
47
+ """
48
+ self.data = {}
49
+ if metrics_dict:
50
+ self.set_metrics(metrics_dict, keep_input=keep_input)
51
+
52
+ def set_metrics(
53
+ self, metrics_dict: Dict[str, MetricsRecordValues], keep_input: bool = True
54
+ ) -> None:
55
+ """Add metrics to the record.
56
+
57
+ Parameters
58
+ ----------
59
+ metrics_dict : Dict[str, MetricsRecordValues]
60
+ A dictionary that stores basic types (i.e. `int`, `float` as defined
61
+ in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
62
+ keep_input : bool (default: True)
63
+ A boolean indicating whether metrics should be deleted from the input
64
+ dictionary immediately after adding them to the record. When set
65
+ to True, the data is duplicated in memory. If memory is a concern, set
66
+ it to False.
67
+ """
68
+ if any(not isinstance(k, str) for k in metrics_dict.keys()):
69
+ raise TypeError(f"Not all keys are of valid type. Expected {str}.")
70
+
71
+ def is_valid(value: MetricsScalar) -> None:
72
+ """Check if value is of expected type."""
73
+ if not isinstance(value, get_args(MetricsScalar)) or isinstance(
74
+ value, bool
75
+ ):
76
+ raise TypeError(
77
+ "Not all values are of valid type."
78
+ f" Expected {MetricsRecordValues} but you passed {type(value)}."
79
+ )
80
+
81
+ # Check types of values
82
+ # Split between those values that are list and those that aren't
83
+ # then process in the same way
84
+ for value in metrics_dict.values():
85
+ if isinstance(value, list):
86
+ # If your lists are large (e.g. 1M+ elements) this will be slow
87
+ # 1s to check 10M element list on a M2 Pro
88
+ # In such settings, you'd be better of treating such metric as
89
+ # an array and pass it to a ParametersRecord.
90
+ for list_value in value:
91
+ is_valid(list_value)
92
+ else:
93
+ is_valid(value)
94
+
95
+ # Add metrics to record
96
+ if keep_input:
97
+ # Copy
98
+ self.data = metrics_dict.copy()
99
+ else:
100
+ # Add entries to dataclass without duplicating memory
101
+ for key in list(metrics_dict.keys()):
102
+ self.data[key] = metrics_dict[key]
103
+ del metrics_dict[key]
104
+
105
+ def __getitem__(self, key: str) -> MetricsRecordValues:
106
+ """Retrieve an element stored in record."""
107
+ return self.data[key]
@@ -59,7 +59,6 @@ class ParametersRecord:
59
59
  PyTorch's state_dict, but holding serialised tensors instead.
60
60
  """
61
61
 
62
- keep_input: bool
63
62
  data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array])
64
63
 
65
64
  def __init__(
@@ -82,25 +81,29 @@ class ParametersRecord:
82
81
  parameters after adding it to the record, set this flag to True. When set
83
82
  to True, the data is duplicated in memory.
84
83
  """
85
- self.keep_input = keep_input
86
84
  self.data = OrderedDict()
87
85
  if array_dict:
88
- self.set_parameters(array_dict)
86
+ self.set_parameters(array_dict, keep_input=keep_input)
89
87
 
90
- def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None:
88
+ def set_parameters(
89
+ self, array_dict: OrderedDict[str, Array], keep_input: bool = False
90
+ ) -> None:
91
91
  """Add parameters to record.
92
92
 
93
93
  Parameters
94
94
  ----------
95
95
  array_dict : OrderedDict[str, Array]
96
96
  A dictionary that stores serialized array-like or tensor-like objects.
97
+ keep_input : bool (default: False)
98
+ A boolean indicating whether parameters should be deleted from the input
99
+ dictionary immediately after adding them to the record.
97
100
  """
98
101
  if any(not isinstance(k, str) for k in array_dict.keys()):
99
102
  raise TypeError(f"Not all keys are of valid type. Expected {str}")
100
103
  if any(not isinstance(v, Array) for v in array_dict.values()):
101
104
  raise TypeError(f"Not all values are of valid type. Expected {Array}")
102
105
 
103
- if self.keep_input:
106
+ if keep_input:
104
107
  # Copy
105
108
  self.data = OrderedDict(array_dict)
106
109
  else:
@@ -108,3 +111,7 @@ class ParametersRecord:
108
111
  for key in list(array_dict.keys()):
109
112
  self.data[key] = array_dict[key]
110
113
  del array_dict[key]
114
+
115
+ def __getitem__(self, key: str) -> Array:
116
+ """Retrieve an element stored in record."""
117
+ return self.data[key]
flwr/common/recordset.py CHANGED
@@ -14,22 +14,15 @@
14
14
  # ==============================================================================
15
15
  """RecordSet."""
16
16
 
17
+
17
18
  from dataclasses import dataclass, field
18
19
  from typing import Dict
19
20
 
21
+ from .configsrecord import ConfigsRecord
22
+ from .metricsrecord import MetricsRecord
20
23
  from .parametersrecord import ParametersRecord
21
24
 
22
25
 
23
- @dataclass
24
- class MetricsRecord:
25
- """Metrics record."""
26
-
27
-
28
- @dataclass
29
- class ConfigsRecord:
30
- """Configs record."""
31
-
32
-
33
26
  @dataclass
34
27
  class RecordSet:
35
28
  """Definition of RecordSet."""
flwr/common/serde.py CHANGED
@@ -17,8 +17,8 @@
17
17
 
18
18
  from typing import Any, Dict, List, MutableMapping, cast
19
19
 
20
- from flwr.proto.task_pb2 import Value
21
- from flwr.proto.transport_pb2 import (
20
+ from flwr.proto.task_pb2 import Value # pylint: disable=E0611
21
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
22
22
  ClientMessage,
23
23
  Code,
24
24
  Parameters,
flwr/common/typing.py CHANGED
@@ -45,6 +45,15 @@ Value = Union[
45
45
  List[str],
46
46
  ]
47
47
 
48
+ # Value types for common.MetricsRecord
49
+ MetricsScalar = Union[int, float]
50
+ MetricsScalarList = Union[List[int], List[float]]
51
+ MetricsRecordValues = Union[MetricsScalar, MetricsScalarList]
52
+ # Value types for common.ConfigsRecord
53
+ ConfigsScalar = Union[MetricsScalar, str, bytes, bool]
54
+ ConfigsScalarList = Union[MetricsScalarList, List[str], List[bytes], List[bool]]
55
+ ConfigsRecordValues = Union[ConfigsScalar, ConfigsScalarList]
56
+
48
57
  Metrics = Dict[str, Scalar]
49
58
  MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics]
50
59
 
flwr/driver/app.py CHANGED
@@ -25,7 +25,7 @@ from typing import Dict, Optional, Union
25
25
  from flwr.common import EventType, event
26
26
  from flwr.common.address import parse_address
27
27
  from flwr.common.logger import log
28
- from flwr.proto import driver_pb2
28
+ from flwr.proto import driver_pb2 # pylint: disable=E0611
29
29
  from flwr.server.app import ServerConfig, init_defaults, run_fl
30
30
  from flwr.server.client_manager import ClientManager
31
31
  from flwr.server.history import History
@@ -171,7 +171,9 @@ def update_client_manager(
171
171
  `client_manager.unregister()`.
172
172
  """
173
173
  # Request for run_id
174
- run_id = driver.create_run(driver_pb2.CreateRunRequest()).run_id
174
+ run_id = driver.create_run(
175
+ driver_pb2.CreateRunRequest() # pylint: disable=E1101
176
+ ).run_id
175
177
 
176
178
  # Loop until the driver is disconnected
177
179
  registered_nodes: Dict[int, DriverClientProxy] = {}
@@ -181,7 +183,7 @@ def update_client_manager(
181
183
  if driver.stub is None:
182
184
  break
183
185
  get_nodes_res = driver.get_nodes(
184
- req=driver_pb2.GetNodesRequest(run_id=run_id)
186
+ req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101
185
187
  )
186
188
  all_node_ids = {node.node_id for node in get_nodes_res.nodes}
187
189
  dead_nodes = set(registered_nodes).difference(all_node_ids)
flwr/driver/driver.py CHANGED
@@ -18,14 +18,14 @@
18
18
  from typing import Iterable, List, Optional, Tuple
19
19
 
20
20
  from flwr.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver
21
- from flwr.proto.driver_pb2 import (
21
+ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
22
22
  CreateRunRequest,
23
23
  GetNodesRequest,
24
24
  PullTaskResRequest,
25
25
  PushTaskInsRequest,
26
26
  )
27
- from flwr.proto.node_pb2 import Node
28
- from flwr.proto.task_pb2 import TaskIns, TaskRes
27
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
28
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
29
29
 
30
30
 
31
31
  class Driver: