flwr-nightly 1.11.0.dev20240812__py3-none-any.whl → 1.11.0.dev20240815__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/run/run.py +6 -2
- flwr/client/app.py +4 -3
- flwr/client/grpc_adapter_client/connection.py +3 -1
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +8 -2
- flwr/client/process/__init__.py +15 -0
- flwr/client/process/clientappio_servicer.py +145 -0
- flwr/client/rest_client/connection.py +9 -3
- flwr/common/config.py +7 -2
- flwr/common/record/recordset.py +9 -7
- flwr/common/record/typeddict.py +20 -58
- flwr/common/recordset_compat.py +6 -6
- flwr/common/serde.py +178 -1
- flwr/common/typing.py +17 -0
- flwr/proto/exec_pb2.py +16 -15
- flwr/proto/exec_pb2.pyi +7 -4
- flwr/proto/message_pb2.py +2 -2
- flwr/proto/message_pb2.pyi +4 -1
- flwr/server/app.py +12 -0
- flwr/server/driver/grpc_driver.py +1 -0
- flwr/server/superlink/driver/driver_grpc.py +3 -0
- flwr/server/superlink/driver/driver_servicer.py +14 -1
- flwr/server/superlink/ffs/ffs_factory.py +47 -0
- flwr/server/superlink/state/in_memory_state.py +7 -5
- flwr/server/superlink/state/sqlite_state.py +17 -7
- flwr/server/superlink/state/state.py +4 -3
- flwr/simulation/run_simulation.py +4 -1
- flwr/superexec/exec_servicer.py +1 -1
- {flwr_nightly-1.11.0.dev20240812.dist-info → flwr_nightly-1.11.0.dev20240815.dist-info}/METADATA +1 -1
- {flwr_nightly-1.11.0.dev20240812.dist-info → flwr_nightly-1.11.0.dev20240815.dist-info}/RECORD +33 -30
- {flwr_nightly-1.11.0.dev20240812.dist-info → flwr_nightly-1.11.0.dev20240815.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240812.dist-info → flwr_nightly-1.11.0.dev20240815.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.11.0.dev20240812.dist-info → flwr_nightly-1.11.0.dev20240815.dist-info}/entry_points.txt +0 -0
flwr/cli/run/run.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower command line interface `run` command."""
|
|
16
16
|
|
|
17
|
+
import hashlib
|
|
17
18
|
import subprocess
|
|
18
19
|
import sys
|
|
19
20
|
from logging import DEBUG
|
|
@@ -28,7 +29,8 @@ from flwr.cli.config_utils import load_and_validate
|
|
|
28
29
|
from flwr.common.config import flatten_dict, parse_config_args
|
|
29
30
|
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
|
|
30
31
|
from flwr.common.logger import log
|
|
31
|
-
from flwr.common.serde import user_config_to_proto
|
|
32
|
+
from flwr.common.serde import fab_to_proto, user_config_to_proto
|
|
33
|
+
from flwr.common.typing import Fab
|
|
32
34
|
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
|
|
33
35
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
|
34
36
|
|
|
@@ -163,9 +165,11 @@ def _run_with_superexec(
|
|
|
163
165
|
stub = ExecStub(channel)
|
|
164
166
|
|
|
165
167
|
fab_path = Path(build(app))
|
|
168
|
+
content = fab_path.read_bytes()
|
|
169
|
+
fab = Fab(hashlib.sha256(content).hexdigest(), content)
|
|
166
170
|
|
|
167
171
|
req = StartRunRequest(
|
|
168
|
-
|
|
172
|
+
fab=fab_to_proto(fab),
|
|
169
173
|
override_config=user_config_to_proto(
|
|
170
174
|
parse_config_args(config_overrides, separator=",")
|
|
171
175
|
),
|
flwr/client/app.py
CHANGED
|
@@ -42,7 +42,7 @@ from flwr.common.constant import (
|
|
|
42
42
|
from flwr.common.logger import log, warn_deprecated_feature
|
|
43
43
|
from flwr.common.message import Error
|
|
44
44
|
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
|
|
45
|
-
from flwr.common.typing import Run, UserConfig
|
|
45
|
+
from flwr.common.typing import Fab, Run, UserConfig
|
|
46
46
|
|
|
47
47
|
from .grpc_adapter_client.connection import grpc_adapter
|
|
48
48
|
from .grpc_client.connection import grpc_connection
|
|
@@ -333,7 +333,7 @@ def _start_client_internal(
|
|
|
333
333
|
root_certificates,
|
|
334
334
|
authentication_keys,
|
|
335
335
|
) as conn:
|
|
336
|
-
receive, send, create_node, delete_node, get_run = conn
|
|
336
|
+
receive, send, create_node, delete_node, get_run, _ = conn
|
|
337
337
|
|
|
338
338
|
# Register node when connecting the first time
|
|
339
339
|
if node_state is None:
|
|
@@ -398,7 +398,7 @@ def _start_client_internal(
|
|
|
398
398
|
runs[run_id] = get_run(run_id)
|
|
399
399
|
# If get_run is None, i.e., in grpc-bidi mode
|
|
400
400
|
else:
|
|
401
|
-
runs[run_id] = Run(run_id, "", "", {})
|
|
401
|
+
runs[run_id] = Run(run_id, "", "", "", {})
|
|
402
402
|
|
|
403
403
|
# Register context for this run
|
|
404
404
|
node_state.register_context(
|
|
@@ -606,6 +606,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
|
606
606
|
Optional[Callable[[], Optional[int]]],
|
|
607
607
|
Optional[Callable[[], None]],
|
|
608
608
|
Optional[Callable[[int], Run]],
|
|
609
|
+
Optional[Callable[[str], Fab]],
|
|
609
610
|
]
|
|
610
611
|
],
|
|
611
612
|
],
|
|
@@ -27,7 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
|
27
27
|
from flwr.common.logger import log
|
|
28
28
|
from flwr.common.message import Message
|
|
29
29
|
from flwr.common.retry_invoker import RetryInvoker
|
|
30
|
-
from flwr.common.typing import Run
|
|
30
|
+
from flwr.common.typing import Fab, Run
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
@contextmanager
|
|
@@ -47,6 +47,7 @@ def grpc_adapter( # pylint: disable=R0913
|
|
|
47
47
|
Optional[Callable[[], Optional[int]]],
|
|
48
48
|
Optional[Callable[[], None]],
|
|
49
49
|
Optional[Callable[[int], Run]],
|
|
50
|
+
Optional[Callable[[str], Fab]],
|
|
50
51
|
]
|
|
51
52
|
]:
|
|
52
53
|
"""Primitives for request/response-based interaction with a server via GrpcAdapter.
|
|
@@ -80,6 +81,7 @@ def grpc_adapter( # pylint: disable=R0913
|
|
|
80
81
|
create_node : Optional[Callable]
|
|
81
82
|
delete_node : Optional[Callable]
|
|
82
83
|
get_run : Optional[Callable]
|
|
84
|
+
get_fab : Optional[Callable]
|
|
83
85
|
"""
|
|
84
86
|
if authentication_keys is not None:
|
|
85
87
|
log(ERROR, "Client authentication is not supported for this transport type.")
|
|
@@ -38,7 +38,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
|
38
38
|
from flwr.common.grpc import create_channel
|
|
39
39
|
from flwr.common.logger import log
|
|
40
40
|
from flwr.common.retry_invoker import RetryInvoker
|
|
41
|
-
from flwr.common.typing import Run
|
|
41
|
+
from flwr.common.typing import Fab, Run
|
|
42
42
|
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
|
43
43
|
ClientMessage,
|
|
44
44
|
Reason,
|
|
@@ -75,6 +75,7 @@ def grpc_connection( # pylint: disable=R0913, R0915
|
|
|
75
75
|
Optional[Callable[[], Optional[int]]],
|
|
76
76
|
Optional[Callable[[], None]],
|
|
77
77
|
Optional[Callable[[int], Run]],
|
|
78
|
+
Optional[Callable[[str], Fab]],
|
|
78
79
|
]
|
|
79
80
|
]:
|
|
80
81
|
"""Establish a gRPC connection to a gRPC server.
|
|
@@ -235,7 +236,7 @@ def grpc_connection( # pylint: disable=R0913, R0915
|
|
|
235
236
|
|
|
236
237
|
try:
|
|
237
238
|
# Yield methods
|
|
238
|
-
yield (receive, send, None, None, None)
|
|
239
|
+
yield (receive, send, None, None, None, None)
|
|
239
240
|
finally:
|
|
240
241
|
# Make sure to have a final
|
|
241
242
|
channel.close()
|
|
@@ -45,7 +45,7 @@ from flwr.common.serde import (
|
|
|
45
45
|
message_to_taskres,
|
|
46
46
|
user_config_from_proto,
|
|
47
47
|
)
|
|
48
|
-
from flwr.common.typing import Run
|
|
48
|
+
from flwr.common.typing import Fab, Run
|
|
49
49
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
50
50
|
CreateNodeRequest,
|
|
51
51
|
DeleteNodeRequest,
|
|
@@ -86,6 +86,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
86
86
|
Optional[Callable[[], Optional[int]]],
|
|
87
87
|
Optional[Callable[[], None]],
|
|
88
88
|
Optional[Callable[[int], Run]],
|
|
89
|
+
Optional[Callable[[str], Fab]],
|
|
89
90
|
]
|
|
90
91
|
]:
|
|
91
92
|
"""Primitives for request/response-based interaction with a server.
|
|
@@ -285,11 +286,16 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
285
286
|
run_id,
|
|
286
287
|
get_run_response.run.fab_id,
|
|
287
288
|
get_run_response.run.fab_version,
|
|
289
|
+
get_run_response.run.fab_hash,
|
|
288
290
|
user_config_from_proto(get_run_response.run.override_config),
|
|
289
291
|
)
|
|
290
292
|
|
|
293
|
+
def get_fab(fab_hash: str) -> Fab:
|
|
294
|
+
# Call FleetAPI
|
|
295
|
+
raise NotImplementedError
|
|
296
|
+
|
|
291
297
|
try:
|
|
292
298
|
# Yield methods
|
|
293
|
-
yield (receive, send, create_node, delete_node, get_run)
|
|
299
|
+
yield (receive, send, create_node, delete_node, get_run, get_fab)
|
|
294
300
|
except Exception as exc: # pylint: disable=broad-except
|
|
295
301
|
log(ERROR, exc)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower AppIO service."""
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""ClientAppIo API servicer."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from logging import DEBUG, ERROR
|
|
20
|
+
from typing import Optional
|
|
21
|
+
|
|
22
|
+
import grpc
|
|
23
|
+
|
|
24
|
+
from flwr.common import Context, Message, typing
|
|
25
|
+
from flwr.common.logger import log
|
|
26
|
+
from flwr.common.serde import (
|
|
27
|
+
clientappstatus_to_proto,
|
|
28
|
+
context_from_proto,
|
|
29
|
+
context_to_proto,
|
|
30
|
+
message_from_proto,
|
|
31
|
+
message_to_proto,
|
|
32
|
+
run_to_proto,
|
|
33
|
+
)
|
|
34
|
+
from flwr.common.typing import Run
|
|
35
|
+
|
|
36
|
+
# pylint: disable=E0611
|
|
37
|
+
from flwr.proto import clientappio_pb2_grpc
|
|
38
|
+
from flwr.proto.clientappio_pb2 import ( # pylint: disable=E0401
|
|
39
|
+
PullClientAppInputsRequest,
|
|
40
|
+
PullClientAppInputsResponse,
|
|
41
|
+
PushClientAppOutputsRequest,
|
|
42
|
+
PushClientAppOutputsResponse,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class ClientAppIoInputs:
|
|
48
|
+
"""Specify the inputs to the ClientApp."""
|
|
49
|
+
|
|
50
|
+
message: Message
|
|
51
|
+
context: Context
|
|
52
|
+
run: Run
|
|
53
|
+
token: int
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class ClientAppIoOutputs:
|
|
58
|
+
"""Specify the outputs from the ClientApp."""
|
|
59
|
+
|
|
60
|
+
message: Message
|
|
61
|
+
context: Context
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# pylint: disable=C0103,W0613,W0201
|
|
65
|
+
class ClientAppIoServicer(clientappio_pb2_grpc.ClientAppIoServicer):
|
|
66
|
+
"""ClientAppIo API servicer."""
|
|
67
|
+
|
|
68
|
+
def __init__(self) -> None:
|
|
69
|
+
self.clientapp_input: Optional[ClientAppIoInputs] = None
|
|
70
|
+
self.clientapp_output: Optional[ClientAppIoOutputs] = None
|
|
71
|
+
|
|
72
|
+
def PullClientAppInputs(
|
|
73
|
+
self, request: PullClientAppInputsRequest, context: grpc.ServicerContext
|
|
74
|
+
) -> PullClientAppInputsResponse:
|
|
75
|
+
"""Pull Message, Context, and Run."""
|
|
76
|
+
log(DEBUG, "ClientAppIo.PullClientAppInputs")
|
|
77
|
+
if self.clientapp_input is None:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"ClientAppIoInputs not set before calling `PullClientAppInputs`."
|
|
80
|
+
)
|
|
81
|
+
if request.token != self.clientapp_input.token:
|
|
82
|
+
context.abort(
|
|
83
|
+
grpc.StatusCode.INVALID_ARGUMENT,
|
|
84
|
+
"Mismatch between ClientApp and SuperNode token",
|
|
85
|
+
)
|
|
86
|
+
return PullClientAppInputsResponse(
|
|
87
|
+
message=message_to_proto(self.clientapp_input.message),
|
|
88
|
+
context=context_to_proto(self.clientapp_input.context),
|
|
89
|
+
run=run_to_proto(self.clientapp_input.run),
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def PushClientAppOutputs(
|
|
93
|
+
self, request: PushClientAppOutputsRequest, context: grpc.ServicerContext
|
|
94
|
+
) -> PushClientAppOutputsResponse:
|
|
95
|
+
"""Push Message and Context."""
|
|
96
|
+
log(DEBUG, "ClientAppIo.PushClientAppOutputs")
|
|
97
|
+
if self.clientapp_output is None:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
"ClientAppIoOutputs not set before calling `PushClientAppOutputs`."
|
|
100
|
+
)
|
|
101
|
+
if self.clientapp_input is None:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
"ClientAppIoInputs not set before calling `PushClientAppOutputs`."
|
|
104
|
+
)
|
|
105
|
+
if request.token != self.clientapp_input.token:
|
|
106
|
+
context.abort(
|
|
107
|
+
grpc.StatusCode.INVALID_ARGUMENT,
|
|
108
|
+
"Mismatch between ClientApp and SuperNode token",
|
|
109
|
+
)
|
|
110
|
+
try:
|
|
111
|
+
# Update Message and Context
|
|
112
|
+
self.clientapp_output.message = message_from_proto(request.message)
|
|
113
|
+
self.clientapp_output.context = context_from_proto(request.context)
|
|
114
|
+
# Set status
|
|
115
|
+
code = typing.ClientAppOutputCode.SUCCESS
|
|
116
|
+
status = typing.ClientAppOutputStatus(code=code, message="Success")
|
|
117
|
+
proto_status = clientappstatus_to_proto(status=status)
|
|
118
|
+
return PushClientAppOutputsResponse(status=proto_status)
|
|
119
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
120
|
+
log(ERROR, "ClientApp failed to push message to SuperNode, %s", e)
|
|
121
|
+
code = typing.ClientAppOutputCode.UNKNOWN_ERROR
|
|
122
|
+
status = typing.ClientAppOutputStatus(code=code, message="Push failed")
|
|
123
|
+
proto_status = clientappstatus_to_proto(status=status)
|
|
124
|
+
return PushClientAppOutputsResponse(status=proto_status)
|
|
125
|
+
|
|
126
|
+
def set_inputs(self, clientapp_input: ClientAppIoInputs) -> None:
|
|
127
|
+
"""Set ClientApp inputs."""
|
|
128
|
+
log(DEBUG, "ClientAppIo.SetInputs")
|
|
129
|
+
if self.clientapp_input is not None or self.clientapp_output is not None:
|
|
130
|
+
raise ValueError(
|
|
131
|
+
"ClientAppIoInputs and ClientAppIoOutputs must not be set before "
|
|
132
|
+
"calling `set_inputs`."
|
|
133
|
+
)
|
|
134
|
+
self.clientapp_input = clientapp_input
|
|
135
|
+
|
|
136
|
+
def get_outputs(self) -> ClientAppIoOutputs:
|
|
137
|
+
"""Get ClientApp outputs."""
|
|
138
|
+
log(DEBUG, "ClientAppIo.GetOutputs")
|
|
139
|
+
if self.clientapp_output is None:
|
|
140
|
+
raise ValueError("ClientAppIoOutputs not set before calling `get_outputs`.")
|
|
141
|
+
# Set outputs to a local variable and clear self.clientapp_output
|
|
142
|
+
output: ClientAppIoOutputs = self.clientapp_output
|
|
143
|
+
self.clientapp_input = None
|
|
144
|
+
self.clientapp_output = None
|
|
145
|
+
return output
|
|
@@ -45,7 +45,7 @@ from flwr.common.serde import (
|
|
|
45
45
|
message_to_taskres,
|
|
46
46
|
user_config_from_proto,
|
|
47
47
|
)
|
|
48
|
-
from flwr.common.typing import Run
|
|
48
|
+
from flwr.common.typing import Fab, Run
|
|
49
49
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
50
50
|
CreateNodeRequest,
|
|
51
51
|
CreateNodeResponse,
|
|
@@ -97,6 +97,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
|
97
97
|
Optional[Callable[[], Optional[int]]],
|
|
98
98
|
Optional[Callable[[], None]],
|
|
99
99
|
Optional[Callable[[int], Run]],
|
|
100
|
+
Optional[Callable[[str], Fab]],
|
|
100
101
|
]
|
|
101
102
|
]:
|
|
102
103
|
"""Primitives for request/response-based interaction with a server.
|
|
@@ -357,17 +358,22 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
|
357
358
|
# Send the request
|
|
358
359
|
res = _request(req, GetRunResponse, PATH_GET_RUN)
|
|
359
360
|
if res is None:
|
|
360
|
-
return Run(run_id, "", "", {})
|
|
361
|
+
return Run(run_id, "", "", "", {})
|
|
361
362
|
|
|
362
363
|
return Run(
|
|
363
364
|
run_id,
|
|
364
365
|
res.run.fab_id,
|
|
365
366
|
res.run.fab_version,
|
|
367
|
+
res.run.fab_hash,
|
|
366
368
|
user_config_from_proto(res.run.override_config),
|
|
367
369
|
)
|
|
368
370
|
|
|
371
|
+
def get_fab(fab_hash: str) -> Fab:
|
|
372
|
+
# Call FleetAPI
|
|
373
|
+
raise NotImplementedError
|
|
374
|
+
|
|
369
375
|
try:
|
|
370
376
|
# Yield methods
|
|
371
|
-
yield (receive, send, create_node, delete_node, get_run)
|
|
377
|
+
yield (receive, send, create_node, delete_node, get_run, get_fab)
|
|
372
378
|
except Exception as exc: # pylint: disable=broad-except
|
|
373
379
|
log(ERROR, exc)
|
flwr/common/config.py
CHANGED
|
@@ -74,10 +74,15 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]:
|
|
|
74
74
|
return config
|
|
75
75
|
|
|
76
76
|
|
|
77
|
-
def
|
|
77
|
+
def fuse_dicts(
|
|
78
78
|
main_dict: UserConfig,
|
|
79
79
|
override_dict: UserConfig,
|
|
80
80
|
) -> UserConfig:
|
|
81
|
+
"""Merge a config with the overrides.
|
|
82
|
+
|
|
83
|
+
Remove the nesting by adding the nested keys as prefixes separated by dots, and fuse
|
|
84
|
+
it with the override dict.
|
|
85
|
+
"""
|
|
81
86
|
fused_dict = main_dict.copy()
|
|
82
87
|
|
|
83
88
|
for key, value in override_dict.items():
|
|
@@ -96,7 +101,7 @@ def get_fused_config_from_dir(
|
|
|
96
101
|
)
|
|
97
102
|
flat_default_config = flatten_dict(default_config)
|
|
98
103
|
|
|
99
|
-
return
|
|
104
|
+
return fuse_dicts(flat_default_config, override_config)
|
|
100
105
|
|
|
101
106
|
|
|
102
107
|
def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
|
flwr/common/record/recordset.py
CHANGED
|
@@ -15,8 +15,10 @@
|
|
|
15
15
|
"""RecordSet."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
18
20
|
from dataclasses import dataclass
|
|
19
|
-
from typing import
|
|
21
|
+
from typing import cast
|
|
20
22
|
|
|
21
23
|
from .configsrecord import ConfigsRecord
|
|
22
24
|
from .metricsrecord import MetricsRecord
|
|
@@ -34,9 +36,9 @@ class RecordSetData:
|
|
|
34
36
|
|
|
35
37
|
def __init__(
|
|
36
38
|
self,
|
|
37
|
-
parameters_records:
|
|
38
|
-
metrics_records:
|
|
39
|
-
configs_records:
|
|
39
|
+
parameters_records: dict[str, ParametersRecord] | None = None,
|
|
40
|
+
metrics_records: dict[str, MetricsRecord] | None = None,
|
|
41
|
+
configs_records: dict[str, ConfigsRecord] | None = None,
|
|
40
42
|
) -> None:
|
|
41
43
|
self.parameters_records = TypedDict[str, ParametersRecord](
|
|
42
44
|
self._check_fn_str, self._check_fn_params
|
|
@@ -88,9 +90,9 @@ class RecordSet:
|
|
|
88
90
|
|
|
89
91
|
def __init__(
|
|
90
92
|
self,
|
|
91
|
-
parameters_records:
|
|
92
|
-
metrics_records:
|
|
93
|
-
configs_records:
|
|
93
|
+
parameters_records: dict[str, ParametersRecord] | None = None,
|
|
94
|
+
metrics_records: dict[str, MetricsRecord] | None = None,
|
|
95
|
+
configs_records: dict[str, ConfigsRecord] | None = None,
|
|
94
96
|
) -> None:
|
|
95
97
|
data = RecordSetData(
|
|
96
98
|
parameters_records=parameters_records,
|
flwr/common/record/typeddict.py
CHANGED
|
@@ -15,99 +15,61 @@
|
|
|
15
15
|
"""Typed dict base class for *Records."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import
|
|
18
|
+
from typing import Callable, Dict, Generic, Iterator, MutableMapping, TypeVar, cast
|
|
19
19
|
|
|
20
20
|
K = TypeVar("K") # Key type
|
|
21
21
|
V = TypeVar("V") # Value type
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
class TypedDict(Generic[K, V]):
|
|
24
|
+
class TypedDict(MutableMapping[K, V], Generic[K, V]):
|
|
25
25
|
"""Typed dictionary."""
|
|
26
26
|
|
|
27
27
|
def __init__(
|
|
28
28
|
self, check_key_fn: Callable[[K], None], check_value_fn: Callable[[V], None]
|
|
29
29
|
):
|
|
30
|
-
self.
|
|
31
|
-
self.
|
|
32
|
-
self.
|
|
30
|
+
self.__dict__["_check_key_fn"] = check_key_fn
|
|
31
|
+
self.__dict__["_check_value_fn"] = check_value_fn
|
|
32
|
+
self.__dict__["_data"] = {}
|
|
33
33
|
|
|
34
34
|
def __setitem__(self, key: K, value: V) -> None:
|
|
35
35
|
"""Set the given key to the given value after type checking."""
|
|
36
36
|
# Check the types of key and value
|
|
37
|
-
self._check_key_fn(key)
|
|
38
|
-
self._check_value_fn(value)
|
|
37
|
+
cast(Callable[[K], None], self.__dict__["_check_key_fn"])(key)
|
|
38
|
+
cast(Callable[[V], None], self.__dict__["_check_value_fn"])(value)
|
|
39
|
+
|
|
39
40
|
# Set key-value pair
|
|
40
|
-
self._data[key] = value
|
|
41
|
+
cast(Dict[K, V], self.__dict__["_data"])[key] = value
|
|
41
42
|
|
|
42
43
|
def __delitem__(self, key: K) -> None:
|
|
43
44
|
"""Remove the item with the specified key."""
|
|
44
|
-
del self._data[key]
|
|
45
|
+
del cast(Dict[K, V], self.__dict__["_data"])[key]
|
|
45
46
|
|
|
46
47
|
def __getitem__(self, item: K) -> V:
|
|
47
48
|
"""Return the value for the specified key."""
|
|
48
|
-
return self._data[item]
|
|
49
|
+
return cast(Dict[K, V], self.__dict__["_data"])[item]
|
|
49
50
|
|
|
50
51
|
def __iter__(self) -> Iterator[K]:
|
|
51
52
|
"""Yield an iterator over the keys of the dictionary."""
|
|
52
|
-
return iter(self._data)
|
|
53
|
+
return iter(cast(Dict[K, V], self.__dict__["_data"]))
|
|
53
54
|
|
|
54
55
|
def __repr__(self) -> str:
|
|
55
56
|
"""Return a string representation of the dictionary."""
|
|
56
|
-
return self._data.__repr__()
|
|
57
|
+
return cast(Dict[K, V], self.__dict__["_data"]).__repr__()
|
|
57
58
|
|
|
58
59
|
def __len__(self) -> int:
|
|
59
60
|
"""Return the number of items in the dictionary."""
|
|
60
|
-
return len(self._data)
|
|
61
|
+
return len(cast(Dict[K, V], self.__dict__["_data"]))
|
|
61
62
|
|
|
62
|
-
def __contains__(self, key:
|
|
63
|
+
def __contains__(self, key: object) -> bool:
|
|
63
64
|
"""Check if the dictionary contains the specified key."""
|
|
64
|
-
return key in self._data
|
|
65
|
+
return key in cast(Dict[K, V], self.__dict__["_data"])
|
|
65
66
|
|
|
66
67
|
def __eq__(self, other: object) -> bool:
|
|
67
68
|
"""Compare this instance to another dictionary or TypedDict."""
|
|
69
|
+
data = cast(Dict[K, V], self.__dict__["_data"])
|
|
68
70
|
if isinstance(other, TypedDict):
|
|
69
|
-
|
|
71
|
+
other_data = cast(Dict[K, V], other.__dict__["_data"])
|
|
72
|
+
return data == other_data
|
|
70
73
|
if isinstance(other, dict):
|
|
71
|
-
return
|
|
74
|
+
return data == other
|
|
72
75
|
return NotImplemented
|
|
73
|
-
|
|
74
|
-
def items(self) -> Iterator[Tuple[K, V]]:
|
|
75
|
-
"""R.items() -> a set-like object providing a view on R's items."""
|
|
76
|
-
return cast(Iterator[Tuple[K, V]], self._data.items())
|
|
77
|
-
|
|
78
|
-
def keys(self) -> Iterator[K]:
|
|
79
|
-
"""R.keys() -> a set-like object providing a view on R's keys."""
|
|
80
|
-
return cast(Iterator[K], self._data.keys())
|
|
81
|
-
|
|
82
|
-
def values(self) -> Iterator[V]:
|
|
83
|
-
"""R.values() -> an object providing a view on R's values."""
|
|
84
|
-
return cast(Iterator[V], self._data.values())
|
|
85
|
-
|
|
86
|
-
def update(self, *args: Any, **kwargs: Any) -> None:
|
|
87
|
-
"""R.update([E, ]**F) -> None.
|
|
88
|
-
|
|
89
|
-
Update R from dict/iterable E and F.
|
|
90
|
-
"""
|
|
91
|
-
for key, value in dict(*args, **kwargs).items():
|
|
92
|
-
self[key] = value
|
|
93
|
-
|
|
94
|
-
def pop(self, key: K) -> V:
|
|
95
|
-
"""R.pop(k[,d]) -> v, remove specified key and return the corresponding value.
|
|
96
|
-
|
|
97
|
-
If key is not found, d is returned if given, otherwise KeyError is raised.
|
|
98
|
-
"""
|
|
99
|
-
return self._data.pop(key)
|
|
100
|
-
|
|
101
|
-
def get(self, key: K, default: V) -> V:
|
|
102
|
-
"""R.get(k[,d]) -> R[k] if k in R, else d.
|
|
103
|
-
|
|
104
|
-
d defaults to None.
|
|
105
|
-
"""
|
|
106
|
-
return self._data.get(key, default)
|
|
107
|
-
|
|
108
|
-
def clear(self) -> None:
|
|
109
|
-
"""R.clear() -> None.
|
|
110
|
-
|
|
111
|
-
Remove all items from R.
|
|
112
|
-
"""
|
|
113
|
-
self._data.clear()
|
flwr/common/recordset_compat.py
CHANGED
|
@@ -145,7 +145,7 @@ def _recordset_to_fit_or_evaluate_ins_components(
|
|
|
145
145
|
# get config dict
|
|
146
146
|
config_record = recordset.configs_records[f"{ins_str}.config"]
|
|
147
147
|
# pylint: disable-next=protected-access
|
|
148
|
-
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record
|
|
148
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
149
149
|
|
|
150
150
|
return parameters, config_dict
|
|
151
151
|
|
|
@@ -213,7 +213,7 @@ def recordset_to_fitres(recordset: RecordSet, keep_input: bool) -> FitRes:
|
|
|
213
213
|
)
|
|
214
214
|
configs_record = recordset.configs_records[f"{ins_str}.metrics"]
|
|
215
215
|
# pylint: disable-next=protected-access
|
|
216
|
-
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record
|
|
216
|
+
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record)
|
|
217
217
|
status = _extract_status_from_recordset(ins_str, recordset)
|
|
218
218
|
|
|
219
219
|
return FitRes(
|
|
@@ -274,7 +274,7 @@ def recordset_to_evaluateres(recordset: RecordSet) -> EvaluateRes:
|
|
|
274
274
|
configs_record = recordset.configs_records[f"{ins_str}.metrics"]
|
|
275
275
|
|
|
276
276
|
# pylint: disable-next=protected-access
|
|
277
|
-
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record
|
|
277
|
+
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record)
|
|
278
278
|
status = _extract_status_from_recordset(ins_str, recordset)
|
|
279
279
|
|
|
280
280
|
return EvaluateRes(
|
|
@@ -314,7 +314,7 @@ def recordset_to_getparametersins(recordset: RecordSet) -> GetParametersIns:
|
|
|
314
314
|
"""Derive GetParametersIns from a RecordSet object."""
|
|
315
315
|
config_record = recordset.configs_records["getparametersins.config"]
|
|
316
316
|
# pylint: disable-next=protected-access
|
|
317
|
-
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record
|
|
317
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
318
318
|
|
|
319
319
|
return GetParametersIns(config=config_dict)
|
|
320
320
|
|
|
@@ -365,7 +365,7 @@ def recordset_to_getpropertiesins(recordset: RecordSet) -> GetPropertiesIns:
|
|
|
365
365
|
"""Derive GetPropertiesIns from a RecordSet object."""
|
|
366
366
|
config_record = recordset.configs_records["getpropertiesins.config"]
|
|
367
367
|
# pylint: disable-next=protected-access
|
|
368
|
-
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record
|
|
368
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
369
369
|
|
|
370
370
|
return GetPropertiesIns(config=config_dict)
|
|
371
371
|
|
|
@@ -384,7 +384,7 @@ def recordset_to_getpropertiesres(recordset: RecordSet) -> GetPropertiesRes:
|
|
|
384
384
|
res_str = "getpropertiesres"
|
|
385
385
|
config_record = recordset.configs_records[f"{res_str}.properties"]
|
|
386
386
|
# pylint: disable-next=protected-access
|
|
387
|
-
properties = _check_mapping_from_recordscalartype_to_scalar(config_record
|
|
387
|
+
properties = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
388
388
|
|
|
389
389
|
status = _extract_status_from_recordset(res_str, recordset=recordset)
|
|
390
390
|
|