flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.15.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (311) hide show
  1. flwr/cli/app.py +16 -2
  2. flwr/cli/build.py +181 -0
  3. flwr/cli/cli_user_auth_interceptor.py +90 -0
  4. flwr/cli/config_utils.py +343 -0
  5. flwr/cli/example.py +4 -1
  6. flwr/cli/install.py +253 -0
  7. flwr/cli/log.py +182 -0
  8. flwr/{server/superlink/state → cli/login}/__init__.py +4 -10
  9. flwr/cli/login/login.py +88 -0
  10. flwr/cli/ls.py +327 -0
  11. flwr/cli/new/__init__.py +1 -0
  12. flwr/cli/new/new.py +210 -66
  13. flwr/cli/new/templates/app/.gitignore.tpl +163 -0
  14. flwr/cli/new/templates/app/LICENSE.tpl +202 -0
  15. flwr/cli/new/templates/app/README.baseline.md.tpl +127 -0
  16. flwr/cli/new/templates/app/README.flowertune.md.tpl +66 -0
  17. flwr/cli/new/templates/app/README.md.tpl +16 -32
  18. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +1 -0
  19. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  20. flwr/cli/new/templates/app/code/client.baseline.py.tpl +58 -0
  21. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +55 -0
  22. flwr/cli/new/templates/app/code/client.jax.py.tpl +50 -0
  23. flwr/cli/new/templates/app/code/client.mlx.py.tpl +73 -0
  24. flwr/cli/new/templates/app/code/client.numpy.py.tpl +7 -7
  25. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +30 -21
  26. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +63 -0
  27. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +57 -1
  28. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +36 -0
  29. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  30. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +126 -0
  31. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +87 -0
  32. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +78 -0
  33. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +94 -0
  34. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
  35. flwr/cli/new/templates/app/code/model.baseline.py.tpl +80 -0
  36. flwr/cli/new/templates/app/code/server.baseline.py.tpl +46 -0
  37. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +38 -0
  38. flwr/cli/new/templates/app/code/server.jax.py.tpl +26 -0
  39. flwr/cli/new/templates/app/code/server.mlx.py.tpl +31 -0
  40. flwr/cli/new/templates/app/code/server.numpy.py.tpl +22 -9
  41. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  42. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +36 -0
  43. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  44. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +1 -0
  45. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +102 -0
  46. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  47. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  48. flwr/cli/new/templates/app/code/task.numpy.py.tpl +7 -0
  49. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +29 -24
  50. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +67 -0
  51. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  52. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +1 -0
  53. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +138 -0
  54. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +68 -0
  55. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +46 -0
  56. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +35 -0
  57. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  58. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  59. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  60. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +35 -0
  61. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  62. flwr/cli/run/__init__.py +1 -0
  63. flwr/cli/run/run.py +212 -34
  64. flwr/cli/stop.py +130 -0
  65. flwr/cli/utils.py +240 -5
  66. flwr/client/__init__.py +3 -2
  67. flwr/client/app.py +432 -255
  68. flwr/client/client.py +1 -11
  69. flwr/client/client_app.py +74 -13
  70. flwr/client/clientapp/__init__.py +22 -0
  71. flwr/client/clientapp/app.py +259 -0
  72. flwr/client/clientapp/clientappio_servicer.py +244 -0
  73. flwr/client/clientapp/utils.py +115 -0
  74. flwr/client/dpfedavg_numpy_client.py +7 -8
  75. flwr/client/grpc_adapter_client/__init__.py +15 -0
  76. flwr/client/grpc_adapter_client/connection.py +98 -0
  77. flwr/client/grpc_client/connection.py +21 -7
  78. flwr/client/grpc_rere_client/__init__.py +1 -1
  79. flwr/client/grpc_rere_client/client_interceptor.py +176 -0
  80. flwr/client/grpc_rere_client/connection.py +163 -56
  81. flwr/client/grpc_rere_client/grpc_adapter.py +167 -0
  82. flwr/client/heartbeat.py +74 -0
  83. flwr/client/message_handler/__init__.py +1 -1
  84. flwr/client/message_handler/message_handler.py +10 -11
  85. flwr/client/mod/__init__.py +5 -5
  86. flwr/client/mod/centraldp_mods.py +4 -2
  87. flwr/client/mod/comms_mods.py +5 -4
  88. flwr/client/mod/localdp_mod.py +10 -5
  89. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  90. flwr/client/mod/secure_aggregation/secaggplus_mod.py +26 -26
  91. flwr/client/mod/utils.py +2 -4
  92. flwr/client/nodestate/__init__.py +26 -0
  93. flwr/client/nodestate/in_memory_nodestate.py +38 -0
  94. flwr/client/nodestate/nodestate.py +31 -0
  95. flwr/client/nodestate/nodestate_factory.py +38 -0
  96. flwr/client/numpy_client.py +8 -31
  97. flwr/client/rest_client/__init__.py +1 -1
  98. flwr/client/rest_client/connection.py +199 -176
  99. flwr/client/run_info_store.py +112 -0
  100. flwr/client/supernode/__init__.py +24 -0
  101. flwr/client/supernode/app.py +321 -0
  102. flwr/client/typing.py +1 -0
  103. flwr/common/__init__.py +17 -11
  104. flwr/common/address.py +47 -3
  105. flwr/common/args.py +153 -0
  106. flwr/common/auth_plugin/__init__.py +24 -0
  107. flwr/common/auth_plugin/auth_plugin.py +121 -0
  108. flwr/common/config.py +243 -0
  109. flwr/common/constant.py +132 -1
  110. flwr/common/context.py +32 -2
  111. flwr/common/date.py +22 -4
  112. flwr/common/differential_privacy.py +2 -2
  113. flwr/common/dp.py +2 -4
  114. flwr/common/exit_handlers.py +3 -3
  115. flwr/common/grpc.py +164 -5
  116. flwr/common/logger.py +230 -12
  117. flwr/common/message.py +191 -106
  118. flwr/common/object_ref.py +179 -44
  119. flwr/common/pyproject.py +1 -0
  120. flwr/common/record/__init__.py +2 -1
  121. flwr/common/record/configsrecord.py +58 -18
  122. flwr/common/record/metricsrecord.py +57 -17
  123. flwr/common/record/parametersrecord.py +88 -20
  124. flwr/common/record/recordset.py +153 -30
  125. flwr/common/record/typeddict.py +30 -55
  126. flwr/common/recordset_compat.py +31 -12
  127. flwr/common/retry_invoker.py +123 -30
  128. flwr/common/secure_aggregation/__init__.py +1 -1
  129. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  130. flwr/common/secure_aggregation/crypto/shamir.py +11 -11
  131. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +68 -4
  132. flwr/common/secure_aggregation/ndarrays_arithmetic.py +17 -17
  133. flwr/common/secure_aggregation/quantization.py +8 -8
  134. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  135. flwr/common/secure_aggregation/secaggplus_utils.py +10 -12
  136. flwr/common/serde.py +298 -19
  137. flwr/common/telemetry.py +65 -29
  138. flwr/common/typing.py +120 -19
  139. flwr/common/version.py +17 -3
  140. flwr/proto/clientappio_pb2.py +45 -0
  141. flwr/proto/clientappio_pb2.pyi +132 -0
  142. flwr/proto/clientappio_pb2_grpc.py +135 -0
  143. flwr/proto/clientappio_pb2_grpc.pyi +53 -0
  144. flwr/proto/exec_pb2.py +62 -0
  145. flwr/proto/exec_pb2.pyi +212 -0
  146. flwr/proto/exec_pb2_grpc.py +237 -0
  147. flwr/proto/exec_pb2_grpc.pyi +93 -0
  148. flwr/proto/fab_pb2.py +31 -0
  149. flwr/proto/fab_pb2.pyi +65 -0
  150. flwr/proto/fab_pb2_grpc.py +4 -0
  151. flwr/proto/fab_pb2_grpc.pyi +4 -0
  152. flwr/proto/fleet_pb2.py +42 -23
  153. flwr/proto/fleet_pb2.pyi +123 -1
  154. flwr/proto/fleet_pb2_grpc.py +170 -0
  155. flwr/proto/fleet_pb2_grpc.pyi +61 -0
  156. flwr/proto/grpcadapter_pb2.py +32 -0
  157. flwr/proto/grpcadapter_pb2.pyi +43 -0
  158. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  159. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  160. flwr/proto/log_pb2.py +29 -0
  161. flwr/proto/log_pb2.pyi +39 -0
  162. flwr/proto/log_pb2_grpc.py +4 -0
  163. flwr/proto/log_pb2_grpc.pyi +4 -0
  164. flwr/proto/message_pb2.py +41 -0
  165. flwr/proto/message_pb2.pyi +128 -0
  166. flwr/proto/message_pb2_grpc.py +4 -0
  167. flwr/proto/message_pb2_grpc.pyi +4 -0
  168. flwr/proto/node_pb2.py +1 -1
  169. flwr/proto/recordset_pb2.py +35 -33
  170. flwr/proto/recordset_pb2.pyi +40 -14
  171. flwr/proto/run_pb2.py +64 -0
  172. flwr/proto/run_pb2.pyi +268 -0
  173. flwr/proto/run_pb2_grpc.py +4 -0
  174. flwr/proto/run_pb2_grpc.pyi +4 -0
  175. flwr/proto/serverappio_pb2.py +52 -0
  176. flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +62 -20
  177. flwr/proto/serverappio_pb2_grpc.py +410 -0
  178. flwr/proto/serverappio_pb2_grpc.pyi +160 -0
  179. flwr/proto/simulationio_pb2.py +38 -0
  180. flwr/proto/simulationio_pb2.pyi +65 -0
  181. flwr/proto/simulationio_pb2_grpc.py +239 -0
  182. flwr/proto/simulationio_pb2_grpc.pyi +94 -0
  183. flwr/proto/task_pb2.py +7 -8
  184. flwr/proto/task_pb2.pyi +8 -5
  185. flwr/proto/transport_pb2.py +8 -8
  186. flwr/proto/transport_pb2.pyi +9 -6
  187. flwr/server/__init__.py +2 -10
  188. flwr/server/app.py +579 -402
  189. flwr/server/client_manager.py +8 -6
  190. flwr/server/compat/app.py +6 -62
  191. flwr/server/compat/app_utils.py +14 -8
  192. flwr/server/compat/driver_client_proxy.py +25 -58
  193. flwr/server/compat/legacy_context.py +5 -4
  194. flwr/server/driver/__init__.py +2 -0
  195. flwr/server/driver/driver.py +36 -131
  196. flwr/server/driver/grpc_driver.py +217 -81
  197. flwr/server/driver/inmemory_driver.py +182 -0
  198. flwr/server/history.py +28 -29
  199. flwr/server/run_serverapp.py +15 -126
  200. flwr/server/server.py +50 -44
  201. flwr/server/server_app.py +59 -10
  202. flwr/server/serverapp/__init__.py +22 -0
  203. flwr/server/serverapp/app.py +256 -0
  204. flwr/server/serverapp_components.py +52 -0
  205. flwr/server/strategy/__init__.py +2 -2
  206. flwr/server/strategy/aggregate.py +37 -23
  207. flwr/server/strategy/bulyan.py +9 -9
  208. flwr/server/strategy/dp_adaptive_clipping.py +25 -25
  209. flwr/server/strategy/dp_fixed_clipping.py +23 -22
  210. flwr/server/strategy/dpfedavg_adaptive.py +8 -8
  211. flwr/server/strategy/dpfedavg_fixed.py +13 -12
  212. flwr/server/strategy/fault_tolerant_fedavg.py +11 -11
  213. flwr/server/strategy/fedadagrad.py +9 -9
  214. flwr/server/strategy/fedadam.py +20 -10
  215. flwr/server/strategy/fedavg.py +16 -16
  216. flwr/server/strategy/fedavg_android.py +17 -17
  217. flwr/server/strategy/fedavgm.py +9 -9
  218. flwr/server/strategy/fedmedian.py +5 -5
  219. flwr/server/strategy/fedopt.py +6 -6
  220. flwr/server/strategy/fedprox.py +7 -7
  221. flwr/server/strategy/fedtrimmedavg.py +8 -8
  222. flwr/server/strategy/fedxgb_bagging.py +12 -12
  223. flwr/server/strategy/fedxgb_cyclic.py +10 -10
  224. flwr/server/strategy/fedxgb_nn_avg.py +6 -6
  225. flwr/server/strategy/fedyogi.py +9 -9
  226. flwr/server/strategy/krum.py +9 -9
  227. flwr/server/strategy/qfedavg.py +16 -16
  228. flwr/server/strategy/strategy.py +10 -10
  229. flwr/server/superlink/driver/__init__.py +2 -2
  230. flwr/server/superlink/driver/serverappio_grpc.py +61 -0
  231. flwr/server/superlink/driver/serverappio_servicer.py +363 -0
  232. flwr/server/superlink/ffs/__init__.py +24 -0
  233. flwr/server/superlink/ffs/disk_ffs.py +108 -0
  234. flwr/server/superlink/ffs/ffs.py +79 -0
  235. flwr/server/superlink/ffs/ffs_factory.py +47 -0
  236. flwr/server/superlink/fleet/__init__.py +1 -1
  237. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  238. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +162 -0
  239. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  240. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +4 -2
  241. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -2
  242. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  243. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -154
  244. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  245. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +120 -13
  246. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +228 -0
  247. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  248. flwr/server/superlink/fleet/message_handler/message_handler.py +153 -9
  249. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  250. flwr/server/superlink/fleet/rest_rere/rest_api.py +119 -81
  251. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  252. flwr/server/superlink/fleet/vce/backend/__init__.py +4 -4
  253. flwr/server/superlink/fleet/vce/backend/backend.py +8 -9
  254. flwr/server/superlink/fleet/vce/backend/raybackend.py +87 -68
  255. flwr/server/superlink/fleet/vce/vce_api.py +208 -146
  256. flwr/server/superlink/linkstate/__init__.py +28 -0
  257. flwr/server/superlink/linkstate/in_memory_linkstate.py +581 -0
  258. flwr/server/superlink/linkstate/linkstate.py +389 -0
  259. flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +19 -10
  260. flwr/server/superlink/linkstate/sqlite_linkstate.py +1236 -0
  261. flwr/server/superlink/linkstate/utils.py +389 -0
  262. flwr/server/superlink/simulation/__init__.py +15 -0
  263. flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
  264. flwr/server/superlink/simulation/simulationio_servicer.py +186 -0
  265. flwr/server/superlink/utils.py +65 -0
  266. flwr/server/typing.py +2 -0
  267. flwr/server/utils/__init__.py +1 -1
  268. flwr/server/utils/tensorboard.py +5 -5
  269. flwr/server/utils/validator.py +31 -11
  270. flwr/server/workflow/default_workflows.py +70 -26
  271. flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -0
  272. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +40 -27
  273. flwr/simulation/__init__.py +12 -5
  274. flwr/simulation/app.py +247 -315
  275. flwr/simulation/legacy_app.py +402 -0
  276. flwr/simulation/ray_transport/__init__.py +1 -1
  277. flwr/simulation/ray_transport/ray_actor.py +42 -67
  278. flwr/simulation/ray_transport/ray_client_proxy.py +37 -17
  279. flwr/simulation/ray_transport/utils.py +1 -0
  280. flwr/simulation/run_simulation.py +306 -163
  281. flwr/simulation/simulationio_connection.py +89 -0
  282. flwr/superexec/__init__.py +15 -0
  283. flwr/superexec/app.py +59 -0
  284. flwr/superexec/deployment.py +188 -0
  285. flwr/superexec/exec_grpc.py +80 -0
  286. flwr/superexec/exec_servicer.py +231 -0
  287. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  288. flwr/superexec/executor.py +96 -0
  289. flwr/superexec/simulation.py +124 -0
  290. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.15.0.dev20250114.dist-info}/METADATA +33 -26
  291. flwr_nightly-1.15.0.dev20250114.dist-info/RECORD +328 -0
  292. flwr_nightly-1.15.0.dev20250114.dist-info/entry_points.txt +12 -0
  293. flwr/cli/flower_toml.py +0 -140
  294. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  295. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  296. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  297. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  298. flwr/client/node_state.py +0 -48
  299. flwr/client/node_state_tests.py +0 -65
  300. flwr/proto/driver_pb2.py +0 -44
  301. flwr/proto/driver_pb2_grpc.py +0 -169
  302. flwr/proto/driver_pb2_grpc.pyi +0 -66
  303. flwr/server/superlink/driver/driver_grpc.py +0 -54
  304. flwr/server/superlink/driver/driver_servicer.py +0 -129
  305. flwr/server/superlink/state/in_memory_state.py +0 -230
  306. flwr/server/superlink/state/sqlite_state.py +0 -630
  307. flwr/server/superlink/state/state.py +0 -154
  308. flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
  309. flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
  310. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.15.0.dev20250114.dist-info}/LICENSE +0 -0
  311. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.15.0.dev20250114.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,17 +15,37 @@
15
15
  """Experimental REST API server."""
16
16
 
17
17
 
18
+ from __future__ import annotations
19
+
18
20
  import sys
21
+ from collections.abc import Awaitable
22
+ from typing import Callable, TypeVar, cast
23
+
24
+ from google.protobuf.message import Message as GrpcMessage
19
25
 
20
26
  from flwr.common.constant import MISSING_EXTRA_REST
27
+ from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
21
28
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
22
29
  CreateNodeRequest,
30
+ CreateNodeResponse,
23
31
  DeleteNodeRequest,
32
+ DeleteNodeResponse,
33
+ PingRequest,
34
+ PingResponse,
35
+ PullMessagesRequest,
36
+ PullMessagesResponse,
24
37
  PullTaskInsRequest,
38
+ PullTaskInsResponse,
39
+ PushMessagesRequest,
40
+ PushMessagesResponse,
25
41
  PushTaskResRequest,
42
+ PushTaskResResponse,
26
43
  )
44
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
45
+ from flwr.server.superlink.ffs.ffs import Ffs
46
+ from flwr.server.superlink.ffs.ffs_factory import FfsFactory
27
47
  from flwr.server.superlink.fleet.message_handler import message_handler
28
- from flwr.server.superlink.state import State
48
+ from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
29
49
 
30
50
  try:
31
51
  from starlette.applications import Starlette
@@ -38,125 +58,143 @@ except ModuleNotFoundError:
38
58
  sys.exit(MISSING_EXTRA_REST)
39
59
 
40
60
 
41
- async def create_node(request: Request) -> Response:
42
- """Create Node."""
43
- _check_headers(request.headers)
61
+ GrpcRequest = TypeVar("GrpcRequest", bound=GrpcMessage)
62
+ GrpcResponse = TypeVar("GrpcResponse", bound=GrpcMessage)
44
63
 
45
- # Get the request body as raw bytes
46
- create_node_request_bytes: bytes = await request.body()
64
+ GrpcAsyncFunction = Callable[[GrpcRequest], Awaitable[GrpcResponse]]
65
+ RestEndPoint = Callable[[Request], Awaitable[Response]]
47
66
 
48
- # Deserialize ProtoBuf
49
- create_node_request_proto = CreateNodeRequest()
50
- create_node_request_proto.ParseFromString(create_node_request_bytes)
51
67
 
52
- # Get state from app
53
- state: State = app.state.STATE_FACTORY.state()
68
+ def rest_request_response(
69
+ grpc_request_type: type[GrpcRequest],
70
+ ) -> Callable[[GrpcAsyncFunction[GrpcRequest, GrpcResponse]], RestEndPoint]:
71
+ """Convert an async gRPC-based function into a RESTful HTTP endpoint."""
54
72
 
55
- # Handle message
56
- create_node_response_proto = message_handler.create_node(
57
- request=create_node_request_proto, state=state
58
- )
73
+ def decorator(func: GrpcAsyncFunction[GrpcRequest, GrpcResponse]) -> RestEndPoint:
74
+ async def wrapper(request: Request) -> Response:
75
+ _check_headers(request.headers)
59
76
 
60
- # Return serialized ProtoBuf
61
- create_node_response_bytes = create_node_response_proto.SerializeToString()
62
- return Response(
63
- status_code=200,
64
- content=create_node_response_bytes,
65
- headers={"Content-Type": "application/protobuf"},
66
- )
77
+ # Get the request body as raw bytes
78
+ grpc_req_bytes: bytes = await request.body()
67
79
 
80
+ # Deserialize ProtoBuf
81
+ grpc_req = grpc_request_type.FromString(grpc_req_bytes)
82
+ grpc_res = await func(grpc_req)
83
+ return Response(
84
+ status_code=200,
85
+ content=grpc_res.SerializeToString(),
86
+ headers={"Content-Type": "application/protobuf"},
87
+ )
68
88
 
69
- async def delete_node(request: Request) -> Response:
70
- """Delete Node Id."""
71
- _check_headers(request.headers)
89
+ return wrapper
72
90
 
73
- # Get the request body as raw bytes
74
- delete_node_request_bytes: bytes = await request.body()
91
+ return decorator
75
92
 
76
- # Deserialize ProtoBuf
77
- delete_node_request_proto = DeleteNodeRequest()
78
- delete_node_request_proto.ParseFromString(delete_node_request_bytes)
79
93
 
94
+ @rest_request_response(CreateNodeRequest)
95
+ async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
96
+ """Create Node."""
80
97
  # Get state from app
81
- state: State = app.state.STATE_FACTORY.state()
98
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
82
99
 
83
100
  # Handle message
84
- delete_node_response_proto = message_handler.delete_node(
85
- request=delete_node_request_proto, state=state
86
- )
101
+ return message_handler.create_node(request=request, state=state)
102
+
87
103
 
88
- # Return serialized ProtoBuf
89
- delete_node_response_bytes = delete_node_response_proto.SerializeToString()
90
- return Response(
91
- status_code=200,
92
- content=delete_node_response_bytes,
93
- headers={"Content-Type": "application/protobuf"},
94
- )
104
+ @rest_request_response(DeleteNodeRequest)
105
+ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
106
+ """Delete Node Id."""
107
+ # Get state from app
108
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
109
+
110
+ # Handle message
111
+ return message_handler.delete_node(request=request, state=state)
95
112
 
96
113
 
97
- async def pull_task_ins(request: Request) -> Response:
114
+ @rest_request_response(PullTaskInsRequest)
115
+ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
98
116
  """Pull TaskIns."""
99
- _check_headers(request.headers)
117
+ # Get state from app
118
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
100
119
 
101
- # Get the request body as raw bytes
102
- pull_task_ins_request_bytes: bytes = await request.body()
120
+ # Handle message
121
+ return message_handler.pull_task_ins(request=request, state=state)
103
122
 
104
- # Deserialize ProtoBuf
105
- pull_task_ins_request_proto = PullTaskInsRequest()
106
- pull_task_ins_request_proto.ParseFromString(pull_task_ins_request_bytes)
107
123
 
124
+ @rest_request_response(PullMessagesRequest)
125
+ async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
126
+ """Pull PullMessages."""
108
127
  # Get state from app
109
- state: State = app.state.STATE_FACTORY.state()
128
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
110
129
 
111
130
  # Handle message
112
- pull_task_ins_response_proto = message_handler.pull_task_ins(
113
- request=pull_task_ins_request_proto,
114
- state=state,
115
- )
131
+ return message_handler.pull_messages(request=request, state=state)
116
132
 
117
- # Return serialized ProtoBuf
118
- pull_task_ins_response_bytes = pull_task_ins_response_proto.SerializeToString()
119
- return Response(
120
- status_code=200,
121
- content=pull_task_ins_response_bytes,
122
- headers={"Content-Type": "application/protobuf"},
123
- )
124
133
 
125
-
126
- async def push_task_res(request: Request) -> Response: # Check if token is needed here
134
+ # Check if token is needed here
135
+ @rest_request_response(PushTaskResRequest)
136
+ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
127
137
  """Push TaskRes."""
128
- _check_headers(request.headers)
138
+ # Get state from app
139
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
129
140
 
130
- # Get the request body as raw bytes
131
- push_task_res_request_bytes: bytes = await request.body()
141
+ # Handle message
142
+ return message_handler.push_task_res(request=request, state=state)
143
+
144
+
145
+ @rest_request_response(PushMessagesRequest)
146
+ async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
147
+ """Pull PushMessages."""
148
+ # Get state from app
149
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
150
+
151
+ # Handle message
152
+ return message_handler.push_messages(request=request, state=state)
153
+
154
+
155
+ @rest_request_response(PingRequest)
156
+ async def ping(request: PingRequest) -> PingResponse:
157
+ """Ping."""
158
+ # Get state from app
159
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
160
+
161
+ # Handle message
162
+ return message_handler.ping(request=request, state=state)
132
163
 
133
- # Deserialize ProtoBuf
134
- push_task_res_request_proto = PushTaskResRequest()
135
- push_task_res_request_proto.ParseFromString(push_task_res_request_bytes)
136
164
 
165
+ @rest_request_response(GetRunRequest)
166
+ async def get_run(request: GetRunRequest) -> GetRunResponse:
167
+ """GetRun."""
137
168
  # Get state from app
138
- state: State = app.state.STATE_FACTORY.state()
169
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
139
170
 
140
171
  # Handle message
141
- push_task_res_response_proto = message_handler.push_task_res(
142
- request=push_task_res_request_proto,
143
- state=state,
144
- )
172
+ return message_handler.get_run(request=request, state=state)
173
+
145
174
 
146
- # Return serialized ProtoBuf
147
- push_task_res_response_bytes = push_task_res_response_proto.SerializeToString()
148
- return Response(
149
- status_code=200,
150
- content=push_task_res_response_bytes,
151
- headers={"Content-Type": "application/protobuf"},
152
- )
175
+ @rest_request_response(GetFabRequest)
176
+ async def get_fab(request: GetFabRequest) -> GetFabResponse:
177
+ """GetRun."""
178
+ # Get ffs from app
179
+ ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()
180
+
181
+ # Get state from app
182
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
183
+
184
+ # Handle message
185
+ return message_handler.get_fab(request=request, ffs=ffs, state=state)
153
186
 
154
187
 
155
188
  routes = [
156
189
  Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
157
190
  Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
158
191
  Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
192
+ Route("/api/v0/fleet/pull-messages", pull_message, methods=["POST"]),
159
193
  Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
194
+ Route("/api/v0/fleet/push-messages", push_message, methods=["POST"]),
195
+ Route("/api/v0/fleet/ping", ping, methods=["POST"]),
196
+ Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
197
+ Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
160
198
  ]
161
199
 
162
200
  app: Starlette = Starlette(
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Fleet Simulation Engine side."""
16
16
 
17
+
17
18
  from .vce_api import start_vce
18
19
 
19
20
  __all__ = [
@@ -14,18 +14,18 @@
14
14
  # ==============================================================================
15
15
  """Simulation Engine Backends."""
16
16
 
17
+
17
18
  import importlib
18
- from typing import Dict, Type
19
19
 
20
20
  from .backend import Backend, BackendConfig
21
21
 
22
22
  is_ray_installed = importlib.util.find_spec("ray") is not None
23
23
 
24
24
  # Mapping of supported backends
25
- supported_backends: Dict[str, Type[Backend]] = {}
25
+ supported_backends: dict[str, type[Backend]] = {}
26
26
 
27
27
  # To log backend-specific error message when chosen backend isn't available
28
- error_messages_backends: Dict[str, str] = {}
28
+ error_messages_backends: dict[str, str] = {}
29
29
 
30
30
  if is_ray_installed:
31
31
  from .raybackend import RayBackend
@@ -38,7 +38,7 @@ else:
38
38
 
39
39
  To install the necessary dependencies, install `flwr` with the `simulation` extra:
40
40
 
41
- pip install -U flwr["simulation"]
41
+ pip install -U "flwr[simulation]"
42
42
  """
43
43
 
44
44
 
@@ -16,25 +16,25 @@
16
16
 
17
17
 
18
18
  from abc import ABC, abstractmethod
19
- from typing import Callable, Dict, Tuple
19
+ from typing import Callable
20
20
 
21
21
  from flwr.client.client_app import ClientApp
22
22
  from flwr.common.context import Context
23
23
  from flwr.common.message import Message
24
24
  from flwr.common.typing import ConfigsRecordValues
25
25
 
26
- BackendConfig = Dict[str, Dict[str, ConfigsRecordValues]]
26
+ BackendConfig = dict[str, dict[str, ConfigsRecordValues]]
27
27
 
28
28
 
29
29
  class Backend(ABC):
30
30
  """Abstract base class for a Simulation Engine Backend."""
31
31
 
32
- def __init__(self, backend_config: BackendConfig, work_dir: str) -> None:
32
+ def __init__(self, backend_config: BackendConfig) -> None:
33
33
  """Construct a backend."""
34
34
 
35
35
  @abstractmethod
36
- async def build(self) -> None:
37
- """Build backend asynchronously.
36
+ def build(self, app_fn: Callable[[], ClientApp]) -> None:
37
+ """Build backend.
38
38
 
39
39
  Different components need to be in place before workers in a backend are ready
40
40
  to accept jobs. When this method finishes executing, the backend should be fully
@@ -54,14 +54,13 @@ class Backend(ABC):
54
54
  """Report whether a backend worker is idle and can therefore run a ClientApp."""
55
55
 
56
56
  @abstractmethod
57
- async def terminate(self) -> None:
57
+ def terminate(self) -> None:
58
58
  """Terminate backend."""
59
59
 
60
60
  @abstractmethod
61
- async def process_message(
61
+ def process_message(
62
62
  self,
63
- app: Callable[[], ClientApp],
64
63
  message: Message,
65
64
  context: Context,
66
- ) -> Tuple[Message, Context]:
65
+ ) -> tuple[Message, Context]:
67
66
  """Submit a job to the backend."""
@@ -14,26 +14,26 @@
14
14
  # ==============================================================================
15
15
  """Ray backend for the Fleet API using the Simulation Engine."""
16
16
 
17
- import pathlib
18
- from logging import ERROR, INFO
19
- from typing import Callable, Dict, List, Tuple, Union
17
+
18
+ import sys
19
+ from logging import DEBUG, ERROR
20
+ from typing import Callable, Optional, Union
20
21
 
21
22
  import ray
22
23
 
23
- from flwr.client.client_app import ClientApp, LoadClientAppError
24
+ from flwr.client.client_app import ClientApp
25
+ from flwr.common.constant import PARTITION_ID_KEY
24
26
  from flwr.common.context import Context
25
27
  from flwr.common.logger import log
26
28
  from flwr.common.message import Message
27
- from flwr.simulation.ray_transport.ray_actor import (
28
- BasicActorPool,
29
- ClientAppActor,
30
- init_ray,
31
- )
29
+ from flwr.common.typing import ConfigsRecordValues
30
+ from flwr.simulation.ray_transport.ray_actor import BasicActorPool, ClientAppActor
32
31
  from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
33
32
 
34
33
  from .backend import Backend, BackendConfig
35
34
 
36
- ClientResourcesDict = Dict[str, Union[int, float]]
35
+ ClientResourcesDict = dict[str, Union[int, float]]
36
+ ActorArgsDict = dict[str, Union[int, float, Callable[[], None]]]
37
37
 
38
38
 
39
39
  class RayBackend(Backend):
@@ -42,51 +42,24 @@ class RayBackend(Backend):
42
42
  def __init__(
43
43
  self,
44
44
  backend_config: BackendConfig,
45
- work_dir: str,
46
45
  ) -> None:
47
46
  """Prepare RayBackend by initialising Ray and creating the ActorPool."""
48
- log(INFO, "Initialising: %s", self.__class__.__name__)
49
- log(INFO, "Backend config: %s", backend_config)
50
-
51
- if not pathlib.Path(work_dir).exists():
52
- raise ValueError(f"Specified work_dir {work_dir} does not exist.")
47
+ log(DEBUG, "Initialising: %s", self.__class__.__name__)
48
+ log(DEBUG, "Backend config: %s", backend_config)
53
49
 
54
- # Init ray and append working dir if needed
55
- runtime_env = (
56
- self._configure_runtime_env(work_dir=work_dir) if work_dir else None
57
- )
58
- init_ray(runtime_env=runtime_env)
50
+ # Initialise ray
51
+ self.init_args_key = "init_args"
52
+ self.init_ray(backend_config)
59
53
 
60
54
  # Validate client resources
61
55
  self.client_resources_key = "client_resources"
56
+ self.client_resources = self._validate_client_resources(config=backend_config)
62
57
 
63
- # Create actor pool
64
- use_tf = backend_config.get("tensorflow", False)
65
- actor_kwargs = {"on_actor_init_fn": enable_tf_gpu_growth} if use_tf else {}
66
-
67
- client_resources = self._validate_client_resources(config=backend_config)
68
- self.pool = BasicActorPool(
69
- actor_type=ClientAppActor,
70
- client_resources=client_resources,
71
- actor_kwargs=actor_kwargs,
72
- )
73
-
74
- def _configure_runtime_env(self, work_dir: str) -> Dict[str, Union[str, List[str]]]:
75
- """Return list of files/subdirectories to exclude relative to work_dir.
58
+ # Valide actor resources
59
+ self.actor_kwargs = self._validate_actor_arguments(config=backend_config)
60
+ self.pool: Optional[BasicActorPool] = None
76
61
 
77
- Without this, Ray will push everything to the Ray Cluster.
78
- """
79
- runtime_env: Dict[str, Union[str, List[str]]] = {"working_dir": work_dir}
80
-
81
- excludes = []
82
- path = pathlib.Path(work_dir)
83
- for p in path.rglob("*"):
84
- # Exclude files need to be relative to the working_dir
85
- if p.is_file() and not str(p).endswith(".py"):
86
- excludes.append(str(p.relative_to(path)))
87
- runtime_env["excludes"] = excludes
88
-
89
- return runtime_env
62
+ self.app_fn: Optional[Callable[[], ClientApp]] = None
90
63
 
91
64
  def _validate_client_resources(self, config: BackendConfig) -> ClientResourcesDict:
92
65
  client_resources_config = config.get(self.client_resources_key)
@@ -109,7 +82,7 @@ class RayBackend(Backend):
109
82
  else:
110
83
  client_resources = {"num_cpus": 2, "num_gpus": 0.0}
111
84
  log(
112
- INFO,
85
+ DEBUG,
113
86
  "`%s` not specified in backend config. Applying default setting: %s",
114
87
  self.client_resources_key,
115
88
  client_resources,
@@ -117,59 +90,105 @@ class RayBackend(Backend):
117
90
 
118
91
  return client_resources
119
92
 
93
+ def _validate_actor_arguments(self, config: BackendConfig) -> ActorArgsDict:
94
+ actor_args_config = config.get("actor", False)
95
+ actor_args: ActorArgsDict = {}
96
+ if actor_args_config:
97
+ use_tf = actor_args.get("tensorflow", False)
98
+ if use_tf:
99
+ actor_args["on_actor_init_fn"] = enable_tf_gpu_growth
100
+ return actor_args
101
+
102
+ def init_ray(self, backend_config: BackendConfig) -> None:
103
+ """Intialises Ray if not already initialised."""
104
+ if not ray.is_initialized():
105
+ ray_init_args: dict[
106
+ str,
107
+ ConfigsRecordValues,
108
+ ] = {}
109
+
110
+ if backend_config.get(self.init_args_key):
111
+ for k, v in backend_config[self.init_args_key].items():
112
+ ray_init_args[k] = v
113
+ ray.init(
114
+ runtime_env={"env_vars": {"PYTHONPATH": ":".join(sys.path)}},
115
+ **ray_init_args,
116
+ )
117
+
120
118
  @property
121
119
  def num_workers(self) -> int:
122
120
  """Return number of actors in pool."""
123
- return self.pool.num_actors
121
+ return self.pool.num_actors if self.pool else 0
124
122
 
125
123
  def is_worker_idle(self) -> bool:
126
124
  """Report whether the pool has idle actors."""
127
- return self.pool.is_actor_available()
125
+ return self.pool.is_actor_available() if self.pool else False
128
126
 
129
- async def build(self) -> None:
127
+ def build(self, app_fn: Callable[[], ClientApp]) -> None:
130
128
  """Build pool of Ray actors that this backend will submit jobs to."""
131
- await self.pool.add_actors_to_pool(self.pool.actors_capacity)
132
- log(INFO, "Constructed ActorPool with: %i actors", self.pool.num_actors)
129
+ # Create Actor Pool
130
+ try:
131
+ self.pool = BasicActorPool(
132
+ actor_type=ClientAppActor,
133
+ client_resources=self.client_resources,
134
+ actor_kwargs=self.actor_kwargs,
135
+ )
136
+ except Exception as ex:
137
+ raise ex
138
+
139
+ self.pool.add_actors_to_pool(self.pool.actors_capacity)
140
+ # Set ClientApp callable that ray actors will use
141
+ self.app_fn = app_fn
142
+ log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
133
143
 
134
- async def process_message(
144
+ def process_message(
135
145
  self,
136
- app: Callable[[], ClientApp],
137
146
  message: Message,
138
147
  context: Context,
139
- ) -> Tuple[Message, Context]:
148
+ ) -> tuple[Message, Context]:
140
149
  """Run ClientApp that process a given message.
141
150
 
142
151
  Return output message and updated context.
143
152
  """
144
- partition_id = message.metadata.partition_id
153
+ partition_id = context.node_config[PARTITION_ID_KEY]
154
+
155
+ if self.pool is None:
156
+ raise ValueError("The actor pool is empty, unfit to process messages.")
157
+
158
+ if self.app_fn is None:
159
+ raise ValueError(
160
+ "Unspecified function to load a `ClientApp`. "
161
+ "Call the backend's `build()` method before processing messages."
162
+ )
145
163
 
146
164
  try:
147
- # Submite a task to the pool
148
- future = await self.pool.submit(
165
+ # Submit a task to the pool
166
+ future = self.pool.submit(
149
167
  lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
150
- (app, message, str(partition_id), context),
168
+ (self.app_fn, message, str(partition_id), context),
151
169
  )
152
170
 
153
- await future
154
-
155
171
  # Fetch result
156
172
  (
157
173
  out_mssg,
158
174
  updated_context,
159
- ) = await self.pool.fetch_result_and_return_actor_to_pool(future)
175
+ ) = self.pool.fetch_result_and_return_actor_to_pool(future)
160
176
 
161
177
  return out_mssg, updated_context
162
178
 
163
- except LoadClientAppError as load_ex:
179
+ except Exception as ex:
164
180
  log(
165
181
  ERROR,
166
182
  "An exception was raised when processing a message by %s",
167
183
  self.__class__.__name__,
168
184
  )
169
- raise load_ex
185
+ # add actor back into pool
186
+ self.pool.add_actor_back_to_pool(future)
187
+ raise ex
170
188
 
171
- async def terminate(self) -> None:
189
+ def terminate(self) -> None:
172
190
  """Terminate all actors in actor pool."""
173
- await self.pool.terminate_all_actors()
191
+ if self.pool:
192
+ self.pool.terminate_all_actors()
174
193
  ray.shutdown()
175
- log(INFO, "Terminated %s", self.__class__.__name__)
194
+ log(DEBUG, "Terminated %s", self.__class__.__name__)