flwr-nightly 1.13.0.dev20241019__py3-none-any.whl → 1.13.0.dev20241022__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/app.py +11 -11
- flwr/client/node_state_tests.py +7 -8
- flwr/client/{node_state.py → run_info_store.py} +3 -3
- flwr/common/constant.py +2 -3
- flwr/server/app.py +51 -9
- flwr/server/driver/inmemory_driver.py +2 -2
- flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
- flwr/server/serverapp/app.py +20 -0
- flwr/server/superlink/driver/driver_grpc.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +9 -7
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
- flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
- flwr/server/superlink/fleet/vce/vce_api.py +23 -23
- flwr/server/superlink/linkstate/__init__.py +28 -0
- flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +8 -8
- flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +10 -10
- flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
- flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +14 -14
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- flwr/simulation/run_simulation.py +3 -3
- flwr/superexec/app.py +9 -2
- flwr/superexec/simulation.py +1 -1
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241022.dist-info}/METADATA +1 -1
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241022.dist-info}/RECORD +32 -30
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241022.dist-info}/entry_points.txt +1 -0
- /flwr/server/superlink/{state → linkstate}/utils.py +0 -0
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241022.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241022.dist-info}/WHEEL +0 -0
flwr/client/app.py
CHANGED
|
@@ -52,15 +52,15 @@ from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
|
|
|
52
52
|
from flwr.common.typing import Fab, Run, UserConfig
|
|
53
53
|
from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server
|
|
54
54
|
from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
|
55
|
-
from flwr.server.superlink.
|
|
55
|
+
from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes
|
|
56
56
|
|
|
57
57
|
from .clientapp.clientappio_servicer import ClientAppInputs, ClientAppIoServicer
|
|
58
58
|
from .grpc_adapter_client.connection import grpc_adapter
|
|
59
59
|
from .grpc_client.connection import grpc_connection
|
|
60
60
|
from .grpc_rere_client.connection import grpc_request_response
|
|
61
61
|
from .message_handler.message_handler import handle_control_message
|
|
62
|
-
from .node_state import NodeState
|
|
63
62
|
from .numpy_client import NumPyClient
|
|
63
|
+
from .run_info_store import DeprecatedRunInfoStore
|
|
64
64
|
|
|
65
65
|
ISOLATION_MODE_SUBPROCESS = "subprocess"
|
|
66
66
|
ISOLATION_MODE_PROCESS = "process"
|
|
@@ -364,8 +364,8 @@ def start_client_internal(
|
|
|
364
364
|
on_backoff=_on_backoff,
|
|
365
365
|
)
|
|
366
366
|
|
|
367
|
-
#
|
|
368
|
-
|
|
367
|
+
# DeprecatedRunInfoStore gets initialized when the first connection is established
|
|
368
|
+
run_info_store: Optional[DeprecatedRunInfoStore] = None
|
|
369
369
|
|
|
370
370
|
runs: dict[int, Run] = {}
|
|
371
371
|
|
|
@@ -382,7 +382,7 @@ def start_client_internal(
|
|
|
382
382
|
receive, send, create_node, delete_node, get_run, get_fab = conn
|
|
383
383
|
|
|
384
384
|
# Register node when connecting the first time
|
|
385
|
-
if
|
|
385
|
+
if run_info_store is None:
|
|
386
386
|
if create_node is None:
|
|
387
387
|
if transport not in ["grpc-bidi", None]:
|
|
388
388
|
raise NotImplementedError(
|
|
@@ -391,7 +391,7 @@ def start_client_internal(
|
|
|
391
391
|
)
|
|
392
392
|
# gRPC-bidi doesn't have the concept of node_id,
|
|
393
393
|
# so we set it to -1
|
|
394
|
-
|
|
394
|
+
run_info_store = DeprecatedRunInfoStore(
|
|
395
395
|
node_id=-1,
|
|
396
396
|
node_config={},
|
|
397
397
|
)
|
|
@@ -402,7 +402,7 @@ def start_client_internal(
|
|
|
402
402
|
) # pylint: disable=not-callable
|
|
403
403
|
if node_id is None:
|
|
404
404
|
raise ValueError("Node registration failed")
|
|
405
|
-
|
|
405
|
+
run_info_store = DeprecatedRunInfoStore(
|
|
406
406
|
node_id=node_id,
|
|
407
407
|
node_config=node_config,
|
|
408
408
|
)
|
|
@@ -461,7 +461,7 @@ def start_client_internal(
|
|
|
461
461
|
run.fab_id, run.fab_version = fab_id, fab_version
|
|
462
462
|
|
|
463
463
|
# Register context for this run
|
|
464
|
-
|
|
464
|
+
run_info_store.register_context(
|
|
465
465
|
run_id=run_id,
|
|
466
466
|
run=run,
|
|
467
467
|
flwr_path=flwr_path,
|
|
@@ -469,7 +469,7 @@ def start_client_internal(
|
|
|
469
469
|
)
|
|
470
470
|
|
|
471
471
|
# Retrieve context for this run
|
|
472
|
-
context =
|
|
472
|
+
context = run_info_store.retrieve_context(run_id=run_id)
|
|
473
473
|
# Create an error reply message that will never be used to prevent
|
|
474
474
|
# the used-before-assignment linting error
|
|
475
475
|
reply_message = message.create_error_reply(
|
|
@@ -542,7 +542,7 @@ def start_client_internal(
|
|
|
542
542
|
# Raise exception, crash process
|
|
543
543
|
raise ex
|
|
544
544
|
|
|
545
|
-
# Don't update/change
|
|
545
|
+
# Don't update/change DeprecatedRunInfoStore
|
|
546
546
|
|
|
547
547
|
e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
|
|
548
548
|
# Ex fmt: "<class 'ZeroDivisionError'>:<'division by zero'>"
|
|
@@ -567,7 +567,7 @@ def start_client_internal(
|
|
|
567
567
|
)
|
|
568
568
|
else:
|
|
569
569
|
# No exception, update node state
|
|
570
|
-
|
|
570
|
+
run_info_store.update_context(
|
|
571
571
|
run_id=run_id,
|
|
572
572
|
context=context,
|
|
573
573
|
)
|
flwr/client/node_state_tests.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
from typing import cast
|
|
19
19
|
|
|
20
|
-
from flwr.client.
|
|
20
|
+
from flwr.client.run_info_store import DeprecatedRunInfoStore
|
|
21
21
|
from flwr.common import ConfigsRecord, Context
|
|
22
22
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
23
23
|
|
|
@@ -34,32 +34,31 @@ def _run_dummy_task(context: Context) -> Context:
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
def test_multirun_in_node_state() -> None:
|
|
37
|
-
"""Test basic
|
|
37
|
+
"""Test basic DeprecatedRunInfoStore logic."""
|
|
38
38
|
# Tasks to perform
|
|
39
39
|
tasks = [TaskIns(run_id=run_id) for run_id in [0, 1, 1, 2, 3, 2, 1, 5]]
|
|
40
40
|
# the "tasks" is to count how many times each run is executed
|
|
41
41
|
expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"}
|
|
42
42
|
|
|
43
|
-
|
|
44
|
-
node_state = NodeState(node_id=0, node_config={})
|
|
43
|
+
node_info_store = DeprecatedRunInfoStore(node_id=0, node_config={})
|
|
45
44
|
|
|
46
45
|
for task in tasks:
|
|
47
46
|
run_id = task.run_id
|
|
48
47
|
|
|
49
48
|
# Register
|
|
50
|
-
|
|
49
|
+
node_info_store.register_context(run_id=run_id)
|
|
51
50
|
|
|
52
51
|
# Get run state
|
|
53
|
-
context =
|
|
52
|
+
context = node_info_store.retrieve_context(run_id=run_id)
|
|
54
53
|
|
|
55
54
|
# Run "task"
|
|
56
55
|
updated_state = _run_dummy_task(context)
|
|
57
56
|
|
|
58
57
|
# Update run state
|
|
59
|
-
|
|
58
|
+
node_info_store.update_context(run_id=run_id, context=updated_state)
|
|
60
59
|
|
|
61
60
|
# Verify values
|
|
62
|
-
for run_id, run_info in
|
|
61
|
+
for run_id, run_info in node_info_store.run_infos.items():
|
|
63
62
|
assert (
|
|
64
63
|
run_info.context.state.configs_records["counter"]["count"]
|
|
65
64
|
== expected_values[run_id]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""Deprecated Run Info Store."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from dataclasses import dataclass
|
|
@@ -36,7 +36,7 @@ class RunInfo:
|
|
|
36
36
|
initial_run_config: UserConfig
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
class
|
|
39
|
+
class DeprecatedRunInfoStore:
|
|
40
40
|
"""State of a node where client nodes execute runs."""
|
|
41
41
|
|
|
42
42
|
def __init__(
|
flwr/common/constant.py
CHANGED
|
@@ -40,15 +40,14 @@ TRANSPORT_TYPES = [
|
|
|
40
40
|
# Addresses
|
|
41
41
|
# SuperNode
|
|
42
42
|
CLIENTAPPIO_API_DEFAULT_ADDRESS = "0.0.0.0:9094"
|
|
43
|
-
# SuperExec
|
|
44
|
-
EXEC_API_DEFAULT_ADDRESS = "0.0.0.0:9093"
|
|
45
43
|
# SuperLink
|
|
46
44
|
DRIVER_API_DEFAULT_ADDRESS = "0.0.0.0:9091"
|
|
47
45
|
FLEET_API_GRPC_RERE_DEFAULT_ADDRESS = "0.0.0.0:9092"
|
|
48
46
|
FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS = (
|
|
49
47
|
"[::]:8080" # IPv6 to keep start_server compatible
|
|
50
48
|
)
|
|
51
|
-
FLEET_API_REST_DEFAULT_ADDRESS = "0.0.0.0:
|
|
49
|
+
FLEET_API_REST_DEFAULT_ADDRESS = "0.0.0.0:9095"
|
|
50
|
+
EXEC_API_DEFAULT_ADDRESS = "0.0.0.0:9093"
|
|
52
51
|
|
|
53
52
|
# Constants for ping
|
|
54
53
|
PING_DEFAULT_INTERVAL = 30
|
flwr/server/app.py
CHANGED
|
@@ -35,9 +35,10 @@ from cryptography.hazmat.primitives.serialization import (
|
|
|
35
35
|
|
|
36
36
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
37
37
|
from flwr.common.address import parse_address
|
|
38
|
-
from flwr.common.config import get_flwr_dir
|
|
38
|
+
from flwr.common.config import get_flwr_dir, parse_config_args
|
|
39
39
|
from flwr.common.constant import (
|
|
40
40
|
DRIVER_API_DEFAULT_ADDRESS,
|
|
41
|
+
EXEC_API_DEFAULT_ADDRESS,
|
|
41
42
|
FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
42
43
|
FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
|
|
43
44
|
FLEET_API_REST_DEFAULT_ADDRESS,
|
|
@@ -56,6 +57,8 @@ from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
|
|
56
57
|
add_FleetServicer_to_server,
|
|
57
58
|
)
|
|
58
59
|
from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
|
|
60
|
+
from flwr.superexec.app import load_executor
|
|
61
|
+
from flwr.superexec.exec_grpc import run_superexec_api_grpc
|
|
59
62
|
|
|
60
63
|
from .client_manager import ClientManager
|
|
61
64
|
from .history import History
|
|
@@ -71,7 +74,7 @@ from .superlink.fleet.grpc_bidi.grpc_server import (
|
|
|
71
74
|
)
|
|
72
75
|
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
|
73
76
|
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
|
|
74
|
-
from .superlink.
|
|
77
|
+
from .superlink.linkstate import LinkStateFactory
|
|
75
78
|
|
|
76
79
|
DATABASE = ":flwr-in-memory-state:"
|
|
77
80
|
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
|
|
@@ -205,14 +208,15 @@ def run_superlink() -> None:
|
|
|
205
208
|
|
|
206
209
|
event(EventType.RUN_SUPERLINK_ENTER)
|
|
207
210
|
|
|
208
|
-
# Parse IP
|
|
211
|
+
# Parse IP addresses
|
|
209
212
|
driver_address, _, _ = _format_address(args.driver_api_address)
|
|
213
|
+
exec_address, _, _ = _format_address(args.exec_api_address)
|
|
210
214
|
|
|
211
215
|
# Obtain certificates
|
|
212
216
|
certificates = _try_obtain_certificates(args)
|
|
213
217
|
|
|
214
218
|
# Initialize StateFactory
|
|
215
|
-
state_factory =
|
|
219
|
+
state_factory = LinkStateFactory(args.database)
|
|
216
220
|
|
|
217
221
|
# Initialize FfsFactory
|
|
218
222
|
ffs_factory = FfsFactory(args.storage_dir)
|
|
@@ -224,8 +228,9 @@ def run_superlink() -> None:
|
|
|
224
228
|
ffs_factory=ffs_factory,
|
|
225
229
|
certificates=certificates,
|
|
226
230
|
)
|
|
227
|
-
|
|
228
231
|
grpc_servers = [driver_server]
|
|
232
|
+
|
|
233
|
+
# Start Fleet API
|
|
229
234
|
bckg_threads = []
|
|
230
235
|
if not args.fleet_api_address:
|
|
231
236
|
if args.fleet_api_type in [
|
|
@@ -250,7 +255,6 @@ def run_superlink() -> None:
|
|
|
250
255
|
)
|
|
251
256
|
num_workers = 1
|
|
252
257
|
|
|
253
|
-
# Start Fleet API
|
|
254
258
|
if args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
255
259
|
if (
|
|
256
260
|
importlib.util.find_spec("requests")
|
|
@@ -318,6 +322,17 @@ def run_superlink() -> None:
|
|
|
318
322
|
else:
|
|
319
323
|
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
|
|
320
324
|
|
|
325
|
+
# Start Exec API
|
|
326
|
+
exec_server: grpc.Server = run_superexec_api_grpc(
|
|
327
|
+
address=exec_address,
|
|
328
|
+
executor=load_executor(args),
|
|
329
|
+
certificates=certificates,
|
|
330
|
+
config=parse_config_args(
|
|
331
|
+
[args.executor_config] if args.executor_config else args.executor_config
|
|
332
|
+
),
|
|
333
|
+
)
|
|
334
|
+
grpc_servers.append(exec_server)
|
|
335
|
+
|
|
321
336
|
# Graceful shutdown
|
|
322
337
|
register_exit_handlers(
|
|
323
338
|
event_type=EventType.RUN_SUPERLINK_LEAVE,
|
|
@@ -489,7 +504,7 @@ def _try_obtain_certificates(
|
|
|
489
504
|
|
|
490
505
|
def _run_fleet_api_grpc_rere(
|
|
491
506
|
address: str,
|
|
492
|
-
state_factory:
|
|
507
|
+
state_factory: LinkStateFactory,
|
|
493
508
|
ffs_factory: FfsFactory,
|
|
494
509
|
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
495
510
|
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
@@ -517,7 +532,7 @@ def _run_fleet_api_grpc_rere(
|
|
|
517
532
|
|
|
518
533
|
def _run_fleet_api_grpc_adapter(
|
|
519
534
|
address: str,
|
|
520
|
-
state_factory:
|
|
535
|
+
state_factory: LinkStateFactory,
|
|
521
536
|
ffs_factory: FfsFactory,
|
|
522
537
|
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
523
538
|
) -> grpc.Server:
|
|
@@ -548,7 +563,7 @@ def _run_fleet_api_rest(
|
|
|
548
563
|
port: int,
|
|
549
564
|
ssl_keyfile: Optional[str],
|
|
550
565
|
ssl_certfile: Optional[str],
|
|
551
|
-
state_factory:
|
|
566
|
+
state_factory: LinkStateFactory,
|
|
552
567
|
ffs_factory: FfsFactory,
|
|
553
568
|
num_workers: int,
|
|
554
569
|
) -> None:
|
|
@@ -587,6 +602,7 @@ def _parse_args_run_superlink() -> argparse.ArgumentParser:
|
|
|
587
602
|
_add_args_common(parser=parser)
|
|
588
603
|
_add_args_driver_api(parser=parser)
|
|
589
604
|
_add_args_fleet_api(parser=parser)
|
|
605
|
+
_add_args_exec_api(parser=parser)
|
|
590
606
|
|
|
591
607
|
return parser
|
|
592
608
|
|
|
@@ -681,3 +697,29 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
|
|
|
681
697
|
type=int,
|
|
682
698
|
help="Set the number of concurrent workers for the Fleet API server.",
|
|
683
699
|
)
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
def _add_args_exec_api(parser: argparse.ArgumentParser) -> None:
|
|
703
|
+
"""Add command line arguments for Exec API."""
|
|
704
|
+
parser.add_argument(
|
|
705
|
+
"--exec-api-address",
|
|
706
|
+
help="Exec API server address (IPv4, IPv6, or a domain name)",
|
|
707
|
+
default=EXEC_API_DEFAULT_ADDRESS,
|
|
708
|
+
)
|
|
709
|
+
parser.add_argument(
|
|
710
|
+
"--executor",
|
|
711
|
+
help="For example: `deployment:exec` or `project.package.module:wrapper.exec`. "
|
|
712
|
+
"The default is `flwr.superexec.deployment:executor`",
|
|
713
|
+
default="flwr.superexec.deployment:executor",
|
|
714
|
+
)
|
|
715
|
+
parser.add_argument(
|
|
716
|
+
"--executor-dir",
|
|
717
|
+
help="The directory for the executor.",
|
|
718
|
+
default=".",
|
|
719
|
+
)
|
|
720
|
+
parser.add_argument(
|
|
721
|
+
"--executor-config",
|
|
722
|
+
help="Key-value pairs for the executor config, separated by spaces. "
|
|
723
|
+
"For example:\n\n`--executor-config 'verbose=true "
|
|
724
|
+
'root-certificates="certificates/superlink-ca.crt"\'`',
|
|
725
|
+
)
|
|
@@ -25,7 +25,7 @@ from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
|
25
25
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
26
26
|
from flwr.common.typing import Run
|
|
27
27
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
28
|
-
from flwr.server.superlink.
|
|
28
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
29
29
|
|
|
30
30
|
from .driver import Driver
|
|
31
31
|
|
|
@@ -46,7 +46,7 @@ class InMemoryDriver(Driver):
|
|
|
46
46
|
def __init__(
|
|
47
47
|
self,
|
|
48
48
|
run_id: int,
|
|
49
|
-
state_factory:
|
|
49
|
+
state_factory: LinkStateFactory,
|
|
50
50
|
pull_interval: float = 0.1,
|
|
51
51
|
) -> None:
|
|
52
52
|
self._run_id = run_id
|
|
@@ -12,17 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower
|
|
15
|
+
"""Flower AppIO service."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from .
|
|
19
|
-
from .sqlite_state import SqliteState as SqliteState
|
|
20
|
-
from .state import State as State
|
|
21
|
-
from .state_factory import StateFactory as StateFactory
|
|
18
|
+
from .app import flwr_serverapp as flwr_serverapp
|
|
22
19
|
|
|
23
20
|
__all__ = [
|
|
24
|
-
"
|
|
25
|
-
"SqliteState",
|
|
26
|
-
"State",
|
|
27
|
-
"StateFactory",
|
|
21
|
+
"flwr_serverapp",
|
|
28
22
|
]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower ServerApp process."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def flwr_serverapp() -> None:
|
|
19
|
+
"""Run process-isolated Flower ServerApp."""
|
|
20
|
+
raise NotImplementedError()
|
|
@@ -25,7 +25,7 @@ from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
|
|
|
25
25
|
add_DriverServicer_to_server,
|
|
26
26
|
)
|
|
27
27
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
28
|
-
from flwr.server.superlink.
|
|
28
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
29
29
|
|
|
30
30
|
from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
|
31
31
|
from .driver_servicer import DriverServicer
|
|
@@ -33,7 +33,7 @@ from .driver_servicer import DriverServicer
|
|
|
33
33
|
|
|
34
34
|
def run_driver_api_grpc(
|
|
35
35
|
address: str,
|
|
36
|
-
state_factory:
|
|
36
|
+
state_factory: LinkStateFactory,
|
|
37
37
|
ffs_factory: FfsFactory,
|
|
38
38
|
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
39
39
|
) -> grpc.Server:
|
|
@@ -51,14 +51,16 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
51
51
|
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
52
52
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
53
53
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
54
|
-
from flwr.server.superlink.
|
|
54
|
+
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
55
55
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
59
59
|
"""Driver API servicer."""
|
|
60
60
|
|
|
61
|
-
def __init__(
|
|
61
|
+
def __init__(
|
|
62
|
+
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
63
|
+
) -> None:
|
|
62
64
|
self.state_factory = state_factory
|
|
63
65
|
self.ffs_factory = ffs_factory
|
|
64
66
|
|
|
@@ -67,7 +69,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
67
69
|
) -> GetNodesResponse:
|
|
68
70
|
"""Get available nodes."""
|
|
69
71
|
log(DEBUG, "DriverServicer.GetNodes")
|
|
70
|
-
state:
|
|
72
|
+
state: LinkState = self.state_factory.state()
|
|
71
73
|
all_ids: set[int] = state.get_nodes(request.run_id)
|
|
72
74
|
nodes: list[Node] = [
|
|
73
75
|
Node(node_id=node_id, anonymous=False) for node_id in all_ids
|
|
@@ -79,7 +81,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
79
81
|
) -> CreateRunResponse:
|
|
80
82
|
"""Create run ID."""
|
|
81
83
|
log(DEBUG, "DriverServicer.CreateRun")
|
|
82
|
-
state:
|
|
84
|
+
state: LinkState = self.state_factory.state()
|
|
83
85
|
if request.HasField("fab"):
|
|
84
86
|
fab = fab_from_proto(request.fab)
|
|
85
87
|
ffs: Ffs = self.ffs_factory.ffs()
|
|
@@ -116,7 +118,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
116
118
|
_raise_if(bool(validation_errors), ", ".join(validation_errors))
|
|
117
119
|
|
|
118
120
|
# Init state
|
|
119
|
-
state:
|
|
121
|
+
state: LinkState = self.state_factory.state()
|
|
120
122
|
|
|
121
123
|
# Store each TaskIns
|
|
122
124
|
task_ids: list[Optional[UUID]] = []
|
|
@@ -138,7 +140,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
138
140
|
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
|
139
141
|
|
|
140
142
|
# Init state
|
|
141
|
-
state:
|
|
143
|
+
state: LinkState = self.state_factory.state()
|
|
142
144
|
|
|
143
145
|
# Register callback
|
|
144
146
|
def on_rpc_done() -> None:
|
|
@@ -167,7 +169,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
167
169
|
log(DEBUG, "DriverServicer.GetRun")
|
|
168
170
|
|
|
169
171
|
# Init state
|
|
170
|
-
state:
|
|
172
|
+
state: LinkState = self.state_factory.state()
|
|
171
173
|
|
|
172
174
|
# Retrieve run information
|
|
173
175
|
run = state.get_run(request.run_id)
|
|
@@ -48,7 +48,7 @@ from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
|
48
48
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
49
49
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
50
50
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
51
|
-
from flwr.server.superlink.
|
|
51
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
52
52
|
|
|
53
53
|
T = TypeVar("T", bound=GrpcMessage)
|
|
54
54
|
|
|
@@ -77,7 +77,9 @@ def _handle(
|
|
|
77
77
|
class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
|
|
78
78
|
"""Fleet API via GrpcAdapter servicer."""
|
|
79
79
|
|
|
80
|
-
def __init__(
|
|
80
|
+
def __init__(
|
|
81
|
+
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
82
|
+
) -> None:
|
|
81
83
|
self.state_factory = state_factory
|
|
82
84
|
self.ffs_factory = ffs_factory
|
|
83
85
|
|
|
@@ -37,13 +37,15 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
37
37
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
38
38
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
39
39
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
40
|
-
from flwr.server.superlink.
|
|
40
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
44
44
|
"""Fleet API servicer."""
|
|
45
45
|
|
|
46
|
-
def __init__(
|
|
46
|
+
def __init__(
|
|
47
|
+
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
48
|
+
) -> None:
|
|
47
49
|
self.state_factory = state_factory
|
|
48
50
|
self.ffs_factory = ffs_factory
|
|
49
51
|
|
|
@@ -45,7 +45,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
45
45
|
)
|
|
46
46
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
47
47
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
48
|
-
from flwr.server.superlink.
|
|
48
|
+
from flwr.server.superlink.linkstate import LinkState
|
|
49
49
|
|
|
50
50
|
_PUBLIC_KEY_HEADER = "public-key"
|
|
51
51
|
_AUTH_TOKEN_HEADER = "auth-token"
|
|
@@ -84,7 +84,7 @@ def _get_value_from_tuples(
|
|
|
84
84
|
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
85
85
|
"""Server interceptor for node authentication."""
|
|
86
86
|
|
|
87
|
-
def __init__(self, state:
|
|
87
|
+
def __init__(self, state: LinkState):
|
|
88
88
|
self.state = state
|
|
89
89
|
|
|
90
90
|
self.node_public_keys = state.get_node_public_keys()
|
|
@@ -43,12 +43,12 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
43
43
|
)
|
|
44
44
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
45
45
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
46
|
-
from flwr.server.superlink.
|
|
46
|
+
from flwr.server.superlink.linkstate import LinkState
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
def create_node(
|
|
50
50
|
request: CreateNodeRequest, # pylint: disable=unused-argument
|
|
51
|
-
state:
|
|
51
|
+
state: LinkState,
|
|
52
52
|
) -> CreateNodeResponse:
|
|
53
53
|
"""."""
|
|
54
54
|
# Create node
|
|
@@ -56,7 +56,7 @@ def create_node(
|
|
|
56
56
|
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
def delete_node(request: DeleteNodeRequest, state:
|
|
59
|
+
def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
|
|
60
60
|
"""."""
|
|
61
61
|
# Validate node_id
|
|
62
62
|
if request.node.anonymous or request.node.node_id == 0:
|
|
@@ -69,14 +69,14 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
|
|
|
69
69
|
|
|
70
70
|
def ping(
|
|
71
71
|
request: PingRequest, # pylint: disable=unused-argument
|
|
72
|
-
state:
|
|
72
|
+
state: LinkState, # pylint: disable=unused-argument
|
|
73
73
|
) -> PingResponse:
|
|
74
74
|
"""."""
|
|
75
75
|
res = state.acknowledge_ping(request.node.node_id, request.ping_interval)
|
|
76
76
|
return PingResponse(success=res)
|
|
77
77
|
|
|
78
78
|
|
|
79
|
-
def pull_task_ins(request: PullTaskInsRequest, state:
|
|
79
|
+
def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse:
|
|
80
80
|
"""Pull TaskIns handler."""
|
|
81
81
|
# Get node_id if client node is not anonymous
|
|
82
82
|
node = request.node # pylint: disable=no-member
|
|
@@ -92,7 +92,7 @@ def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsRespo
|
|
|
92
92
|
return response
|
|
93
93
|
|
|
94
94
|
|
|
95
|
-
def push_task_res(request: PushTaskResRequest, state:
|
|
95
|
+
def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResResponse:
|
|
96
96
|
"""Push TaskRes handler."""
|
|
97
97
|
# pylint: disable=no-member
|
|
98
98
|
task_res: TaskRes = request.task_res_list[0]
|
|
@@ -113,7 +113,7 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
|
|
|
113
113
|
|
|
114
114
|
|
|
115
115
|
def get_run(
|
|
116
|
-
request: GetRunRequest, state:
|
|
116
|
+
request: GetRunRequest, state: LinkState # pylint: disable=W0613
|
|
117
117
|
) -> GetRunResponse:
|
|
118
118
|
"""Get run information."""
|
|
119
119
|
run = state.get_run(request.run_id)
|
|
@@ -40,7 +40,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
40
40
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
41
41
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
42
42
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
43
|
-
from flwr.server.superlink.
|
|
43
|
+
from flwr.server.superlink.linkstate import LinkState
|
|
44
44
|
|
|
45
45
|
try:
|
|
46
46
|
from starlette.applications import Starlette
|
|
@@ -90,7 +90,7 @@ def rest_request_response(
|
|
|
90
90
|
async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
|
|
91
91
|
"""Create Node."""
|
|
92
92
|
# Get state from app
|
|
93
|
-
state:
|
|
93
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
94
94
|
|
|
95
95
|
# Handle message
|
|
96
96
|
return message_handler.create_node(request=request, state=state)
|
|
@@ -100,7 +100,7 @@ async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
|
|
|
100
100
|
async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
101
101
|
"""Delete Node Id."""
|
|
102
102
|
# Get state from app
|
|
103
|
-
state:
|
|
103
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
104
104
|
|
|
105
105
|
# Handle message
|
|
106
106
|
return message_handler.delete_node(request=request, state=state)
|
|
@@ -110,7 +110,7 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
|
110
110
|
async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
|
|
111
111
|
"""Pull TaskIns."""
|
|
112
112
|
# Get state from app
|
|
113
|
-
state:
|
|
113
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
114
114
|
|
|
115
115
|
# Handle message
|
|
116
116
|
return message_handler.pull_task_ins(request=request, state=state)
|
|
@@ -121,7 +121,7 @@ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
|
|
|
121
121
|
async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
|
122
122
|
"""Push TaskRes."""
|
|
123
123
|
# Get state from app
|
|
124
|
-
state:
|
|
124
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
125
125
|
|
|
126
126
|
# Handle message
|
|
127
127
|
return message_handler.push_task_res(request=request, state=state)
|
|
@@ -131,7 +131,7 @@ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
|
|
131
131
|
async def ping(request: PingRequest) -> PingResponse:
|
|
132
132
|
"""Ping."""
|
|
133
133
|
# Get state from app
|
|
134
|
-
state:
|
|
134
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
135
135
|
|
|
136
136
|
# Handle message
|
|
137
137
|
return message_handler.ping(request=request, state=state)
|
|
@@ -141,7 +141,7 @@ async def ping(request: PingRequest) -> PingResponse:
|
|
|
141
141
|
async def get_run(request: GetRunRequest) -> GetRunResponse:
|
|
142
142
|
"""GetRun."""
|
|
143
143
|
# Get state from app
|
|
144
|
-
state:
|
|
144
|
+
state: LinkState = app.state.STATE_FACTORY.state()
|
|
145
145
|
|
|
146
146
|
# Handle message
|
|
147
147
|
return message_handler.get_run(request=request, state=state)
|