flwr 1.13.0__py3-none-any.whl → 1.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. flwr/cli/app.py +5 -0
  2. flwr/cli/build.py +1 -37
  3. flwr/cli/cli_user_auth_interceptor.py +86 -0
  4. flwr/cli/config_utils.py +19 -2
  5. flwr/cli/example.py +1 -0
  6. flwr/cli/install.py +2 -19
  7. flwr/cli/log.py +18 -36
  8. flwr/cli/login/__init__.py +22 -0
  9. flwr/cli/login/login.py +81 -0
  10. flwr/cli/ls.py +205 -106
  11. flwr/cli/new/__init__.py +1 -0
  12. flwr/cli/new/new.py +25 -14
  13. flwr/cli/new/templates/app/.gitignore.tpl +3 -0
  14. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  15. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +3 -3
  16. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  17. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  18. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -3
  19. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  20. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  21. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  22. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  23. flwr/cli/run/__init__.py +1 -0
  24. flwr/cli/run/run.py +89 -39
  25. flwr/cli/stop.py +130 -0
  26. flwr/cli/utils.py +172 -8
  27. flwr/client/app.py +14 -3
  28. flwr/client/client.py +1 -32
  29. flwr/client/clientapp/app.py +4 -8
  30. flwr/client/clientapp/utils.py +1 -0
  31. flwr/client/grpc_adapter_client/connection.py +1 -1
  32. flwr/client/grpc_client/connection.py +1 -1
  33. flwr/client/grpc_rere_client/connection.py +13 -7
  34. flwr/client/message_handler/message_handler.py +1 -2
  35. flwr/client/mod/comms_mods.py +1 -0
  36. flwr/client/mod/localdp_mod.py +1 -1
  37. flwr/client/nodestate/__init__.py +1 -0
  38. flwr/client/nodestate/nodestate.py +1 -0
  39. flwr/client/nodestate/nodestate_factory.py +1 -0
  40. flwr/client/numpy_client.py +0 -44
  41. flwr/client/rest_client/connection.py +3 -3
  42. flwr/client/supernode/app.py +2 -2
  43. flwr/common/address.py +1 -0
  44. flwr/common/args.py +1 -0
  45. flwr/common/auth_plugin/__init__.py +24 -0
  46. flwr/common/auth_plugin/auth_plugin.py +111 -0
  47. flwr/common/config.py +3 -1
  48. flwr/common/constant.py +17 -1
  49. flwr/common/logger.py +40 -0
  50. flwr/common/message.py +1 -0
  51. flwr/common/object_ref.py +57 -54
  52. flwr/common/pyproject.py +1 -0
  53. flwr/common/record/__init__.py +1 -0
  54. flwr/common/record/parametersrecord.py +1 -0
  55. flwr/common/retry_invoker.py +77 -0
  56. flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
  57. flwr/common/telemetry.py +15 -4
  58. flwr/common/typing.py +12 -0
  59. flwr/common/version.py +1 -0
  60. flwr/proto/exec_pb2.py +38 -14
  61. flwr/proto/exec_pb2.pyi +107 -2
  62. flwr/proto/exec_pb2_grpc.py +102 -0
  63. flwr/proto/exec_pb2_grpc.pyi +39 -0
  64. flwr/proto/fab_pb2.py +4 -4
  65. flwr/proto/fab_pb2.pyi +4 -1
  66. flwr/proto/serverappio_pb2.py +18 -18
  67. flwr/proto/serverappio_pb2.pyi +8 -2
  68. flwr/proto/serverappio_pb2_grpc.py +34 -0
  69. flwr/proto/serverappio_pb2_grpc.pyi +13 -0
  70. flwr/proto/simulationio_pb2.py +2 -2
  71. flwr/proto/simulationio_pb2_grpc.py +34 -0
  72. flwr/proto/simulationio_pb2_grpc.pyi +13 -0
  73. flwr/server/app.py +62 -7
  74. flwr/server/compat/app_utils.py +7 -1
  75. flwr/server/driver/grpc_driver.py +11 -63
  76. flwr/server/driver/inmemory_driver.py +5 -1
  77. flwr/server/run_serverapp.py +8 -9
  78. flwr/server/serverapp/app.py +25 -10
  79. flwr/server/strategy/dpfedavg_fixed.py +1 -0
  80. flwr/server/superlink/driver/serverappio_grpc.py +1 -0
  81. flwr/server/superlink/driver/serverappio_servicer.py +82 -23
  82. flwr/server/superlink/ffs/disk_ffs.py +1 -0
  83. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
  84. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
  85. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
  86. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +12 -11
  87. flwr/server/superlink/fleet/message_handler/message_handler.py +32 -5
  88. flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
  89. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  90. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
  91. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
  92. flwr/server/superlink/linkstate/in_memory_linkstate.py +21 -30
  93. flwr/server/superlink/linkstate/linkstate.py +17 -2
  94. flwr/server/superlink/linkstate/sqlite_linkstate.py +30 -49
  95. flwr/server/superlink/simulation/simulationio_servicer.py +33 -0
  96. flwr/server/superlink/utils.py +65 -0
  97. flwr/simulation/app.py +59 -52
  98. flwr/simulation/ray_transport/ray_actor.py +1 -0
  99. flwr/simulation/ray_transport/utils.py +1 -0
  100. flwr/simulation/run_simulation.py +36 -22
  101. flwr/simulation/simulationio_connection.py +3 -0
  102. flwr/superexec/app.py +1 -0
  103. flwr/superexec/deployment.py +1 -0
  104. flwr/superexec/exec_grpc.py +19 -1
  105. flwr/superexec/exec_servicer.py +76 -2
  106. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  107. flwr/superexec/executor.py +1 -0
  108. {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/METADATA +8 -8
  109. {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/RECORD +112 -112
  110. flwr/proto/common_pb2.py +0 -36
  111. flwr/proto/common_pb2.pyi +0 -121
  112. flwr/proto/common_pb2_grpc.py +0 -4
  113. flwr/proto/common_pb2_grpc.pyi +0 -4
  114. flwr/proto/control_pb2.py +0 -27
  115. flwr/proto/control_pb2.pyi +0 -7
  116. flwr/proto/control_pb2_grpc.py +0 -135
  117. flwr/proto/control_pb2_grpc.pyi +0 -53
  118. {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/LICENSE +0 -0
  119. {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/WHEEL +0 -0
  120. {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/entry_points.txt +0 -0
flwr/server/app.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower server app."""
16
16
 
17
+
17
18
  import argparse
18
19
  import csv
19
20
  import importlib.util
@@ -24,9 +25,10 @@ from collections.abc import Sequence
24
25
  from logging import DEBUG, INFO, WARN
25
26
  from pathlib import Path
26
27
  from time import sleep
27
- from typing import Optional
28
+ from typing import Any, Optional
28
29
 
29
30
  import grpc
31
+ import yaml
30
32
  from cryptography.exceptions import UnsupportedAlgorithm
31
33
  from cryptography.hazmat.primitives.asymmetric import ec
32
34
  from cryptography.hazmat.primitives.serialization import (
@@ -37,8 +39,10 @@ from cryptography.hazmat.primitives.serialization import (
37
39
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
38
40
  from flwr.common.address import parse_address
39
41
  from flwr.common.args import try_obtain_server_certificates
42
+ from flwr.common.auth_plugin import ExecAuthPlugin
40
43
  from flwr.common.config import get_flwr_dir, parse_config_args
41
44
  from flwr.common.constant import (
45
+ AUTH_TYPE,
42
46
  CLIENT_OCTET,
43
47
  EXEC_API_DEFAULT_SERVER_ADDRESS,
44
48
  FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
@@ -66,7 +70,6 @@ from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
66
70
  from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
67
71
  from flwr.superexec.app import load_executor
68
72
  from flwr.superexec.exec_grpc import run_exec_api_grpc
69
- from flwr.superexec.simulation import SimulationEngine
70
73
 
71
74
  from .client_manager import ClientManager
72
75
  from .history import History
@@ -89,6 +92,19 @@ DATABASE = ":flwr-in-memory-state:"
89
92
  BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
90
93
 
91
94
 
95
+ try:
96
+ from flwr.ee import add_ee_args_superlink, get_exec_auth_plugins
97
+ except ImportError:
98
+
99
+ # pylint: disable-next=unused-argument
100
+ def add_ee_args_superlink(parser: argparse.ArgumentParser) -> None:
101
+ """Add EE-specific arguments to the parser."""
102
+
103
+ def get_exec_auth_plugins() -> dict[str, type[ExecAuthPlugin]]:
104
+ """Return all Exec API authentication plugins."""
105
+ raise NotImplementedError("No authentication plugins are currently supported.")
106
+
107
+
92
108
  def start_server( # pylint: disable=too-many-arguments,too-many-locals
93
109
  *,
94
110
  server_address: str = FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
@@ -247,6 +263,12 @@ def run_superlink() -> None:
247
263
  # Obtain certificates
248
264
  certificates = try_obtain_server_certificates(args, args.fleet_api_type)
249
265
 
266
+ user_auth_config = _try_obtain_user_auth_config(args)
267
+ auth_plugin: Optional[ExecAuthPlugin] = None
268
+ # user_auth_config is None only if the args.user_auth_config is not provided
269
+ if user_auth_config is not None:
270
+ auth_plugin = _try_obtain_exec_auth_plugin(user_auth_config)
271
+
250
272
  # Initialize StateFactory
251
273
  state_factory = LinkStateFactory(args.database)
252
274
 
@@ -264,13 +286,13 @@ def run_superlink() -> None:
264
286
  config=parse_config_args(
265
287
  [args.executor_config] if args.executor_config else args.executor_config
266
288
  ),
289
+ auth_plugin=auth_plugin,
267
290
  )
268
291
  grpc_servers = [exec_server]
269
292
 
270
293
  # Determine Exec plugin
271
294
  # If simulation is used, don't start ServerAppIo and Fleet APIs
272
- sim_exec = isinstance(executor, SimulationEngine)
273
-
295
+ sim_exec = executor.__class__.__qualname__ == "SimulationEngine"
274
296
  bckg_threads = []
275
297
 
276
298
  if sim_exec:
@@ -278,7 +300,7 @@ def run_superlink() -> None:
278
300
  address=simulationio_address,
279
301
  state_factory=state_factory,
280
302
  ffs_factory=ffs_factory,
281
- certificates=certificates,
303
+ certificates=None, # SimulationAppIo API doesn't support SSL yet
282
304
  )
283
305
  grpc_servers.append(simulationio_server)
284
306
 
@@ -352,6 +374,7 @@ def run_superlink() -> None:
352
374
  server_public_key,
353
375
  ) = maybe_keys
354
376
  state = state_factory.state()
377
+ state.clear_supernode_auth_keys_and_credentials()
355
378
  state.store_node_public_keys(node_public_keys)
356
379
  state.store_server_private_public_key(
357
380
  private_key_to_bytes(server_private_key),
@@ -362,7 +385,7 @@ def run_superlink() -> None:
362
385
  "Node authentication enabled with %d known public keys",
363
386
  len(node_public_keys),
364
387
  )
365
- interceptors = [AuthenticateServerInterceptor(state)]
388
+ interceptors = [AuthenticateServerInterceptor(state_factory)]
366
389
 
367
390
  fleet_server = _run_fleet_api_grpc_rere(
368
391
  address=fleet_address,
@@ -389,6 +412,9 @@ def run_superlink() -> None:
389
412
  io_address = (
390
413
  f"{CLIENT_OCTET}:{_port}" if _octet == SERVER_OCTET else serverappio_address
391
414
  )
415
+ address_arg = (
416
+ "--simulationio-api-address" if sim_exec else "--serverappio-api-address"
417
+ )
392
418
  address = simulationio_address if sim_exec else io_address
393
419
  cmd = "flwr-simulation" if sim_exec else "flwr-serverapp"
394
420
 
@@ -397,6 +423,7 @@ def run_superlink() -> None:
397
423
  target=_flwr_scheduler,
398
424
  args=(
399
425
  state_factory,
426
+ address_arg,
400
427
  address,
401
428
  cmd,
402
429
  ),
@@ -422,6 +449,7 @@ def run_superlink() -> None:
422
449
 
423
450
  def _flwr_scheduler(
424
451
  state_factory: LinkStateFactory,
452
+ io_api_arg: str,
425
453
  io_api_address: str,
426
454
  cmd: str,
427
455
  ) -> None:
@@ -446,7 +474,7 @@ def _flwr_scheduler(
446
474
  command = [
447
475
  cmd,
448
476
  "--run-once",
449
- "--serverappio-api-address",
477
+ io_api_arg,
450
478
  io_api_address,
451
479
  "--insecure",
452
480
  ]
@@ -556,6 +584,32 @@ def _try_setup_node_authentication(
556
584
  )
557
585
 
558
586
 
587
+ def _try_obtain_user_auth_config(args: argparse.Namespace) -> Optional[dict[str, Any]]:
588
+ if getattr(args, "user_auth_config", None) is not None:
589
+ with open(args.user_auth_config, encoding="utf-8") as file:
590
+ config: dict[str, Any] = yaml.safe_load(file)
591
+ return config
592
+ return None
593
+
594
+
595
+ def _try_obtain_exec_auth_plugin(config: dict[str, Any]) -> Optional[ExecAuthPlugin]:
596
+ auth_config: dict[str, Any] = config.get("authentication", {})
597
+ auth_type: str = auth_config.get(AUTH_TYPE, "")
598
+ try:
599
+ all_plugins: dict[str, type[ExecAuthPlugin]] = get_exec_auth_plugins()
600
+ auth_plugin_class = all_plugins[auth_type]
601
+ return auth_plugin_class(config=auth_config)
602
+ except KeyError:
603
+ if auth_type != "":
604
+ sys.exit(
605
+ f'Authentication type "{auth_type}" is not supported. '
606
+ "Please provide a valid authentication type in the configuration."
607
+ )
608
+ sys.exit("No authentication type is provided in the configuration.")
609
+ except NotImplementedError:
610
+ sys.exit("No authentication plugins are currently supported.")
611
+
612
+
559
613
  def _run_fleet_api_grpc_rere(
560
614
  address: str,
561
615
  state_factory: LinkStateFactory,
@@ -654,6 +708,7 @@ def _parse_args_run_superlink() -> argparse.ArgumentParser:
654
708
  )
655
709
 
656
710
  _add_args_common(parser=parser)
711
+ add_ee_args_superlink(parser=parser)
657
712
  _add_args_serverappio_api(parser=parser)
658
713
  _add_args_fleet_api(parser=parser)
659
714
  _add_args_exec_api(parser=parser)
@@ -17,6 +17,8 @@
17
17
 
18
18
  import threading
19
19
 
20
+ from flwr.common.typing import RunNotRunningException
21
+
20
22
  from ..client_manager import ClientManager
21
23
  from ..compat.driver_client_proxy import DriverClientProxy
22
24
  from ..driver import Driver
@@ -74,7 +76,11 @@ def _update_client_manager(
74
76
  # Loop until the driver is disconnected
75
77
  registered_nodes: dict[int, DriverClientProxy] = {}
76
78
  while not f_stop.is_set():
77
- all_node_ids = set(driver.get_node_ids())
79
+ try:
80
+ all_node_ids = set(driver.get_node_ids())
81
+ except RunNotRunningException:
82
+ f_stop.set()
83
+ break
78
84
  dead_nodes = set(registered_nodes).difference(all_node_ids)
79
85
  new_nodes = all_node_ids.difference(registered_nodes)
80
86
 
@@ -14,19 +14,20 @@
14
14
  # ==============================================================================
15
15
  """Flower gRPC Driver."""
16
16
 
17
+
17
18
  import time
18
19
  import warnings
19
20
  from collections.abc import Iterable
20
- from logging import DEBUG, INFO, WARN, WARNING
21
- from typing import Any, Optional, cast
21
+ from logging import DEBUG, WARNING
22
+ from typing import Optional, cast
22
23
 
23
24
  import grpc
24
25
 
25
26
  from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
26
- from flwr.common.constant import MAX_RETRY_DELAY, SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS
27
+ from flwr.common.constant import SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS
27
28
  from flwr.common.grpc import create_channel
28
29
  from flwr.common.logger import log
29
- from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
30
+ from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
30
31
  from flwr.common.serde import message_from_taskres, message_to_taskins, run_from_proto
31
32
  from flwr.common.typing import Run
32
33
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
@@ -203,7 +204,9 @@ class GrpcDriver(Driver):
203
204
  task_ins_list.append(taskins)
204
205
  # Call GrpcDriverStub method
205
206
  res: PushTaskInsResponse = self._stub.PushTaskIns(
206
- PushTaskInsRequest(task_ins_list=task_ins_list)
207
+ PushTaskInsRequest(
208
+ task_ins_list=task_ins_list, run_id=cast(Run, self._run).run_id
209
+ )
207
210
  )
208
211
  return list(res.task_ids)
209
212
 
@@ -215,7 +218,9 @@ class GrpcDriver(Driver):
215
218
  """
216
219
  # Pull TaskRes
217
220
  res: PullTaskResResponse = self._stub.PullTaskRes(
218
- PullTaskResRequest(node=self.node, task_ids=message_ids)
221
+ PullTaskResRequest(
222
+ node=self.node, task_ids=message_ids, run_id=cast(Run, self._run).run_id
223
+ )
219
224
  )
220
225
  # Convert TaskRes to Message
221
226
  msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
@@ -258,60 +263,3 @@ class GrpcDriver(Driver):
258
263
  return
259
264
  # Disconnect
260
265
  self._disconnect()
261
-
262
-
263
- def _make_simple_grpc_retry_invoker() -> RetryInvoker:
264
- """Create a simple gRPC retry invoker."""
265
-
266
- def _on_sucess(retry_state: RetryState) -> None:
267
- if retry_state.tries > 1:
268
- log(
269
- INFO,
270
- "Connection successful after %.2f seconds and %s tries.",
271
- retry_state.elapsed_time,
272
- retry_state.tries,
273
- )
274
-
275
- def _on_backoff(retry_state: RetryState) -> None:
276
- if retry_state.tries == 1:
277
- log(WARN, "Connection attempt failed, retrying...")
278
- else:
279
- log(
280
- WARN,
281
- "Connection attempt failed, retrying in %.2f seconds",
282
- retry_state.actual_wait,
283
- )
284
-
285
- def _on_giveup(retry_state: RetryState) -> None:
286
- if retry_state.tries > 1:
287
- log(
288
- WARN,
289
- "Giving up reconnection after %.2f seconds and %s tries.",
290
- retry_state.elapsed_time,
291
- retry_state.tries,
292
- )
293
-
294
- return RetryInvoker(
295
- wait_gen_factory=lambda: exponential(max_delay=MAX_RETRY_DELAY),
296
- recoverable_exceptions=grpc.RpcError,
297
- max_tries=None,
298
- max_time=None,
299
- on_success=_on_sucess,
300
- on_backoff=_on_backoff,
301
- on_giveup=_on_giveup,
302
- should_giveup=lambda e: e.code() != grpc.StatusCode.UNAVAILABLE, # type: ignore
303
- )
304
-
305
-
306
- def _wrap_stub(stub: ServerAppIoStub, retry_invoker: RetryInvoker) -> None:
307
- """Wrap the gRPC stub with a retry invoker."""
308
-
309
- def make_lambda(original_method: Any) -> Any:
310
- return lambda *args, **kwargs: retry_invoker.invoke(
311
- original_method, *args, **kwargs
312
- )
313
-
314
- for method_name in vars(stub):
315
- method = getattr(stub, method_name)
316
- if callable(method):
317
- setattr(stub, method_name, make_lambda(method))
@@ -142,7 +142,11 @@ class InMemoryDriver(Driver):
142
142
  # Pull TaskRes
143
143
  task_res_list = self.state.get_task_res(task_ids=msg_ids)
144
144
  # Delete tasks in state
145
- self.state.delete_tasks(msg_ids)
145
+ # Delete the TaskIns/TaskRes pairs if TaskRes is found
146
+ task_ins_ids_to_delete = {
147
+ UUID(task_res.task.ancestry[0]) for task_res in task_res_list
148
+ }
149
+ self.state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
146
150
  # Convert TaskRes to Message
147
151
  msgs = [message_from_taskres(taskres) for taskres in task_res_list]
148
152
  return msgs
@@ -15,12 +15,12 @@
15
15
  """Run ServerApp."""
16
16
 
17
17
 
18
- import sys
19
18
  from logging import DEBUG, ERROR
20
19
  from typing import Optional
21
20
 
22
- from flwr.common import Context
23
- from flwr.common.logger import log, warn_unsupported_feature
21
+ from flwr.common import Context, EventType, event
22
+ from flwr.common.exit_handlers import register_exit_handlers
23
+ from flwr.common.logger import log
24
24
  from flwr.common.object_ref import load_app
25
25
 
26
26
  from .driver import Driver
@@ -66,12 +66,11 @@ def run(
66
66
  return context
67
67
 
68
68
 
69
- # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
70
69
  def run_server_app() -> None:
71
70
  """Run Flower server app."""
72
- warn_unsupported_feature(
73
- "The command `flower-server-app` is deprecated and no longer in use. "
74
- "Use the `flwr-serverapp` exclusively instead."
71
+ event(EventType.RUN_SERVER_APP_ENTER)
72
+ log(
73
+ ERROR,
74
+ "The command `flower-server-app` has been replaced by `flwr run`.",
75
75
  )
76
- log(ERROR, "`flower-server-app` used.")
77
- sys.exit()
76
+ register_exit_handlers(event_type=EventType.RUN_SERVER_APP_LEAVE)
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower ServerApp process."""
16
16
 
17
+
17
18
  import argparse
18
19
  import sys
19
20
  from logging import DEBUG, ERROR, INFO
@@ -24,6 +25,7 @@ from typing import Optional
24
25
 
25
26
  from flwr.cli.config_utils import get_fab_metadata
26
27
  from flwr.cli.install import install_from_fab
28
+ from flwr.cli.utils import get_sha256_hash
27
29
  from flwr.common.args import add_args_flwr_app_common
28
30
  from flwr.common.config import (
29
31
  get_flwr_dir,
@@ -50,7 +52,8 @@ from flwr.common.serde import (
50
52
  run_from_proto,
51
53
  run_status_to_proto,
52
54
  )
53
- from flwr.common.typing import RunStatus
55
+ from flwr.common.telemetry import EventType, event
56
+ from flwr.common.typing import RunNotRunningException, RunStatus
54
57
  from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
55
58
  from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
56
59
  PullServerAppInputsRequest,
@@ -96,7 +99,7 @@ def flwr_serverapp() -> None:
96
99
  restore_output()
97
100
 
98
101
 
99
- def run_serverapp( # pylint: disable=R0914, disable=W0212
102
+ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
100
103
  serverappio_api_address: str,
101
104
  log_queue: Queue[Optional[str]],
102
105
  run_once: bool,
@@ -112,7 +115,8 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
112
115
  # Resolve directory where FABs are installed
113
116
  flwr_dir_ = get_flwr_dir(flwr_dir)
114
117
  log_uploader = None
115
-
118
+ success = True
119
+ hash_run_id = None
116
120
  while True:
117
121
 
118
122
  try:
@@ -128,6 +132,8 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
128
132
  run = run_from_proto(res.run)
129
133
  fab = fab_from_proto(res.fab)
130
134
 
135
+ hash_run_id = get_sha256_hash(run.run_id)
136
+
131
137
  driver.set_run(run.run_id)
132
138
 
133
139
  # Start log uploader for this run
@@ -170,6 +176,11 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
170
176
  UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
171
177
  )
172
178
 
179
+ event(
180
+ EventType.FLWR_SERVERAPP_RUN_ENTER,
181
+ event_details={"run-id-hash": hash_run_id},
182
+ )
183
+
173
184
  # Load and run the ServerApp with the Driver
174
185
  updated_context = run_(
175
186
  driver=driver,
@@ -186,11 +197,18 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
186
197
  _ = driver._stub.PushServerAppOutputs(out_req)
187
198
 
188
199
  run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
200
+ except RunNotRunningException:
201
+ log(INFO, "")
202
+ log(INFO, "Run ID %s stopped.", run.run_id)
203
+ log(INFO, "")
204
+ run_status = None
205
+ success = False
189
206
 
190
207
  except Exception as ex: # pylint: disable=broad-exception-caught
191
208
  exc_entity = "ServerApp"
192
209
  log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
193
210
  run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))
211
+ success = False
194
212
 
195
213
  finally:
196
214
  # Stop log uploader for this run and upload final logs
@@ -206,6 +224,10 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
206
224
  run_id=run.run_id, run_status=run_status_proto
207
225
  )
208
226
  )
227
+ event(
228
+ EventType.FLWR_SERVERAPP_RUN_LEAVE,
229
+ event_details={"run-id-hash": hash_run_id, "success": success},
230
+ )
209
231
 
210
232
  # Stop the loop if `flwr-serverapp` is expected to process a single run
211
233
  if run_once:
@@ -230,12 +252,5 @@ def _parse_args_run_flwr_serverapp() -> argparse.ArgumentParser:
230
252
  help="When set, this process will start a single ServerApp for a pending Run. "
231
253
  "If there is no pending Run, the process will exit.",
232
254
  )
233
- parser.add_argument(
234
- "--root-certificates",
235
- metavar="ROOT_CERT",
236
- type=str,
237
- help="Specifies the path to the PEM-encoded root certificate file for "
238
- "establishing secure HTTPS connections.",
239
- )
240
255
  add_args_flwr_app_common(parser=parser)
241
256
  return parser
@@ -17,6 +17,7 @@
17
17
  Paper: arxiv.org/pdf/1710.06963.pdf
18
18
  """
19
19
 
20
+
20
21
  from typing import Optional, Union
21
22
 
22
23
  from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """ServerAppIo gRPC API."""
16
16
 
17
+
17
18
  from logging import INFO
18
19
  from typing import Optional
19
20
 
@@ -32,6 +32,7 @@ from flwr.common.serde import (
32
32
  fab_from_proto,
33
33
  fab_to_proto,
34
34
  run_status_from_proto,
35
+ run_status_to_proto,
35
36
  run_to_proto,
36
37
  user_config_from_proto,
37
38
  )
@@ -48,6 +49,8 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
48
49
  CreateRunResponse,
49
50
  GetRunRequest,
50
51
  GetRunResponse,
52
+ GetRunStatusRequest,
53
+ GetRunStatusResponse,
51
54
  UpdateRunStatusRequest,
52
55
  UpdateRunStatusResponse,
53
56
  )
@@ -67,6 +70,7 @@ from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
67
70
  from flwr.server.superlink.ffs.ffs import Ffs
68
71
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
69
72
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
73
+ from flwr.server.superlink.utils import abort_if
70
74
  from flwr.server.utils.validator import validate_task_ins_or_res
71
75
 
72
76
 
@@ -85,7 +89,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
85
89
  ) -> GetNodesResponse:
86
90
  """Get available nodes."""
87
91
  log(DEBUG, "ServerAppIoServicer.GetNodes")
92
+
93
+ # Init state
88
94
  state: LinkState = self.state_factory.state()
95
+
96
+ # Abort if the run is not running
97
+ abort_if(
98
+ request.run_id,
99
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
100
+ state,
101
+ context,
102
+ )
103
+
89
104
  all_ids: set[int] = state.get_nodes(request.run_id)
90
105
  nodes: list[Node] = [
91
106
  Node(node_id=node_id, anonymous=False) for node_id in all_ids
@@ -123,6 +138,17 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
123
138
  """Push a set of TaskIns."""
124
139
  log(DEBUG, "ServerAppIoServicer.PushTaskIns")
125
140
 
141
+ # Init state
142
+ state: LinkState = self.state_factory.state()
143
+
144
+ # Abort if the run is not running
145
+ abort_if(
146
+ request.run_id,
147
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
148
+ state,
149
+ context,
150
+ )
151
+
126
152
  # Set pushed_at (timestamp in seconds)
127
153
  pushed_at = time.time()
128
154
  for task_ins in request.task_ins_list:
@@ -133,9 +159,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
133
159
  for task_ins in request.task_ins_list:
134
160
  validation_errors = validate_task_ins_or_res(task_ins)
135
161
  _raise_if(bool(validation_errors), ", ".join(validation_errors))
136
-
137
- # Init state
138
- state: LinkState = self.state_factory.state()
162
+ _raise_if(
163
+ request.run_id != task_ins.run_id, "`task_ins` has mismatched `run_id`"
164
+ )
139
165
 
140
166
  # Store each TaskIns
141
167
  task_ids: list[Optional[UUID]] = []
@@ -153,33 +179,35 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
153
179
  """Pull a set of TaskRes."""
154
180
  log(DEBUG, "ServerAppIoServicer.PullTaskRes")
155
181
 
156
- # Convert each task_id str to UUID
157
- task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
158
-
159
182
  # Init state
160
183
  state: LinkState = self.state_factory.state()
161
184
 
162
- # Register callback
163
- def on_rpc_done() -> None:
164
- log(
165
- DEBUG,
166
- "ServerAppIoServicer.PullTaskRes callback: delete TaskIns/TaskRes",
167
- )
168
-
169
- if context.is_active():
170
- return
171
- if context.code() != grpc.StatusCode.OK:
172
- return
173
-
174
- # Delete delivered TaskIns and TaskRes
175
- state.delete_tasks(task_ids=task_ids)
185
+ # Abort if the run is not running
186
+ abort_if(
187
+ request.run_id,
188
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
189
+ state,
190
+ context,
191
+ )
176
192
 
177
- context.add_callback(on_rpc_done)
193
+ # Convert each task_id str to UUID
194
+ task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
178
195
 
179
196
  # Read from state
180
197
  task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
181
198
 
182
- context.set_code(grpc.StatusCode.OK)
199
+ # Validate request
200
+ for task_res in task_res_list:
201
+ _raise_if(
202
+ request.run_id != task_res.run_id, "`task_res` has mismatched `run_id`"
203
+ )
204
+
205
+ # Delete the TaskIns/TaskRes pairs if TaskRes is found
206
+ task_ins_ids_to_delete = {
207
+ UUID(task_res.task.ancestry[0]) for task_res in task_res_list
208
+ }
209
+ state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
210
+
183
211
  return PullTaskResResponse(task_res_list=task_res_list)
184
212
 
185
213
  def GetRun(
@@ -255,7 +283,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
255
283
  ) -> PushServerAppOutputsResponse:
256
284
  """Push ServerApp process outputs."""
257
285
  log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
286
+
287
+ # Init state
258
288
  state = self.state_factory.state()
289
+
290
+ # Abort if the run is not running
291
+ abort_if(
292
+ request.run_id,
293
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
294
+ state,
295
+ context,
296
+ )
297
+
259
298
  state.set_serverapp_context(request.run_id, context_from_proto(request.context))
260
299
  return PushServerAppOutputsResponse()
261
300
 
@@ -263,9 +302,14 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
263
302
  self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
264
303
  ) -> UpdateRunStatusResponse:
265
304
  """Update the status of a run."""
266
- log(DEBUG, "ControlServicer.UpdateRunStatus")
305
+ log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
306
+
307
+ # Init state
267
308
  state = self.state_factory.state()
268
309
 
310
+ # Abort if the run is finished
311
+ abort_if(request.run_id, [Status.FINISHED], state, context)
312
+
269
313
  # Update the run status
270
314
  state.update_run_status(
271
315
  run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
@@ -284,6 +328,21 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
284
328
  state.add_serverapp_log(request.run_id, merged_logs)
285
329
  return PushLogsResponse()
286
330
 
331
+ def GetRunStatus(
332
+ self, request: GetRunStatusRequest, context: grpc.ServicerContext
333
+ ) -> GetRunStatusResponse:
334
+ """Get the status of a run."""
335
+ log(DEBUG, "ServerAppIoServicer.GetRunStatus")
336
+ state = self.state_factory.state()
337
+
338
+ # Get run status from LinkState
339
+ run_statuses = state.get_run_status(set(request.run_ids))
340
+ run_status_dict = {
341
+ run_id: run_status_to_proto(run_status)
342
+ for run_id, run_status in run_statuses.items()
343
+ }
344
+ return GetRunStatusResponse(run_status_dict=run_status_dict)
345
+
287
346
 
288
347
  def _raise_if(validation_error: bool, detail: str) -> None:
289
348
  if validation_error:
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Disk based Flower File Storage."""
16
16
 
17
+
17
18
  import hashlib
18
19
  import json
19
20
  from pathlib import Path
@@ -158,4 +158,5 @@ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
158
158
  return message_handler.get_fab(
159
159
  request=request,
160
160
  ffs=self.ffs_factory.ffs(),
161
+ state=self.state_factory.state(),
161
162
  )