flwr 1.23.0__py3-none-any.whl → 1.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (339) hide show
  1. flwr/__init__.py +16 -5
  2. flwr/app/error.py +2 -2
  3. flwr/app/exception.py +3 -3
  4. flwr/cli/app.py +19 -0
  5. flwr/cli/{new/templates → app_cmd}/__init__.py +9 -1
  6. flwr/cli/app_cmd/publish.py +285 -0
  7. flwr/cli/app_cmd/review.py +262 -0
  8. flwr/cli/auth_plugin/auth_plugin.py +4 -5
  9. flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
  10. flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
  11. flwr/cli/build.py +60 -18
  12. flwr/cli/cli_account_auth_interceptor.py +24 -7
  13. flwr/cli/config_utils.py +101 -13
  14. flwr/cli/{new/templates/app/code/flwr_tune → federation}/__init__.py +10 -1
  15. flwr/cli/federation/ls.py +140 -0
  16. flwr/cli/federation/show.py +318 -0
  17. flwr/cli/install.py +91 -13
  18. flwr/cli/log.py +52 -9
  19. flwr/cli/login/login.py +7 -4
  20. flwr/cli/ls.py +211 -130
  21. flwr/cli/new/new.py +123 -331
  22. flwr/cli/pull.py +10 -5
  23. flwr/cli/run/run.py +71 -29
  24. flwr/cli/run_utils.py +148 -0
  25. flwr/cli/stop.py +26 -8
  26. flwr/cli/supernode/ls.py +25 -12
  27. flwr/cli/supernode/register.py +9 -4
  28. flwr/cli/supernode/unregister.py +5 -3
  29. flwr/cli/utils.py +239 -16
  30. flwr/client/__init__.py +1 -1
  31. flwr/client/dpfedavg_numpy_client.py +4 -1
  32. flwr/client/grpc_adapter_client/connection.py +8 -9
  33. flwr/client/grpc_rere_client/connection.py +16 -14
  34. flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
  35. flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
  36. flwr/client/message_handler/message_handler.py +2 -2
  37. flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
  38. flwr/client/numpy_client.py +1 -1
  39. flwr/client/rest_client/connection.py +18 -18
  40. flwr/client/run_info_store.py +4 -5
  41. flwr/client/typing.py +1 -1
  42. flwr/clientapp/client_app.py +9 -10
  43. flwr/clientapp/mod/centraldp_mods.py +16 -17
  44. flwr/clientapp/mod/localdp_mod.py +8 -9
  45. flwr/clientapp/typing.py +1 -1
  46. flwr/clientapp/utils.py +3 -3
  47. flwr/common/address.py +1 -2
  48. flwr/common/args.py +3 -4
  49. flwr/common/config.py +13 -16
  50. flwr/common/constant.py +5 -2
  51. flwr/common/differential_privacy.py +3 -4
  52. flwr/common/event_log_plugin/event_log_plugin.py +3 -4
  53. flwr/common/exit/exit.py +15 -2
  54. flwr/common/exit/exit_code.py +19 -0
  55. flwr/common/exit/exit_handler.py +6 -2
  56. flwr/common/exit/signal_handler.py +5 -5
  57. flwr/common/grpc.py +6 -6
  58. flwr/common/inflatable_protobuf_utils.py +1 -1
  59. flwr/common/inflatable_utils.py +38 -21
  60. flwr/common/logger.py +19 -19
  61. flwr/common/message.py +4 -4
  62. flwr/common/object_ref.py +7 -7
  63. flwr/common/record/array.py +3 -3
  64. flwr/common/record/arrayrecord.py +18 -30
  65. flwr/common/record/configrecord.py +3 -3
  66. flwr/common/record/recorddict.py +5 -5
  67. flwr/common/record/typeddict.py +9 -2
  68. flwr/common/recorddict_compat.py +7 -10
  69. flwr/common/retry_invoker.py +20 -20
  70. flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
  71. flwr/common/serde.py +11 -4
  72. flwr/common/serde_utils.py +2 -2
  73. flwr/common/telemetry.py +9 -5
  74. flwr/common/typing.py +58 -37
  75. flwr/compat/client/app.py +38 -37
  76. flwr/compat/client/grpc_client/connection.py +11 -11
  77. flwr/compat/server/app.py +5 -6
  78. flwr/proto/appio_pb2.py +13 -3
  79. flwr/proto/appio_pb2.pyi +134 -65
  80. flwr/proto/appio_pb2_grpc.py +20 -0
  81. flwr/proto/appio_pb2_grpc.pyi +27 -0
  82. flwr/proto/clientappio_pb2.py +17 -7
  83. flwr/proto/clientappio_pb2.pyi +15 -0
  84. flwr/proto/clientappio_pb2_grpc.py +206 -40
  85. flwr/proto/clientappio_pb2_grpc.pyi +168 -53
  86. flwr/proto/control_pb2.py +71 -52
  87. flwr/proto/control_pb2.pyi +277 -111
  88. flwr/proto/control_pb2_grpc.py +249 -40
  89. flwr/proto/control_pb2_grpc.pyi +185 -52
  90. flwr/proto/error_pb2.py +13 -3
  91. flwr/proto/error_pb2.pyi +24 -6
  92. flwr/proto/error_pb2_grpc.py +20 -0
  93. flwr/proto/error_pb2_grpc.pyi +27 -0
  94. flwr/proto/fab_pb2.py +14 -4
  95. flwr/proto/fab_pb2.pyi +59 -31
  96. flwr/proto/fab_pb2_grpc.py +20 -0
  97. flwr/proto/fab_pb2_grpc.pyi +27 -0
  98. flwr/proto/federation_pb2.py +38 -0
  99. flwr/proto/federation_pb2.pyi +56 -0
  100. flwr/proto/federation_pb2_grpc.py +24 -0
  101. flwr/proto/federation_pb2_grpc.pyi +31 -0
  102. flwr/proto/fleet_pb2.py +24 -14
  103. flwr/proto/fleet_pb2.pyi +141 -61
  104. flwr/proto/fleet_pb2_grpc.py +189 -48
  105. flwr/proto/fleet_pb2_grpc.pyi +175 -61
  106. flwr/proto/grpcadapter_pb2.py +14 -4
  107. flwr/proto/grpcadapter_pb2.pyi +38 -16
  108. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  109. flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
  110. flwr/proto/heartbeat_pb2.py +17 -7
  111. flwr/proto/heartbeat_pb2.pyi +51 -22
  112. flwr/proto/heartbeat_pb2_grpc.py +20 -0
  113. flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
  114. flwr/proto/log_pb2.py +13 -3
  115. flwr/proto/log_pb2.pyi +34 -11
  116. flwr/proto/log_pb2_grpc.py +20 -0
  117. flwr/proto/log_pb2_grpc.pyi +27 -0
  118. flwr/proto/message_pb2.py +15 -5
  119. flwr/proto/message_pb2.pyi +154 -86
  120. flwr/proto/message_pb2_grpc.py +20 -0
  121. flwr/proto/message_pb2_grpc.pyi +27 -0
  122. flwr/proto/node_pb2.py +15 -5
  123. flwr/proto/node_pb2.pyi +50 -25
  124. flwr/proto/node_pb2_grpc.py +20 -0
  125. flwr/proto/node_pb2_grpc.pyi +27 -0
  126. flwr/proto/recorddict_pb2.py +13 -3
  127. flwr/proto/recorddict_pb2.pyi +184 -107
  128. flwr/proto/recorddict_pb2_grpc.py +20 -0
  129. flwr/proto/recorddict_pb2_grpc.pyi +27 -0
  130. flwr/proto/run_pb2.py +40 -31
  131. flwr/proto/run_pb2.pyi +158 -84
  132. flwr/proto/run_pb2_grpc.py +20 -0
  133. flwr/proto/run_pb2_grpc.pyi +27 -0
  134. flwr/proto/serverappio_pb2.py +13 -3
  135. flwr/proto/serverappio_pb2.pyi +32 -8
  136. flwr/proto/serverappio_pb2_grpc.py +246 -65
  137. flwr/proto/serverappio_pb2_grpc.pyi +221 -85
  138. flwr/proto/simulationio_pb2.py +16 -8
  139. flwr/proto/simulationio_pb2.pyi +15 -0
  140. flwr/proto/simulationio_pb2_grpc.py +162 -41
  141. flwr/proto/simulationio_pb2_grpc.pyi +149 -55
  142. flwr/proto/transport_pb2.py +20 -10
  143. flwr/proto/transport_pb2.pyi +249 -160
  144. flwr/proto/transport_pb2_grpc.py +35 -4
  145. flwr/proto/transport_pb2_grpc.pyi +38 -8
  146. flwr/server/app.py +39 -17
  147. flwr/server/client_manager.py +4 -5
  148. flwr/server/client_proxy.py +10 -11
  149. flwr/server/compat/app.py +4 -5
  150. flwr/server/compat/app_utils.py +2 -1
  151. flwr/server/compat/grid_client_proxy.py +10 -12
  152. flwr/server/compat/legacy_context.py +3 -4
  153. flwr/server/fleet_event_log_interceptor.py +2 -1
  154. flwr/server/grid/grid.py +2 -3
  155. flwr/server/grid/grpc_grid.py +10 -8
  156. flwr/server/grid/inmemory_grid.py +4 -4
  157. flwr/server/run_serverapp.py +2 -3
  158. flwr/server/server.py +34 -39
  159. flwr/server/server_app.py +7 -8
  160. flwr/server/server_config.py +1 -2
  161. flwr/server/serverapp/app.py +34 -28
  162. flwr/server/serverapp_components.py +4 -5
  163. flwr/server/strategy/aggregate.py +9 -8
  164. flwr/server/strategy/bulyan.py +13 -11
  165. flwr/server/strategy/dp_adaptive_clipping.py +16 -20
  166. flwr/server/strategy/dp_fixed_clipping.py +12 -17
  167. flwr/server/strategy/dpfedavg_adaptive.py +3 -4
  168. flwr/server/strategy/dpfedavg_fixed.py +6 -10
  169. flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
  170. flwr/server/strategy/fedadagrad.py +18 -14
  171. flwr/server/strategy/fedadam.py +16 -14
  172. flwr/server/strategy/fedavg.py +16 -17
  173. flwr/server/strategy/fedavg_android.py +15 -15
  174. flwr/server/strategy/fedavgm.py +21 -18
  175. flwr/server/strategy/fedmedian.py +2 -3
  176. flwr/server/strategy/fedopt.py +11 -10
  177. flwr/server/strategy/fedprox.py +10 -9
  178. flwr/server/strategy/fedtrimmedavg.py +12 -11
  179. flwr/server/strategy/fedxgb_bagging.py +13 -11
  180. flwr/server/strategy/fedxgb_cyclic.py +6 -6
  181. flwr/server/strategy/fedxgb_nn_avg.py +4 -4
  182. flwr/server/strategy/fedyogi.py +16 -14
  183. flwr/server/strategy/krum.py +12 -11
  184. flwr/server/strategy/qfedavg.py +16 -15
  185. flwr/server/strategy/strategy.py +6 -9
  186. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
  187. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
  190. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
  192. flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
  193. flwr/server/superlink/fleet/message_handler/message_handler.py +75 -30
  194. flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
  195. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  196. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  197. flwr/server/superlink/fleet/vce/vce_api.py +15 -9
  198. flwr/server/superlink/linkstate/in_memory_linkstate.py +148 -149
  199. flwr/server/superlink/linkstate/linkstate.py +91 -43
  200. flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
  201. flwr/server/superlink/linkstate/sqlite_linkstate.py +502 -436
  202. flwr/server/superlink/linkstate/utils.py +6 -6
  203. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
  204. flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
  205. flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
  206. flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
  207. flwr/server/superlink/utils.py +4 -6
  208. flwr/server/typing.py +1 -1
  209. flwr/server/utils/tensorboard.py +15 -8
  210. flwr/server/workflow/default_workflows.py +5 -5
  211. flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
  213. flwr/serverapp/strategy/bulyan.py +16 -15
  214. flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
  215. flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
  216. flwr/serverapp/strategy/fedadagrad.py +10 -11
  217. flwr/serverapp/strategy/fedadam.py +10 -11
  218. flwr/serverapp/strategy/fedavg.py +9 -10
  219. flwr/serverapp/strategy/fedavgm.py +17 -16
  220. flwr/serverapp/strategy/fedmedian.py +2 -2
  221. flwr/serverapp/strategy/fedopt.py +10 -11
  222. flwr/serverapp/strategy/fedprox.py +7 -8
  223. flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
  224. flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
  225. flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
  226. flwr/serverapp/strategy/fedyogi.py +9 -11
  227. flwr/serverapp/strategy/krum.py +7 -7
  228. flwr/serverapp/strategy/multikrum.py +9 -9
  229. flwr/serverapp/strategy/qfedavg.py +17 -16
  230. flwr/serverapp/strategy/strategy.py +6 -9
  231. flwr/serverapp/strategy/strategy_utils.py +7 -8
  232. flwr/simulation/app.py +46 -42
  233. flwr/simulation/legacy_app.py +12 -12
  234. flwr/simulation/ray_transport/ray_actor.py +10 -11
  235. flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
  236. flwr/simulation/run_simulation.py +43 -43
  237. flwr/simulation/simulationio_connection.py +4 -4
  238. flwr/supercore/cli/flower_superexec.py +3 -4
  239. flwr/supercore/constant.py +34 -1
  240. flwr/supercore/corestate/corestate.py +24 -3
  241. flwr/supercore/corestate/in_memory_corestate.py +138 -0
  242. flwr/supercore/corestate/sqlite_corestate.py +157 -0
  243. flwr/supercore/ffs/disk_ffs.py +1 -2
  244. flwr/supercore/ffs/ffs.py +1 -2
  245. flwr/supercore/ffs/ffs_factory.py +1 -2
  246. flwr/{common → supercore}/heartbeat.py +20 -25
  247. flwr/supercore/object_store/in_memory_object_store.py +1 -2
  248. flwr/supercore/object_store/object_store.py +1 -2
  249. flwr/supercore/object_store/object_store_factory.py +1 -2
  250. flwr/supercore/object_store/sqlite_object_store.py +8 -7
  251. flwr/supercore/primitives/asymmetric.py +1 -1
  252. flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
  253. flwr/supercore/sqlite_mixin.py +37 -34
  254. flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
  255. flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
  256. flwr/supercore/superexec/run_superexec.py +9 -13
  257. flwr/supercore/utils.py +190 -0
  258. flwr/superlink/artifact_provider/artifact_provider.py +1 -2
  259. flwr/superlink/auth_plugin/auth_plugin.py +6 -9
  260. flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
  261. flwr/{cli/new/templates/app → superlink/federation}/__init__.py +10 -1
  262. flwr/superlink/federation/federation_manager.py +64 -0
  263. flwr/superlink/federation/noop_federation_manager.py +71 -0
  264. flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
  265. flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
  266. flwr/superlink/servicer/control/control_grpc.py +7 -6
  267. flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
  268. flwr/superlink/servicer/control/control_servicer.py +190 -23
  269. flwr/supernode/cli/flower_supernode.py +58 -3
  270. flwr/supernode/nodestate/in_memory_nodestate.py +121 -49
  271. flwr/supernode/nodestate/nodestate.py +52 -8
  272. flwr/supernode/nodestate/nodestate_factory.py +7 -4
  273. flwr/supernode/runtime/run_clientapp.py +41 -22
  274. flwr/supernode/servicer/clientappio/clientappio_servicer.py +46 -10
  275. flwr/supernode/start_client_internal.py +165 -46
  276. {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/METADATA +9 -11
  277. flwr-1.25.0.dist-info/RECORD +393 -0
  278. flwr/cli/new/templates/app/.gitignore.tpl +0 -163
  279. flwr/cli/new/templates/app/LICENSE.tpl +0 -202
  280. flwr/cli/new/templates/app/README.baseline.md.tpl +0 -127
  281. flwr/cli/new/templates/app/README.flowertune.md.tpl +0 -68
  282. flwr/cli/new/templates/app/README.md.tpl +0 -37
  283. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +0 -1
  284. flwr/cli/new/templates/app/code/__init__.py +0 -15
  285. flwr/cli/new/templates/app/code/__init__.py.tpl +0 -1
  286. flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +0 -1
  287. flwr/cli/new/templates/app/code/client.baseline.py.tpl +0 -75
  288. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +0 -93
  289. flwr/cli/new/templates/app/code/client.jax.py.tpl +0 -71
  290. flwr/cli/new/templates/app/code/client.mlx.py.tpl +0 -102
  291. flwr/cli/new/templates/app/code/client.numpy.py.tpl +0 -46
  292. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +0 -80
  293. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +0 -55
  294. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +0 -108
  295. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -82
  296. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +0 -110
  297. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +0 -36
  298. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +0 -92
  299. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +0 -87
  300. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -56
  301. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +0 -73
  302. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +0 -78
  303. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -66
  304. flwr/cli/new/templates/app/code/server.baseline.py.tpl +0 -43
  305. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +0 -42
  306. flwr/cli/new/templates/app/code/server.jax.py.tpl +0 -39
  307. flwr/cli/new/templates/app/code/server.mlx.py.tpl +0 -41
  308. flwr/cli/new/templates/app/code/server.numpy.py.tpl +0 -38
  309. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +0 -41
  310. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +0 -31
  311. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +0 -44
  312. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +0 -38
  313. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +0 -56
  314. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +0 -1
  315. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +0 -98
  316. flwr/cli/new/templates/app/code/task.jax.py.tpl +0 -57
  317. flwr/cli/new/templates/app/code/task.mlx.py.tpl +0 -102
  318. flwr/cli/new/templates/app/code/task.numpy.py.tpl +0 -7
  319. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +0 -98
  320. flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +0 -111
  321. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +0 -67
  322. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +0 -52
  323. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +0 -67
  324. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +0 -1
  325. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +0 -146
  326. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +0 -80
  327. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +0 -65
  328. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +0 -52
  329. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +0 -56
  330. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +0 -49
  331. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +0 -53
  332. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +0 -53
  333. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +0 -52
  334. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +0 -53
  335. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +0 -61
  336. flwr/supercore/object_store/utils.py +0 -43
  337. flwr-1.23.0.dist-info/RECORD +0 -439
  338. {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/WHEEL +0 -0
  339. {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/entry_points.txt +0 -0
flwr/server/server.py CHANGED
@@ -19,7 +19,6 @@ import concurrent.futures
19
19
  import io
20
20
  import timeit
21
21
  from logging import INFO, WARN
22
- from typing import Optional, Union
23
22
 
24
23
  from flwr.common import (
25
24
  Code,
@@ -43,15 +42,15 @@ from .server_config import ServerConfig
43
42
 
44
43
  FitResultsAndFailures = tuple[
45
44
  list[tuple[ClientProxy, FitRes]],
46
- list[Union[tuple[ClientProxy, FitRes], BaseException]],
45
+ list[tuple[ClientProxy, FitRes] | BaseException],
47
46
  ]
48
47
  EvaluateResultsAndFailures = tuple[
49
48
  list[tuple[ClientProxy, EvaluateRes]],
50
- list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
49
+ list[tuple[ClientProxy, EvaluateRes] | BaseException],
51
50
  ]
52
51
  ReconnectResultsAndFailures = tuple[
53
52
  list[tuple[ClientProxy, DisconnectRes]],
54
- list[Union[tuple[ClientProxy, DisconnectRes], BaseException]],
53
+ list[tuple[ClientProxy, DisconnectRes] | BaseException],
55
54
  ]
56
55
 
57
56
 
@@ -62,16 +61,16 @@ class Server:
62
61
  self,
63
62
  *,
64
63
  client_manager: ClientManager,
65
- strategy: Optional[Strategy] = None,
64
+ strategy: Strategy | None = None,
66
65
  ) -> None:
67
66
  self._client_manager: ClientManager = client_manager
68
67
  self.parameters: Parameters = Parameters(
69
68
  tensors=[], tensor_type="numpy.ndarray"
70
69
  )
71
70
  self.strategy: Strategy = strategy if strategy is not None else FedAvg()
72
- self.max_workers: Optional[int] = None
71
+ self.max_workers: int | None = None
73
72
 
74
- def set_max_workers(self, max_workers: Optional[int]) -> None:
73
+ def set_max_workers(self, max_workers: int | None) -> None:
75
74
  """Set the max_workers used by ThreadPoolExecutor."""
76
75
  self.max_workers = max_workers
77
76
 
@@ -84,7 +83,7 @@ class Server:
84
83
  return self._client_manager
85
84
 
86
85
  # pylint: disable=too-many-locals
87
- def fit(self, num_rounds: int, timeout: Optional[float]) -> tuple[History, float]:
86
+ def fit(self, num_rounds: int, timeout: float | None) -> tuple[History, float]:
88
87
  """Run federated averaging for a number of rounds."""
89
88
  history = History()
90
89
 
@@ -161,10 +160,8 @@ class Server:
161
160
  def evaluate_round(
162
161
  self,
163
162
  server_round: int,
164
- timeout: Optional[float],
165
- ) -> Optional[
166
- tuple[Optional[float], dict[str, Scalar], EvaluateResultsAndFailures]
167
- ]:
163
+ timeout: float | None,
164
+ ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None:
168
165
  """Validate current global model on a number of clients."""
169
166
  # Get clients and their respective instructions from strategy
170
167
  client_instructions = self.strategy.configure_evaluate(
@@ -198,7 +195,7 @@ class Server:
198
195
 
199
196
  # Aggregate the evaluation results
200
197
  aggregated_result: tuple[
201
- Optional[float],
198
+ float | None,
202
199
  dict[str, Scalar],
203
200
  ] = self.strategy.aggregate_evaluate(server_round, results, failures)
204
201
 
@@ -208,10 +205,8 @@ class Server:
208
205
  def fit_round(
209
206
  self,
210
207
  server_round: int,
211
- timeout: Optional[float],
212
- ) -> Optional[
213
- tuple[Optional[Parameters], dict[str, Scalar], FitResultsAndFailures]
214
- ]:
208
+ timeout: float | None,
209
+ ) -> tuple[Parameters | None, dict[str, Scalar], FitResultsAndFailures] | None:
215
210
  """Perform a single round of federated averaging."""
216
211
  # Get clients and their respective instructions from strategy
217
212
  client_instructions = self.strategy.configure_fit(
@@ -246,14 +241,14 @@ class Server:
246
241
 
247
242
  # Aggregate training results
248
243
  aggregated_result: tuple[
249
- Optional[Parameters],
244
+ Parameters | None,
250
245
  dict[str, Scalar],
251
246
  ] = self.strategy.aggregate_fit(server_round, results, failures)
252
247
 
253
248
  parameters_aggregated, metrics_aggregated = aggregated_result
254
249
  return parameters_aggregated, metrics_aggregated, (results, failures)
255
250
 
256
- def disconnect_all_clients(self, timeout: Optional[float]) -> None:
251
+ def disconnect_all_clients(self, timeout: float | None) -> None:
257
252
  """Send shutdown signal to all clients."""
258
253
  all_clients = self._client_manager.all()
259
254
  clients = [all_clients[k] for k in all_clients.keys()]
@@ -266,11 +261,11 @@ class Server:
266
261
  )
267
262
 
268
263
  def _get_initial_parameters(
269
- self, server_round: int, timeout: Optional[float]
264
+ self, server_round: int, timeout: float | None
270
265
  ) -> Parameters:
271
266
  """Get initial parameters from one of the available clients."""
272
267
  # Server-side parameter initialization
273
- parameters: Optional[Parameters] = self.strategy.initialize_parameters(
268
+ parameters: Parameters | None = self.strategy.initialize_parameters(
274
269
  client_manager=self._client_manager
275
270
  )
276
271
  if parameters is not None:
@@ -297,8 +292,8 @@ class Server:
297
292
 
298
293
  def reconnect_clients(
299
294
  client_instructions: list[tuple[ClientProxy, ReconnectIns]],
300
- max_workers: Optional[int],
301
- timeout: Optional[float],
295
+ max_workers: int | None,
296
+ timeout: float | None,
302
297
  ) -> ReconnectResultsAndFailures:
303
298
  """Instruct clients to disconnect and never reconnect."""
304
299
  with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
@@ -313,7 +308,7 @@ def reconnect_clients(
313
308
 
314
309
  # Gather results
315
310
  results: list[tuple[ClientProxy, DisconnectRes]] = []
316
- failures: list[Union[tuple[ClientProxy, DisconnectRes], BaseException]] = []
311
+ failures: list[tuple[ClientProxy, DisconnectRes] | BaseException] = []
317
312
  for future in finished_fs:
318
313
  failure = future.exception()
319
314
  if failure is not None:
@@ -327,7 +322,7 @@ def reconnect_clients(
327
322
  def reconnect_client(
328
323
  client: ClientProxy,
329
324
  reconnect: ReconnectIns,
330
- timeout: Optional[float],
325
+ timeout: float | None,
331
326
  ) -> tuple[ClientProxy, DisconnectRes]:
332
327
  """Instruct client to disconnect and (optionally) reconnect later."""
333
328
  disconnect = client.reconnect(
@@ -340,8 +335,8 @@ def reconnect_client(
340
335
 
341
336
  def fit_clients(
342
337
  client_instructions: list[tuple[ClientProxy, FitIns]],
343
- max_workers: Optional[int],
344
- timeout: Optional[float],
338
+ max_workers: int | None,
339
+ timeout: float | None,
345
340
  group_id: int,
346
341
  ) -> FitResultsAndFailures:
347
342
  """Refine parameters concurrently on all selected clients."""
@@ -357,7 +352,7 @@ def fit_clients(
357
352
 
358
353
  # Gather results
359
354
  results: list[tuple[ClientProxy, FitRes]] = []
360
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = []
355
+ failures: list[tuple[ClientProxy, FitRes] | BaseException] = []
361
356
  for future in finished_fs:
362
357
  _handle_finished_future_after_fit(
363
358
  future=future, results=results, failures=failures
@@ -366,7 +361,7 @@ def fit_clients(
366
361
 
367
362
 
368
363
  def fit_client(
369
- client: ClientProxy, ins: FitIns, timeout: Optional[float], group_id: int
364
+ client: ClientProxy, ins: FitIns, timeout: float | None, group_id: int
370
365
  ) -> tuple[ClientProxy, FitRes]:
371
366
  """Refine parameters on a single client."""
372
367
  fit_res = client.fit(ins, timeout=timeout, group_id=group_id)
@@ -376,7 +371,7 @@ def fit_client(
376
371
  def _handle_finished_future_after_fit(
377
372
  future: concurrent.futures.Future, # type: ignore
378
373
  results: list[tuple[ClientProxy, FitRes]],
379
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
374
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
380
375
  ) -> None:
381
376
  """Convert finished future into either a result or a failure."""
382
377
  # Check if there was an exception
@@ -400,8 +395,8 @@ def _handle_finished_future_after_fit(
400
395
 
401
396
  def evaluate_clients(
402
397
  client_instructions: list[tuple[ClientProxy, EvaluateIns]],
403
- max_workers: Optional[int],
404
- timeout: Optional[float],
398
+ max_workers: int | None,
399
+ timeout: float | None,
405
400
  group_id: int,
406
401
  ) -> EvaluateResultsAndFailures:
407
402
  """Evaluate parameters concurrently on all selected clients."""
@@ -417,7 +412,7 @@ def evaluate_clients(
417
412
 
418
413
  # Gather results
419
414
  results: list[tuple[ClientProxy, EvaluateRes]] = []
420
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = []
415
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException] = []
421
416
  for future in finished_fs:
422
417
  _handle_finished_future_after_evaluate(
423
418
  future=future, results=results, failures=failures
@@ -428,7 +423,7 @@ def evaluate_clients(
428
423
  def evaluate_client(
429
424
  client: ClientProxy,
430
425
  ins: EvaluateIns,
431
- timeout: Optional[float],
426
+ timeout: float | None,
432
427
  group_id: int,
433
428
  ) -> tuple[ClientProxy, EvaluateRes]:
434
429
  """Evaluate parameters on a single client."""
@@ -439,7 +434,7 @@ def evaluate_client(
439
434
  def _handle_finished_future_after_evaluate(
440
435
  future: concurrent.futures.Future, # type: ignore
441
436
  results: list[tuple[ClientProxy, EvaluateRes]],
442
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
437
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
443
438
  ) -> None:
444
439
  """Convert finished future into either a result or a failure."""
445
440
  # Check if there was an exception
@@ -462,10 +457,10 @@ def _handle_finished_future_after_evaluate(
462
457
 
463
458
 
464
459
  def init_defaults(
465
- server: Optional[Server],
466
- config: Optional[ServerConfig],
467
- strategy: Optional[Strategy],
468
- client_manager: Optional[ClientManager],
460
+ server: Server | None,
461
+ config: ServerConfig | None,
462
+ strategy: Strategy | None,
463
+ client_manager: ClientManager | None,
469
464
  ) -> tuple[Server, ServerConfig]:
470
465
  """Create server instance if none was given."""
471
466
  if server is None:
flwr/server/server_app.py CHANGED
@@ -16,9 +16,8 @@
16
16
 
17
17
 
18
18
  import inspect
19
- from collections.abc import Iterator
19
+ from collections.abc import Callable, Iterator
20
20
  from contextlib import contextmanager
21
- from typing import Callable, Optional
22
21
 
23
22
  from flwr.common import Context
24
23
  from flwr.common.logger import warn_deprecated_feature_with_example
@@ -118,11 +117,11 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
118
117
  # pylint: disable=too-many-arguments,too-many-positional-arguments
119
118
  def __init__(
120
119
  self,
121
- server: Optional[Server] = None,
122
- config: Optional[ServerConfig] = None,
123
- strategy: Optional[Strategy] = None,
124
- client_manager: Optional[ClientManager] = None,
125
- server_fn: Optional[ServerFn] = None,
120
+ server: Server | None = None,
121
+ config: ServerConfig | None = None,
122
+ strategy: Strategy | None = None,
123
+ client_manager: ClientManager | None = None,
124
+ server_fn: ServerFn | None = None,
126
125
  ) -> None:
127
126
  if any([server, config, strategy, client_manager]):
128
127
  warn_deprecated_feature_with_example(
@@ -148,7 +147,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
148
147
  self._strategy = strategy
149
148
  self._client_manager = client_manager
150
149
  self._server_fn = server_fn
151
- self._main: Optional[ServerAppCallable] = None
150
+ self._main: ServerAppCallable | None = None
152
151
  self._lifespan = _empty_lifespan
153
152
 
154
153
  def __call__(self, grid: Grid, context: Context) -> None:
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  from dataclasses import dataclass
19
- from typing import Optional
20
19
 
21
20
 
22
21
  @dataclass
@@ -28,7 +27,7 @@ class ServerConfig:
28
27
  """
29
28
 
30
29
  num_rounds: int = 1
31
- round_timeout: Optional[float] = None
30
+ round_timeout: float | None = None
32
31
 
33
32
  def __repr__(self) -> str:
34
33
  """Return the string representation of the ServerConfig."""
@@ -19,7 +19,8 @@ import argparse
19
19
  from logging import DEBUG, ERROR, INFO
20
20
  from pathlib import Path
21
21
  from queue import Queue
22
- from typing import Optional
22
+
23
+ import grpc
23
24
 
24
25
  from flwr.app.exception import AppExitException
25
26
  from flwr.cli.config_utils import get_fab_metadata
@@ -38,8 +39,7 @@ from flwr.common.constant import (
38
39
  Status,
39
40
  SubStatus,
40
41
  )
41
- from flwr.common.exit import ExitCode, add_exit_handler, flwr_exit
42
- from flwr.common.heartbeat import HeartbeatSender, get_grpc_app_heartbeat_fn
42
+ from flwr.common.exit import ExitCode, flwr_exit, register_signal_handlers
43
43
  from flwr.common.logger import (
44
44
  log,
45
45
  mirror_output_to_queue,
@@ -66,6 +66,7 @@ from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub
66
66
  from flwr.server.grid.grpc_grid import GrpcGrid
67
67
  from flwr.server.run_serverapp import run as run_
68
68
  from flwr.supercore.app_utils import start_parent_process_monitor
69
+ from flwr.supercore.heartbeat import HeartbeatSender, make_app_heartbeat_fn_grpc
69
70
  from flwr.supercore.superexec.plugin import ServerAppExecPlugin
70
71
  from flwr.supercore.superexec.run_superexec import run_with_deprecation_warning
71
72
 
@@ -73,7 +74,7 @@ from flwr.supercore.superexec.run_superexec import run_with_deprecation_warning
73
74
  def flwr_serverapp() -> None:
74
75
  """Run process-isolated Flower ServerApp."""
75
76
  # Capture stdout/stderr
76
- log_queue: Queue[Optional[str]] = Queue()
77
+ log_queue: Queue[str | None] = Queue()
77
78
  mirror_output_to_queue(log_queue)
78
79
 
79
80
  args = _parse_args_run_flwr_serverapp().parse_args()
@@ -120,21 +121,22 @@ def flwr_serverapp() -> None:
120
121
 
121
122
  def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
122
123
  serverappio_api_address: str,
123
- log_queue: Queue[Optional[str]],
124
+ log_queue: Queue[str | None],
124
125
  token: str,
125
- flwr_dir: Optional[str] = None,
126
- certificates: Optional[bytes] = None,
127
- parent_pid: Optional[int] = None,
126
+ flwr_dir: str | None = None,
127
+ certificates: bytes | None = None,
128
+ parent_pid: int | None = None,
128
129
  ) -> None:
129
130
  """Run Flower ServerApp process."""
130
131
  # Monitor the main process in case of SIGKILL
131
132
  if parent_pid is not None:
132
133
  start_parent_process_monitor(parent_pid)
133
134
 
134
- # Resolve directory where FABs are installed
135
+ # Initialize variables for exit handler
135
136
  flwr_dir_ = get_flwr_dir(flwr_dir)
136
137
  log_uploader = None
137
138
  hash_run_id = None
139
+ run = None
138
140
  run_status = None
139
141
  heartbeat_sender = None
140
142
  grid = None
@@ -143,7 +145,7 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
143
145
 
144
146
  def on_exit() -> None:
145
147
  # Stop heartbeat sender
146
- if heartbeat_sender:
148
+ if heartbeat_sender and heartbeat_sender.is_running:
147
149
  heartbeat_sender.stop()
148
150
 
149
151
  # Stop log uploader for this run and upload final logs
@@ -151,7 +153,7 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
151
153
  stop_log_uploader(log_queue, log_uploader)
152
154
 
153
155
  # Update run status
154
- if run_status and grid:
156
+ if run and run_status and grid:
155
157
  run_status_proto = run_status_to_proto(run_status)
156
158
  grid._stub.UpdateRunStatus(
157
159
  UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
@@ -161,7 +163,12 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
161
163
  if grid:
162
164
  grid.close()
163
165
 
164
- add_exit_handler(on_exit)
166
+ # Register signal handlers for graceful shutdown
167
+ register_signal_handlers(
168
+ event_type=EventType.FLWR_SERVERAPP_RUN_LEAVE,
169
+ exit_message="Run stopped by user.",
170
+ exit_handlers=[on_exit],
171
+ )
165
172
 
166
173
  try:
167
174
  # Initialize the GrpcGrid
@@ -171,9 +178,14 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
171
178
  )
172
179
 
173
180
  # Pull ServerAppInputs from LinkState
174
- req = PullAppInputsRequest(token=token)
175
- log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
176
- res: PullAppInputsResponse = grid._stub.PullAppInputs(req)
181
+ try:
182
+ log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
183
+ req = PullAppInputsRequest(token=token)
184
+ res: PullAppInputsResponse = grid._stub.PullAppInputs(req)
185
+ except grpc.RpcError as ex:
186
+ if ex.code() == grpc.StatusCode.FAILED_PRECONDITION:
187
+ raise RuntimeError("Failed to start the run.") from ex
188
+ raise
177
189
  context = context_from_proto(res.context)
178
190
  run = run_from_proto(res.run)
179
191
  fab = fab_from_proto(res.fab)
@@ -214,25 +226,15 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
214
226
  app_path,
215
227
  )
216
228
 
217
- # Change status to Running
218
- run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
219
- grid._stub.UpdateRunStatus(
220
- UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
221
- )
222
-
223
229
  event(
224
230
  EventType.FLWR_SERVERAPP_RUN_ENTER,
225
231
  event_details={"run-id-hash": hash_run_id},
226
232
  )
227
233
 
228
234
  # Set up heartbeat sender
229
- heartbeat_fn = get_grpc_app_heartbeat_fn(
230
- grid._stub,
231
- run.run_id,
232
- failure_message="Heartbeat failed unexpectedly. The SuperLink could "
233
- "not find the provided run ID, or the run status is invalid.",
235
+ heartbeat_sender = HeartbeatSender(
236
+ make_app_heartbeat_fn_grpc(grid._stub, token)
234
237
  )
235
- heartbeat_sender = HeartbeatSender(heartbeat_fn)
236
238
  heartbeat_sender.start()
237
239
 
238
240
  # Load and run the ServerApp with the Grid
@@ -256,11 +258,15 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
256
258
  # Raised when the run is already stopped by the user
257
259
  except RunNotRunningException:
258
260
  log(INFO, "")
259
- log(INFO, "Run ID %s stopped.", run.run_id)
261
+ log(INFO, "Run ID %s stopped.", run.run_id) # type: ignore[union-attr]
260
262
  log(INFO, "")
261
263
  run_status = None
262
264
  # No need to update the exit code since this is expected behavior
263
265
 
266
+ except RuntimeError:
267
+ log(ERROR, "Failed to start run.")
268
+ exit_code = ExitCode.SERVERAPP_RUN_START_REJECTED
269
+
264
270
  except Exception as ex: # pylint: disable=broad-exception-caught
265
271
  exc_entity = "ServerApp"
266
272
  log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  from dataclasses import dataclass
19
- from typing import Optional
20
19
 
21
20
  from .client_manager import ClientManager
22
21
  from .server import Server
@@ -46,7 +45,7 @@ class ServerAppComponents: # pylint: disable=too-many-instance-attributes
46
45
  will be used.
47
46
  """
48
47
 
49
- server: Optional[Server] = None
50
- config: Optional[ServerConfig] = None
51
- strategy: Optional[Strategy] = None
52
- client_manager: Optional[ClientManager] = None
48
+ server: Server | None = None
49
+ config: ServerConfig | None = None
50
+ strategy: Strategy | None = None
51
+ client_manager: ClientManager | None = None
@@ -15,8 +15,9 @@
15
15
  """Aggregation functions for strategy implementations."""
16
16
  # mypy: disallow_untyped_calls=False
17
17
 
18
+ from collections.abc import Callable
18
19
  from functools import partial, reduce
19
- from typing import Any, Callable, Union
20
+ from typing import Any
20
21
 
21
22
  import numpy as np
22
23
 
@@ -37,7 +38,7 @@ def aggregate(results: list[tuple[NDArrays, int]]) -> NDArrays:
37
38
  # Compute average weights of each layer
38
39
  weights_prime: NDArrays = [
39
40
  reduce(np.add, layer_updates) / num_examples_total
40
- for layer_updates in zip(*weighted_weights)
41
+ for layer_updates in zip(*weighted_weights, strict=True)
41
42
  ]
42
43
  return weights_prime
43
44
 
@@ -53,7 +54,7 @@ def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
53
54
  )
54
55
 
55
56
  def _try_inplace(
56
- x: NDArray, y: Union[NDArray, np.float64], np_binary_op: np.ufunc
57
+ x: NDArray, y: NDArray | np.float64, np_binary_op: np.ufunc
57
58
  ) -> NDArray:
58
59
  return ( # type: ignore[no-any-return]
59
60
  np_binary_op(x, y, out=x)
@@ -75,7 +76,7 @@ def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
75
76
  )
76
77
  params = [
77
78
  reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates)
78
- for layer_updates in zip(params, res)
79
+ for layer_updates in zip(params, res, strict=True)
79
80
  ]
80
81
 
81
82
  return params
@@ -88,7 +89,7 @@ def aggregate_median(results: list[tuple[NDArrays, int]]) -> NDArrays:
88
89
 
89
90
  # Compute median weight of each layer
90
91
  median_w: NDArrays = [
91
- np.median(np.asarray(layer), axis=0) for layer in zip(*weights)
92
+ np.median(np.asarray(layer), axis=0) for layer in zip(*weights, strict=True)
92
93
  ]
93
94
  return median_w
94
95
 
@@ -235,7 +236,7 @@ def aggregate_qffl(
235
236
  for j in range(1, len(deltas)):
236
237
  tmp += scaled_deltas[j][i]
237
238
  updates.append(tmp)
238
- new_parameters = [(u - v) * 1.0 for u, v in zip(parameters, updates)]
239
+ new_parameters = [(u - v) * 1.0 for u, v in zip(parameters, updates, strict=True)]
239
240
  return new_parameters
240
241
 
241
242
 
@@ -287,7 +288,7 @@ def aggregate_trimmed_avg(
287
288
 
288
289
  trimmed_w: NDArrays = [
289
290
  _trim_mean(np.asarray(layer), proportiontocut=proportiontocut)
290
- for layer in zip(*weights)
291
+ for layer in zip(*weights, strict=True)
291
292
  ]
292
293
 
293
294
  return trimmed_w
@@ -299,7 +300,7 @@ def _check_weights_equality(weights1: NDArrays, weights2: NDArrays) -> bool:
299
300
  return False
300
301
  return all(
301
302
  np.array_equal(layer_weights1, layer_weights2)
302
- for layer_weights1, layer_weights2 in zip(weights1, weights2)
303
+ for layer_weights1, layer_weights2 in zip(weights1, weights2, strict=True)
303
304
  )
304
305
 
305
306
 
@@ -18,8 +18,9 @@ Paper: arxiv.org/abs/1802.07927
18
18
  """
19
19
 
20
20
 
21
+ from collections.abc import Callable
21
22
  from logging import WARNING
22
- from typing import Any, Callable, Optional, Union
23
+ from typing import Any
23
24
 
24
25
  from flwr.common import (
25
26
  FitRes,
@@ -84,18 +85,19 @@ class Bulyan(FedAvg):
84
85
  min_evaluate_clients: int = 2,
85
86
  min_available_clients: int = 2,
86
87
  num_malicious_clients: int = 0,
87
- evaluate_fn: Optional[
88
+ evaluate_fn: (
88
89
  Callable[
89
90
  [int, NDArrays, dict[str, Scalar]],
90
- Optional[tuple[float, dict[str, Scalar]]],
91
+ tuple[float, dict[str, Scalar]] | None,
91
92
  ]
92
- ] = None,
93
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
94
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
93
+ | None
94
+ ) = None,
95
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
96
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
95
97
  accept_failures: bool = True,
96
- initial_parameters: Optional[Parameters] = None,
97
- fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
98
- evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
98
+ initial_parameters: Parameters | None = None,
99
+ fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
100
+ evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
99
101
  first_aggregation_rule: Callable = aggregate_krum, # type: ignore
100
102
  **aggregation_rule_kwargs: Any,
101
103
  ) -> None:
@@ -126,8 +128,8 @@ class Bulyan(FedAvg):
126
128
  self,
127
129
  server_round: int,
128
130
  results: list[tuple[ClientProxy, FitRes]],
129
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
130
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
131
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
132
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
131
133
  """Aggregate fit results using Bulyan."""
132
134
  if not results:
133
135
  return None, {}