flwr 1.17.0__py3-none-any.whl → 1.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. flwr/__init__.py +1 -1
  2. flwr/app/__init__.py +15 -0
  3. flwr/app/error.py +68 -0
  4. flwr/app/metadata.py +223 -0
  5. flwr/cli/__init__.py +1 -1
  6. flwr/cli/app.py +21 -2
  7. flwr/cli/build.py +83 -58
  8. flwr/cli/cli_user_auth_interceptor.py +1 -1
  9. flwr/cli/config_utils.py +53 -17
  10. flwr/cli/example.py +1 -1
  11. flwr/cli/install.py +1 -1
  12. flwr/cli/log.py +4 -4
  13. flwr/cli/login/__init__.py +1 -1
  14. flwr/cli/login/login.py +15 -8
  15. flwr/cli/ls.py +16 -37
  16. flwr/cli/new/__init__.py +1 -1
  17. flwr/cli/new/new.py +4 -4
  18. flwr/cli/new/templates/__init__.py +1 -1
  19. flwr/cli/new/templates/app/__init__.py +1 -1
  20. flwr/cli/new/templates/app/code/__init__.py +1 -1
  21. flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
  22. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +1 -1
  23. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +4 -4
  24. flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
  25. flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
  26. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
  28. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
  29. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  30. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  31. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  32. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  33. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  34. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  35. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  36. flwr/cli/run/__init__.py +1 -1
  37. flwr/cli/run/run.py +11 -19
  38. flwr/cli/stop.py +3 -3
  39. flwr/cli/utils.py +42 -17
  40. flwr/client/__init__.py +3 -3
  41. flwr/client/client.py +1 -1
  42. flwr/client/client_app.py +140 -138
  43. flwr/client/clientapp/__init__.py +1 -8
  44. flwr/client/clientapp/utils.py +1 -1
  45. flwr/client/dpfedavg_numpy_client.py +1 -1
  46. flwr/client/grpc_adapter_client/__init__.py +1 -1
  47. flwr/client/grpc_adapter_client/connection.py +5 -5
  48. flwr/client/grpc_rere_client/__init__.py +1 -1
  49. flwr/client/grpc_rere_client/client_interceptor.py +1 -1
  50. flwr/client/grpc_rere_client/connection.py +131 -61
  51. flwr/client/grpc_rere_client/grpc_adapter.py +35 -7
  52. flwr/client/message_handler/__init__.py +1 -1
  53. flwr/client/message_handler/message_handler.py +2 -2
  54. flwr/client/mod/__init__.py +1 -1
  55. flwr/client/mod/centraldp_mods.py +1 -1
  56. flwr/client/mod/comms_mods.py +39 -20
  57. flwr/client/mod/localdp_mod.py +6 -6
  58. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  59. flwr/client/mod/secure_aggregation/secagg_mod.py +1 -1
  60. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  61. flwr/client/mod/utils.py +1 -1
  62. flwr/client/numpy_client.py +1 -1
  63. flwr/client/rest_client/__init__.py +1 -1
  64. flwr/client/rest_client/connection.py +174 -68
  65. flwr/client/run_info_store.py +1 -1
  66. flwr/client/typing.py +1 -1
  67. flwr/clientapp/__init__.py +15 -0
  68. flwr/common/__init__.py +3 -3
  69. flwr/common/address.py +1 -1
  70. flwr/common/args.py +1 -1
  71. flwr/common/auth_plugin/__init__.py +3 -1
  72. flwr/common/auth_plugin/auth_plugin.py +30 -4
  73. flwr/common/config.py +1 -1
  74. flwr/common/constant.py +37 -8
  75. flwr/common/context.py +1 -1
  76. flwr/common/date.py +1 -1
  77. flwr/common/differential_privacy.py +1 -1
  78. flwr/common/differential_privacy_constants.py +1 -1
  79. flwr/common/dp.py +1 -1
  80. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  81. flwr/common/exit/exit.py +6 -6
  82. flwr/common/exit_handlers.py +31 -1
  83. flwr/common/grpc.py +1 -1
  84. flwr/common/heartbeat.py +165 -0
  85. flwr/common/inflatable.py +290 -0
  86. flwr/common/inflatable_grpc_utils.py +99 -0
  87. flwr/common/inflatable_rest_utils.py +99 -0
  88. flwr/common/inflatable_utils.py +341 -0
  89. flwr/common/logger.py +1 -1
  90. flwr/common/message.py +137 -252
  91. flwr/common/object_ref.py +1 -1
  92. flwr/common/parameter.py +1 -1
  93. flwr/common/pyproject.py +1 -1
  94. flwr/common/record/__init__.py +3 -2
  95. flwr/common/record/array.py +323 -0
  96. flwr/common/record/arrayrecord.py +121 -243
  97. flwr/common/record/configrecord.py +71 -16
  98. flwr/common/record/conversion_utils.py +2 -2
  99. flwr/common/record/metricrecord.py +71 -20
  100. flwr/common/record/recorddict.py +207 -90
  101. flwr/common/record/typeddict.py +1 -1
  102. flwr/common/recorddict_compat.py +2 -2
  103. flwr/common/retry_invoker.py +15 -11
  104. flwr/common/secure_aggregation/__init__.py +1 -1
  105. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  106. flwr/common/secure_aggregation/crypto/shamir.py +52 -30
  107. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
  108. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  109. flwr/common/secure_aggregation/quantization.py +1 -1
  110. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  111. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  112. flwr/common/serde.py +60 -184
  113. flwr/common/serde_utils.py +175 -0
  114. flwr/common/telemetry.py +2 -2
  115. flwr/common/typing.py +6 -4
  116. flwr/common/version.py +1 -1
  117. flwr/compat/__init__.py +15 -0
  118. flwr/compat/client/__init__.py +15 -0
  119. flwr/{client → compat/client}/app.py +71 -211
  120. flwr/{client → compat/client}/grpc_client/__init__.py +1 -1
  121. flwr/{client → compat/client}/grpc_client/connection.py +13 -13
  122. flwr/compat/common/__init__.py +15 -0
  123. flwr/compat/server/__init__.py +15 -0
  124. flwr/compat/server/app.py +174 -0
  125. flwr/compat/simulation/__init__.py +15 -0
  126. flwr/proto/__init__.py +1 -1
  127. flwr/proto/fleet_pb2.py +32 -27
  128. flwr/proto/fleet_pb2.pyi +49 -35
  129. flwr/proto/fleet_pb2_grpc.py +117 -13
  130. flwr/proto/fleet_pb2_grpc.pyi +47 -6
  131. flwr/proto/heartbeat_pb2.py +33 -0
  132. flwr/proto/heartbeat_pb2.pyi +66 -0
  133. flwr/proto/heartbeat_pb2_grpc.py +4 -0
  134. flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
  135. flwr/proto/message_pb2.py +28 -11
  136. flwr/proto/message_pb2.pyi +125 -0
  137. flwr/proto/recorddict_pb2.py +16 -28
  138. flwr/proto/recorddict_pb2.pyi +46 -64
  139. flwr/proto/run_pb2.py +24 -32
  140. flwr/proto/run_pb2.pyi +4 -52
  141. flwr/proto/serverappio_pb2.py +32 -23
  142. flwr/proto/serverappio_pb2.pyi +45 -3
  143. flwr/proto/serverappio_pb2_grpc.py +138 -34
  144. flwr/proto/serverappio_pb2_grpc.pyi +54 -13
  145. flwr/proto/simulationio_pb2.py +12 -11
  146. flwr/proto/simulationio_pb2_grpc.py +35 -0
  147. flwr/proto/simulationio_pb2_grpc.pyi +14 -0
  148. flwr/server/__init__.py +2 -2
  149. flwr/server/app.py +69 -187
  150. flwr/server/client_manager.py +1 -1
  151. flwr/server/client_proxy.py +1 -1
  152. flwr/server/compat/__init__.py +1 -1
  153. flwr/server/compat/app.py +1 -1
  154. flwr/server/compat/app_utils.py +51 -29
  155. flwr/server/compat/legacy_context.py +1 -1
  156. flwr/server/criterion.py +1 -1
  157. flwr/server/fleet_event_log_interceptor.py +2 -2
  158. flwr/server/grid/grid.py +3 -3
  159. flwr/server/grid/grpc_grid.py +104 -34
  160. flwr/server/grid/inmemory_grid.py +5 -4
  161. flwr/server/history.py +1 -1
  162. flwr/server/run_serverapp.py +1 -1
  163. flwr/server/server.py +1 -1
  164. flwr/server/server_app.py +65 -58
  165. flwr/server/server_config.py +1 -1
  166. flwr/server/serverapp/__init__.py +1 -1
  167. flwr/server/serverapp/app.py +19 -1
  168. flwr/server/serverapp_components.py +1 -1
  169. flwr/server/strategy/__init__.py +1 -1
  170. flwr/server/strategy/aggregate.py +1 -1
  171. flwr/server/strategy/bulyan.py +2 -2
  172. flwr/server/strategy/dp_adaptive_clipping.py +17 -17
  173. flwr/server/strategy/dp_fixed_clipping.py +17 -17
  174. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  175. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  176. flwr/server/strategy/fault_tolerant_fedavg.py +1 -1
  177. flwr/server/strategy/fedadagrad.py +1 -1
  178. flwr/server/strategy/fedadam.py +1 -1
  179. flwr/server/strategy/fedavg.py +1 -1
  180. flwr/server/strategy/fedavg_android.py +1 -1
  181. flwr/server/strategy/fedavgm.py +1 -1
  182. flwr/server/strategy/fedmedian.py +1 -1
  183. flwr/server/strategy/fedopt.py +1 -1
  184. flwr/server/strategy/fedprox.py +1 -1
  185. flwr/server/strategy/fedtrimmedavg.py +1 -1
  186. flwr/server/strategy/fedxgb_bagging.py +1 -1
  187. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  188. flwr/server/strategy/fedxgb_nn_avg.py +3 -2
  189. flwr/server/strategy/fedyogi.py +1 -1
  190. flwr/server/strategy/krum.py +1 -1
  191. flwr/server/strategy/qfedavg.py +1 -1
  192. flwr/server/strategy/strategy.py +1 -1
  193. flwr/server/superlink/__init__.py +1 -1
  194. flwr/server/superlink/ffs/__init__.py +3 -1
  195. flwr/server/superlink/ffs/disk_ffs.py +1 -1
  196. flwr/server/superlink/ffs/ffs.py +1 -1
  197. flwr/server/superlink/ffs/ffs_factory.py +1 -1
  198. flwr/server/superlink/fleet/__init__.py +1 -1
  199. flwr/server/superlink/fleet/grpc_adapter/__init__.py +1 -1
  200. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +14 -4
  201. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  202. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  203. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  204. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  205. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +13 -13
  206. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  207. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
  208. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -1
  209. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  210. flwr/server/superlink/fleet/message_handler/message_handler.py +136 -19
  211. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  212. flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -12
  213. flwr/server/superlink/fleet/vce/__init__.py +1 -1
  214. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  215. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  216. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
  217. flwr/server/superlink/fleet/vce/vce_api.py +7 -4
  218. flwr/server/superlink/linkstate/__init__.py +1 -1
  219. flwr/server/superlink/linkstate/in_memory_linkstate.py +139 -44
  220. flwr/server/superlink/linkstate/linkstate.py +54 -21
  221. flwr/server/superlink/linkstate/linkstate_factory.py +1 -1
  222. flwr/server/superlink/linkstate/sqlite_linkstate.py +150 -56
  223. flwr/server/superlink/linkstate/utils.py +34 -30
  224. flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
  225. flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
  226. flwr/server/superlink/simulation/__init__.py +1 -1
  227. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  228. flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
  229. flwr/server/superlink/utils.py +45 -3
  230. flwr/server/typing.py +1 -1
  231. flwr/server/utils/__init__.py +1 -1
  232. flwr/server/utils/tensorboard.py +1 -1
  233. flwr/server/utils/validator.py +3 -3
  234. flwr/server/workflow/__init__.py +1 -1
  235. flwr/server/workflow/constant.py +1 -1
  236. flwr/server/workflow/default_workflows.py +1 -1
  237. flwr/server/workflow/secure_aggregation/__init__.py +1 -1
  238. flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -1
  239. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -1
  240. flwr/serverapp/__init__.py +15 -0
  241. flwr/simulation/__init__.py +1 -1
  242. flwr/simulation/app.py +18 -1
  243. flwr/simulation/legacy_app.py +1 -1
  244. flwr/simulation/ray_transport/__init__.py +1 -1
  245. flwr/simulation/ray_transport/ray_actor.py +1 -1
  246. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  247. flwr/simulation/ray_transport/utils.py +1 -1
  248. flwr/simulation/run_simulation.py +2 -2
  249. flwr/simulation/simulationio_connection.py +1 -1
  250. flwr/supercore/__init__.py +15 -0
  251. flwr/supercore/object_store/__init__.py +24 -0
  252. flwr/supercore/object_store/in_memory_object_store.py +229 -0
  253. flwr/supercore/object_store/object_store.py +192 -0
  254. flwr/supercore/object_store/object_store_factory.py +44 -0
  255. flwr/superexec/__init__.py +1 -1
  256. flwr/superexec/app.py +1 -1
  257. flwr/superexec/deployment.py +7 -3
  258. flwr/superexec/exec_event_log_interceptor.py +4 -4
  259. flwr/superexec/exec_grpc.py +8 -4
  260. flwr/superexec/exec_servicer.py +126 -24
  261. flwr/superexec/exec_user_auth_interceptor.py +38 -9
  262. flwr/superexec/executor.py +5 -1
  263. flwr/superexec/simulation.py +8 -2
  264. flwr/superlink/__init__.py +15 -0
  265. flwr/{client/supernode → supernode}/__init__.py +1 -8
  266. flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +8 -15
  267. flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +4 -13
  268. flwr/supernode/cli/flwr_clientapp.py +81 -0
  269. flwr/{client → supernode}/nodestate/__init__.py +1 -1
  270. flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
  271. flwr/supernode/nodestate/nodestate.py +212 -0
  272. flwr/{client → supernode}/nodestate/nodestate_factory.py +1 -1
  273. flwr/supernode/runtime/__init__.py +15 -0
  274. flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +26 -57
  275. flwr/supernode/servicer/__init__.py +15 -0
  276. flwr/supernode/servicer/clientappio/__init__.py +24 -0
  277. flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +1 -1
  278. flwr/supernode/start_client_internal.py +491 -0
  279. {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/METADATA +6 -5
  280. flwr-1.19.0.dist-info/RECORD +365 -0
  281. {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
  282. {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
  283. flwr/client/heartbeat.py +0 -74
  284. flwr/client/nodestate/in_memory_nodestate.py +0 -38
  285. flwr-1.17.0.dist-info/LICENSE +0 -202
  286. flwr-1.17.0.dist-info/RECORD +0 -333
flwr/common/exit/exit.py CHANGED
@@ -37,13 +37,13 @@ def flwr_exit(
37
37
  ) -> NoReturn:
38
38
  """Handle application exit with an optional message.
39
39
 
40
- The exit message logged and displayed will follow this structure:
40
+ The exit message logged and displayed will follow this structure::
41
41
 
42
- >>> Exit Code: <code>
43
- >>> <message>
44
- >>> <short-help-message>
45
- >>>
46
- >>> For more information, visit: <help-page-url>
42
+ Exit Code: <code>
43
+ <message>
44
+ <short-help-message>
45
+
46
+ For more information, visit: <help-page-url>
47
47
 
48
48
  - `<code>`: The unique exit code representing the termination reason.
49
49
  - `<message>`: Optional context or additional information about the exit.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -30,6 +30,7 @@ SIGNAL_TO_EXIT_CODE: dict[int, int] = {
30
30
  signal.SIGINT: ExitCode.GRACEFUL_EXIT_SIGINT,
31
31
  signal.SIGTERM: ExitCode.GRACEFUL_EXIT_SIGTERM,
32
32
  }
33
+ registered_exit_handlers: list[Callable[[], None]] = []
33
34
 
34
35
  # SIGQUIT is not available on Windows
35
36
  if hasattr(signal, "SIGQUIT"):
@@ -41,6 +42,7 @@ def register_exit_handlers(
41
42
  exit_message: Optional[str] = None,
42
43
  grpc_servers: Optional[list[Server]] = None,
43
44
  bckg_threads: Optional[list[Thread]] = None,
45
+ exit_handlers: Optional[list[Callable[[], None]]] = None,
44
46
  ) -> None:
45
47
  """Register exit handlers for `SIGINT`, `SIGTERM` and `SIGQUIT` signals.
46
48
 
@@ -56,8 +58,12 @@ def register_exit_handlers(
56
58
  bckg_threads: Optional[List[Thread]] (default: None)
57
59
  An optional list of threads that need to be gracefully
58
60
  terminated before exiting.
61
+ exit_handlers: Optional[List[Callable[[], None]]] (default: None)
62
+ An optional list of exit handlers to be called before exiting.
63
+ Additional exit handlers can be added using `add_exit_handler`.
59
64
  """
60
65
  default_handlers: dict[int, Callable[[int, FrameType], None]] = {}
66
+ registered_exit_handlers.extend(exit_handlers or [])
61
67
 
62
68
  def graceful_exit_handler(signalnum: int, _frame: FrameType) -> None:
63
69
  """Exit handler to be registered with `signal.signal`.
@@ -68,6 +74,9 @@ def register_exit_handlers(
68
74
  # Reset to default handler
69
75
  signal.signal(signalnum, default_handlers[signalnum]) # type: ignore
70
76
 
77
+ for handler in registered_exit_handlers:
78
+ handler()
79
+
71
80
  if grpc_servers is not None:
72
81
  for grpc_server in grpc_servers:
73
82
  grpc_server.stop(grace=1)
@@ -87,3 +96,24 @@ def register_exit_handlers(
87
96
  for sig in SIGNAL_TO_EXIT_CODE:
88
97
  default_handler = signal.signal(sig, graceful_exit_handler) # type: ignore
89
98
  default_handlers[sig] = default_handler # type: ignore
99
+
100
+
101
+ def add_exit_handler(exit_handler: Callable[[], None]) -> None:
102
+ """Add an exit handler to be called on graceful exit.
103
+
104
+ This function allows you to register additional exit handlers
105
+ that will be executed when the application exits gracefully,
106
+ if `register_exit_handlers` was called.
107
+
108
+ Parameters
109
+ ----------
110
+ exit_handler : Callable[[], None]
111
+ A callable that takes no arguments and performs cleanup or
112
+ other actions before the application exits.
113
+
114
+ Notes
115
+ -----
116
+ This method is not thread-safe, and it allows you to add the
117
+ same exit handler multiple times.
118
+ """
119
+ registered_exit_handlers.append(exit_handler)
flwr/common/grpc.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -0,0 +1,165 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Heartbeat sender."""
16
+
17
+
18
+ import random
19
+ import threading
20
+ from typing import Callable, Union
21
+
22
+ import grpc
23
+
24
+ # pylint: disable=E0611
25
+ from flwr.proto.heartbeat_pb2 import SendAppHeartbeatRequest
26
+ from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub
27
+ from flwr.proto.simulationio_pb2_grpc import SimulationIoStub
28
+
29
+ # pylint: enable=E0611
30
+ from .constant import (
31
+ HEARTBEAT_BASE_MULTIPLIER,
32
+ HEARTBEAT_CALL_TIMEOUT,
33
+ HEARTBEAT_DEFAULT_INTERVAL,
34
+ HEARTBEAT_RANDOM_RANGE,
35
+ )
36
+ from .retry_invoker import RetryInvoker, exponential
37
+
38
+
39
+ class HeartbeatFailure(Exception):
40
+ """Exception raised when a heartbeat fails."""
41
+
42
+
43
+ class HeartbeatSender:
44
+ """Periodically send heartbeat signals to a server in a background thread.
45
+
46
+ This class uses the provided `heartbeat_fn` to send heartbeats. If a heartbeat
47
+ attempt fails, it will be retried using an exponential backoff strategy.
48
+
49
+ Parameters
50
+ ----------
51
+ heartbeat_fn : Callable[[], bool]
52
+ Function used to send a heartbeat signal. It should return True if the heartbeat
53
+ succeeds, or False if it fails. Any internal exceptions (e.g., gRPC errors)
54
+ should be handled within this function to ensure boolean return values.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ heartbeat_fn: Callable[[], bool],
60
+ ) -> None:
61
+ self.heartbeat_fn = heartbeat_fn
62
+ self._stop_event = threading.Event()
63
+ self._thread = threading.Thread(target=self._run, daemon=True)
64
+ self._retry_invoker = RetryInvoker(
65
+ lambda: exponential(max_delay=20),
66
+ HeartbeatFailure, # The only exception we want to retry on
67
+ max_tries=None,
68
+ max_time=None,
69
+ # Allow the stop event to interrupt the wait
70
+ wait_function=self._stop_event.wait, # type: ignore
71
+ )
72
+
73
+ def start(self) -> None:
74
+ """Start the heartbeat sender."""
75
+ if self._thread.is_alive():
76
+ raise RuntimeError("Heartbeat sender is already running.")
77
+ if self._stop_event.is_set():
78
+ raise RuntimeError("Cannot start a stopped heartbeat sender.")
79
+ self._thread.start()
80
+
81
+ def stop(self) -> None:
82
+ """Stop the heartbeat sender."""
83
+ if not self._thread.is_alive():
84
+ raise RuntimeError("Heartbeat sender is not running.")
85
+ self._stop_event.set()
86
+ self._thread.join()
87
+
88
+ @property
89
+ def is_running(self) -> bool:
90
+ """Return True if the heartbeat sender is running, False otherwise."""
91
+ return self._thread.is_alive() and not self._stop_event.is_set()
92
+
93
+ def _run(self) -> None:
94
+ """Periodically send heartbeats until stopped."""
95
+ while not self._stop_event.is_set():
96
+ # Attempt to send a heartbeat with retry on failure
97
+ self._retry_invoker.invoke(self._heartbeat)
98
+
99
+ # Calculate the interval for the next heartbeat
100
+ # Formula: next_interval = (interval - timeout) * random.uniform(0.7, 0.9)
101
+ rd = random.uniform(*HEARTBEAT_RANDOM_RANGE)
102
+ next_interval: float = HEARTBEAT_DEFAULT_INTERVAL - HEARTBEAT_CALL_TIMEOUT
103
+ next_interval *= HEARTBEAT_BASE_MULTIPLIER + rd
104
+
105
+ # Wait for the calculated interval or exit early if stopped
106
+ self._stop_event.wait(next_interval)
107
+
108
+ def _heartbeat(self) -> None:
109
+ """Send a single heartbeat and raise an exception if it fails.
110
+
111
+ Call the provided `heartbeat_fn`. If the function returns False,
112
+ a `HeartbeatFailure` exception is raised to trigger the retry mechanism.
113
+ """
114
+ if not self._stop_event.is_set():
115
+ if not self.heartbeat_fn():
116
+ raise HeartbeatFailure
117
+
118
+
119
+ def get_grpc_app_heartbeat_fn(
120
+ stub: Union[ServerAppIoStub, SimulationIoStub],
121
+ run_id: int,
122
+ *,
123
+ failure_message: str,
124
+ ) -> Callable[[], bool]:
125
+ """Get the function to send a heartbeat to gRPC endpoint.
126
+
127
+ This function is for app heartbeats only. It is not used for node heartbeats.
128
+
129
+ Parameters
130
+ ----------
131
+ stub : Union[ServerAppIoStub, SimulationIoStub]
132
+ gRPC stub to send the heartbeat.
133
+ run_id : int
134
+ The run ID to use in the heartbeat request.
135
+ failure_message : str
136
+ Error message to raise if the heartbeat fails.
137
+
138
+ Returns
139
+ -------
140
+ Callable[[], bool]
141
+ Function that sends a heartbeat to the gRPC endpoint.
142
+ """
143
+ # Construct the heartbeat request
144
+ req = SendAppHeartbeatRequest(
145
+ run_id=run_id, heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
146
+ )
147
+
148
+ def fn() -> bool:
149
+ # Call ServerAppIo API
150
+ try:
151
+ res = stub.SendAppHeartbeat(req)
152
+ except grpc.RpcError as e:
153
+ status_code = e.code()
154
+ if status_code == grpc.StatusCode.UNAVAILABLE:
155
+ return False
156
+ if status_code == grpc.StatusCode.DEADLINE_EXCEEDED:
157
+ return False
158
+ raise
159
+
160
+ # Check if not successful
161
+ if not res.success:
162
+ raise RuntimeError(failure_message)
163
+ return True
164
+
165
+ return fn
@@ -0,0 +1,290 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """InflatableObject base class."""
16
+
17
+
18
+ from __future__ import annotations
19
+
20
+ import hashlib
21
+ import threading
22
+ from collections.abc import Iterator
23
+ from contextlib import contextmanager
24
+ from typing import TypeVar, cast
25
+
26
+ from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
27
+
28
+ from .constant import HEAD_BODY_DIVIDER, HEAD_VALUE_DIVIDER
29
+
30
+
31
+ class UnexpectedObjectContentError(Exception):
32
+ """Exception raised when the content of an object does not conform to the expected
33
+ structure for an InflatableObject (i.e., head, body, and values within the head)."""
34
+
35
+ def __init__(self, object_id: str, reason: str):
36
+ super().__init__(
37
+ f"Object with ID '{object_id}' has an unexpected structure. {reason}"
38
+ )
39
+
40
+
41
+ _ctx = threading.local()
42
+
43
+
44
+ def _is_recompute_enabled() -> bool:
45
+ """Check if recomputing object IDs is enabled."""
46
+ return getattr(_ctx, "recompute_object_id_enabled", True)
47
+
48
+
49
+ def _get_computed_object_ids() -> set[str]:
50
+ """Get the set of computed object IDs."""
51
+ return getattr(_ctx, "computed_object_ids", set())
52
+
53
+
54
+ @contextmanager
55
+ def no_object_id_recompute() -> Iterator[None]:
56
+ """Context manager to disable recomputing object IDs."""
57
+ old_value = _is_recompute_enabled()
58
+ old_set = _get_computed_object_ids()
59
+ _ctx.recompute_object_id_enabled = False
60
+ _ctx.computed_object_ids = set()
61
+ try:
62
+ yield
63
+ finally:
64
+ _ctx.recompute_object_id_enabled = old_value
65
+ _ctx.computed_object_ids = old_set
66
+
67
+
68
+ class InflatableObject:
69
+ """Base class for inflatable objects."""
70
+
71
+ def deflate(self) -> bytes:
72
+ """Deflate object."""
73
+ raise NotImplementedError()
74
+
75
+ @classmethod
76
+ def inflate(
77
+ cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
78
+ ) -> InflatableObject:
79
+ """Inflate the object from bytes.
80
+
81
+ Parameters
82
+ ----------
83
+ object_content : bytes
84
+ The deflated object content.
85
+
86
+ children : Optional[dict[str, InflatableObject]] (default: None)
87
+ Dictionary of children InflatableObjects mapped to their object IDs. These
88
+ childrens enable the full inflation of the parent InflatableObject.
89
+
90
+ Returns
91
+ -------
92
+ InflatableObject
93
+ The inflated object.
94
+ """
95
+ raise NotImplementedError()
96
+
97
+ @property
98
+ def object_id(self) -> str:
99
+ """Get object_id."""
100
+ # If recomputing object ID is disabled and the object ID is already computed,
101
+ # return the cached object ID.
102
+ if (
103
+ not _is_recompute_enabled()
104
+ and (obj_id := self.__dict__.get("_object_id"))
105
+ in _get_computed_object_ids()
106
+ ):
107
+ return cast(str, obj_id)
108
+
109
+ if self.is_dirty or "_object_id" not in self.__dict__:
110
+ obj_id = get_object_id(self.deflate())
111
+ self.__dict__["_object_id"] = obj_id
112
+
113
+ # If recomputing object ID is disabled, add the object ID to the set of
114
+ # computed object IDs to avoid recomputing it within the context.
115
+ if not _is_recompute_enabled():
116
+ _get_computed_object_ids().add(obj_id)
117
+ return cast(str, self.__dict__["_object_id"])
118
+
119
+ @property
120
+ def children(self) -> dict[str, InflatableObject] | None:
121
+ """Get all child objects as a dictionary or None if there are no children."""
122
+ return None
123
+
124
+ @property
125
+ def is_dirty(self) -> bool:
126
+ """Check if the object is dirty after the last deflation.
127
+
128
+ An object is considered dirty if its content has changed since the last its
129
+ object ID was computed.
130
+ """
131
+ return True
132
+
133
+
134
+ T = TypeVar("T", bound=InflatableObject)
135
+
136
+
137
+ def get_object_id(object_content: bytes) -> str:
138
+ """Return a SHA-256 hash of the (deflated) object content."""
139
+ return hashlib.sha256(object_content).hexdigest()
140
+
141
+
142
+ def get_object_body(object_content: bytes, cls: type[T]) -> bytes:
143
+ """Return object body but raise an error if object type doesn't match class name."""
144
+ class_name = cls.__qualname__
145
+ object_type = get_object_type_from_object_content(object_content)
146
+ if not object_type == class_name:
147
+ raise ValueError(
148
+ f"Class name ({class_name}) and object type "
149
+ f"({object_type}) do not match."
150
+ )
151
+
152
+ # Return object body
153
+ return _get_object_body(object_content)
154
+
155
+
156
+ def add_header_to_object_body(object_body: bytes, obj: InflatableObject) -> bytes:
157
+ """Add header to object content."""
158
+ # Construct header
159
+ header = f"%s{HEAD_VALUE_DIVIDER}%s{HEAD_VALUE_DIVIDER}%d" % (
160
+ obj.__class__.__qualname__, # Type of object
161
+ ",".join((obj.children or {}).keys()), # IDs of child objects
162
+ len(object_body), # Length of object body
163
+ )
164
+
165
+ # Concatenate header and object body
166
+ ret = bytearray()
167
+ ret.extend(header.encode(encoding="utf-8"))
168
+ ret.extend(HEAD_BODY_DIVIDER)
169
+ ret.extend(object_body)
170
+ return bytes(ret)
171
+
172
+
173
+ def _get_object_head(object_content: bytes) -> bytes:
174
+ """Return object head from object content."""
175
+ index = object_content.find(HEAD_BODY_DIVIDER)
176
+ return object_content[:index]
177
+
178
+
179
+ def _get_object_body(object_content: bytes) -> bytes:
180
+ """Return object body from object content."""
181
+ index = object_content.find(HEAD_BODY_DIVIDER)
182
+ return object_content[index + len(HEAD_BODY_DIVIDER) :]
183
+
184
+
185
+ def is_valid_sha256_hash(object_id: str) -> bool:
186
+ """Check if the given string is a valid SHA-256 hash.
187
+
188
+ Parameters
189
+ ----------
190
+ object_id : str
191
+ The string to check.
192
+
193
+ Returns
194
+ -------
195
+ bool
196
+ ``True`` if the string is a valid SHA-256 hash, ``False`` otherwise.
197
+ """
198
+ if len(object_id) != 64:
199
+ return False
200
+ try:
201
+ # If base 16 int conversion succeeds, it's a valid hexadecimal str
202
+ int(object_id, 16)
203
+ return True
204
+ except ValueError:
205
+ return False
206
+
207
+
208
+ def get_object_type_from_object_content(object_content: bytes) -> str:
209
+ """Return object type from bytes."""
210
+ return get_object_head_values_from_object_content(object_content)[0]
211
+
212
+
213
+ def get_object_children_ids_from_object_content(object_content: bytes) -> list[str]:
214
+ """Return object children IDs from bytes."""
215
+ return get_object_head_values_from_object_content(object_content)[1]
216
+
217
+
218
+ def get_object_body_len_from_object_content(object_content: bytes) -> int:
219
+ """Return length of the object body."""
220
+ return get_object_head_values_from_object_content(object_content)[2]
221
+
222
+
223
+ def get_object_head_values_from_object_content(
224
+ object_content: bytes,
225
+ ) -> tuple[str, list[str], int]:
226
+ """Return object type and body length from object content.
227
+
228
+ Parameters
229
+ ----------
230
+ object_content : bytes
231
+ The deflated object content.
232
+
233
+ Returns
234
+ -------
235
+ tuple[str, list[str], int]
236
+ A tuple containing:
237
+ - The object type as a string.
238
+ - A list of child object IDs as strings.
239
+ - The length of the object body as an integer.
240
+ """
241
+ head = _get_object_head(object_content).decode(encoding="utf-8")
242
+ obj_type, children_str, body_len = head.split(HEAD_VALUE_DIVIDER)
243
+ children_ids = children_str.split(",") if children_str else []
244
+ return obj_type, children_ids, int(body_len)
245
+
246
+
247
+ def get_descendant_object_ids(obj: InflatableObject) -> set[str]:
248
+ """Get a set of object IDs of all descendants."""
249
+ descendants = set(get_all_nested_objects(obj).keys())
250
+ # Exclude Object ID of parent object
251
+ descendants.discard(obj.object_id)
252
+ return descendants
253
+
254
+
255
+ def get_all_nested_objects(obj: InflatableObject) -> dict[str, InflatableObject]:
256
+ """Get a dictionary of all nested objects, including the object itself.
257
+
258
+ Each key in the dictionary is an object ID, and the entries are ordered by post-
259
+ order traversal, i.e., child objects appear before their respective parents.
260
+ """
261
+ ret: dict[str, InflatableObject] = {}
262
+ if children := obj.children:
263
+ for child in children.values():
264
+ ret.update(get_all_nested_objects(child))
265
+
266
+ ret[obj.object_id] = obj
267
+
268
+ return ret
269
+
270
+
271
+ def get_object_tree(obj: InflatableObject) -> ObjectTree:
272
+ """Get a tree representation of the InflatableObject."""
273
+ tree_children = []
274
+ if children := obj.children:
275
+ for child in children.values():
276
+ tree_children.append(get_object_tree(child))
277
+ return ObjectTree(object_id=obj.object_id, children=tree_children)
278
+
279
+
280
+ def iterate_object_tree(
281
+ tree: ObjectTree,
282
+ ) -> Iterator[ObjectTree]:
283
+ """Iterate over the object tree and yield object IDs.
284
+
285
+ This function performs a post-order traversal of the tree, yielding the object ID of
286
+ each node after all its children have been yielded.
287
+ """
288
+ for child in tree.children:
289
+ yield from iterate_object_tree(child)
290
+ yield tree
@@ -0,0 +1,99 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """InflatableObject gRPC utils."""
16
+
17
+
18
+ from typing import Callable
19
+
20
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
21
+ PullObjectRequest,
22
+ PullObjectResponse,
23
+ PushObjectRequest,
24
+ PushObjectResponse,
25
+ )
26
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
27
+
28
+ from .inflatable_utils import ObjectIdNotPreregisteredError, ObjectUnavailableError
29
+
30
+
31
+ def make_pull_object_fn_grpc(
32
+ pull_object_grpc: Callable[[PullObjectRequest], PullObjectResponse],
33
+ node: Node,
34
+ run_id: int,
35
+ ) -> Callable[[str], bytes]:
36
+ """Create a pull object function that uses gRPC to pull objects.
37
+
38
+ Parameters
39
+ ----------
40
+ pull_object_grpc : Callable[[PullObjectRequest], PullObjectResponse]
41
+ The gRPC function to pull objects, e.g., `FleetStub.PullObject`.
42
+ node : Node
43
+ The node making the request.
44
+ run_id : int
45
+ The run ID for the current operation.
46
+
47
+ Returns
48
+ -------
49
+ Callable[[str], bytes]
50
+ A function that takes an object ID and returns the object content as bytes.
51
+ The function raises `ObjectIdNotPreregisteredError` if the object ID is not
52
+ pre-registered, or `ObjectUnavailableError` if the object is not yet available.
53
+ """
54
+
55
+ def pull_object_fn(object_id: str) -> bytes:
56
+ request = PullObjectRequest(node=node, run_id=run_id, object_id=object_id)
57
+ response: PullObjectResponse = pull_object_grpc(request)
58
+ if not response.object_found:
59
+ raise ObjectIdNotPreregisteredError(object_id)
60
+ if not response.object_available:
61
+ raise ObjectUnavailableError(object_id)
62
+ return response.object_content
63
+
64
+ return pull_object_fn
65
+
66
+
67
+ def make_push_object_fn_grpc(
68
+ push_object_grpc: Callable[[PushObjectRequest], PushObjectResponse],
69
+ node: Node,
70
+ run_id: int,
71
+ ) -> Callable[[str, bytes], None]:
72
+ """Create a push object function that uses gRPC to push objects.
73
+
74
+ Parameters
75
+ ----------
76
+ push_object_grpc : Callable[[PushObjectRequest], PushObjectResponse]
77
+ The gRPC function to push objects, e.g., `FleetStub.PushObject`.
78
+ node : Node
79
+ The node making the request.
80
+ run_id : int
81
+ The run ID for the current operation.
82
+
83
+ Returns
84
+ -------
85
+ Callable[[str, bytes], None]
86
+ A function that takes an object ID and its content as bytes, and pushes it
87
+ to the servicer. The function raises `ObjectIdNotPreregisteredError` if
88
+ the object ID is not pre-registered.
89
+ """
90
+
91
+ def push_object_fn(object_id: str, object_content: bytes) -> None:
92
+ request = PushObjectRequest(
93
+ node=node, run_id=run_id, object_id=object_id, object_content=object_content
94
+ )
95
+ response: PushObjectResponse = push_object_grpc(request)
96
+ if not response.stored:
97
+ raise ObjectIdNotPreregisteredError(object_id)
98
+
99
+ return push_object_fn