flwr-nightly 1.14.0.dev20241209__py3-none-any.whl → 1.14.0.dev20241211__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/cli_user_auth_interceptor.py +86 -0
- flwr/cli/config_utils.py +18 -2
- flwr/cli/log.py +10 -31
- flwr/cli/login/__init__.py +21 -0
- flwr/cli/login/login.py +82 -0
- flwr/cli/ls.py +10 -40
- flwr/cli/run/run.py +14 -25
- flwr/cli/stop.py +9 -39
- flwr/cli/utils.py +108 -1
- flwr/common/auth_plugin/__init__.py +24 -0
- flwr/common/auth_plugin/auth_plugin.py +111 -0
- flwr/common/config.py +2 -1
- flwr/common/constant.py +6 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
- flwr/common/telemetry.py +2 -1
- flwr/proto/serverappio_pb2.py +18 -18
- flwr/proto/serverappio_pb2.pyi +8 -2
- flwr/proto/serverappio_pb2_grpc.py +34 -0
- flwr/proto/serverappio_pb2_grpc.pyi +13 -0
- flwr/server/app.py +52 -1
- flwr/server/driver/grpc_driver.py +6 -2
- flwr/server/superlink/driver/serverappio_servicer.py +18 -0
- flwr/superexec/exec_grpc.py +18 -1
- flwr/superexec/exec_servicer.py +22 -3
- flwr/superexec/exec_user_auth_interceptor.py +101 -0
- {flwr_nightly-1.14.0.dev20241209.dist-info → flwr_nightly-1.14.0.dev20241211.dist-info}/METADATA +8 -7
- {flwr_nightly-1.14.0.dev20241209.dist-info → flwr_nightly-1.14.0.dev20241211.dist-info}/RECORD +31 -25
- {flwr_nightly-1.14.0.dev20241209.dist-info → flwr_nightly-1.14.0.dev20241211.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.14.0.dev20241209.dist-info → flwr_nightly-1.14.0.dev20241211.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.14.0.dev20241209.dist-info → flwr_nightly-1.14.0.dev20241211.dist-info}/entry_points.txt +0 -0
flwr/cli/utils.py
CHANGED
|
@@ -15,12 +15,32 @@
|
|
|
15
15
|
"""Flower command line interface utils."""
|
|
16
16
|
|
|
17
17
|
import hashlib
|
|
18
|
+
import json
|
|
18
19
|
import re
|
|
20
|
+
from logging import DEBUG
|
|
19
21
|
from pathlib import Path
|
|
20
|
-
from typing import Callable, Optional, cast
|
|
22
|
+
from typing import Any, Callable, Optional, cast
|
|
21
23
|
|
|
24
|
+
import grpc
|
|
22
25
|
import typer
|
|
23
26
|
|
|
27
|
+
from flwr.cli.cli_user_auth_interceptor import CliUserAuthInterceptor
|
|
28
|
+
from flwr.common.address import parse_address
|
|
29
|
+
from flwr.common.auth_plugin import CliAuthPlugin
|
|
30
|
+
from flwr.common.constant import AUTH_TYPE, CREDENTIALS_DIR, FLWR_DIR
|
|
31
|
+
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
|
|
32
|
+
from flwr.common.logger import log
|
|
33
|
+
|
|
34
|
+
from .config_utils import validate_certificate_in_federation_config
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
from flwr.ee import get_cli_auth_plugins
|
|
38
|
+
except ImportError:
|
|
39
|
+
|
|
40
|
+
def get_cli_auth_plugins() -> dict[str, type[CliAuthPlugin]]:
|
|
41
|
+
"""Return all CLI authentication plugins."""
|
|
42
|
+
raise NotImplementedError("No authentication plugins are currently supported.")
|
|
43
|
+
|
|
24
44
|
|
|
25
45
|
def prompt_text(
|
|
26
46
|
text: str,
|
|
@@ -136,3 +156,90 @@ def get_sha256_hash(file_path: Path) -> str:
|
|
|
136
156
|
break
|
|
137
157
|
sha256.update(data)
|
|
138
158
|
return sha256.hexdigest()
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def get_user_auth_config_path(
|
|
162
|
+
root_dir: Path, federation: str, server_address: str
|
|
163
|
+
) -> Path:
|
|
164
|
+
"""Return the path to the user auth config file."""
|
|
165
|
+
# Parse the server address
|
|
166
|
+
parsed_addr = parse_address(server_address)
|
|
167
|
+
if parsed_addr is None:
|
|
168
|
+
raise ValueError(f"Invalid server address: {server_address}")
|
|
169
|
+
host, port, is_v6 = parsed_addr
|
|
170
|
+
formatted_addr = f"[{host}]_{port}" if is_v6 else f"{host}_{port}"
|
|
171
|
+
|
|
172
|
+
# Locate the credentials directory
|
|
173
|
+
credentials_dir = root_dir.absolute() / FLWR_DIR / CREDENTIALS_DIR
|
|
174
|
+
credentials_dir.mkdir(parents=True, exist_ok=True)
|
|
175
|
+
return credentials_dir / f"{federation}_{formatted_addr}.json"
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def try_obtain_cli_auth_plugin(
|
|
179
|
+
root_dir: Path,
|
|
180
|
+
federation: str,
|
|
181
|
+
federation_config: dict[str, Any],
|
|
182
|
+
auth_type: Optional[str] = None,
|
|
183
|
+
) -> Optional[CliAuthPlugin]:
|
|
184
|
+
"""Load the CLI-side user auth plugin for the given auth type."""
|
|
185
|
+
config_path = get_user_auth_config_path(
|
|
186
|
+
root_dir, federation, federation_config["address"]
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Load the config file if it exists
|
|
190
|
+
config: dict[str, Any] = {}
|
|
191
|
+
if config_path.exists():
|
|
192
|
+
with config_path.open("r", encoding="utf-8") as file:
|
|
193
|
+
config = json.load(file)
|
|
194
|
+
# This is the case when the user auth is not enabled
|
|
195
|
+
elif auth_type is None:
|
|
196
|
+
return None
|
|
197
|
+
|
|
198
|
+
# Get the auth type from the config if not provided
|
|
199
|
+
if auth_type is None:
|
|
200
|
+
if AUTH_TYPE not in config:
|
|
201
|
+
return None
|
|
202
|
+
auth_type = config[AUTH_TYPE]
|
|
203
|
+
|
|
204
|
+
# Retrieve auth plugin class and instantiate it
|
|
205
|
+
try:
|
|
206
|
+
all_plugins: dict[str, type[CliAuthPlugin]] = get_cli_auth_plugins()
|
|
207
|
+
auth_plugin_class = all_plugins[auth_type]
|
|
208
|
+
return auth_plugin_class(config_path)
|
|
209
|
+
except KeyError:
|
|
210
|
+
typer.echo(f"❌ Unknown user authentication type: {auth_type}")
|
|
211
|
+
raise typer.Exit(code=1) from None
|
|
212
|
+
except ImportError:
|
|
213
|
+
typer.echo("❌ No authentication plugins are currently supported.")
|
|
214
|
+
raise typer.Exit(code=1) from None
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def init_channel(
|
|
218
|
+
app: Path, federation_config: dict[str, Any], auth_plugin: Optional[CliAuthPlugin]
|
|
219
|
+
) -> grpc.Channel:
|
|
220
|
+
"""Initialize gRPC channel to the Exec API."""
|
|
221
|
+
|
|
222
|
+
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
223
|
+
"""Log channel connectivity."""
|
|
224
|
+
log(DEBUG, channel_connectivity)
|
|
225
|
+
|
|
226
|
+
insecure, root_certificates_bytes = validate_certificate_in_federation_config(
|
|
227
|
+
app, federation_config
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# Initialize the CLI-side user auth interceptor
|
|
231
|
+
interceptors: list[grpc.UnaryUnaryClientInterceptor] = []
|
|
232
|
+
if auth_plugin is not None:
|
|
233
|
+
auth_plugin.load_tokens()
|
|
234
|
+
interceptors = CliUserAuthInterceptor(auth_plugin)
|
|
235
|
+
|
|
236
|
+
# Create the gRPC channel
|
|
237
|
+
channel = create_channel(
|
|
238
|
+
server_address=federation_config["address"],
|
|
239
|
+
insecure=insecure,
|
|
240
|
+
root_certificates=root_certificates_bytes,
|
|
241
|
+
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
242
|
+
interceptors=interceptors or None,
|
|
243
|
+
)
|
|
244
|
+
channel.subscribe(on_channel_state_change)
|
|
245
|
+
return channel
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Auth plugin components."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from .auth_plugin import CliAuthPlugin as CliAuthPlugin
|
|
19
|
+
from .auth_plugin import ExecAuthPlugin as ExecAuthPlugin
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"CliAuthPlugin",
|
|
23
|
+
"ExecAuthPlugin",
|
|
24
|
+
]
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Abstract classes for Flower User Auth Plugin."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
from collections.abc import Sequence
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any, Optional, Union
|
|
22
|
+
|
|
23
|
+
from flwr.proto.exec_pb2_grpc import ExecStub
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ExecAuthPlugin(ABC):
|
|
27
|
+
"""Abstract Flower Auth Plugin class for ExecServicer.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
config : dict[str, Any]
|
|
32
|
+
The authentication configuration loaded from a YAML file.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def __init__(self, config: dict[str, Any]):
|
|
37
|
+
"""Abstract constructor."""
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def get_login_details(self) -> dict[str, str]:
|
|
41
|
+
"""Get the login details."""
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def validate_tokens_in_metadata(
|
|
45
|
+
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
|
|
46
|
+
) -> bool:
|
|
47
|
+
"""Validate authentication tokens in the provided metadata."""
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def get_auth_tokens(self, auth_details: dict[str, str]) -> dict[str, str]:
|
|
51
|
+
"""Get authentication tokens."""
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def refresh_tokens(
|
|
55
|
+
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
|
|
56
|
+
) -> Optional[Sequence[tuple[str, Union[str, bytes]]]]:
|
|
57
|
+
"""Refresh authentication tokens in the provided metadata."""
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class CliAuthPlugin(ABC):
|
|
61
|
+
"""Abstract Flower Auth Plugin class for CLI.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
user_auth_config_path : Path
|
|
66
|
+
The path to the user's authentication configuration file.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def login(
|
|
72
|
+
login_details: dict[str, str],
|
|
73
|
+
exec_stub: ExecStub,
|
|
74
|
+
) -> dict[str, Any]:
|
|
75
|
+
"""Authenticate the user with the SuperLink.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
login_details : dict[str, str]
|
|
80
|
+
A dictionary containing the user's login details.
|
|
81
|
+
exec_stub : ExecStub
|
|
82
|
+
An instance of `ExecStub` used for communication with the SuperLink.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
user_auth_config : dict[str, Any]
|
|
87
|
+
A dictionary containing the user's authentication configuration
|
|
88
|
+
in JSON format.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
@abstractmethod
|
|
92
|
+
def __init__(self, user_auth_config_path: Path):
|
|
93
|
+
"""Abstract constructor."""
|
|
94
|
+
|
|
95
|
+
@abstractmethod
|
|
96
|
+
def store_tokens(self, user_auth_config: dict[str, Any]) -> None:
|
|
97
|
+
"""Store authentication tokens from the provided user_auth_config.
|
|
98
|
+
|
|
99
|
+
The configuration, including tokens, will be saved as a JSON file
|
|
100
|
+
at `user_auth_config_path`.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
@abstractmethod
|
|
104
|
+
def load_tokens(self) -> None:
|
|
105
|
+
"""Load authentication tokens from the user_auth_config_path."""
|
|
106
|
+
|
|
107
|
+
@abstractmethod
|
|
108
|
+
def write_tokens_to_metadata(
|
|
109
|
+
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
|
|
110
|
+
) -> Sequence[tuple[str, Union[str, bytes]]]:
|
|
111
|
+
"""Write authentication tokens to the provided metadata."""
|
flwr/common/config.py
CHANGED
|
@@ -27,6 +27,7 @@ from flwr.common.constant import (
|
|
|
27
27
|
APP_DIR,
|
|
28
28
|
FAB_CONFIG_FILE,
|
|
29
29
|
FAB_HASH_TRUNCATION,
|
|
30
|
+
FLWR_DIR,
|
|
30
31
|
FLWR_HOME,
|
|
31
32
|
)
|
|
32
33
|
from flwr.common.typing import Run, UserConfig, UserConfigValue
|
|
@@ -38,7 +39,7 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
|
|
|
38
39
|
return Path(
|
|
39
40
|
os.getenv(
|
|
40
41
|
FLWR_HOME,
|
|
41
|
-
Path(f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}") /
|
|
42
|
+
Path(f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}") / FLWR_DIR,
|
|
42
43
|
)
|
|
43
44
|
)
|
|
44
45
|
return Path(provided_path).absolute()
|
flwr/common/constant.py
CHANGED
|
@@ -80,7 +80,8 @@ FAB_ALLOWED_EXTENSIONS = {".py", ".toml", ".md"}
|
|
|
80
80
|
FAB_CONFIG_FILE = "pyproject.toml"
|
|
81
81
|
FAB_DATE = (2024, 10, 1, 0, 0, 0)
|
|
82
82
|
FAB_HASH_TRUNCATION = 8
|
|
83
|
-
|
|
83
|
+
FLWR_DIR = ".flwr" # The default Flower directory: ~/.flwr/
|
|
84
|
+
FLWR_HOME = "FLWR_HOME" # If set, override the default Flower directory
|
|
84
85
|
|
|
85
86
|
# Constants entries in Node config for Simulation
|
|
86
87
|
PARTITION_ID_KEY = "partition-id"
|
|
@@ -110,6 +111,10 @@ LOG_UPLOAD_INTERVAL = 0.2 # Minimum interval between two log uploads
|
|
|
110
111
|
# Retry configurations
|
|
111
112
|
MAX_RETRY_DELAY = 20 # Maximum delay duration between two consecutive retries.
|
|
112
113
|
|
|
114
|
+
# Constants for user authentication
|
|
115
|
+
CREDENTIALS_DIR = ".credentials"
|
|
116
|
+
AUTH_TYPE = "auth_type"
|
|
117
|
+
|
|
113
118
|
|
|
114
119
|
class MessageType:
|
|
115
120
|
"""Message type."""
|
|
@@ -93,8 +93,8 @@ def pseudo_rand_gen(
|
|
|
93
93
|
output = []
|
|
94
94
|
for dimension in dimensions_list:
|
|
95
95
|
if len(dimension) == 0:
|
|
96
|
-
arr = np.array(gen.randint(0, num_range - 1), dtype=
|
|
96
|
+
arr = np.array(gen.randint(0, num_range - 1), dtype=np.int64)
|
|
97
97
|
else:
|
|
98
|
-
arr = gen.randint(0, num_range - 1, dimension)
|
|
98
|
+
arr = gen.randint(0, num_range - 1, dimension, dtype=np.int64)
|
|
99
99
|
output.append(arr)
|
|
100
100
|
return output
|
flwr/common/telemetry.py
CHANGED
|
@@ -27,6 +27,7 @@ from enum import Enum, auto
|
|
|
27
27
|
from pathlib import Path
|
|
28
28
|
from typing import Any, Optional, Union, cast
|
|
29
29
|
|
|
30
|
+
from flwr.common.constant import FLWR_DIR
|
|
30
31
|
from flwr.common.version import package_name, package_version
|
|
31
32
|
|
|
32
33
|
FLWR_TELEMETRY_ENABLED = os.getenv("FLWR_TELEMETRY_ENABLED", "1")
|
|
@@ -86,7 +87,7 @@ def _get_source_id() -> str:
|
|
|
86
87
|
# If the home directory can’t be resolved, RuntimeError is raised.
|
|
87
88
|
return source_id
|
|
88
89
|
|
|
89
|
-
flwr_dir = home.joinpath(
|
|
90
|
+
flwr_dir = home.joinpath(FLWR_DIR)
|
|
90
91
|
# Create .flwr directory if it does not exist yet.
|
|
91
92
|
try:
|
|
92
93
|
flwr_dir.mkdir(parents=True, exist_ok=True)
|
flwr/proto/serverappio_pb2.py
CHANGED
|
@@ -20,7 +20,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
|
|
20
20
|
from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"
|
|
23
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"P\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"V\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\x9f\x07\n\x0bServerAppIo\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x62\x06proto3')
|
|
24
24
|
|
|
25
25
|
_globals = globals()
|
|
26
26
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -32,21 +32,21 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
32
32
|
_globals['_GETNODESRESPONSE']._serialized_start=217
|
|
33
33
|
_globals['_GETNODESRESPONSE']._serialized_end=268
|
|
34
34
|
_globals['_PUSHTASKINSREQUEST']._serialized_start=270
|
|
35
|
-
_globals['_PUSHTASKINSREQUEST']._serialized_end=
|
|
36
|
-
_globals['_PUSHTASKINSRESPONSE']._serialized_start=
|
|
37
|
-
_globals['_PUSHTASKINSRESPONSE']._serialized_end=
|
|
38
|
-
_globals['_PULLTASKRESREQUEST']._serialized_start=
|
|
39
|
-
_globals['_PULLTASKRESREQUEST']._serialized_end=
|
|
40
|
-
_globals['_PULLTASKRESRESPONSE']._serialized_start=
|
|
41
|
-
_globals['_PULLTASKRESRESPONSE']._serialized_end=
|
|
42
|
-
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=
|
|
43
|
-
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=
|
|
44
|
-
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=
|
|
45
|
-
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=
|
|
46
|
-
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=
|
|
47
|
-
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=
|
|
48
|
-
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=
|
|
49
|
-
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=
|
|
50
|
-
_globals['_SERVERAPPIO']._serialized_start=
|
|
51
|
-
_globals['_SERVERAPPIO']._serialized_end=
|
|
35
|
+
_globals['_PUSHTASKINSREQUEST']._serialized_end=350
|
|
36
|
+
_globals['_PUSHTASKINSRESPONSE']._serialized_start=352
|
|
37
|
+
_globals['_PUSHTASKINSRESPONSE']._serialized_end=391
|
|
38
|
+
_globals['_PULLTASKRESREQUEST']._serialized_start=393
|
|
39
|
+
_globals['_PULLTASKRESREQUEST']._serialized_end=479
|
|
40
|
+
_globals['_PULLTASKRESRESPONSE']._serialized_start=481
|
|
41
|
+
_globals['_PULLTASKRESRESPONSE']._serialized_end=546
|
|
42
|
+
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=548
|
|
43
|
+
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=576
|
|
44
|
+
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=578
|
|
45
|
+
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=705
|
|
46
|
+
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=707
|
|
47
|
+
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=790
|
|
48
|
+
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=792
|
|
49
|
+
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=822
|
|
50
|
+
_globals['_SERVERAPPIO']._serialized_start=825
|
|
51
|
+
_globals['_SERVERAPPIO']._serialized_end=1752
|
|
52
52
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/serverappio_pb2.pyi
CHANGED
|
@@ -44,13 +44,16 @@ class PushTaskInsRequest(google.protobuf.message.Message):
|
|
|
44
44
|
"""PushTaskIns messages"""
|
|
45
45
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
46
46
|
TASK_INS_LIST_FIELD_NUMBER: builtins.int
|
|
47
|
+
RUN_ID_FIELD_NUMBER: builtins.int
|
|
47
48
|
@property
|
|
48
49
|
def task_ins_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.task_pb2.TaskIns]: ...
|
|
50
|
+
run_id: builtins.int
|
|
49
51
|
def __init__(self,
|
|
50
52
|
*,
|
|
51
53
|
task_ins_list: typing.Optional[typing.Iterable[flwr.proto.task_pb2.TaskIns]] = ...,
|
|
54
|
+
run_id: builtins.int = ...,
|
|
52
55
|
) -> None: ...
|
|
53
|
-
def ClearField(self, field_name: typing_extensions.Literal["task_ins_list",b"task_ins_list"]) -> None: ...
|
|
56
|
+
def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id","task_ins_list",b"task_ins_list"]) -> None: ...
|
|
54
57
|
global___PushTaskInsRequest = PushTaskInsRequest
|
|
55
58
|
|
|
56
59
|
class PushTaskInsResponse(google.protobuf.message.Message):
|
|
@@ -70,17 +73,20 @@ class PullTaskResRequest(google.protobuf.message.Message):
|
|
|
70
73
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
71
74
|
NODE_FIELD_NUMBER: builtins.int
|
|
72
75
|
TASK_IDS_FIELD_NUMBER: builtins.int
|
|
76
|
+
RUN_ID_FIELD_NUMBER: builtins.int
|
|
73
77
|
@property
|
|
74
78
|
def node(self) -> flwr.proto.node_pb2.Node: ...
|
|
75
79
|
@property
|
|
76
80
|
def task_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ...
|
|
81
|
+
run_id: builtins.int
|
|
77
82
|
def __init__(self,
|
|
78
83
|
*,
|
|
79
84
|
node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
|
|
80
85
|
task_ids: typing.Optional[typing.Iterable[typing.Text]] = ...,
|
|
86
|
+
run_id: builtins.int = ...,
|
|
81
87
|
) -> None: ...
|
|
82
88
|
def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
|
|
83
|
-
def ClearField(self, field_name: typing_extensions.Literal["node",b"node","task_ids",b"task_ids"]) -> None: ...
|
|
89
|
+
def ClearField(self, field_name: typing_extensions.Literal["node",b"node","run_id",b"run_id","task_ids",b"task_ids"]) -> None: ...
|
|
84
90
|
global___PullTaskResRequest = PullTaskResRequest
|
|
85
91
|
|
|
86
92
|
class PullTaskResResponse(google.protobuf.message.Message):
|
|
@@ -62,6 +62,11 @@ class ServerAppIoStub(object):
|
|
|
62
62
|
request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString,
|
|
63
63
|
response_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString,
|
|
64
64
|
)
|
|
65
|
+
self.GetRunStatus = channel.unary_unary(
|
|
66
|
+
'/flwr.proto.ServerAppIo/GetRunStatus',
|
|
67
|
+
request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
|
|
68
|
+
response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
|
|
69
|
+
)
|
|
65
70
|
self.PushLogs = channel.unary_unary(
|
|
66
71
|
'/flwr.proto.ServerAppIo/PushLogs',
|
|
67
72
|
request_serializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.SerializeToString,
|
|
@@ -135,6 +140,13 @@ class ServerAppIoServicer(object):
|
|
|
135
140
|
context.set_details('Method not implemented!')
|
|
136
141
|
raise NotImplementedError('Method not implemented!')
|
|
137
142
|
|
|
143
|
+
def GetRunStatus(self, request, context):
|
|
144
|
+
"""Get the status of a given run
|
|
145
|
+
"""
|
|
146
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
147
|
+
context.set_details('Method not implemented!')
|
|
148
|
+
raise NotImplementedError('Method not implemented!')
|
|
149
|
+
|
|
138
150
|
def PushLogs(self, request, context):
|
|
139
151
|
"""Push ServerApp logs
|
|
140
152
|
"""
|
|
@@ -190,6 +202,11 @@ def add_ServerAppIoServicer_to_server(servicer, server):
|
|
|
190
202
|
request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString,
|
|
191
203
|
response_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.SerializeToString,
|
|
192
204
|
),
|
|
205
|
+
'GetRunStatus': grpc.unary_unary_rpc_method_handler(
|
|
206
|
+
servicer.GetRunStatus,
|
|
207
|
+
request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString,
|
|
208
|
+
response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString,
|
|
209
|
+
),
|
|
193
210
|
'PushLogs': grpc.unary_unary_rpc_method_handler(
|
|
194
211
|
servicer.PushLogs,
|
|
195
212
|
request_deserializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.FromString,
|
|
@@ -358,6 +375,23 @@ class ServerAppIo(object):
|
|
|
358
375
|
options, channel_credentials,
|
|
359
376
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
360
377
|
|
|
378
|
+
@staticmethod
|
|
379
|
+
def GetRunStatus(request,
|
|
380
|
+
target,
|
|
381
|
+
options=(),
|
|
382
|
+
channel_credentials=None,
|
|
383
|
+
call_credentials=None,
|
|
384
|
+
insecure=False,
|
|
385
|
+
compression=None,
|
|
386
|
+
wait_for_ready=None,
|
|
387
|
+
timeout=None,
|
|
388
|
+
metadata=None):
|
|
389
|
+
return grpc.experimental.unary_unary(request, target, '/flwr.proto.ServerAppIo/GetRunStatus',
|
|
390
|
+
flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
|
|
391
|
+
flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
|
|
392
|
+
options, channel_credentials,
|
|
393
|
+
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
394
|
+
|
|
361
395
|
@staticmethod
|
|
362
396
|
def PushLogs(request,
|
|
363
397
|
target,
|
|
@@ -56,6 +56,11 @@ class ServerAppIoStub:
|
|
|
56
56
|
flwr.proto.run_pb2.UpdateRunStatusResponse]
|
|
57
57
|
"""Update the status of a given run"""
|
|
58
58
|
|
|
59
|
+
GetRunStatus: grpc.UnaryUnaryMultiCallable[
|
|
60
|
+
flwr.proto.run_pb2.GetRunStatusRequest,
|
|
61
|
+
flwr.proto.run_pb2.GetRunStatusResponse]
|
|
62
|
+
"""Get the status of a given run"""
|
|
63
|
+
|
|
59
64
|
PushLogs: grpc.UnaryUnaryMultiCallable[
|
|
60
65
|
flwr.proto.log_pb2.PushLogsRequest,
|
|
61
66
|
flwr.proto.log_pb2.PushLogsResponse]
|
|
@@ -135,6 +140,14 @@ class ServerAppIoServicer(metaclass=abc.ABCMeta):
|
|
|
135
140
|
"""Update the status of a given run"""
|
|
136
141
|
pass
|
|
137
142
|
|
|
143
|
+
@abc.abstractmethod
|
|
144
|
+
def GetRunStatus(self,
|
|
145
|
+
request: flwr.proto.run_pb2.GetRunStatusRequest,
|
|
146
|
+
context: grpc.ServicerContext,
|
|
147
|
+
) -> flwr.proto.run_pb2.GetRunStatusResponse:
|
|
148
|
+
"""Get the status of a given run"""
|
|
149
|
+
pass
|
|
150
|
+
|
|
138
151
|
@abc.abstractmethod
|
|
139
152
|
def PushLogs(self,
|
|
140
153
|
request: flwr.proto.log_pb2.PushLogsRequest,
|
flwr/server/app.py
CHANGED
|
@@ -24,9 +24,10 @@ from collections.abc import Sequence
|
|
|
24
24
|
from logging import DEBUG, INFO, WARN
|
|
25
25
|
from pathlib import Path
|
|
26
26
|
from time import sleep
|
|
27
|
-
from typing import Optional
|
|
27
|
+
from typing import Any, Optional
|
|
28
28
|
|
|
29
29
|
import grpc
|
|
30
|
+
import yaml
|
|
30
31
|
from cryptography.exceptions import UnsupportedAlgorithm
|
|
31
32
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
32
33
|
from cryptography.hazmat.primitives.serialization import (
|
|
@@ -37,8 +38,10 @@ from cryptography.hazmat.primitives.serialization import (
|
|
|
37
38
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
38
39
|
from flwr.common.address import parse_address
|
|
39
40
|
from flwr.common.args import try_obtain_server_certificates
|
|
41
|
+
from flwr.common.auth_plugin import ExecAuthPlugin
|
|
40
42
|
from flwr.common.config import get_flwr_dir, parse_config_args
|
|
41
43
|
from flwr.common.constant import (
|
|
44
|
+
AUTH_TYPE,
|
|
42
45
|
CLIENT_OCTET,
|
|
43
46
|
EXEC_API_DEFAULT_SERVER_ADDRESS,
|
|
44
47
|
FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
@@ -88,6 +91,15 @@ DATABASE = ":flwr-in-memory-state:"
|
|
|
88
91
|
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
|
|
89
92
|
|
|
90
93
|
|
|
94
|
+
try:
|
|
95
|
+
from flwr.ee import get_exec_auth_plugins
|
|
96
|
+
except ImportError:
|
|
97
|
+
|
|
98
|
+
def get_exec_auth_plugins() -> dict[str, type[ExecAuthPlugin]]:
|
|
99
|
+
"""Return all Exec API authentication plugins."""
|
|
100
|
+
raise NotImplementedError("No authentication plugins are currently supported.")
|
|
101
|
+
|
|
102
|
+
|
|
91
103
|
def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
92
104
|
*,
|
|
93
105
|
server_address: str = FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
@@ -246,6 +258,12 @@ def run_superlink() -> None:
|
|
|
246
258
|
# Obtain certificates
|
|
247
259
|
certificates = try_obtain_server_certificates(args, args.fleet_api_type)
|
|
248
260
|
|
|
261
|
+
user_auth_config = _try_obtain_user_auth_config(args)
|
|
262
|
+
auth_plugin: Optional[ExecAuthPlugin] = None
|
|
263
|
+
# user_auth_config is None only if the args.user_auth_config is not provided
|
|
264
|
+
if user_auth_config is not None:
|
|
265
|
+
auth_plugin = _try_obtain_exec_auth_plugin(user_auth_config)
|
|
266
|
+
|
|
249
267
|
# Initialize StateFactory
|
|
250
268
|
state_factory = LinkStateFactory(args.database)
|
|
251
269
|
|
|
@@ -263,6 +281,7 @@ def run_superlink() -> None:
|
|
|
263
281
|
config=parse_config_args(
|
|
264
282
|
[args.executor_config] if args.executor_config else args.executor_config
|
|
265
283
|
),
|
|
284
|
+
auth_plugin=auth_plugin,
|
|
266
285
|
)
|
|
267
286
|
grpc_servers = [exec_server]
|
|
268
287
|
|
|
@@ -559,6 +578,32 @@ def _try_setup_node_authentication(
|
|
|
559
578
|
)
|
|
560
579
|
|
|
561
580
|
|
|
581
|
+
def _try_obtain_user_auth_config(args: argparse.Namespace) -> Optional[dict[str, Any]]:
|
|
582
|
+
if args.user_auth_config is not None:
|
|
583
|
+
with open(args.user_auth_config, encoding="utf-8") as file:
|
|
584
|
+
config: dict[str, Any] = yaml.safe_load(file)
|
|
585
|
+
return config
|
|
586
|
+
return None
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def _try_obtain_exec_auth_plugin(config: dict[str, Any]) -> Optional[ExecAuthPlugin]:
|
|
590
|
+
auth_config: dict[str, Any] = config.get("authentication", {})
|
|
591
|
+
auth_type: str = auth_config.get(AUTH_TYPE, "")
|
|
592
|
+
try:
|
|
593
|
+
all_plugins: dict[str, type[ExecAuthPlugin]] = get_exec_auth_plugins()
|
|
594
|
+
auth_plugin_class = all_plugins[auth_type]
|
|
595
|
+
return auth_plugin_class(config=auth_config)
|
|
596
|
+
except KeyError:
|
|
597
|
+
if auth_type != "":
|
|
598
|
+
sys.exit(
|
|
599
|
+
f'Authentication type "{auth_type}" is not supported. '
|
|
600
|
+
"Please provide a valid authentication type in the configuration."
|
|
601
|
+
)
|
|
602
|
+
sys.exit("No authentication type is provided in the configuration.")
|
|
603
|
+
except NotImplementedError:
|
|
604
|
+
sys.exit("No authentication plugins are currently supported.")
|
|
605
|
+
|
|
606
|
+
|
|
562
607
|
def _run_fleet_api_grpc_rere(
|
|
563
608
|
address: str,
|
|
564
609
|
state_factory: LinkStateFactory,
|
|
@@ -746,6 +791,12 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
746
791
|
type=str,
|
|
747
792
|
help="The SuperLink's public key (as a path str) to enable authentication.",
|
|
748
793
|
)
|
|
794
|
+
parser.add_argument(
|
|
795
|
+
"--user-auth-config",
|
|
796
|
+
help="The path to the user authentication configuration YAML file.",
|
|
797
|
+
type=str,
|
|
798
|
+
default=None,
|
|
799
|
+
)
|
|
749
800
|
|
|
750
801
|
|
|
751
802
|
def _add_args_serverappio_api(parser: argparse.ArgumentParser) -> None:
|
|
@@ -203,7 +203,9 @@ class GrpcDriver(Driver):
|
|
|
203
203
|
task_ins_list.append(taskins)
|
|
204
204
|
# Call GrpcDriverStub method
|
|
205
205
|
res: PushTaskInsResponse = self._stub.PushTaskIns(
|
|
206
|
-
PushTaskInsRequest(
|
|
206
|
+
PushTaskInsRequest(
|
|
207
|
+
task_ins_list=task_ins_list, run_id=cast(Run, self._run).run_id
|
|
208
|
+
)
|
|
207
209
|
)
|
|
208
210
|
return list(res.task_ids)
|
|
209
211
|
|
|
@@ -215,7 +217,9 @@ class GrpcDriver(Driver):
|
|
|
215
217
|
"""
|
|
216
218
|
# Pull TaskRes
|
|
217
219
|
res: PullTaskResResponse = self._stub.PullTaskRes(
|
|
218
|
-
PullTaskResRequest(
|
|
220
|
+
PullTaskResRequest(
|
|
221
|
+
node=self.node, task_ids=message_ids, run_id=cast(Run, self._run).run_id
|
|
222
|
+
)
|
|
219
223
|
)
|
|
220
224
|
# Convert TaskRes to Message
|
|
221
225
|
msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
|