flwr 1.13.0__py3-none-any.whl → 1.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +5 -0
- flwr/cli/build.py +1 -37
- flwr/cli/cli_user_auth_interceptor.py +86 -0
- flwr/cli/config_utils.py +19 -2
- flwr/cli/example.py +1 -0
- flwr/cli/install.py +2 -19
- flwr/cli/log.py +18 -36
- flwr/cli/login/__init__.py +22 -0
- flwr/cli/login/login.py +81 -0
- flwr/cli/ls.py +205 -106
- flwr/cli/new/__init__.py +1 -0
- flwr/cli/new/new.py +25 -14
- flwr/cli/new/templates/app/.gitignore.tpl +3 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/__init__.py +1 -0
- flwr/cli/run/run.py +89 -39
- flwr/cli/stop.py +130 -0
- flwr/cli/utils.py +172 -8
- flwr/client/app.py +14 -3
- flwr/client/client.py +1 -32
- flwr/client/clientapp/app.py +4 -8
- flwr/client/clientapp/utils.py +1 -0
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/connection.py +1 -1
- flwr/client/grpc_rere_client/connection.py +13 -7
- flwr/client/message_handler/message_handler.py +1 -2
- flwr/client/mod/comms_mods.py +1 -0
- flwr/client/mod/localdp_mod.py +1 -1
- flwr/client/nodestate/__init__.py +1 -0
- flwr/client/nodestate/nodestate.py +1 -0
- flwr/client/nodestate/nodestate_factory.py +1 -0
- flwr/client/numpy_client.py +0 -44
- flwr/client/rest_client/connection.py +3 -3
- flwr/client/supernode/app.py +2 -2
- flwr/common/address.py +1 -0
- flwr/common/args.py +1 -0
- flwr/common/auth_plugin/__init__.py +24 -0
- flwr/common/auth_plugin/auth_plugin.py +111 -0
- flwr/common/config.py +3 -1
- flwr/common/constant.py +17 -1
- flwr/common/logger.py +40 -0
- flwr/common/message.py +1 -0
- flwr/common/object_ref.py +57 -54
- flwr/common/pyproject.py +1 -0
- flwr/common/record/__init__.py +1 -0
- flwr/common/record/parametersrecord.py +1 -0
- flwr/common/retry_invoker.py +77 -0
- flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
- flwr/common/telemetry.py +15 -4
- flwr/common/typing.py +12 -0
- flwr/common/version.py +1 -0
- flwr/proto/exec_pb2.py +38 -14
- flwr/proto/exec_pb2.pyi +107 -2
- flwr/proto/exec_pb2_grpc.py +102 -0
- flwr/proto/exec_pb2_grpc.pyi +39 -0
- flwr/proto/fab_pb2.py +4 -4
- flwr/proto/fab_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +18 -18
- flwr/proto/serverappio_pb2.pyi +8 -2
- flwr/proto/serverappio_pb2_grpc.py +34 -0
- flwr/proto/serverappio_pb2_grpc.pyi +13 -0
- flwr/proto/simulationio_pb2.py +2 -2
- flwr/proto/simulationio_pb2_grpc.py +34 -0
- flwr/proto/simulationio_pb2_grpc.pyi +13 -0
- flwr/server/app.py +62 -7
- flwr/server/compat/app_utils.py +7 -1
- flwr/server/driver/grpc_driver.py +11 -63
- flwr/server/driver/inmemory_driver.py +5 -1
- flwr/server/run_serverapp.py +8 -9
- flwr/server/serverapp/app.py +25 -10
- flwr/server/strategy/dpfedavg_fixed.py +1 -0
- flwr/server/superlink/driver/serverappio_grpc.py +1 -0
- flwr/server/superlink/driver/serverappio_servicer.py +82 -23
- flwr/server/superlink/ffs/disk_ffs.py +1 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +12 -11
- flwr/server/superlink/fleet/message_handler/message_handler.py +32 -5
- flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
- flwr/server/superlink/fleet/vce/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
- flwr/server/superlink/linkstate/in_memory_linkstate.py +21 -30
- flwr/server/superlink/linkstate/linkstate.py +17 -2
- flwr/server/superlink/linkstate/sqlite_linkstate.py +30 -49
- flwr/server/superlink/simulation/simulationio_servicer.py +33 -0
- flwr/server/superlink/utils.py +65 -0
- flwr/simulation/app.py +59 -52
- flwr/simulation/ray_transport/ray_actor.py +1 -0
- flwr/simulation/ray_transport/utils.py +1 -0
- flwr/simulation/run_simulation.py +36 -22
- flwr/simulation/simulationio_connection.py +3 -0
- flwr/superexec/app.py +1 -0
- flwr/superexec/deployment.py +1 -0
- flwr/superexec/exec_grpc.py +19 -1
- flwr/superexec/exec_servicer.py +76 -2
- flwr/superexec/exec_user_auth_interceptor.py +101 -0
- flwr/superexec/executor.py +1 -0
- {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/METADATA +8 -8
- {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/RECORD +112 -112
- flwr/proto/common_pb2.py +0 -36
- flwr/proto/common_pb2.pyi +0 -121
- flwr/proto/common_pb2_grpc.py +0 -4
- flwr/proto/common_pb2_grpc.pyi +0 -4
- flwr/proto/control_pb2.py +0 -27
- flwr/proto/control_pb2.pyi +0 -7
- flwr/proto/control_pb2_grpc.py +0 -135
- flwr/proto/control_pb2_grpc.pyi +0 -53
- {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/LICENSE +0 -0
- {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/WHEEL +0 -0
- {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/entry_points.txt +0 -0
flwr/server/app.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower server app."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import argparse
|
|
18
19
|
import csv
|
|
19
20
|
import importlib.util
|
|
@@ -24,9 +25,10 @@ from collections.abc import Sequence
|
|
|
24
25
|
from logging import DEBUG, INFO, WARN
|
|
25
26
|
from pathlib import Path
|
|
26
27
|
from time import sleep
|
|
27
|
-
from typing import Optional
|
|
28
|
+
from typing import Any, Optional
|
|
28
29
|
|
|
29
30
|
import grpc
|
|
31
|
+
import yaml
|
|
30
32
|
from cryptography.exceptions import UnsupportedAlgorithm
|
|
31
33
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
32
34
|
from cryptography.hazmat.primitives.serialization import (
|
|
@@ -37,8 +39,10 @@ from cryptography.hazmat.primitives.serialization import (
|
|
|
37
39
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
38
40
|
from flwr.common.address import parse_address
|
|
39
41
|
from flwr.common.args import try_obtain_server_certificates
|
|
42
|
+
from flwr.common.auth_plugin import ExecAuthPlugin
|
|
40
43
|
from flwr.common.config import get_flwr_dir, parse_config_args
|
|
41
44
|
from flwr.common.constant import (
|
|
45
|
+
AUTH_TYPE,
|
|
42
46
|
CLIENT_OCTET,
|
|
43
47
|
EXEC_API_DEFAULT_SERVER_ADDRESS,
|
|
44
48
|
FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
@@ -66,7 +70,6 @@ from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
|
|
66
70
|
from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
|
|
67
71
|
from flwr.superexec.app import load_executor
|
|
68
72
|
from flwr.superexec.exec_grpc import run_exec_api_grpc
|
|
69
|
-
from flwr.superexec.simulation import SimulationEngine
|
|
70
73
|
|
|
71
74
|
from .client_manager import ClientManager
|
|
72
75
|
from .history import History
|
|
@@ -89,6 +92,19 @@ DATABASE = ":flwr-in-memory-state:"
|
|
|
89
92
|
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
|
|
90
93
|
|
|
91
94
|
|
|
95
|
+
try:
|
|
96
|
+
from flwr.ee import add_ee_args_superlink, get_exec_auth_plugins
|
|
97
|
+
except ImportError:
|
|
98
|
+
|
|
99
|
+
# pylint: disable-next=unused-argument
|
|
100
|
+
def add_ee_args_superlink(parser: argparse.ArgumentParser) -> None:
|
|
101
|
+
"""Add EE-specific arguments to the parser."""
|
|
102
|
+
|
|
103
|
+
def get_exec_auth_plugins() -> dict[str, type[ExecAuthPlugin]]:
|
|
104
|
+
"""Return all Exec API authentication plugins."""
|
|
105
|
+
raise NotImplementedError("No authentication plugins are currently supported.")
|
|
106
|
+
|
|
107
|
+
|
|
92
108
|
def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
93
109
|
*,
|
|
94
110
|
server_address: str = FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
@@ -247,6 +263,12 @@ def run_superlink() -> None:
|
|
|
247
263
|
# Obtain certificates
|
|
248
264
|
certificates = try_obtain_server_certificates(args, args.fleet_api_type)
|
|
249
265
|
|
|
266
|
+
user_auth_config = _try_obtain_user_auth_config(args)
|
|
267
|
+
auth_plugin: Optional[ExecAuthPlugin] = None
|
|
268
|
+
# user_auth_config is None only if the args.user_auth_config is not provided
|
|
269
|
+
if user_auth_config is not None:
|
|
270
|
+
auth_plugin = _try_obtain_exec_auth_plugin(user_auth_config)
|
|
271
|
+
|
|
250
272
|
# Initialize StateFactory
|
|
251
273
|
state_factory = LinkStateFactory(args.database)
|
|
252
274
|
|
|
@@ -264,13 +286,13 @@ def run_superlink() -> None:
|
|
|
264
286
|
config=parse_config_args(
|
|
265
287
|
[args.executor_config] if args.executor_config else args.executor_config
|
|
266
288
|
),
|
|
289
|
+
auth_plugin=auth_plugin,
|
|
267
290
|
)
|
|
268
291
|
grpc_servers = [exec_server]
|
|
269
292
|
|
|
270
293
|
# Determine Exec plugin
|
|
271
294
|
# If simulation is used, don't start ServerAppIo and Fleet APIs
|
|
272
|
-
sim_exec =
|
|
273
|
-
|
|
295
|
+
sim_exec = executor.__class__.__qualname__ == "SimulationEngine"
|
|
274
296
|
bckg_threads = []
|
|
275
297
|
|
|
276
298
|
if sim_exec:
|
|
@@ -278,7 +300,7 @@ def run_superlink() -> None:
|
|
|
278
300
|
address=simulationio_address,
|
|
279
301
|
state_factory=state_factory,
|
|
280
302
|
ffs_factory=ffs_factory,
|
|
281
|
-
certificates=
|
|
303
|
+
certificates=None, # SimulationAppIo API doesn't support SSL yet
|
|
282
304
|
)
|
|
283
305
|
grpc_servers.append(simulationio_server)
|
|
284
306
|
|
|
@@ -352,6 +374,7 @@ def run_superlink() -> None:
|
|
|
352
374
|
server_public_key,
|
|
353
375
|
) = maybe_keys
|
|
354
376
|
state = state_factory.state()
|
|
377
|
+
state.clear_supernode_auth_keys_and_credentials()
|
|
355
378
|
state.store_node_public_keys(node_public_keys)
|
|
356
379
|
state.store_server_private_public_key(
|
|
357
380
|
private_key_to_bytes(server_private_key),
|
|
@@ -362,7 +385,7 @@ def run_superlink() -> None:
|
|
|
362
385
|
"Node authentication enabled with %d known public keys",
|
|
363
386
|
len(node_public_keys),
|
|
364
387
|
)
|
|
365
|
-
interceptors = [AuthenticateServerInterceptor(
|
|
388
|
+
interceptors = [AuthenticateServerInterceptor(state_factory)]
|
|
366
389
|
|
|
367
390
|
fleet_server = _run_fleet_api_grpc_rere(
|
|
368
391
|
address=fleet_address,
|
|
@@ -389,6 +412,9 @@ def run_superlink() -> None:
|
|
|
389
412
|
io_address = (
|
|
390
413
|
f"{CLIENT_OCTET}:{_port}" if _octet == SERVER_OCTET else serverappio_address
|
|
391
414
|
)
|
|
415
|
+
address_arg = (
|
|
416
|
+
"--simulationio-api-address" if sim_exec else "--serverappio-api-address"
|
|
417
|
+
)
|
|
392
418
|
address = simulationio_address if sim_exec else io_address
|
|
393
419
|
cmd = "flwr-simulation" if sim_exec else "flwr-serverapp"
|
|
394
420
|
|
|
@@ -397,6 +423,7 @@ def run_superlink() -> None:
|
|
|
397
423
|
target=_flwr_scheduler,
|
|
398
424
|
args=(
|
|
399
425
|
state_factory,
|
|
426
|
+
address_arg,
|
|
400
427
|
address,
|
|
401
428
|
cmd,
|
|
402
429
|
),
|
|
@@ -422,6 +449,7 @@ def run_superlink() -> None:
|
|
|
422
449
|
|
|
423
450
|
def _flwr_scheduler(
|
|
424
451
|
state_factory: LinkStateFactory,
|
|
452
|
+
io_api_arg: str,
|
|
425
453
|
io_api_address: str,
|
|
426
454
|
cmd: str,
|
|
427
455
|
) -> None:
|
|
@@ -446,7 +474,7 @@ def _flwr_scheduler(
|
|
|
446
474
|
command = [
|
|
447
475
|
cmd,
|
|
448
476
|
"--run-once",
|
|
449
|
-
|
|
477
|
+
io_api_arg,
|
|
450
478
|
io_api_address,
|
|
451
479
|
"--insecure",
|
|
452
480
|
]
|
|
@@ -556,6 +584,32 @@ def _try_setup_node_authentication(
|
|
|
556
584
|
)
|
|
557
585
|
|
|
558
586
|
|
|
587
|
+
def _try_obtain_user_auth_config(args: argparse.Namespace) -> Optional[dict[str, Any]]:
|
|
588
|
+
if getattr(args, "user_auth_config", None) is not None:
|
|
589
|
+
with open(args.user_auth_config, encoding="utf-8") as file:
|
|
590
|
+
config: dict[str, Any] = yaml.safe_load(file)
|
|
591
|
+
return config
|
|
592
|
+
return None
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def _try_obtain_exec_auth_plugin(config: dict[str, Any]) -> Optional[ExecAuthPlugin]:
|
|
596
|
+
auth_config: dict[str, Any] = config.get("authentication", {})
|
|
597
|
+
auth_type: str = auth_config.get(AUTH_TYPE, "")
|
|
598
|
+
try:
|
|
599
|
+
all_plugins: dict[str, type[ExecAuthPlugin]] = get_exec_auth_plugins()
|
|
600
|
+
auth_plugin_class = all_plugins[auth_type]
|
|
601
|
+
return auth_plugin_class(config=auth_config)
|
|
602
|
+
except KeyError:
|
|
603
|
+
if auth_type != "":
|
|
604
|
+
sys.exit(
|
|
605
|
+
f'Authentication type "{auth_type}" is not supported. '
|
|
606
|
+
"Please provide a valid authentication type in the configuration."
|
|
607
|
+
)
|
|
608
|
+
sys.exit("No authentication type is provided in the configuration.")
|
|
609
|
+
except NotImplementedError:
|
|
610
|
+
sys.exit("No authentication plugins are currently supported.")
|
|
611
|
+
|
|
612
|
+
|
|
559
613
|
def _run_fleet_api_grpc_rere(
|
|
560
614
|
address: str,
|
|
561
615
|
state_factory: LinkStateFactory,
|
|
@@ -654,6 +708,7 @@ def _parse_args_run_superlink() -> argparse.ArgumentParser:
|
|
|
654
708
|
)
|
|
655
709
|
|
|
656
710
|
_add_args_common(parser=parser)
|
|
711
|
+
add_ee_args_superlink(parser=parser)
|
|
657
712
|
_add_args_serverappio_api(parser=parser)
|
|
658
713
|
_add_args_fleet_api(parser=parser)
|
|
659
714
|
_add_args_exec_api(parser=parser)
|
flwr/server/compat/app_utils.py
CHANGED
|
@@ -17,6 +17,8 @@
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
19
|
|
|
20
|
+
from flwr.common.typing import RunNotRunningException
|
|
21
|
+
|
|
20
22
|
from ..client_manager import ClientManager
|
|
21
23
|
from ..compat.driver_client_proxy import DriverClientProxy
|
|
22
24
|
from ..driver import Driver
|
|
@@ -74,7 +76,11 @@ def _update_client_manager(
|
|
|
74
76
|
# Loop until the driver is disconnected
|
|
75
77
|
registered_nodes: dict[int, DriverClientProxy] = {}
|
|
76
78
|
while not f_stop.is_set():
|
|
77
|
-
|
|
79
|
+
try:
|
|
80
|
+
all_node_ids = set(driver.get_node_ids())
|
|
81
|
+
except RunNotRunningException:
|
|
82
|
+
f_stop.set()
|
|
83
|
+
break
|
|
78
84
|
dead_nodes = set(registered_nodes).difference(all_node_ids)
|
|
79
85
|
new_nodes = all_node_ids.difference(registered_nodes)
|
|
80
86
|
|
|
@@ -14,19 +14,20 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower gRPC Driver."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import time
|
|
18
19
|
import warnings
|
|
19
20
|
from collections.abc import Iterable
|
|
20
|
-
from logging import DEBUG,
|
|
21
|
-
from typing import
|
|
21
|
+
from logging import DEBUG, WARNING
|
|
22
|
+
from typing import Optional, cast
|
|
22
23
|
|
|
23
24
|
import grpc
|
|
24
25
|
|
|
25
26
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
26
|
-
from flwr.common.constant import
|
|
27
|
+
from flwr.common.constant import SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS
|
|
27
28
|
from flwr.common.grpc import create_channel
|
|
28
29
|
from flwr.common.logger import log
|
|
29
|
-
from flwr.common.retry_invoker import
|
|
30
|
+
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
|
30
31
|
from flwr.common.serde import message_from_taskres, message_to_taskins, run_from_proto
|
|
31
32
|
from flwr.common.typing import Run
|
|
32
33
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
@@ -203,7 +204,9 @@ class GrpcDriver(Driver):
|
|
|
203
204
|
task_ins_list.append(taskins)
|
|
204
205
|
# Call GrpcDriverStub method
|
|
205
206
|
res: PushTaskInsResponse = self._stub.PushTaskIns(
|
|
206
|
-
PushTaskInsRequest(
|
|
207
|
+
PushTaskInsRequest(
|
|
208
|
+
task_ins_list=task_ins_list, run_id=cast(Run, self._run).run_id
|
|
209
|
+
)
|
|
207
210
|
)
|
|
208
211
|
return list(res.task_ids)
|
|
209
212
|
|
|
@@ -215,7 +218,9 @@ class GrpcDriver(Driver):
|
|
|
215
218
|
"""
|
|
216
219
|
# Pull TaskRes
|
|
217
220
|
res: PullTaskResResponse = self._stub.PullTaskRes(
|
|
218
|
-
PullTaskResRequest(
|
|
221
|
+
PullTaskResRequest(
|
|
222
|
+
node=self.node, task_ids=message_ids, run_id=cast(Run, self._run).run_id
|
|
223
|
+
)
|
|
219
224
|
)
|
|
220
225
|
# Convert TaskRes to Message
|
|
221
226
|
msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
|
|
@@ -258,60 +263,3 @@ class GrpcDriver(Driver):
|
|
|
258
263
|
return
|
|
259
264
|
# Disconnect
|
|
260
265
|
self._disconnect()
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
def _make_simple_grpc_retry_invoker() -> RetryInvoker:
|
|
264
|
-
"""Create a simple gRPC retry invoker."""
|
|
265
|
-
|
|
266
|
-
def _on_sucess(retry_state: RetryState) -> None:
|
|
267
|
-
if retry_state.tries > 1:
|
|
268
|
-
log(
|
|
269
|
-
INFO,
|
|
270
|
-
"Connection successful after %.2f seconds and %s tries.",
|
|
271
|
-
retry_state.elapsed_time,
|
|
272
|
-
retry_state.tries,
|
|
273
|
-
)
|
|
274
|
-
|
|
275
|
-
def _on_backoff(retry_state: RetryState) -> None:
|
|
276
|
-
if retry_state.tries == 1:
|
|
277
|
-
log(WARN, "Connection attempt failed, retrying...")
|
|
278
|
-
else:
|
|
279
|
-
log(
|
|
280
|
-
WARN,
|
|
281
|
-
"Connection attempt failed, retrying in %.2f seconds",
|
|
282
|
-
retry_state.actual_wait,
|
|
283
|
-
)
|
|
284
|
-
|
|
285
|
-
def _on_giveup(retry_state: RetryState) -> None:
|
|
286
|
-
if retry_state.tries > 1:
|
|
287
|
-
log(
|
|
288
|
-
WARN,
|
|
289
|
-
"Giving up reconnection after %.2f seconds and %s tries.",
|
|
290
|
-
retry_state.elapsed_time,
|
|
291
|
-
retry_state.tries,
|
|
292
|
-
)
|
|
293
|
-
|
|
294
|
-
return RetryInvoker(
|
|
295
|
-
wait_gen_factory=lambda: exponential(max_delay=MAX_RETRY_DELAY),
|
|
296
|
-
recoverable_exceptions=grpc.RpcError,
|
|
297
|
-
max_tries=None,
|
|
298
|
-
max_time=None,
|
|
299
|
-
on_success=_on_sucess,
|
|
300
|
-
on_backoff=_on_backoff,
|
|
301
|
-
on_giveup=_on_giveup,
|
|
302
|
-
should_giveup=lambda e: e.code() != grpc.StatusCode.UNAVAILABLE, # type: ignore
|
|
303
|
-
)
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
def _wrap_stub(stub: ServerAppIoStub, retry_invoker: RetryInvoker) -> None:
|
|
307
|
-
"""Wrap the gRPC stub with a retry invoker."""
|
|
308
|
-
|
|
309
|
-
def make_lambda(original_method: Any) -> Any:
|
|
310
|
-
return lambda *args, **kwargs: retry_invoker.invoke(
|
|
311
|
-
original_method, *args, **kwargs
|
|
312
|
-
)
|
|
313
|
-
|
|
314
|
-
for method_name in vars(stub):
|
|
315
|
-
method = getattr(stub, method_name)
|
|
316
|
-
if callable(method):
|
|
317
|
-
setattr(stub, method_name, make_lambda(method))
|
|
@@ -142,7 +142,11 @@ class InMemoryDriver(Driver):
|
|
|
142
142
|
# Pull TaskRes
|
|
143
143
|
task_res_list = self.state.get_task_res(task_ids=msg_ids)
|
|
144
144
|
# Delete tasks in state
|
|
145
|
-
|
|
145
|
+
# Delete the TaskIns/TaskRes pairs if TaskRes is found
|
|
146
|
+
task_ins_ids_to_delete = {
|
|
147
|
+
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
|
148
|
+
}
|
|
149
|
+
self.state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
|
|
146
150
|
# Convert TaskRes to Message
|
|
147
151
|
msgs = [message_from_taskres(taskres) for taskres in task_res_list]
|
|
148
152
|
return msgs
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -15,12 +15,12 @@
|
|
|
15
15
|
"""Run ServerApp."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import sys
|
|
19
18
|
from logging import DEBUG, ERROR
|
|
20
19
|
from typing import Optional
|
|
21
20
|
|
|
22
|
-
from flwr.common import Context
|
|
23
|
-
from flwr.common.
|
|
21
|
+
from flwr.common import Context, EventType, event
|
|
22
|
+
from flwr.common.exit_handlers import register_exit_handlers
|
|
23
|
+
from flwr.common.logger import log
|
|
24
24
|
from flwr.common.object_ref import load_app
|
|
25
25
|
|
|
26
26
|
from .driver import Driver
|
|
@@ -66,12 +66,11 @@ def run(
|
|
|
66
66
|
return context
|
|
67
67
|
|
|
68
68
|
|
|
69
|
-
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
|
|
70
69
|
def run_server_app() -> None:
|
|
71
70
|
"""Run Flower server app."""
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
71
|
+
event(EventType.RUN_SERVER_APP_ENTER)
|
|
72
|
+
log(
|
|
73
|
+
ERROR,
|
|
74
|
+
"The command `flower-server-app` has been replaced by `flwr run`.",
|
|
75
75
|
)
|
|
76
|
-
|
|
77
|
-
sys.exit()
|
|
76
|
+
register_exit_handlers(event_type=EventType.RUN_SERVER_APP_LEAVE)
|
flwr/server/serverapp/app.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower ServerApp process."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import argparse
|
|
18
19
|
import sys
|
|
19
20
|
from logging import DEBUG, ERROR, INFO
|
|
@@ -24,6 +25,7 @@ from typing import Optional
|
|
|
24
25
|
|
|
25
26
|
from flwr.cli.config_utils import get_fab_metadata
|
|
26
27
|
from flwr.cli.install import install_from_fab
|
|
28
|
+
from flwr.cli.utils import get_sha256_hash
|
|
27
29
|
from flwr.common.args import add_args_flwr_app_common
|
|
28
30
|
from flwr.common.config import (
|
|
29
31
|
get_flwr_dir,
|
|
@@ -50,7 +52,8 @@ from flwr.common.serde import (
|
|
|
50
52
|
run_from_proto,
|
|
51
53
|
run_status_to_proto,
|
|
52
54
|
)
|
|
53
|
-
from flwr.common.
|
|
55
|
+
from flwr.common.telemetry import EventType, event
|
|
56
|
+
from flwr.common.typing import RunNotRunningException, RunStatus
|
|
54
57
|
from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
|
|
55
58
|
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
56
59
|
PullServerAppInputsRequest,
|
|
@@ -96,7 +99,7 @@ def flwr_serverapp() -> None:
|
|
|
96
99
|
restore_output()
|
|
97
100
|
|
|
98
101
|
|
|
99
|
-
def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
102
|
+
def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
100
103
|
serverappio_api_address: str,
|
|
101
104
|
log_queue: Queue[Optional[str]],
|
|
102
105
|
run_once: bool,
|
|
@@ -112,7 +115,8 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
112
115
|
# Resolve directory where FABs are installed
|
|
113
116
|
flwr_dir_ = get_flwr_dir(flwr_dir)
|
|
114
117
|
log_uploader = None
|
|
115
|
-
|
|
118
|
+
success = True
|
|
119
|
+
hash_run_id = None
|
|
116
120
|
while True:
|
|
117
121
|
|
|
118
122
|
try:
|
|
@@ -128,6 +132,8 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
128
132
|
run = run_from_proto(res.run)
|
|
129
133
|
fab = fab_from_proto(res.fab)
|
|
130
134
|
|
|
135
|
+
hash_run_id = get_sha256_hash(run.run_id)
|
|
136
|
+
|
|
131
137
|
driver.set_run(run.run_id)
|
|
132
138
|
|
|
133
139
|
# Start log uploader for this run
|
|
@@ -170,6 +176,11 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
170
176
|
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
|
171
177
|
)
|
|
172
178
|
|
|
179
|
+
event(
|
|
180
|
+
EventType.FLWR_SERVERAPP_RUN_ENTER,
|
|
181
|
+
event_details={"run-id-hash": hash_run_id},
|
|
182
|
+
)
|
|
183
|
+
|
|
173
184
|
# Load and run the ServerApp with the Driver
|
|
174
185
|
updated_context = run_(
|
|
175
186
|
driver=driver,
|
|
@@ -186,11 +197,18 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
186
197
|
_ = driver._stub.PushServerAppOutputs(out_req)
|
|
187
198
|
|
|
188
199
|
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
|
200
|
+
except RunNotRunningException:
|
|
201
|
+
log(INFO, "")
|
|
202
|
+
log(INFO, "Run ID %s stopped.", run.run_id)
|
|
203
|
+
log(INFO, "")
|
|
204
|
+
run_status = None
|
|
205
|
+
success = False
|
|
189
206
|
|
|
190
207
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
191
208
|
exc_entity = "ServerApp"
|
|
192
209
|
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
|
193
210
|
run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))
|
|
211
|
+
success = False
|
|
194
212
|
|
|
195
213
|
finally:
|
|
196
214
|
# Stop log uploader for this run and upload final logs
|
|
@@ -206,6 +224,10 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
206
224
|
run_id=run.run_id, run_status=run_status_proto
|
|
207
225
|
)
|
|
208
226
|
)
|
|
227
|
+
event(
|
|
228
|
+
EventType.FLWR_SERVERAPP_RUN_LEAVE,
|
|
229
|
+
event_details={"run-id-hash": hash_run_id, "success": success},
|
|
230
|
+
)
|
|
209
231
|
|
|
210
232
|
# Stop the loop if `flwr-serverapp` is expected to process a single run
|
|
211
233
|
if run_once:
|
|
@@ -230,12 +252,5 @@ def _parse_args_run_flwr_serverapp() -> argparse.ArgumentParser:
|
|
|
230
252
|
help="When set, this process will start a single ServerApp for a pending Run. "
|
|
231
253
|
"If there is no pending Run, the process will exit.",
|
|
232
254
|
)
|
|
233
|
-
parser.add_argument(
|
|
234
|
-
"--root-certificates",
|
|
235
|
-
metavar="ROOT_CERT",
|
|
236
|
-
type=str,
|
|
237
|
-
help="Specifies the path to the PEM-encoded root certificate file for "
|
|
238
|
-
"establishing secure HTTPS connections.",
|
|
239
|
-
)
|
|
240
255
|
add_args_flwr_app_common(parser=parser)
|
|
241
256
|
return parser
|
|
@@ -32,6 +32,7 @@ from flwr.common.serde import (
|
|
|
32
32
|
fab_from_proto,
|
|
33
33
|
fab_to_proto,
|
|
34
34
|
run_status_from_proto,
|
|
35
|
+
run_status_to_proto,
|
|
35
36
|
run_to_proto,
|
|
36
37
|
user_config_from_proto,
|
|
37
38
|
)
|
|
@@ -48,6 +49,8 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
48
49
|
CreateRunResponse,
|
|
49
50
|
GetRunRequest,
|
|
50
51
|
GetRunResponse,
|
|
52
|
+
GetRunStatusRequest,
|
|
53
|
+
GetRunStatusResponse,
|
|
51
54
|
UpdateRunStatusRequest,
|
|
52
55
|
UpdateRunStatusResponse,
|
|
53
56
|
)
|
|
@@ -67,6 +70,7 @@ from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
|
67
70
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
68
71
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
69
72
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
73
|
+
from flwr.server.superlink.utils import abort_if
|
|
70
74
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
71
75
|
|
|
72
76
|
|
|
@@ -85,7 +89,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
85
89
|
) -> GetNodesResponse:
|
|
86
90
|
"""Get available nodes."""
|
|
87
91
|
log(DEBUG, "ServerAppIoServicer.GetNodes")
|
|
92
|
+
|
|
93
|
+
# Init state
|
|
88
94
|
state: LinkState = self.state_factory.state()
|
|
95
|
+
|
|
96
|
+
# Abort if the run is not running
|
|
97
|
+
abort_if(
|
|
98
|
+
request.run_id,
|
|
99
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
100
|
+
state,
|
|
101
|
+
context,
|
|
102
|
+
)
|
|
103
|
+
|
|
89
104
|
all_ids: set[int] = state.get_nodes(request.run_id)
|
|
90
105
|
nodes: list[Node] = [
|
|
91
106
|
Node(node_id=node_id, anonymous=False) for node_id in all_ids
|
|
@@ -123,6 +138,17 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
123
138
|
"""Push a set of TaskIns."""
|
|
124
139
|
log(DEBUG, "ServerAppIoServicer.PushTaskIns")
|
|
125
140
|
|
|
141
|
+
# Init state
|
|
142
|
+
state: LinkState = self.state_factory.state()
|
|
143
|
+
|
|
144
|
+
# Abort if the run is not running
|
|
145
|
+
abort_if(
|
|
146
|
+
request.run_id,
|
|
147
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
148
|
+
state,
|
|
149
|
+
context,
|
|
150
|
+
)
|
|
151
|
+
|
|
126
152
|
# Set pushed_at (timestamp in seconds)
|
|
127
153
|
pushed_at = time.time()
|
|
128
154
|
for task_ins in request.task_ins_list:
|
|
@@ -133,9 +159,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
133
159
|
for task_ins in request.task_ins_list:
|
|
134
160
|
validation_errors = validate_task_ins_or_res(task_ins)
|
|
135
161
|
_raise_if(bool(validation_errors), ", ".join(validation_errors))
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
162
|
+
_raise_if(
|
|
163
|
+
request.run_id != task_ins.run_id, "`task_ins` has mismatched `run_id`"
|
|
164
|
+
)
|
|
139
165
|
|
|
140
166
|
# Store each TaskIns
|
|
141
167
|
task_ids: list[Optional[UUID]] = []
|
|
@@ -153,33 +179,35 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
153
179
|
"""Pull a set of TaskRes."""
|
|
154
180
|
log(DEBUG, "ServerAppIoServicer.PullTaskRes")
|
|
155
181
|
|
|
156
|
-
# Convert each task_id str to UUID
|
|
157
|
-
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
|
158
|
-
|
|
159
182
|
# Init state
|
|
160
183
|
state: LinkState = self.state_factory.state()
|
|
161
184
|
|
|
162
|
-
#
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
if context.is_active():
|
|
170
|
-
return
|
|
171
|
-
if context.code() != grpc.StatusCode.OK:
|
|
172
|
-
return
|
|
173
|
-
|
|
174
|
-
# Delete delivered TaskIns and TaskRes
|
|
175
|
-
state.delete_tasks(task_ids=task_ids)
|
|
185
|
+
# Abort if the run is not running
|
|
186
|
+
abort_if(
|
|
187
|
+
request.run_id,
|
|
188
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
189
|
+
state,
|
|
190
|
+
context,
|
|
191
|
+
)
|
|
176
192
|
|
|
177
|
-
|
|
193
|
+
# Convert each task_id str to UUID
|
|
194
|
+
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
|
178
195
|
|
|
179
196
|
# Read from state
|
|
180
197
|
task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
|
|
181
198
|
|
|
182
|
-
|
|
199
|
+
# Validate request
|
|
200
|
+
for task_res in task_res_list:
|
|
201
|
+
_raise_if(
|
|
202
|
+
request.run_id != task_res.run_id, "`task_res` has mismatched `run_id`"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Delete the TaskIns/TaskRes pairs if TaskRes is found
|
|
206
|
+
task_ins_ids_to_delete = {
|
|
207
|
+
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
|
208
|
+
}
|
|
209
|
+
state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
|
|
210
|
+
|
|
183
211
|
return PullTaskResResponse(task_res_list=task_res_list)
|
|
184
212
|
|
|
185
213
|
def GetRun(
|
|
@@ -255,7 +283,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
255
283
|
) -> PushServerAppOutputsResponse:
|
|
256
284
|
"""Push ServerApp process outputs."""
|
|
257
285
|
log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
|
|
286
|
+
|
|
287
|
+
# Init state
|
|
258
288
|
state = self.state_factory.state()
|
|
289
|
+
|
|
290
|
+
# Abort if the run is not running
|
|
291
|
+
abort_if(
|
|
292
|
+
request.run_id,
|
|
293
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
294
|
+
state,
|
|
295
|
+
context,
|
|
296
|
+
)
|
|
297
|
+
|
|
259
298
|
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
|
260
299
|
return PushServerAppOutputsResponse()
|
|
261
300
|
|
|
@@ -263,9 +302,14 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
263
302
|
self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
|
|
264
303
|
) -> UpdateRunStatusResponse:
|
|
265
304
|
"""Update the status of a run."""
|
|
266
|
-
log(DEBUG, "
|
|
305
|
+
log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
|
|
306
|
+
|
|
307
|
+
# Init state
|
|
267
308
|
state = self.state_factory.state()
|
|
268
309
|
|
|
310
|
+
# Abort if the run is finished
|
|
311
|
+
abort_if(request.run_id, [Status.FINISHED], state, context)
|
|
312
|
+
|
|
269
313
|
# Update the run status
|
|
270
314
|
state.update_run_status(
|
|
271
315
|
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
|
|
@@ -284,6 +328,21 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
284
328
|
state.add_serverapp_log(request.run_id, merged_logs)
|
|
285
329
|
return PushLogsResponse()
|
|
286
330
|
|
|
331
|
+
def GetRunStatus(
|
|
332
|
+
self, request: GetRunStatusRequest, context: grpc.ServicerContext
|
|
333
|
+
) -> GetRunStatusResponse:
|
|
334
|
+
"""Get the status of a run."""
|
|
335
|
+
log(DEBUG, "ServerAppIoServicer.GetRunStatus")
|
|
336
|
+
state = self.state_factory.state()
|
|
337
|
+
|
|
338
|
+
# Get run status from LinkState
|
|
339
|
+
run_statuses = state.get_run_status(set(request.run_ids))
|
|
340
|
+
run_status_dict = {
|
|
341
|
+
run_id: run_status_to_proto(run_status)
|
|
342
|
+
for run_id, run_status in run_statuses.items()
|
|
343
|
+
}
|
|
344
|
+
return GetRunStatusResponse(run_status_dict=run_status_dict)
|
|
345
|
+
|
|
287
346
|
|
|
288
347
|
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
289
348
|
if validation_error:
|