flwr-nightly 1.14.0.dev20241209__py3-none-any.whl → 1.14.0.dev20241211__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

flwr/cli/utils.py CHANGED
@@ -15,12 +15,32 @@
15
15
  """Flower command line interface utils."""
16
16
 
17
17
  import hashlib
18
+ import json
18
19
  import re
20
+ from logging import DEBUG
19
21
  from pathlib import Path
20
- from typing import Callable, Optional, cast
22
+ from typing import Any, Callable, Optional, cast
21
23
 
24
+ import grpc
22
25
  import typer
23
26
 
27
+ from flwr.cli.cli_user_auth_interceptor import CliUserAuthInterceptor
28
+ from flwr.common.address import parse_address
29
+ from flwr.common.auth_plugin import CliAuthPlugin
30
+ from flwr.common.constant import AUTH_TYPE, CREDENTIALS_DIR, FLWR_DIR
31
+ from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
32
+ from flwr.common.logger import log
33
+
34
+ from .config_utils import validate_certificate_in_federation_config
35
+
36
+ try:
37
+ from flwr.ee import get_cli_auth_plugins
38
+ except ImportError:
39
+
40
+ def get_cli_auth_plugins() -> dict[str, type[CliAuthPlugin]]:
41
+ """Return all CLI authentication plugins."""
42
+ raise NotImplementedError("No authentication plugins are currently supported.")
43
+
24
44
 
25
45
  def prompt_text(
26
46
  text: str,
@@ -136,3 +156,90 @@ def get_sha256_hash(file_path: Path) -> str:
136
156
  break
137
157
  sha256.update(data)
138
158
  return sha256.hexdigest()
159
+
160
+
161
+ def get_user_auth_config_path(
162
+ root_dir: Path, federation: str, server_address: str
163
+ ) -> Path:
164
+ """Return the path to the user auth config file."""
165
+ # Parse the server address
166
+ parsed_addr = parse_address(server_address)
167
+ if parsed_addr is None:
168
+ raise ValueError(f"Invalid server address: {server_address}")
169
+ host, port, is_v6 = parsed_addr
170
+ formatted_addr = f"[{host}]_{port}" if is_v6 else f"{host}_{port}"
171
+
172
+ # Locate the credentials directory
173
+ credentials_dir = root_dir.absolute() / FLWR_DIR / CREDENTIALS_DIR
174
+ credentials_dir.mkdir(parents=True, exist_ok=True)
175
+ return credentials_dir / f"{federation}_{formatted_addr}.json"
176
+
177
+
178
+ def try_obtain_cli_auth_plugin(
179
+ root_dir: Path,
180
+ federation: str,
181
+ federation_config: dict[str, Any],
182
+ auth_type: Optional[str] = None,
183
+ ) -> Optional[CliAuthPlugin]:
184
+ """Load the CLI-side user auth plugin for the given auth type."""
185
+ config_path = get_user_auth_config_path(
186
+ root_dir, federation, federation_config["address"]
187
+ )
188
+
189
+ # Load the config file if it exists
190
+ config: dict[str, Any] = {}
191
+ if config_path.exists():
192
+ with config_path.open("r", encoding="utf-8") as file:
193
+ config = json.load(file)
194
+ # This is the case when the user auth is not enabled
195
+ elif auth_type is None:
196
+ return None
197
+
198
+ # Get the auth type from the config if not provided
199
+ if auth_type is None:
200
+ if AUTH_TYPE not in config:
201
+ return None
202
+ auth_type = config[AUTH_TYPE]
203
+
204
+ # Retrieve auth plugin class and instantiate it
205
+ try:
206
+ all_plugins: dict[str, type[CliAuthPlugin]] = get_cli_auth_plugins()
207
+ auth_plugin_class = all_plugins[auth_type]
208
+ return auth_plugin_class(config_path)
209
+ except KeyError:
210
+ typer.echo(f"❌ Unknown user authentication type: {auth_type}")
211
+ raise typer.Exit(code=1) from None
212
+ except ImportError:
213
+ typer.echo("❌ No authentication plugins are currently supported.")
214
+ raise typer.Exit(code=1) from None
215
+
216
+
217
+ def init_channel(
218
+ app: Path, federation_config: dict[str, Any], auth_plugin: Optional[CliAuthPlugin]
219
+ ) -> grpc.Channel:
220
+ """Initialize gRPC channel to the Exec API."""
221
+
222
+ def on_channel_state_change(channel_connectivity: str) -> None:
223
+ """Log channel connectivity."""
224
+ log(DEBUG, channel_connectivity)
225
+
226
+ insecure, root_certificates_bytes = validate_certificate_in_federation_config(
227
+ app, federation_config
228
+ )
229
+
230
+ # Initialize the CLI-side user auth interceptor
231
+ interceptors: list[grpc.UnaryUnaryClientInterceptor] = []
232
+ if auth_plugin is not None:
233
+ auth_plugin.load_tokens()
234
+ interceptors = CliUserAuthInterceptor(auth_plugin)
235
+
236
+ # Create the gRPC channel
237
+ channel = create_channel(
238
+ server_address=federation_config["address"],
239
+ insecure=insecure,
240
+ root_certificates=root_certificates_bytes,
241
+ max_message_length=GRPC_MAX_MESSAGE_LENGTH,
242
+ interceptors=interceptors or None,
243
+ )
244
+ channel.subscribe(on_channel_state_change)
245
+ return channel
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Auth plugin components."""
16
+
17
+
18
+ from .auth_plugin import CliAuthPlugin as CliAuthPlugin
19
+ from .auth_plugin import ExecAuthPlugin as ExecAuthPlugin
20
+
21
+ __all__ = [
22
+ "CliAuthPlugin",
23
+ "ExecAuthPlugin",
24
+ ]
@@ -0,0 +1,111 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Abstract classes for Flower User Auth Plugin."""
16
+
17
+
18
+ from abc import ABC, abstractmethod
19
+ from collections.abc import Sequence
20
+ from pathlib import Path
21
+ from typing import Any, Optional, Union
22
+
23
+ from flwr.proto.exec_pb2_grpc import ExecStub
24
+
25
+
26
+ class ExecAuthPlugin(ABC):
27
+ """Abstract Flower Auth Plugin class for ExecServicer.
28
+
29
+ Parameters
30
+ ----------
31
+ config : dict[str, Any]
32
+ The authentication configuration loaded from a YAML file.
33
+ """
34
+
35
+ @abstractmethod
36
+ def __init__(self, config: dict[str, Any]):
37
+ """Abstract constructor."""
38
+
39
+ @abstractmethod
40
+ def get_login_details(self) -> dict[str, str]:
41
+ """Get the login details."""
42
+
43
+ @abstractmethod
44
+ def validate_tokens_in_metadata(
45
+ self, metadata: Sequence[tuple[str, Union[str, bytes]]]
46
+ ) -> bool:
47
+ """Validate authentication tokens in the provided metadata."""
48
+
49
+ @abstractmethod
50
+ def get_auth_tokens(self, auth_details: dict[str, str]) -> dict[str, str]:
51
+ """Get authentication tokens."""
52
+
53
+ @abstractmethod
54
+ def refresh_tokens(
55
+ self, metadata: Sequence[tuple[str, Union[str, bytes]]]
56
+ ) -> Optional[Sequence[tuple[str, Union[str, bytes]]]]:
57
+ """Refresh authentication tokens in the provided metadata."""
58
+
59
+
60
+ class CliAuthPlugin(ABC):
61
+ """Abstract Flower Auth Plugin class for CLI.
62
+
63
+ Parameters
64
+ ----------
65
+ user_auth_config_path : Path
66
+ The path to the user's authentication configuration file.
67
+ """
68
+
69
+ @staticmethod
70
+ @abstractmethod
71
+ def login(
72
+ login_details: dict[str, str],
73
+ exec_stub: ExecStub,
74
+ ) -> dict[str, Any]:
75
+ """Authenticate the user with the SuperLink.
76
+
77
+ Parameters
78
+ ----------
79
+ login_details : dict[str, str]
80
+ A dictionary containing the user's login details.
81
+ exec_stub : ExecStub
82
+ An instance of `ExecStub` used for communication with the SuperLink.
83
+
84
+ Returns
85
+ -------
86
+ user_auth_config : dict[str, Any]
87
+ A dictionary containing the user's authentication configuration
88
+ in JSON format.
89
+ """
90
+
91
+ @abstractmethod
92
+ def __init__(self, user_auth_config_path: Path):
93
+ """Abstract constructor."""
94
+
95
+ @abstractmethod
96
+ def store_tokens(self, user_auth_config: dict[str, Any]) -> None:
97
+ """Store authentication tokens from the provided user_auth_config.
98
+
99
+ The configuration, including tokens, will be saved as a JSON file
100
+ at `user_auth_config_path`.
101
+ """
102
+
103
+ @abstractmethod
104
+ def load_tokens(self) -> None:
105
+ """Load authentication tokens from the user_auth_config_path."""
106
+
107
+ @abstractmethod
108
+ def write_tokens_to_metadata(
109
+ self, metadata: Sequence[tuple[str, Union[str, bytes]]]
110
+ ) -> Sequence[tuple[str, Union[str, bytes]]]:
111
+ """Write authentication tokens to the provided metadata."""
flwr/common/config.py CHANGED
@@ -27,6 +27,7 @@ from flwr.common.constant import (
27
27
  APP_DIR,
28
28
  FAB_CONFIG_FILE,
29
29
  FAB_HASH_TRUNCATION,
30
+ FLWR_DIR,
30
31
  FLWR_HOME,
31
32
  )
32
33
  from flwr.common.typing import Run, UserConfig, UserConfigValue
@@ -38,7 +39,7 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
38
39
  return Path(
39
40
  os.getenv(
40
41
  FLWR_HOME,
41
- Path(f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}") / ".flwr",
42
+ Path(f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}") / FLWR_DIR,
42
43
  )
43
44
  )
44
45
  return Path(provided_path).absolute()
flwr/common/constant.py CHANGED
@@ -80,7 +80,8 @@ FAB_ALLOWED_EXTENSIONS = {".py", ".toml", ".md"}
80
80
  FAB_CONFIG_FILE = "pyproject.toml"
81
81
  FAB_DATE = (2024, 10, 1, 0, 0, 0)
82
82
  FAB_HASH_TRUNCATION = 8
83
- FLWR_HOME = "FLWR_HOME"
83
+ FLWR_DIR = ".flwr" # The default Flower directory: ~/.flwr/
84
+ FLWR_HOME = "FLWR_HOME" # If set, override the default Flower directory
84
85
 
85
86
  # Constants entries in Node config for Simulation
86
87
  PARTITION_ID_KEY = "partition-id"
@@ -110,6 +111,10 @@ LOG_UPLOAD_INTERVAL = 0.2 # Minimum interval between two log uploads
110
111
  # Retry configurations
111
112
  MAX_RETRY_DELAY = 20 # Maximum delay duration between two consecutive retries.
112
113
 
114
+ # Constants for user authentication
115
+ CREDENTIALS_DIR = ".credentials"
116
+ AUTH_TYPE = "auth_type"
117
+
113
118
 
114
119
  class MessageType:
115
120
  """Message type."""
@@ -93,8 +93,8 @@ def pseudo_rand_gen(
93
93
  output = []
94
94
  for dimension in dimensions_list:
95
95
  if len(dimension) == 0:
96
- arr = np.array(gen.randint(0, num_range - 1), dtype=int)
96
+ arr = np.array(gen.randint(0, num_range - 1), dtype=np.int64)
97
97
  else:
98
- arr = gen.randint(0, num_range - 1, dimension)
98
+ arr = gen.randint(0, num_range - 1, dimension, dtype=np.int64)
99
99
  output.append(arr)
100
100
  return output
flwr/common/telemetry.py CHANGED
@@ -27,6 +27,7 @@ from enum import Enum, auto
27
27
  from pathlib import Path
28
28
  from typing import Any, Optional, Union, cast
29
29
 
30
+ from flwr.common.constant import FLWR_DIR
30
31
  from flwr.common.version import package_name, package_version
31
32
 
32
33
  FLWR_TELEMETRY_ENABLED = os.getenv("FLWR_TELEMETRY_ENABLED", "1")
@@ -86,7 +87,7 @@ def _get_source_id() -> str:
86
87
  # If the home directory can’t be resolved, RuntimeError is raised.
87
88
  return source_id
88
89
 
89
- flwr_dir = home.joinpath(".flwr")
90
+ flwr_dir = home.joinpath(FLWR_DIR)
90
91
  # Create .flwr directory if it does not exist yet.
91
92
  try:
92
93
  flwr_dir.mkdir(parents=True, exist_ok=True)
@@ -20,7 +20,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
20
20
  from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
21
21
 
22
22
 
23
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\xca\x06\n\x0bServerAppIo\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x62\x06proto3')
23
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"P\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"V\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\x9f\x07\n\x0bServerAppIo\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x62\x06proto3')
24
24
 
25
25
  _globals = globals()
26
26
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -32,21 +32,21 @@ if _descriptor._USE_C_DESCRIPTORS == False:
32
32
  _globals['_GETNODESRESPONSE']._serialized_start=217
33
33
  _globals['_GETNODESRESPONSE']._serialized_end=268
34
34
  _globals['_PUSHTASKINSREQUEST']._serialized_start=270
35
- _globals['_PUSHTASKINSREQUEST']._serialized_end=334
36
- _globals['_PUSHTASKINSRESPONSE']._serialized_start=336
37
- _globals['_PUSHTASKINSRESPONSE']._serialized_end=375
38
- _globals['_PULLTASKRESREQUEST']._serialized_start=377
39
- _globals['_PULLTASKRESREQUEST']._serialized_end=447
40
- _globals['_PULLTASKRESRESPONSE']._serialized_start=449
41
- _globals['_PULLTASKRESRESPONSE']._serialized_end=514
42
- _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=516
43
- _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=544
44
- _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=546
45
- _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=673
46
- _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=675
47
- _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=758
48
- _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=760
49
- _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=790
50
- _globals['_SERVERAPPIO']._serialized_start=793
51
- _globals['_SERVERAPPIO']._serialized_end=1635
35
+ _globals['_PUSHTASKINSREQUEST']._serialized_end=350
36
+ _globals['_PUSHTASKINSRESPONSE']._serialized_start=352
37
+ _globals['_PUSHTASKINSRESPONSE']._serialized_end=391
38
+ _globals['_PULLTASKRESREQUEST']._serialized_start=393
39
+ _globals['_PULLTASKRESREQUEST']._serialized_end=479
40
+ _globals['_PULLTASKRESRESPONSE']._serialized_start=481
41
+ _globals['_PULLTASKRESRESPONSE']._serialized_end=546
42
+ _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=548
43
+ _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=576
44
+ _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=578
45
+ _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=705
46
+ _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=707
47
+ _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=790
48
+ _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=792
49
+ _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=822
50
+ _globals['_SERVERAPPIO']._serialized_start=825
51
+ _globals['_SERVERAPPIO']._serialized_end=1752
52
52
  # @@protoc_insertion_point(module_scope)
@@ -44,13 +44,16 @@ class PushTaskInsRequest(google.protobuf.message.Message):
44
44
  """PushTaskIns messages"""
45
45
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
46
46
  TASK_INS_LIST_FIELD_NUMBER: builtins.int
47
+ RUN_ID_FIELD_NUMBER: builtins.int
47
48
  @property
48
49
  def task_ins_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.task_pb2.TaskIns]: ...
50
+ run_id: builtins.int
49
51
  def __init__(self,
50
52
  *,
51
53
  task_ins_list: typing.Optional[typing.Iterable[flwr.proto.task_pb2.TaskIns]] = ...,
54
+ run_id: builtins.int = ...,
52
55
  ) -> None: ...
53
- def ClearField(self, field_name: typing_extensions.Literal["task_ins_list",b"task_ins_list"]) -> None: ...
56
+ def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id","task_ins_list",b"task_ins_list"]) -> None: ...
54
57
  global___PushTaskInsRequest = PushTaskInsRequest
55
58
 
56
59
  class PushTaskInsResponse(google.protobuf.message.Message):
@@ -70,17 +73,20 @@ class PullTaskResRequest(google.protobuf.message.Message):
70
73
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
71
74
  NODE_FIELD_NUMBER: builtins.int
72
75
  TASK_IDS_FIELD_NUMBER: builtins.int
76
+ RUN_ID_FIELD_NUMBER: builtins.int
73
77
  @property
74
78
  def node(self) -> flwr.proto.node_pb2.Node: ...
75
79
  @property
76
80
  def task_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ...
81
+ run_id: builtins.int
77
82
  def __init__(self,
78
83
  *,
79
84
  node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
80
85
  task_ids: typing.Optional[typing.Iterable[typing.Text]] = ...,
86
+ run_id: builtins.int = ...,
81
87
  ) -> None: ...
82
88
  def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
83
- def ClearField(self, field_name: typing_extensions.Literal["node",b"node","task_ids",b"task_ids"]) -> None: ...
89
+ def ClearField(self, field_name: typing_extensions.Literal["node",b"node","run_id",b"run_id","task_ids",b"task_ids"]) -> None: ...
84
90
  global___PullTaskResRequest = PullTaskResRequest
85
91
 
86
92
  class PullTaskResResponse(google.protobuf.message.Message):
@@ -62,6 +62,11 @@ class ServerAppIoStub(object):
62
62
  request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString,
63
63
  response_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString,
64
64
  )
65
+ self.GetRunStatus = channel.unary_unary(
66
+ '/flwr.proto.ServerAppIo/GetRunStatus',
67
+ request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
68
+ response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
69
+ )
65
70
  self.PushLogs = channel.unary_unary(
66
71
  '/flwr.proto.ServerAppIo/PushLogs',
67
72
  request_serializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.SerializeToString,
@@ -135,6 +140,13 @@ class ServerAppIoServicer(object):
135
140
  context.set_details('Method not implemented!')
136
141
  raise NotImplementedError('Method not implemented!')
137
142
 
143
+ def GetRunStatus(self, request, context):
144
+ """Get the status of a given run
145
+ """
146
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
147
+ context.set_details('Method not implemented!')
148
+ raise NotImplementedError('Method not implemented!')
149
+
138
150
  def PushLogs(self, request, context):
139
151
  """Push ServerApp logs
140
152
  """
@@ -190,6 +202,11 @@ def add_ServerAppIoServicer_to_server(servicer, server):
190
202
  request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString,
191
203
  response_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.SerializeToString,
192
204
  ),
205
+ 'GetRunStatus': grpc.unary_unary_rpc_method_handler(
206
+ servicer.GetRunStatus,
207
+ request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString,
208
+ response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString,
209
+ ),
193
210
  'PushLogs': grpc.unary_unary_rpc_method_handler(
194
211
  servicer.PushLogs,
195
212
  request_deserializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.FromString,
@@ -358,6 +375,23 @@ class ServerAppIo(object):
358
375
  options, channel_credentials,
359
376
  insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
360
377
 
378
+ @staticmethod
379
+ def GetRunStatus(request,
380
+ target,
381
+ options=(),
382
+ channel_credentials=None,
383
+ call_credentials=None,
384
+ insecure=False,
385
+ compression=None,
386
+ wait_for_ready=None,
387
+ timeout=None,
388
+ metadata=None):
389
+ return grpc.experimental.unary_unary(request, target, '/flwr.proto.ServerAppIo/GetRunStatus',
390
+ flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
391
+ flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
392
+ options, channel_credentials,
393
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
394
+
361
395
  @staticmethod
362
396
  def PushLogs(request,
363
397
  target,
@@ -56,6 +56,11 @@ class ServerAppIoStub:
56
56
  flwr.proto.run_pb2.UpdateRunStatusResponse]
57
57
  """Update the status of a given run"""
58
58
 
59
+ GetRunStatus: grpc.UnaryUnaryMultiCallable[
60
+ flwr.proto.run_pb2.GetRunStatusRequest,
61
+ flwr.proto.run_pb2.GetRunStatusResponse]
62
+ """Get the status of a given run"""
63
+
59
64
  PushLogs: grpc.UnaryUnaryMultiCallable[
60
65
  flwr.proto.log_pb2.PushLogsRequest,
61
66
  flwr.proto.log_pb2.PushLogsResponse]
@@ -135,6 +140,14 @@ class ServerAppIoServicer(metaclass=abc.ABCMeta):
135
140
  """Update the status of a given run"""
136
141
  pass
137
142
 
143
+ @abc.abstractmethod
144
+ def GetRunStatus(self,
145
+ request: flwr.proto.run_pb2.GetRunStatusRequest,
146
+ context: grpc.ServicerContext,
147
+ ) -> flwr.proto.run_pb2.GetRunStatusResponse:
148
+ """Get the status of a given run"""
149
+ pass
150
+
138
151
  @abc.abstractmethod
139
152
  def PushLogs(self,
140
153
  request: flwr.proto.log_pb2.PushLogsRequest,
flwr/server/app.py CHANGED
@@ -24,9 +24,10 @@ from collections.abc import Sequence
24
24
  from logging import DEBUG, INFO, WARN
25
25
  from pathlib import Path
26
26
  from time import sleep
27
- from typing import Optional
27
+ from typing import Any, Optional
28
28
 
29
29
  import grpc
30
+ import yaml
30
31
  from cryptography.exceptions import UnsupportedAlgorithm
31
32
  from cryptography.hazmat.primitives.asymmetric import ec
32
33
  from cryptography.hazmat.primitives.serialization import (
@@ -37,8 +38,10 @@ from cryptography.hazmat.primitives.serialization import (
37
38
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
38
39
  from flwr.common.address import parse_address
39
40
  from flwr.common.args import try_obtain_server_certificates
41
+ from flwr.common.auth_plugin import ExecAuthPlugin
40
42
  from flwr.common.config import get_flwr_dir, parse_config_args
41
43
  from flwr.common.constant import (
44
+ AUTH_TYPE,
42
45
  CLIENT_OCTET,
43
46
  EXEC_API_DEFAULT_SERVER_ADDRESS,
44
47
  FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
@@ -88,6 +91,15 @@ DATABASE = ":flwr-in-memory-state:"
88
91
  BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
89
92
 
90
93
 
94
+ try:
95
+ from flwr.ee import get_exec_auth_plugins
96
+ except ImportError:
97
+
98
+ def get_exec_auth_plugins() -> dict[str, type[ExecAuthPlugin]]:
99
+ """Return all Exec API authentication plugins."""
100
+ raise NotImplementedError("No authentication plugins are currently supported.")
101
+
102
+
91
103
  def start_server( # pylint: disable=too-many-arguments,too-many-locals
92
104
  *,
93
105
  server_address: str = FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
@@ -246,6 +258,12 @@ def run_superlink() -> None:
246
258
  # Obtain certificates
247
259
  certificates = try_obtain_server_certificates(args, args.fleet_api_type)
248
260
 
261
+ user_auth_config = _try_obtain_user_auth_config(args)
262
+ auth_plugin: Optional[ExecAuthPlugin] = None
263
+ # user_auth_config is None only if the args.user_auth_config is not provided
264
+ if user_auth_config is not None:
265
+ auth_plugin = _try_obtain_exec_auth_plugin(user_auth_config)
266
+
249
267
  # Initialize StateFactory
250
268
  state_factory = LinkStateFactory(args.database)
251
269
 
@@ -263,6 +281,7 @@ def run_superlink() -> None:
263
281
  config=parse_config_args(
264
282
  [args.executor_config] if args.executor_config else args.executor_config
265
283
  ),
284
+ auth_plugin=auth_plugin,
266
285
  )
267
286
  grpc_servers = [exec_server]
268
287
 
@@ -559,6 +578,32 @@ def _try_setup_node_authentication(
559
578
  )
560
579
 
561
580
 
581
+ def _try_obtain_user_auth_config(args: argparse.Namespace) -> Optional[dict[str, Any]]:
582
+ if args.user_auth_config is not None:
583
+ with open(args.user_auth_config, encoding="utf-8") as file:
584
+ config: dict[str, Any] = yaml.safe_load(file)
585
+ return config
586
+ return None
587
+
588
+
589
+ def _try_obtain_exec_auth_plugin(config: dict[str, Any]) -> Optional[ExecAuthPlugin]:
590
+ auth_config: dict[str, Any] = config.get("authentication", {})
591
+ auth_type: str = auth_config.get(AUTH_TYPE, "")
592
+ try:
593
+ all_plugins: dict[str, type[ExecAuthPlugin]] = get_exec_auth_plugins()
594
+ auth_plugin_class = all_plugins[auth_type]
595
+ return auth_plugin_class(config=auth_config)
596
+ except KeyError:
597
+ if auth_type != "":
598
+ sys.exit(
599
+ f'Authentication type "{auth_type}" is not supported. '
600
+ "Please provide a valid authentication type in the configuration."
601
+ )
602
+ sys.exit("No authentication type is provided in the configuration.")
603
+ except NotImplementedError:
604
+ sys.exit("No authentication plugins are currently supported.")
605
+
606
+
562
607
  def _run_fleet_api_grpc_rere(
563
608
  address: str,
564
609
  state_factory: LinkStateFactory,
@@ -746,6 +791,12 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
746
791
  type=str,
747
792
  help="The SuperLink's public key (as a path str) to enable authentication.",
748
793
  )
794
+ parser.add_argument(
795
+ "--user-auth-config",
796
+ help="The path to the user authentication configuration YAML file.",
797
+ type=str,
798
+ default=None,
799
+ )
749
800
 
750
801
 
751
802
  def _add_args_serverappio_api(parser: argparse.ArgumentParser) -> None:
@@ -203,7 +203,9 @@ class GrpcDriver(Driver):
203
203
  task_ins_list.append(taskins)
204
204
  # Call GrpcDriverStub method
205
205
  res: PushTaskInsResponse = self._stub.PushTaskIns(
206
- PushTaskInsRequest(task_ins_list=task_ins_list)
206
+ PushTaskInsRequest(
207
+ task_ins_list=task_ins_list, run_id=cast(Run, self._run).run_id
208
+ )
207
209
  )
208
210
  return list(res.task_ids)
209
211
 
@@ -215,7 +217,9 @@ class GrpcDriver(Driver):
215
217
  """
216
218
  # Pull TaskRes
217
219
  res: PullTaskResResponse = self._stub.PullTaskRes(
218
- PullTaskResRequest(node=self.node, task_ids=message_ids)
220
+ PullTaskResRequest(
221
+ node=self.node, task_ids=message_ids, run_id=cast(Run, self._run).run_id
222
+ )
219
223
  )
220
224
  # Convert TaskRes to Message
221
225
  msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]