flwr-nightly 1.14.0.dev20241204__py3-none-any.whl → 1.14.0.dev20241216__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +5 -0
- flwr/cli/build.py +1 -0
- flwr/cli/cli_user_auth_interceptor.py +86 -0
- flwr/cli/config_utils.py +19 -2
- flwr/cli/example.py +1 -0
- flwr/cli/install.py +1 -0
- flwr/cli/log.py +11 -31
- flwr/cli/login/__init__.py +22 -0
- flwr/cli/login/login.py +81 -0
- flwr/cli/ls.py +25 -55
- flwr/cli/new/__init__.py +1 -0
- flwr/cli/new/new.py +2 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -2
- flwr/cli/run/__init__.py +1 -0
- flwr/cli/run/run.py +17 -39
- flwr/cli/stop.py +129 -0
- flwr/cli/utils.py +96 -1
- flwr/client/app.py +14 -3
- flwr/client/client.py +1 -0
- flwr/client/clientapp/app.py +4 -1
- flwr/client/clientapp/utils.py +1 -0
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/connection.py +1 -1
- flwr/client/grpc_rere_client/connection.py +13 -7
- flwr/client/message_handler/message_handler.py +1 -0
- flwr/client/mod/comms_mods.py +1 -0
- flwr/client/mod/localdp_mod.py +1 -1
- flwr/client/nodestate/__init__.py +1 -0
- flwr/client/nodestate/nodestate.py +1 -0
- flwr/client/nodestate/nodestate_factory.py +1 -0
- flwr/client/rest_client/connection.py +3 -3
- flwr/client/supernode/app.py +1 -0
- flwr/common/address.py +1 -0
- flwr/common/args.py +1 -0
- flwr/common/auth_plugin/__init__.py +24 -0
- flwr/common/auth_plugin/auth_plugin.py +111 -0
- flwr/common/config.py +3 -1
- flwr/common/constant.py +6 -1
- flwr/common/logger.py +17 -1
- flwr/common/message.py +1 -0
- flwr/common/object_ref.py +57 -54
- flwr/common/pyproject.py +1 -0
- flwr/common/record/__init__.py +1 -0
- flwr/common/record/parametersrecord.py +1 -0
- flwr/common/retry_invoker.py +77 -0
- flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
- flwr/common/telemetry.py +2 -1
- flwr/common/typing.py +12 -0
- flwr/common/version.py +1 -0
- flwr/proto/exec_pb2.py +27 -3
- flwr/proto/exec_pb2.pyi +103 -0
- flwr/proto/exec_pb2_grpc.py +102 -0
- flwr/proto/exec_pb2_grpc.pyi +39 -0
- flwr/proto/fab_pb2.py +4 -4
- flwr/proto/fab_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +18 -18
- flwr/proto/serverappio_pb2.pyi +8 -2
- flwr/proto/serverappio_pb2_grpc.py +34 -0
- flwr/proto/serverappio_pb2_grpc.pyi +13 -0
- flwr/proto/simulationio_pb2.py +2 -2
- flwr/proto/simulationio_pb2_grpc.py +34 -0
- flwr/proto/simulationio_pb2_grpc.pyi +13 -0
- flwr/server/app.py +52 -1
- flwr/server/compat/app_utils.py +7 -1
- flwr/server/driver/grpc_driver.py +11 -63
- flwr/server/driver/inmemory_driver.py +5 -1
- flwr/server/serverapp/app.py +9 -2
- flwr/server/strategy/dpfedavg_fixed.py +1 -0
- flwr/server/superlink/driver/serverappio_grpc.py +1 -0
- flwr/server/superlink/driver/serverappio_servicer.py +72 -22
- flwr/server/superlink/ffs/disk_ffs.py +1 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
- flwr/server/superlink/fleet/message_handler/message_handler.py +32 -5
- flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
- flwr/server/superlink/fleet/vce/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
- flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -30
- flwr/server/superlink/linkstate/linkstate.py +13 -2
- flwr/server/superlink/linkstate/sqlite_linkstate.py +24 -44
- flwr/server/superlink/simulation/simulationio_servicer.py +20 -0
- flwr/server/superlink/utils.py +65 -0
- flwr/simulation/app.py +1 -0
- flwr/simulation/ray_transport/ray_actor.py +1 -0
- flwr/simulation/ray_transport/utils.py +1 -0
- flwr/simulation/run_simulation.py +1 -15
- flwr/simulation/simulationio_connection.py +3 -0
- flwr/superexec/app.py +1 -0
- flwr/superexec/deployment.py +1 -0
- flwr/superexec/exec_grpc.py +19 -1
- flwr/superexec/exec_servicer.py +76 -2
- flwr/superexec/exec_user_auth_interceptor.py +101 -0
- flwr/superexec/executor.py +1 -0
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/METADATA +8 -7
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/RECORD +101 -93
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/entry_points.txt +0 -0
flwr/server/app.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower server app."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import argparse
|
|
18
19
|
import csv
|
|
19
20
|
import importlib.util
|
|
@@ -24,9 +25,10 @@ from collections.abc import Sequence
|
|
|
24
25
|
from logging import DEBUG, INFO, WARN
|
|
25
26
|
from pathlib import Path
|
|
26
27
|
from time import sleep
|
|
27
|
-
from typing import Optional
|
|
28
|
+
from typing import Any, Optional
|
|
28
29
|
|
|
29
30
|
import grpc
|
|
31
|
+
import yaml
|
|
30
32
|
from cryptography.exceptions import UnsupportedAlgorithm
|
|
31
33
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
32
34
|
from cryptography.hazmat.primitives.serialization import (
|
|
@@ -37,8 +39,10 @@ from cryptography.hazmat.primitives.serialization import (
|
|
|
37
39
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
38
40
|
from flwr.common.address import parse_address
|
|
39
41
|
from flwr.common.args import try_obtain_server_certificates
|
|
42
|
+
from flwr.common.auth_plugin import ExecAuthPlugin
|
|
40
43
|
from flwr.common.config import get_flwr_dir, parse_config_args
|
|
41
44
|
from flwr.common.constant import (
|
|
45
|
+
AUTH_TYPE,
|
|
42
46
|
CLIENT_OCTET,
|
|
43
47
|
EXEC_API_DEFAULT_SERVER_ADDRESS,
|
|
44
48
|
FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
@@ -88,6 +92,19 @@ DATABASE = ":flwr-in-memory-state:"
|
|
|
88
92
|
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
|
|
89
93
|
|
|
90
94
|
|
|
95
|
+
try:
|
|
96
|
+
from flwr.ee import add_ee_args_superlink, get_exec_auth_plugins
|
|
97
|
+
except ImportError:
|
|
98
|
+
|
|
99
|
+
# pylint: disable-next=unused-argument
|
|
100
|
+
def add_ee_args_superlink(parser: argparse.ArgumentParser) -> None:
|
|
101
|
+
"""Add EE-specific arguments to the parser."""
|
|
102
|
+
|
|
103
|
+
def get_exec_auth_plugins() -> dict[str, type[ExecAuthPlugin]]:
|
|
104
|
+
"""Return all Exec API authentication plugins."""
|
|
105
|
+
raise NotImplementedError("No authentication plugins are currently supported.")
|
|
106
|
+
|
|
107
|
+
|
|
91
108
|
def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
92
109
|
*,
|
|
93
110
|
server_address: str = FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
@@ -246,6 +263,12 @@ def run_superlink() -> None:
|
|
|
246
263
|
# Obtain certificates
|
|
247
264
|
certificates = try_obtain_server_certificates(args, args.fleet_api_type)
|
|
248
265
|
|
|
266
|
+
user_auth_config = _try_obtain_user_auth_config(args)
|
|
267
|
+
auth_plugin: Optional[ExecAuthPlugin] = None
|
|
268
|
+
# user_auth_config is None only if the args.user_auth_config is not provided
|
|
269
|
+
if user_auth_config is not None:
|
|
270
|
+
auth_plugin = _try_obtain_exec_auth_plugin(user_auth_config)
|
|
271
|
+
|
|
249
272
|
# Initialize StateFactory
|
|
250
273
|
state_factory = LinkStateFactory(args.database)
|
|
251
274
|
|
|
@@ -263,6 +286,7 @@ def run_superlink() -> None:
|
|
|
263
286
|
config=parse_config_args(
|
|
264
287
|
[args.executor_config] if args.executor_config else args.executor_config
|
|
265
288
|
),
|
|
289
|
+
auth_plugin=auth_plugin,
|
|
266
290
|
)
|
|
267
291
|
grpc_servers = [exec_server]
|
|
268
292
|
|
|
@@ -559,6 +583,32 @@ def _try_setup_node_authentication(
|
|
|
559
583
|
)
|
|
560
584
|
|
|
561
585
|
|
|
586
|
+
def _try_obtain_user_auth_config(args: argparse.Namespace) -> Optional[dict[str, Any]]:
|
|
587
|
+
if getattr(args, "user_auth_config", None) is not None:
|
|
588
|
+
with open(args.user_auth_config, encoding="utf-8") as file:
|
|
589
|
+
config: dict[str, Any] = yaml.safe_load(file)
|
|
590
|
+
return config
|
|
591
|
+
return None
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
def _try_obtain_exec_auth_plugin(config: dict[str, Any]) -> Optional[ExecAuthPlugin]:
|
|
595
|
+
auth_config: dict[str, Any] = config.get("authentication", {})
|
|
596
|
+
auth_type: str = auth_config.get(AUTH_TYPE, "")
|
|
597
|
+
try:
|
|
598
|
+
all_plugins: dict[str, type[ExecAuthPlugin]] = get_exec_auth_plugins()
|
|
599
|
+
auth_plugin_class = all_plugins[auth_type]
|
|
600
|
+
return auth_plugin_class(config=auth_config)
|
|
601
|
+
except KeyError:
|
|
602
|
+
if auth_type != "":
|
|
603
|
+
sys.exit(
|
|
604
|
+
f'Authentication type "{auth_type}" is not supported. '
|
|
605
|
+
"Please provide a valid authentication type in the configuration."
|
|
606
|
+
)
|
|
607
|
+
sys.exit("No authentication type is provided in the configuration.")
|
|
608
|
+
except NotImplementedError:
|
|
609
|
+
sys.exit("No authentication plugins are currently supported.")
|
|
610
|
+
|
|
611
|
+
|
|
562
612
|
def _run_fleet_api_grpc_rere(
|
|
563
613
|
address: str,
|
|
564
614
|
state_factory: LinkStateFactory,
|
|
@@ -657,6 +707,7 @@ def _parse_args_run_superlink() -> argparse.ArgumentParser:
|
|
|
657
707
|
)
|
|
658
708
|
|
|
659
709
|
_add_args_common(parser=parser)
|
|
710
|
+
add_ee_args_superlink(parser=parser)
|
|
660
711
|
_add_args_serverappio_api(parser=parser)
|
|
661
712
|
_add_args_fleet_api(parser=parser)
|
|
662
713
|
_add_args_exec_api(parser=parser)
|
flwr/server/compat/app_utils.py
CHANGED
|
@@ -17,6 +17,8 @@
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
19
|
|
|
20
|
+
from flwr.common.typing import RunNotRunningException
|
|
21
|
+
|
|
20
22
|
from ..client_manager import ClientManager
|
|
21
23
|
from ..compat.driver_client_proxy import DriverClientProxy
|
|
22
24
|
from ..driver import Driver
|
|
@@ -74,7 +76,11 @@ def _update_client_manager(
|
|
|
74
76
|
# Loop until the driver is disconnected
|
|
75
77
|
registered_nodes: dict[int, DriverClientProxy] = {}
|
|
76
78
|
while not f_stop.is_set():
|
|
77
|
-
|
|
79
|
+
try:
|
|
80
|
+
all_node_ids = set(driver.get_node_ids())
|
|
81
|
+
except RunNotRunningException:
|
|
82
|
+
f_stop.set()
|
|
83
|
+
break
|
|
78
84
|
dead_nodes = set(registered_nodes).difference(all_node_ids)
|
|
79
85
|
new_nodes = all_node_ids.difference(registered_nodes)
|
|
80
86
|
|
|
@@ -14,19 +14,20 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower gRPC Driver."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import time
|
|
18
19
|
import warnings
|
|
19
20
|
from collections.abc import Iterable
|
|
20
|
-
from logging import DEBUG,
|
|
21
|
-
from typing import
|
|
21
|
+
from logging import DEBUG, WARNING
|
|
22
|
+
from typing import Optional, cast
|
|
22
23
|
|
|
23
24
|
import grpc
|
|
24
25
|
|
|
25
26
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
26
|
-
from flwr.common.constant import
|
|
27
|
+
from flwr.common.constant import SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS
|
|
27
28
|
from flwr.common.grpc import create_channel
|
|
28
29
|
from flwr.common.logger import log
|
|
29
|
-
from flwr.common.retry_invoker import
|
|
30
|
+
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
|
30
31
|
from flwr.common.serde import message_from_taskres, message_to_taskins, run_from_proto
|
|
31
32
|
from flwr.common.typing import Run
|
|
32
33
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
@@ -203,7 +204,9 @@ class GrpcDriver(Driver):
|
|
|
203
204
|
task_ins_list.append(taskins)
|
|
204
205
|
# Call GrpcDriverStub method
|
|
205
206
|
res: PushTaskInsResponse = self._stub.PushTaskIns(
|
|
206
|
-
PushTaskInsRequest(
|
|
207
|
+
PushTaskInsRequest(
|
|
208
|
+
task_ins_list=task_ins_list, run_id=cast(Run, self._run).run_id
|
|
209
|
+
)
|
|
207
210
|
)
|
|
208
211
|
return list(res.task_ids)
|
|
209
212
|
|
|
@@ -215,7 +218,9 @@ class GrpcDriver(Driver):
|
|
|
215
218
|
"""
|
|
216
219
|
# Pull TaskRes
|
|
217
220
|
res: PullTaskResResponse = self._stub.PullTaskRes(
|
|
218
|
-
PullTaskResRequest(
|
|
221
|
+
PullTaskResRequest(
|
|
222
|
+
node=self.node, task_ids=message_ids, run_id=cast(Run, self._run).run_id
|
|
223
|
+
)
|
|
219
224
|
)
|
|
220
225
|
# Convert TaskRes to Message
|
|
221
226
|
msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
|
|
@@ -258,60 +263,3 @@ class GrpcDriver(Driver):
|
|
|
258
263
|
return
|
|
259
264
|
# Disconnect
|
|
260
265
|
self._disconnect()
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
def _make_simple_grpc_retry_invoker() -> RetryInvoker:
|
|
264
|
-
"""Create a simple gRPC retry invoker."""
|
|
265
|
-
|
|
266
|
-
def _on_sucess(retry_state: RetryState) -> None:
|
|
267
|
-
if retry_state.tries > 1:
|
|
268
|
-
log(
|
|
269
|
-
INFO,
|
|
270
|
-
"Connection successful after %.2f seconds and %s tries.",
|
|
271
|
-
retry_state.elapsed_time,
|
|
272
|
-
retry_state.tries,
|
|
273
|
-
)
|
|
274
|
-
|
|
275
|
-
def _on_backoff(retry_state: RetryState) -> None:
|
|
276
|
-
if retry_state.tries == 1:
|
|
277
|
-
log(WARN, "Connection attempt failed, retrying...")
|
|
278
|
-
else:
|
|
279
|
-
log(
|
|
280
|
-
WARN,
|
|
281
|
-
"Connection attempt failed, retrying in %.2f seconds",
|
|
282
|
-
retry_state.actual_wait,
|
|
283
|
-
)
|
|
284
|
-
|
|
285
|
-
def _on_giveup(retry_state: RetryState) -> None:
|
|
286
|
-
if retry_state.tries > 1:
|
|
287
|
-
log(
|
|
288
|
-
WARN,
|
|
289
|
-
"Giving up reconnection after %.2f seconds and %s tries.",
|
|
290
|
-
retry_state.elapsed_time,
|
|
291
|
-
retry_state.tries,
|
|
292
|
-
)
|
|
293
|
-
|
|
294
|
-
return RetryInvoker(
|
|
295
|
-
wait_gen_factory=lambda: exponential(max_delay=MAX_RETRY_DELAY),
|
|
296
|
-
recoverable_exceptions=grpc.RpcError,
|
|
297
|
-
max_tries=None,
|
|
298
|
-
max_time=None,
|
|
299
|
-
on_success=_on_sucess,
|
|
300
|
-
on_backoff=_on_backoff,
|
|
301
|
-
on_giveup=_on_giveup,
|
|
302
|
-
should_giveup=lambda e: e.code() != grpc.StatusCode.UNAVAILABLE, # type: ignore
|
|
303
|
-
)
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
def _wrap_stub(stub: ServerAppIoStub, retry_invoker: RetryInvoker) -> None:
|
|
307
|
-
"""Wrap the gRPC stub with a retry invoker."""
|
|
308
|
-
|
|
309
|
-
def make_lambda(original_method: Any) -> Any:
|
|
310
|
-
return lambda *args, **kwargs: retry_invoker.invoke(
|
|
311
|
-
original_method, *args, **kwargs
|
|
312
|
-
)
|
|
313
|
-
|
|
314
|
-
for method_name in vars(stub):
|
|
315
|
-
method = getattr(stub, method_name)
|
|
316
|
-
if callable(method):
|
|
317
|
-
setattr(stub, method_name, make_lambda(method))
|
|
@@ -142,7 +142,11 @@ class InMemoryDriver(Driver):
|
|
|
142
142
|
# Pull TaskRes
|
|
143
143
|
task_res_list = self.state.get_task_res(task_ids=msg_ids)
|
|
144
144
|
# Delete tasks in state
|
|
145
|
-
|
|
145
|
+
# Delete the TaskIns/TaskRes pairs if TaskRes is found
|
|
146
|
+
task_ins_ids_to_delete = {
|
|
147
|
+
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
|
148
|
+
}
|
|
149
|
+
self.state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
|
|
146
150
|
# Convert TaskRes to Message
|
|
147
151
|
msgs = [message_from_taskres(taskres) for taskres in task_res_list]
|
|
148
152
|
return msgs
|
flwr/server/serverapp/app.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower ServerApp process."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import argparse
|
|
18
19
|
import sys
|
|
19
20
|
from logging import DEBUG, ERROR, INFO
|
|
@@ -50,7 +51,7 @@ from flwr.common.serde import (
|
|
|
50
51
|
run_from_proto,
|
|
51
52
|
run_status_to_proto,
|
|
52
53
|
)
|
|
53
|
-
from flwr.common.typing import RunStatus
|
|
54
|
+
from flwr.common.typing import RunNotRunningException, RunStatus
|
|
54
55
|
from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
|
|
55
56
|
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
56
57
|
PullServerAppInputsRequest,
|
|
@@ -96,7 +97,7 @@ def flwr_serverapp() -> None:
|
|
|
96
97
|
restore_output()
|
|
97
98
|
|
|
98
99
|
|
|
99
|
-
def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
100
|
+
def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
100
101
|
serverappio_api_address: str,
|
|
101
102
|
log_queue: Queue[Optional[str]],
|
|
102
103
|
run_once: bool,
|
|
@@ -187,6 +188,12 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
187
188
|
|
|
188
189
|
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
|
189
190
|
|
|
191
|
+
except RunNotRunningException:
|
|
192
|
+
log(INFO, "")
|
|
193
|
+
log(INFO, "Run ID %s stopped.", run.run_id)
|
|
194
|
+
log(INFO, "")
|
|
195
|
+
run_status = None
|
|
196
|
+
|
|
190
197
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
191
198
|
exc_entity = "ServerApp"
|
|
192
199
|
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
|
@@ -32,6 +32,7 @@ from flwr.common.serde import (
|
|
|
32
32
|
fab_from_proto,
|
|
33
33
|
fab_to_proto,
|
|
34
34
|
run_status_from_proto,
|
|
35
|
+
run_status_to_proto,
|
|
35
36
|
run_to_proto,
|
|
36
37
|
user_config_from_proto,
|
|
37
38
|
)
|
|
@@ -48,6 +49,8 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
48
49
|
CreateRunResponse,
|
|
49
50
|
GetRunRequest,
|
|
50
51
|
GetRunResponse,
|
|
52
|
+
GetRunStatusRequest,
|
|
53
|
+
GetRunStatusResponse,
|
|
51
54
|
UpdateRunStatusRequest,
|
|
52
55
|
UpdateRunStatusResponse,
|
|
53
56
|
)
|
|
@@ -67,6 +70,7 @@ from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
|
67
70
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
68
71
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
69
72
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
73
|
+
from flwr.server.superlink.utils import abort_if
|
|
70
74
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
71
75
|
|
|
72
76
|
|
|
@@ -85,7 +89,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
85
89
|
) -> GetNodesResponse:
|
|
86
90
|
"""Get available nodes."""
|
|
87
91
|
log(DEBUG, "ServerAppIoServicer.GetNodes")
|
|
92
|
+
|
|
93
|
+
# Init state
|
|
88
94
|
state: LinkState = self.state_factory.state()
|
|
95
|
+
|
|
96
|
+
# Abort if the run is not running
|
|
97
|
+
abort_if(
|
|
98
|
+
request.run_id,
|
|
99
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
100
|
+
state,
|
|
101
|
+
context,
|
|
102
|
+
)
|
|
103
|
+
|
|
89
104
|
all_ids: set[int] = state.get_nodes(request.run_id)
|
|
90
105
|
nodes: list[Node] = [
|
|
91
106
|
Node(node_id=node_id, anonymous=False) for node_id in all_ids
|
|
@@ -123,6 +138,17 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
123
138
|
"""Push a set of TaskIns."""
|
|
124
139
|
log(DEBUG, "ServerAppIoServicer.PushTaskIns")
|
|
125
140
|
|
|
141
|
+
# Init state
|
|
142
|
+
state: LinkState = self.state_factory.state()
|
|
143
|
+
|
|
144
|
+
# Abort if the run is not running
|
|
145
|
+
abort_if(
|
|
146
|
+
request.run_id,
|
|
147
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
148
|
+
state,
|
|
149
|
+
context,
|
|
150
|
+
)
|
|
151
|
+
|
|
126
152
|
# Set pushed_at (timestamp in seconds)
|
|
127
153
|
pushed_at = time.time()
|
|
128
154
|
for task_ins in request.task_ins_list:
|
|
@@ -134,9 +160,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
134
160
|
validation_errors = validate_task_ins_or_res(task_ins)
|
|
135
161
|
_raise_if(bool(validation_errors), ", ".join(validation_errors))
|
|
136
162
|
|
|
137
|
-
# Init state
|
|
138
|
-
state: LinkState = self.state_factory.state()
|
|
139
|
-
|
|
140
163
|
# Store each TaskIns
|
|
141
164
|
task_ids: list[Optional[UUID]] = []
|
|
142
165
|
for task_ins in request.task_ins_list:
|
|
@@ -153,33 +176,29 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
153
176
|
"""Pull a set of TaskRes."""
|
|
154
177
|
log(DEBUG, "ServerAppIoServicer.PullTaskRes")
|
|
155
178
|
|
|
156
|
-
# Convert each task_id str to UUID
|
|
157
|
-
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
|
158
|
-
|
|
159
179
|
# Init state
|
|
160
180
|
state: LinkState = self.state_factory.state()
|
|
161
181
|
|
|
162
|
-
#
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
if context.is_active():
|
|
170
|
-
return
|
|
171
|
-
if context.code() != grpc.StatusCode.OK:
|
|
172
|
-
return
|
|
173
|
-
|
|
174
|
-
# Delete delivered TaskIns and TaskRes
|
|
175
|
-
state.delete_tasks(task_ids=task_ids)
|
|
182
|
+
# Abort if the run is not running
|
|
183
|
+
abort_if(
|
|
184
|
+
request.run_id,
|
|
185
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
186
|
+
state,
|
|
187
|
+
context,
|
|
188
|
+
)
|
|
176
189
|
|
|
177
|
-
|
|
190
|
+
# Convert each task_id str to UUID
|
|
191
|
+
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
|
178
192
|
|
|
179
193
|
# Read from state
|
|
180
194
|
task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
|
|
181
195
|
|
|
182
|
-
|
|
196
|
+
# Delete the TaskIns/TaskRes pairs if TaskRes is found
|
|
197
|
+
task_ins_ids_to_delete = {
|
|
198
|
+
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
|
199
|
+
}
|
|
200
|
+
state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
|
|
201
|
+
|
|
183
202
|
return PullTaskResResponse(task_res_list=task_res_list)
|
|
184
203
|
|
|
185
204
|
def GetRun(
|
|
@@ -255,7 +274,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
255
274
|
) -> PushServerAppOutputsResponse:
|
|
256
275
|
"""Push ServerApp process outputs."""
|
|
257
276
|
log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
|
|
277
|
+
|
|
278
|
+
# Init state
|
|
258
279
|
state = self.state_factory.state()
|
|
280
|
+
|
|
281
|
+
# Abort if the run is not running
|
|
282
|
+
abort_if(
|
|
283
|
+
request.run_id,
|
|
284
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
285
|
+
state,
|
|
286
|
+
context,
|
|
287
|
+
)
|
|
288
|
+
|
|
259
289
|
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
|
260
290
|
return PushServerAppOutputsResponse()
|
|
261
291
|
|
|
@@ -264,8 +294,13 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
264
294
|
) -> UpdateRunStatusResponse:
|
|
265
295
|
"""Update the status of a run."""
|
|
266
296
|
log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
|
|
297
|
+
|
|
298
|
+
# Init state
|
|
267
299
|
state = self.state_factory.state()
|
|
268
300
|
|
|
301
|
+
# Abort if the run is finished
|
|
302
|
+
abort_if(request.run_id, [Status.FINISHED], state, context)
|
|
303
|
+
|
|
269
304
|
# Update the run status
|
|
270
305
|
state.update_run_status(
|
|
271
306
|
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
|
|
@@ -284,6 +319,21 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
284
319
|
state.add_serverapp_log(request.run_id, merged_logs)
|
|
285
320
|
return PushLogsResponse()
|
|
286
321
|
|
|
322
|
+
def GetRunStatus(
|
|
323
|
+
self, request: GetRunStatusRequest, context: grpc.ServicerContext
|
|
324
|
+
) -> GetRunStatusResponse:
|
|
325
|
+
"""Get the status of a run."""
|
|
326
|
+
log(DEBUG, "ServerAppIoServicer.GetRunStatus")
|
|
327
|
+
state = self.state_factory.state()
|
|
328
|
+
|
|
329
|
+
# Get run status from LinkState
|
|
330
|
+
run_statuses = state.get_run_status(set(request.run_ids))
|
|
331
|
+
run_status_dict = {
|
|
332
|
+
run_id: run_status_to_proto(run_status)
|
|
333
|
+
for run_id, run_status in run_statuses.items()
|
|
334
|
+
}
|
|
335
|
+
return GetRunStatusResponse(run_status_dict=run_status_dict)
|
|
336
|
+
|
|
287
337
|
|
|
288
338
|
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
289
339
|
if validation_error:
|
|
@@ -20,6 +20,7 @@ from logging import DEBUG, INFO
|
|
|
20
20
|
import grpc
|
|
21
21
|
|
|
22
22
|
from flwr.common.logger import log
|
|
23
|
+
from flwr.common.typing import InvalidRunStatusException
|
|
23
24
|
from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611
|
|
24
25
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
25
26
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
@@ -38,6 +39,7 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=
|
|
|
38
39
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
39
40
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
40
41
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
42
|
+
from flwr.server.superlink.utils import abort_grpc_context
|
|
41
43
|
|
|
42
44
|
|
|
43
45
|
class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
@@ -105,27 +107,45 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
105
107
|
)
|
|
106
108
|
else:
|
|
107
109
|
log(INFO, "[Fleet.PushTaskRes] No task results to push")
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
res = message_handler.push_task_res(
|
|
113
|
+
request=request,
|
|
114
|
+
state=self.state_factory.state(),
|
|
115
|
+
)
|
|
116
|
+
except InvalidRunStatusException as e:
|
|
117
|
+
abort_grpc_context(e.message, context)
|
|
118
|
+
|
|
119
|
+
return res
|
|
112
120
|
|
|
113
121
|
def GetRun(
|
|
114
122
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
|
115
123
|
) -> GetRunResponse:
|
|
116
124
|
"""Get run information."""
|
|
117
125
|
log(INFO, "[Fleet.GetRun] Requesting `Run` for run_id=%s", request.run_id)
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
res = message_handler.get_run(
|
|
129
|
+
request=request,
|
|
130
|
+
state=self.state_factory.state(),
|
|
131
|
+
)
|
|
132
|
+
except InvalidRunStatusException as e:
|
|
133
|
+
abort_grpc_context(e.message, context)
|
|
134
|
+
|
|
135
|
+
return res
|
|
122
136
|
|
|
123
137
|
def GetFab(
|
|
124
138
|
self, request: GetFabRequest, context: grpc.ServicerContext
|
|
125
139
|
) -> GetFabResponse:
|
|
126
140
|
"""Get FAB."""
|
|
127
141
|
log(INFO, "[Fleet.GetFab] Requesting FAB for fab_hash=%s", request.hash_str)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
142
|
+
try:
|
|
143
|
+
res = message_handler.get_fab(
|
|
144
|
+
request=request,
|
|
145
|
+
ffs=self.ffs_factory.ffs(),
|
|
146
|
+
state=self.state_factory.state(),
|
|
147
|
+
)
|
|
148
|
+
except InvalidRunStatusException as e:
|
|
149
|
+
abort_grpc_context(e.message, context)
|
|
150
|
+
|
|
151
|
+
return res
|
|
@@ -19,8 +19,9 @@ import time
|
|
|
19
19
|
from typing import Optional
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
+
from flwr.common.constant import Status
|
|
22
23
|
from flwr.common.serde import fab_to_proto, user_config_to_proto
|
|
23
|
-
from flwr.common.typing import Fab
|
|
24
|
+
from flwr.common.typing import Fab, InvalidRunStatusException
|
|
24
25
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
25
26
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
26
27
|
CreateNodeRequest,
|
|
@@ -44,6 +45,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
44
45
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
45
46
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
46
47
|
from flwr.server.superlink.linkstate import LinkState
|
|
48
|
+
from flwr.server.superlink.utils import check_abort
|
|
47
49
|
|
|
48
50
|
|
|
49
51
|
def create_node(
|
|
@@ -98,6 +100,15 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
|
|
|
98
100
|
task_res: TaskRes = request.task_res_list[0]
|
|
99
101
|
# pylint: enable=no-member
|
|
100
102
|
|
|
103
|
+
# Abort if the run is not running
|
|
104
|
+
abort_msg = check_abort(
|
|
105
|
+
task_res.run_id,
|
|
106
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
107
|
+
state,
|
|
108
|
+
)
|
|
109
|
+
if abort_msg:
|
|
110
|
+
raise InvalidRunStatusException(abort_msg)
|
|
111
|
+
|
|
101
112
|
# Set pushed_at (timestamp in seconds)
|
|
102
113
|
task_res.task.pushed_at = time.time()
|
|
103
114
|
|
|
@@ -112,15 +123,22 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
|
|
|
112
123
|
return response
|
|
113
124
|
|
|
114
125
|
|
|
115
|
-
def get_run(
|
|
116
|
-
request: GetRunRequest, state: LinkState # pylint: disable=W0613
|
|
117
|
-
) -> GetRunResponse:
|
|
126
|
+
def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
|
|
118
127
|
"""Get run information."""
|
|
119
128
|
run = state.get_run(request.run_id)
|
|
120
129
|
|
|
121
130
|
if run is None:
|
|
122
131
|
return GetRunResponse()
|
|
123
132
|
|
|
133
|
+
# Abort if the run is not running
|
|
134
|
+
abort_msg = check_abort(
|
|
135
|
+
request.run_id,
|
|
136
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
137
|
+
state,
|
|
138
|
+
)
|
|
139
|
+
if abort_msg:
|
|
140
|
+
raise InvalidRunStatusException(abort_msg)
|
|
141
|
+
|
|
124
142
|
return GetRunResponse(
|
|
125
143
|
run=Run(
|
|
126
144
|
run_id=run.run_id,
|
|
@@ -133,9 +151,18 @@ def get_run(
|
|
|
133
151
|
|
|
134
152
|
|
|
135
153
|
def get_fab(
|
|
136
|
-
request: GetFabRequest, ffs: Ffs # pylint: disable=W0613
|
|
154
|
+
request: GetFabRequest, ffs: Ffs, state: LinkState # pylint: disable=W0613
|
|
137
155
|
) -> GetFabResponse:
|
|
138
156
|
"""Get FAB."""
|
|
157
|
+
# Abort if the run is not running
|
|
158
|
+
abort_msg = check_abort(
|
|
159
|
+
request.run_id,
|
|
160
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
161
|
+
state,
|
|
162
|
+
)
|
|
163
|
+
if abort_msg:
|
|
164
|
+
raise InvalidRunStatusException(abort_msg)
|
|
165
|
+
|
|
139
166
|
if result := ffs.get(request.hash_str):
|
|
140
167
|
fab = Fab(request.hash_str, result[0])
|
|
141
168
|
return GetFabResponse(fab=fab_to_proto(fab))
|
|
@@ -154,8 +154,11 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
|
|
|
154
154
|
# Get ffs from app
|
|
155
155
|
ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()
|
|
156
156
|
|
|
157
|
+
# Get state from app
|
|
158
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
159
|
+
|
|
157
160
|
# Handle message
|
|
158
|
-
return message_handler.get_fab(request=request, ffs=ffs)
|
|
161
|
+
return message_handler.get_fab(request=request, ffs=ffs, state=state)
|
|
159
162
|
|
|
160
163
|
|
|
161
164
|
routes = [
|