flwr 1.18.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/build.py +82 -57
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +15 -36
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +10 -18
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +31 -5
- flwr/client/__init__.py +2 -2
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/__init__.py +0 -7
- flwr/client/grpc_adapter_client/connection.py +4 -4
- flwr/client/grpc_rere_client/connection.py +130 -60
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/comms_mods.py +36 -17
- flwr/client/rest_client/connection.py +173 -67
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +29 -3
- flwr/common/constant.py +36 -7
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit_handlers.py +30 -0
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_grpc_utils.py +99 -0
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/inflatable_utils.py +341 -0
- flwr/common/message.py +110 -242
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/array.py +323 -0
- flwr/common/record/arrayrecord.py +103 -225
- flwr/common/record/configrecord.py +59 -4
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/metricrecord.py +55 -4
- flwr/common/record/recorddict.py +69 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +5 -1
- flwr/common/serde.py +59 -183
- flwr/common/serde_utils.py +175 -0
- flwr/common/typing.py +5 -3
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +19 -159
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/fleet_pb2.py +32 -27
- flwr/proto/fleet_pb2.pyi +49 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +32 -23
- flwr/proto/serverappio_pb2.pyi +45 -3
- flwr/proto/serverappio_pb2_grpc.py +138 -34
- flwr/proto/serverappio_pb2_grpc.pyi +54 -13
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +1 -1
- flwr/server/app.py +68 -186
- flwr/server/compat/app_utils.py +50 -28
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grpc_grid.py +104 -34
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/serverapp/app.py +18 -0
- flwr/server/superlink/ffs/__init__.py +2 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +13 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +101 -7
- flwr/server/superlink/fleet/message_handler/message_handler.py +135 -18
- flwr/server/superlink/fleet/rest_rere/rest_api.py +72 -11
- flwr/server/superlink/fleet/vce/vce_api.py +6 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
- flwr/server/superlink/linkstate/linkstate.py +53 -20
- flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
- flwr/server/superlink/linkstate/utils.py +33 -29
- flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
- flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
- flwr/server/superlink/simulation/simulationio_servicer.py +25 -1
- flwr/server/superlink/utils.py +44 -2
- flwr/server/utils/validator.py +2 -2
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/app.py +17 -0
- flwr/supercore/__init__.py +15 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +192 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/superexec/deployment.py +6 -2
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_grpc.py +7 -3
- flwr/superexec/exec_servicer.py +125 -23
- flwr/superexec/exec_user_auth_interceptor.py +37 -8
- flwr/superexec/executor.py +4 -0
- flwr/superexec/simulation.py +7 -1
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +0 -7
- flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +7 -14
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -12
- flwr/supernode/cli/flwr_clientapp.py +81 -0
- flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
- flwr/supernode/nodestate/nodestate.py +212 -0
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +25 -56
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +24 -0
- flwr/supernode/start_client_internal.py +491 -0
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/METADATA +5 -4
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/RECORD +141 -108
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
- /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
- /flwr/{client → supernode}/nodestate/__init__.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
- /flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +0 -0
flwr/server/app.py
CHANGED
|
@@ -27,7 +27,7 @@ from collections.abc import Sequence
|
|
|
27
27
|
from logging import DEBUG, INFO, WARN
|
|
28
28
|
from pathlib import Path
|
|
29
29
|
from time import sleep
|
|
30
|
-
from typing import Any, Optional
|
|
30
|
+
from typing import Any, Callable, Optional, TypeVar
|
|
31
31
|
|
|
32
32
|
import grpc
|
|
33
33
|
import yaml
|
|
@@ -37,13 +37,13 @@ from cryptography.hazmat.primitives.serialization import load_ssh_public_key
|
|
|
37
37
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
38
38
|
from flwr.common.address import parse_address
|
|
39
39
|
from flwr.common.args import try_obtain_server_certificates
|
|
40
|
-
from flwr.common.auth_plugin import ExecAuthPlugin
|
|
40
|
+
from flwr.common.auth_plugin import ExecAuthPlugin, ExecAuthzPlugin
|
|
41
41
|
from flwr.common.config import get_flwr_dir, parse_config_args
|
|
42
42
|
from flwr.common.constant import (
|
|
43
43
|
AUTH_TYPE_YAML_KEY,
|
|
44
|
+
AUTHZ_TYPE_YAML_KEY,
|
|
44
45
|
CLIENT_OCTET,
|
|
45
46
|
EXEC_API_DEFAULT_SERVER_ADDRESS,
|
|
46
|
-
FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
47
47
|
FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
|
|
48
48
|
FLEET_API_REST_DEFAULT_ADDRESS,
|
|
49
49
|
ISOLATION_MODE_PROCESS,
|
|
@@ -60,7 +60,7 @@ from flwr.common.event_log_plugin import EventLogWriterPlugin
|
|
|
60
60
|
from flwr.common.exit import ExitCode, flwr_exit
|
|
61
61
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
62
62
|
from flwr.common.grpc import generic_create_grpc_server
|
|
63
|
-
from flwr.common.logger import log
|
|
63
|
+
from flwr.common.logger import log
|
|
64
64
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
65
65
|
public_key_to_bytes,
|
|
66
66
|
)
|
|
@@ -71,17 +71,12 @@ from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
|
|
|
71
71
|
from flwr.server.fleet_event_log_interceptor import FleetEventLogInterceptor
|
|
72
72
|
from flwr.server.serverapp.app import flwr_serverapp
|
|
73
73
|
from flwr.simulation.app import flwr_simulation
|
|
74
|
+
from flwr.supercore.object_store import ObjectStoreFactory
|
|
74
75
|
from flwr.superexec.app import load_executor
|
|
75
76
|
from flwr.superexec.exec_grpc import run_exec_api_grpc
|
|
76
77
|
|
|
77
|
-
from .client_manager import ClientManager
|
|
78
|
-
from .history import History
|
|
79
|
-
from .server import Server, init_defaults, run_fl
|
|
80
|
-
from .server_config import ServerConfig
|
|
81
|
-
from .strategy import Strategy
|
|
82
78
|
from .superlink.ffs.ffs_factory import FfsFactory
|
|
83
79
|
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
|
|
84
|
-
from .superlink.fleet.grpc_bidi.grpc_server import start_grpc_server
|
|
85
80
|
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
|
86
81
|
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
|
|
87
82
|
from .superlink.linkstate import LinkStateFactory
|
|
@@ -90,13 +85,14 @@ from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
|
|
|
90
85
|
|
|
91
86
|
DATABASE = ":flwr-in-memory-state:"
|
|
92
87
|
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
|
|
88
|
+
P = TypeVar("P", ExecAuthPlugin, ExecAuthzPlugin)
|
|
93
89
|
|
|
94
90
|
|
|
95
91
|
try:
|
|
96
92
|
from flwr.ee import (
|
|
97
93
|
add_ee_args_superlink,
|
|
98
|
-
get_dashboard_server,
|
|
99
94
|
get_exec_auth_plugins,
|
|
95
|
+
get_exec_authz_plugins,
|
|
100
96
|
get_exec_event_log_writer_plugins,
|
|
101
97
|
get_fleet_event_log_writer_plugins,
|
|
102
98
|
)
|
|
@@ -110,6 +106,10 @@ except ImportError:
|
|
|
110
106
|
"""Return all Exec API authentication plugins."""
|
|
111
107
|
raise NotImplementedError("No authentication plugins are currently supported.")
|
|
112
108
|
|
|
109
|
+
def get_exec_authz_plugins() -> dict[str, type[ExecAuthzPlugin]]:
|
|
110
|
+
"""Return all Exec API authorization plugins."""
|
|
111
|
+
raise NotImplementedError("No authorization plugins are currently supported.")
|
|
112
|
+
|
|
113
113
|
def get_exec_event_log_writer_plugins() -> dict[str, type[EventLogWriterPlugin]]:
|
|
114
114
|
"""Return all Exec API event log writer plugins."""
|
|
115
115
|
raise NotImplementedError(
|
|
@@ -123,148 +123,6 @@ except ImportError:
|
|
|
123
123
|
)
|
|
124
124
|
|
|
125
125
|
|
|
126
|
-
def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
127
|
-
*,
|
|
128
|
-
server_address: str = FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
129
|
-
server: Optional[Server] = None,
|
|
130
|
-
config: Optional[ServerConfig] = None,
|
|
131
|
-
strategy: Optional[Strategy] = None,
|
|
132
|
-
client_manager: Optional[ClientManager] = None,
|
|
133
|
-
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
134
|
-
certificates: Optional[tuple[bytes, bytes, bytes]] = None,
|
|
135
|
-
) -> History:
|
|
136
|
-
"""Start a Flower server using the gRPC transport layer.
|
|
137
|
-
|
|
138
|
-
Warning
|
|
139
|
-
-------
|
|
140
|
-
This function is deprecated since 1.13.0. Use the :code:`flower-superlink` command
|
|
141
|
-
instead to start a SuperLink.
|
|
142
|
-
|
|
143
|
-
Parameters
|
|
144
|
-
----------
|
|
145
|
-
server_address : Optional[str]
|
|
146
|
-
The IPv4 or IPv6 address of the server. Defaults to `"[::]:8080"`.
|
|
147
|
-
server : Optional[flwr.server.Server] (default: None)
|
|
148
|
-
A server implementation, either `flwr.server.Server` or a subclass
|
|
149
|
-
thereof. If no instance is provided, then `start_server` will create
|
|
150
|
-
one.
|
|
151
|
-
config : Optional[ServerConfig] (default: None)
|
|
152
|
-
Currently supported values are `num_rounds` (int, default: 1) and
|
|
153
|
-
`round_timeout` in seconds (float, default: None).
|
|
154
|
-
strategy : Optional[flwr.server.Strategy] (default: None).
|
|
155
|
-
An implementation of the abstract base class
|
|
156
|
-
`flwr.server.strategy.Strategy`. If no strategy is provided, then
|
|
157
|
-
`start_server` will use `flwr.server.strategy.FedAvg`.
|
|
158
|
-
client_manager : Optional[flwr.server.ClientManager] (default: None)
|
|
159
|
-
An implementation of the abstract base class
|
|
160
|
-
`flwr.server.ClientManager`. If no implementation is provided, then
|
|
161
|
-
`start_server` will use
|
|
162
|
-
`flwr.server.client_manager.SimpleClientManager`.
|
|
163
|
-
grpc_max_message_length : int (default: 536_870_912, this equals 512MB)
|
|
164
|
-
The maximum length of gRPC messages that can be exchanged with the
|
|
165
|
-
Flower clients. The default should be sufficient for most models.
|
|
166
|
-
Users who train very large models might need to increase this
|
|
167
|
-
value. Note that the Flower clients need to be started with the
|
|
168
|
-
same value (see `flwr.client.start_client`), otherwise clients will
|
|
169
|
-
not know about the increased limit and block larger messages.
|
|
170
|
-
certificates : Tuple[bytes, bytes, bytes] (default: None)
|
|
171
|
-
Tuple containing root certificate, server certificate, and private key
|
|
172
|
-
to start a secure SSL-enabled server. The tuple is expected to have
|
|
173
|
-
three bytes elements in the following order:
|
|
174
|
-
|
|
175
|
-
* CA certificate.
|
|
176
|
-
* server certificate.
|
|
177
|
-
* server private key.
|
|
178
|
-
|
|
179
|
-
Returns
|
|
180
|
-
-------
|
|
181
|
-
hist : flwr.server.history.History
|
|
182
|
-
Object containing training and evaluation metrics.
|
|
183
|
-
|
|
184
|
-
Examples
|
|
185
|
-
--------
|
|
186
|
-
Starting an insecure server::
|
|
187
|
-
|
|
188
|
-
start_server()
|
|
189
|
-
|
|
190
|
-
Starting a TLS-enabled server::
|
|
191
|
-
|
|
192
|
-
start_server(
|
|
193
|
-
certificates=(
|
|
194
|
-
Path("/crts/root.pem").read_bytes(),
|
|
195
|
-
Path("/crts/localhost.crt").read_bytes(),
|
|
196
|
-
Path("/crts/localhost.key").read_bytes()
|
|
197
|
-
)
|
|
198
|
-
)
|
|
199
|
-
"""
|
|
200
|
-
msg = (
|
|
201
|
-
"flwr.server.start_server() is deprecated."
|
|
202
|
-
"\n\tInstead, use the `flower-superlink` CLI command to start a SuperLink "
|
|
203
|
-
"as shown below:"
|
|
204
|
-
"\n\n\t\t$ flower-superlink --insecure"
|
|
205
|
-
"\n\n\tTo view usage and all available options, run:"
|
|
206
|
-
"\n\n\t\t$ flower-superlink --help"
|
|
207
|
-
"\n\n\tUsing `start_server()` is deprecated."
|
|
208
|
-
)
|
|
209
|
-
warn_deprecated_feature(name=msg)
|
|
210
|
-
|
|
211
|
-
event(EventType.START_SERVER_ENTER)
|
|
212
|
-
|
|
213
|
-
# Parse IP address
|
|
214
|
-
parsed_address = parse_address(server_address)
|
|
215
|
-
if not parsed_address:
|
|
216
|
-
sys.exit(f"Server IP address ({server_address}) cannot be parsed.")
|
|
217
|
-
host, port, is_v6 = parsed_address
|
|
218
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
219
|
-
|
|
220
|
-
# Initialize server and server config
|
|
221
|
-
initialized_server, initialized_config = init_defaults(
|
|
222
|
-
server=server,
|
|
223
|
-
config=config,
|
|
224
|
-
strategy=strategy,
|
|
225
|
-
client_manager=client_manager,
|
|
226
|
-
)
|
|
227
|
-
log(
|
|
228
|
-
INFO,
|
|
229
|
-
"Starting Flower server, config: %s",
|
|
230
|
-
initialized_config,
|
|
231
|
-
)
|
|
232
|
-
|
|
233
|
-
# Start gRPC server
|
|
234
|
-
grpc_server = start_grpc_server(
|
|
235
|
-
client_manager=initialized_server.client_manager(),
|
|
236
|
-
server_address=address,
|
|
237
|
-
max_message_length=grpc_max_message_length,
|
|
238
|
-
certificates=certificates,
|
|
239
|
-
)
|
|
240
|
-
log(
|
|
241
|
-
INFO,
|
|
242
|
-
"Flower ECE: gRPC server running (%s rounds), SSL is %s",
|
|
243
|
-
initialized_config.num_rounds,
|
|
244
|
-
"enabled" if certificates is not None else "disabled",
|
|
245
|
-
)
|
|
246
|
-
|
|
247
|
-
# Graceful shutdown
|
|
248
|
-
register_exit_handlers(
|
|
249
|
-
event_type=EventType.START_SERVER_LEAVE,
|
|
250
|
-
exit_message="Flower server terminated gracefully.",
|
|
251
|
-
grpc_servers=[grpc_server],
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
# Start training
|
|
255
|
-
hist = run_fl(
|
|
256
|
-
server=initialized_server,
|
|
257
|
-
config=initialized_config,
|
|
258
|
-
)
|
|
259
|
-
|
|
260
|
-
# Stop the gRPC server
|
|
261
|
-
grpc_server.stop(grace=1)
|
|
262
|
-
|
|
263
|
-
event(EventType.START_SERVER_LEAVE)
|
|
264
|
-
|
|
265
|
-
return hist
|
|
266
|
-
|
|
267
|
-
|
|
268
126
|
# pylint: disable=too-many-branches, too-many-locals, too-many-statements
|
|
269
127
|
def run_superlink() -> None:
|
|
270
128
|
"""Run Flower SuperLink (ServerAppIo API and Fleet API)."""
|
|
@@ -293,10 +151,13 @@ def run_superlink() -> None:
|
|
|
293
151
|
verify_tls_cert = not getattr(args, "disable_oidc_tls_cert_verification", None)
|
|
294
152
|
|
|
295
153
|
auth_plugin: Optional[ExecAuthPlugin] = None
|
|
154
|
+
authz_plugin: Optional[ExecAuthzPlugin] = None
|
|
296
155
|
event_log_plugin: Optional[EventLogWriterPlugin] = None
|
|
297
156
|
# Load the auth plugin if the args.user_auth_config is provided
|
|
298
157
|
if cfg_path := getattr(args, "user_auth_config", None):
|
|
299
|
-
auth_plugin =
|
|
158
|
+
auth_plugin, authz_plugin = _try_obtain_exec_auth_plugins(
|
|
159
|
+
Path(cfg_path), verify_tls_cert
|
|
160
|
+
)
|
|
300
161
|
# Enable event logging if the args.enable_event_log is True
|
|
301
162
|
if args.enable_event_log:
|
|
302
163
|
event_log_plugin = _try_obtain_exec_event_log_writer_plugin()
|
|
@@ -307,18 +168,23 @@ def run_superlink() -> None:
|
|
|
307
168
|
# Initialize FfsFactory
|
|
308
169
|
ffs_factory = FfsFactory(args.storage_dir)
|
|
309
170
|
|
|
171
|
+
# Initialize ObjectStoreFactory
|
|
172
|
+
objectstore_factory = ObjectStoreFactory()
|
|
173
|
+
|
|
310
174
|
# Start Exec API
|
|
311
175
|
executor = load_executor(args)
|
|
312
176
|
exec_server: grpc.Server = run_exec_api_grpc(
|
|
313
177
|
address=exec_address,
|
|
314
178
|
state_factory=state_factory,
|
|
315
179
|
ffs_factory=ffs_factory,
|
|
180
|
+
objectstore_factory=objectstore_factory,
|
|
316
181
|
executor=executor,
|
|
317
182
|
certificates=certificates,
|
|
318
183
|
config=parse_config_args(
|
|
319
184
|
[args.executor_config] if args.executor_config else args.executor_config
|
|
320
185
|
),
|
|
321
186
|
auth_plugin=auth_plugin,
|
|
187
|
+
authz_plugin=authz_plugin,
|
|
322
188
|
event_log_plugin=event_log_plugin,
|
|
323
189
|
)
|
|
324
190
|
grpc_servers = [exec_server]
|
|
@@ -343,6 +209,7 @@ def run_superlink() -> None:
|
|
|
343
209
|
address=serverappio_address,
|
|
344
210
|
state_factory=state_factory,
|
|
345
211
|
ffs_factory=ffs_factory,
|
|
212
|
+
objectstore_factory=objectstore_factory,
|
|
346
213
|
certificates=None, # ServerAppIo API doesn't support SSL yet
|
|
347
214
|
)
|
|
348
215
|
grpc_servers.append(serverappio_server)
|
|
@@ -388,6 +255,7 @@ def run_superlink() -> None:
|
|
|
388
255
|
args.ssl_certfile,
|
|
389
256
|
state_factory,
|
|
390
257
|
ffs_factory,
|
|
258
|
+
objectstore_factory,
|
|
391
259
|
num_workers,
|
|
392
260
|
),
|
|
393
261
|
daemon=True,
|
|
@@ -421,6 +289,7 @@ def run_superlink() -> None:
|
|
|
421
289
|
address=fleet_address,
|
|
422
290
|
state_factory=state_factory,
|
|
423
291
|
ffs_factory=ffs_factory,
|
|
292
|
+
objectstore_factory=objectstore_factory,
|
|
424
293
|
certificates=certificates,
|
|
425
294
|
interceptors=interceptors,
|
|
426
295
|
)
|
|
@@ -430,6 +299,7 @@ def run_superlink() -> None:
|
|
|
430
299
|
address=fleet_address,
|
|
431
300
|
state_factory=state_factory,
|
|
432
301
|
ffs_factory=ffs_factory,
|
|
302
|
+
objectstore_factory=objectstore_factory,
|
|
433
303
|
certificates=certificates,
|
|
434
304
|
)
|
|
435
305
|
grpc_servers.append(fleet_server)
|
|
@@ -462,17 +332,6 @@ def run_superlink() -> None:
|
|
|
462
332
|
scheduler_th.start()
|
|
463
333
|
bckg_threads.append(scheduler_th)
|
|
464
334
|
|
|
465
|
-
# Add Dashboard server if available
|
|
466
|
-
if dashboard_address := getattr(args, "dashboard_address", None):
|
|
467
|
-
dashboard_address_str, _, _ = _format_address(dashboard_address)
|
|
468
|
-
dashboard_server = get_dashboard_server(
|
|
469
|
-
address=dashboard_address_str,
|
|
470
|
-
state_factory=state_factory,
|
|
471
|
-
certificates=None,
|
|
472
|
-
)
|
|
473
|
-
|
|
474
|
-
grpc_servers.append(dashboard_server)
|
|
475
|
-
|
|
476
335
|
# Graceful shutdown
|
|
477
336
|
register_exit_handlers(
|
|
478
337
|
event_type=EventType.RUN_SUPERLINK_LEAVE,
|
|
@@ -611,33 +470,50 @@ def _try_load_public_keys_node_authentication(
|
|
|
611
470
|
return node_public_keys
|
|
612
471
|
|
|
613
472
|
|
|
614
|
-
def
|
|
473
|
+
def _try_obtain_exec_auth_plugins(
|
|
615
474
|
config_path: Path, verify_tls_cert: bool
|
|
616
|
-
) ->
|
|
475
|
+
) -> tuple[ExecAuthPlugin, ExecAuthzPlugin]:
|
|
476
|
+
"""Obtain Exec API authentication and authorization plugins."""
|
|
617
477
|
# Load YAML file
|
|
618
478
|
with config_path.open("r", encoding="utf-8") as file:
|
|
619
479
|
config: dict[str, Any] = yaml.safe_load(file)
|
|
620
480
|
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
481
|
+
def _load_plugin(
|
|
482
|
+
section: str, yaml_key: str, loader: Callable[[], dict[str, type[P]]]
|
|
483
|
+
) -> P:
|
|
484
|
+
section_cfg = config.get(section, {})
|
|
485
|
+
auth_plugin_name = section_cfg.get(yaml_key, "")
|
|
486
|
+
try:
|
|
487
|
+
plugins: dict[str, type[P]] = loader()
|
|
488
|
+
plugin_cls: type[P] = plugins[auth_plugin_name]
|
|
489
|
+
return plugin_cls(
|
|
490
|
+
user_auth_config_path=config_path, verify_tls_cert=verify_tls_cert
|
|
491
|
+
)
|
|
492
|
+
except KeyError:
|
|
493
|
+
if auth_plugin_name:
|
|
494
|
+
sys.exit(
|
|
495
|
+
f"{yaml_key}: {auth_plugin_name} is not supported. "
|
|
496
|
+
f"Please provide a valid {section} type in the configuration."
|
|
497
|
+
)
|
|
498
|
+
sys.exit(f"No {section} type is provided in the configuration.")
|
|
499
|
+
except NotImplementedError:
|
|
500
|
+
sys.exit(f"No {section} plugins are currently supported.")
|
|
624
501
|
|
|
625
502
|
# Load authentication plugin
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
sys.exit("No authentication plugins are currently supported.")
|
|
503
|
+
auth_plugin = _load_plugin(
|
|
504
|
+
section="authentication",
|
|
505
|
+
yaml_key=AUTH_TYPE_YAML_KEY,
|
|
506
|
+
loader=get_exec_auth_plugins,
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
# Load authorization plugin
|
|
510
|
+
authz_plugin = _load_plugin(
|
|
511
|
+
section="authorization",
|
|
512
|
+
yaml_key=AUTHZ_TYPE_YAML_KEY,
|
|
513
|
+
loader=get_exec_authz_plugins,
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
return auth_plugin, authz_plugin
|
|
641
517
|
|
|
642
518
|
|
|
643
519
|
def _try_obtain_exec_event_log_writer_plugin() -> Optional[EventLogWriterPlugin]:
|
|
@@ -668,10 +544,11 @@ def _try_obtain_fleet_event_log_writer_plugin() -> Optional[EventLogWriterPlugin
|
|
|
668
544
|
sys.exit("No Fleet API event log writer plugins are currently supported.")
|
|
669
545
|
|
|
670
546
|
|
|
671
|
-
def _run_fleet_api_grpc_rere(
|
|
547
|
+
def _run_fleet_api_grpc_rere( # pylint: disable=R0913, R0917
|
|
672
548
|
address: str,
|
|
673
549
|
state_factory: LinkStateFactory,
|
|
674
550
|
ffs_factory: FfsFactory,
|
|
551
|
+
objectstore_factory: ObjectStoreFactory,
|
|
675
552
|
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
676
553
|
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
677
554
|
) -> grpc.Server:
|
|
@@ -680,6 +557,7 @@ def _run_fleet_api_grpc_rere(
|
|
|
680
557
|
fleet_servicer = FleetServicer(
|
|
681
558
|
state_factory=state_factory,
|
|
682
559
|
ffs_factory=ffs_factory,
|
|
560
|
+
objectstore_factory=objectstore_factory,
|
|
683
561
|
)
|
|
684
562
|
fleet_add_servicer_to_server_fn = add_FleetServicer_to_server
|
|
685
563
|
fleet_grpc_server = generic_create_grpc_server(
|
|
@@ -700,6 +578,7 @@ def _run_fleet_api_grpc_adapter(
|
|
|
700
578
|
address: str,
|
|
701
579
|
state_factory: LinkStateFactory,
|
|
702
580
|
ffs_factory: FfsFactory,
|
|
581
|
+
objectstore_factory: ObjectStoreFactory,
|
|
703
582
|
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
704
583
|
) -> grpc.Server:
|
|
705
584
|
"""Run Fleet API (GrpcAdapter)."""
|
|
@@ -707,6 +586,7 @@ def _run_fleet_api_grpc_adapter(
|
|
|
707
586
|
fleet_servicer = GrpcAdapterServicer(
|
|
708
587
|
state_factory=state_factory,
|
|
709
588
|
ffs_factory=ffs_factory,
|
|
589
|
+
objectstore_factory=objectstore_factory,
|
|
710
590
|
)
|
|
711
591
|
fleet_add_servicer_to_server_fn = add_GrpcAdapterServicer_to_server
|
|
712
592
|
fleet_grpc_server = generic_create_grpc_server(
|
|
@@ -731,6 +611,7 @@ def _run_fleet_api_rest(
|
|
|
731
611
|
ssl_certfile: Optional[str],
|
|
732
612
|
state_factory: LinkStateFactory,
|
|
733
613
|
ffs_factory: FfsFactory,
|
|
614
|
+
objectstore_factory: ObjectStoreFactory,
|
|
734
615
|
num_workers: int,
|
|
735
616
|
) -> None:
|
|
736
617
|
"""Run ServerAppIo API (REST-based)."""
|
|
@@ -746,6 +627,7 @@ def _run_fleet_api_rest(
|
|
|
746
627
|
# See: https://www.starlette.io/applications/#accessing-the-app-instance
|
|
747
628
|
fast_api_app.state.STATE_FACTORY = state_factory
|
|
748
629
|
fast_api_app.state.FFS_FACTORY = ffs_factory
|
|
630
|
+
fast_api_app.state.OBJECTSTORE_FACTORY = objectstore_factory
|
|
749
631
|
|
|
750
632
|
uvicorn.run(
|
|
751
633
|
app="flwr.server.superlink.fleet.rest_rere.rest_api:app",
|
flwr/server/compat/app_utils.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
|
+
from typing import Any, Callable
|
|
19
20
|
|
|
20
21
|
from flwr.common.typing import RunNotRunningException
|
|
21
22
|
|
|
@@ -80,36 +81,57 @@ def _update_client_manager(
|
|
|
80
81
|
"""Update the nodes list in the client manager."""
|
|
81
82
|
# Loop until the grid is disconnected
|
|
82
83
|
registered_nodes: dict[int, GridClientProxy] = {}
|
|
84
|
+
lock = threading.RLock()
|
|
85
|
+
|
|
86
|
+
def update_registered_nodes() -> None:
|
|
87
|
+
with lock:
|
|
88
|
+
all_node_ids = set(grid.get_node_ids())
|
|
89
|
+
dead_nodes = set(registered_nodes).difference(all_node_ids)
|
|
90
|
+
new_nodes = all_node_ids.difference(registered_nodes)
|
|
91
|
+
|
|
92
|
+
# Unregister dead nodes
|
|
93
|
+
for node_id in dead_nodes:
|
|
94
|
+
client_proxy = registered_nodes[node_id]
|
|
95
|
+
client_manager.unregister(client_proxy)
|
|
96
|
+
del registered_nodes[node_id]
|
|
97
|
+
|
|
98
|
+
# Register new nodes
|
|
99
|
+
for node_id in new_nodes:
|
|
100
|
+
client_proxy = GridClientProxy(
|
|
101
|
+
node_id=node_id,
|
|
102
|
+
grid=grid,
|
|
103
|
+
run_id=grid.run.run_id,
|
|
104
|
+
)
|
|
105
|
+
if client_manager.register(client_proxy):
|
|
106
|
+
registered_nodes[node_id] = client_proxy
|
|
107
|
+
else:
|
|
108
|
+
raise RuntimeError("Could not register node.")
|
|
109
|
+
|
|
110
|
+
# Get the wrapped method of ClientManager instance
|
|
111
|
+
def get_wrapped_method(method_name: str) -> Callable[..., Any]:
|
|
112
|
+
original_method = getattr(client_manager, method_name)
|
|
113
|
+
|
|
114
|
+
def wrapped_method(*args: Any, **kwargs: Any) -> Any:
|
|
115
|
+
# Update registered nodes before calling the original method
|
|
116
|
+
update_registered_nodes()
|
|
117
|
+
return original_method(*args, **kwargs)
|
|
118
|
+
|
|
119
|
+
return wrapped_method
|
|
120
|
+
|
|
121
|
+
# Wrap the ClientManager
|
|
122
|
+
for method_name in ["num_available", "all", "sample"]:
|
|
123
|
+
setattr(client_manager, method_name, get_wrapped_method(method_name))
|
|
124
|
+
|
|
125
|
+
c_done.set()
|
|
126
|
+
|
|
83
127
|
while not f_stop.is_set():
|
|
128
|
+
# Sleep for 5 seconds
|
|
129
|
+
if not f_stop.is_set():
|
|
130
|
+
f_stop.wait(5)
|
|
131
|
+
|
|
84
132
|
try:
|
|
85
|
-
|
|
133
|
+
# Update registered nodes
|
|
134
|
+
update_registered_nodes()
|
|
86
135
|
except RunNotRunningException:
|
|
87
136
|
f_stop.set()
|
|
88
137
|
break
|
|
89
|
-
dead_nodes = set(registered_nodes).difference(all_node_ids)
|
|
90
|
-
new_nodes = all_node_ids.difference(registered_nodes)
|
|
91
|
-
|
|
92
|
-
# Unregister dead nodes
|
|
93
|
-
for node_id in dead_nodes:
|
|
94
|
-
client_proxy = registered_nodes[node_id]
|
|
95
|
-
client_manager.unregister(client_proxy)
|
|
96
|
-
del registered_nodes[node_id]
|
|
97
|
-
|
|
98
|
-
# Register new nodes
|
|
99
|
-
for node_id in new_nodes:
|
|
100
|
-
client_proxy = GridClientProxy(
|
|
101
|
-
node_id=node_id,
|
|
102
|
-
grid=grid,
|
|
103
|
-
run_id=grid.run.run_id,
|
|
104
|
-
)
|
|
105
|
-
if client_manager.register(client_proxy):
|
|
106
|
-
registered_nodes[node_id] = client_proxy
|
|
107
|
-
else:
|
|
108
|
-
raise RuntimeError("Could not register node.")
|
|
109
|
-
|
|
110
|
-
# Flag first pass for nodes registration is completed
|
|
111
|
-
c_done.set()
|
|
112
|
-
|
|
113
|
-
# Sleep for 3 seconds
|
|
114
|
-
if not f_stop.is_set():
|
|
115
|
-
f_stop.wait(3)
|
|
@@ -59,7 +59,7 @@ class FleetEventLogInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
59
59
|
log_entry = self.log_plugin.compose_log_before_event(
|
|
60
60
|
request=request,
|
|
61
61
|
context=context,
|
|
62
|
-
|
|
62
|
+
account_info=None,
|
|
63
63
|
method_name=method_name,
|
|
64
64
|
)
|
|
65
65
|
self.log_plugin.write_log(log_entry)
|
|
@@ -75,7 +75,7 @@ class FleetEventLogInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
75
75
|
log_entry = self.log_plugin.compose_log_after_event(
|
|
76
76
|
request=request,
|
|
77
77
|
context=context,
|
|
78
|
-
|
|
78
|
+
account_info=None,
|
|
79
79
|
method_name=method_name,
|
|
80
80
|
response=unary_response or error,
|
|
81
81
|
)
|