flwr 1.13.1__py3-none-any.whl → 1.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. flwr/cli/app.py +5 -0
  2. flwr/cli/build.py +1 -0
  3. flwr/cli/cli_user_auth_interceptor.py +86 -0
  4. flwr/cli/config_utils.py +19 -2
  5. flwr/cli/example.py +1 -0
  6. flwr/cli/install.py +1 -0
  7. flwr/cli/log.py +18 -36
  8. flwr/cli/login/__init__.py +22 -0
  9. flwr/cli/login/login.py +81 -0
  10. flwr/cli/ls.py +205 -106
  11. flwr/cli/new/__init__.py +1 -0
  12. flwr/cli/new/new.py +2 -1
  13. flwr/cli/new/templates/app/.gitignore.tpl +3 -0
  14. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  15. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +3 -3
  16. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  17. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  18. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -3
  19. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  20. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  21. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  22. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  23. flwr/cli/run/__init__.py +1 -0
  24. flwr/cli/run/run.py +89 -39
  25. flwr/cli/stop.py +130 -0
  26. flwr/cli/utils.py +172 -8
  27. flwr/client/app.py +14 -3
  28. flwr/client/client.py +1 -32
  29. flwr/client/clientapp/app.py +4 -1
  30. flwr/client/clientapp/utils.py +1 -0
  31. flwr/client/grpc_adapter_client/connection.py +1 -1
  32. flwr/client/grpc_client/connection.py +1 -1
  33. flwr/client/grpc_rere_client/connection.py +13 -7
  34. flwr/client/message_handler/message_handler.py +1 -2
  35. flwr/client/mod/comms_mods.py +1 -0
  36. flwr/client/mod/localdp_mod.py +1 -1
  37. flwr/client/nodestate/__init__.py +1 -0
  38. flwr/client/nodestate/nodestate.py +1 -0
  39. flwr/client/nodestate/nodestate_factory.py +1 -0
  40. flwr/client/numpy_client.py +0 -44
  41. flwr/client/rest_client/connection.py +3 -3
  42. flwr/client/supernode/app.py +2 -2
  43. flwr/common/address.py +1 -0
  44. flwr/common/args.py +1 -0
  45. flwr/common/auth_plugin/__init__.py +24 -0
  46. flwr/common/auth_plugin/auth_plugin.py +111 -0
  47. flwr/common/config.py +3 -1
  48. flwr/common/constant.py +17 -1
  49. flwr/common/logger.py +40 -0
  50. flwr/common/message.py +1 -0
  51. flwr/common/object_ref.py +57 -54
  52. flwr/common/pyproject.py +1 -0
  53. flwr/common/record/__init__.py +1 -0
  54. flwr/common/record/parametersrecord.py +1 -0
  55. flwr/common/retry_invoker.py +77 -0
  56. flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
  57. flwr/common/telemetry.py +15 -4
  58. flwr/common/typing.py +12 -0
  59. flwr/common/version.py +1 -0
  60. flwr/proto/exec_pb2.py +38 -14
  61. flwr/proto/exec_pb2.pyi +107 -2
  62. flwr/proto/exec_pb2_grpc.py +102 -0
  63. flwr/proto/exec_pb2_grpc.pyi +39 -0
  64. flwr/proto/fab_pb2.py +4 -4
  65. flwr/proto/fab_pb2.pyi +4 -1
  66. flwr/proto/serverappio_pb2.py +18 -18
  67. flwr/proto/serverappio_pb2.pyi +8 -2
  68. flwr/proto/serverappio_pb2_grpc.py +34 -0
  69. flwr/proto/serverappio_pb2_grpc.pyi +13 -0
  70. flwr/proto/simulationio_pb2.py +2 -2
  71. flwr/proto/simulationio_pb2_grpc.py +34 -0
  72. flwr/proto/simulationio_pb2_grpc.pyi +13 -0
  73. flwr/server/app.py +54 -2
  74. flwr/server/compat/app_utils.py +7 -1
  75. flwr/server/driver/grpc_driver.py +11 -63
  76. flwr/server/driver/inmemory_driver.py +5 -1
  77. flwr/server/run_serverapp.py +8 -9
  78. flwr/server/serverapp/app.py +25 -3
  79. flwr/server/strategy/dpfedavg_fixed.py +1 -0
  80. flwr/server/superlink/driver/serverappio_grpc.py +1 -0
  81. flwr/server/superlink/driver/serverappio_servicer.py +82 -23
  82. flwr/server/superlink/ffs/disk_ffs.py +1 -0
  83. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
  84. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
  85. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
  86. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +12 -11
  87. flwr/server/superlink/fleet/message_handler/message_handler.py +32 -5
  88. flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
  89. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  90. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
  91. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
  92. flwr/server/superlink/linkstate/in_memory_linkstate.py +21 -30
  93. flwr/server/superlink/linkstate/linkstate.py +17 -2
  94. flwr/server/superlink/linkstate/sqlite_linkstate.py +30 -49
  95. flwr/server/superlink/simulation/simulationio_servicer.py +33 -0
  96. flwr/server/superlink/utils.py +65 -0
  97. flwr/simulation/app.py +16 -4
  98. flwr/simulation/ray_transport/ray_actor.py +1 -0
  99. flwr/simulation/ray_transport/utils.py +1 -0
  100. flwr/simulation/run_simulation.py +36 -22
  101. flwr/simulation/simulationio_connection.py +3 -0
  102. flwr/superexec/app.py +1 -0
  103. flwr/superexec/deployment.py +1 -0
  104. flwr/superexec/exec_grpc.py +19 -1
  105. flwr/superexec/exec_servicer.py +76 -2
  106. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  107. flwr/superexec/executor.py +1 -0
  108. {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/METADATA +8 -7
  109. {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/RECORD +112 -112
  110. flwr/proto/common_pb2.py +0 -36
  111. flwr/proto/common_pb2.pyi +0 -121
  112. flwr/proto/common_pb2_grpc.py +0 -4
  113. flwr/proto/common_pb2_grpc.pyi +0 -4
  114. flwr/proto/control_pb2.py +0 -27
  115. flwr/proto/control_pb2.pyi +0 -7
  116. flwr/proto/control_pb2_grpc.py +0 -135
  117. flwr/proto/control_pb2_grpc.pyi +0 -53
  118. {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/LICENSE +0 -0
  119. {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/WHEEL +0 -0
  120. {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/entry_points.txt +0 -0
@@ -139,8 +139,19 @@ class LinkState(abc.ABC): # pylint: disable=R0904
139
139
  """
140
140
 
141
141
  @abc.abstractmethod
142
- def delete_tasks(self, task_ids: set[UUID]) -> None:
143
- """Delete all delivered TaskIns/TaskRes pairs."""
142
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
143
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.
144
+
145
+ Parameters
146
+ ----------
147
+ task_ins_ids : set[UUID]
148
+ A set of TaskIns IDs. For each ID in the set, the corresponding
149
+ TaskIns and its associated TaskRes will be deleted.
150
+ """
151
+
152
+ @abc.abstractmethod
153
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
154
+ """Get all TaskIns IDs for the given run_id."""
144
155
 
145
156
  @abc.abstractmethod
146
157
  def create_node(
@@ -273,6 +284,10 @@ class LinkState(abc.ABC): # pylint: disable=R0904
273
284
  def get_server_public_key(self) -> Optional[bytes]:
274
285
  """Retrieve `server_public_key` in urlsafe bytes."""
275
286
 
287
+ @abc.abstractmethod
288
+ def clear_supernode_auth_keys_and_credentials(self) -> None:
289
+ """Clear stored `node_public_keys` and credentials in the link state if any."""
290
+
276
291
  @abc.abstractmethod
277
292
  def store_node_public_keys(self, public_keys: set[bytes]) -> None:
278
293
  """Store a set of `node_public_keys` in the link state."""
@@ -14,12 +14,12 @@
14
14
  # ==============================================================================
15
15
  """SQLite based implemenation of the link state."""
16
16
 
17
+
17
18
  # pylint: disable=too-many-lines
18
19
 
19
20
  import json
20
21
  import re
21
22
  import sqlite3
22
- import threading
23
23
  import time
24
24
  from collections.abc import Sequence
25
25
  from logging import DEBUG, ERROR, WARNING
@@ -183,7 +183,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
183
183
  """
184
184
  self.database_path = database_path
185
185
  self.conn: Optional[sqlite3.Connection] = None
186
- self.lock = threading.RLock()
187
186
 
188
187
  def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
189
188
  """Create tables if they don't exist yet.
@@ -216,7 +215,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
216
215
  cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
217
216
  cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
218
217
  res = cur.execute("SELECT name FROM sqlite_schema;")
219
-
220
218
  return res.fetchall()
221
219
 
222
220
  def query(
@@ -569,9 +567,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
569
567
  data: list[Any] = [delivered_at] + task_res_ids
570
568
  self.query(query, data)
571
569
 
572
- # Cleanup
573
- self._force_delete_tasks_by_ids(set(ret.keys()))
574
-
575
570
  return list(ret.values())
576
571
 
577
572
  def num_task_ins(self) -> int:
@@ -595,68 +590,50 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
595
590
  result: dict[str, int] = rows[0]
596
591
  return result["num"]
597
592
 
598
- def delete_tasks(self, task_ids: set[UUID]) -> None:
599
- """Delete all delivered TaskIns/TaskRes pairs."""
600
- ids = list(task_ids)
601
- if len(ids) == 0:
602
- return None
593
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
594
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
595
+ if not task_ins_ids:
596
+ return
597
+ if self.conn is None:
598
+ raise AttributeError("LinkState not initialized")
603
599
 
604
- placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
605
- data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
600
+ placeholders = ",".join(["?"] * len(task_ins_ids))
601
+ data = tuple(str(task_id) for task_id in task_ins_ids)
606
602
 
607
- # 1. Query: Delete task_ins which have a delivered task_res
603
+ # Delete task_ins
608
604
  query_1 = f"""
609
605
  DELETE FROM task_ins
610
- WHERE delivered_at != ''
611
- AND task_id IN (
612
- SELECT ancestry
613
- FROM task_res
614
- WHERE ancestry IN ({placeholders})
615
- AND delivered_at != ''
616
- );
606
+ WHERE task_id IN ({placeholders});
617
607
  """
618
608
 
619
- # 2. Query: Delete delivered task_res to be run after 1. Query
609
+ # Delete task_res
620
610
  query_2 = f"""
621
611
  DELETE FROM task_res
622
- WHERE ancestry IN ({placeholders})
623
- AND delivered_at != '';
612
+ WHERE ancestry IN ({placeholders});
624
613
  """
625
614
 
626
- if self.conn is None:
627
- raise AttributeError("LinkState not intitialized")
628
-
629
615
  with self.conn:
630
616
  self.conn.execute(query_1, data)
631
617
  self.conn.execute(query_2, data)
632
618
 
633
- return None
634
-
635
- def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
636
- """Delete tasks based on a set of TaskIns IDs."""
637
- if not task_ids:
638
- return
619
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
620
+ """Get all TaskIns IDs for the given run_id."""
639
621
  if self.conn is None:
640
622
  raise AttributeError("LinkState not initialized")
641
623
 
642
- placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
643
- data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
644
-
645
- # Delete task_ins
646
- query_1 = f"""
647
- DELETE FROM task_ins
648
- WHERE task_id IN ({placeholders});
624
+ query = """
625
+ SELECT task_id
626
+ FROM task_ins
627
+ WHERE run_id = :run_id;
649
628
  """
650
629
 
651
- # Delete task_res
652
- query_2 = f"""
653
- DELETE FROM task_res
654
- WHERE ancestry IN ({placeholders});
655
- """
630
+ sint64_run_id = convert_uint64_to_sint64(run_id)
631
+ data = {"run_id": sint64_run_id}
656
632
 
657
633
  with self.conn:
658
- self.conn.execute(query_1, data)
659
- self.conn.execute(query_2, data)
634
+ rows = self.conn.execute(query, data).fetchall()
635
+
636
+ return {UUID(row["task_id"]) for row in rows}
660
637
 
661
638
  def create_node(
662
639
  self, ping_interval: float, public_key: Optional[bytes] = None
@@ -784,8 +761,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
784
761
  "federation_options, pending_at, starting_at, running_at, finished_at, "
785
762
  "sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
786
763
  )
787
- if fab_hash:
788
- fab_id, fab_version = "", ""
789
764
  override_config_json = json.dumps(override_config)
790
765
  data = [
791
766
  sint64_run_id,
@@ -843,6 +818,12 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
843
818
  public_key = None
844
819
  return public_key
845
820
 
821
+ def clear_supernode_auth_keys_and_credentials(self) -> None:
822
+ """Clear stored `node_public_keys` and credentials in the link state if any."""
823
+ queries = ["DELETE FROM public_key;", "DELETE FROM credential;"]
824
+ for query in queries:
825
+ self.query(query)
826
+
846
827
  def store_node_public_keys(self, public_keys: set[bytes]) -> None:
847
828
  """Store a set of `node_public_keys` in the link state."""
848
829
  query = "INSERT INTO public_key (public_key) VALUES (?)"
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """SimulationIo API servicer."""
16
16
 
17
+
17
18
  import threading
18
19
  from logging import DEBUG, INFO
19
20
 
@@ -28,6 +29,7 @@ from flwr.common.serde import (
28
29
  context_to_proto,
29
30
  fab_to_proto,
30
31
  run_status_from_proto,
32
+ run_status_to_proto,
31
33
  run_to_proto,
32
34
  )
33
35
  from flwr.common.typing import Fab, RunStatus
@@ -39,6 +41,8 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
39
41
  from flwr.proto.run_pb2 import ( # pylint: disable=E0611
40
42
  GetFederationOptionsRequest,
41
43
  GetFederationOptionsResponse,
44
+ GetRunStatusRequest,
45
+ GetRunStatusResponse,
42
46
  UpdateRunStatusRequest,
43
47
  UpdateRunStatusResponse,
44
48
  )
@@ -50,6 +54,7 @@ from flwr.proto.simulationio_pb2 import ( # pylint: disable=E0611
50
54
  )
51
55
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
52
56
  from flwr.server.superlink.linkstate import LinkStateFactory
57
+ from flwr.server.superlink.utils import abort_if
53
58
 
54
59
 
55
60
  class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
@@ -106,6 +111,15 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
106
111
  """Push Simulation process outputs."""
107
112
  log(DEBUG, "SimultionIoServicer.PushSimulationOutputs")
108
113
  state = self.state_factory.state()
114
+
115
+ # Abort if the run is not running
116
+ abort_if(
117
+ request.run_id,
118
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
119
+ state,
120
+ context,
121
+ )
122
+
109
123
  state.set_serverapp_context(request.run_id, context_from_proto(request.context))
110
124
  return PushSimulationOutputsResponse()
111
125
 
@@ -116,12 +130,31 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
116
130
  log(DEBUG, "SimultionIoServicer.UpdateRunStatus")
117
131
  state = self.state_factory.state()
118
132
 
133
+ # Abort if the run is finished
134
+ abort_if(request.run_id, [Status.FINISHED], state, context)
135
+
119
136
  # Update the run status
120
137
  state.update_run_status(
121
138
  run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
122
139
  )
123
140
  return UpdateRunStatusResponse()
124
141
 
142
+ def GetRunStatus(
143
+ self, request: GetRunStatusRequest, context: ServicerContext
144
+ ) -> GetRunStatusResponse:
145
+ """Get status of requested runs."""
146
+ log(DEBUG, "SimultionIoServicer.GetRunStatus")
147
+ state = self.state_factory.state()
148
+
149
+ statuses = state.get_run_status(set(request.run_ids))
150
+
151
+ return GetRunStatusResponse(
152
+ run_status_dict={
153
+ run_id: run_status_to_proto(status)
154
+ for run_id, status in statuses.items()
155
+ }
156
+ )
157
+
125
158
  def PushLogs(
126
159
  self, request: PushLogsRequest, context: grpc.ServicerContext
127
160
  ) -> PushLogsResponse:
@@ -0,0 +1,65 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """SuperLink utilities."""
16
+
17
+
18
+ from typing import Union
19
+
20
+ import grpc
21
+
22
+ from flwr.common.constant import Status, SubStatus
23
+ from flwr.common.typing import RunStatus
24
+ from flwr.server.superlink.linkstate import LinkState
25
+
26
+ _STATUS_TO_MSG = {
27
+ Status.PENDING: "Run is pending.",
28
+ Status.STARTING: "Run is starting.",
29
+ Status.RUNNING: "Run is running.",
30
+ Status.FINISHED: "Run is finished.",
31
+ }
32
+
33
+
34
+ def check_abort(
35
+ run_id: int,
36
+ abort_status_list: list[str],
37
+ state: LinkState,
38
+ ) -> Union[str, None]:
39
+ """Check if the status of the provided `run_id` is in `abort_status_list`."""
40
+ run_status: RunStatus = state.get_run_status({run_id})[run_id]
41
+
42
+ if run_status.status in abort_status_list:
43
+ msg = _STATUS_TO_MSG[run_status.status]
44
+ if run_status.sub_status == SubStatus.STOPPED:
45
+ msg += " Stopped by user."
46
+ return msg
47
+
48
+ return None
49
+
50
+
51
+ def abort_grpc_context(msg: Union[str, None], context: grpc.ServicerContext) -> None:
52
+ """Abort context with statuscode PERMISSION_DENIED if `msg` is not None."""
53
+ if msg is not None:
54
+ context.abort(grpc.StatusCode.PERMISSION_DENIED, msg)
55
+
56
+
57
+ def abort_if(
58
+ run_id: int,
59
+ abort_status_list: list[str],
60
+ state: LinkState,
61
+ context: grpc.ServicerContext,
62
+ ) -> None:
63
+ """Abort context if status of the provided `run_id` is in `abort_status_list`."""
64
+ msg = check_abort(run_id, abort_status_list, state)
65
+ abort_grpc_context(msg, context)
flwr/simulation/app.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower Simulation process."""
16
16
 
17
+
17
18
  import argparse
18
19
  import sys
19
20
  from logging import DEBUG, ERROR, INFO
@@ -23,7 +24,8 @@ from typing import Optional
23
24
 
24
25
  from flwr.cli.config_utils import get_fab_metadata
25
26
  from flwr.cli.install import install_from_fab
26
- from flwr.common import EventType
27
+ from flwr.cli.utils import get_sha256_hash
28
+ from flwr.common import EventType, event
27
29
  from flwr.common.args import add_args_flwr_app_common
28
30
  from flwr.common.config import (
29
31
  get_flwr_dir,
@@ -47,6 +49,7 @@ from flwr.common.logger import (
47
49
  from flwr.common.serde import (
48
50
  configs_record_from_proto,
49
51
  context_from_proto,
52
+ context_to_proto,
50
53
  fab_from_proto,
51
54
  run_from_proto,
52
55
  run_status_to_proto,
@@ -200,8 +203,17 @@ def run_simulation_process( # pylint: disable=R0914, disable=W0212, disable=R09
200
203
  verbose: bool = fed_opt.get("verbose", False)
201
204
  enable_tf_gpu_growth: bool = fed_opt.get("enable_tf_gpu_growth", False)
202
205
 
206
+ event(
207
+ EventType.FLWR_SIMULATION_RUN_ENTER,
208
+ event_details={
209
+ "backend": "ray",
210
+ "num-supernodes": num_supernodes,
211
+ "run-id-hash": get_sha256_hash(run.run_id),
212
+ },
213
+ )
214
+
203
215
  # Launch the simulation
204
- _run_simulation(
216
+ updated_context = _run_simulation(
205
217
  server_app_attr=server_app_attr,
206
218
  client_app_attr=client_app_attr,
207
219
  num_supernodes=num_supernodes,
@@ -212,11 +224,11 @@ def run_simulation_process( # pylint: disable=R0914, disable=W0212, disable=R09
212
224
  verbose_logging=verbose,
213
225
  server_app_run_config=fused_config,
214
226
  is_app=True,
215
- exit_event=EventType.CLI_FLOWER_SIMULATION_LEAVE,
227
+ exit_event=EventType.FLWR_SIMULATION_RUN_LEAVE,
216
228
  )
217
229
 
218
230
  # Send resulting context
219
- context_proto = None # context_to_proto(updated_context)
231
+ context_proto = context_to_proto(updated_context)
220
232
  out_req = PushSimulationOutputsRequest(
221
233
  run_id=run.run_id, context=context_proto
222
234
  )
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Ray-based Flower Actor and ActorPool implementation."""
16
16
 
17
+
17
18
  import threading
18
19
  from abc import ABC
19
20
  from logging import DEBUG, ERROR, WARNING
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Utilities for Actors in the Virtual Client Engine."""
16
16
 
17
+
17
18
  import traceback
18
19
  import warnings
19
20
  from logging import ERROR
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower Simulation."""
16
16
 
17
+
17
18
  import argparse
18
19
  import asyncio
19
20
  import json
@@ -23,10 +24,11 @@ import threading
23
24
  import traceback
24
25
  from logging import DEBUG, ERROR, INFO, WARNING
25
26
  from pathlib import Path
26
- from time import sleep
27
+ from queue import Empty, Queue
27
28
  from typing import Any, Optional
28
29
 
29
30
  from flwr.cli.config_utils import load_and_validate
31
+ from flwr.cli.utils import get_sha256_hash
30
32
  from flwr.client import ClientApp
31
33
  from flwr.common import Context, EventType, RecordSet, event, log, now
32
34
  from flwr.common.config import get_fused_config_from_dir, parse_config_args
@@ -126,7 +128,7 @@ def run_simulation_from_cli() -> None:
126
128
  run = Run.create_empty(run_id)
127
129
  run.override_config = override_config
128
130
 
129
- _run_simulation(
131
+ _ = _run_simulation(
130
132
  server_app_attr=server_app_attr,
131
133
  client_app_attr=client_app_attr,
132
134
  num_supernodes=args.num_supernodes,
@@ -135,7 +137,6 @@ def run_simulation_from_cli() -> None:
135
137
  app_dir=args.app,
136
138
  run=run,
137
139
  enable_tf_gpu_growth=args.enable_tf_gpu_growth,
138
- delay_start=args.delay_start,
139
140
  verbose_logging=args.verbose,
140
141
  server_app_run_config=fused_config,
141
142
  is_app=True,
@@ -207,7 +208,7 @@ def run_simulation(
207
208
  "\n\tflwr.simulation.run_simulationt(...)",
208
209
  )
209
210
 
210
- _run_simulation(
211
+ _ = _run_simulation(
211
212
  num_supernodes=num_supernodes,
212
213
  client_app=client_app,
213
214
  server_app=server_app,
@@ -230,6 +231,7 @@ def run_serverapp_th(
230
231
  has_exception: threading.Event,
231
232
  enable_tf_gpu_growth: bool,
232
233
  run_id: int,
234
+ ctx_queue: "Queue[Context]",
233
235
  ) -> threading.Thread:
234
236
  """Run SeverApp in a thread."""
235
237
 
@@ -242,6 +244,7 @@ def run_serverapp_th(
242
244
  _server_app_run_config: UserConfig,
243
245
  _server_app_attr: Optional[str],
244
246
  _server_app: Optional[ServerApp],
247
+ _ctx_queue: "Queue[Context]",
245
248
  ) -> None:
246
249
  """Run SeverApp, after check if GPU memory growth has to be set.
247
250
 
@@ -262,13 +265,14 @@ def run_serverapp_th(
262
265
  )
263
266
 
264
267
  # Run ServerApp
265
- _run(
268
+ updated_context = _run(
266
269
  driver=_driver,
267
270
  context=context,
268
271
  server_app_dir=_server_app_dir,
269
272
  server_app_attr=_server_app_attr,
270
273
  loaded_server_app=_server_app,
271
274
  )
275
+ _ctx_queue.put(updated_context)
272
276
  except Exception as ex: # pylint: disable=broad-exception-caught
273
277
  log(ERROR, "ServerApp thread raised an exception: %s", ex)
274
278
  log(ERROR, traceback.format_exc())
@@ -292,6 +296,7 @@ def run_serverapp_th(
292
296
  server_app_run_config,
293
297
  server_app_attr,
294
298
  server_app,
299
+ ctx_queue,
295
300
  ),
296
301
  )
297
302
  serverapp_th.start()
@@ -308,14 +313,13 @@ def _main_loop(
308
313
  enable_tf_gpu_growth: bool,
309
314
  run: Run,
310
315
  exit_event: EventType,
311
- delay_start: int,
312
316
  flwr_dir: Optional[str] = None,
313
317
  client_app: Optional[ClientApp] = None,
314
318
  client_app_attr: Optional[str] = None,
315
319
  server_app: Optional[ServerApp] = None,
316
320
  server_app_attr: Optional[str] = None,
317
321
  server_app_run_config: Optional[UserConfig] = None,
318
- ) -> None:
322
+ ) -> Context:
319
323
  """Start ServerApp on a separate thread, then launch Simulation Engine."""
320
324
  # Initialize StateFactory
321
325
  state_factory = LinkStateFactory(":flwr-in-memory-state:")
@@ -325,6 +329,13 @@ def _main_loop(
325
329
  server_app_thread_has_exception = threading.Event()
326
330
  serverapp_th = None
327
331
  success = True
332
+ updated_context = Context(
333
+ run_id=run.run_id,
334
+ node_id=0,
335
+ node_config=UserConfig(),
336
+ state=RecordSet(),
337
+ run_config=UserConfig(),
338
+ )
328
339
  try:
329
340
  # Register run
330
341
  log(DEBUG, "Pre-registering run with id %s", run.run_id)
@@ -339,6 +350,7 @@ def _main_loop(
339
350
  # Initialize Driver
340
351
  driver = InMemoryDriver(state_factory=state_factory)
341
352
  driver.set_run(run_id=run.run_id)
353
+ output_context_queue: "Queue[Context]" = Queue()
342
354
 
343
355
  # Get and run ServerApp thread
344
356
  serverapp_th = run_serverapp_th(
@@ -351,11 +363,9 @@ def _main_loop(
351
363
  has_exception=server_app_thread_has_exception,
352
364
  enable_tf_gpu_growth=enable_tf_gpu_growth,
353
365
  run_id=run.run_id,
366
+ ctx_queue=output_context_queue,
354
367
  )
355
368
 
356
- # Buffer time so the `ServerApp` in separate thread is ready
357
- log(DEBUG, "Buffer time delay: %ds", delay_start)
358
- sleep(delay_start)
359
369
  # Start Simulation Engine
360
370
  vce.start_vce(
361
371
  num_supernodes=num_supernodes,
@@ -371,6 +381,11 @@ def _main_loop(
371
381
  flwr_dir=flwr_dir,
372
382
  )
373
383
 
384
+ updated_context = output_context_queue.get(timeout=3)
385
+
386
+ except Empty:
387
+ log(DEBUG, "Queue timeout. No context received.")
388
+
374
389
  except Exception as ex:
375
390
  log(ERROR, "An exception occurred !! %s", ex)
376
391
  log(ERROR, traceback.format_exc())
@@ -380,13 +395,20 @@ def _main_loop(
380
395
  finally:
381
396
  # Trigger stop event
382
397
  f_stop.set()
383
- event(exit_event, event_details={"success": success})
398
+ event(
399
+ exit_event,
400
+ event_details={
401
+ "run-id-hash": get_sha256_hash(run.run_id),
402
+ "success": success,
403
+ },
404
+ )
384
405
  if serverapp_th:
385
406
  serverapp_th.join()
386
407
  if server_app_thread_has_exception.is_set():
387
408
  raise RuntimeError("Exception in ServerApp thread")
388
409
 
389
410
  log(DEBUG, "Stopping Simulation Engine now.")
411
+ return updated_context
390
412
 
391
413
 
392
414
  # pylint: disable=too-many-arguments,too-many-locals,too-many-positional-arguments
@@ -404,10 +426,9 @@ def _run_simulation(
404
426
  flwr_dir: Optional[str] = None,
405
427
  run: Optional[Run] = None,
406
428
  enable_tf_gpu_growth: bool = False,
407
- delay_start: int = 5,
408
429
  verbose_logging: bool = False,
409
430
  is_app: bool = False,
410
- ) -> None:
431
+ ) -> Context:
411
432
  """Launch the Simulation Engine."""
412
433
  if backend_config is None:
413
434
  backend_config = {}
@@ -459,7 +480,6 @@ def _run_simulation(
459
480
  enable_tf_gpu_growth,
460
481
  run,
461
482
  exit_event,
462
- delay_start,
463
483
  flwr_dir,
464
484
  client_app,
465
485
  client_app_attr,
@@ -487,7 +507,8 @@ def _run_simulation(
487
507
  # Set logger propagation to False to prevent duplicated log output in Colab.
488
508
  logger = set_logger_propagation(logger, False)
489
509
 
490
- _main_loop(*args)
510
+ updated_context = _main_loop(*args)
511
+ return updated_context
491
512
 
492
513
 
493
514
  def _parse_args_run_simulation() -> argparse.ArgumentParser:
@@ -537,13 +558,6 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
537
558
  "Read more about how `tf.config.experimental.set_memory_growth()` works in "
538
559
  "the TensorFlow documentation: https://www.tensorflow.org/api/stable.",
539
560
  )
540
- parser.add_argument(
541
- "--delay-start",
542
- type=int,
543
- default=3,
544
- help="Buffer time (in seconds) to delay the start the simulation engine after "
545
- "the `ServerApp`, which runs in a separate thread, has been launched.",
546
- )
547
561
  parser.add_argument(
548
562
  "--verbose",
549
563
  action="store_true",
@@ -23,6 +23,7 @@ import grpc
23
23
  from flwr.common.constant import SIMULATIONIO_API_DEFAULT_CLIENT_ADDRESS
24
24
  from flwr.common.grpc import create_channel
25
25
  from flwr.common.logger import log
26
+ from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
26
27
  from flwr.proto.simulationio_pb2_grpc import SimulationIoStub # pylint: disable=E0611
27
28
 
28
29
 
@@ -48,6 +49,7 @@ class SimulationIoConnection:
48
49
  self._cert = root_certificates
49
50
  self._grpc_stub: Optional[SimulationIoStub] = None
50
51
  self._channel: Optional[grpc.Channel] = None
52
+ self._retry_invoker = _make_simple_grpc_retry_invoker()
51
53
 
52
54
  @property
53
55
  def _is_connected(self) -> bool:
@@ -72,6 +74,7 @@ class SimulationIoConnection:
72
74
  root_certificates=self._cert,
73
75
  )
74
76
  self._grpc_stub = SimulationIoStub(self._channel)
77
+ _wrap_stub(self._grpc_stub, self._retry_invoker)
75
78
  log(DEBUG, "[SimulationIO] Connected to %s", self._addr)
76
79
 
77
80
  def _disconnect(self) -> None:
flwr/superexec/app.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower SuperExec app."""
16
16
 
17
+
17
18
  import argparse
18
19
  import sys
19
20
  from logging import INFO
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Deployment engine executor."""
16
16
 
17
+
17
18
  import hashlib
18
19
  from logging import ERROR, INFO
19
20
  from pathlib import Path