flwr-nightly 1.14.0.dev20241204__py3-none-any.whl → 1.14.0.dev20241216__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (101) hide show
  1. flwr/cli/app.py +5 -0
  2. flwr/cli/build.py +1 -0
  3. flwr/cli/cli_user_auth_interceptor.py +86 -0
  4. flwr/cli/config_utils.py +19 -2
  5. flwr/cli/example.py +1 -0
  6. flwr/cli/install.py +1 -0
  7. flwr/cli/log.py +11 -31
  8. flwr/cli/login/__init__.py +22 -0
  9. flwr/cli/login/login.py +81 -0
  10. flwr/cli/ls.py +25 -55
  11. flwr/cli/new/__init__.py +1 -0
  12. flwr/cli/new/new.py +2 -1
  13. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  14. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -2
  15. flwr/cli/run/__init__.py +1 -0
  16. flwr/cli/run/run.py +17 -39
  17. flwr/cli/stop.py +129 -0
  18. flwr/cli/utils.py +96 -1
  19. flwr/client/app.py +14 -3
  20. flwr/client/client.py +1 -0
  21. flwr/client/clientapp/app.py +4 -1
  22. flwr/client/clientapp/utils.py +1 -0
  23. flwr/client/grpc_adapter_client/connection.py +1 -1
  24. flwr/client/grpc_client/connection.py +1 -1
  25. flwr/client/grpc_rere_client/connection.py +13 -7
  26. flwr/client/message_handler/message_handler.py +1 -0
  27. flwr/client/mod/comms_mods.py +1 -0
  28. flwr/client/mod/localdp_mod.py +1 -1
  29. flwr/client/nodestate/__init__.py +1 -0
  30. flwr/client/nodestate/nodestate.py +1 -0
  31. flwr/client/nodestate/nodestate_factory.py +1 -0
  32. flwr/client/rest_client/connection.py +3 -3
  33. flwr/client/supernode/app.py +1 -0
  34. flwr/common/address.py +1 -0
  35. flwr/common/args.py +1 -0
  36. flwr/common/auth_plugin/__init__.py +24 -0
  37. flwr/common/auth_plugin/auth_plugin.py +111 -0
  38. flwr/common/config.py +3 -1
  39. flwr/common/constant.py +6 -1
  40. flwr/common/logger.py +17 -1
  41. flwr/common/message.py +1 -0
  42. flwr/common/object_ref.py +57 -54
  43. flwr/common/pyproject.py +1 -0
  44. flwr/common/record/__init__.py +1 -0
  45. flwr/common/record/parametersrecord.py +1 -0
  46. flwr/common/retry_invoker.py +77 -0
  47. flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
  48. flwr/common/telemetry.py +2 -1
  49. flwr/common/typing.py +12 -0
  50. flwr/common/version.py +1 -0
  51. flwr/proto/exec_pb2.py +27 -3
  52. flwr/proto/exec_pb2.pyi +103 -0
  53. flwr/proto/exec_pb2_grpc.py +102 -0
  54. flwr/proto/exec_pb2_grpc.pyi +39 -0
  55. flwr/proto/fab_pb2.py +4 -4
  56. flwr/proto/fab_pb2.pyi +4 -1
  57. flwr/proto/serverappio_pb2.py +18 -18
  58. flwr/proto/serverappio_pb2.pyi +8 -2
  59. flwr/proto/serverappio_pb2_grpc.py +34 -0
  60. flwr/proto/serverappio_pb2_grpc.pyi +13 -0
  61. flwr/proto/simulationio_pb2.py +2 -2
  62. flwr/proto/simulationio_pb2_grpc.py +34 -0
  63. flwr/proto/simulationio_pb2_grpc.pyi +13 -0
  64. flwr/server/app.py +52 -1
  65. flwr/server/compat/app_utils.py +7 -1
  66. flwr/server/driver/grpc_driver.py +11 -63
  67. flwr/server/driver/inmemory_driver.py +5 -1
  68. flwr/server/serverapp/app.py +9 -2
  69. flwr/server/strategy/dpfedavg_fixed.py +1 -0
  70. flwr/server/superlink/driver/serverappio_grpc.py +1 -0
  71. flwr/server/superlink/driver/serverappio_servicer.py +72 -22
  72. flwr/server/superlink/ffs/disk_ffs.py +1 -0
  73. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
  74. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
  76. flwr/server/superlink/fleet/message_handler/message_handler.py +32 -5
  77. flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
  78. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  79. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
  80. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
  81. flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -30
  82. flwr/server/superlink/linkstate/linkstate.py +13 -2
  83. flwr/server/superlink/linkstate/sqlite_linkstate.py +24 -44
  84. flwr/server/superlink/simulation/simulationio_servicer.py +20 -0
  85. flwr/server/superlink/utils.py +65 -0
  86. flwr/simulation/app.py +1 -0
  87. flwr/simulation/ray_transport/ray_actor.py +1 -0
  88. flwr/simulation/ray_transport/utils.py +1 -0
  89. flwr/simulation/run_simulation.py +1 -15
  90. flwr/simulation/simulationio_connection.py +3 -0
  91. flwr/superexec/app.py +1 -0
  92. flwr/superexec/deployment.py +1 -0
  93. flwr/superexec/exec_grpc.py +19 -1
  94. flwr/superexec/exec_servicer.py +76 -2
  95. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  96. flwr/superexec/executor.py +1 -0
  97. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/METADATA +8 -7
  98. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/RECORD +101 -93
  99. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/LICENSE +0 -0
  100. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/WHEEL +0 -0
  101. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/entry_points.txt +0 -0
flwr/server/app.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower server app."""
16
16
 
17
+
17
18
  import argparse
18
19
  import csv
19
20
  import importlib.util
@@ -24,9 +25,10 @@ from collections.abc import Sequence
24
25
  from logging import DEBUG, INFO, WARN
25
26
  from pathlib import Path
26
27
  from time import sleep
27
- from typing import Optional
28
+ from typing import Any, Optional
28
29
 
29
30
  import grpc
31
+ import yaml
30
32
  from cryptography.exceptions import UnsupportedAlgorithm
31
33
  from cryptography.hazmat.primitives.asymmetric import ec
32
34
  from cryptography.hazmat.primitives.serialization import (
@@ -37,8 +39,10 @@ from cryptography.hazmat.primitives.serialization import (
37
39
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
38
40
  from flwr.common.address import parse_address
39
41
  from flwr.common.args import try_obtain_server_certificates
42
+ from flwr.common.auth_plugin import ExecAuthPlugin
40
43
  from flwr.common.config import get_flwr_dir, parse_config_args
41
44
  from flwr.common.constant import (
45
+ AUTH_TYPE,
42
46
  CLIENT_OCTET,
43
47
  EXEC_API_DEFAULT_SERVER_ADDRESS,
44
48
  FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
@@ -88,6 +92,19 @@ DATABASE = ":flwr-in-memory-state:"
88
92
  BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
89
93
 
90
94
 
95
+ try:
96
+ from flwr.ee import add_ee_args_superlink, get_exec_auth_plugins
97
+ except ImportError:
98
+
99
+ # pylint: disable-next=unused-argument
100
+ def add_ee_args_superlink(parser: argparse.ArgumentParser) -> None:
101
+ """Add EE-specific arguments to the parser."""
102
+
103
+ def get_exec_auth_plugins() -> dict[str, type[ExecAuthPlugin]]:
104
+ """Return all Exec API authentication plugins."""
105
+ raise NotImplementedError("No authentication plugins are currently supported.")
106
+
107
+
91
108
  def start_server( # pylint: disable=too-many-arguments,too-many-locals
92
109
  *,
93
110
  server_address: str = FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
@@ -246,6 +263,12 @@ def run_superlink() -> None:
246
263
  # Obtain certificates
247
264
  certificates = try_obtain_server_certificates(args, args.fleet_api_type)
248
265
 
266
+ user_auth_config = _try_obtain_user_auth_config(args)
267
+ auth_plugin: Optional[ExecAuthPlugin] = None
268
+ # user_auth_config is None only if the args.user_auth_config is not provided
269
+ if user_auth_config is not None:
270
+ auth_plugin = _try_obtain_exec_auth_plugin(user_auth_config)
271
+
249
272
  # Initialize StateFactory
250
273
  state_factory = LinkStateFactory(args.database)
251
274
 
@@ -263,6 +286,7 @@ def run_superlink() -> None:
263
286
  config=parse_config_args(
264
287
  [args.executor_config] if args.executor_config else args.executor_config
265
288
  ),
289
+ auth_plugin=auth_plugin,
266
290
  )
267
291
  grpc_servers = [exec_server]
268
292
 
@@ -559,6 +583,32 @@ def _try_setup_node_authentication(
559
583
  )
560
584
 
561
585
 
586
+ def _try_obtain_user_auth_config(args: argparse.Namespace) -> Optional[dict[str, Any]]:
587
+ if getattr(args, "user_auth_config", None) is not None:
588
+ with open(args.user_auth_config, encoding="utf-8") as file:
589
+ config: dict[str, Any] = yaml.safe_load(file)
590
+ return config
591
+ return None
592
+
593
+
594
+ def _try_obtain_exec_auth_plugin(config: dict[str, Any]) -> Optional[ExecAuthPlugin]:
595
+ auth_config: dict[str, Any] = config.get("authentication", {})
596
+ auth_type: str = auth_config.get(AUTH_TYPE, "")
597
+ try:
598
+ all_plugins: dict[str, type[ExecAuthPlugin]] = get_exec_auth_plugins()
599
+ auth_plugin_class = all_plugins[auth_type]
600
+ return auth_plugin_class(config=auth_config)
601
+ except KeyError:
602
+ if auth_type != "":
603
+ sys.exit(
604
+ f'Authentication type "{auth_type}" is not supported. '
605
+ "Please provide a valid authentication type in the configuration."
606
+ )
607
+ sys.exit("No authentication type is provided in the configuration.")
608
+ except NotImplementedError:
609
+ sys.exit("No authentication plugins are currently supported.")
610
+
611
+
562
612
  def _run_fleet_api_grpc_rere(
563
613
  address: str,
564
614
  state_factory: LinkStateFactory,
@@ -657,6 +707,7 @@ def _parse_args_run_superlink() -> argparse.ArgumentParser:
657
707
  )
658
708
 
659
709
  _add_args_common(parser=parser)
710
+ add_ee_args_superlink(parser=parser)
660
711
  _add_args_serverappio_api(parser=parser)
661
712
  _add_args_fleet_api(parser=parser)
662
713
  _add_args_exec_api(parser=parser)
@@ -17,6 +17,8 @@
17
17
 
18
18
  import threading
19
19
 
20
+ from flwr.common.typing import RunNotRunningException
21
+
20
22
  from ..client_manager import ClientManager
21
23
  from ..compat.driver_client_proxy import DriverClientProxy
22
24
  from ..driver import Driver
@@ -74,7 +76,11 @@ def _update_client_manager(
74
76
  # Loop until the driver is disconnected
75
77
  registered_nodes: dict[int, DriverClientProxy] = {}
76
78
  while not f_stop.is_set():
77
- all_node_ids = set(driver.get_node_ids())
79
+ try:
80
+ all_node_ids = set(driver.get_node_ids())
81
+ except RunNotRunningException:
82
+ f_stop.set()
83
+ break
78
84
  dead_nodes = set(registered_nodes).difference(all_node_ids)
79
85
  new_nodes = all_node_ids.difference(registered_nodes)
80
86
 
@@ -14,19 +14,20 @@
14
14
  # ==============================================================================
15
15
  """Flower gRPC Driver."""
16
16
 
17
+
17
18
  import time
18
19
  import warnings
19
20
  from collections.abc import Iterable
20
- from logging import DEBUG, INFO, WARN, WARNING
21
- from typing import Any, Optional, cast
21
+ from logging import DEBUG, WARNING
22
+ from typing import Optional, cast
22
23
 
23
24
  import grpc
24
25
 
25
26
  from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
26
- from flwr.common.constant import MAX_RETRY_DELAY, SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS
27
+ from flwr.common.constant import SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS
27
28
  from flwr.common.grpc import create_channel
28
29
  from flwr.common.logger import log
29
- from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
30
+ from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
30
31
  from flwr.common.serde import message_from_taskres, message_to_taskins, run_from_proto
31
32
  from flwr.common.typing import Run
32
33
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
@@ -203,7 +204,9 @@ class GrpcDriver(Driver):
203
204
  task_ins_list.append(taskins)
204
205
  # Call GrpcDriverStub method
205
206
  res: PushTaskInsResponse = self._stub.PushTaskIns(
206
- PushTaskInsRequest(task_ins_list=task_ins_list)
207
+ PushTaskInsRequest(
208
+ task_ins_list=task_ins_list, run_id=cast(Run, self._run).run_id
209
+ )
207
210
  )
208
211
  return list(res.task_ids)
209
212
 
@@ -215,7 +218,9 @@ class GrpcDriver(Driver):
215
218
  """
216
219
  # Pull TaskRes
217
220
  res: PullTaskResResponse = self._stub.PullTaskRes(
218
- PullTaskResRequest(node=self.node, task_ids=message_ids)
221
+ PullTaskResRequest(
222
+ node=self.node, task_ids=message_ids, run_id=cast(Run, self._run).run_id
223
+ )
219
224
  )
220
225
  # Convert TaskRes to Message
221
226
  msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
@@ -258,60 +263,3 @@ class GrpcDriver(Driver):
258
263
  return
259
264
  # Disconnect
260
265
  self._disconnect()
261
-
262
-
263
- def _make_simple_grpc_retry_invoker() -> RetryInvoker:
264
- """Create a simple gRPC retry invoker."""
265
-
266
- def _on_sucess(retry_state: RetryState) -> None:
267
- if retry_state.tries > 1:
268
- log(
269
- INFO,
270
- "Connection successful after %.2f seconds and %s tries.",
271
- retry_state.elapsed_time,
272
- retry_state.tries,
273
- )
274
-
275
- def _on_backoff(retry_state: RetryState) -> None:
276
- if retry_state.tries == 1:
277
- log(WARN, "Connection attempt failed, retrying...")
278
- else:
279
- log(
280
- WARN,
281
- "Connection attempt failed, retrying in %.2f seconds",
282
- retry_state.actual_wait,
283
- )
284
-
285
- def _on_giveup(retry_state: RetryState) -> None:
286
- if retry_state.tries > 1:
287
- log(
288
- WARN,
289
- "Giving up reconnection after %.2f seconds and %s tries.",
290
- retry_state.elapsed_time,
291
- retry_state.tries,
292
- )
293
-
294
- return RetryInvoker(
295
- wait_gen_factory=lambda: exponential(max_delay=MAX_RETRY_DELAY),
296
- recoverable_exceptions=grpc.RpcError,
297
- max_tries=None,
298
- max_time=None,
299
- on_success=_on_sucess,
300
- on_backoff=_on_backoff,
301
- on_giveup=_on_giveup,
302
- should_giveup=lambda e: e.code() != grpc.StatusCode.UNAVAILABLE, # type: ignore
303
- )
304
-
305
-
306
- def _wrap_stub(stub: ServerAppIoStub, retry_invoker: RetryInvoker) -> None:
307
- """Wrap the gRPC stub with a retry invoker."""
308
-
309
- def make_lambda(original_method: Any) -> Any:
310
- return lambda *args, **kwargs: retry_invoker.invoke(
311
- original_method, *args, **kwargs
312
- )
313
-
314
- for method_name in vars(stub):
315
- method = getattr(stub, method_name)
316
- if callable(method):
317
- setattr(stub, method_name, make_lambda(method))
@@ -142,7 +142,11 @@ class InMemoryDriver(Driver):
142
142
  # Pull TaskRes
143
143
  task_res_list = self.state.get_task_res(task_ids=msg_ids)
144
144
  # Delete tasks in state
145
- self.state.delete_tasks(msg_ids)
145
+ # Delete the TaskIns/TaskRes pairs if TaskRes is found
146
+ task_ins_ids_to_delete = {
147
+ UUID(task_res.task.ancestry[0]) for task_res in task_res_list
148
+ }
149
+ self.state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
146
150
  # Convert TaskRes to Message
147
151
  msgs = [message_from_taskres(taskres) for taskres in task_res_list]
148
152
  return msgs
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower ServerApp process."""
16
16
 
17
+
17
18
  import argparse
18
19
  import sys
19
20
  from logging import DEBUG, ERROR, INFO
@@ -50,7 +51,7 @@ from flwr.common.serde import (
50
51
  run_from_proto,
51
52
  run_status_to_proto,
52
53
  )
53
- from flwr.common.typing import RunStatus
54
+ from flwr.common.typing import RunNotRunningException, RunStatus
54
55
  from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
55
56
  from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
56
57
  PullServerAppInputsRequest,
@@ -96,7 +97,7 @@ def flwr_serverapp() -> None:
96
97
  restore_output()
97
98
 
98
99
 
99
- def run_serverapp( # pylint: disable=R0914, disable=W0212
100
+ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
100
101
  serverappio_api_address: str,
101
102
  log_queue: Queue[Optional[str]],
102
103
  run_once: bool,
@@ -187,6 +188,12 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
187
188
 
188
189
  run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
189
190
 
191
+ except RunNotRunningException:
192
+ log(INFO, "")
193
+ log(INFO, "Run ID %s stopped.", run.run_id)
194
+ log(INFO, "")
195
+ run_status = None
196
+
190
197
  except Exception as ex: # pylint: disable=broad-exception-caught
191
198
  exc_entity = "ServerApp"
192
199
  log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
@@ -17,6 +17,7 @@
17
17
  Paper: arxiv.org/pdf/1710.06963.pdf
18
18
  """
19
19
 
20
+
20
21
  from typing import Optional, Union
21
22
 
22
23
  from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """ServerAppIo gRPC API."""
16
16
 
17
+
17
18
  from logging import INFO
18
19
  from typing import Optional
19
20
 
@@ -32,6 +32,7 @@ from flwr.common.serde import (
32
32
  fab_from_proto,
33
33
  fab_to_proto,
34
34
  run_status_from_proto,
35
+ run_status_to_proto,
35
36
  run_to_proto,
36
37
  user_config_from_proto,
37
38
  )
@@ -48,6 +49,8 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
48
49
  CreateRunResponse,
49
50
  GetRunRequest,
50
51
  GetRunResponse,
52
+ GetRunStatusRequest,
53
+ GetRunStatusResponse,
51
54
  UpdateRunStatusRequest,
52
55
  UpdateRunStatusResponse,
53
56
  )
@@ -67,6 +70,7 @@ from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
67
70
  from flwr.server.superlink.ffs.ffs import Ffs
68
71
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
69
72
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
73
+ from flwr.server.superlink.utils import abort_if
70
74
  from flwr.server.utils.validator import validate_task_ins_or_res
71
75
 
72
76
 
@@ -85,7 +89,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
85
89
  ) -> GetNodesResponse:
86
90
  """Get available nodes."""
87
91
  log(DEBUG, "ServerAppIoServicer.GetNodes")
92
+
93
+ # Init state
88
94
  state: LinkState = self.state_factory.state()
95
+
96
+ # Abort if the run is not running
97
+ abort_if(
98
+ request.run_id,
99
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
100
+ state,
101
+ context,
102
+ )
103
+
89
104
  all_ids: set[int] = state.get_nodes(request.run_id)
90
105
  nodes: list[Node] = [
91
106
  Node(node_id=node_id, anonymous=False) for node_id in all_ids
@@ -123,6 +138,17 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
123
138
  """Push a set of TaskIns."""
124
139
  log(DEBUG, "ServerAppIoServicer.PushTaskIns")
125
140
 
141
+ # Init state
142
+ state: LinkState = self.state_factory.state()
143
+
144
+ # Abort if the run is not running
145
+ abort_if(
146
+ request.run_id,
147
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
148
+ state,
149
+ context,
150
+ )
151
+
126
152
  # Set pushed_at (timestamp in seconds)
127
153
  pushed_at = time.time()
128
154
  for task_ins in request.task_ins_list:
@@ -134,9 +160,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
134
160
  validation_errors = validate_task_ins_or_res(task_ins)
135
161
  _raise_if(bool(validation_errors), ", ".join(validation_errors))
136
162
 
137
- # Init state
138
- state: LinkState = self.state_factory.state()
139
-
140
163
  # Store each TaskIns
141
164
  task_ids: list[Optional[UUID]] = []
142
165
  for task_ins in request.task_ins_list:
@@ -153,33 +176,29 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
153
176
  """Pull a set of TaskRes."""
154
177
  log(DEBUG, "ServerAppIoServicer.PullTaskRes")
155
178
 
156
- # Convert each task_id str to UUID
157
- task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
158
-
159
179
  # Init state
160
180
  state: LinkState = self.state_factory.state()
161
181
 
162
- # Register callback
163
- def on_rpc_done() -> None:
164
- log(
165
- DEBUG,
166
- "ServerAppIoServicer.PullTaskRes callback: delete TaskIns/TaskRes",
167
- )
168
-
169
- if context.is_active():
170
- return
171
- if context.code() != grpc.StatusCode.OK:
172
- return
173
-
174
- # Delete delivered TaskIns and TaskRes
175
- state.delete_tasks(task_ids=task_ids)
182
+ # Abort if the run is not running
183
+ abort_if(
184
+ request.run_id,
185
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
186
+ state,
187
+ context,
188
+ )
176
189
 
177
- context.add_callback(on_rpc_done)
190
+ # Convert each task_id str to UUID
191
+ task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
178
192
 
179
193
  # Read from state
180
194
  task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
181
195
 
182
- context.set_code(grpc.StatusCode.OK)
196
+ # Delete the TaskIns/TaskRes pairs if TaskRes is found
197
+ task_ins_ids_to_delete = {
198
+ UUID(task_res.task.ancestry[0]) for task_res in task_res_list
199
+ }
200
+ state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
201
+
183
202
  return PullTaskResResponse(task_res_list=task_res_list)
184
203
 
185
204
  def GetRun(
@@ -255,7 +274,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
255
274
  ) -> PushServerAppOutputsResponse:
256
275
  """Push ServerApp process outputs."""
257
276
  log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
277
+
278
+ # Init state
258
279
  state = self.state_factory.state()
280
+
281
+ # Abort if the run is not running
282
+ abort_if(
283
+ request.run_id,
284
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
285
+ state,
286
+ context,
287
+ )
288
+
259
289
  state.set_serverapp_context(request.run_id, context_from_proto(request.context))
260
290
  return PushServerAppOutputsResponse()
261
291
 
@@ -264,8 +294,13 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
264
294
  ) -> UpdateRunStatusResponse:
265
295
  """Update the status of a run."""
266
296
  log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
297
+
298
+ # Init state
267
299
  state = self.state_factory.state()
268
300
 
301
+ # Abort if the run is finished
302
+ abort_if(request.run_id, [Status.FINISHED], state, context)
303
+
269
304
  # Update the run status
270
305
  state.update_run_status(
271
306
  run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
@@ -284,6 +319,21 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
284
319
  state.add_serverapp_log(request.run_id, merged_logs)
285
320
  return PushLogsResponse()
286
321
 
322
+ def GetRunStatus(
323
+ self, request: GetRunStatusRequest, context: grpc.ServicerContext
324
+ ) -> GetRunStatusResponse:
325
+ """Get the status of a run."""
326
+ log(DEBUG, "ServerAppIoServicer.GetRunStatus")
327
+ state = self.state_factory.state()
328
+
329
+ # Get run status from LinkState
330
+ run_statuses = state.get_run_status(set(request.run_ids))
331
+ run_status_dict = {
332
+ run_id: run_status_to_proto(run_status)
333
+ for run_id, run_status in run_statuses.items()
334
+ }
335
+ return GetRunStatusResponse(run_status_dict=run_status_dict)
336
+
287
337
 
288
338
  def _raise_if(validation_error: bool, detail: str) -> None:
289
339
  if validation_error:
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Disk based Flower File Storage."""
16
16
 
17
+
17
18
  import hashlib
18
19
  import json
19
20
  from pathlib import Path
@@ -158,4 +158,5 @@ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
158
158
  return message_handler.get_fab(
159
159
  request=request,
160
160
  ffs=self.ffs_factory.ffs(),
161
+ state=self.state_factory.state(),
161
162
  )
@@ -18,6 +18,7 @@ Relevant knowledge for reading this modules code:
18
18
  - https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
19
19
  """
20
20
 
21
+
21
22
  import uuid
22
23
  from collections.abc import Iterator
23
24
  from typing import Callable
@@ -20,6 +20,7 @@ from logging import DEBUG, INFO
20
20
  import grpc
21
21
 
22
22
  from flwr.common.logger import log
23
+ from flwr.common.typing import InvalidRunStatusException
23
24
  from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611
24
25
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
25
26
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
@@ -38,6 +39,7 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=
38
39
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
39
40
  from flwr.server.superlink.fleet.message_handler import message_handler
40
41
  from flwr.server.superlink.linkstate import LinkStateFactory
42
+ from flwr.server.superlink.utils import abort_grpc_context
41
43
 
42
44
 
43
45
  class FleetServicer(fleet_pb2_grpc.FleetServicer):
@@ -105,27 +107,45 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
105
107
  )
106
108
  else:
107
109
  log(INFO, "[Fleet.PushTaskRes] No task results to push")
108
- return message_handler.push_task_res(
109
- request=request,
110
- state=self.state_factory.state(),
111
- )
110
+
111
+ try:
112
+ res = message_handler.push_task_res(
113
+ request=request,
114
+ state=self.state_factory.state(),
115
+ )
116
+ except InvalidRunStatusException as e:
117
+ abort_grpc_context(e.message, context)
118
+
119
+ return res
112
120
 
113
121
  def GetRun(
114
122
  self, request: GetRunRequest, context: grpc.ServicerContext
115
123
  ) -> GetRunResponse:
116
124
  """Get run information."""
117
125
  log(INFO, "[Fleet.GetRun] Requesting `Run` for run_id=%s", request.run_id)
118
- return message_handler.get_run(
119
- request=request,
120
- state=self.state_factory.state(),
121
- )
126
+
127
+ try:
128
+ res = message_handler.get_run(
129
+ request=request,
130
+ state=self.state_factory.state(),
131
+ )
132
+ except InvalidRunStatusException as e:
133
+ abort_grpc_context(e.message, context)
134
+
135
+ return res
122
136
 
123
137
  def GetFab(
124
138
  self, request: GetFabRequest, context: grpc.ServicerContext
125
139
  ) -> GetFabResponse:
126
140
  """Get FAB."""
127
141
  log(INFO, "[Fleet.GetFab] Requesting FAB for fab_hash=%s", request.hash_str)
128
- return message_handler.get_fab(
129
- request=request,
130
- ffs=self.ffs_factory.ffs(),
131
- )
142
+ try:
143
+ res = message_handler.get_fab(
144
+ request=request,
145
+ ffs=self.ffs_factory.ffs(),
146
+ state=self.state_factory.state(),
147
+ )
148
+ except InvalidRunStatusException as e:
149
+ abort_grpc_context(e.message, context)
150
+
151
+ return res
@@ -19,8 +19,9 @@ import time
19
19
  from typing import Optional
20
20
  from uuid import UUID
21
21
 
22
+ from flwr.common.constant import Status
22
23
  from flwr.common.serde import fab_to_proto, user_config_to_proto
23
- from flwr.common.typing import Fab
24
+ from flwr.common.typing import Fab, InvalidRunStatusException
24
25
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
25
26
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
26
27
  CreateNodeRequest,
@@ -44,6 +45,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
44
45
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
45
46
  from flwr.server.superlink.ffs.ffs import Ffs
46
47
  from flwr.server.superlink.linkstate import LinkState
48
+ from flwr.server.superlink.utils import check_abort
47
49
 
48
50
 
49
51
  def create_node(
@@ -98,6 +100,15 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
98
100
  task_res: TaskRes = request.task_res_list[0]
99
101
  # pylint: enable=no-member
100
102
 
103
+ # Abort if the run is not running
104
+ abort_msg = check_abort(
105
+ task_res.run_id,
106
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
107
+ state,
108
+ )
109
+ if abort_msg:
110
+ raise InvalidRunStatusException(abort_msg)
111
+
101
112
  # Set pushed_at (timestamp in seconds)
102
113
  task_res.task.pushed_at = time.time()
103
114
 
@@ -112,15 +123,22 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
112
123
  return response
113
124
 
114
125
 
115
- def get_run(
116
- request: GetRunRequest, state: LinkState # pylint: disable=W0613
117
- ) -> GetRunResponse:
126
+ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
118
127
  """Get run information."""
119
128
  run = state.get_run(request.run_id)
120
129
 
121
130
  if run is None:
122
131
  return GetRunResponse()
123
132
 
133
+ # Abort if the run is not running
134
+ abort_msg = check_abort(
135
+ request.run_id,
136
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
137
+ state,
138
+ )
139
+ if abort_msg:
140
+ raise InvalidRunStatusException(abort_msg)
141
+
124
142
  return GetRunResponse(
125
143
  run=Run(
126
144
  run_id=run.run_id,
@@ -133,9 +151,18 @@ def get_run(
133
151
 
134
152
 
135
153
  def get_fab(
136
- request: GetFabRequest, ffs: Ffs # pylint: disable=W0613
154
+ request: GetFabRequest, ffs: Ffs, state: LinkState # pylint: disable=W0613
137
155
  ) -> GetFabResponse:
138
156
  """Get FAB."""
157
+ # Abort if the run is not running
158
+ abort_msg = check_abort(
159
+ request.run_id,
160
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
161
+ state,
162
+ )
163
+ if abort_msg:
164
+ raise InvalidRunStatusException(abort_msg)
165
+
139
166
  if result := ffs.get(request.hash_str):
140
167
  fab = Fab(request.hash_str, result[0])
141
168
  return GetFabResponse(fab=fab_to_proto(fab))
@@ -154,8 +154,11 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
154
154
  # Get ffs from app
155
155
  ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()
156
156
 
157
+ # Get state from app
158
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
159
+
157
160
  # Handle message
158
- return message_handler.get_fab(request=request, ffs=ffs)
161
+ return message_handler.get_fab(request=request, ffs=ffs, state=state)
159
162
 
160
163
 
161
164
  routes = [
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Fleet Simulation Engine side."""
16
16
 
17
+
17
18
  from .vce_api import start_vce
18
19
 
19
20
  __all__ = [