flwr-nightly 1.15.0.dev20250114__py3-none-any.whl → 1.15.0.dev20250123__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/config_utils.py +23 -146
- flwr/cli/constant.py +27 -0
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +17 -2
- flwr/cli/login/login.py +9 -1
- flwr/cli/ls.py +10 -2
- flwr/cli/run/run.py +20 -10
- flwr/cli/stop.py +9 -1
- flwr/client/app.py +23 -43
- flwr/client/clientapp/app.py +4 -6
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/grpc_client/connection.py +0 -6
- flwr/client/grpc_rere_client/client_interceptor.py +19 -125
- flwr/client/grpc_rere_client/connection.py +10 -0
- flwr/client/rest_client/connection.py +12 -3
- flwr/client/supernode/app.py +14 -20
- flwr/common/auth_plugin/auth_plugin.py +1 -0
- flwr/common/config.py +152 -15
- flwr/common/constant.py +9 -8
- flwr/common/exit/__init__.py +24 -0
- flwr/common/exit/exit.py +99 -0
- flwr/common/exit/exit_code.py +93 -0
- flwr/common/exit_handlers.py +24 -10
- flwr/common/grpc.py +7 -0
- flwr/common/logger.py +1 -1
- flwr/common/serde.py +6 -4
- flwr/proto/clientappio_pb2.py +13 -3
- flwr/proto/clientappio_pb2_grpc.py +63 -12
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/exec_pb2.py +15 -5
- flwr/proto/exec_pb2_grpc.py +105 -24
- flwr/proto/fab_pb2.py +13 -3
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fleet_pb2.py +15 -5
- flwr/proto/fleet_pb2_grpc.py +147 -36
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +1 -4
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/recordset_pb2.py +18 -8
- flwr/proto/recordset_pb2_grpc.py +20 -0
- flwr/proto/run_pb2.py +16 -6
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/serverappio_pb2.py +32 -14
- flwr/proto/serverappio_pb2.pyi +56 -0
- flwr/proto/serverappio_pb2_grpc.py +261 -44
- flwr/proto/serverappio_pb2_grpc.pyi +20 -0
- flwr/proto/simulationio_pb2.py +13 -3
- flwr/proto/simulationio_pb2_grpc.py +105 -24
- flwr/proto/task_pb2.py +13 -3
- flwr/proto/task_pb2_grpc.py +20 -0
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/server/app.py +40 -11
- flwr/server/compat/app_utils.py +0 -1
- flwr/server/compat/driver_client_proxy.py +1 -2
- flwr/server/driver/grpc_driver.py +5 -2
- flwr/server/driver/inmemory_driver.py +2 -1
- flwr/server/serverapp/app.py +5 -6
- flwr/server/superlink/driver/serverappio_servicer.py +110 -6
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +95 -169
- flwr/server/superlink/fleet/message_handler/message_handler.py +4 -5
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -26
- flwr/server/superlink/linkstate/linkstate.py +5 -18
- flwr/server/superlink/linkstate/sqlite_linkstate.py +30 -70
- flwr/server/superlink/linkstate/utils.py +18 -8
- flwr/server/utils/validator.py +9 -34
- flwr/simulation/app.py +4 -6
- flwr/simulation/legacy_app.py +4 -2
- {flwr_nightly-1.15.0.dev20250114.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/METADATA +4 -4
- {flwr_nightly-1.15.0.dev20250114.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/RECORD +82 -78
- {flwr_nightly-1.15.0.dev20250114.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.15.0.dev20250114.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.15.0.dev20250114.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/entry_points.txt +0 -0
flwr/client/clientapp/utils.py
CHANGED
@@ -66,7 +66,7 @@ def get_load_client_app_fn(
|
|
66
66
|
# `fab_hash` is not required since the app is loaded from `runtime_app_dir`.
|
67
67
|
elif app_path is not None:
|
68
68
|
config = get_project_config(runtime_app_dir)
|
69
|
-
|
69
|
+
this_fab_id, this_fab_version = get_metadata_from_config(config)
|
70
70
|
|
71
71
|
if this_fab_version != fab_version or this_fab_id != fab_id:
|
72
72
|
raise LoadClientAppError(
|
@@ -47,12 +47,6 @@ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
|
47
47
|
)
|
48
48
|
from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
|
49
49
|
|
50
|
-
# The following flags can be uncommented for debugging. Other possible values:
|
51
|
-
# https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
|
52
|
-
# import os
|
53
|
-
# os.environ["GRPC_VERBOSITY"] = "debug"
|
54
|
-
# os.environ["GRPC_TRACE"] = "tcp,http"
|
55
|
-
|
56
50
|
|
57
51
|
def on_channel_state_change(channel_connectivity: str) -> None:
|
58
52
|
"""Log channel connectivity."""
|
@@ -15,71 +15,18 @@
|
|
15
15
|
"""Flower client interceptor."""
|
16
16
|
|
17
17
|
|
18
|
-
import
|
19
|
-
import collections
|
20
|
-
from collections.abc import Sequence
|
21
|
-
from logging import WARNING
|
22
|
-
from typing import Any, Callable, Optional, Union
|
18
|
+
from typing import Any, Callable
|
23
19
|
|
24
20
|
import grpc
|
25
21
|
from cryptography.hazmat.primitives.asymmetric import ec
|
22
|
+
from google.protobuf.message import Message as GrpcMessage
|
26
23
|
|
27
|
-
from flwr.common
|
24
|
+
from flwr.common import now
|
25
|
+
from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER
|
28
26
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
29
|
-
bytes_to_public_key,
|
30
|
-
compute_hmac,
|
31
|
-
generate_shared_key,
|
32
27
|
public_key_to_bytes,
|
28
|
+
sign_message,
|
33
29
|
)
|
34
|
-
from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
|
35
|
-
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
36
|
-
CreateNodeRequest,
|
37
|
-
DeleteNodeRequest,
|
38
|
-
PingRequest,
|
39
|
-
PullMessagesRequest,
|
40
|
-
PullTaskInsRequest,
|
41
|
-
PushMessagesRequest,
|
42
|
-
PushTaskResRequest,
|
43
|
-
)
|
44
|
-
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
|
45
|
-
|
46
|
-
_PUBLIC_KEY_HEADER = "public-key"
|
47
|
-
_AUTH_TOKEN_HEADER = "auth-token"
|
48
|
-
|
49
|
-
Request = Union[
|
50
|
-
CreateNodeRequest,
|
51
|
-
DeleteNodeRequest,
|
52
|
-
PullTaskInsRequest,
|
53
|
-
PushTaskResRequest,
|
54
|
-
GetRunRequest,
|
55
|
-
PingRequest,
|
56
|
-
GetFabRequest,
|
57
|
-
PullMessagesRequest,
|
58
|
-
PushMessagesRequest,
|
59
|
-
]
|
60
|
-
|
61
|
-
|
62
|
-
def _get_value_from_tuples(
|
63
|
-
key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]]
|
64
|
-
) -> bytes:
|
65
|
-
value = next((value for key, value in tuples if key == key_string), "")
|
66
|
-
if isinstance(value, str):
|
67
|
-
return value.encode()
|
68
|
-
|
69
|
-
return value
|
70
|
-
|
71
|
-
|
72
|
-
class _ClientCallDetails(
|
73
|
-
collections.namedtuple(
|
74
|
-
"_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
|
75
|
-
),
|
76
|
-
grpc.ClientCallDetails, # type: ignore
|
77
|
-
):
|
78
|
-
"""Details for each client call.
|
79
|
-
|
80
|
-
The class will be passed on as the first argument in continuation function.
|
81
|
-
In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
|
82
|
-
"""
|
83
30
|
|
84
31
|
|
85
32
|
class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
|
@@ -91,86 +38,33 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
|
|
91
38
|
public_key: ec.EllipticCurvePublicKey,
|
92
39
|
):
|
93
40
|
self.private_key = private_key
|
94
|
-
self.
|
95
|
-
self.shared_secret: Optional[bytes] = None
|
96
|
-
self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
|
97
|
-
self.encoded_public_key = base64.urlsafe_b64encode(
|
98
|
-
public_key_to_bytes(self.public_key)
|
99
|
-
)
|
41
|
+
self.public_key_bytes = public_key_to_bytes(public_key)
|
100
42
|
|
101
43
|
def intercept_unary_unary(
|
102
44
|
self,
|
103
45
|
continuation: Callable[[Any, Any], Any],
|
104
46
|
client_call_details: grpc.ClientCallDetails,
|
105
|
-
request:
|
47
|
+
request: GrpcMessage,
|
106
48
|
) -> grpc.Call:
|
107
49
|
"""Flower client interceptor.
|
108
50
|
|
109
51
|
Intercept unary call from client and add necessary authentication header in the
|
110
52
|
RPC metadata.
|
111
53
|
"""
|
112
|
-
metadata = []
|
113
|
-
postprocess = False
|
114
|
-
if client_call_details.metadata is not None:
|
115
|
-
metadata = list(client_call_details.metadata)
|
116
|
-
|
117
|
-
# Always add the public key header
|
118
|
-
metadata.append(
|
119
|
-
(
|
120
|
-
_PUBLIC_KEY_HEADER,
|
121
|
-
self.encoded_public_key,
|
122
|
-
)
|
123
|
-
)
|
124
|
-
|
125
|
-
if isinstance(request, CreateNodeRequest):
|
126
|
-
postprocess = True
|
127
|
-
elif isinstance(
|
128
|
-
request,
|
129
|
-
(
|
130
|
-
DeleteNodeRequest,
|
131
|
-
PullTaskInsRequest,
|
132
|
-
PushTaskResRequest,
|
133
|
-
GetRunRequest,
|
134
|
-
PingRequest,
|
135
|
-
GetFabRequest,
|
136
|
-
PullMessagesRequest,
|
137
|
-
PushMessagesRequest,
|
138
|
-
),
|
139
|
-
):
|
140
|
-
if self.shared_secret is None:
|
141
|
-
raise RuntimeError("Failure to compute hmac")
|
142
|
-
|
143
|
-
message_bytes = request.SerializeToString(deterministic=True)
|
144
|
-
metadata.append(
|
145
|
-
(
|
146
|
-
_AUTH_TOKEN_HEADER,
|
147
|
-
base64.urlsafe_b64encode(
|
148
|
-
compute_hmac(self.shared_secret, message_bytes)
|
149
|
-
),
|
150
|
-
)
|
151
|
-
)
|
54
|
+
metadata = list(client_call_details.metadata or [])
|
152
55
|
|
153
|
-
|
154
|
-
|
155
|
-
client_call_details.timeout,
|
156
|
-
metadata,
|
157
|
-
client_call_details.credentials,
|
158
|
-
)
|
56
|
+
# Add the public key
|
57
|
+
metadata.append((PUBLIC_KEY_HEADER, self.public_key_bytes))
|
159
58
|
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
_get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
|
164
|
-
)
|
59
|
+
# Add timestamp
|
60
|
+
timestamp = now().isoformat()
|
61
|
+
metadata.append((TIMESTAMP_HEADER, timestamp))
|
165
62
|
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
log(WARNING, "Can't get server public key, SuperLink may be offline")
|
63
|
+
# Sign and add the signature
|
64
|
+
signature = sign_message(self.private_key, timestamp.encode("ascii"))
|
65
|
+
metadata.append((SIGNATURE_HEADER, signature))
|
170
66
|
|
171
|
-
|
172
|
-
|
173
|
-
self.private_key, self.server_public_key
|
174
|
-
)
|
67
|
+
# Overwrite the metadata
|
68
|
+
details = client_call_details._replace(metadata=metadata)
|
175
69
|
|
176
|
-
return
|
70
|
+
return continuation(details, request)
|
@@ -311,3 +311,13 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
311
311
|
yield (receive, send, create_node, delete_node, get_run, get_fab)
|
312
312
|
except Exception as exc: # pylint: disable=broad-except
|
313
313
|
log(ERROR, exc)
|
314
|
+
# Cleanup
|
315
|
+
finally:
|
316
|
+
try:
|
317
|
+
if node is not None:
|
318
|
+
# Disable retrying
|
319
|
+
retry_invoker.max_tries = 1
|
320
|
+
delete_node()
|
321
|
+
except grpc.RpcError:
|
322
|
+
pass
|
323
|
+
channel.close()
|
@@ -16,7 +16,6 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import random
|
19
|
-
import sys
|
20
19
|
import threading
|
21
20
|
from collections.abc import Iterator
|
22
21
|
from contextlib import contextmanager
|
@@ -26,17 +25,18 @@ from typing import Callable, Optional, TypeVar, Union
|
|
26
25
|
|
27
26
|
from cryptography.hazmat.primitives.asymmetric import ec
|
28
27
|
from google.protobuf.message import Message as GrpcMessage
|
28
|
+
from requests.exceptions import ConnectionError as RequestsConnectionError
|
29
29
|
|
30
30
|
from flwr.client.heartbeat import start_ping_loop
|
31
31
|
from flwr.client.message_handler.message_handler import validate_out_message
|
32
32
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
33
33
|
from flwr.common.constant import (
|
34
|
-
MISSING_EXTRA_REST,
|
35
34
|
PING_BASE_MULTIPLIER,
|
36
35
|
PING_CALL_TIMEOUT,
|
37
36
|
PING_DEFAULT_INTERVAL,
|
38
37
|
PING_RANDOM_RANGE,
|
39
38
|
)
|
39
|
+
from flwr.common.exit import ExitCode, flwr_exit
|
40
40
|
from flwr.common.logger import log
|
41
41
|
from flwr.common.message import Message, Metadata
|
42
42
|
from flwr.common.retry_invoker import RetryInvoker
|
@@ -61,7 +61,7 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=
|
|
61
61
|
try:
|
62
62
|
import requests
|
63
63
|
except ModuleNotFoundError:
|
64
|
-
|
64
|
+
flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
|
65
65
|
|
66
66
|
|
67
67
|
PATH_CREATE_NODE: str = "api/v0/fleet/create-node"
|
@@ -379,3 +379,12 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
379
379
|
yield (receive, send, create_node, delete_node, get_run, get_fab)
|
380
380
|
except Exception as exc: # pylint: disable=broad-except
|
381
381
|
log(ERROR, exc)
|
382
|
+
# Cleanup
|
383
|
+
finally:
|
384
|
+
try:
|
385
|
+
if node is not None:
|
386
|
+
# Disable retrying
|
387
|
+
retry_invoker.max_tries = 1
|
388
|
+
delete_node()
|
389
|
+
except RequestsConnectionError:
|
390
|
+
pass
|
flwr/client/supernode/app.py
CHANGED
@@ -40,6 +40,7 @@ from flwr.common.constant import (
|
|
40
40
|
TRANSPORT_TYPE_GRPC_RERE,
|
41
41
|
TRANSPORT_TYPE_REST,
|
42
42
|
)
|
43
|
+
from flwr.common.exit import ExitCode, flwr_exit
|
43
44
|
from flwr.common.exit_handlers import register_exit_handlers
|
44
45
|
from flwr.common.logger import log, warn_deprecated_feature
|
45
46
|
|
@@ -86,6 +87,12 @@ def run_supernode() -> None:
|
|
86
87
|
|
87
88
|
log(DEBUG, "Isolation mode: %s", args.isolation)
|
88
89
|
|
90
|
+
# Register handlers for graceful shutdown
|
91
|
+
register_exit_handlers(
|
92
|
+
event_type=EventType.RUN_SUPERNODE_LEAVE,
|
93
|
+
exit_message="SuperNode terminated gracefully.",
|
94
|
+
)
|
95
|
+
|
89
96
|
start_client_internal(
|
90
97
|
server_address=args.superlink,
|
91
98
|
load_client_app_fn=load_fn,
|
@@ -103,11 +110,6 @@ def run_supernode() -> None:
|
|
103
110
|
clientappio_api_address=args.clientappio_api_address,
|
104
111
|
)
|
105
112
|
|
106
|
-
# Graceful shutdown
|
107
|
-
register_exit_handlers(
|
108
|
-
event_type=EventType.RUN_SUPERNODE_LEAVE,
|
109
|
-
)
|
110
|
-
|
111
113
|
|
112
114
|
def run_client_app() -> None:
|
113
115
|
"""Run Flower client app."""
|
@@ -280,11 +282,7 @@ def _try_setup_client_authentication(
|
|
280
282
|
return None
|
281
283
|
|
282
284
|
if not args.auth_supernode_private_key or not args.auth_supernode_public_key:
|
283
|
-
|
284
|
-
"Authentication requires file paths to both "
|
285
|
-
"'--auth-supernode-private-key' and '--auth-supernode-public-key'"
|
286
|
-
"to be provided (providing only one of them is not sufficient)."
|
287
|
-
)
|
285
|
+
flwr_exit(ExitCode.SUPERNODE_NODE_AUTH_KEYS_REQUIRED)
|
288
286
|
|
289
287
|
try:
|
290
288
|
ssh_private_key = load_ssh_private_key(
|
@@ -294,11 +292,9 @@ def _try_setup_client_authentication(
|
|
294
292
|
if not isinstance(ssh_private_key, ec.EllipticCurvePrivateKey):
|
295
293
|
raise ValueError()
|
296
294
|
except (ValueError, UnsupportedAlgorithm):
|
297
|
-
|
298
|
-
|
299
|
-
"
|
300
|
-
"curve private and public key pair. Please ensure that the file "
|
301
|
-
"path points to a valid private key file and try again."
|
295
|
+
flwr_exit(
|
296
|
+
ExitCode.SUPERNODE_NODE_AUTH_KEYS_INVALID,
|
297
|
+
"Unable to parse the private key file.",
|
302
298
|
)
|
303
299
|
|
304
300
|
try:
|
@@ -308,11 +304,9 @@ def _try_setup_client_authentication(
|
|
308
304
|
if not isinstance(ssh_public_key, ec.EllipticCurvePublicKey):
|
309
305
|
raise ValueError()
|
310
306
|
except (ValueError, UnsupportedAlgorithm):
|
311
|
-
|
312
|
-
|
313
|
-
"
|
314
|
-
"curve private and public key pair. Please ensure that the file "
|
315
|
-
"path points to a valid public key file and try again."
|
307
|
+
flwr_exit(
|
308
|
+
ExitCode.SUPERNODE_NODE_AUTH_KEYS_INVALID,
|
309
|
+
"Unable to parse the public key file.",
|
316
310
|
)
|
317
311
|
|
318
312
|
return (
|
flwr/common/config.py
CHANGED
@@ -17,13 +17,13 @@
|
|
17
17
|
|
18
18
|
import os
|
19
19
|
import re
|
20
|
+
import zipfile
|
21
|
+
from io import BytesIO
|
20
22
|
from pathlib import Path
|
21
|
-
from typing import Any, Optional, Union, cast, get_args
|
23
|
+
from typing import IO, Any, Optional, TypeVar, Union, cast, get_args
|
22
24
|
|
23
25
|
import tomli
|
24
26
|
|
25
|
-
from flwr.cli.config_utils import get_fab_config, validate_fields
|
26
|
-
from flwr.common import ConfigsRecord
|
27
27
|
from flwr.common.constant import (
|
28
28
|
APP_DIR,
|
29
29
|
FAB_CONFIG_FILE,
|
@@ -33,6 +33,10 @@ from flwr.common.constant import (
|
|
33
33
|
)
|
34
34
|
from flwr.common.typing import Run, UserConfig, UserConfigValue
|
35
35
|
|
36
|
+
from . import ConfigsRecord, object_ref
|
37
|
+
|
38
|
+
T_dict = TypeVar("T_dict", bound=dict[str, Any]) # pylint: disable=invalid-name
|
39
|
+
|
36
40
|
|
37
41
|
def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
|
38
42
|
"""Return the Flower home directory based on env variables."""
|
@@ -80,7 +84,7 @@ def get_project_config(project_dir: Union[str, Path]) -> dict[str, Any]:
|
|
80
84
|
config = tomli.loads(toml_file.read())
|
81
85
|
|
82
86
|
# Validate pyproject.toml fields
|
83
|
-
is_valid, errors, _ =
|
87
|
+
is_valid, errors, _ = validate_fields_in_config(config)
|
84
88
|
if not is_valid:
|
85
89
|
error_msg = "\n".join([f" - {error}" for error in errors])
|
86
90
|
raise ValueError(
|
@@ -91,19 +95,28 @@ def get_project_config(project_dir: Union[str, Path]) -> dict[str, Any]:
|
|
91
95
|
|
92
96
|
|
93
97
|
def fuse_dicts(
|
94
|
-
main_dict:
|
95
|
-
override_dict:
|
96
|
-
|
98
|
+
main_dict: T_dict,
|
99
|
+
override_dict: T_dict,
|
100
|
+
check_keys: bool = True,
|
101
|
+
) -> T_dict:
|
97
102
|
"""Merge a config with the overrides.
|
98
103
|
|
99
|
-
|
100
|
-
|
104
|
+
If `check_keys` is set to True, an error will be raised if the override
|
105
|
+
dictionary contains keys that are not present in the main dictionary.
|
106
|
+
Otherwise, only the keys present in the main dictionary will be updated.
|
101
107
|
"""
|
102
|
-
|
108
|
+
if not isinstance(main_dict, dict) or not isinstance(override_dict, dict):
|
109
|
+
raise ValueError("Both dictionaries must be of type dict")
|
110
|
+
|
111
|
+
fused_dict = cast(T_dict, main_dict.copy())
|
103
112
|
|
104
113
|
for key, value in override_dict.items():
|
105
114
|
if key in main_dict:
|
115
|
+
if isinstance(value, dict):
|
116
|
+
fused_dict[key] = fuse_dicts(main_dict[key], value)
|
106
117
|
fused_dict[key] = value
|
118
|
+
elif check_keys:
|
119
|
+
raise ValueError(f"Key '{key}' is not present in the main dictionary")
|
107
120
|
|
108
121
|
return fused_dict
|
109
122
|
|
@@ -192,8 +205,8 @@ def unflatten_dict(flat_dict: dict[str, Any]) -> dict[str, Any]:
|
|
192
205
|
|
193
206
|
|
194
207
|
def parse_config_args(
|
195
|
-
config: Optional[list[str]],
|
196
|
-
) ->
|
208
|
+
config: Optional[list[str]], flatten: bool = True
|
209
|
+
) -> dict[str, Any]:
|
197
210
|
"""Parse separator separated list of key-value pairs separated by '='."""
|
198
211
|
overrides: UserConfig = {}
|
199
212
|
|
@@ -221,16 +234,16 @@ def parse_config_args(
|
|
221
234
|
matches = pattern.findall(config_line)
|
222
235
|
toml_str = "\n".join(f"{k} = {v}" for k, v in matches)
|
223
236
|
overrides.update(tomli.loads(toml_str))
|
224
|
-
flat_overrides = flatten_dict(overrides)
|
237
|
+
flat_overrides = flatten_dict(overrides) if flatten else overrides
|
225
238
|
|
226
239
|
return flat_overrides
|
227
240
|
|
228
241
|
|
229
242
|
def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
|
230
|
-
"""Extract `
|
243
|
+
"""Extract `fab_id` and `fab_version` from a project config."""
|
231
244
|
return (
|
232
|
-
config["project"]["version"],
|
233
245
|
f"{config['tool']['flwr']['app']['publisher']}/{config['project']['name']}",
|
246
|
+
config["project"]["version"],
|
234
247
|
)
|
235
248
|
|
236
249
|
|
@@ -241,3 +254,127 @@ def user_config_to_configsrecord(config: UserConfig) -> ConfigsRecord:
|
|
241
254
|
c_record[k] = v
|
242
255
|
|
243
256
|
return c_record
|
257
|
+
|
258
|
+
|
259
|
+
def get_fab_config(fab_file: Union[Path, bytes]) -> dict[str, Any]:
|
260
|
+
"""Extract the config from a FAB file or path.
|
261
|
+
|
262
|
+
Parameters
|
263
|
+
----------
|
264
|
+
fab_file : Union[Path, bytes]
|
265
|
+
The Flower App Bundle file to validate and extract the metadata from.
|
266
|
+
It can either be a path to the file or the file itself as bytes.
|
267
|
+
|
268
|
+
Returns
|
269
|
+
-------
|
270
|
+
Dict[str, Any]
|
271
|
+
The `config` of the given Flower App Bundle.
|
272
|
+
"""
|
273
|
+
fab_file_archive: Union[Path, IO[bytes]]
|
274
|
+
if isinstance(fab_file, bytes):
|
275
|
+
fab_file_archive = BytesIO(fab_file)
|
276
|
+
elif isinstance(fab_file, Path):
|
277
|
+
fab_file_archive = fab_file
|
278
|
+
else:
|
279
|
+
raise ValueError("fab_file must be either a Path or bytes")
|
280
|
+
|
281
|
+
with zipfile.ZipFile(fab_file_archive, "r") as zipf:
|
282
|
+
with zipf.open("pyproject.toml") as file:
|
283
|
+
toml_content = file.read().decode("utf-8")
|
284
|
+
try:
|
285
|
+
conf = tomli.loads(toml_content)
|
286
|
+
except tomli.TOMLDecodeError:
|
287
|
+
raise ValueError("Invalid TOML content in pyproject.toml") from None
|
288
|
+
|
289
|
+
is_valid, errors, _ = validate_config(conf, check_module=False)
|
290
|
+
if not is_valid:
|
291
|
+
raise ValueError(errors)
|
292
|
+
|
293
|
+
return conf
|
294
|
+
|
295
|
+
|
296
|
+
def _validate_run_config(config_dict: dict[str, Any], errors: list[str]) -> None:
|
297
|
+
for key, value in config_dict.items():
|
298
|
+
if isinstance(value, dict):
|
299
|
+
_validate_run_config(config_dict[key], errors)
|
300
|
+
elif not isinstance(value, get_args(UserConfigValue)):
|
301
|
+
raise ValueError(
|
302
|
+
f"The value for key {key} needs to be of type `int`, `float`, "
|
303
|
+
"`bool, `str`, or a `dict` of those.",
|
304
|
+
)
|
305
|
+
|
306
|
+
|
307
|
+
# pylint: disable=too-many-branches
|
308
|
+
def validate_fields_in_config(
|
309
|
+
config: dict[str, Any]
|
310
|
+
) -> tuple[bool, list[str], list[str]]:
|
311
|
+
"""Validate pyproject.toml fields."""
|
312
|
+
errors = []
|
313
|
+
warnings = []
|
314
|
+
|
315
|
+
if "project" not in config:
|
316
|
+
errors.append("Missing [project] section")
|
317
|
+
else:
|
318
|
+
if "name" not in config["project"]:
|
319
|
+
errors.append('Property "name" missing in [project]')
|
320
|
+
if "version" not in config["project"]:
|
321
|
+
errors.append('Property "version" missing in [project]')
|
322
|
+
if "description" not in config["project"]:
|
323
|
+
warnings.append('Recommended property "description" missing in [project]')
|
324
|
+
if "license" not in config["project"]:
|
325
|
+
warnings.append('Recommended property "license" missing in [project]')
|
326
|
+
if "authors" not in config["project"]:
|
327
|
+
warnings.append('Recommended property "authors" missing in [project]')
|
328
|
+
|
329
|
+
if (
|
330
|
+
"tool" not in config
|
331
|
+
or "flwr" not in config["tool"]
|
332
|
+
or "app" not in config["tool"]["flwr"]
|
333
|
+
):
|
334
|
+
errors.append("Missing [tool.flwr.app] section")
|
335
|
+
else:
|
336
|
+
if "publisher" not in config["tool"]["flwr"]["app"]:
|
337
|
+
errors.append('Property "publisher" missing in [tool.flwr.app]')
|
338
|
+
if "config" in config["tool"]["flwr"]["app"]:
|
339
|
+
_validate_run_config(config["tool"]["flwr"]["app"]["config"], errors)
|
340
|
+
if "components" not in config["tool"]["flwr"]["app"]:
|
341
|
+
errors.append("Missing [tool.flwr.app.components] section")
|
342
|
+
else:
|
343
|
+
if "serverapp" not in config["tool"]["flwr"]["app"]["components"]:
|
344
|
+
errors.append(
|
345
|
+
'Property "serverapp" missing in [tool.flwr.app.components]'
|
346
|
+
)
|
347
|
+
if "clientapp" not in config["tool"]["flwr"]["app"]["components"]:
|
348
|
+
errors.append(
|
349
|
+
'Property "clientapp" missing in [tool.flwr.app.components]'
|
350
|
+
)
|
351
|
+
|
352
|
+
return len(errors) == 0, errors, warnings
|
353
|
+
|
354
|
+
|
355
|
+
def validate_config(
|
356
|
+
config: dict[str, Any],
|
357
|
+
check_module: bool = True,
|
358
|
+
project_dir: Optional[Union[str, Path]] = None,
|
359
|
+
) -> tuple[bool, list[str], list[str]]:
|
360
|
+
"""Validate pyproject.toml."""
|
361
|
+
is_valid, errors, warnings = validate_fields_in_config(config)
|
362
|
+
|
363
|
+
if not is_valid:
|
364
|
+
return False, errors, warnings
|
365
|
+
|
366
|
+
# Validate serverapp
|
367
|
+
serverapp_ref = config["tool"]["flwr"]["app"]["components"]["serverapp"]
|
368
|
+
is_valid, reason = object_ref.validate(serverapp_ref, check_module, project_dir)
|
369
|
+
|
370
|
+
if not is_valid and isinstance(reason, str):
|
371
|
+
return False, [reason], []
|
372
|
+
|
373
|
+
# Validate clientapp
|
374
|
+
clientapp_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
|
375
|
+
is_valid, reason = object_ref.validate(clientapp_ref, check_module, project_dir)
|
376
|
+
|
377
|
+
if not is_valid and isinstance(reason, str):
|
378
|
+
return False, [reason], []
|
379
|
+
|
380
|
+
return True, [], []
|
flwr/common/constant.py
CHANGED
@@ -17,14 +17,6 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
MISSING_EXTRA_REST = """
|
21
|
-
Extra dependencies required for using the REST-based Fleet API are missing.
|
22
|
-
|
23
|
-
To use the REST API, install `flwr` with the `rest` extra:
|
24
|
-
|
25
|
-
`pip install flwr[rest]`.
|
26
|
-
"""
|
27
|
-
|
28
20
|
TRANSPORT_TYPE_GRPC_BIDI = "grpc-bidi"
|
29
21
|
TRANSPORT_TYPE_GRPC_RERE = "grpc-rere"
|
30
22
|
TRANSPORT_TYPE_GRPC_ADAPTER = "grpc-adapter"
|
@@ -83,6 +75,9 @@ FAB_HASH_TRUNCATION = 8
|
|
83
75
|
FLWR_DIR = ".flwr" # The default Flower directory: ~/.flwr/
|
84
76
|
FLWR_HOME = "FLWR_HOME" # If set, override the default Flower directory
|
85
77
|
|
78
|
+
# Constant for SuperLink
|
79
|
+
SUPERLINK_NODE_ID = 1
|
80
|
+
|
86
81
|
# Constants entries in Node config for Simulation
|
87
82
|
PARTITION_ID_KEY = "partition-id"
|
88
83
|
NUM_PARTITIONS_KEY = "num-partitions"
|
@@ -117,6 +112,12 @@ AUTH_TYPE = "auth_type"
|
|
117
112
|
ACCESS_TOKEN_KEY = "access_token"
|
118
113
|
REFRESH_TOKEN_KEY = "refresh_token"
|
119
114
|
|
115
|
+
# Constants for node authentication
|
116
|
+
PUBLIC_KEY_HEADER = "public-key-bin" # Must end with "-bin" for binary data
|
117
|
+
SIGNATURE_HEADER = "signature-bin" # Must end with "-bin" for binary data
|
118
|
+
TIMESTAMP_HEADER = "timestamp"
|
119
|
+
TIMESTAMP_TOLERANCE = 10 # Tolerance for timestamp verification
|
120
|
+
|
120
121
|
|
121
122
|
class MessageType:
|
122
123
|
"""Message type."""
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Flower exit functionality."""
|
16
|
+
|
17
|
+
|
18
|
+
from .exit import flwr_exit
|
19
|
+
from .exit_code import ExitCode
|
20
|
+
|
21
|
+
__all__ = [
|
22
|
+
"ExitCode",
|
23
|
+
"flwr_exit",
|
24
|
+
]
|