flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (237) hide show
  1. flwr/cli/app.py +7 -0
  2. flwr/cli/build.py +150 -0
  3. flwr/cli/config_utils.py +219 -0
  4. flwr/cli/example.py +3 -1
  5. flwr/cli/install.py +227 -0
  6. flwr/cli/new/new.py +179 -48
  7. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  8. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  9. flwr/cli/new/templates/app/README.md.tpl +1 -5
  10. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
  12. flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
  13. flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
  14. flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
  15. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
  16. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
  17. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
  18. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  19. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
  20. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
  21. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  22. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  24. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  26. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  27. flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
  28. flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
  29. flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
  30. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  31. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
  32. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  33. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
  34. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  35. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  36. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
  37. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  38. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
  39. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  40. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
  41. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  42. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  43. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  46. flwr/cli/run/run.py +168 -17
  47. flwr/cli/utils.py +75 -4
  48. flwr/client/__init__.py +6 -1
  49. flwr/client/app.py +239 -248
  50. flwr/client/client_app.py +70 -9
  51. flwr/client/dpfedavg_numpy_client.py +1 -1
  52. flwr/client/grpc_adapter_client/__init__.py +15 -0
  53. flwr/client/grpc_adapter_client/connection.py +97 -0
  54. flwr/client/grpc_client/connection.py +18 -5
  55. flwr/client/grpc_rere_client/__init__.py +1 -1
  56. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  57. flwr/client/grpc_rere_client/connection.py +127 -33
  58. flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
  59. flwr/client/heartbeat.py +74 -0
  60. flwr/client/message_handler/__init__.py +1 -1
  61. flwr/client/message_handler/message_handler.py +7 -7
  62. flwr/client/mod/__init__.py +5 -5
  63. flwr/client/mod/centraldp_mods.py +4 -2
  64. flwr/client/mod/comms_mods.py +4 -4
  65. flwr/client/mod/localdp_mod.py +9 -4
  66. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  67. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  68. flwr/client/mod/utils.py +1 -1
  69. flwr/client/node_state.py +60 -10
  70. flwr/client/node_state_tests.py +4 -3
  71. flwr/client/rest_client/__init__.py +1 -1
  72. flwr/client/rest_client/connection.py +177 -157
  73. flwr/client/supernode/__init__.py +26 -0
  74. flwr/client/supernode/app.py +464 -0
  75. flwr/client/typing.py +1 -0
  76. flwr/common/__init__.py +13 -11
  77. flwr/common/address.py +1 -1
  78. flwr/common/config.py +193 -0
  79. flwr/common/constant.py +42 -1
  80. flwr/common/context.py +26 -1
  81. flwr/common/date.py +1 -1
  82. flwr/common/dp.py +1 -1
  83. flwr/common/grpc.py +6 -2
  84. flwr/common/logger.py +79 -8
  85. flwr/common/message.py +167 -105
  86. flwr/common/object_ref.py +126 -25
  87. flwr/common/record/__init__.py +1 -1
  88. flwr/common/record/parametersrecord.py +0 -1
  89. flwr/common/record/recordset.py +78 -27
  90. flwr/common/recordset_compat.py +8 -1
  91. flwr/common/retry_invoker.py +25 -13
  92. flwr/common/secure_aggregation/__init__.py +1 -1
  93. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  94. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  95. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
  96. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  97. flwr/common/secure_aggregation/quantization.py +1 -1
  98. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  99. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  100. flwr/common/serde.py +209 -3
  101. flwr/common/telemetry.py +25 -0
  102. flwr/common/typing.py +38 -0
  103. flwr/common/version.py +14 -0
  104. flwr/proto/clientappio_pb2.py +41 -0
  105. flwr/proto/clientappio_pb2.pyi +110 -0
  106. flwr/proto/clientappio_pb2_grpc.py +101 -0
  107. flwr/proto/clientappio_pb2_grpc.pyi +40 -0
  108. flwr/proto/common_pb2.py +36 -0
  109. flwr/proto/common_pb2.pyi +121 -0
  110. flwr/proto/common_pb2_grpc.py +4 -0
  111. flwr/proto/common_pb2_grpc.pyi +4 -0
  112. flwr/proto/driver_pb2.py +26 -19
  113. flwr/proto/driver_pb2.pyi +34 -0
  114. flwr/proto/driver_pb2_grpc.py +70 -0
  115. flwr/proto/driver_pb2_grpc.pyi +28 -0
  116. flwr/proto/exec_pb2.py +43 -0
  117. flwr/proto/exec_pb2.pyi +95 -0
  118. flwr/proto/exec_pb2_grpc.py +101 -0
  119. flwr/proto/exec_pb2_grpc.pyi +41 -0
  120. flwr/proto/fab_pb2.py +30 -0
  121. flwr/proto/fab_pb2.pyi +56 -0
  122. flwr/proto/fab_pb2_grpc.py +4 -0
  123. flwr/proto/fab_pb2_grpc.pyi +4 -0
  124. flwr/proto/fleet_pb2.py +29 -23
  125. flwr/proto/fleet_pb2.pyi +33 -0
  126. flwr/proto/fleet_pb2_grpc.py +102 -0
  127. flwr/proto/fleet_pb2_grpc.pyi +35 -0
  128. flwr/proto/grpcadapter_pb2.py +32 -0
  129. flwr/proto/grpcadapter_pb2.pyi +43 -0
  130. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  131. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  132. flwr/proto/message_pb2.py +41 -0
  133. flwr/proto/message_pb2.pyi +122 -0
  134. flwr/proto/message_pb2_grpc.py +4 -0
  135. flwr/proto/message_pb2_grpc.pyi +4 -0
  136. flwr/proto/run_pb2.py +35 -0
  137. flwr/proto/run_pb2.pyi +76 -0
  138. flwr/proto/run_pb2_grpc.py +4 -0
  139. flwr/proto/run_pb2_grpc.pyi +4 -0
  140. flwr/proto/task_pb2.py +7 -8
  141. flwr/proto/task_pb2.pyi +8 -5
  142. flwr/server/__init__.py +4 -8
  143. flwr/server/app.py +298 -350
  144. flwr/server/compat/app.py +6 -57
  145. flwr/server/compat/app_utils.py +5 -4
  146. flwr/server/compat/driver_client_proxy.py +29 -48
  147. flwr/server/compat/legacy_context.py +5 -4
  148. flwr/server/driver/__init__.py +2 -0
  149. flwr/server/driver/driver.py +22 -132
  150. flwr/server/driver/grpc_driver.py +224 -74
  151. flwr/server/driver/inmemory_driver.py +183 -0
  152. flwr/server/history.py +20 -20
  153. flwr/server/run_serverapp.py +121 -34
  154. flwr/server/server.py +11 -7
  155. flwr/server/server_app.py +59 -10
  156. flwr/server/serverapp_components.py +52 -0
  157. flwr/server/strategy/__init__.py +2 -2
  158. flwr/server/strategy/bulyan.py +1 -1
  159. flwr/server/strategy/dp_adaptive_clipping.py +3 -3
  160. flwr/server/strategy/dp_fixed_clipping.py +4 -3
  161. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  162. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  163. flwr/server/strategy/fedadagrad.py +1 -1
  164. flwr/server/strategy/fedadam.py +1 -1
  165. flwr/server/strategy/fedavg_android.py +1 -1
  166. flwr/server/strategy/fedavgm.py +1 -1
  167. flwr/server/strategy/fedmedian.py +1 -1
  168. flwr/server/strategy/fedopt.py +1 -1
  169. flwr/server/strategy/fedprox.py +1 -1
  170. flwr/server/strategy/fedxgb_bagging.py +1 -1
  171. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  172. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  173. flwr/server/strategy/fedyogi.py +1 -1
  174. flwr/server/strategy/krum.py +1 -1
  175. flwr/server/strategy/qfedavg.py +1 -1
  176. flwr/server/superlink/driver/__init__.py +1 -1
  177. flwr/server/superlink/driver/driver_grpc.py +1 -1
  178. flwr/server/superlink/driver/driver_servicer.py +51 -4
  179. flwr/server/superlink/ffs/__init__.py +24 -0
  180. flwr/server/superlink/ffs/disk_ffs.py +104 -0
  181. flwr/server/superlink/ffs/ffs.py +79 -0
  182. flwr/server/superlink/fleet/__init__.py +1 -1
  183. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  184. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  185. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  186. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  187. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
  190. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
  192. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
  193. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  194. flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
  195. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  196. flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
  197. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  198. flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
  199. flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
  200. flwr/server/superlink/fleet/vce/vce_api.py +190 -127
  201. flwr/server/superlink/state/__init__.py +1 -1
  202. flwr/server/superlink/state/in_memory_state.py +159 -42
  203. flwr/server/superlink/state/sqlite_state.py +243 -39
  204. flwr/server/superlink/state/state.py +81 -6
  205. flwr/server/superlink/state/state_factory.py +11 -2
  206. flwr/server/superlink/state/utils.py +62 -0
  207. flwr/server/typing.py +2 -0
  208. flwr/server/utils/__init__.py +1 -1
  209. flwr/server/utils/tensorboard.py +1 -1
  210. flwr/server/utils/validator.py +23 -9
  211. flwr/server/workflow/default_workflows.py +67 -25
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
  213. flwr/simulation/__init__.py +7 -4
  214. flwr/simulation/app.py +67 -36
  215. flwr/simulation/ray_transport/__init__.py +1 -1
  216. flwr/simulation/ray_transport/ray_actor.py +20 -46
  217. flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
  218. flwr/simulation/run_simulation.py +308 -92
  219. flwr/superexec/__init__.py +21 -0
  220. flwr/superexec/app.py +184 -0
  221. flwr/superexec/deployment.py +185 -0
  222. flwr/superexec/exec_grpc.py +55 -0
  223. flwr/superexec/exec_servicer.py +70 -0
  224. flwr/superexec/executor.py +75 -0
  225. flwr/superexec/simulation.py +193 -0
  226. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
  227. flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
  228. flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
  229. flwr/cli/flower_toml.py +0 -140
  230. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  231. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  232. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  233. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  234. flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
  235. flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
  236. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
  237. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
@@ -15,19 +15,32 @@
15
15
  """Fleet Simulation Engine API."""
16
16
 
17
17
 
18
- import asyncio
19
18
  import json
19
+ import threading
20
+ import time
20
21
  import traceback
22
+ from concurrent.futures import ThreadPoolExecutor
21
23
  from logging import DEBUG, ERROR, INFO, WARN
22
- from typing import Callable, Dict, List, Optional
24
+ from pathlib import Path
25
+ from queue import Empty, Queue
26
+ from time import sleep
27
+ from typing import Callable, Dict, Optional
23
28
 
24
- from flwr.client.client_app import ClientApp, LoadClientAppError
29
+ from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
25
30
  from flwr.client.node_state import NodeState
31
+ from flwr.client.supernode.app import _get_load_client_app_fn
32
+ from flwr.common.constant import (
33
+ NUM_PARTITIONS_KEY,
34
+ PARTITION_ID_KEY,
35
+ PING_MAX_INTERVAL,
36
+ ErrorCode,
37
+ )
26
38
  from flwr.common.logger import log
27
- from flwr.common.object_ref import load_app
39
+ from flwr.common.message import Error
28
40
  from flwr.common.serde import message_from_taskins, message_to_taskres
29
- from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
30
- from flwr.server.superlink.state import StateFactory
41
+ from flwr.common.typing import Run
42
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
43
+ from flwr.server.superlink.state import State, StateFactory
31
44
 
32
45
  from .backend import Backend, error_messages_backends, supported_backends
33
46
 
@@ -41,39 +54,63 @@ def _register_nodes(
41
54
  nodes_mapping: NodeToPartitionMapping = {}
42
55
  state = state_factory.state()
43
56
  for i in range(num_nodes):
44
- node_id = state.create_node()
57
+ node_id = state.create_node(ping_interval=PING_MAX_INTERVAL)
45
58
  nodes_mapping[node_id] = i
46
- log(INFO, "Registered %i nodes", len(nodes_mapping))
59
+ log(DEBUG, "Registered %i nodes", len(nodes_mapping))
47
60
  return nodes_mapping
48
61
 
49
62
 
63
+ def _register_node_states(
64
+ nodes_mapping: NodeToPartitionMapping,
65
+ run: Run,
66
+ app_dir: Optional[str] = None,
67
+ ) -> Dict[int, NodeState]:
68
+ """Create NodeState objects and pre-register the context for the run."""
69
+ node_states: Dict[int, NodeState] = {}
70
+ num_partitions = len(set(nodes_mapping.values()))
71
+ for node_id, partition_id in nodes_mapping.items():
72
+ node_states[node_id] = NodeState(
73
+ node_id=node_id,
74
+ node_config={
75
+ PARTITION_ID_KEY: partition_id,
76
+ NUM_PARTITIONS_KEY: num_partitions,
77
+ },
78
+ )
79
+
80
+ # Pre-register Context objects
81
+ node_states[node_id].register_context(
82
+ run_id=run.run_id, run=run, app_dir=app_dir
83
+ )
84
+
85
+ return node_states
86
+
87
+
50
88
  # pylint: disable=too-many-arguments,too-many-locals
51
- async def worker(
89
+ def worker(
52
90
  app_fn: Callable[[], ClientApp],
53
- queue: "asyncio.Queue[TaskIns]",
91
+ taskins_queue: "Queue[TaskIns]",
92
+ taskres_queue: "Queue[TaskRes]",
54
93
  node_states: Dict[int, NodeState],
55
- state_factory: StateFactory,
56
- nodes_mapping: NodeToPartitionMapping,
57
94
  backend: Backend,
95
+ f_stop: threading.Event,
58
96
  ) -> None:
59
97
  """Get TaskIns from queue and pass it to an actor in the pool to execute it."""
60
- state = state_factory.state()
61
- while True:
98
+ while not f_stop.is_set():
99
+ out_mssg = None
62
100
  try:
63
- task_ins: TaskIns = await queue.get()
101
+ # Fetch from queue with timeout. We use a timeout so
102
+ # the stopping event can be evaluated even when the queue is empty.
103
+ task_ins: TaskIns = taskins_queue.get(timeout=1.0)
64
104
  node_id = task_ins.task.consumer.node_id
65
105
 
66
- # Register and retrieve runstate
67
- node_states[node_id].register_context(run_id=task_ins.run_id)
106
+ # Retrieve context
68
107
  context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
69
108
 
70
109
  # Convert TaskIns to Message
71
110
  message = message_from_taskins(task_ins)
72
- # Set partition_id
73
- message.metadata.partition_id = nodes_mapping[node_id]
74
111
 
75
112
  # Let backend process message
76
- out_mssg, updated_context = await backend.process_message(
113
+ out_mssg, updated_context = backend.process_message(
77
114
  app_fn, message, context
78
115
  )
79
116
 
@@ -81,85 +118,74 @@ async def worker(
81
118
  node_states[node_id].update_context(
82
119
  task_ins.run_id, context=updated_context
83
120
  )
84
-
85
- # Convert to TaskRes
86
- task_res = message_to_taskres(out_mssg)
87
- # Store TaskRes in state
88
- state.store_task_res(task_res)
89
-
90
- except asyncio.CancelledError as e:
91
- log(DEBUG, "Async worker: %s", e)
92
- break
93
-
94
- except LoadClientAppError as app_ex:
95
- log(ERROR, "Async worker: %s", app_ex)
96
- log(ERROR, traceback.format_exc())
97
- raise
98
-
121
+ except Empty:
122
+ # An exception raised if queue.get times out
123
+ pass
124
+ # Exceptions aren't raised but reported as an error message
99
125
  except Exception as ex: # pylint: disable=broad-exception-caught
100
126
  log(ERROR, ex)
101
127
  log(ERROR, traceback.format_exc())
102
- break
103
128
 
129
+ if isinstance(ex, ClientAppException):
130
+ e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
131
+ elif isinstance(ex, LoadClientAppError):
132
+ e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
133
+ else:
134
+ e_code = ErrorCode.UNKNOWN
104
135
 
105
- async def add_taskins_to_queue(
106
- queue: "asyncio.Queue[TaskIns]",
107
- state_factory: StateFactory,
136
+ reason = str(type(ex)) + ":<'" + str(ex) + "'>"
137
+ out_mssg = message.create_error_reply(
138
+ error=Error(code=e_code, reason=reason)
139
+ )
140
+
141
+ finally:
142
+ if out_mssg:
143
+ # Convert to TaskRes
144
+ task_res = message_to_taskres(out_mssg)
145
+ # Store TaskRes in state
146
+ task_res.task.pushed_at = time.time()
147
+ taskres_queue.put(task_res)
148
+
149
+
150
+ def add_taskins_to_queue(
151
+ state: State,
152
+ queue: "Queue[TaskIns]",
108
153
  nodes_mapping: NodeToPartitionMapping,
109
- backend: Backend,
110
- consumers: List["asyncio.Task[None]"],
111
- f_stop: asyncio.Event,
154
+ f_stop: threading.Event,
112
155
  ) -> None:
113
- """Retrieve TaskIns and add it to the queue."""
114
- state = state_factory.state()
115
- num_initial_consumers = len(consumers)
156
+ """Put TaskIns in a queue from State."""
116
157
  while not f_stop.is_set():
117
158
  for node_id in nodes_mapping.keys():
118
- task_ins = state.get_task_ins(node_id=node_id, limit=1)
119
- if task_ins:
120
- await queue.put(task_ins[0])
121
-
122
- # Count consumers that are running
123
- num_active = sum(not (cc.done()) for cc in consumers)
124
-
125
- # Alert if number of consumers decreased by half
126
- if num_active < num_initial_consumers // 2:
127
- log(
128
- WARN,
129
- "Number of active workers has more than halved: (%i/%i active)",
130
- num_active,
131
- num_initial_consumers,
132
- )
159
+ task_ins_list = state.get_task_ins(node_id=node_id, limit=1)
160
+ for task_ins in task_ins_list:
161
+ queue.put(task_ins)
162
+ sleep(0.1)
133
163
 
134
- # Break if consumers died
135
- if num_active == 0:
136
- raise RuntimeError("All workers have died. Ending Simulation.")
137
164
 
138
- # Log some stats
139
- log(
140
- DEBUG,
141
- "Simulation Engine stats: "
142
- "Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)",
143
- num_active,
144
- num_initial_consumers,
145
- backend.__class__.__name__,
146
- backend.num_workers,
147
- queue.qsize(),
148
- )
149
- await asyncio.sleep(1.0)
150
- log(DEBUG, "Async producer: Stopped pulling from StateFactory.")
165
+ def put_taskres_into_state(
166
+ state: State, queue: "Queue[TaskRes]", f_stop: threading.Event
167
+ ) -> None:
168
+ """Put TaskRes into State from a queue."""
169
+ while not f_stop.is_set():
170
+ try:
171
+ taskres = queue.get(timeout=1.0)
172
+ state.store_task_res(taskres)
173
+ except Empty:
174
+ # queue is empty when timeout was triggered
175
+ pass
151
176
 
152
177
 
153
- async def run(
178
+ def run_api(
154
179
  app_fn: Callable[[], ClientApp],
155
180
  backend_fn: Callable[[], Backend],
156
181
  nodes_mapping: NodeToPartitionMapping,
157
182
  state_factory: StateFactory,
158
183
  node_states: Dict[int, NodeState],
159
- f_stop: asyncio.Event,
184
+ f_stop: threading.Event,
160
185
  ) -> None:
161
- """Run the VCE async."""
162
- queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128)
186
+ """Run the VCE."""
187
+ taskins_queue: "Queue[TaskIns]" = Queue()
188
+ taskres_queue: "Queue[TaskRes]" = Queue()
163
189
 
164
190
  try:
165
191
 
@@ -167,29 +193,48 @@ async def run(
167
193
  backend = backend_fn()
168
194
 
169
195
  # Build backend
170
- await backend.build()
196
+ backend.build()
171
197
 
172
198
  # Add workers (they submit Messages to Backend)
173
- worker_tasks = [
174
- asyncio.create_task(
175
- worker(
176
- app_fn, queue, node_states, state_factory, nodes_mapping, backend
177
- )
178
- )
179
- for _ in range(backend.num_workers)
180
- ]
181
- # Create producer (adds TaskIns into Queue)
182
- producer = asyncio.create_task(
183
- add_taskins_to_queue(
184
- queue, state_factory, nodes_mapping, backend, worker_tasks, f_stop
185
- )
199
+ state = state_factory.state()
200
+
201
+ extractor_th = threading.Thread(
202
+ target=add_taskins_to_queue,
203
+ args=(
204
+ state,
205
+ taskins_queue,
206
+ nodes_mapping,
207
+ f_stop,
208
+ ),
186
209
  )
210
+ extractor_th.start()
211
+
212
+ injector_th = threading.Thread(
213
+ target=put_taskres_into_state,
214
+ args=(
215
+ state,
216
+ taskres_queue,
217
+ f_stop,
218
+ ),
219
+ )
220
+ injector_th.start()
221
+
222
+ with ThreadPoolExecutor() as executor:
223
+ _ = [
224
+ executor.submit(
225
+ worker,
226
+ app_fn,
227
+ taskins_queue,
228
+ taskres_queue,
229
+ node_states,
230
+ backend,
231
+ f_stop,
232
+ )
233
+ for _ in range(backend.num_workers)
234
+ ]
187
235
 
188
- # Wait for producer to finish
189
- # The producer runs forever until f_stop is set or until
190
- # all worker (consumer) coroutines are completed. Workers
191
- # also run forever and only end if an exception is raised.
192
- await asyncio.gather(producer)
236
+ extractor_th.join()
237
+ injector_th.join()
193
238
 
194
239
  except Exception as ex:
195
240
 
@@ -204,26 +249,21 @@ async def run(
204
249
  raise RuntimeError("Simulation Engine crashed.") from ex
205
250
 
206
251
  finally:
207
- # Produced task terminated, now cancel worker tasks
208
- for w_t in worker_tasks:
209
- _ = w_t.cancel()
210
-
211
- while not all(w_t.done() for w_t in worker_tasks):
212
- log(DEBUG, "Terminating async workers...")
213
- await asyncio.sleep(0.5)
214
-
215
- await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])
216
252
 
217
253
  # Terminate backend
218
- await backend.terminate()
254
+ backend.terminate()
219
255
 
220
256
 
221
- # pylint: disable=too-many-arguments,unused-argument,too-many-locals
257
+ # pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
258
+ # pylint: disable=too-many-statements
222
259
  def start_vce(
223
260
  backend_name: str,
224
261
  backend_config_json_stream: str,
225
262
  app_dir: str,
226
- f_stop: asyncio.Event,
263
+ is_app: bool,
264
+ f_stop: threading.Event,
265
+ run: Run,
266
+ flwr_dir: Optional[str] = None,
227
267
  client_app: Optional[ClientApp] = None,
228
268
  client_app_attr: Optional[str] = None,
229
269
  num_supernodes: Optional[int] = None,
@@ -259,6 +299,7 @@ def start_vce(
259
299
  # Use mapping constructed externally. This also means nodes
260
300
  # have previously being registered.
261
301
  nodes_mapping = existing_nodes_mapping
302
+ app_dir = str(Path(app_dir).absolute())
262
303
 
263
304
  if not state_factory:
264
305
  log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
@@ -273,12 +314,12 @@ def start_vce(
273
314
  )
274
315
 
275
316
  # Construct mapping of NodeStates
276
- node_states: Dict[int, NodeState] = {}
277
- for node_id in nodes_mapping:
278
- node_states[node_id] = NodeState()
317
+ node_states = _register_node_states(
318
+ nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None
319
+ )
279
320
 
280
321
  # Load backend config
281
- log(INFO, "Supported backends: %s", list(supported_backends.keys()))
322
+ log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
282
323
  backend_config = json.loads(backend_config_json_stream)
283
324
 
284
325
  try:
@@ -298,20 +339,18 @@ def start_vce(
298
339
 
299
340
  def backend_fn() -> Backend:
300
341
  """Instantiate a Backend."""
301
- return backend_type(backend_config, work_dir=app_dir)
302
-
303
- log(INFO, "client_app_attr = %s", client_app_attr)
342
+ return backend_type(backend_config)
304
343
 
305
344
  # Load ClientApp if needed
306
345
  def _load() -> ClientApp:
307
346
 
308
347
  if client_app_attr:
309
- app: ClientApp = load_app(client_app_attr, LoadClientAppError)
310
-
311
- if not isinstance(app, ClientApp):
312
- raise LoadClientAppError(
313
- f"Attribute {client_app_attr} is not of type {ClientApp}",
314
- ) from None
348
+ app = _get_load_client_app_fn(
349
+ default_app_ref=client_app_attr,
350
+ app_path=app_dir,
351
+ flwr_dir=flwr_dir,
352
+ multi_app=False,
353
+ )(run.fab_id, run.fab_version)
315
354
 
316
355
  if client_app:
317
356
  app = client_app
@@ -319,8 +358,21 @@ def start_vce(
319
358
 
320
359
  app_fn = _load
321
360
 
322
- asyncio.run(
323
- run(
361
+ try:
362
+ # Test if ClientApp can be loaded
363
+ client_app = app_fn()
364
+
365
+ # Cache `ClientApp`
366
+ if client_app_attr:
367
+ # Now wrap the loaded ClientApp in a dummy function
368
+ # this prevent unnecesary low-level loading of ClientApp
369
+ def _load_client_app() -> ClientApp:
370
+ return client_app
371
+
372
+ app_fn = _load_client_app
373
+
374
+ # Run main simulation loop
375
+ run_api(
324
376
  app_fn,
325
377
  backend_fn,
326
378
  nodes_mapping,
@@ -328,4 +380,15 @@ def start_vce(
328
380
  node_states,
329
381
  f_stop,
330
382
  )
331
- )
383
+ except LoadClientAppError as loadapp_ex:
384
+ f_stop_delay = 10
385
+ log(
386
+ ERROR,
387
+ "LoadClientAppError exception encountered. Terminating simulation in %is",
388
+ f_stop_delay,
389
+ )
390
+ time.sleep(f_stop_delay)
391
+ f_stop.set() # set termination event
392
+ raise loadapp_ex
393
+ except Exception as ex:
394
+ raise ex
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.