flwr-nightly 1.14.0.dev20241204__py3-none-any.whl → 1.14.0.dev20241216__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (101) hide show
  1. flwr/cli/app.py +5 -0
  2. flwr/cli/build.py +1 -0
  3. flwr/cli/cli_user_auth_interceptor.py +86 -0
  4. flwr/cli/config_utils.py +19 -2
  5. flwr/cli/example.py +1 -0
  6. flwr/cli/install.py +1 -0
  7. flwr/cli/log.py +11 -31
  8. flwr/cli/login/__init__.py +22 -0
  9. flwr/cli/login/login.py +81 -0
  10. flwr/cli/ls.py +25 -55
  11. flwr/cli/new/__init__.py +1 -0
  12. flwr/cli/new/new.py +2 -1
  13. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  14. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -2
  15. flwr/cli/run/__init__.py +1 -0
  16. flwr/cli/run/run.py +17 -39
  17. flwr/cli/stop.py +129 -0
  18. flwr/cli/utils.py +96 -1
  19. flwr/client/app.py +14 -3
  20. flwr/client/client.py +1 -0
  21. flwr/client/clientapp/app.py +4 -1
  22. flwr/client/clientapp/utils.py +1 -0
  23. flwr/client/grpc_adapter_client/connection.py +1 -1
  24. flwr/client/grpc_client/connection.py +1 -1
  25. flwr/client/grpc_rere_client/connection.py +13 -7
  26. flwr/client/message_handler/message_handler.py +1 -0
  27. flwr/client/mod/comms_mods.py +1 -0
  28. flwr/client/mod/localdp_mod.py +1 -1
  29. flwr/client/nodestate/__init__.py +1 -0
  30. flwr/client/nodestate/nodestate.py +1 -0
  31. flwr/client/nodestate/nodestate_factory.py +1 -0
  32. flwr/client/rest_client/connection.py +3 -3
  33. flwr/client/supernode/app.py +1 -0
  34. flwr/common/address.py +1 -0
  35. flwr/common/args.py +1 -0
  36. flwr/common/auth_plugin/__init__.py +24 -0
  37. flwr/common/auth_plugin/auth_plugin.py +111 -0
  38. flwr/common/config.py +3 -1
  39. flwr/common/constant.py +6 -1
  40. flwr/common/logger.py +17 -1
  41. flwr/common/message.py +1 -0
  42. flwr/common/object_ref.py +57 -54
  43. flwr/common/pyproject.py +1 -0
  44. flwr/common/record/__init__.py +1 -0
  45. flwr/common/record/parametersrecord.py +1 -0
  46. flwr/common/retry_invoker.py +77 -0
  47. flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
  48. flwr/common/telemetry.py +2 -1
  49. flwr/common/typing.py +12 -0
  50. flwr/common/version.py +1 -0
  51. flwr/proto/exec_pb2.py +27 -3
  52. flwr/proto/exec_pb2.pyi +103 -0
  53. flwr/proto/exec_pb2_grpc.py +102 -0
  54. flwr/proto/exec_pb2_grpc.pyi +39 -0
  55. flwr/proto/fab_pb2.py +4 -4
  56. flwr/proto/fab_pb2.pyi +4 -1
  57. flwr/proto/serverappio_pb2.py +18 -18
  58. flwr/proto/serverappio_pb2.pyi +8 -2
  59. flwr/proto/serverappio_pb2_grpc.py +34 -0
  60. flwr/proto/serverappio_pb2_grpc.pyi +13 -0
  61. flwr/proto/simulationio_pb2.py +2 -2
  62. flwr/proto/simulationio_pb2_grpc.py +34 -0
  63. flwr/proto/simulationio_pb2_grpc.pyi +13 -0
  64. flwr/server/app.py +52 -1
  65. flwr/server/compat/app_utils.py +7 -1
  66. flwr/server/driver/grpc_driver.py +11 -63
  67. flwr/server/driver/inmemory_driver.py +5 -1
  68. flwr/server/serverapp/app.py +9 -2
  69. flwr/server/strategy/dpfedavg_fixed.py +1 -0
  70. flwr/server/superlink/driver/serverappio_grpc.py +1 -0
  71. flwr/server/superlink/driver/serverappio_servicer.py +72 -22
  72. flwr/server/superlink/ffs/disk_ffs.py +1 -0
  73. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
  74. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
  76. flwr/server/superlink/fleet/message_handler/message_handler.py +32 -5
  77. flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
  78. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  79. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
  80. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
  81. flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -30
  82. flwr/server/superlink/linkstate/linkstate.py +13 -2
  83. flwr/server/superlink/linkstate/sqlite_linkstate.py +24 -44
  84. flwr/server/superlink/simulation/simulationio_servicer.py +20 -0
  85. flwr/server/superlink/utils.py +65 -0
  86. flwr/simulation/app.py +1 -0
  87. flwr/simulation/ray_transport/ray_actor.py +1 -0
  88. flwr/simulation/ray_transport/utils.py +1 -0
  89. flwr/simulation/run_simulation.py +1 -15
  90. flwr/simulation/simulationio_connection.py +3 -0
  91. flwr/superexec/app.py +1 -0
  92. flwr/superexec/deployment.py +1 -0
  93. flwr/superexec/exec_grpc.py +19 -1
  94. flwr/superexec/exec_servicer.py +76 -2
  95. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  96. flwr/superexec/executor.py +1 -0
  97. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/METADATA +8 -7
  98. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/RECORD +101 -93
  99. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/LICENSE +0 -0
  100. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/WHEEL +0 -0
  101. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/entry_points.txt +0 -0
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Simulation Engine Backends."""
16
16
 
17
+
17
18
  import importlib
18
19
 
19
20
  from .backend import Backend, BackendConfig
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Ray backend for the Fleet API using the Simulation Engine."""
16
16
 
17
+
17
18
  import sys
18
19
  from logging import DEBUG, ERROR
19
20
  from typing import Callable, Optional, Union
@@ -265,41 +265,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
265
265
  for task_res in task_res_found:
266
266
  task_res.task.delivered_at = delivered_at
267
267
 
268
- # Cleanup
269
- self._force_delete_tasks_by_ids(set(ret.keys()))
270
-
271
268
  return list(ret.values())
272
269
 
273
- def delete_tasks(self, task_ids: set[UUID]) -> None:
274
- """Delete all delivered TaskIns/TaskRes pairs."""
275
- task_ins_to_be_deleted: set[UUID] = set()
276
- task_res_to_be_deleted: set[UUID] = set()
277
-
278
- with self.lock:
279
- for task_ins_id in task_ids:
280
- # Find the task_id of the matching task_res
281
- for task_res_id, task_res in self.task_res_store.items():
282
- if UUID(task_res.task.ancestry[0]) != task_ins_id:
283
- continue
284
- if task_res.task.delivered_at == "":
285
- continue
286
-
287
- task_ins_to_be_deleted.add(task_ins_id)
288
- task_res_to_be_deleted.add(task_res_id)
289
-
290
- for task_id in task_ins_to_be_deleted:
291
- del self.task_ins_store[task_id]
292
- del self.task_ins_id_to_task_res_id[task_id]
293
- for task_id in task_res_to_be_deleted:
294
- del self.task_res_store[task_id]
295
-
296
- def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
297
- """Delete tasks based on a set of TaskIns IDs."""
298
- if not task_ids:
270
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
271
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
272
+ if not task_ins_ids:
299
273
  return
300
274
 
301
275
  with self.lock:
302
- for task_id in task_ids:
276
+ for task_id in task_ins_ids:
303
277
  # Delete TaskIns
304
278
  if task_id in self.task_ins_store:
305
279
  del self.task_ins_store[task_id]
@@ -308,6 +282,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
308
282
  task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
309
283
  del self.task_res_store[task_res_id]
310
284
 
285
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
286
+ """Get all TaskIns IDs for the given run_id."""
287
+ task_id_list: set[UUID] = set()
288
+ with self.lock:
289
+ for task_id, task_ins in self.task_ins_store.items():
290
+ if task_ins.run_id == run_id:
291
+ task_id_list.add(task_id)
292
+
293
+ return task_id_list
294
+
311
295
  def num_task_ins(self) -> int:
312
296
  """Calculate the number of task_ins in store.
313
297
 
@@ -139,8 +139,19 @@ class LinkState(abc.ABC): # pylint: disable=R0904
139
139
  """
140
140
 
141
141
  @abc.abstractmethod
142
- def delete_tasks(self, task_ids: set[UUID]) -> None:
143
- """Delete all delivered TaskIns/TaskRes pairs."""
142
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
143
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.
144
+
145
+ Parameters
146
+ ----------
147
+ task_ins_ids : set[UUID]
148
+ A set of TaskIns IDs. For each ID in the set, the corresponding
149
+ TaskIns and its associated TaskRes will be deleted.
150
+ """
151
+
152
+ @abc.abstractmethod
153
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
154
+ """Get all TaskIns IDs for the given run_id."""
144
155
 
145
156
  @abc.abstractmethod
146
157
  def create_node(
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """SQLite based implemenation of the link state."""
16
16
 
17
+
17
18
  # pylint: disable=too-many-lines
18
19
 
19
20
  import json
@@ -566,9 +567,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
566
567
  data: list[Any] = [delivered_at] + task_res_ids
567
568
  self.query(query, data)
568
569
 
569
- # Cleanup
570
- self._force_delete_tasks_by_ids(set(ret.keys()))
571
-
572
570
  return list(ret.values())
573
571
 
574
572
  def num_task_ins(self) -> int:
@@ -592,68 +590,50 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
592
590
  result: dict[str, int] = rows[0]
593
591
  return result["num"]
594
592
 
595
- def delete_tasks(self, task_ids: set[UUID]) -> None:
596
- """Delete all delivered TaskIns/TaskRes pairs."""
597
- ids = list(task_ids)
598
- if len(ids) == 0:
599
- return None
593
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
594
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
595
+ if not task_ins_ids:
596
+ return
597
+ if self.conn is None:
598
+ raise AttributeError("LinkState not initialized")
600
599
 
601
- placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
602
- data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
600
+ placeholders = ",".join(["?"] * len(task_ins_ids))
601
+ data = tuple(str(task_id) for task_id in task_ins_ids)
603
602
 
604
- # 1. Query: Delete task_ins which have a delivered task_res
603
+ # Delete task_ins
605
604
  query_1 = f"""
606
605
  DELETE FROM task_ins
607
- WHERE delivered_at != ''
608
- AND task_id IN (
609
- SELECT ancestry
610
- FROM task_res
611
- WHERE ancestry IN ({placeholders})
612
- AND delivered_at != ''
613
- );
606
+ WHERE task_id IN ({placeholders});
614
607
  """
615
608
 
616
- # 2. Query: Delete delivered task_res to be run after 1. Query
609
+ # Delete task_res
617
610
  query_2 = f"""
618
611
  DELETE FROM task_res
619
- WHERE ancestry IN ({placeholders})
620
- AND delivered_at != '';
612
+ WHERE ancestry IN ({placeholders});
621
613
  """
622
614
 
623
- if self.conn is None:
624
- raise AttributeError("LinkState not intitialized")
625
-
626
615
  with self.conn:
627
616
  self.conn.execute(query_1, data)
628
617
  self.conn.execute(query_2, data)
629
618
 
630
- return None
631
-
632
- def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
633
- """Delete tasks based on a set of TaskIns IDs."""
634
- if not task_ids:
635
- return
619
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
620
+ """Get all TaskIns IDs for the given run_id."""
636
621
  if self.conn is None:
637
622
  raise AttributeError("LinkState not initialized")
638
623
 
639
- placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
640
- data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
641
-
642
- # Delete task_ins
643
- query_1 = f"""
644
- DELETE FROM task_ins
645
- WHERE task_id IN ({placeholders});
624
+ query = """
625
+ SELECT task_id
626
+ FROM task_ins
627
+ WHERE run_id = :run_id;
646
628
  """
647
629
 
648
- # Delete task_res
649
- query_2 = f"""
650
- DELETE FROM task_res
651
- WHERE ancestry IN ({placeholders});
652
- """
630
+ sint64_run_id = convert_uint64_to_sint64(run_id)
631
+ data = {"run_id": sint64_run_id}
653
632
 
654
633
  with self.conn:
655
- self.conn.execute(query_1, data)
656
- self.conn.execute(query_2, data)
634
+ rows = self.conn.execute(query, data).fetchall()
635
+
636
+ return {UUID(row["task_id"]) for row in rows}
657
637
 
658
638
  def create_node(
659
639
  self, ping_interval: float, public_key: Optional[bytes] = None
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """SimulationIo API servicer."""
16
16
 
17
+
17
18
  import threading
18
19
  from logging import DEBUG, INFO
19
20
 
@@ -28,6 +29,7 @@ from flwr.common.serde import (
28
29
  context_to_proto,
29
30
  fab_to_proto,
30
31
  run_status_from_proto,
32
+ run_status_to_proto,
31
33
  run_to_proto,
32
34
  )
33
35
  from flwr.common.typing import Fab, RunStatus
@@ -39,6 +41,8 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
39
41
  from flwr.proto.run_pb2 import ( # pylint: disable=E0611
40
42
  GetFederationOptionsRequest,
41
43
  GetFederationOptionsResponse,
44
+ GetRunStatusRequest,
45
+ GetRunStatusResponse,
42
46
  UpdateRunStatusRequest,
43
47
  UpdateRunStatusResponse,
44
48
  )
@@ -122,6 +126,22 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
122
126
  )
123
127
  return UpdateRunStatusResponse()
124
128
 
129
+ def GetRunStatus(
130
+ self, request: GetRunStatusRequest, context: ServicerContext
131
+ ) -> GetRunStatusResponse:
132
+ """Get status of requested runs."""
133
+ log(DEBUG, "SimultionIoServicer.GetRunStatus")
134
+ state = self.state_factory.state()
135
+
136
+ statuses = state.get_run_status(set(request.run_ids))
137
+
138
+ return GetRunStatusResponse(
139
+ run_status_dict={
140
+ run_id: run_status_to_proto(status)
141
+ for run_id, status in statuses.items()
142
+ }
143
+ )
144
+
125
145
  def PushLogs(
126
146
  self, request: PushLogsRequest, context: grpc.ServicerContext
127
147
  ) -> PushLogsResponse:
@@ -0,0 +1,65 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """SuperLink utilities."""
16
+
17
+
18
+ from typing import Union
19
+
20
+ import grpc
21
+
22
+ from flwr.common.constant import Status, SubStatus
23
+ from flwr.common.typing import RunStatus
24
+ from flwr.server.superlink.linkstate import LinkState
25
+
26
+ _STATUS_TO_MSG = {
27
+ Status.PENDING: "Run is pending.",
28
+ Status.STARTING: "Run is starting.",
29
+ Status.RUNNING: "Run is running.",
30
+ Status.FINISHED: "Run is finished.",
31
+ }
32
+
33
+
34
+ def check_abort(
35
+ run_id: int,
36
+ abort_status_list: list[str],
37
+ state: LinkState,
38
+ ) -> Union[str, None]:
39
+ """Check if the status of the provided `run_id` is in `abort_status_list`."""
40
+ run_status: RunStatus = state.get_run_status({run_id})[run_id]
41
+
42
+ if run_status.status in abort_status_list:
43
+ msg = _STATUS_TO_MSG[run_status.status]
44
+ if run_status.sub_status == SubStatus.STOPPED:
45
+ msg += " Stopped by user."
46
+ return msg
47
+
48
+ return None
49
+
50
+
51
+ def abort_grpc_context(msg: Union[str, None], context: grpc.ServicerContext) -> None:
52
+ """Abort context with statuscode PERMISSION_DENIED if `msg` is not None."""
53
+ if msg is not None:
54
+ context.abort(grpc.StatusCode.PERMISSION_DENIED, msg)
55
+
56
+
57
+ def abort_if(
58
+ run_id: int,
59
+ abort_status_list: list[str],
60
+ state: LinkState,
61
+ context: grpc.ServicerContext,
62
+ ) -> None:
63
+ """Abort context if status of the provided `run_id` is in `abort_status_list`."""
64
+ msg = check_abort(run_id, abort_status_list, state)
65
+ abort_grpc_context(msg, context)
flwr/simulation/app.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower Simulation process."""
16
16
 
17
+
17
18
  import argparse
18
19
  import sys
19
20
  from logging import DEBUG, ERROR, INFO
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Ray-based Flower Actor and ActorPool implementation."""
16
16
 
17
+
17
18
  import threading
18
19
  from abc import ABC
19
20
  from logging import DEBUG, ERROR, WARNING
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Utilities for Actors in the Virtual Client Engine."""
16
16
 
17
+
17
18
  import traceback
18
19
  import warnings
19
20
  from logging import ERROR
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower Simulation."""
16
16
 
17
+
17
18
  import argparse
18
19
  import asyncio
19
20
  import json
@@ -23,7 +24,6 @@ import threading
23
24
  import traceback
24
25
  from logging import DEBUG, ERROR, INFO, WARNING
25
26
  from pathlib import Path
26
- from time import sleep
27
27
  from typing import Any, Optional
28
28
 
29
29
  from flwr.cli.config_utils import load_and_validate
@@ -135,7 +135,6 @@ def run_simulation_from_cli() -> None:
135
135
  app_dir=args.app,
136
136
  run=run,
137
137
  enable_tf_gpu_growth=args.enable_tf_gpu_growth,
138
- delay_start=args.delay_start,
139
138
  verbose_logging=args.verbose,
140
139
  server_app_run_config=fused_config,
141
140
  is_app=True,
@@ -308,7 +307,6 @@ def _main_loop(
308
307
  enable_tf_gpu_growth: bool,
309
308
  run: Run,
310
309
  exit_event: EventType,
311
- delay_start: int,
312
310
  flwr_dir: Optional[str] = None,
313
311
  client_app: Optional[ClientApp] = None,
314
312
  client_app_attr: Optional[str] = None,
@@ -353,9 +351,6 @@ def _main_loop(
353
351
  run_id=run.run_id,
354
352
  )
355
353
 
356
- # Buffer time so the `ServerApp` in separate thread is ready
357
- log(DEBUG, "Buffer time delay: %ds", delay_start)
358
- sleep(delay_start)
359
354
  # Start Simulation Engine
360
355
  vce.start_vce(
361
356
  num_supernodes=num_supernodes,
@@ -404,7 +399,6 @@ def _run_simulation(
404
399
  flwr_dir: Optional[str] = None,
405
400
  run: Optional[Run] = None,
406
401
  enable_tf_gpu_growth: bool = False,
407
- delay_start: int = 5,
408
402
  verbose_logging: bool = False,
409
403
  is_app: bool = False,
410
404
  ) -> None:
@@ -459,7 +453,6 @@ def _run_simulation(
459
453
  enable_tf_gpu_growth,
460
454
  run,
461
455
  exit_event,
462
- delay_start,
463
456
  flwr_dir,
464
457
  client_app,
465
458
  client_app_attr,
@@ -537,13 +530,6 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
537
530
  "Read more about how `tf.config.experimental.set_memory_growth()` works in "
538
531
  "the TensorFlow documentation: https://www.tensorflow.org/api/stable.",
539
532
  )
540
- parser.add_argument(
541
- "--delay-start",
542
- type=int,
543
- default=3,
544
- help="Buffer time (in seconds) to delay the start the simulation engine after "
545
- "the `ServerApp`, which runs in a separate thread, has been launched.",
546
- )
547
533
  parser.add_argument(
548
534
  "--verbose",
549
535
  action="store_true",
@@ -23,6 +23,7 @@ import grpc
23
23
  from flwr.common.constant import SIMULATIONIO_API_DEFAULT_CLIENT_ADDRESS
24
24
  from flwr.common.grpc import create_channel
25
25
  from flwr.common.logger import log
26
+ from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
26
27
  from flwr.proto.simulationio_pb2_grpc import SimulationIoStub # pylint: disable=E0611
27
28
 
28
29
 
@@ -48,6 +49,7 @@ class SimulationIoConnection:
48
49
  self._cert = root_certificates
49
50
  self._grpc_stub: Optional[SimulationIoStub] = None
50
51
  self._channel: Optional[grpc.Channel] = None
52
+ self._retry_invoker = _make_simple_grpc_retry_invoker()
51
53
 
52
54
  @property
53
55
  def _is_connected(self) -> bool:
@@ -72,6 +74,7 @@ class SimulationIoConnection:
72
74
  root_certificates=self._cert,
73
75
  )
74
76
  self._grpc_stub = SimulationIoStub(self._channel)
77
+ _wrap_stub(self._grpc_stub, self._retry_invoker)
75
78
  log(DEBUG, "[SimulationIO] Connected to %s", self._addr)
76
79
 
77
80
  def _disconnect(self) -> None:
flwr/superexec/app.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower SuperExec app."""
16
16
 
17
+
17
18
  import argparse
18
19
  import sys
19
20
  from logging import INFO
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Deployment engine executor."""
16
16
 
17
+
17
18
  import hashlib
18
19
  from logging import ERROR, INFO
19
20
  from pathlib import Path
@@ -14,18 +14,22 @@
14
14
  # ==============================================================================
15
15
  """SuperExec gRPC API."""
16
16
 
17
+
18
+ from collections.abc import Sequence
17
19
  from logging import INFO
18
20
  from typing import Optional
19
21
 
20
22
  import grpc
21
23
 
22
24
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
25
+ from flwr.common.auth_plugin import ExecAuthPlugin
23
26
  from flwr.common.logger import log
24
27
  from flwr.common.typing import UserConfig
25
28
  from flwr.proto.exec_pb2_grpc import add_ExecServicer_to_server
26
29
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
27
30
  from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
28
31
  from flwr.server.superlink.linkstate import LinkStateFactory
32
+ from flwr.superexec.exec_user_auth_interceptor import ExecUserAuthInterceptor
29
33
 
30
34
  from .exec_servicer import ExecServicer
31
35
  from .executor import Executor
@@ -39,6 +43,7 @@ def run_exec_api_grpc(
39
43
  ffs_factory: FfsFactory,
40
44
  certificates: Optional[tuple[bytes, bytes, bytes]],
41
45
  config: UserConfig,
46
+ auth_plugin: Optional[ExecAuthPlugin] = None,
42
47
  ) -> grpc.Server:
43
48
  """Run Exec API (gRPC, request-response)."""
44
49
  executor.set_config(config)
@@ -47,16 +52,29 @@ def run_exec_api_grpc(
47
52
  linkstate_factory=state_factory,
48
53
  ffs_factory=ffs_factory,
49
54
  executor=executor,
55
+ auth_plugin=auth_plugin,
50
56
  )
57
+ interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
58
+ if auth_plugin is not None:
59
+ interceptors = [ExecUserAuthInterceptor(auth_plugin)]
51
60
  exec_add_servicer_to_server_fn = add_ExecServicer_to_server
52
61
  exec_grpc_server = generic_create_grpc_server(
53
62
  servicer_and_add_fn=(exec_servicer, exec_add_servicer_to_server_fn),
54
63
  server_address=address,
55
64
  max_message_length=GRPC_MAX_MESSAGE_LENGTH,
56
65
  certificates=certificates,
66
+ interceptors=interceptors,
57
67
  )
58
68
 
59
- log(INFO, "Flower Deployment Engine: Starting Exec API on %s", address)
69
+ if auth_plugin is None:
70
+ log(INFO, "Flower Deployment Engine: Starting Exec API on %s", address)
71
+ else:
72
+ log(
73
+ INFO,
74
+ "Flower Deployment Engine: Starting Exec API with user "
75
+ "authentication on %s",
76
+ address,
77
+ )
60
78
  exec_grpc_server.start()
61
79
 
62
80
  return exec_grpc_server
@@ -18,24 +18,33 @@
18
18
  import time
19
19
  from collections.abc import Generator
20
20
  from logging import ERROR, INFO
21
- from typing import Any
21
+ from typing import Any, Optional
22
+ from uuid import UUID
22
23
 
23
24
  import grpc
24
25
 
25
26
  from flwr.common import now
26
- from flwr.common.constant import LOG_STREAM_INTERVAL, Status
27
+ from flwr.common.auth_plugin import ExecAuthPlugin
28
+ from flwr.common.constant import LOG_STREAM_INTERVAL, Status, SubStatus
27
29
  from flwr.common.logger import log
28
30
  from flwr.common.serde import (
29
31
  configs_record_from_proto,
30
32
  run_to_proto,
31
33
  user_config_from_proto,
32
34
  )
35
+ from flwr.common.typing import RunStatus
33
36
  from flwr.proto import exec_pb2_grpc # pylint: disable=E0611
34
37
  from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
38
+ GetAuthTokensRequest,
39
+ GetAuthTokensResponse,
40
+ GetLoginDetailsRequest,
41
+ GetLoginDetailsResponse,
35
42
  ListRunsRequest,
36
43
  ListRunsResponse,
37
44
  StartRunRequest,
38
45
  StartRunResponse,
46
+ StopRunRequest,
47
+ StopRunResponse,
39
48
  StreamLogsRequest,
40
49
  StreamLogsResponse,
41
50
  )
@@ -53,11 +62,13 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
53
62
  linkstate_factory: LinkStateFactory,
54
63
  ffs_factory: FfsFactory,
55
64
  executor: Executor,
65
+ auth_plugin: Optional[ExecAuthPlugin] = None,
56
66
  ) -> None:
57
67
  self.linkstate_factory = linkstate_factory
58
68
  self.ffs_factory = ffs_factory
59
69
  self.executor = executor
60
70
  self.executor.initialize(linkstate_factory, ffs_factory)
71
+ self.auth_plugin = auth_plugin
61
72
 
62
73
  def StartRun(
63
74
  self, request: StartRunRequest, context: grpc.ServicerContext
@@ -126,6 +137,69 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
126
137
  # Handle `flwr ls --run-id <run_id>`
127
138
  return _create_list_runs_response({request.run_id}, state)
128
139
 
140
+ def StopRun(
141
+ self, request: StopRunRequest, context: grpc.ServicerContext
142
+ ) -> StopRunResponse:
143
+ """Stop a given run ID."""
144
+ log(INFO, "ExecServicer.StopRun")
145
+ state = self.linkstate_factory.state()
146
+
147
+ # Exit if `run_id` not found
148
+ if not state.get_run(request.run_id):
149
+ context.abort(
150
+ grpc.StatusCode.NOT_FOUND, f"Run ID {request.run_id} not found"
151
+ )
152
+
153
+ run_status = state.get_run_status({request.run_id})[request.run_id]
154
+ if run_status.status == Status.FINISHED:
155
+ context.abort(
156
+ grpc.StatusCode.FAILED_PRECONDITION,
157
+ f"Run ID {request.run_id} is already finished",
158
+ )
159
+
160
+ update_success = state.update_run_status(
161
+ run_id=request.run_id,
162
+ new_status=RunStatus(Status.FINISHED, SubStatus.STOPPED, ""),
163
+ )
164
+
165
+ if update_success:
166
+ task_ids: set[UUID] = state.get_task_ids_from_run_id(request.run_id)
167
+
168
+ # Delete TaskIns and TaskRes for the `run_id`
169
+ state.delete_tasks(task_ids)
170
+
171
+ return StopRunResponse(success=update_success)
172
+
173
+ def GetLoginDetails(
174
+ self, request: GetLoginDetailsRequest, context: grpc.ServicerContext
175
+ ) -> GetLoginDetailsResponse:
176
+ """Start login."""
177
+ log(INFO, "ExecServicer.GetLoginDetails")
178
+ if self.auth_plugin is None:
179
+ context.abort(
180
+ grpc.StatusCode.UNIMPLEMENTED,
181
+ "ExecServicer initialized without user authentication",
182
+ )
183
+ raise grpc.RpcError() # This line is unreachable
184
+ return GetLoginDetailsResponse(
185
+ login_details=self.auth_plugin.get_login_details()
186
+ )
187
+
188
+ def GetAuthTokens(
189
+ self, request: GetAuthTokensRequest, context: grpc.ServicerContext
190
+ ) -> GetAuthTokensResponse:
191
+ """Get auth token."""
192
+ log(INFO, "ExecServicer.GetAuthTokens")
193
+ if self.auth_plugin is None:
194
+ context.abort(
195
+ grpc.StatusCode.UNIMPLEMENTED,
196
+ "ExecServicer initialized without user authentication",
197
+ )
198
+ raise grpc.RpcError() # This line is unreachable
199
+ return GetAuthTokensResponse(
200
+ auth_tokens=self.auth_plugin.get_auth_tokens(dict(request.auth_details))
201
+ )
202
+
129
203
 
130
204
  def _create_list_runs_response(run_ids: set[int], state: LinkState) -> ListRunsResponse:
131
205
  """Create response for `flwr ls --runs` and `flwr ls --run-id <run_id>`."""