flwr 1.23.0__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (292) hide show
  1. flwr/__init__.py +16 -5
  2. flwr/app/error.py +2 -2
  3. flwr/app/exception.py +3 -3
  4. flwr/cli/app.py +19 -0
  5. flwr/cli/app_cmd/__init__.py +23 -0
  6. flwr/cli/app_cmd/publish.py +285 -0
  7. flwr/cli/app_cmd/review.py +252 -0
  8. flwr/cli/auth_plugin/auth_plugin.py +4 -5
  9. flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
  10. flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
  11. flwr/cli/build.py +60 -18
  12. flwr/cli/cli_account_auth_interceptor.py +24 -7
  13. flwr/cli/config_utils.py +101 -13
  14. flwr/cli/federation/__init__.py +24 -0
  15. flwr/cli/federation/ls.py +140 -0
  16. flwr/cli/federation/show.py +317 -0
  17. flwr/cli/install.py +91 -13
  18. flwr/cli/log.py +52 -9
  19. flwr/cli/login/login.py +7 -4
  20. flwr/cli/ls.py +170 -130
  21. flwr/cli/new/new.py +33 -50
  22. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
  23. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  25. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  26. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  28. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  29. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  30. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
  31. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  32. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  33. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
  34. flwr/cli/pull.py +10 -5
  35. flwr/cli/run/run.py +77 -30
  36. flwr/cli/run_utils.py +130 -0
  37. flwr/cli/stop.py +25 -7
  38. flwr/cli/supernode/ls.py +16 -8
  39. flwr/cli/supernode/register.py +9 -4
  40. flwr/cli/supernode/unregister.py +5 -3
  41. flwr/cli/utils.py +376 -16
  42. flwr/client/__init__.py +1 -1
  43. flwr/client/dpfedavg_numpy_client.py +4 -1
  44. flwr/client/grpc_adapter_client/connection.py +6 -7
  45. flwr/client/grpc_rere_client/connection.py +10 -11
  46. flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
  47. flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
  48. flwr/client/message_handler/message_handler.py +2 -2
  49. flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
  50. flwr/client/numpy_client.py +1 -1
  51. flwr/client/rest_client/connection.py +12 -14
  52. flwr/client/run_info_store.py +4 -5
  53. flwr/client/typing.py +1 -1
  54. flwr/clientapp/client_app.py +9 -10
  55. flwr/clientapp/mod/centraldp_mods.py +16 -17
  56. flwr/clientapp/mod/localdp_mod.py +8 -9
  57. flwr/clientapp/typing.py +1 -1
  58. flwr/clientapp/utils.py +3 -3
  59. flwr/common/address.py +1 -2
  60. flwr/common/args.py +3 -4
  61. flwr/common/config.py +13 -16
  62. flwr/common/constant.py +5 -2
  63. flwr/common/differential_privacy.py +3 -4
  64. flwr/common/event_log_plugin/event_log_plugin.py +3 -4
  65. flwr/common/exit/exit.py +15 -2
  66. flwr/common/exit/exit_code.py +19 -0
  67. flwr/common/exit/exit_handler.py +6 -2
  68. flwr/common/exit/signal_handler.py +5 -5
  69. flwr/common/grpc.py +6 -6
  70. flwr/common/inflatable_protobuf_utils.py +1 -1
  71. flwr/common/inflatable_utils.py +38 -21
  72. flwr/common/logger.py +19 -19
  73. flwr/common/message.py +4 -4
  74. flwr/common/object_ref.py +7 -7
  75. flwr/common/record/array.py +3 -3
  76. flwr/common/record/arrayrecord.py +18 -30
  77. flwr/common/record/configrecord.py +3 -3
  78. flwr/common/record/recorddict.py +5 -5
  79. flwr/common/record/typeddict.py +9 -2
  80. flwr/common/recorddict_compat.py +7 -10
  81. flwr/common/retry_invoker.py +20 -20
  82. flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
  83. flwr/common/serde.py +5 -4
  84. flwr/common/serde_utils.py +2 -2
  85. flwr/common/telemetry.py +9 -5
  86. flwr/common/typing.py +52 -37
  87. flwr/compat/client/app.py +38 -37
  88. flwr/compat/client/grpc_client/connection.py +11 -11
  89. flwr/compat/server/app.py +5 -6
  90. flwr/proto/appio_pb2.py +13 -3
  91. flwr/proto/appio_pb2.pyi +134 -65
  92. flwr/proto/appio_pb2_grpc.py +20 -0
  93. flwr/proto/appio_pb2_grpc.pyi +27 -0
  94. flwr/proto/clientappio_pb2.py +17 -7
  95. flwr/proto/clientappio_pb2.pyi +15 -0
  96. flwr/proto/clientappio_pb2_grpc.py +206 -40
  97. flwr/proto/clientappio_pb2_grpc.pyi +168 -53
  98. flwr/proto/control_pb2.py +71 -52
  99. flwr/proto/control_pb2.pyi +277 -111
  100. flwr/proto/control_pb2_grpc.py +249 -40
  101. flwr/proto/control_pb2_grpc.pyi +185 -52
  102. flwr/proto/error_pb2.py +13 -3
  103. flwr/proto/error_pb2.pyi +24 -6
  104. flwr/proto/error_pb2_grpc.py +20 -0
  105. flwr/proto/error_pb2_grpc.pyi +27 -0
  106. flwr/proto/fab_pb2.py +14 -4
  107. flwr/proto/fab_pb2.pyi +59 -31
  108. flwr/proto/fab_pb2_grpc.py +20 -0
  109. flwr/proto/fab_pb2_grpc.pyi +27 -0
  110. flwr/proto/federation_pb2.py +38 -0
  111. flwr/proto/federation_pb2.pyi +56 -0
  112. flwr/proto/federation_pb2_grpc.py +24 -0
  113. flwr/proto/federation_pb2_grpc.pyi +31 -0
  114. flwr/proto/fleet_pb2.py +14 -4
  115. flwr/proto/fleet_pb2.pyi +137 -61
  116. flwr/proto/fleet_pb2_grpc.py +189 -48
  117. flwr/proto/fleet_pb2_grpc.pyi +175 -61
  118. flwr/proto/grpcadapter_pb2.py +14 -4
  119. flwr/proto/grpcadapter_pb2.pyi +38 -16
  120. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  121. flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
  122. flwr/proto/heartbeat_pb2.py +17 -7
  123. flwr/proto/heartbeat_pb2.pyi +51 -22
  124. flwr/proto/heartbeat_pb2_grpc.py +20 -0
  125. flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
  126. flwr/proto/log_pb2.py +13 -3
  127. flwr/proto/log_pb2.pyi +34 -11
  128. flwr/proto/log_pb2_grpc.py +20 -0
  129. flwr/proto/log_pb2_grpc.pyi +27 -0
  130. flwr/proto/message_pb2.py +15 -5
  131. flwr/proto/message_pb2.pyi +154 -86
  132. flwr/proto/message_pb2_grpc.py +20 -0
  133. flwr/proto/message_pb2_grpc.pyi +27 -0
  134. flwr/proto/node_pb2.py +15 -5
  135. flwr/proto/node_pb2.pyi +50 -25
  136. flwr/proto/node_pb2_grpc.py +20 -0
  137. flwr/proto/node_pb2_grpc.pyi +27 -0
  138. flwr/proto/recorddict_pb2.py +13 -3
  139. flwr/proto/recorddict_pb2.pyi +184 -107
  140. flwr/proto/recorddict_pb2_grpc.py +20 -0
  141. flwr/proto/recorddict_pb2_grpc.pyi +27 -0
  142. flwr/proto/run_pb2.py +40 -31
  143. flwr/proto/run_pb2.pyi +149 -84
  144. flwr/proto/run_pb2_grpc.py +20 -0
  145. flwr/proto/run_pb2_grpc.pyi +27 -0
  146. flwr/proto/serverappio_pb2.py +13 -3
  147. flwr/proto/serverappio_pb2.pyi +32 -8
  148. flwr/proto/serverappio_pb2_grpc.py +246 -65
  149. flwr/proto/serverappio_pb2_grpc.pyi +221 -85
  150. flwr/proto/simulationio_pb2.py +16 -8
  151. flwr/proto/simulationio_pb2.pyi +15 -0
  152. flwr/proto/simulationio_pb2_grpc.py +162 -41
  153. flwr/proto/simulationio_pb2_grpc.pyi +149 -55
  154. flwr/proto/transport_pb2.py +20 -10
  155. flwr/proto/transport_pb2.pyi +249 -160
  156. flwr/proto/transport_pb2_grpc.py +35 -4
  157. flwr/proto/transport_pb2_grpc.pyi +38 -8
  158. flwr/server/app.py +38 -17
  159. flwr/server/client_manager.py +4 -5
  160. flwr/server/client_proxy.py +10 -11
  161. flwr/server/compat/app.py +4 -5
  162. flwr/server/compat/app_utils.py +2 -1
  163. flwr/server/compat/grid_client_proxy.py +10 -12
  164. flwr/server/compat/legacy_context.py +3 -4
  165. flwr/server/fleet_event_log_interceptor.py +2 -1
  166. flwr/server/grid/grid.py +2 -3
  167. flwr/server/grid/grpc_grid.py +10 -8
  168. flwr/server/grid/inmemory_grid.py +4 -4
  169. flwr/server/run_serverapp.py +2 -3
  170. flwr/server/server.py +34 -39
  171. flwr/server/server_app.py +7 -8
  172. flwr/server/server_config.py +1 -2
  173. flwr/server/serverapp/app.py +34 -28
  174. flwr/server/serverapp_components.py +4 -5
  175. flwr/server/strategy/aggregate.py +9 -8
  176. flwr/server/strategy/bulyan.py +13 -11
  177. flwr/server/strategy/dp_adaptive_clipping.py +16 -20
  178. flwr/server/strategy/dp_fixed_clipping.py +12 -17
  179. flwr/server/strategy/dpfedavg_adaptive.py +3 -4
  180. flwr/server/strategy/dpfedavg_fixed.py +6 -10
  181. flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
  182. flwr/server/strategy/fedadagrad.py +18 -14
  183. flwr/server/strategy/fedadam.py +16 -14
  184. flwr/server/strategy/fedavg.py +16 -17
  185. flwr/server/strategy/fedavg_android.py +15 -15
  186. flwr/server/strategy/fedavgm.py +21 -18
  187. flwr/server/strategy/fedmedian.py +2 -3
  188. flwr/server/strategy/fedopt.py +11 -10
  189. flwr/server/strategy/fedprox.py +10 -9
  190. flwr/server/strategy/fedtrimmedavg.py +12 -11
  191. flwr/server/strategy/fedxgb_bagging.py +13 -11
  192. flwr/server/strategy/fedxgb_cyclic.py +6 -6
  193. flwr/server/strategy/fedxgb_nn_avg.py +4 -4
  194. flwr/server/strategy/fedyogi.py +16 -14
  195. flwr/server/strategy/krum.py +12 -11
  196. flwr/server/strategy/qfedavg.py +16 -15
  197. flwr/server/strategy/strategy.py +6 -9
  198. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
  199. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
  200. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
  201. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
  202. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
  203. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
  204. flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
  205. flwr/server/superlink/fleet/message_handler/message_handler.py +34 -28
  206. flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
  207. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  208. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  209. flwr/server/superlink/fleet/vce/vce_api.py +15 -9
  210. flwr/server/superlink/linkstate/in_memory_linkstate.py +115 -150
  211. flwr/server/superlink/linkstate/linkstate.py +59 -43
  212. flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
  213. flwr/server/superlink/linkstate/sqlite_linkstate.py +447 -438
  214. flwr/server/superlink/linkstate/utils.py +6 -6
  215. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
  216. flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
  217. flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
  218. flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
  219. flwr/server/superlink/utils.py +4 -6
  220. flwr/server/typing.py +1 -1
  221. flwr/server/utils/tensorboard.py +15 -8
  222. flwr/server/workflow/default_workflows.py +5 -5
  223. flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
  224. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
  225. flwr/serverapp/strategy/bulyan.py +16 -15
  226. flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
  227. flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
  228. flwr/serverapp/strategy/fedadagrad.py +10 -11
  229. flwr/serverapp/strategy/fedadam.py +10 -11
  230. flwr/serverapp/strategy/fedavg.py +9 -10
  231. flwr/serverapp/strategy/fedavgm.py +17 -16
  232. flwr/serverapp/strategy/fedmedian.py +2 -2
  233. flwr/serverapp/strategy/fedopt.py +10 -11
  234. flwr/serverapp/strategy/fedprox.py +7 -8
  235. flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
  236. flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
  237. flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
  238. flwr/serverapp/strategy/fedyogi.py +9 -11
  239. flwr/serverapp/strategy/krum.py +7 -7
  240. flwr/serverapp/strategy/multikrum.py +9 -9
  241. flwr/serverapp/strategy/qfedavg.py +17 -16
  242. flwr/serverapp/strategy/strategy.py +6 -9
  243. flwr/serverapp/strategy/strategy_utils.py +7 -8
  244. flwr/simulation/app.py +46 -42
  245. flwr/simulation/legacy_app.py +12 -12
  246. flwr/simulation/ray_transport/ray_actor.py +10 -11
  247. flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
  248. flwr/simulation/run_simulation.py +43 -43
  249. flwr/simulation/simulationio_connection.py +4 -4
  250. flwr/supercore/cli/flower_superexec.py +3 -4
  251. flwr/supercore/constant.py +31 -1
  252. flwr/supercore/corestate/corestate.py +24 -3
  253. flwr/supercore/corestate/in_memory_corestate.py +138 -0
  254. flwr/supercore/corestate/sqlite_corestate.py +157 -0
  255. flwr/supercore/ffs/disk_ffs.py +1 -2
  256. flwr/supercore/ffs/ffs.py +1 -2
  257. flwr/supercore/ffs/ffs_factory.py +1 -2
  258. flwr/{common → supercore}/heartbeat.py +20 -25
  259. flwr/supercore/object_store/in_memory_object_store.py +1 -2
  260. flwr/supercore/object_store/object_store.py +1 -2
  261. flwr/supercore/object_store/object_store_factory.py +1 -2
  262. flwr/supercore/object_store/sqlite_object_store.py +8 -7
  263. flwr/supercore/primitives/asymmetric.py +1 -1
  264. flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
  265. flwr/supercore/sqlite_mixin.py +37 -34
  266. flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
  267. flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
  268. flwr/supercore/superexec/run_superexec.py +9 -13
  269. flwr/superlink/artifact_provider/artifact_provider.py +1 -2
  270. flwr/superlink/auth_plugin/auth_plugin.py +6 -9
  271. flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
  272. flwr/superlink/federation/__init__.py +24 -0
  273. flwr/superlink/federation/federation_manager.py +64 -0
  274. flwr/superlink/federation/noop_federation_manager.py +71 -0
  275. flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
  276. flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
  277. flwr/superlink/servicer/control/control_grpc.py +5 -6
  278. flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
  279. flwr/superlink/servicer/control/control_servicer.py +102 -18
  280. flwr/supernode/cli/flower_supernode.py +58 -3
  281. flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
  282. flwr/supernode/nodestate/nodestate.py +7 -8
  283. flwr/supernode/nodestate/nodestate_factory.py +7 -4
  284. flwr/supernode/runtime/run_clientapp.py +41 -22
  285. flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
  286. flwr/supernode/start_client_internal.py +158 -42
  287. {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
  288. flwr-1.24.0.dist-info/RECORD +454 -0
  289. flwr/supercore/object_store/utils.py +0 -43
  290. flwr-1.23.0.dist-info/RECORD +0 -439
  291. {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
  292. {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
@@ -18,10 +18,8 @@ Paper: openreview.net/pdf?id=ByexElSYDr
18
18
  """
19
19
 
20
20
 
21
- from collections import OrderedDict
22
- from collections.abc import Iterable
21
+ from collections.abc import Callable, Iterable
23
22
  from logging import INFO
24
- from typing import Callable, Optional
25
23
 
26
24
  import numpy as np
27
25
 
@@ -105,12 +103,12 @@ class QFedAvg(FedAvg):
105
103
  weighted_by_key: str = "num-examples",
106
104
  arrayrecord_key: str = "arrays",
107
105
  configrecord_key: str = "config",
108
- train_metrics_aggr_fn: Optional[
109
- Callable[[list[RecordDict], str], MetricRecord]
110
- ] = None,
111
- evaluate_metrics_aggr_fn: Optional[
112
- Callable[[list[RecordDict], str], MetricRecord]
113
- ] = None,
106
+ train_metrics_aggr_fn: (
107
+ Callable[[list[RecordDict], str], MetricRecord] | None
108
+ ) = None,
109
+ evaluate_metrics_aggr_fn: (
110
+ Callable[[list[RecordDict], str], MetricRecord] | None
111
+ ) = None,
114
112
  ) -> None:
115
113
  super().__init__(
116
114
  fraction_train=fraction_train,
@@ -127,7 +125,7 @@ class QFedAvg(FedAvg):
127
125
  self.q = q
128
126
  self.client_learning_rate = client_learning_rate
129
127
  self.train_loss_key = train_loss_key
130
- self.current_arrays: Optional[ArrayRecord] = None
128
+ self.current_arrays: ArrayRecord | None = None
131
129
 
132
130
  def summary(self) -> None:
133
131
  """Log summary configuration of the strategy."""
@@ -148,7 +146,7 @@ class QFedAvg(FedAvg):
148
146
  self,
149
147
  server_round: int,
150
148
  replies: Iterable[Message],
151
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
149
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
152
150
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
153
151
  # Call FedAvg aggregate_train to perform validation and aggregation
154
152
  valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
@@ -184,7 +182,7 @@ class QFedAvg(FedAvg):
184
182
  if sum_delta is None:
185
183
  sum_delta = delta
186
184
  else:
187
- sum_delta = [sd + d for sd, d in zip(sum_delta, delta)]
185
+ sum_delta = [sd + d for sd, d in zip(sum_delta, delta, strict=True)]
188
186
  sum_h += h
189
187
 
190
188
  # Compute new global weights and convert to Array type
@@ -192,7 +190,7 @@ class QFedAvg(FedAvg):
192
190
  assert sum_delta is not None # Make mypy happy
193
191
  array_list = [
194
192
  Array(np.asarray(gw - (d / sum_h)))
195
- for gw, d in zip(global_weights, sum_delta)
193
+ for gw, d in zip(global_weights, sum_delta, strict=True)
196
194
  ]
197
195
 
198
196
  # Aggregate MetricRecords
@@ -200,13 +198,16 @@ class QFedAvg(FedAvg):
200
198
  [msg.content for msg in valid_replies],
201
199
  self.weighted_by_key,
202
200
  )
203
- return ArrayRecord(OrderedDict(zip(array_keys, array_list))), metrics
201
+ return (
202
+ ArrayRecord(dict(zip(array_keys, array_list, strict=True))),
203
+ metrics,
204
+ )
204
205
 
205
206
 
206
207
  def get_train_loss(msg: Message, loss_key: str) -> float:
207
208
  """Extract training loss from a Message."""
208
209
  metrics = list(msg.content.metric_records.values())[0]
209
- if (loss := metrics.get(loss_key)) is None or not isinstance(loss, (int, float)):
210
+ if (loss := metrics.get(loss_key)) is None or not isinstance(loss, (int | float)):
210
211
  raise AggregationError(
211
212
  "Missing or invalid training loss. "
212
213
  f"The strategy expected a float value for the key '{loss_key}' "
@@ -236,7 +237,7 @@ def compute_delta_and_h(
236
237
  ) -> tuple[list[NDArray], float]:
237
238
  """Compute delta and h used in q-FedAvg aggregation."""
238
239
  # Compute gradient_k = L * (w - w_k)
239
- for gw, lw in zip(global_weights, local_weights):
240
+ for gw, lw in zip(global_weights, local_weights, strict=True):
240
241
  np.subtract(gw, lw, out=lw)
241
242
  lw *= L
242
243
  grad = local_weights # After in-place operations, local_weights is now grad
@@ -18,9 +18,8 @@
18
18
  import io
19
19
  import time
20
20
  from abc import ABC, abstractmethod
21
- from collections.abc import Iterable
21
+ from collections.abc import Callable, Iterable
22
22
  from logging import INFO
23
- from typing import Callable, Optional
24
23
 
25
24
  from flwr.common import ArrayRecord, ConfigRecord, Message, MetricRecord, log
26
25
  from flwr.server import Grid
@@ -61,7 +60,7 @@ class Strategy(ABC):
61
60
  self,
62
61
  server_round: int,
63
62
  replies: Iterable[Message],
64
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
63
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
65
64
  """Aggregate training results from client nodes.
66
65
 
67
66
  Parameters
@@ -109,7 +108,7 @@ class Strategy(ABC):
109
108
  self,
110
109
  server_round: int,
111
110
  replies: Iterable[Message],
112
- ) -> Optional[MetricRecord]:
111
+ ) -> MetricRecord | None:
113
112
  """Aggregate evaluation metrics from client nodes.
114
113
 
115
114
  Parameters
@@ -138,11 +137,9 @@ class Strategy(ABC):
138
137
  initial_arrays: ArrayRecord,
139
138
  num_rounds: int = 3,
140
139
  timeout: float = 3600,
141
- train_config: Optional[ConfigRecord] = None,
142
- evaluate_config: Optional[ConfigRecord] = None,
143
- evaluate_fn: Optional[
144
- Callable[[int, ArrayRecord], Optional[MetricRecord]]
145
- ] = None,
140
+ train_config: ConfigRecord | None = None,
141
+ evaluate_config: ConfigRecord | None = None,
142
+ evaluate_fn: Callable[[int, ArrayRecord], MetricRecord | None] | None = None,
146
143
  ) -> Result:
147
144
  """Execute the federated learning strategy.
148
145
 
@@ -17,10 +17,9 @@
17
17
 
18
18
  import json
19
19
  import random
20
- from collections import OrderedDict
21
20
  from logging import INFO
22
21
  from time import sleep
23
- from typing import Optional, cast
22
+ from typing import cast
24
23
 
25
24
  import numpy as np
26
25
 
@@ -49,8 +48,8 @@ def config_to_str(config: ConfigRecord) -> str:
49
48
  def log_strategy_start_info(
50
49
  num_rounds: int,
51
50
  arrays: ArrayRecord,
52
- train_config: Optional[ConfigRecord],
53
- evaluate_config: Optional[ConfigRecord],
51
+ train_config: ConfigRecord | None,
52
+ evaluate_config: ConfigRecord | None,
54
53
  ) -> None:
55
54
  """Log information about the strategy start."""
56
55
  log(INFO, "\t├── Number of rounds: %d", num_rounds)
@@ -92,7 +91,7 @@ def aggregate_arrayrecords(
92
91
  # Perform weighted aggregation
93
92
  aggregated_np_arrays: dict[str, NDArray] = {}
94
93
 
95
- for record, weight in zip(records, weight_factors):
94
+ for record, weight in zip(records, weight_factors, strict=True):
96
95
  for record_item in record.array_records.values():
97
96
  # aggregate in-place
98
97
  for key, value in record_item.items():
@@ -102,7 +101,7 @@ def aggregate_arrayrecords(
102
101
  aggregated_np_arrays[key] += value.numpy() * weight
103
102
 
104
103
  return ArrayRecord(
105
- OrderedDict({k: Array(np.asarray(v)) for k, v in aggregated_np_arrays.items()})
104
+ {k: Array(np.asarray(v)) for k, v in aggregated_np_arrays.items()}
106
105
  )
107
106
 
108
107
 
@@ -125,7 +124,7 @@ def aggregate_metricrecords(
125
124
  weight_factors = [w / total_weight for w in weights]
126
125
 
127
126
  aggregated_metrics = MetricRecord()
128
- for record, weight in zip(records, weight_factors):
127
+ for record, weight in zip(records, weight_factors, strict=True):
129
128
  for record_item in record.metric_records.values():
130
129
  # aggregate in-place
131
130
  for key, value in record_item.items():
@@ -142,7 +141,7 @@ def aggregate_metricrecords(
142
141
  current_list = cast(list[float], aggregated_metrics[key])
143
142
  aggregated_metrics[key] = [
144
143
  curr + val * weight
145
- for curr, val in zip(current_list, value)
144
+ for curr, val in zip(current_list, value, strict=True)
146
145
  ]
147
146
  else:
148
147
  current_value = cast(float, aggregated_metrics[key])
flwr/simulation/app.py CHANGED
@@ -18,7 +18,6 @@
18
18
  import argparse
19
19
  from logging import DEBUG, ERROR, INFO
20
20
  from queue import Queue
21
- from typing import Optional
22
21
 
23
22
  from flwr.cli.config_utils import get_fab_metadata
24
23
  from flwr.cli.install import install_from_fab
@@ -38,8 +37,7 @@ from flwr.common.constant import (
38
37
  Status,
39
38
  SubStatus,
40
39
  )
41
- from flwr.common.exit import ExitCode, flwr_exit
42
- from flwr.common.heartbeat import HeartbeatSender, get_grpc_app_heartbeat_fn
40
+ from flwr.common.exit import ExitCode, flwr_exit, register_signal_handlers
43
41
  from flwr.common.logger import (
44
42
  log,
45
43
  mirror_output_to_queue,
@@ -71,6 +69,7 @@ from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig
71
69
  from flwr.simulation.run_simulation import _run_simulation
72
70
  from flwr.simulation.simulationio_connection import SimulationIoConnection
73
71
  from flwr.supercore.app_utils import start_parent_process_monitor
72
+ from flwr.supercore.heartbeat import HeartbeatSender, make_app_heartbeat_fn_grpc
74
73
  from flwr.supercore.superexec.plugin import SimulationExecPlugin
75
74
  from flwr.supercore.superexec.run_superexec import run_with_deprecation_warning
76
75
 
@@ -78,7 +77,7 @@ from flwr.supercore.superexec.run_superexec import run_with_deprecation_warning
78
77
  def flwr_simulation() -> None:
79
78
  """Run process-isolated Flower Simulation."""
80
79
  # Capture stdout/stderr
81
- log_queue: Queue[Optional[str]] = Queue()
80
+ log_queue: Queue[str | None] = Queue()
82
81
  mirror_output_to_queue(log_queue)
83
82
 
84
83
  args = _parse_args_run_flwr_simulation().parse_args()
@@ -125,11 +124,11 @@ def flwr_simulation() -> None:
125
124
 
126
125
  def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
127
126
  simulationio_api_address: str,
128
- log_queue: Queue[Optional[str]],
127
+ log_queue: Queue[str | None],
129
128
  token: str,
130
- flwr_dir_: Optional[str] = None,
131
- certificates: Optional[bytes] = None,
132
- parent_pid: Optional[int] = None,
129
+ flwr_dir_: str | None = None,
130
+ certificates: bytes | None = None,
131
+ parent_pid: int | None = None,
133
132
  ) -> None:
134
133
  """Run Flower Simulation process."""
135
134
  # Start monitoring the parent process if a PID is provided
@@ -141,11 +140,35 @@ def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
141
140
  root_certificates=certificates,
142
141
  )
143
142
 
144
- # Resolve directory where FABs are installed
143
+ # Initialize variables for finally block
145
144
  flwr_dir = get_flwr_dir(flwr_dir_)
146
145
  log_uploader = None
147
146
  heartbeat_sender = None
147
+ run = None
148
148
  run_status = None
149
+ exit_code = ExitCode.SUCCESS
150
+
151
+ def on_exit() -> None:
152
+ # Stop heartbeat sender
153
+ if heartbeat_sender and heartbeat_sender.is_running:
154
+ heartbeat_sender.stop()
155
+
156
+ # Stop log uploader for this run and upload final logs
157
+ if log_uploader:
158
+ stop_log_uploader(log_queue, log_uploader)
159
+
160
+ # Update run status
161
+ if run and run_status:
162
+ run_status_proto = run_status_to_proto(run_status)
163
+ conn._stub.UpdateRunStatus(
164
+ UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
165
+ )
166
+
167
+ register_signal_handlers(
168
+ event_type=EventType.FLWR_SIMULATION_RUN_LEAVE,
169
+ exit_message="Run stopped by user.",
170
+ exit_handlers=[on_exit],
171
+ )
149
172
 
150
173
  try:
151
174
  # Pull SimulationInputs from LinkState
@@ -193,12 +216,6 @@ def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
193
216
  app_path,
194
217
  )
195
218
 
196
- # Change status to Running
197
- run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
198
- conn._stub.UpdateRunStatus(
199
- UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
200
- )
201
-
202
219
  # Pull Federation Options
203
220
  fed_opt_res: GetFederationOptionsResponse = conn._stub.GetFederationOptions(
204
221
  GetFederationOptionsRequest(run_id=run.run_id)
@@ -216,23 +233,20 @@ def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
216
233
  verbose: bool = fed_opt.get("verbose", False)
217
234
  enable_tf_gpu_growth: bool = fed_opt.get("enable_tf_gpu_growth", False)
218
235
 
236
+ run_id_hash = get_sha256_hash(run.run_id)
219
237
  event(
220
238
  EventType.FLWR_SIMULATION_RUN_ENTER,
221
239
  event_details={
222
240
  "backend": "ray",
223
241
  "num-supernodes": num_supernodes,
224
- "run-id-hash": get_sha256_hash(run.run_id),
242
+ "run-id-hash": run_id_hash,
225
243
  },
226
244
  )
227
245
 
228
246
  # Set up heartbeat sender
229
- heartbeat_fn = get_grpc_app_heartbeat_fn(
230
- conn._stub,
231
- run.run_id,
232
- failure_message="Heartbeat failed unexpectedly. The SuperLink could "
233
- "not find the provided run ID, or the run status is invalid.",
247
+ heartbeat_sender = HeartbeatSender(
248
+ make_app_heartbeat_fn_grpc(conn._stub, token)
234
249
  )
235
- heartbeat_sender = HeartbeatSender(heartbeat_fn)
236
250
  heartbeat_sender.start()
237
251
 
238
252
  # Launch the simulation
@@ -264,27 +278,17 @@ def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
264
278
  log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
265
279
  run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))
266
280
 
267
- finally:
268
- # Stop heartbeat sender
269
- if heartbeat_sender:
270
- heartbeat_sender.stop()
271
-
272
- # Stop log uploader for this run and upload final logs
273
- if log_uploader:
274
- stop_log_uploader(log_queue, log_uploader)
281
+ # General exit code
282
+ exit_code = ExitCode.SIMULATION_EXCEPTION
275
283
 
276
- # Update run status
277
- if run_status:
278
- run_status_proto = run_status_to_proto(run_status)
279
- conn._stub.UpdateRunStatus(
280
- UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
281
- )
282
-
283
- # Clean up the Context if it exists
284
- try:
285
- del updated_context
286
- except NameError:
287
- pass
284
+ flwr_exit(
285
+ code=exit_code,
286
+ event_type=EventType.FLWR_SIMULATION_RUN_LEAVE,
287
+ event_details={
288
+ "run-id-hash": run_id_hash,
289
+ "success": exit_code == ExitCode.SUCCESS,
290
+ },
291
+ )
288
292
 
289
293
 
290
294
  def _parse_args_run_flwr_simulation() -> argparse.ArgumentParser:
@@ -22,7 +22,7 @@ import threading
22
22
  import traceback
23
23
  import warnings
24
24
  from logging import ERROR, INFO
25
- from typing import Any, Optional, Union
25
+ from typing import Any
26
26
 
27
27
  import ray
28
28
  from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
@@ -101,17 +101,17 @@ def start_simulation(
101
101
  *,
102
102
  client_fn: ClientFnExt,
103
103
  num_clients: int,
104
- clients_ids: Optional[list[str]] = None, # UNSUPPORTED, WILL BE REMOVED
105
- client_resources: Optional[dict[str, float]] = None,
106
- server: Optional[Server] = None,
107
- config: Optional[ServerConfig] = None,
108
- strategy: Optional[Strategy] = None,
109
- client_manager: Optional[ClientManager] = None,
110
- ray_init_args: Optional[dict[str, Any]] = None,
111
- keep_initialised: Optional[bool] = False,
104
+ clients_ids: list[str] | None = None, # UNSUPPORTED, WILL BE REMOVED
105
+ client_resources: dict[str, float] | None = None,
106
+ server: Server | None = None,
107
+ config: ServerConfig | None = None,
108
+ strategy: Strategy | None = None,
109
+ client_manager: ClientManager | None = None,
110
+ ray_init_args: dict[str, Any] | None = None,
111
+ keep_initialised: bool | None = False,
112
112
  actor_type: type[VirtualClientEngineActor] = ClientAppActor,
113
- actor_kwargs: Optional[dict[str, Any]] = None,
114
- actor_scheduling: Union[str, NodeAffinitySchedulingStrategy] = "DEFAULT",
113
+ actor_kwargs: dict[str, Any] | None = None,
114
+ actor_scheduling: str | NodeAffinitySchedulingStrategy = "DEFAULT",
115
115
  ) -> History:
116
116
  """Start a Ray-based Flower simulation server.
117
117
 
@@ -219,7 +219,7 @@ def start_simulation(
219
219
  sys.exit()
220
220
 
221
221
  # Set logger propagation
222
- loop: Optional[asyncio.AbstractEventLoop] = None
222
+ loop: asyncio.AbstractEventLoop | None = None
223
223
  try:
224
224
  loop = asyncio.get_running_loop()
225
225
  except RuntimeError:
@@ -17,8 +17,9 @@
17
17
 
18
18
  import threading
19
19
  from abc import ABC
20
+ from collections.abc import Callable
20
21
  from logging import DEBUG, ERROR, WARNING
21
- from typing import Any, Callable, Optional, Union
22
+ from typing import Any
22
23
 
23
24
  import ray
24
25
  from ray import ObjectRef
@@ -76,13 +77,13 @@ class ClientAppActor(VirtualClientEngineActor):
76
77
  A function to execute upon actor initialization.
77
78
  """
78
79
 
79
- def __init__(self, on_actor_init_fn: Optional[Callable[[], None]] = None) -> None:
80
+ def __init__(self, on_actor_init_fn: Callable[[], None] | None = None) -> None:
80
81
  super().__init__()
81
82
  if on_actor_init_fn:
82
83
  on_actor_init_fn()
83
84
 
84
85
 
85
- def pool_size_from_resources(client_resources: dict[str, Union[int, float]]) -> int:
86
+ def pool_size_from_resources(client_resources: dict[str, int | float]) -> int:
86
87
  """Calculate number of Actors that fit in the cluster.
87
88
 
88
89
  For this we consider the resources available on each node and those required per
@@ -166,8 +167,8 @@ class VirtualClientEngineActorPool(ActorPool):
166
167
  def __init__(
167
168
  self,
168
169
  create_actor_fn: Callable[[], type[VirtualClientEngineActor]],
169
- client_resources: dict[str, Union[int, float]],
170
- actor_list: Optional[list[type[VirtualClientEngineActor]]] = None,
170
+ client_resources: dict[str, int | float],
171
+ actor_list: list[type[VirtualClientEngineActor]] | None = None,
171
172
  ):
172
173
  self.client_resources = client_resources
173
174
  self.create_actor_fn = create_actor_fn
@@ -186,9 +187,7 @@ class VirtualClientEngineActorPool(ActorPool):
186
187
 
187
188
  # A dict that maps cid to another dict containing: a reference to the remote job
188
189
  # and its status (i.e. whether it is ready or not)
189
- self._cid_to_future: dict[
190
- str, dict[str, Union[bool, Optional[ObjectRef[Any]]]]
191
- ] = {}
190
+ self._cid_to_future: dict[str, dict[str, bool | ObjectRef[Any] | None]] = {}
192
191
  self.actor_to_remove: set[str] = set() # a set
193
192
  self.num_actors = len(actors)
194
193
 
@@ -353,7 +352,7 @@ class VirtualClientEngineActorPool(ActorPool):
353
352
 
354
353
  return True
355
354
 
356
- def process_unordered_future(self, timeout: Optional[float] = None) -> None:
355
+ def process_unordered_future(self, timeout: float | None = None) -> None:
357
356
  """Similar to parent's get_next_unordered() but without final ray.get()."""
358
357
  if not self.has_next(): # type: ignore
359
358
  raise StopIteration("No more results to get")
@@ -384,7 +383,7 @@ class VirtualClientEngineActorPool(ActorPool):
384
383
  actor.terminate.remote()
385
384
 
386
385
  def get_client_result(
387
- self, cid: str, timeout: Optional[float]
386
+ self, cid: str, timeout: float | None
388
387
  ) -> tuple[Message, Context]:
389
388
  """Get result from VirtualClient with specific cid."""
390
389
  # Loop until all jobs submitted to the pool are completed. Break early
@@ -407,7 +406,7 @@ class BasicActorPool:
407
406
  def __init__(
408
407
  self,
409
408
  actor_type: type[VirtualClientEngineActor],
410
- client_resources: dict[str, Union[int, float]],
409
+ client_resources: dict[str, int | float],
411
410
  actor_kwargs: dict[str, Any],
412
411
  ):
413
412
  self.client_resources = client_resources
@@ -17,7 +17,6 @@
17
17
 
18
18
  import traceback
19
19
  from logging import ERROR
20
- from typing import Optional
21
20
 
22
21
  from flwr import common
23
22
  from flwr.client import ClientFnExt
@@ -74,7 +73,7 @@ class RayActorClientProxy(ClientProxy):
74
73
  },
75
74
  )
76
75
 
77
- def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
76
+ def _submit_job(self, message: Message, timeout: float | None) -> Message:
78
77
  """Sumbit a message to the ActorPool."""
79
78
  run_id = message.metadata.run_id
80
79
 
@@ -114,8 +113,8 @@ class RayActorClientProxy(ClientProxy):
114
113
  self,
115
114
  recorddict: RecordDict,
116
115
  message_type: str,
117
- timeout: Optional[float],
118
- group_id: Optional[int],
116
+ timeout: float | None,
117
+ group_id: int | None,
119
118
  ) -> Message:
120
119
  """Wrap a RecordDict inside a Message."""
121
120
  return make_message(
@@ -136,8 +135,8 @@ class RayActorClientProxy(ClientProxy):
136
135
  def get_properties(
137
136
  self,
138
137
  ins: common.GetPropertiesIns,
139
- timeout: Optional[float],
140
- group_id: Optional[int],
138
+ timeout: float | None,
139
+ group_id: int | None,
141
140
  ) -> common.GetPropertiesRes:
142
141
  """Return client's properties."""
143
142
  recorddict = getpropertiesins_to_recorddict(ins)
@@ -155,8 +154,8 @@ class RayActorClientProxy(ClientProxy):
155
154
  def get_parameters(
156
155
  self,
157
156
  ins: common.GetParametersIns,
158
- timeout: Optional[float],
159
- group_id: Optional[int],
157
+ timeout: float | None,
158
+ group_id: int | None,
160
159
  ) -> common.GetParametersRes:
161
160
  """Return the current local model parameters."""
162
161
  recorddict = getparametersins_to_recorddict(ins)
@@ -172,7 +171,7 @@ class RayActorClientProxy(ClientProxy):
172
171
  return recorddict_to_getparametersres(message_out.content, keep_input=False)
173
172
 
174
173
  def fit(
175
- self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
174
+ self, ins: common.FitIns, timeout: float | None, group_id: int | None
176
175
  ) -> common.FitRes:
177
176
  """Train model parameters on the locally held dataset."""
178
177
  recorddict = fitins_to_recorddict(
@@ -190,7 +189,7 @@ class RayActorClientProxy(ClientProxy):
190
189
  return recorddict_to_fitres(message_out.content, keep_input=False)
191
190
 
192
191
  def evaluate(
193
- self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
192
+ self, ins: common.EvaluateIns, timeout: float | None, group_id: int | None
194
193
  ) -> common.EvaluateRes:
195
194
  """Evaluate model parameters on the locally held dataset."""
196
195
  recorddict = evaluateins_to_recorddict(
@@ -210,8 +209,8 @@ class RayActorClientProxy(ClientProxy):
210
209
  def reconnect(
211
210
  self,
212
211
  ins: common.ReconnectIns,
213
- timeout: Optional[float],
214
- group_id: Optional[int],
212
+ timeout: float | None,
213
+ group_id: int | None,
215
214
  ) -> common.DisconnectRes:
216
215
  """Disconnect and (optionally) reconnect later."""
217
216
  return common.DisconnectRes(reason="") # Nothing to do here (yet)