flwr-nightly 1.13.0.dev20241021__py3-none-any.whl → 1.13.0.dev20241022__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. flwr/client/app.py +11 -11
  2. flwr/client/node_state_tests.py +7 -8
  3. flwr/client/{node_state.py → run_info_store.py} +3 -3
  4. flwr/common/constant.py +2 -3
  5. flwr/server/app.py +51 -9
  6. flwr/server/driver/inmemory_driver.py +2 -2
  7. flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
  8. flwr/server/serverapp/app.py +20 -0
  9. flwr/server/superlink/driver/driver_grpc.py +2 -2
  10. flwr/server/superlink/driver/driver_servicer.py +9 -7
  11. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
  12. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
  13. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
  14. flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
  15. flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
  16. flwr/server/superlink/fleet/vce/vce_api.py +23 -23
  17. flwr/server/superlink/linkstate/__init__.py +28 -0
  18. flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +8 -8
  19. flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +10 -10
  20. flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
  21. flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +14 -14
  22. flwr/simulation/app.py +1 -1
  23. flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
  24. flwr/simulation/run_simulation.py +3 -3
  25. flwr/superexec/app.py +9 -2
  26. flwr/superexec/simulation.py +1 -1
  27. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241022.dist-info}/METADATA +1 -1
  28. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241022.dist-info}/RECORD +32 -30
  29. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241022.dist-info}/entry_points.txt +1 -0
  30. /flwr/server/superlink/{state → linkstate}/utils.py +0 -0
  31. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241022.dist-info}/LICENSE +0 -0
  32. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241022.dist-info}/WHEEL +0 -0
flwr/client/app.py CHANGED
@@ -52,15 +52,15 @@ from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
52
52
  from flwr.common.typing import Fab, Run, UserConfig
53
53
  from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server
54
54
  from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
55
- from flwr.server.superlink.state.utils import generate_rand_int_from_bytes
55
+ from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes
56
56
 
57
57
  from .clientapp.clientappio_servicer import ClientAppInputs, ClientAppIoServicer
58
58
  from .grpc_adapter_client.connection import grpc_adapter
59
59
  from .grpc_client.connection import grpc_connection
60
60
  from .grpc_rere_client.connection import grpc_request_response
61
61
  from .message_handler.message_handler import handle_control_message
62
- from .node_state import NodeState
63
62
  from .numpy_client import NumPyClient
63
+ from .run_info_store import DeprecatedRunInfoStore
64
64
 
65
65
  ISOLATION_MODE_SUBPROCESS = "subprocess"
66
66
  ISOLATION_MODE_PROCESS = "process"
@@ -364,8 +364,8 @@ def start_client_internal(
364
364
  on_backoff=_on_backoff,
365
365
  )
366
366
 
367
- # NodeState gets initialized when the first connection is established
368
- node_state: Optional[NodeState] = None
367
+ # DeprecatedRunInfoStore gets initialized when the first connection is established
368
+ run_info_store: Optional[DeprecatedRunInfoStore] = None
369
369
 
370
370
  runs: dict[int, Run] = {}
371
371
 
@@ -382,7 +382,7 @@ def start_client_internal(
382
382
  receive, send, create_node, delete_node, get_run, get_fab = conn
383
383
 
384
384
  # Register node when connecting the first time
385
- if node_state is None:
385
+ if run_info_store is None:
386
386
  if create_node is None:
387
387
  if transport not in ["grpc-bidi", None]:
388
388
  raise NotImplementedError(
@@ -391,7 +391,7 @@ def start_client_internal(
391
391
  )
392
392
  # gRPC-bidi doesn't have the concept of node_id,
393
393
  # so we set it to -1
394
- node_state = NodeState(
394
+ run_info_store = DeprecatedRunInfoStore(
395
395
  node_id=-1,
396
396
  node_config={},
397
397
  )
@@ -402,7 +402,7 @@ def start_client_internal(
402
402
  ) # pylint: disable=not-callable
403
403
  if node_id is None:
404
404
  raise ValueError("Node registration failed")
405
- node_state = NodeState(
405
+ run_info_store = DeprecatedRunInfoStore(
406
406
  node_id=node_id,
407
407
  node_config=node_config,
408
408
  )
@@ -461,7 +461,7 @@ def start_client_internal(
461
461
  run.fab_id, run.fab_version = fab_id, fab_version
462
462
 
463
463
  # Register context for this run
464
- node_state.register_context(
464
+ run_info_store.register_context(
465
465
  run_id=run_id,
466
466
  run=run,
467
467
  flwr_path=flwr_path,
@@ -469,7 +469,7 @@ def start_client_internal(
469
469
  )
470
470
 
471
471
  # Retrieve context for this run
472
- context = node_state.retrieve_context(run_id=run_id)
472
+ context = run_info_store.retrieve_context(run_id=run_id)
473
473
  # Create an error reply message that will never be used to prevent
474
474
  # the used-before-assignment linting error
475
475
  reply_message = message.create_error_reply(
@@ -542,7 +542,7 @@ def start_client_internal(
542
542
  # Raise exception, crash process
543
543
  raise ex
544
544
 
545
- # Don't update/change NodeState
545
+ # Don't update/change DeprecatedRunInfoStore
546
546
 
547
547
  e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
548
548
  # Ex fmt: "<class 'ZeroDivisionError'>:<'division by zero'>"
@@ -567,7 +567,7 @@ def start_client_internal(
567
567
  )
568
568
  else:
569
569
  # No exception, update node state
570
- node_state.update_context(
570
+ run_info_store.update_context(
571
571
  run_id=run_id,
572
572
  context=context,
573
573
  )
@@ -17,7 +17,7 @@
17
17
 
18
18
  from typing import cast
19
19
 
20
- from flwr.client.node_state import NodeState
20
+ from flwr.client.run_info_store import DeprecatedRunInfoStore
21
21
  from flwr.common import ConfigsRecord, Context
22
22
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
23
23
 
@@ -34,32 +34,31 @@ def _run_dummy_task(context: Context) -> Context:
34
34
 
35
35
 
36
36
  def test_multirun_in_node_state() -> None:
37
- """Test basic NodeState logic."""
37
+ """Test basic DeprecatedRunInfoStore logic."""
38
38
  # Tasks to perform
39
39
  tasks = [TaskIns(run_id=run_id) for run_id in [0, 1, 1, 2, 3, 2, 1, 5]]
40
40
  # the "tasks" is to count how many times each run is executed
41
41
  expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"}
42
42
 
43
- # NodeState
44
- node_state = NodeState(node_id=0, node_config={})
43
+ node_info_store = DeprecatedRunInfoStore(node_id=0, node_config={})
45
44
 
46
45
  for task in tasks:
47
46
  run_id = task.run_id
48
47
 
49
48
  # Register
50
- node_state.register_context(run_id=run_id)
49
+ node_info_store.register_context(run_id=run_id)
51
50
 
52
51
  # Get run state
53
- context = node_state.retrieve_context(run_id=run_id)
52
+ context = node_info_store.retrieve_context(run_id=run_id)
54
53
 
55
54
  # Run "task"
56
55
  updated_state = _run_dummy_task(context)
57
56
 
58
57
  # Update run state
59
- node_state.update_context(run_id=run_id, context=updated_state)
58
+ node_info_store.update_context(run_id=run_id, context=updated_state)
60
59
 
61
60
  # Verify values
62
- for run_id, run_info in node_state.run_infos.items():
61
+ for run_id, run_info in node_info_store.run_infos.items():
63
62
  assert (
64
63
  run_info.context.state.configs_records["counter"]["count"]
65
64
  == expected_values[run_id]
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Node state."""
15
+ """Deprecated Run Info Store."""
16
16
 
17
17
 
18
18
  from dataclasses import dataclass
@@ -36,7 +36,7 @@ class RunInfo:
36
36
  initial_run_config: UserConfig
37
37
 
38
38
 
39
- class NodeState:
39
+ class DeprecatedRunInfoStore:
40
40
  """State of a node where client nodes execute runs."""
41
41
 
42
42
  def __init__(
flwr/common/constant.py CHANGED
@@ -40,15 +40,14 @@ TRANSPORT_TYPES = [
40
40
  # Addresses
41
41
  # SuperNode
42
42
  CLIENTAPPIO_API_DEFAULT_ADDRESS = "0.0.0.0:9094"
43
- # SuperExec
44
- EXEC_API_DEFAULT_ADDRESS = "0.0.0.0:9093"
45
43
  # SuperLink
46
44
  DRIVER_API_DEFAULT_ADDRESS = "0.0.0.0:9091"
47
45
  FLEET_API_GRPC_RERE_DEFAULT_ADDRESS = "0.0.0.0:9092"
48
46
  FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS = (
49
47
  "[::]:8080" # IPv6 to keep start_server compatible
50
48
  )
51
- FLEET_API_REST_DEFAULT_ADDRESS = "0.0.0.0:9093"
49
+ FLEET_API_REST_DEFAULT_ADDRESS = "0.0.0.0:9095"
50
+ EXEC_API_DEFAULT_ADDRESS = "0.0.0.0:9093"
52
51
 
53
52
  # Constants for ping
54
53
  PING_DEFAULT_INTERVAL = 30
flwr/server/app.py CHANGED
@@ -35,9 +35,10 @@ from cryptography.hazmat.primitives.serialization import (
35
35
 
36
36
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
37
37
  from flwr.common.address import parse_address
38
- from flwr.common.config import get_flwr_dir
38
+ from flwr.common.config import get_flwr_dir, parse_config_args
39
39
  from flwr.common.constant import (
40
40
  DRIVER_API_DEFAULT_ADDRESS,
41
+ EXEC_API_DEFAULT_ADDRESS,
41
42
  FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
42
43
  FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
43
44
  FLEET_API_REST_DEFAULT_ADDRESS,
@@ -56,6 +57,8 @@ from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
56
57
  add_FleetServicer_to_server,
57
58
  )
58
59
  from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
60
+ from flwr.superexec.app import load_executor
61
+ from flwr.superexec.exec_grpc import run_superexec_api_grpc
59
62
 
60
63
  from .client_manager import ClientManager
61
64
  from .history import History
@@ -71,7 +74,7 @@ from .superlink.fleet.grpc_bidi.grpc_server import (
71
74
  )
72
75
  from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
73
76
  from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
74
- from .superlink.state import StateFactory
77
+ from .superlink.linkstate import LinkStateFactory
75
78
 
76
79
  DATABASE = ":flwr-in-memory-state:"
77
80
  BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
@@ -205,14 +208,15 @@ def run_superlink() -> None:
205
208
 
206
209
  event(EventType.RUN_SUPERLINK_ENTER)
207
210
 
208
- # Parse IP address
211
+ # Parse IP addresses
209
212
  driver_address, _, _ = _format_address(args.driver_api_address)
213
+ exec_address, _, _ = _format_address(args.exec_api_address)
210
214
 
211
215
  # Obtain certificates
212
216
  certificates = _try_obtain_certificates(args)
213
217
 
214
218
  # Initialize StateFactory
215
- state_factory = StateFactory(args.database)
219
+ state_factory = LinkStateFactory(args.database)
216
220
 
217
221
  # Initialize FfsFactory
218
222
  ffs_factory = FfsFactory(args.storage_dir)
@@ -224,8 +228,9 @@ def run_superlink() -> None:
224
228
  ffs_factory=ffs_factory,
225
229
  certificates=certificates,
226
230
  )
227
-
228
231
  grpc_servers = [driver_server]
232
+
233
+ # Start Fleet API
229
234
  bckg_threads = []
230
235
  if not args.fleet_api_address:
231
236
  if args.fleet_api_type in [
@@ -250,7 +255,6 @@ def run_superlink() -> None:
250
255
  )
251
256
  num_workers = 1
252
257
 
253
- # Start Fleet API
254
258
  if args.fleet_api_type == TRANSPORT_TYPE_REST:
255
259
  if (
256
260
  importlib.util.find_spec("requests")
@@ -318,6 +322,17 @@ def run_superlink() -> None:
318
322
  else:
319
323
  raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
320
324
 
325
+ # Start Exec API
326
+ exec_server: grpc.Server = run_superexec_api_grpc(
327
+ address=exec_address,
328
+ executor=load_executor(args),
329
+ certificates=certificates,
330
+ config=parse_config_args(
331
+ [args.executor_config] if args.executor_config else args.executor_config
332
+ ),
333
+ )
334
+ grpc_servers.append(exec_server)
335
+
321
336
  # Graceful shutdown
322
337
  register_exit_handlers(
323
338
  event_type=EventType.RUN_SUPERLINK_LEAVE,
@@ -489,7 +504,7 @@ def _try_obtain_certificates(
489
504
 
490
505
  def _run_fleet_api_grpc_rere(
491
506
  address: str,
492
- state_factory: StateFactory,
507
+ state_factory: LinkStateFactory,
493
508
  ffs_factory: FfsFactory,
494
509
  certificates: Optional[tuple[bytes, bytes, bytes]],
495
510
  interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
@@ -517,7 +532,7 @@ def _run_fleet_api_grpc_rere(
517
532
 
518
533
  def _run_fleet_api_grpc_adapter(
519
534
  address: str,
520
- state_factory: StateFactory,
535
+ state_factory: LinkStateFactory,
521
536
  ffs_factory: FfsFactory,
522
537
  certificates: Optional[tuple[bytes, bytes, bytes]],
523
538
  ) -> grpc.Server:
@@ -548,7 +563,7 @@ def _run_fleet_api_rest(
548
563
  port: int,
549
564
  ssl_keyfile: Optional[str],
550
565
  ssl_certfile: Optional[str],
551
- state_factory: StateFactory,
566
+ state_factory: LinkStateFactory,
552
567
  ffs_factory: FfsFactory,
553
568
  num_workers: int,
554
569
  ) -> None:
@@ -587,6 +602,7 @@ def _parse_args_run_superlink() -> argparse.ArgumentParser:
587
602
  _add_args_common(parser=parser)
588
603
  _add_args_driver_api(parser=parser)
589
604
  _add_args_fleet_api(parser=parser)
605
+ _add_args_exec_api(parser=parser)
590
606
 
591
607
  return parser
592
608
 
@@ -681,3 +697,29 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
681
697
  type=int,
682
698
  help="Set the number of concurrent workers for the Fleet API server.",
683
699
  )
700
+
701
+
702
+ def _add_args_exec_api(parser: argparse.ArgumentParser) -> None:
703
+ """Add command line arguments for Exec API."""
704
+ parser.add_argument(
705
+ "--exec-api-address",
706
+ help="Exec API server address (IPv4, IPv6, or a domain name)",
707
+ default=EXEC_API_DEFAULT_ADDRESS,
708
+ )
709
+ parser.add_argument(
710
+ "--executor",
711
+ help="For example: `deployment:exec` or `project.package.module:wrapper.exec`. "
712
+ "The default is `flwr.superexec.deployment:executor`",
713
+ default="flwr.superexec.deployment:executor",
714
+ )
715
+ parser.add_argument(
716
+ "--executor-dir",
717
+ help="The directory for the executor.",
718
+ default=".",
719
+ )
720
+ parser.add_argument(
721
+ "--executor-config",
722
+ help="Key-value pairs for the executor config, separated by spaces. "
723
+ "For example:\n\n`--executor-config 'verbose=true "
724
+ 'root-certificates="certificates/superlink-ca.crt"\'`',
725
+ )
@@ -25,7 +25,7 @@ from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
25
25
  from flwr.common.serde import message_from_taskres, message_to_taskins
26
26
  from flwr.common.typing import Run
27
27
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
28
- from flwr.server.superlink.state import StateFactory
28
+ from flwr.server.superlink.linkstate import LinkStateFactory
29
29
 
30
30
  from .driver import Driver
31
31
 
@@ -46,7 +46,7 @@ class InMemoryDriver(Driver):
46
46
  def __init__(
47
47
  self,
48
48
  run_id: int,
49
- state_factory: StateFactory,
49
+ state_factory: LinkStateFactory,
50
50
  pull_interval: float = 0.1,
51
51
  ) -> None:
52
52
  self._run_id = run_id
@@ -12,17 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower server state."""
15
+ """Flower AppIO service."""
16
16
 
17
17
 
18
- from .in_memory_state import InMemoryState as InMemoryState
19
- from .sqlite_state import SqliteState as SqliteState
20
- from .state import State as State
21
- from .state_factory import StateFactory as StateFactory
18
+ from .app import flwr_serverapp as flwr_serverapp
22
19
 
23
20
  __all__ = [
24
- "InMemoryState",
25
- "SqliteState",
26
- "State",
27
- "StateFactory",
21
+ "flwr_serverapp",
28
22
  ]
@@ -0,0 +1,20 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower ServerApp process."""
16
+
17
+
18
+ def flwr_serverapp() -> None:
19
+ """Run process-isolated Flower ServerApp."""
20
+ raise NotImplementedError()
@@ -25,7 +25,7 @@ from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
25
25
  add_DriverServicer_to_server,
26
26
  )
27
27
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
28
- from flwr.server.superlink.state import StateFactory
28
+ from flwr.server.superlink.linkstate import LinkStateFactory
29
29
 
30
30
  from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
31
31
  from .driver_servicer import DriverServicer
@@ -33,7 +33,7 @@ from .driver_servicer import DriverServicer
33
33
 
34
34
  def run_driver_api_grpc(
35
35
  address: str,
36
- state_factory: StateFactory,
36
+ state_factory: LinkStateFactory,
37
37
  ffs_factory: FfsFactory,
38
38
  certificates: Optional[tuple[bytes, bytes, bytes]],
39
39
  ) -> grpc.Server:
@@ -51,14 +51,16 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
51
51
  from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
52
52
  from flwr.server.superlink.ffs.ffs import Ffs
53
53
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
54
- from flwr.server.superlink.state import State, StateFactory
54
+ from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
55
55
  from flwr.server.utils.validator import validate_task_ins_or_res
56
56
 
57
57
 
58
58
  class DriverServicer(driver_pb2_grpc.DriverServicer):
59
59
  """Driver API servicer."""
60
60
 
61
- def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
61
+ def __init__(
62
+ self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
63
+ ) -> None:
62
64
  self.state_factory = state_factory
63
65
  self.ffs_factory = ffs_factory
64
66
 
@@ -67,7 +69,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
67
69
  ) -> GetNodesResponse:
68
70
  """Get available nodes."""
69
71
  log(DEBUG, "DriverServicer.GetNodes")
70
- state: State = self.state_factory.state()
72
+ state: LinkState = self.state_factory.state()
71
73
  all_ids: set[int] = state.get_nodes(request.run_id)
72
74
  nodes: list[Node] = [
73
75
  Node(node_id=node_id, anonymous=False) for node_id in all_ids
@@ -79,7 +81,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
79
81
  ) -> CreateRunResponse:
80
82
  """Create run ID."""
81
83
  log(DEBUG, "DriverServicer.CreateRun")
82
- state: State = self.state_factory.state()
84
+ state: LinkState = self.state_factory.state()
83
85
  if request.HasField("fab"):
84
86
  fab = fab_from_proto(request.fab)
85
87
  ffs: Ffs = self.ffs_factory.ffs()
@@ -116,7 +118,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
116
118
  _raise_if(bool(validation_errors), ", ".join(validation_errors))
117
119
 
118
120
  # Init state
119
- state: State = self.state_factory.state()
121
+ state: LinkState = self.state_factory.state()
120
122
 
121
123
  # Store each TaskIns
122
124
  task_ids: list[Optional[UUID]] = []
@@ -138,7 +140,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
138
140
  task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
139
141
 
140
142
  # Init state
141
- state: State = self.state_factory.state()
143
+ state: LinkState = self.state_factory.state()
142
144
 
143
145
  # Register callback
144
146
  def on_rpc_done() -> None:
@@ -167,7 +169,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
167
169
  log(DEBUG, "DriverServicer.GetRun")
168
170
 
169
171
  # Init state
170
- state: State = self.state_factory.state()
172
+ state: LinkState = self.state_factory.state()
171
173
 
172
174
  # Retrieve run information
173
175
  run = state.get_run(request.run_id)
@@ -48,7 +48,7 @@ from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
48
48
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
49
49
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
50
50
  from flwr.server.superlink.fleet.message_handler import message_handler
51
- from flwr.server.superlink.state import StateFactory
51
+ from flwr.server.superlink.linkstate import LinkStateFactory
52
52
 
53
53
  T = TypeVar("T", bound=GrpcMessage)
54
54
 
@@ -77,7 +77,9 @@ def _handle(
77
77
  class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
78
78
  """Fleet API via GrpcAdapter servicer."""
79
79
 
80
- def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
80
+ def __init__(
81
+ self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
82
+ ) -> None:
81
83
  self.state_factory = state_factory
82
84
  self.ffs_factory = ffs_factory
83
85
 
@@ -37,13 +37,15 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
37
37
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
38
38
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
39
39
  from flwr.server.superlink.fleet.message_handler import message_handler
40
- from flwr.server.superlink.state import StateFactory
40
+ from flwr.server.superlink.linkstate import LinkStateFactory
41
41
 
42
42
 
43
43
  class FleetServicer(fleet_pb2_grpc.FleetServicer):
44
44
  """Fleet API servicer."""
45
45
 
46
- def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
46
+ def __init__(
47
+ self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
48
+ ) -> None:
47
49
  self.state_factory = state_factory
48
50
  self.ffs_factory = ffs_factory
49
51
 
@@ -45,7 +45,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
45
45
  )
46
46
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
47
47
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
48
- from flwr.server.superlink.state import State
48
+ from flwr.server.superlink.linkstate import LinkState
49
49
 
50
50
  _PUBLIC_KEY_HEADER = "public-key"
51
51
  _AUTH_TOKEN_HEADER = "auth-token"
@@ -84,7 +84,7 @@ def _get_value_from_tuples(
84
84
  class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
85
85
  """Server interceptor for node authentication."""
86
86
 
87
- def __init__(self, state: State):
87
+ def __init__(self, state: LinkState):
88
88
  self.state = state
89
89
 
90
90
  self.node_public_keys = state.get_node_public_keys()
@@ -43,12 +43,12 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
43
43
  )
44
44
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
45
45
  from flwr.server.superlink.ffs.ffs import Ffs
46
- from flwr.server.superlink.state import State
46
+ from flwr.server.superlink.linkstate import LinkState
47
47
 
48
48
 
49
49
  def create_node(
50
50
  request: CreateNodeRequest, # pylint: disable=unused-argument
51
- state: State,
51
+ state: LinkState,
52
52
  ) -> CreateNodeResponse:
53
53
  """."""
54
54
  # Create node
@@ -56,7 +56,7 @@ def create_node(
56
56
  return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
57
57
 
58
58
 
59
- def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
59
+ def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
60
60
  """."""
61
61
  # Validate node_id
62
62
  if request.node.anonymous or request.node.node_id == 0:
@@ -69,14 +69,14 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
69
69
 
70
70
  def ping(
71
71
  request: PingRequest, # pylint: disable=unused-argument
72
- state: State, # pylint: disable=unused-argument
72
+ state: LinkState, # pylint: disable=unused-argument
73
73
  ) -> PingResponse:
74
74
  """."""
75
75
  res = state.acknowledge_ping(request.node.node_id, request.ping_interval)
76
76
  return PingResponse(success=res)
77
77
 
78
78
 
79
- def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse:
79
+ def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse:
80
80
  """Pull TaskIns handler."""
81
81
  # Get node_id if client node is not anonymous
82
82
  node = request.node # pylint: disable=no-member
@@ -92,7 +92,7 @@ def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsRespo
92
92
  return response
93
93
 
94
94
 
95
- def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResResponse:
95
+ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResResponse:
96
96
  """Push TaskRes handler."""
97
97
  # pylint: disable=no-member
98
98
  task_res: TaskRes = request.task_res_list[0]
@@ -113,7 +113,7 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
113
113
 
114
114
 
115
115
  def get_run(
116
- request: GetRunRequest, state: State # pylint: disable=W0613
116
+ request: GetRunRequest, state: LinkState # pylint: disable=W0613
117
117
  ) -> GetRunResponse:
118
118
  """Get run information."""
119
119
  run = state.get_run(request.run_id)
@@ -40,7 +40,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
40
40
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
41
41
  from flwr.server.superlink.ffs.ffs import Ffs
42
42
  from flwr.server.superlink.fleet.message_handler import message_handler
43
- from flwr.server.superlink.state import State
43
+ from flwr.server.superlink.linkstate import LinkState
44
44
 
45
45
  try:
46
46
  from starlette.applications import Starlette
@@ -90,7 +90,7 @@ def rest_request_response(
90
90
  async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
91
91
  """Create Node."""
92
92
  # Get state from app
93
- state: State = app.state.STATE_FACTORY.state()
93
+ state: LinkState = app.state.STATE_FACTORY.state()
94
94
 
95
95
  # Handle message
96
96
  return message_handler.create_node(request=request, state=state)
@@ -100,7 +100,7 @@ async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
100
100
  async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
101
101
  """Delete Node Id."""
102
102
  # Get state from app
103
- state: State = app.state.STATE_FACTORY.state()
103
+ state: LinkState = app.state.STATE_FACTORY.state()
104
104
 
105
105
  # Handle message
106
106
  return message_handler.delete_node(request=request, state=state)
@@ -110,7 +110,7 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
110
110
  async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
111
111
  """Pull TaskIns."""
112
112
  # Get state from app
113
- state: State = app.state.STATE_FACTORY.state()
113
+ state: LinkState = app.state.STATE_FACTORY.state()
114
114
 
115
115
  # Handle message
116
116
  return message_handler.pull_task_ins(request=request, state=state)
@@ -121,7 +121,7 @@ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
121
121
  async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
122
122
  """Push TaskRes."""
123
123
  # Get state from app
124
- state: State = app.state.STATE_FACTORY.state()
124
+ state: LinkState = app.state.STATE_FACTORY.state()
125
125
 
126
126
  # Handle message
127
127
  return message_handler.push_task_res(request=request, state=state)
@@ -131,7 +131,7 @@ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
131
131
  async def ping(request: PingRequest) -> PingResponse:
132
132
  """Ping."""
133
133
  # Get state from app
134
- state: State = app.state.STATE_FACTORY.state()
134
+ state: LinkState = app.state.STATE_FACTORY.state()
135
135
 
136
136
  # Handle message
137
137
  return message_handler.ping(request=request, state=state)
@@ -141,7 +141,7 @@ async def ping(request: PingRequest) -> PingResponse:
141
141
  async def get_run(request: GetRunRequest) -> GetRunResponse:
142
142
  """GetRun."""
143
143
  # Get state from app
144
- state: State = app.state.STATE_FACTORY.state()
144
+ state: LinkState = app.state.STATE_FACTORY.state()
145
145
 
146
146
  # Handle message
147
147
  return message_handler.get_run(request=request, state=state)