flwr-nightly 1.11.0.dev20240813__py3-none-any.whl → 1.11.0.dev20240816__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/config_utils.py +2 -2
- flwr/cli/run/run.py +11 -6
- flwr/client/app.py +97 -12
- flwr/client/grpc_rere_client/connection.py +9 -1
- flwr/client/process/__init__.py +15 -0
- flwr/client/process/clientappio_servicer.py +144 -0
- flwr/client/process/process.py +143 -0
- flwr/client/process/utils.py +108 -0
- flwr/client/rest_client/connection.py +16 -3
- flwr/client/supernode/app.py +25 -97
- flwr/common/config.py +7 -2
- flwr/common/record/recordset.py +9 -7
- flwr/common/record/typeddict.py +20 -58
- flwr/common/recordset_compat.py +6 -6
- flwr/common/serde.py +24 -2
- flwr/common/typing.py +1 -0
- flwr/proto/exec_pb2.py +16 -15
- flwr/proto/exec_pb2.pyi +7 -4
- flwr/proto/message_pb2.py +2 -2
- flwr/proto/message_pb2.pyi +4 -1
- flwr/server/app.py +15 -0
- flwr/server/driver/grpc_driver.py +1 -0
- flwr/server/run_serverapp.py +18 -2
- flwr/server/server.py +3 -1
- flwr/server/superlink/driver/driver_grpc.py +3 -0
- flwr/server/superlink/driver/driver_servicer.py +32 -4
- flwr/server/superlink/ffs/disk_ffs.py +6 -3
- flwr/server/superlink/ffs/ffs.py +3 -3
- flwr/server/superlink/ffs/ffs_factory.py +47 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +9 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +16 -1
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/superlink/state/in_memory_state.py +7 -5
- flwr/server/superlink/state/sqlite_state.py +17 -7
- flwr/server/superlink/state/state.py +4 -3
- flwr/server/workflow/default_workflows.py +3 -1
- flwr/simulation/run_simulation.py +4 -1
- flwr/superexec/deployment.py +8 -9
- flwr/superexec/exec_servicer.py +1 -1
- {flwr_nightly-1.11.0.dev20240813.dist-info → flwr_nightly-1.11.0.dev20240816.dist-info}/METADATA +1 -1
- {flwr_nightly-1.11.0.dev20240813.dist-info → flwr_nightly-1.11.0.dev20240816.dist-info}/RECORD +44 -39
- {flwr_nightly-1.11.0.dev20240813.dist-info → flwr_nightly-1.11.0.dev20240816.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240813.dist-info → flwr_nightly-1.11.0.dev20240816.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.11.0.dev20240813.dist-info → flwr_nightly-1.11.0.dev20240816.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower ClientApp loading utils."""
|
|
16
|
+
|
|
17
|
+
from logging import DEBUG
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Callable, Optional
|
|
20
|
+
|
|
21
|
+
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
22
|
+
from flwr.common.config import (
|
|
23
|
+
get_flwr_dir,
|
|
24
|
+
get_metadata_from_config,
|
|
25
|
+
get_project_config,
|
|
26
|
+
get_project_dir,
|
|
27
|
+
)
|
|
28
|
+
from flwr.common.logger import log
|
|
29
|
+
from flwr.common.object_ref import load_app, validate
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_load_client_app_fn(
|
|
33
|
+
default_app_ref: str,
|
|
34
|
+
app_path: Optional[str],
|
|
35
|
+
multi_app: bool,
|
|
36
|
+
flwr_dir: Optional[str] = None,
|
|
37
|
+
) -> Callable[[str, str], ClientApp]:
|
|
38
|
+
"""Get the load_client_app_fn function.
|
|
39
|
+
|
|
40
|
+
If `multi_app` is True, this function loads the specified ClientApp
|
|
41
|
+
based on `fab_id` and `fab_version`. If `fab_id` is empty, a default
|
|
42
|
+
ClientApp will be loaded.
|
|
43
|
+
|
|
44
|
+
If `multi_app` is False, it ignores `fab_id` and `fab_version` and
|
|
45
|
+
loads a default ClientApp.
|
|
46
|
+
"""
|
|
47
|
+
if not multi_app:
|
|
48
|
+
log(
|
|
49
|
+
DEBUG,
|
|
50
|
+
"Flower SuperNode will load and validate ClientApp `%s`",
|
|
51
|
+
default_app_ref,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
valid, error_msg = validate(default_app_ref, project_dir=app_path)
|
|
55
|
+
if not valid and error_msg:
|
|
56
|
+
raise LoadClientAppError(error_msg) from None
|
|
57
|
+
|
|
58
|
+
def _load(fab_id: str, fab_version: str) -> ClientApp:
|
|
59
|
+
runtime_app_dir = Path(app_path if app_path else "").absolute()
|
|
60
|
+
# If multi-app feature is disabled
|
|
61
|
+
if not multi_app:
|
|
62
|
+
# Set app reference
|
|
63
|
+
client_app_ref = default_app_ref
|
|
64
|
+
# If multi-app feature is enabled but app directory is provided
|
|
65
|
+
elif app_path is not None:
|
|
66
|
+
config = get_project_config(runtime_app_dir)
|
|
67
|
+
this_fab_version, this_fab_id = get_metadata_from_config(config)
|
|
68
|
+
|
|
69
|
+
if this_fab_version != fab_version or this_fab_id != fab_id:
|
|
70
|
+
raise LoadClientAppError(
|
|
71
|
+
f"FAB ID or version mismatch: Expected FAB ID '{this_fab_id}' and "
|
|
72
|
+
f"FAB version '{this_fab_version}', but received FAB ID '{fab_id}' "
|
|
73
|
+
f"and FAB version '{fab_version}'.",
|
|
74
|
+
) from None
|
|
75
|
+
|
|
76
|
+
# log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.")
|
|
77
|
+
|
|
78
|
+
# Set app reference
|
|
79
|
+
client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
|
|
80
|
+
# If multi-app feature is enabled
|
|
81
|
+
else:
|
|
82
|
+
try:
|
|
83
|
+
runtime_app_dir = get_project_dir(
|
|
84
|
+
fab_id, fab_version, get_flwr_dir(flwr_dir)
|
|
85
|
+
)
|
|
86
|
+
config = get_project_config(runtime_app_dir)
|
|
87
|
+
except Exception as e:
|
|
88
|
+
raise LoadClientAppError("Failed to load ClientApp") from e
|
|
89
|
+
|
|
90
|
+
# Set app reference
|
|
91
|
+
client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
|
|
92
|
+
|
|
93
|
+
# Load ClientApp
|
|
94
|
+
log(
|
|
95
|
+
DEBUG,
|
|
96
|
+
"Loading ClientApp `%s`",
|
|
97
|
+
client_app_ref,
|
|
98
|
+
)
|
|
99
|
+
client_app = load_app(client_app_ref, LoadClientAppError, runtime_app_dir)
|
|
100
|
+
|
|
101
|
+
if not isinstance(client_app, ClientApp):
|
|
102
|
+
raise LoadClientAppError(
|
|
103
|
+
f"Attribute {client_app_ref} is not of type {ClientApp}",
|
|
104
|
+
) from None
|
|
105
|
+
|
|
106
|
+
return client_app
|
|
107
|
+
|
|
108
|
+
return _load
|
|
@@ -46,6 +46,7 @@ from flwr.common.serde import (
|
|
|
46
46
|
user_config_from_proto,
|
|
47
47
|
)
|
|
48
48
|
from flwr.common.typing import Fab, Run
|
|
49
|
+
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
49
50
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
50
51
|
CreateNodeRequest,
|
|
51
52
|
CreateNodeResponse,
|
|
@@ -74,6 +75,7 @@ PATH_PULL_TASK_INS: str = "api/v0/fleet/pull-task-ins"
|
|
|
74
75
|
PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
|
|
75
76
|
PATH_PING: str = "api/v0/fleet/ping"
|
|
76
77
|
PATH_GET_RUN: str = "/api/v0/fleet/get-run"
|
|
78
|
+
PATH_GET_FAB: str = "/api/v0/fleet/get-fab"
|
|
77
79
|
|
|
78
80
|
T = TypeVar("T", bound=GrpcMessage)
|
|
79
81
|
|
|
@@ -358,18 +360,29 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
|
358
360
|
# Send the request
|
|
359
361
|
res = _request(req, GetRunResponse, PATH_GET_RUN)
|
|
360
362
|
if res is None:
|
|
361
|
-
return Run(run_id, "", "", {})
|
|
363
|
+
return Run(run_id, "", "", "", {})
|
|
362
364
|
|
|
363
365
|
return Run(
|
|
364
366
|
run_id,
|
|
365
367
|
res.run.fab_id,
|
|
366
368
|
res.run.fab_version,
|
|
369
|
+
res.run.fab_hash,
|
|
367
370
|
user_config_from_proto(res.run.override_config),
|
|
368
371
|
)
|
|
369
372
|
|
|
370
373
|
def get_fab(fab_hash: str) -> Fab:
|
|
371
|
-
#
|
|
372
|
-
|
|
374
|
+
# Construct the request
|
|
375
|
+
req = GetFabRequest(hash_str=fab_hash)
|
|
376
|
+
|
|
377
|
+
# Send the request
|
|
378
|
+
res = _request(req, GetFabResponse, PATH_GET_FAB)
|
|
379
|
+
if res is None:
|
|
380
|
+
return Fab("", b"")
|
|
381
|
+
|
|
382
|
+
return Fab(
|
|
383
|
+
res.fab.hash_str,
|
|
384
|
+
res.fab.content,
|
|
385
|
+
)
|
|
373
386
|
|
|
374
387
|
try:
|
|
375
388
|
# Yield methods
|
flwr/client/supernode/app.py
CHANGED
|
@@ -18,7 +18,7 @@ import argparse
|
|
|
18
18
|
import sys
|
|
19
19
|
from logging import DEBUG, INFO, WARN
|
|
20
20
|
from pathlib import Path
|
|
21
|
-
from typing import
|
|
21
|
+
from typing import Optional, Tuple
|
|
22
22
|
|
|
23
23
|
from cryptography.exceptions import UnsupportedAlgorithm
|
|
24
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
@@ -27,15 +27,8 @@ from cryptography.hazmat.primitives.serialization import (
|
|
|
27
27
|
load_ssh_public_key,
|
|
28
28
|
)
|
|
29
29
|
|
|
30
|
-
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
31
30
|
from flwr.common import EventType, event
|
|
32
|
-
from flwr.common.config import
|
|
33
|
-
get_flwr_dir,
|
|
34
|
-
get_metadata_from_config,
|
|
35
|
-
get_project_config,
|
|
36
|
-
get_project_dir,
|
|
37
|
-
parse_config_args,
|
|
38
|
-
)
|
|
31
|
+
from flwr.common.config import parse_config_args
|
|
39
32
|
from flwr.common.constant import (
|
|
40
33
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
41
34
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
@@ -43,9 +36,10 @@ from flwr.common.constant import (
|
|
|
43
36
|
)
|
|
44
37
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
45
38
|
from flwr.common.logger import log, warn_deprecated_feature
|
|
46
|
-
from flwr.common.object_ref import load_app, validate
|
|
47
39
|
|
|
48
|
-
from ..app import
|
|
40
|
+
from ..app import start_client_internal
|
|
41
|
+
from ..process.process import run_clientapp
|
|
42
|
+
from ..process.utils import get_load_client_app_fn
|
|
49
43
|
|
|
50
44
|
ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092"
|
|
51
45
|
|
|
@@ -61,7 +55,7 @@ def run_supernode() -> None:
|
|
|
61
55
|
_warn_deprecated_server_arg(args)
|
|
62
56
|
|
|
63
57
|
root_certificates = _get_certificates(args)
|
|
64
|
-
load_fn =
|
|
58
|
+
load_fn = get_load_client_app_fn(
|
|
65
59
|
default_app_ref="",
|
|
66
60
|
app_path=args.app,
|
|
67
61
|
flwr_dir=args.flwr_dir,
|
|
@@ -69,7 +63,7 @@ def run_supernode() -> None:
|
|
|
69
63
|
)
|
|
70
64
|
authentication_keys = _try_setup_client_authentication(args)
|
|
71
65
|
|
|
72
|
-
|
|
66
|
+
start_client_internal(
|
|
73
67
|
server_address=args.superlink,
|
|
74
68
|
load_client_app_fn=load_fn,
|
|
75
69
|
transport=args.transport,
|
|
@@ -79,7 +73,8 @@ def run_supernode() -> None:
|
|
|
79
73
|
max_retries=args.max_retries,
|
|
80
74
|
max_wait_time=args.max_wait_time,
|
|
81
75
|
node_config=parse_config_args([args.node_config]),
|
|
82
|
-
|
|
76
|
+
isolate=args.isolate,
|
|
77
|
+
supernode_address=args.supernode_address,
|
|
83
78
|
)
|
|
84
79
|
|
|
85
80
|
# Graceful shutdown
|
|
@@ -99,14 +94,14 @@ def run_client_app() -> None:
|
|
|
99
94
|
_warn_deprecated_server_arg(args)
|
|
100
95
|
|
|
101
96
|
root_certificates = _get_certificates(args)
|
|
102
|
-
load_fn =
|
|
97
|
+
load_fn = get_load_client_app_fn(
|
|
103
98
|
default_app_ref=getattr(args, "client-app"),
|
|
104
99
|
app_path=args.dir,
|
|
105
100
|
multi_app=False,
|
|
106
101
|
)
|
|
107
102
|
authentication_keys = _try_setup_client_authentication(args)
|
|
108
103
|
|
|
109
|
-
|
|
104
|
+
start_client_internal(
|
|
110
105
|
server_address=args.superlink,
|
|
111
106
|
node_config=parse_config_args([args.node_config]),
|
|
112
107
|
load_client_app_fn=load_fn,
|
|
@@ -128,7 +123,7 @@ def flwr_clientapp() -> None:
|
|
|
128
123
|
description="Run a Flower ClientApp",
|
|
129
124
|
)
|
|
130
125
|
parser.add_argument(
|
|
131
|
-
"--
|
|
126
|
+
"--supernode",
|
|
132
127
|
help="Address of SuperNode ClientAppIo gRPC servicer",
|
|
133
128
|
)
|
|
134
129
|
parser.add_argument(
|
|
@@ -140,9 +135,10 @@ def flwr_clientapp() -> None:
|
|
|
140
135
|
DEBUG,
|
|
141
136
|
"Staring isolated `ClientApp` connected to SuperNode ClientAppIo at %s "
|
|
142
137
|
"with the token %s",
|
|
143
|
-
args.
|
|
138
|
+
args.supernode,
|
|
144
139
|
args.token,
|
|
145
140
|
)
|
|
141
|
+
run_clientapp(supernode=args.supernode, token=int(args.token))
|
|
146
142
|
|
|
147
143
|
|
|
148
144
|
def _warn_deprecated_server_arg(args: argparse.Namespace) -> None:
|
|
@@ -200,85 +196,6 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
|
|
|
200
196
|
return root_certificates
|
|
201
197
|
|
|
202
198
|
|
|
203
|
-
def _get_load_client_app_fn(
|
|
204
|
-
default_app_ref: str,
|
|
205
|
-
app_path: Optional[str],
|
|
206
|
-
multi_app: bool,
|
|
207
|
-
flwr_dir: Optional[str] = None,
|
|
208
|
-
) -> Callable[[str, str], ClientApp]:
|
|
209
|
-
"""Get the load_client_app_fn function.
|
|
210
|
-
|
|
211
|
-
If `multi_app` is True, this function loads the specified ClientApp
|
|
212
|
-
based on `fab_id` and `fab_version`. If `fab_id` is empty, a default
|
|
213
|
-
ClientApp will be loaded.
|
|
214
|
-
|
|
215
|
-
If `multi_app` is False, it ignores `fab_id` and `fab_version` and
|
|
216
|
-
loads a default ClientApp.
|
|
217
|
-
"""
|
|
218
|
-
if not multi_app:
|
|
219
|
-
log(
|
|
220
|
-
DEBUG,
|
|
221
|
-
"Flower SuperNode will load and validate ClientApp `%s`",
|
|
222
|
-
default_app_ref,
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
valid, error_msg = validate(default_app_ref, project_dir=app_path)
|
|
226
|
-
if not valid and error_msg:
|
|
227
|
-
raise LoadClientAppError(error_msg) from None
|
|
228
|
-
|
|
229
|
-
def _load(fab_id: str, fab_version: str) -> ClientApp:
|
|
230
|
-
runtime_app_dir = Path(app_path if app_path else "").absolute()
|
|
231
|
-
# If multi-app feature is disabled
|
|
232
|
-
if not multi_app:
|
|
233
|
-
# Set app reference
|
|
234
|
-
client_app_ref = default_app_ref
|
|
235
|
-
# If multi-app feature is enabled but app directory is provided
|
|
236
|
-
elif app_path is not None:
|
|
237
|
-
config = get_project_config(runtime_app_dir)
|
|
238
|
-
this_fab_version, this_fab_id = get_metadata_from_config(config)
|
|
239
|
-
|
|
240
|
-
if this_fab_version != fab_version or this_fab_id != fab_id:
|
|
241
|
-
raise LoadClientAppError(
|
|
242
|
-
f"FAB ID or version mismatch: Expected FAB ID '{this_fab_id}' and "
|
|
243
|
-
f"FAB version '{this_fab_version}', but received FAB ID '{fab_id}' "
|
|
244
|
-
f"and FAB version '{fab_version}'.",
|
|
245
|
-
) from None
|
|
246
|
-
|
|
247
|
-
# log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.")
|
|
248
|
-
|
|
249
|
-
# Set app reference
|
|
250
|
-
client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
|
|
251
|
-
# If multi-app feature is enabled
|
|
252
|
-
else:
|
|
253
|
-
try:
|
|
254
|
-
runtime_app_dir = get_project_dir(
|
|
255
|
-
fab_id, fab_version, get_flwr_dir(flwr_dir)
|
|
256
|
-
)
|
|
257
|
-
config = get_project_config(runtime_app_dir)
|
|
258
|
-
except Exception as e:
|
|
259
|
-
raise LoadClientAppError("Failed to load ClientApp") from e
|
|
260
|
-
|
|
261
|
-
# Set app reference
|
|
262
|
-
client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
|
|
263
|
-
|
|
264
|
-
# Load ClientApp
|
|
265
|
-
log(
|
|
266
|
-
DEBUG,
|
|
267
|
-
"Loading ClientApp `%s`",
|
|
268
|
-
client_app_ref,
|
|
269
|
-
)
|
|
270
|
-
client_app = load_app(client_app_ref, LoadClientAppError, runtime_app_dir)
|
|
271
|
-
|
|
272
|
-
if not isinstance(client_app, ClientApp):
|
|
273
|
-
raise LoadClientAppError(
|
|
274
|
-
f"Attribute {client_app_ref} is not of type {ClientApp}",
|
|
275
|
-
) from None
|
|
276
|
-
|
|
277
|
-
return client_app
|
|
278
|
-
|
|
279
|
-
return _load
|
|
280
|
-
|
|
281
|
-
|
|
282
199
|
def _parse_args_run_supernode() -> argparse.ArgumentParser:
|
|
283
200
|
"""Parse flower-supernode command line arguments."""
|
|
284
201
|
parser = argparse.ArgumentParser(
|
|
@@ -308,6 +225,17 @@ def _parse_args_run_supernode() -> argparse.ArgumentParser:
|
|
|
308
225
|
- `$HOME/.flwr/` in all other cases
|
|
309
226
|
""",
|
|
310
227
|
)
|
|
228
|
+
parser.add_argument(
|
|
229
|
+
"--isolate",
|
|
230
|
+
action="store_true",
|
|
231
|
+
help="Run `ClientApp` in an isolated subprocess. By default, `ClientApp` "
|
|
232
|
+
"runs in the same process that executes the SuperNode.",
|
|
233
|
+
)
|
|
234
|
+
parser.add_argument(
|
|
235
|
+
"--supernode-address",
|
|
236
|
+
default="0.0.0.0:9094",
|
|
237
|
+
help="Set the SuperNode gRPC server address. Defaults to `0.0.0.0:9094`.",
|
|
238
|
+
)
|
|
311
239
|
|
|
312
240
|
return parser
|
|
313
241
|
|
flwr/common/config.py
CHANGED
|
@@ -74,10 +74,15 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]:
|
|
|
74
74
|
return config
|
|
75
75
|
|
|
76
76
|
|
|
77
|
-
def
|
|
77
|
+
def fuse_dicts(
|
|
78
78
|
main_dict: UserConfig,
|
|
79
79
|
override_dict: UserConfig,
|
|
80
80
|
) -> UserConfig:
|
|
81
|
+
"""Merge a config with the overrides.
|
|
82
|
+
|
|
83
|
+
Remove the nesting by adding the nested keys as prefixes separated by dots, and fuse
|
|
84
|
+
it with the override dict.
|
|
85
|
+
"""
|
|
81
86
|
fused_dict = main_dict.copy()
|
|
82
87
|
|
|
83
88
|
for key, value in override_dict.items():
|
|
@@ -96,7 +101,7 @@ def get_fused_config_from_dir(
|
|
|
96
101
|
)
|
|
97
102
|
flat_default_config = flatten_dict(default_config)
|
|
98
103
|
|
|
99
|
-
return
|
|
104
|
+
return fuse_dicts(flat_default_config, override_config)
|
|
100
105
|
|
|
101
106
|
|
|
102
107
|
def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
|
flwr/common/record/recordset.py
CHANGED
|
@@ -15,8 +15,10 @@
|
|
|
15
15
|
"""RecordSet."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
18
20
|
from dataclasses import dataclass
|
|
19
|
-
from typing import
|
|
21
|
+
from typing import cast
|
|
20
22
|
|
|
21
23
|
from .configsrecord import ConfigsRecord
|
|
22
24
|
from .metricsrecord import MetricsRecord
|
|
@@ -34,9 +36,9 @@ class RecordSetData:
|
|
|
34
36
|
|
|
35
37
|
def __init__(
|
|
36
38
|
self,
|
|
37
|
-
parameters_records:
|
|
38
|
-
metrics_records:
|
|
39
|
-
configs_records:
|
|
39
|
+
parameters_records: dict[str, ParametersRecord] | None = None,
|
|
40
|
+
metrics_records: dict[str, MetricsRecord] | None = None,
|
|
41
|
+
configs_records: dict[str, ConfigsRecord] | None = None,
|
|
40
42
|
) -> None:
|
|
41
43
|
self.parameters_records = TypedDict[str, ParametersRecord](
|
|
42
44
|
self._check_fn_str, self._check_fn_params
|
|
@@ -88,9 +90,9 @@ class RecordSet:
|
|
|
88
90
|
|
|
89
91
|
def __init__(
|
|
90
92
|
self,
|
|
91
|
-
parameters_records:
|
|
92
|
-
metrics_records:
|
|
93
|
-
configs_records:
|
|
93
|
+
parameters_records: dict[str, ParametersRecord] | None = None,
|
|
94
|
+
metrics_records: dict[str, MetricsRecord] | None = None,
|
|
95
|
+
configs_records: dict[str, ConfigsRecord] | None = None,
|
|
94
96
|
) -> None:
|
|
95
97
|
data = RecordSetData(
|
|
96
98
|
parameters_records=parameters_records,
|
flwr/common/record/typeddict.py
CHANGED
|
@@ -15,99 +15,61 @@
|
|
|
15
15
|
"""Typed dict base class for *Records."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import
|
|
18
|
+
from typing import Callable, Dict, Generic, Iterator, MutableMapping, TypeVar, cast
|
|
19
19
|
|
|
20
20
|
K = TypeVar("K") # Key type
|
|
21
21
|
V = TypeVar("V") # Value type
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
class TypedDict(Generic[K, V]):
|
|
24
|
+
class TypedDict(MutableMapping[K, V], Generic[K, V]):
|
|
25
25
|
"""Typed dictionary."""
|
|
26
26
|
|
|
27
27
|
def __init__(
|
|
28
28
|
self, check_key_fn: Callable[[K], None], check_value_fn: Callable[[V], None]
|
|
29
29
|
):
|
|
30
|
-
self.
|
|
31
|
-
self.
|
|
32
|
-
self.
|
|
30
|
+
self.__dict__["_check_key_fn"] = check_key_fn
|
|
31
|
+
self.__dict__["_check_value_fn"] = check_value_fn
|
|
32
|
+
self.__dict__["_data"] = {}
|
|
33
33
|
|
|
34
34
|
def __setitem__(self, key: K, value: V) -> None:
|
|
35
35
|
"""Set the given key to the given value after type checking."""
|
|
36
36
|
# Check the types of key and value
|
|
37
|
-
self._check_key_fn(key)
|
|
38
|
-
self._check_value_fn(value)
|
|
37
|
+
cast(Callable[[K], None], self.__dict__["_check_key_fn"])(key)
|
|
38
|
+
cast(Callable[[V], None], self.__dict__["_check_value_fn"])(value)
|
|
39
|
+
|
|
39
40
|
# Set key-value pair
|
|
40
|
-
self._data[key] = value
|
|
41
|
+
cast(Dict[K, V], self.__dict__["_data"])[key] = value
|
|
41
42
|
|
|
42
43
|
def __delitem__(self, key: K) -> None:
|
|
43
44
|
"""Remove the item with the specified key."""
|
|
44
|
-
del self._data[key]
|
|
45
|
+
del cast(Dict[K, V], self.__dict__["_data"])[key]
|
|
45
46
|
|
|
46
47
|
def __getitem__(self, item: K) -> V:
|
|
47
48
|
"""Return the value for the specified key."""
|
|
48
|
-
return self._data[item]
|
|
49
|
+
return cast(Dict[K, V], self.__dict__["_data"])[item]
|
|
49
50
|
|
|
50
51
|
def __iter__(self) -> Iterator[K]:
|
|
51
52
|
"""Yield an iterator over the keys of the dictionary."""
|
|
52
|
-
return iter(self._data)
|
|
53
|
+
return iter(cast(Dict[K, V], self.__dict__["_data"]))
|
|
53
54
|
|
|
54
55
|
def __repr__(self) -> str:
|
|
55
56
|
"""Return a string representation of the dictionary."""
|
|
56
|
-
return self._data.__repr__()
|
|
57
|
+
return cast(Dict[K, V], self.__dict__["_data"]).__repr__()
|
|
57
58
|
|
|
58
59
|
def __len__(self) -> int:
|
|
59
60
|
"""Return the number of items in the dictionary."""
|
|
60
|
-
return len(self._data)
|
|
61
|
+
return len(cast(Dict[K, V], self.__dict__["_data"]))
|
|
61
62
|
|
|
62
|
-
def __contains__(self, key:
|
|
63
|
+
def __contains__(self, key: object) -> bool:
|
|
63
64
|
"""Check if the dictionary contains the specified key."""
|
|
64
|
-
return key in self._data
|
|
65
|
+
return key in cast(Dict[K, V], self.__dict__["_data"])
|
|
65
66
|
|
|
66
67
|
def __eq__(self, other: object) -> bool:
|
|
67
68
|
"""Compare this instance to another dictionary or TypedDict."""
|
|
69
|
+
data = cast(Dict[K, V], self.__dict__["_data"])
|
|
68
70
|
if isinstance(other, TypedDict):
|
|
69
|
-
|
|
71
|
+
other_data = cast(Dict[K, V], other.__dict__["_data"])
|
|
72
|
+
return data == other_data
|
|
70
73
|
if isinstance(other, dict):
|
|
71
|
-
return
|
|
74
|
+
return data == other
|
|
72
75
|
return NotImplemented
|
|
73
|
-
|
|
74
|
-
def items(self) -> Iterator[Tuple[K, V]]:
|
|
75
|
-
"""R.items() -> a set-like object providing a view on R's items."""
|
|
76
|
-
return cast(Iterator[Tuple[K, V]], self._data.items())
|
|
77
|
-
|
|
78
|
-
def keys(self) -> Iterator[K]:
|
|
79
|
-
"""R.keys() -> a set-like object providing a view on R's keys."""
|
|
80
|
-
return cast(Iterator[K], self._data.keys())
|
|
81
|
-
|
|
82
|
-
def values(self) -> Iterator[V]:
|
|
83
|
-
"""R.values() -> an object providing a view on R's values."""
|
|
84
|
-
return cast(Iterator[V], self._data.values())
|
|
85
|
-
|
|
86
|
-
def update(self, *args: Any, **kwargs: Any) -> None:
|
|
87
|
-
"""R.update([E, ]**F) -> None.
|
|
88
|
-
|
|
89
|
-
Update R from dict/iterable E and F.
|
|
90
|
-
"""
|
|
91
|
-
for key, value in dict(*args, **kwargs).items():
|
|
92
|
-
self[key] = value
|
|
93
|
-
|
|
94
|
-
def pop(self, key: K) -> V:
|
|
95
|
-
"""R.pop(k[,d]) -> v, remove specified key and return the corresponding value.
|
|
96
|
-
|
|
97
|
-
If key is not found, d is returned if given, otherwise KeyError is raised.
|
|
98
|
-
"""
|
|
99
|
-
return self._data.pop(key)
|
|
100
|
-
|
|
101
|
-
def get(self, key: K, default: V) -> V:
|
|
102
|
-
"""R.get(k[,d]) -> R[k] if k in R, else d.
|
|
103
|
-
|
|
104
|
-
d defaults to None.
|
|
105
|
-
"""
|
|
106
|
-
return self._data.get(key, default)
|
|
107
|
-
|
|
108
|
-
def clear(self) -> None:
|
|
109
|
-
"""R.clear() -> None.
|
|
110
|
-
|
|
111
|
-
Remove all items from R.
|
|
112
|
-
"""
|
|
113
|
-
self._data.clear()
|
flwr/common/recordset_compat.py
CHANGED
|
@@ -145,7 +145,7 @@ def _recordset_to_fit_or_evaluate_ins_components(
|
|
|
145
145
|
# get config dict
|
|
146
146
|
config_record = recordset.configs_records[f"{ins_str}.config"]
|
|
147
147
|
# pylint: disable-next=protected-access
|
|
148
|
-
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record
|
|
148
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
149
149
|
|
|
150
150
|
return parameters, config_dict
|
|
151
151
|
|
|
@@ -213,7 +213,7 @@ def recordset_to_fitres(recordset: RecordSet, keep_input: bool) -> FitRes:
|
|
|
213
213
|
)
|
|
214
214
|
configs_record = recordset.configs_records[f"{ins_str}.metrics"]
|
|
215
215
|
# pylint: disable-next=protected-access
|
|
216
|
-
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record
|
|
216
|
+
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record)
|
|
217
217
|
status = _extract_status_from_recordset(ins_str, recordset)
|
|
218
218
|
|
|
219
219
|
return FitRes(
|
|
@@ -274,7 +274,7 @@ def recordset_to_evaluateres(recordset: RecordSet) -> EvaluateRes:
|
|
|
274
274
|
configs_record = recordset.configs_records[f"{ins_str}.metrics"]
|
|
275
275
|
|
|
276
276
|
# pylint: disable-next=protected-access
|
|
277
|
-
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record
|
|
277
|
+
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record)
|
|
278
278
|
status = _extract_status_from_recordset(ins_str, recordset)
|
|
279
279
|
|
|
280
280
|
return EvaluateRes(
|
|
@@ -314,7 +314,7 @@ def recordset_to_getparametersins(recordset: RecordSet) -> GetParametersIns:
|
|
|
314
314
|
"""Derive GetParametersIns from a RecordSet object."""
|
|
315
315
|
config_record = recordset.configs_records["getparametersins.config"]
|
|
316
316
|
# pylint: disable-next=protected-access
|
|
317
|
-
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record
|
|
317
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
318
318
|
|
|
319
319
|
return GetParametersIns(config=config_dict)
|
|
320
320
|
|
|
@@ -365,7 +365,7 @@ def recordset_to_getpropertiesins(recordset: RecordSet) -> GetPropertiesIns:
|
|
|
365
365
|
"""Derive GetPropertiesIns from a RecordSet object."""
|
|
366
366
|
config_record = recordset.configs_records["getpropertiesins.config"]
|
|
367
367
|
# pylint: disable-next=protected-access
|
|
368
|
-
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record
|
|
368
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
369
369
|
|
|
370
370
|
return GetPropertiesIns(config=config_dict)
|
|
371
371
|
|
|
@@ -384,7 +384,7 @@ def recordset_to_getpropertiesres(recordset: RecordSet) -> GetPropertiesRes:
|
|
|
384
384
|
res_str = "getpropertiesres"
|
|
385
385
|
config_record = recordset.configs_records[f"{res_str}.properties"]
|
|
386
386
|
# pylint: disable-next=protected-access
|
|
387
|
-
properties = _check_mapping_from_recordscalartype_to_scalar(config_record
|
|
387
|
+
properties = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
388
388
|
|
|
389
389
|
status = _extract_status_from_recordset(res_str, recordset=recordset)
|
|
390
390
|
|
flwr/common/serde.py
CHANGED
|
@@ -22,6 +22,7 @@ from google.protobuf.message import Message as GrpcMessage
|
|
|
22
22
|
# pylint: disable=E0611
|
|
23
23
|
from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
|
|
24
24
|
from flwr.proto.error_pb2 import Error as ProtoError
|
|
25
|
+
from flwr.proto.fab_pb2 import Fab as ProtoFab
|
|
25
26
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
|
26
27
|
from flwr.proto.message_pb2 import Message as ProtoMessage
|
|
27
28
|
from flwr.proto.message_pb2 import Metadata as ProtoMetadata
|
|
@@ -686,6 +687,19 @@ def message_from_taskres(taskres: TaskRes) -> Message:
|
|
|
686
687
|
return message
|
|
687
688
|
|
|
688
689
|
|
|
690
|
+
# === FAB ===
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def fab_to_proto(fab: typing.Fab) -> ProtoFab:
|
|
694
|
+
"""Create a proto Fab object from a Python Fab."""
|
|
695
|
+
return ProtoFab(hash_str=fab.hash_str, content=fab.content)
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
def fab_from_proto(fab: ProtoFab) -> typing.Fab:
|
|
699
|
+
"""Create a Python Fab object from a proto Fab."""
|
|
700
|
+
return typing.Fab(fab.hash_str, fab.content)
|
|
701
|
+
|
|
702
|
+
|
|
689
703
|
# === User configs ===
|
|
690
704
|
|
|
691
705
|
|
|
@@ -745,6 +759,7 @@ def metadata_to_proto(metadata: Metadata) -> ProtoMetadata:
|
|
|
745
759
|
group_id=metadata.group_id,
|
|
746
760
|
ttl=metadata.ttl,
|
|
747
761
|
message_type=metadata.message_type,
|
|
762
|
+
created_at=metadata.created_at,
|
|
748
763
|
)
|
|
749
764
|
return proto
|
|
750
765
|
|
|
@@ -771,7 +786,9 @@ def message_to_proto(message: Message) -> ProtoMessage:
|
|
|
771
786
|
"""Serialize `Message` to ProtoBuf."""
|
|
772
787
|
proto = ProtoMessage(
|
|
773
788
|
metadata=metadata_to_proto(message.metadata),
|
|
774
|
-
content=
|
|
789
|
+
content=(
|
|
790
|
+
recordset_to_proto(message.content) if message.has_content() else None
|
|
791
|
+
),
|
|
775
792
|
error=error_to_proto(message.error) if message.has_error() else None,
|
|
776
793
|
)
|
|
777
794
|
return proto
|
|
@@ -779,6 +796,7 @@ def message_to_proto(message: Message) -> ProtoMessage:
|
|
|
779
796
|
|
|
780
797
|
def message_from_proto(message_proto: ProtoMessage) -> Message:
|
|
781
798
|
"""Deserialize `Message` from ProtoBuf."""
|
|
799
|
+
created_at = message_proto.metadata.created_at
|
|
782
800
|
message = Message(
|
|
783
801
|
metadata=metadata_from_proto(message_proto.metadata),
|
|
784
802
|
content=(
|
|
@@ -792,6 +810,9 @@ def message_from_proto(message_proto: ProtoMessage) -> Message:
|
|
|
792
810
|
else None
|
|
793
811
|
),
|
|
794
812
|
)
|
|
813
|
+
# `.created_at` is set upon Message object construction
|
|
814
|
+
# we need to manually set it to the original value
|
|
815
|
+
message.metadata.created_at = created_at
|
|
795
816
|
return message
|
|
796
817
|
|
|
797
818
|
|
|
@@ -829,8 +850,8 @@ def run_to_proto(run: typing.Run) -> ProtoRun:
|
|
|
829
850
|
run_id=run.run_id,
|
|
830
851
|
fab_id=run.fab_id,
|
|
831
852
|
fab_version=run.fab_version,
|
|
853
|
+
fab_hash=run.fab_hash,
|
|
832
854
|
override_config=user_config_to_proto(run.override_config),
|
|
833
|
-
fab_hash="",
|
|
834
855
|
)
|
|
835
856
|
return proto
|
|
836
857
|
|
|
@@ -841,6 +862,7 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
|
|
|
841
862
|
run_id=run_proto.run_id,
|
|
842
863
|
fab_id=run_proto.fab_id,
|
|
843
864
|
fab_version=run_proto.fab_version,
|
|
865
|
+
fab_hash=run_proto.fab_hash,
|
|
844
866
|
override_config=user_config_from_proto(run_proto.override_config),
|
|
845
867
|
)
|
|
846
868
|
return run
|