flwr-nightly 1.7.0.dev20240117__py3-none-any.whl → 1.7.0.dev20240119__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/app.py +1 -1
- flwr/client/grpc_client/connection.py +7 -4
- flwr/client/grpc_rere_client/connection.py +4 -4
- flwr/client/message_handler/message_handler.py +11 -2
- flwr/client/message_handler/task_handler.py +7 -4
- flwr/client/node_state_tests.py +1 -1
- flwr/client/rest_client/connection.py +3 -3
- flwr/client/typing.py +1 -1
- flwr/common/configsrecord.py +107 -0
- flwr/common/flowercontext.py +77 -0
- flwr/common/logger.py +14 -0
- flwr/common/metricsrecord.py +107 -0
- flwr/common/parametersrecord.py +12 -5
- flwr/common/recordset.py +3 -10
- flwr/common/serde.py +2 -2
- flwr/common/typing.py +9 -0
- flwr/driver/app.py +5 -3
- flwr/driver/driver.py +3 -3
- flwr/driver/driver_client_proxy.py +24 -15
- flwr/driver/grpc_driver.py +2 -2
- flwr/proto/driver_pb2.py +23 -88
- flwr/proto/fleet_pb2.py +29 -111
- flwr/proto/node_pb2.py +7 -15
- flwr/proto/task_pb2.py +34 -128
- flwr/proto/task_pb2.pyi +4 -1
- flwr/proto/transport_pb2.py +69 -278
- flwr/server/app.py +9 -3
- flwr/server/driver/driver_servicer.py +4 -4
- flwr/server/fleet/grpc_bidi/flower_service_servicer.py +5 -2
- flwr/server/fleet/grpc_bidi/grpc_bridge.py +4 -1
- flwr/server/fleet/grpc_bidi/grpc_client_proxy.py +4 -1
- flwr/server/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/fleet/grpc_bidi/ins_scheduler.py +6 -3
- flwr/server/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/fleet/message_handler/message_handler.py +3 -3
- flwr/server/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/state/in_memory_state.py +1 -1
- flwr/server/state/sqlite_state.py +6 -3
- flwr/server/state/state.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +9 -2
- flwr/server/utils/validator.py +1 -1
- {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240119.dist-info}/METADATA +3 -3
- {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240119.dist-info}/RECORD +46 -43
- {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240119.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240119.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.7.0.dev20240117.dist-info → flwr_nightly-1.7.0.dev20240119.dist-info}/entry_points.txt +0 -0
flwr/client/app.py
CHANGED
@@ -35,7 +35,7 @@ from flwr.common.constant import (
|
|
35
35
|
TRANSPORT_TYPES,
|
36
36
|
)
|
37
37
|
from flwr.common.logger import log, warn_experimental_feature
|
38
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
38
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
39
39
|
|
40
40
|
from .flower import load_flower_callable
|
41
41
|
from .grpc_client.connection import grpc_connection
|
@@ -25,10 +25,13 @@ from typing import Callable, Iterator, Optional, Tuple, Union
|
|
25
25
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
26
26
|
from flwr.common.grpc import create_channel
|
27
27
|
from flwr.common.logger import log
|
28
|
-
from flwr.proto.node_pb2 import Node
|
29
|
-
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
30
|
-
from flwr.proto.transport_pb2 import
|
31
|
-
|
28
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
29
|
+
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
30
|
+
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
31
|
+
ClientMessage,
|
32
|
+
ServerMessage,
|
33
|
+
)
|
34
|
+
from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
|
32
35
|
|
33
36
|
# The following flags can be uncommented for debugging. Other possible values:
|
34
37
|
# https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
|
@@ -29,15 +29,15 @@ from flwr.client.message_handler.task_handler import (
|
|
29
29
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
30
30
|
from flwr.common.grpc import create_channel
|
31
31
|
from flwr.common.logger import log, warn_experimental_feature
|
32
|
-
from flwr.proto.fleet_pb2 import (
|
32
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
33
33
|
CreateNodeRequest,
|
34
34
|
DeleteNodeRequest,
|
35
35
|
PullTaskInsRequest,
|
36
36
|
PushTaskResRequest,
|
37
37
|
)
|
38
|
-
from flwr.proto.fleet_pb2_grpc import FleetStub
|
39
|
-
from flwr.proto.node_pb2 import Node
|
40
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
38
|
+
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
39
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
40
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
41
41
|
|
42
42
|
KEY_NODE = "node"
|
43
43
|
KEY_TASK_INS = "current_task_ins"
|
@@ -32,8 +32,17 @@ from flwr.client.run_state import RunState
|
|
32
32
|
from flwr.client.secure_aggregation import SecureAggregationHandler
|
33
33
|
from flwr.client.typing import ClientFn
|
34
34
|
from flwr.common import serde
|
35
|
-
from flwr.proto.task_pb2 import
|
36
|
-
|
35
|
+
from flwr.proto.task_pb2 import ( # pylint: disable=E0611
|
36
|
+
SecureAggregation,
|
37
|
+
Task,
|
38
|
+
TaskIns,
|
39
|
+
TaskRes,
|
40
|
+
)
|
41
|
+
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
42
|
+
ClientMessage,
|
43
|
+
Reason,
|
44
|
+
ServerMessage,
|
45
|
+
)
|
37
46
|
|
38
47
|
|
39
48
|
class UnexpectedServerMessage(Exception):
|
@@ -17,10 +17,13 @@
|
|
17
17
|
|
18
18
|
from typing import Optional
|
19
19
|
|
20
|
-
from flwr.proto.fleet_pb2 import PullTaskInsResponse
|
21
|
-
from flwr.proto.node_pb2 import Node
|
22
|
-
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
23
|
-
from flwr.proto.transport_pb2 import
|
20
|
+
from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611
|
21
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
22
|
+
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
23
|
+
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
24
|
+
ClientMessage,
|
25
|
+
ServerMessage,
|
26
|
+
)
|
24
27
|
|
25
28
|
|
26
29
|
def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool:
|
flwr/client/node_state_tests.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from flwr.client.node_state import NodeState
|
19
19
|
from flwr.client.run_state import RunState
|
20
|
-
from flwr.proto.task_pb2 import TaskIns
|
20
|
+
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
21
21
|
|
22
22
|
|
23
23
|
def _run_dummy_task(state: RunState) -> RunState:
|
@@ -29,7 +29,7 @@ from flwr.client.message_handler.task_handler import (
|
|
29
29
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
30
30
|
from flwr.common.constant import MISSING_EXTRA_REST
|
31
31
|
from flwr.common.logger import log
|
32
|
-
from flwr.proto.fleet_pb2 import (
|
32
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
33
33
|
CreateNodeRequest,
|
34
34
|
CreateNodeResponse,
|
35
35
|
DeleteNodeRequest,
|
@@ -38,8 +38,8 @@ from flwr.proto.fleet_pb2 import (
|
|
38
38
|
PushTaskResRequest,
|
39
39
|
PushTaskResResponse,
|
40
40
|
)
|
41
|
-
from flwr.proto.node_pb2 import Node
|
42
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
41
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
42
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
43
43
|
|
44
44
|
try:
|
45
45
|
import requests
|
flwr/client/typing.py
CHANGED
@@ -18,7 +18,7 @@ from dataclasses import dataclass
|
|
18
18
|
from typing import Callable
|
19
19
|
|
20
20
|
from flwr.client.run_state import RunState
|
21
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
21
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
22
22
|
|
23
23
|
from .client import Client as Client
|
24
24
|
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""ConfigsRecord."""
|
16
|
+
|
17
|
+
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
from typing import Dict, Optional, get_args
|
20
|
+
|
21
|
+
from .typing import ConfigsRecordValues, ConfigsScalar
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class ConfigsRecord:
|
26
|
+
"""Configs record."""
|
27
|
+
|
28
|
+
data: Dict[str, ConfigsRecordValues] = field(default_factory=dict)
|
29
|
+
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None,
|
33
|
+
keep_input: bool = True,
|
34
|
+
):
|
35
|
+
"""Construct a ConfigsRecord object.
|
36
|
+
|
37
|
+
Parameters
|
38
|
+
----------
|
39
|
+
configs_dict : Optional[Dict[str, ConfigsRecordValues]]
|
40
|
+
A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as
|
41
|
+
defined in `ConfigsScalar`) and lists of such types (see
|
42
|
+
`ConfigsScalarList`).
|
43
|
+
keep_input : bool (default: True)
|
44
|
+
A boolean indicating whether config passed should be deleted from the input
|
45
|
+
dictionary immediately after adding them to the record. When set
|
46
|
+
to True, the data is duplicated in memory. If memory is a concern, set
|
47
|
+
it to False.
|
48
|
+
"""
|
49
|
+
self.data = {}
|
50
|
+
if configs_dict:
|
51
|
+
self.set_configs(configs_dict, keep_input=keep_input)
|
52
|
+
|
53
|
+
def set_configs(
|
54
|
+
self, configs_dict: Dict[str, ConfigsRecordValues], keep_input: bool = True
|
55
|
+
) -> None:
|
56
|
+
"""Add configs to the record.
|
57
|
+
|
58
|
+
Parameters
|
59
|
+
----------
|
60
|
+
configs_dict : Dict[str, ConfigsRecordValues]
|
61
|
+
A dictionary that stores basic types (i.e. `str`,`int`, `float`, `bytes` as
|
62
|
+
defined in `ConfigsRecordValues`) and list of such types (see
|
63
|
+
`ConfigsScalarList`).
|
64
|
+
keep_input : bool (default: True)
|
65
|
+
A boolean indicating whether config passed should be deleted from the input
|
66
|
+
dictionary immediately after adding them to the record. When set
|
67
|
+
to True, the data is duplicated in memory. If memory is a concern, set
|
68
|
+
it to False.
|
69
|
+
"""
|
70
|
+
if any(not isinstance(k, str) for k in configs_dict.keys()):
|
71
|
+
raise TypeError(f"Not all keys are of valid type. Expected {str}")
|
72
|
+
|
73
|
+
def is_valid(value: ConfigsScalar) -> None:
|
74
|
+
"""Check if value is of expected type."""
|
75
|
+
if not isinstance(value, get_args(ConfigsScalar)):
|
76
|
+
raise TypeError(
|
77
|
+
"Not all values are of valid type."
|
78
|
+
f" Expected {ConfigsRecordValues} but you passed {type(value)}."
|
79
|
+
)
|
80
|
+
|
81
|
+
# Check types of values
|
82
|
+
# Split between those values that are list and those that aren't
|
83
|
+
# then process in the same way
|
84
|
+
for value in configs_dict.values():
|
85
|
+
if isinstance(value, list):
|
86
|
+
# If your lists are large (e.g. 1M+ elements) this will be slow
|
87
|
+
# 1s to check 10M element list on a M2 Pro
|
88
|
+
# In such settings, you'd be better of treating such config as
|
89
|
+
# an array and pass it to a ParametersRecord.
|
90
|
+
for list_value in value:
|
91
|
+
is_valid(list_value)
|
92
|
+
else:
|
93
|
+
is_valid(value)
|
94
|
+
|
95
|
+
# Add configs to record
|
96
|
+
if keep_input:
|
97
|
+
# Copy
|
98
|
+
self.data = configs_dict.copy()
|
99
|
+
else:
|
100
|
+
# Add entries to dataclass without duplicating memory
|
101
|
+
for key in list(configs_dict.keys()):
|
102
|
+
self.data[key] = configs_dict[key]
|
103
|
+
del configs_dict[key]
|
104
|
+
|
105
|
+
def __getitem__(self, key: str) -> ConfigsRecordValues:
|
106
|
+
"""Retrieve an element stored in record."""
|
107
|
+
return self.data[key]
|
@@ -0,0 +1,77 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""FlowerContext and Metadata."""
|
16
|
+
|
17
|
+
|
18
|
+
from dataclasses import dataclass
|
19
|
+
|
20
|
+
from .recordset import RecordSet
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass
|
24
|
+
class Metadata:
|
25
|
+
"""A dataclass holding metadata associated with the current task.
|
26
|
+
|
27
|
+
Parameters
|
28
|
+
----------
|
29
|
+
run_id : int
|
30
|
+
An identifier for the current run.
|
31
|
+
task_id : str
|
32
|
+
An identifier for the current task.
|
33
|
+
group_id : str
|
34
|
+
An identifier for grouping tasks. In some settings
|
35
|
+
this is used as the FL round.
|
36
|
+
ttl : str
|
37
|
+
Time-to-live for this task.
|
38
|
+
task_type : str
|
39
|
+
A string that encodes the action to be executed on
|
40
|
+
the receiving end.
|
41
|
+
"""
|
42
|
+
|
43
|
+
run_id: int
|
44
|
+
task_id: str
|
45
|
+
group_id: str
|
46
|
+
ttl: str
|
47
|
+
task_type: str
|
48
|
+
|
49
|
+
|
50
|
+
@dataclass
|
51
|
+
class FlowerContext:
|
52
|
+
"""State of your application from the viewpoint of the entity using it.
|
53
|
+
|
54
|
+
Parameters
|
55
|
+
----------
|
56
|
+
in_message : RecordSet
|
57
|
+
Holds records sent by another entity (e.g. sent by the server-side
|
58
|
+
logic to a client, or vice-versa)
|
59
|
+
out_message : RecordSet
|
60
|
+
Holds records added by the current entity. This `RecordSet` will
|
61
|
+
be sent out (e.g. back to the server-side for aggregation of
|
62
|
+
parameter, or to the client to perform a certain task)
|
63
|
+
local : RecordSet
|
64
|
+
Holds record added by the current entity and that will stay local.
|
65
|
+
This means that the data it holds will never leave the system it's running from.
|
66
|
+
This can be used as an intermediate storage or scratchpad when
|
67
|
+
executing middleware layers. It can also be used as a memory to access
|
68
|
+
at different points during the lifecycle of this entity (e.g. across
|
69
|
+
multiple rounds)
|
70
|
+
metadata : Metadata
|
71
|
+
A dataclass including information about the task to be executed.
|
72
|
+
"""
|
73
|
+
|
74
|
+
in_message: RecordSet
|
75
|
+
out_message: RecordSet
|
76
|
+
local: RecordSet
|
77
|
+
metadata: Metadata
|
flwr/common/logger.py
CHANGED
@@ -111,3 +111,17 @@ def warn_experimental_feature(name: str) -> None:
|
|
111
111
|
""",
|
112
112
|
name,
|
113
113
|
)
|
114
|
+
|
115
|
+
|
116
|
+
def warn_deprecated_feature(name: str) -> None:
|
117
|
+
"""Warn the user when they use a deprecated feature."""
|
118
|
+
log(
|
119
|
+
WARN,
|
120
|
+
"""
|
121
|
+
DEPRECATED FEATURE: %s
|
122
|
+
|
123
|
+
This is a deprecated feature. It will be removed
|
124
|
+
entirely in future versions of Flower.
|
125
|
+
""",
|
126
|
+
name,
|
127
|
+
)
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""MetricsRecord."""
|
16
|
+
|
17
|
+
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
from typing import Dict, Optional, get_args
|
20
|
+
|
21
|
+
from .typing import MetricsRecordValues, MetricsScalar
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class MetricsRecord:
|
26
|
+
"""Metrics record."""
|
27
|
+
|
28
|
+
data: Dict[str, MetricsRecordValues] = field(default_factory=dict)
|
29
|
+
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None,
|
33
|
+
keep_input: bool = True,
|
34
|
+
):
|
35
|
+
"""Construct a MetricsRecord object.
|
36
|
+
|
37
|
+
Parameters
|
38
|
+
----------
|
39
|
+
metrics_dict : Optional[Dict[str, MetricsRecordValues]]
|
40
|
+
A dictionary that stores basic types (i.e. `int`, `float` as defined
|
41
|
+
in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
|
42
|
+
keep_input : bool (default: True)
|
43
|
+
A boolean indicating whether metrics should be deleted from the input
|
44
|
+
dictionary immediately after adding them to the record. When set
|
45
|
+
to True, the data is duplicated in memory. If memory is a concern, set
|
46
|
+
it to False.
|
47
|
+
"""
|
48
|
+
self.data = {}
|
49
|
+
if metrics_dict:
|
50
|
+
self.set_metrics(metrics_dict, keep_input=keep_input)
|
51
|
+
|
52
|
+
def set_metrics(
|
53
|
+
self, metrics_dict: Dict[str, MetricsRecordValues], keep_input: bool = True
|
54
|
+
) -> None:
|
55
|
+
"""Add metrics to the record.
|
56
|
+
|
57
|
+
Parameters
|
58
|
+
----------
|
59
|
+
metrics_dict : Dict[str, MetricsRecordValues]
|
60
|
+
A dictionary that stores basic types (i.e. `int`, `float` as defined
|
61
|
+
in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
|
62
|
+
keep_input : bool (default: True)
|
63
|
+
A boolean indicating whether metrics should be deleted from the input
|
64
|
+
dictionary immediately after adding them to the record. When set
|
65
|
+
to True, the data is duplicated in memory. If memory is a concern, set
|
66
|
+
it to False.
|
67
|
+
"""
|
68
|
+
if any(not isinstance(k, str) for k in metrics_dict.keys()):
|
69
|
+
raise TypeError(f"Not all keys are of valid type. Expected {str}.")
|
70
|
+
|
71
|
+
def is_valid(value: MetricsScalar) -> None:
|
72
|
+
"""Check if value is of expected type."""
|
73
|
+
if not isinstance(value, get_args(MetricsScalar)) or isinstance(
|
74
|
+
value, bool
|
75
|
+
):
|
76
|
+
raise TypeError(
|
77
|
+
"Not all values are of valid type."
|
78
|
+
f" Expected {MetricsRecordValues} but you passed {type(value)}."
|
79
|
+
)
|
80
|
+
|
81
|
+
# Check types of values
|
82
|
+
# Split between those values that are list and those that aren't
|
83
|
+
# then process in the same way
|
84
|
+
for value in metrics_dict.values():
|
85
|
+
if isinstance(value, list):
|
86
|
+
# If your lists are large (e.g. 1M+ elements) this will be slow
|
87
|
+
# 1s to check 10M element list on a M2 Pro
|
88
|
+
# In such settings, you'd be better of treating such metric as
|
89
|
+
# an array and pass it to a ParametersRecord.
|
90
|
+
for list_value in value:
|
91
|
+
is_valid(list_value)
|
92
|
+
else:
|
93
|
+
is_valid(value)
|
94
|
+
|
95
|
+
# Add metrics to record
|
96
|
+
if keep_input:
|
97
|
+
# Copy
|
98
|
+
self.data = metrics_dict.copy()
|
99
|
+
else:
|
100
|
+
# Add entries to dataclass without duplicating memory
|
101
|
+
for key in list(metrics_dict.keys()):
|
102
|
+
self.data[key] = metrics_dict[key]
|
103
|
+
del metrics_dict[key]
|
104
|
+
|
105
|
+
def __getitem__(self, key: str) -> MetricsRecordValues:
|
106
|
+
"""Retrieve an element stored in record."""
|
107
|
+
return self.data[key]
|
flwr/common/parametersrecord.py
CHANGED
@@ -59,7 +59,6 @@ class ParametersRecord:
|
|
59
59
|
PyTorch's state_dict, but holding serialised tensors instead.
|
60
60
|
"""
|
61
61
|
|
62
|
-
keep_input: bool
|
63
62
|
data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array])
|
64
63
|
|
65
64
|
def __init__(
|
@@ -82,25 +81,29 @@ class ParametersRecord:
|
|
82
81
|
parameters after adding it to the record, set this flag to True. When set
|
83
82
|
to True, the data is duplicated in memory.
|
84
83
|
"""
|
85
|
-
self.keep_input = keep_input
|
86
84
|
self.data = OrderedDict()
|
87
85
|
if array_dict:
|
88
|
-
self.set_parameters(array_dict)
|
86
|
+
self.set_parameters(array_dict, keep_input=keep_input)
|
89
87
|
|
90
|
-
def set_parameters(
|
88
|
+
def set_parameters(
|
89
|
+
self, array_dict: OrderedDict[str, Array], keep_input: bool = False
|
90
|
+
) -> None:
|
91
91
|
"""Add parameters to record.
|
92
92
|
|
93
93
|
Parameters
|
94
94
|
----------
|
95
95
|
array_dict : OrderedDict[str, Array]
|
96
96
|
A dictionary that stores serialized array-like or tensor-like objects.
|
97
|
+
keep_input : bool (default: False)
|
98
|
+
A boolean indicating whether parameters should be deleted from the input
|
99
|
+
dictionary immediately after adding them to the record.
|
97
100
|
"""
|
98
101
|
if any(not isinstance(k, str) for k in array_dict.keys()):
|
99
102
|
raise TypeError(f"Not all keys are of valid type. Expected {str}")
|
100
103
|
if any(not isinstance(v, Array) for v in array_dict.values()):
|
101
104
|
raise TypeError(f"Not all values are of valid type. Expected {Array}")
|
102
105
|
|
103
|
-
if
|
106
|
+
if keep_input:
|
104
107
|
# Copy
|
105
108
|
self.data = OrderedDict(array_dict)
|
106
109
|
else:
|
@@ -108,3 +111,7 @@ class ParametersRecord:
|
|
108
111
|
for key in list(array_dict.keys()):
|
109
112
|
self.data[key] = array_dict[key]
|
110
113
|
del array_dict[key]
|
114
|
+
|
115
|
+
def __getitem__(self, key: str) -> Array:
|
116
|
+
"""Retrieve an element stored in record."""
|
117
|
+
return self.data[key]
|
flwr/common/recordset.py
CHANGED
@@ -14,22 +14,15 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""RecordSet."""
|
16
16
|
|
17
|
+
|
17
18
|
from dataclasses import dataclass, field
|
18
19
|
from typing import Dict
|
19
20
|
|
21
|
+
from .configsrecord import ConfigsRecord
|
22
|
+
from .metricsrecord import MetricsRecord
|
20
23
|
from .parametersrecord import ParametersRecord
|
21
24
|
|
22
25
|
|
23
|
-
@dataclass
|
24
|
-
class MetricsRecord:
|
25
|
-
"""Metrics record."""
|
26
|
-
|
27
|
-
|
28
|
-
@dataclass
|
29
|
-
class ConfigsRecord:
|
30
|
-
"""Configs record."""
|
31
|
-
|
32
|
-
|
33
26
|
@dataclass
|
34
27
|
class RecordSet:
|
35
28
|
"""Definition of RecordSet."""
|
flwr/common/serde.py
CHANGED
@@ -17,8 +17,8 @@
|
|
17
17
|
|
18
18
|
from typing import Any, Dict, List, MutableMapping, cast
|
19
19
|
|
20
|
-
from flwr.proto.task_pb2 import Value
|
21
|
-
from flwr.proto.transport_pb2 import (
|
20
|
+
from flwr.proto.task_pb2 import Value # pylint: disable=E0611
|
21
|
+
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
22
22
|
ClientMessage,
|
23
23
|
Code,
|
24
24
|
Parameters,
|
flwr/common/typing.py
CHANGED
@@ -45,6 +45,15 @@ Value = Union[
|
|
45
45
|
List[str],
|
46
46
|
]
|
47
47
|
|
48
|
+
# Value types for common.MetricsRecord
|
49
|
+
MetricsScalar = Union[int, float]
|
50
|
+
MetricsScalarList = Union[List[int], List[float]]
|
51
|
+
MetricsRecordValues = Union[MetricsScalar, MetricsScalarList]
|
52
|
+
# Value types for common.ConfigsRecord
|
53
|
+
ConfigsScalar = Union[MetricsScalar, str, bytes, bool]
|
54
|
+
ConfigsScalarList = Union[MetricsScalarList, List[str], List[bytes], List[bool]]
|
55
|
+
ConfigsRecordValues = Union[ConfigsScalar, ConfigsScalarList]
|
56
|
+
|
48
57
|
Metrics = Dict[str, Scalar]
|
49
58
|
MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics]
|
50
59
|
|
flwr/driver/app.py
CHANGED
@@ -25,7 +25,7 @@ from typing import Dict, Optional, Union
|
|
25
25
|
from flwr.common import EventType, event
|
26
26
|
from flwr.common.address import parse_address
|
27
27
|
from flwr.common.logger import log
|
28
|
-
from flwr.proto import driver_pb2
|
28
|
+
from flwr.proto import driver_pb2 # pylint: disable=E0611
|
29
29
|
from flwr.server.app import ServerConfig, init_defaults, run_fl
|
30
30
|
from flwr.server.client_manager import ClientManager
|
31
31
|
from flwr.server.history import History
|
@@ -171,7 +171,9 @@ def update_client_manager(
|
|
171
171
|
`client_manager.unregister()`.
|
172
172
|
"""
|
173
173
|
# Request for run_id
|
174
|
-
run_id = driver.create_run(
|
174
|
+
run_id = driver.create_run(
|
175
|
+
driver_pb2.CreateRunRequest() # pylint: disable=E1101
|
176
|
+
).run_id
|
175
177
|
|
176
178
|
# Loop until the driver is disconnected
|
177
179
|
registered_nodes: Dict[int, DriverClientProxy] = {}
|
@@ -181,7 +183,7 @@ def update_client_manager(
|
|
181
183
|
if driver.stub is None:
|
182
184
|
break
|
183
185
|
get_nodes_res = driver.get_nodes(
|
184
|
-
req=driver_pb2.GetNodesRequest(run_id=run_id)
|
186
|
+
req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101
|
185
187
|
)
|
186
188
|
all_node_ids = {node.node_id for node in get_nodes_res.nodes}
|
187
189
|
dead_nodes = set(registered_nodes).difference(all_node_ids)
|
flwr/driver/driver.py
CHANGED
@@ -18,14 +18,14 @@
|
|
18
18
|
from typing import Iterable, List, Optional, Tuple
|
19
19
|
|
20
20
|
from flwr.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver
|
21
|
-
from flwr.proto.driver_pb2 import (
|
21
|
+
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
22
22
|
CreateRunRequest,
|
23
23
|
GetNodesRequest,
|
24
24
|
PullTaskResRequest,
|
25
25
|
PushTaskInsRequest,
|
26
26
|
)
|
27
|
-
from flwr.proto.node_pb2 import Node
|
28
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
27
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
28
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
29
29
|
|
30
30
|
|
31
31
|
class Driver:
|