flwr 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (301) hide show
  1. flwr/__init__.py +16 -5
  2. flwr/app/error.py +2 -2
  3. flwr/app/exception.py +3 -3
  4. flwr/cli/app.py +34 -1
  5. flwr/cli/app_cmd/__init__.py +23 -0
  6. flwr/cli/app_cmd/publish.py +285 -0
  7. flwr/cli/app_cmd/review.py +252 -0
  8. flwr/cli/auth_plugin/__init__.py +15 -6
  9. flwr/cli/auth_plugin/auth_plugin.py +94 -0
  10. flwr/cli/auth_plugin/noop_auth_plugin.py +101 -0
  11. flwr/cli/auth_plugin/oidc_cli_plugin.py +46 -32
  12. flwr/cli/build.py +166 -53
  13. flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +29 -11
  14. flwr/cli/config_utils.py +101 -13
  15. flwr/cli/federation/__init__.py +24 -0
  16. flwr/cli/federation/ls.py +140 -0
  17. flwr/cli/federation/show.py +317 -0
  18. flwr/cli/install.py +91 -13
  19. flwr/cli/log.py +54 -11
  20. flwr/cli/login/login.py +41 -27
  21. flwr/cli/ls.py +177 -133
  22. flwr/cli/new/new.py +175 -40
  23. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
  24. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  25. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  26. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  28. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  29. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  30. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  31. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
  32. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  33. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  34. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
  35. flwr/cli/pull.py +12 -7
  36. flwr/cli/run/run.py +82 -31
  37. flwr/cli/run_utils.py +130 -0
  38. flwr/cli/stop.py +27 -9
  39. flwr/cli/supernode/__init__.py +25 -0
  40. flwr/cli/supernode/ls.py +268 -0
  41. flwr/cli/supernode/register.py +190 -0
  42. flwr/cli/supernode/unregister.py +140 -0
  43. flwr/cli/utils.py +464 -81
  44. flwr/client/__init__.py +2 -1
  45. flwr/client/dpfedavg_numpy_client.py +4 -1
  46. flwr/client/grpc_adapter_client/connection.py +12 -15
  47. flwr/client/grpc_rere_client/connection.py +68 -41
  48. flwr/client/grpc_rere_client/grpc_adapter.py +34 -14
  49. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +5 -7
  50. flwr/client/message_handler/message_handler.py +2 -2
  51. flwr/client/mod/secure_aggregation/secaggplus_mod.py +10 -8
  52. flwr/client/numpy_client.py +1 -1
  53. flwr/client/rest_client/connection.py +94 -51
  54. flwr/client/run_info_store.py +4 -5
  55. flwr/client/typing.py +1 -1
  56. flwr/clientapp/__init__.py +1 -2
  57. flwr/{client → clientapp}/client_app.py +9 -10
  58. flwr/clientapp/mod/centraldp_mods.py +16 -17
  59. flwr/clientapp/mod/localdp_mod.py +8 -9
  60. flwr/clientapp/typing.py +1 -1
  61. flwr/{client/clientapp → clientapp}/utils.py +4 -4
  62. flwr/common/address.py +1 -2
  63. flwr/common/args.py +3 -4
  64. flwr/common/config.py +13 -16
  65. flwr/common/constant.py +56 -13
  66. flwr/common/differential_privacy.py +3 -4
  67. flwr/common/event_log_plugin/event_log_plugin.py +3 -4
  68. flwr/common/exit/exit.py +15 -2
  69. flwr/common/exit/exit_code.py +39 -10
  70. flwr/common/exit/exit_handler.py +6 -2
  71. flwr/common/exit/signal_handler.py +5 -5
  72. flwr/common/grpc.py +6 -6
  73. flwr/common/inflatable_protobuf_utils.py +1 -1
  74. flwr/common/inflatable_utils.py +48 -31
  75. flwr/common/logger.py +19 -19
  76. flwr/common/message.py +4 -4
  77. flwr/common/object_ref.py +7 -7
  78. flwr/common/record/array.py +6 -6
  79. flwr/common/record/arrayrecord.py +18 -21
  80. flwr/common/record/configrecord.py +3 -3
  81. flwr/common/record/recorddict.py +5 -5
  82. flwr/common/record/typeddict.py +9 -2
  83. flwr/common/recorddict_compat.py +7 -10
  84. flwr/common/retry_invoker.py +20 -20
  85. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
  86. flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
  87. flwr/common/serde.py +9 -6
  88. flwr/common/serde_utils.py +2 -2
  89. flwr/common/telemetry.py +9 -5
  90. flwr/common/typing.py +59 -43
  91. flwr/compat/client/app.py +39 -38
  92. flwr/compat/client/grpc_client/connection.py +13 -13
  93. flwr/compat/server/app.py +5 -6
  94. flwr/proto/appio_pb2.py +13 -3
  95. flwr/proto/appio_pb2.pyi +134 -65
  96. flwr/proto/appio_pb2_grpc.py +20 -0
  97. flwr/proto/appio_pb2_grpc.pyi +27 -0
  98. flwr/proto/clientappio_pb2.py +17 -7
  99. flwr/proto/clientappio_pb2.pyi +15 -0
  100. flwr/proto/clientappio_pb2_grpc.py +206 -40
  101. flwr/proto/clientappio_pb2_grpc.pyi +168 -53
  102. flwr/proto/control_pb2.py +72 -40
  103. flwr/proto/control_pb2.pyi +319 -87
  104. flwr/proto/control_pb2_grpc.py +339 -28
  105. flwr/proto/control_pb2_grpc.pyi +209 -37
  106. flwr/proto/error_pb2.py +13 -3
  107. flwr/proto/error_pb2.pyi +24 -6
  108. flwr/proto/error_pb2_grpc.py +20 -0
  109. flwr/proto/error_pb2_grpc.pyi +27 -0
  110. flwr/proto/fab_pb2.py +24 -10
  111. flwr/proto/fab_pb2.pyi +68 -20
  112. flwr/proto/fab_pb2_grpc.py +20 -0
  113. flwr/proto/fab_pb2_grpc.pyi +27 -0
  114. flwr/proto/federation_pb2.py +38 -0
  115. flwr/proto/federation_pb2.pyi +56 -0
  116. flwr/proto/federation_pb2_grpc.py +24 -0
  117. flwr/proto/federation_pb2_grpc.pyi +31 -0
  118. flwr/proto/fleet_pb2.py +45 -27
  119. flwr/proto/fleet_pb2.pyi +186 -70
  120. flwr/proto/fleet_pb2_grpc.py +277 -66
  121. flwr/proto/fleet_pb2_grpc.pyi +201 -55
  122. flwr/proto/grpcadapter_pb2.py +14 -4
  123. flwr/proto/grpcadapter_pb2.pyi +38 -16
  124. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  125. flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
  126. flwr/proto/heartbeat_pb2.py +17 -7
  127. flwr/proto/heartbeat_pb2.pyi +51 -22
  128. flwr/proto/heartbeat_pb2_grpc.py +20 -0
  129. flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
  130. flwr/proto/log_pb2.py +13 -3
  131. flwr/proto/log_pb2.pyi +34 -11
  132. flwr/proto/log_pb2_grpc.py +20 -0
  133. flwr/proto/log_pb2_grpc.pyi +27 -0
  134. flwr/proto/message_pb2.py +15 -5
  135. flwr/proto/message_pb2.pyi +154 -86
  136. flwr/proto/message_pb2_grpc.py +20 -0
  137. flwr/proto/message_pb2_grpc.pyi +27 -0
  138. flwr/proto/node_pb2.py +16 -4
  139. flwr/proto/node_pb2.pyi +77 -4
  140. flwr/proto/node_pb2_grpc.py +20 -0
  141. flwr/proto/node_pb2_grpc.pyi +27 -0
  142. flwr/proto/recorddict_pb2.py +13 -3
  143. flwr/proto/recorddict_pb2.pyi +184 -107
  144. flwr/proto/recorddict_pb2_grpc.py +20 -0
  145. flwr/proto/recorddict_pb2_grpc.pyi +27 -0
  146. flwr/proto/run_pb2.py +40 -31
  147. flwr/proto/run_pb2.pyi +149 -84
  148. flwr/proto/run_pb2_grpc.py +20 -0
  149. flwr/proto/run_pb2_grpc.pyi +27 -0
  150. flwr/proto/serverappio_pb2.py +13 -3
  151. flwr/proto/serverappio_pb2.pyi +32 -8
  152. flwr/proto/serverappio_pb2_grpc.py +246 -65
  153. flwr/proto/serverappio_pb2_grpc.pyi +221 -85
  154. flwr/proto/simulationio_pb2.py +16 -8
  155. flwr/proto/simulationio_pb2.pyi +15 -0
  156. flwr/proto/simulationio_pb2_grpc.py +162 -41
  157. flwr/proto/simulationio_pb2_grpc.pyi +149 -55
  158. flwr/proto/transport_pb2.py +20 -10
  159. flwr/proto/transport_pb2.pyi +249 -160
  160. flwr/proto/transport_pb2_grpc.py +35 -4
  161. flwr/proto/transport_pb2_grpc.pyi +38 -8
  162. flwr/server/app.py +173 -127
  163. flwr/server/client_manager.py +4 -5
  164. flwr/server/client_proxy.py +10 -11
  165. flwr/server/compat/app.py +4 -5
  166. flwr/server/compat/app_utils.py +2 -1
  167. flwr/server/compat/grid_client_proxy.py +10 -12
  168. flwr/server/compat/legacy_context.py +3 -4
  169. flwr/server/fleet_event_log_interceptor.py +2 -1
  170. flwr/server/grid/grid.py +2 -3
  171. flwr/server/grid/grpc_grid.py +10 -8
  172. flwr/server/grid/inmemory_grid.py +4 -4
  173. flwr/server/run_serverapp.py +2 -3
  174. flwr/server/server.py +34 -39
  175. flwr/server/server_app.py +7 -8
  176. flwr/server/server_config.py +1 -2
  177. flwr/server/serverapp/app.py +34 -28
  178. flwr/server/serverapp_components.py +4 -5
  179. flwr/server/strategy/aggregate.py +9 -8
  180. flwr/server/strategy/bulyan.py +13 -11
  181. flwr/server/strategy/dp_adaptive_clipping.py +16 -20
  182. flwr/server/strategy/dp_fixed_clipping.py +12 -17
  183. flwr/server/strategy/dpfedavg_adaptive.py +3 -4
  184. flwr/server/strategy/dpfedavg_fixed.py +6 -10
  185. flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
  186. flwr/server/strategy/fedadagrad.py +18 -14
  187. flwr/server/strategy/fedadam.py +16 -14
  188. flwr/server/strategy/fedavg.py +16 -17
  189. flwr/server/strategy/fedavg_android.py +15 -15
  190. flwr/server/strategy/fedavgm.py +21 -18
  191. flwr/server/strategy/fedmedian.py +2 -3
  192. flwr/server/strategy/fedopt.py +11 -10
  193. flwr/server/strategy/fedprox.py +10 -9
  194. flwr/server/strategy/fedtrimmedavg.py +12 -11
  195. flwr/server/strategy/fedxgb_bagging.py +13 -11
  196. flwr/server/strategy/fedxgb_cyclic.py +6 -6
  197. flwr/server/strategy/fedxgb_nn_avg.py +4 -4
  198. flwr/server/strategy/fedyogi.py +16 -14
  199. flwr/server/strategy/krum.py +12 -11
  200. flwr/server/strategy/qfedavg.py +16 -15
  201. flwr/server/strategy/strategy.py +6 -9
  202. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +19 -8
  203. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
  204. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
  205. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
  206. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
  207. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +136 -42
  208. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +28 -51
  209. flwr/server/superlink/fleet/message_handler/message_handler.py +100 -49
  210. flwr/server/superlink/fleet/rest_rere/rest_api.py +54 -33
  211. flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
  212. flwr/server/superlink/fleet/vce/backend/raybackend.py +6 -6
  213. flwr/server/superlink/fleet/vce/vce_api.py +32 -13
  214. flwr/server/superlink/linkstate/in_memory_linkstate.py +266 -207
  215. flwr/server/superlink/linkstate/linkstate.py +161 -62
  216. flwr/server/superlink/linkstate/linkstate_factory.py +24 -6
  217. flwr/server/superlink/linkstate/sqlite_linkstate.py +698 -638
  218. flwr/server/superlink/linkstate/utils.py +9 -60
  219. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
  220. flwr/server/superlink/serverappio/serverappio_servicer.py +28 -23
  221. flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
  222. flwr/server/superlink/simulation/simulationio_servicer.py +19 -14
  223. flwr/server/superlink/utils.py +4 -6
  224. flwr/server/typing.py +1 -1
  225. flwr/server/utils/tensorboard.py +15 -8
  226. flwr/server/utils/validator.py +2 -3
  227. flwr/server/workflow/default_workflows.py +5 -5
  228. flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
  229. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +12 -10
  230. flwr/serverapp/strategy/bulyan.py +16 -15
  231. flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
  232. flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
  233. flwr/serverapp/strategy/fedadagrad.py +10 -11
  234. flwr/serverapp/strategy/fedadam.py +10 -11
  235. flwr/serverapp/strategy/fedavg.py +9 -10
  236. flwr/serverapp/strategy/fedavgm.py +17 -16
  237. flwr/serverapp/strategy/fedmedian.py +2 -2
  238. flwr/serverapp/strategy/fedopt.py +10 -11
  239. flwr/serverapp/strategy/fedprox.py +7 -8
  240. flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
  241. flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
  242. flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
  243. flwr/serverapp/strategy/fedyogi.py +9 -11
  244. flwr/serverapp/strategy/krum.py +7 -7
  245. flwr/serverapp/strategy/multikrum.py +9 -9
  246. flwr/serverapp/strategy/qfedavg.py +17 -16
  247. flwr/serverapp/strategy/strategy.py +6 -9
  248. flwr/serverapp/strategy/strategy_utils.py +7 -8
  249. flwr/simulation/app.py +46 -42
  250. flwr/simulation/legacy_app.py +12 -12
  251. flwr/simulation/ray_transport/ray_actor.py +11 -12
  252. flwr/simulation/ray_transport/ray_client_proxy.py +12 -13
  253. flwr/simulation/run_simulation.py +44 -43
  254. flwr/simulation/simulationio_connection.py +4 -4
  255. flwr/supercore/cli/flower_superexec.py +3 -4
  256. flwr/supercore/constant.py +52 -0
  257. flwr/supercore/corestate/corestate.py +24 -3
  258. flwr/supercore/corestate/in_memory_corestate.py +138 -0
  259. flwr/supercore/corestate/sqlite_corestate.py +157 -0
  260. flwr/supercore/ffs/disk_ffs.py +1 -2
  261. flwr/supercore/ffs/ffs.py +1 -2
  262. flwr/supercore/ffs/ffs_factory.py +1 -2
  263. flwr/{common → supercore}/heartbeat.py +20 -25
  264. flwr/supercore/object_store/in_memory_object_store.py +1 -6
  265. flwr/supercore/object_store/object_store.py +1 -2
  266. flwr/supercore/object_store/object_store_factory.py +27 -8
  267. flwr/supercore/object_store/sqlite_object_store.py +253 -0
  268. flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
  269. flwr/supercore/primitives/asymmetric.py +117 -0
  270. flwr/supercore/primitives/asymmetric_ed25519.py +175 -0
  271. flwr/supercore/sqlite_mixin.py +159 -0
  272. flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
  273. flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
  274. flwr/supercore/superexec/run_superexec.py +9 -13
  275. flwr/supercore/utils.py +20 -0
  276. flwr/superlink/artifact_provider/artifact_provider.py +1 -2
  277. flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
  278. flwr/superlink/auth_plugin/auth_plugin.py +88 -0
  279. flwr/superlink/auth_plugin/noop_auth_plugin.py +84 -0
  280. flwr/superlink/federation/__init__.py +24 -0
  281. flwr/superlink/federation/federation_manager.py +64 -0
  282. flwr/superlink/federation/noop_federation_manager.py +71 -0
  283. flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +41 -32
  284. flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
  285. flwr/superlink/servicer/control/control_grpc.py +18 -17
  286. flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
  287. flwr/superlink/servicer/control/control_servicer.py +239 -63
  288. flwr/supernode/cli/flower_supernode.py +74 -26
  289. flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
  290. flwr/supernode/nodestate/nodestate.py +7 -8
  291. flwr/supernode/nodestate/nodestate_factory.py +7 -4
  292. flwr/supernode/runtime/run_clientapp.py +43 -24
  293. flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
  294. flwr/supernode/start_client_internal.py +175 -51
  295. {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
  296. flwr-1.24.0.dist-info/RECORD +454 -0
  297. flwr/common/auth_plugin/auth_plugin.py +0 -149
  298. flwr/supercore/object_store/utils.py +0 -43
  299. flwr-1.22.0.dist-info/RECORD +0 -428
  300. {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
  301. {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
@@ -19,7 +19,8 @@ import argparse
19
19
  from logging import DEBUG, ERROR, INFO
20
20
  from pathlib import Path
21
21
  from queue import Queue
22
- from typing import Optional
22
+
23
+ import grpc
23
24
 
24
25
  from flwr.app.exception import AppExitException
25
26
  from flwr.cli.config_utils import get_fab_metadata
@@ -38,8 +39,7 @@ from flwr.common.constant import (
38
39
  Status,
39
40
  SubStatus,
40
41
  )
41
- from flwr.common.exit import ExitCode, add_exit_handler, flwr_exit
42
- from flwr.common.heartbeat import HeartbeatSender, get_grpc_app_heartbeat_fn
42
+ from flwr.common.exit import ExitCode, flwr_exit, register_signal_handlers
43
43
  from flwr.common.logger import (
44
44
  log,
45
45
  mirror_output_to_queue,
@@ -66,6 +66,7 @@ from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub
66
66
  from flwr.server.grid.grpc_grid import GrpcGrid
67
67
  from flwr.server.run_serverapp import run as run_
68
68
  from flwr.supercore.app_utils import start_parent_process_monitor
69
+ from flwr.supercore.heartbeat import HeartbeatSender, make_app_heartbeat_fn_grpc
69
70
  from flwr.supercore.superexec.plugin import ServerAppExecPlugin
70
71
  from flwr.supercore.superexec.run_superexec import run_with_deprecation_warning
71
72
 
@@ -73,7 +74,7 @@ from flwr.supercore.superexec.run_superexec import run_with_deprecation_warning
73
74
  def flwr_serverapp() -> None:
74
75
  """Run process-isolated Flower ServerApp."""
75
76
  # Capture stdout/stderr
76
- log_queue: Queue[Optional[str]] = Queue()
77
+ log_queue: Queue[str | None] = Queue()
77
78
  mirror_output_to_queue(log_queue)
78
79
 
79
80
  args = _parse_args_run_flwr_serverapp().parse_args()
@@ -120,21 +121,22 @@ def flwr_serverapp() -> None:
120
121
 
121
122
  def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
122
123
  serverappio_api_address: str,
123
- log_queue: Queue[Optional[str]],
124
+ log_queue: Queue[str | None],
124
125
  token: str,
125
- flwr_dir: Optional[str] = None,
126
- certificates: Optional[bytes] = None,
127
- parent_pid: Optional[int] = None,
126
+ flwr_dir: str | None = None,
127
+ certificates: bytes | None = None,
128
+ parent_pid: int | None = None,
128
129
  ) -> None:
129
130
  """Run Flower ServerApp process."""
130
131
  # Monitor the main process in case of SIGKILL
131
132
  if parent_pid is not None:
132
133
  start_parent_process_monitor(parent_pid)
133
134
 
134
- # Resolve directory where FABs are installed
135
+ # Initialize variables for exit handler
135
136
  flwr_dir_ = get_flwr_dir(flwr_dir)
136
137
  log_uploader = None
137
138
  hash_run_id = None
139
+ run = None
138
140
  run_status = None
139
141
  heartbeat_sender = None
140
142
  grid = None
@@ -143,7 +145,7 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
143
145
 
144
146
  def on_exit() -> None:
145
147
  # Stop heartbeat sender
146
- if heartbeat_sender:
148
+ if heartbeat_sender and heartbeat_sender.is_running:
147
149
  heartbeat_sender.stop()
148
150
 
149
151
  # Stop log uploader for this run and upload final logs
@@ -151,7 +153,7 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
151
153
  stop_log_uploader(log_queue, log_uploader)
152
154
 
153
155
  # Update run status
154
- if run_status and grid:
156
+ if run and run_status and grid:
155
157
  run_status_proto = run_status_to_proto(run_status)
156
158
  grid._stub.UpdateRunStatus(
157
159
  UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
@@ -161,7 +163,12 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
161
163
  if grid:
162
164
  grid.close()
163
165
 
164
- add_exit_handler(on_exit)
166
+ # Register signal handlers for graceful shutdown
167
+ register_signal_handlers(
168
+ event_type=EventType.FLWR_SERVERAPP_RUN_LEAVE,
169
+ exit_message="Run stopped by user.",
170
+ exit_handlers=[on_exit],
171
+ )
165
172
 
166
173
  try:
167
174
  # Initialize the GrpcGrid
@@ -171,9 +178,14 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
171
178
  )
172
179
 
173
180
  # Pull ServerAppInputs from LinkState
174
- req = PullAppInputsRequest(token=token)
175
- log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
176
- res: PullAppInputsResponse = grid._stub.PullAppInputs(req)
181
+ try:
182
+ log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
183
+ req = PullAppInputsRequest(token=token)
184
+ res: PullAppInputsResponse = grid._stub.PullAppInputs(req)
185
+ except grpc.RpcError as ex:
186
+ if ex.code() == grpc.StatusCode.FAILED_PRECONDITION:
187
+ raise RuntimeError("Failed to start the run.") from ex
188
+ raise
177
189
  context = context_from_proto(res.context)
178
190
  run = run_from_proto(res.run)
179
191
  fab = fab_from_proto(res.fab)
@@ -214,25 +226,15 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
214
226
  app_path,
215
227
  )
216
228
 
217
- # Change status to Running
218
- run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
219
- grid._stub.UpdateRunStatus(
220
- UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
221
- )
222
-
223
229
  event(
224
230
  EventType.FLWR_SERVERAPP_RUN_ENTER,
225
231
  event_details={"run-id-hash": hash_run_id},
226
232
  )
227
233
 
228
234
  # Set up heartbeat sender
229
- heartbeat_fn = get_grpc_app_heartbeat_fn(
230
- grid._stub,
231
- run.run_id,
232
- failure_message="Heartbeat failed unexpectedly. The SuperLink could "
233
- "not find the provided run ID, or the run status is invalid.",
235
+ heartbeat_sender = HeartbeatSender(
236
+ make_app_heartbeat_fn_grpc(grid._stub, token)
234
237
  )
235
- heartbeat_sender = HeartbeatSender(heartbeat_fn)
236
238
  heartbeat_sender.start()
237
239
 
238
240
  # Load and run the ServerApp with the Grid
@@ -256,11 +258,15 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
256
258
  # Raised when the run is already stopped by the user
257
259
  except RunNotRunningException:
258
260
  log(INFO, "")
259
- log(INFO, "Run ID %s stopped.", run.run_id)
261
+ log(INFO, "Run ID %s stopped.", run.run_id) # type: ignore[union-attr]
260
262
  log(INFO, "")
261
263
  run_status = None
262
264
  # No need to update the exit code since this is expected behavior
263
265
 
266
+ except RuntimeError:
267
+ log(ERROR, "Failed to start run.")
268
+ exit_code = ExitCode.SERVERAPP_RUN_START_REJECTED
269
+
264
270
  except Exception as ex: # pylint: disable=broad-exception-caught
265
271
  exc_entity = "ServerApp"
266
272
  log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  from dataclasses import dataclass
19
- from typing import Optional
20
19
 
21
20
  from .client_manager import ClientManager
22
21
  from .server import Server
@@ -46,7 +45,7 @@ class ServerAppComponents: # pylint: disable=too-many-instance-attributes
46
45
  will be used.
47
46
  """
48
47
 
49
- server: Optional[Server] = None
50
- config: Optional[ServerConfig] = None
51
- strategy: Optional[Strategy] = None
52
- client_manager: Optional[ClientManager] = None
48
+ server: Server | None = None
49
+ config: ServerConfig | None = None
50
+ strategy: Strategy | None = None
51
+ client_manager: ClientManager | None = None
@@ -15,8 +15,9 @@
15
15
  """Aggregation functions for strategy implementations."""
16
16
  # mypy: disallow_untyped_calls=False
17
17
 
18
+ from collections.abc import Callable
18
19
  from functools import partial, reduce
19
- from typing import Any, Callable, Union
20
+ from typing import Any
20
21
 
21
22
  import numpy as np
22
23
 
@@ -37,7 +38,7 @@ def aggregate(results: list[tuple[NDArrays, int]]) -> NDArrays:
37
38
  # Compute average weights of each layer
38
39
  weights_prime: NDArrays = [
39
40
  reduce(np.add, layer_updates) / num_examples_total
40
- for layer_updates in zip(*weighted_weights)
41
+ for layer_updates in zip(*weighted_weights, strict=True)
41
42
  ]
42
43
  return weights_prime
43
44
 
@@ -53,7 +54,7 @@ def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
53
54
  )
54
55
 
55
56
  def _try_inplace(
56
- x: NDArray, y: Union[NDArray, np.float64], np_binary_op: np.ufunc
57
+ x: NDArray, y: NDArray | np.float64, np_binary_op: np.ufunc
57
58
  ) -> NDArray:
58
59
  return ( # type: ignore[no-any-return]
59
60
  np_binary_op(x, y, out=x)
@@ -75,7 +76,7 @@ def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
75
76
  )
76
77
  params = [
77
78
  reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates)
78
- for layer_updates in zip(params, res)
79
+ for layer_updates in zip(params, res, strict=True)
79
80
  ]
80
81
 
81
82
  return params
@@ -88,7 +89,7 @@ def aggregate_median(results: list[tuple[NDArrays, int]]) -> NDArrays:
88
89
 
89
90
  # Compute median weight of each layer
90
91
  median_w: NDArrays = [
91
- np.median(np.asarray(layer), axis=0) for layer in zip(*weights)
92
+ np.median(np.asarray(layer), axis=0) for layer in zip(*weights, strict=True)
92
93
  ]
93
94
  return median_w
94
95
 
@@ -235,7 +236,7 @@ def aggregate_qffl(
235
236
  for j in range(1, len(deltas)):
236
237
  tmp += scaled_deltas[j][i]
237
238
  updates.append(tmp)
238
- new_parameters = [(u - v) * 1.0 for u, v in zip(parameters, updates)]
239
+ new_parameters = [(u - v) * 1.0 for u, v in zip(parameters, updates, strict=True)]
239
240
  return new_parameters
240
241
 
241
242
 
@@ -287,7 +288,7 @@ def aggregate_trimmed_avg(
287
288
 
288
289
  trimmed_w: NDArrays = [
289
290
  _trim_mean(np.asarray(layer), proportiontocut=proportiontocut)
290
- for layer in zip(*weights)
291
+ for layer in zip(*weights, strict=True)
291
292
  ]
292
293
 
293
294
  return trimmed_w
@@ -299,7 +300,7 @@ def _check_weights_equality(weights1: NDArrays, weights2: NDArrays) -> bool:
299
300
  return False
300
301
  return all(
301
302
  np.array_equal(layer_weights1, layer_weights2)
302
- for layer_weights1, layer_weights2 in zip(weights1, weights2)
303
+ for layer_weights1, layer_weights2 in zip(weights1, weights2, strict=True)
303
304
  )
304
305
 
305
306
 
@@ -18,8 +18,9 @@ Paper: arxiv.org/abs/1802.07927
18
18
  """
19
19
 
20
20
 
21
+ from collections.abc import Callable
21
22
  from logging import WARNING
22
- from typing import Any, Callable, Optional, Union
23
+ from typing import Any
23
24
 
24
25
  from flwr.common import (
25
26
  FitRes,
@@ -84,18 +85,19 @@ class Bulyan(FedAvg):
84
85
  min_evaluate_clients: int = 2,
85
86
  min_available_clients: int = 2,
86
87
  num_malicious_clients: int = 0,
87
- evaluate_fn: Optional[
88
+ evaluate_fn: (
88
89
  Callable[
89
90
  [int, NDArrays, dict[str, Scalar]],
90
- Optional[tuple[float, dict[str, Scalar]]],
91
+ tuple[float, dict[str, Scalar]] | None,
91
92
  ]
92
- ] = None,
93
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
94
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
93
+ | None
94
+ ) = None,
95
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
96
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
95
97
  accept_failures: bool = True,
96
- initial_parameters: Optional[Parameters] = None,
97
- fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
98
- evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
98
+ initial_parameters: Parameters | None = None,
99
+ fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
100
+ evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
99
101
  first_aggregation_rule: Callable = aggregate_krum, # type: ignore
100
102
  **aggregation_rule_kwargs: Any,
101
103
  ) -> None:
@@ -126,8 +128,8 @@ class Bulyan(FedAvg):
126
128
  self,
127
129
  server_round: int,
128
130
  results: list[tuple[ClientProxy, FitRes]],
129
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
130
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
131
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
132
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
131
133
  """Aggregate fit results using Bulyan."""
132
134
  if not results:
133
135
  return None, {}
@@ -20,7 +20,6 @@ Paper (Andrew et al.): https://arxiv.org/abs/1905.03871
20
20
 
21
21
  import math
22
22
  from logging import INFO, WARNING
23
- from typing import Optional, Union
24
23
 
25
24
  import numpy as np
26
25
 
@@ -97,7 +96,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
97
96
  initial_clipping_norm: float = 0.1,
98
97
  target_clipped_quantile: float = 0.5,
99
98
  clip_norm_lr: float = 0.2,
100
- clipped_count_stddev: Optional[float] = None,
99
+ clipped_count_stddev: float | None = None,
101
100
  ) -> None:
102
101
  super().__init__()
103
102
 
@@ -148,9 +147,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
148
147
  rep = "Differential Privacy Strategy Wrapper (Server-Side Adaptive Clipping)"
149
148
  return rep
150
149
 
151
- def initialize_parameters(
152
- self, client_manager: ClientManager
153
- ) -> Optional[Parameters]:
150
+ def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
154
151
  """Initialize global model parameters using given strategy."""
155
152
  return self.strategy.initialize_parameters(client_manager)
156
153
 
@@ -173,8 +170,8 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
173
170
  self,
174
171
  server_round: int,
175
172
  results: list[tuple[ClientProxy, FitRes]],
176
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
177
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
173
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
174
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
178
175
  """Aggregate training results and update clip norms."""
179
176
  if failures:
180
177
  return None, {}
@@ -192,7 +189,8 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
192
189
  param = parameters_to_ndarrays(res.parameters)
193
190
  # Compute and clip update
194
191
  model_update = [
195
- np.subtract(x, y) for (x, y) in zip(param, self.current_round_params)
192
+ np.subtract(x, y)
193
+ for (x, y) in zip(param, self.current_round_params, strict=True)
196
194
  ]
197
195
 
198
196
  norm_bit = adaptive_clip_inputs_inplace(model_update, self.clipping_norm)
@@ -246,14 +244,14 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
246
244
  self,
247
245
  server_round: int,
248
246
  results: list[tuple[ClientProxy, EvaluateRes]],
249
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
250
- ) -> tuple[Optional[float], dict[str, Scalar]]:
247
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
248
+ ) -> tuple[float | None, dict[str, Scalar]]:
251
249
  """Aggregate evaluation losses using the given strategy."""
252
250
  return self.strategy.aggregate_evaluate(server_round, results, failures)
253
251
 
254
252
  def evaluate(
255
253
  self, server_round: int, parameters: Parameters
256
- ) -> Optional[tuple[float, dict[str, Scalar]]]:
254
+ ) -> tuple[float, dict[str, Scalar]] | None:
257
255
  """Evaluate model parameters using an evaluation function from the strategy."""
258
256
  return self.strategy.evaluate(server_round, parameters)
259
257
 
@@ -316,7 +314,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
316
314
  initial_clipping_norm: float = 0.1,
317
315
  target_clipped_quantile: float = 0.5,
318
316
  clip_norm_lr: float = 0.2,
319
- clipped_count_stddev: Optional[float] = None,
317
+ clipped_count_stddev: float | None = None,
320
318
  ) -> None:
321
319
  super().__init__()
322
320
 
@@ -364,9 +362,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
364
362
  rep = "Differential Privacy Strategy Wrapper (Client-Side Adaptive Clipping)"
365
363
  return rep
366
364
 
367
- def initialize_parameters(
368
- self, client_manager: ClientManager
369
- ) -> Optional[Parameters]:
365
+ def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
370
366
  """Initialize global model parameters using given strategy."""
371
367
  return self.strategy.initialize_parameters(client_manager)
372
368
 
@@ -395,8 +391,8 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
395
391
  self,
396
392
  server_round: int,
397
393
  results: list[tuple[ClientProxy, FitRes]],
398
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
399
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
394
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
395
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
400
396
  """Aggregate training results and update clip norms."""
401
397
  if failures:
402
398
  return None, {}
@@ -458,13 +454,13 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
458
454
  self,
459
455
  server_round: int,
460
456
  results: list[tuple[ClientProxy, EvaluateRes]],
461
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
462
- ) -> tuple[Optional[float], dict[str, Scalar]]:
457
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
458
+ ) -> tuple[float | None, dict[str, Scalar]]:
463
459
  """Aggregate evaluation losses using the given strategy."""
464
460
  return self.strategy.aggregate_evaluate(server_round, results, failures)
465
461
 
466
462
  def evaluate(
467
463
  self, server_round: int, parameters: Parameters
468
- ) -> Optional[tuple[float, dict[str, Scalar]]]:
464
+ ) -> tuple[float, dict[str, Scalar]] | None:
469
465
  """Evaluate model parameters using an evaluation function from the strategy."""
470
466
  return self.strategy.evaluate(server_round, parameters)
@@ -19,7 +19,6 @@ Papers: https://arxiv.org/abs/1712.07557, https://arxiv.org/abs/1710.06963
19
19
 
20
20
 
21
21
  from logging import INFO, WARNING
22
- from typing import Optional, Union
23
22
 
24
23
  from flwr.common import (
25
24
  EvaluateIns,
@@ -109,9 +108,7 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
109
108
  rep = "Differential Privacy Strategy Wrapper (Server-Side Fixed Clipping)"
110
109
  return rep
111
110
 
112
- def initialize_parameters(
113
- self, client_manager: ClientManager
114
- ) -> Optional[Parameters]:
111
+ def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
115
112
  """Initialize global model parameters using given strategy."""
116
113
  return self.strategy.initialize_parameters(client_manager)
117
114
 
@@ -134,8 +131,8 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
134
131
  self,
135
132
  server_round: int,
136
133
  results: list[tuple[ClientProxy, FitRes]],
137
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
138
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
134
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
135
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
139
136
  """Compute the updates, clip, and pass them for aggregation.
140
137
 
141
138
  Afterward, add noise to the aggregated parameters.
@@ -192,14 +189,14 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
192
189
  self,
193
190
  server_round: int,
194
191
  results: list[tuple[ClientProxy, EvaluateRes]],
195
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
196
- ) -> tuple[Optional[float], dict[str, Scalar]]:
192
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
193
+ ) -> tuple[float | None, dict[str, Scalar]]:
197
194
  """Aggregate evaluation losses using the given strategy."""
198
195
  return self.strategy.aggregate_evaluate(server_round, results, failures)
199
196
 
200
197
  def evaluate(
201
198
  self, server_round: int, parameters: Parameters
202
- ) -> Optional[tuple[float, dict[str, Scalar]]]:
199
+ ) -> tuple[float, dict[str, Scalar]] | None:
203
200
  """Evaluate model parameters using an evaluation function from the strategy."""
204
201
  return self.strategy.evaluate(server_round, parameters)
205
202
 
@@ -277,9 +274,7 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
277
274
  rep = "Differential Privacy Strategy Wrapper (Client-Side Fixed Clipping)"
278
275
  return rep
279
276
 
280
- def initialize_parameters(
281
- self, client_manager: ClientManager
282
- ) -> Optional[Parameters]:
277
+ def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
283
278
  """Initialize global model parameters using given strategy."""
284
279
  return self.strategy.initialize_parameters(client_manager)
285
280
 
@@ -308,8 +303,8 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
308
303
  self,
309
304
  server_round: int,
310
305
  results: list[tuple[ClientProxy, FitRes]],
311
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
312
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
306
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
307
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
313
308
  """Add noise to the aggregated parameters."""
314
309
  if failures:
315
310
  return None, {}
@@ -349,13 +344,13 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
349
344
  self,
350
345
  server_round: int,
351
346
  results: list[tuple[ClientProxy, EvaluateRes]],
352
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
353
- ) -> tuple[Optional[float], dict[str, Scalar]]:
347
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
348
+ ) -> tuple[float | None, dict[str, Scalar]]:
354
349
  """Aggregate evaluation losses using the given strategy."""
355
350
  return self.strategy.aggregate_evaluate(server_round, results, failures)
356
351
 
357
352
  def evaluate(
358
353
  self, server_round: int, parameters: Parameters
359
- ) -> Optional[tuple[float, dict[str, Scalar]]]:
354
+ ) -> tuple[float, dict[str, Scalar]] | None:
360
355
  """Evaluate model parameters using an evaluation function from the strategy."""
361
356
  return self.strategy.evaluate(server_round, parameters)
@@ -19,7 +19,6 @@ Paper: arxiv.org/pdf/1905.03871.pdf
19
19
 
20
20
 
21
21
  import math
22
- from typing import Optional, Union
23
22
 
24
23
  import numpy as np
25
24
 
@@ -49,7 +48,7 @@ class DPFedAvgAdaptive(DPFedAvgFixed):
49
48
  server_side_noising: bool = True,
50
49
  clip_norm_lr: float = 0.2,
51
50
  clip_norm_target_quantile: float = 0.5,
52
- clip_count_stddev: Optional[float] = None,
51
+ clip_count_stddev: float | None = None,
53
52
  ) -> None:
54
53
  warn_deprecated_feature("`DPFedAvgAdaptive` wrapper")
55
54
  super().__init__(
@@ -119,8 +118,8 @@ class DPFedAvgAdaptive(DPFedAvgFixed):
119
118
  self,
120
119
  server_round: int,
121
120
  results: list[tuple[ClientProxy, FitRes]],
122
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
123
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
121
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
122
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
124
123
  """Aggregate training results as in DPFedAvgFixed and update clip norms."""
125
124
  if failures:
126
125
  return None, {}
@@ -18,8 +18,6 @@ Paper: arxiv.org/pdf/1710.06963.pdf
18
18
  """
19
19
 
20
20
 
21
- from typing import Optional, Union
22
-
23
21
  from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
24
22
  from flwr.common.dp import add_gaussian_noise
25
23
  from flwr.common.logger import warn_deprecated_feature
@@ -72,9 +70,7 @@ class DPFedAvgFixed(Strategy):
72
70
  self.noise_multiplier * self.clip_norm / (self.num_sampled_clients ** (0.5))
73
71
  )
74
72
 
75
- def initialize_parameters(
76
- self, client_manager: ClientManager
77
- ) -> Optional[Parameters]:
73
+ def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
78
74
  """Initialize global model parameters using given strategy."""
79
75
  return self.strategy.initialize_parameters(client_manager)
80
76
 
@@ -149,8 +145,8 @@ class DPFedAvgFixed(Strategy):
149
145
  self,
150
146
  server_round: int,
151
147
  results: list[tuple[ClientProxy, FitRes]],
152
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
153
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
148
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
149
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
154
150
  """Aggregate training results using unweighted aggregation."""
155
151
  if failures:
156
152
  return None, {}
@@ -170,13 +166,13 @@ class DPFedAvgFixed(Strategy):
170
166
  self,
171
167
  server_round: int,
172
168
  results: list[tuple[ClientProxy, EvaluateRes]],
173
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
174
- ) -> tuple[Optional[float], dict[str, Scalar]]:
169
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
170
+ ) -> tuple[float | None, dict[str, Scalar]]:
175
171
  """Aggregate evaluation losses using the given strategy."""
176
172
  return self.strategy.aggregate_evaluate(server_round, results, failures)
177
173
 
178
174
  def evaluate(
179
175
  self, server_round: int, parameters: Parameters
180
- ) -> Optional[tuple[float, dict[str, Scalar]]]:
176
+ ) -> tuple[float, dict[str, Scalar]] | None:
181
177
  """Evaluate model parameters using an evaluation function from the strategy."""
182
178
  return self.strategy.evaluate(server_round, parameters)
@@ -15,8 +15,8 @@
15
15
  """Fault-tolerant variant of FedAvg strategy."""
16
16
 
17
17
 
18
+ from collections.abc import Callable
18
19
  from logging import WARNING
19
- from typing import Callable, Optional, Union
20
20
 
21
21
  from flwr.common import (
22
22
  EvaluateRes,
@@ -47,19 +47,20 @@ class FaultTolerantFedAvg(FedAvg):
47
47
  min_fit_clients: int = 1,
48
48
  min_evaluate_clients: int = 1,
49
49
  min_available_clients: int = 1,
50
- evaluate_fn: Optional[
50
+ evaluate_fn: (
51
51
  Callable[
52
52
  [int, NDArrays, dict[str, Scalar]],
53
- Optional[tuple[float, dict[str, Scalar]]],
53
+ tuple[float, dict[str, Scalar]] | None,
54
54
  ]
55
- ] = None,
56
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
57
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
55
+ | None
56
+ ) = None,
57
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
58
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
58
59
  min_completion_rate_fit: float = 0.5,
59
60
  min_completion_rate_evaluate: float = 0.5,
60
- initial_parameters: Optional[Parameters] = None,
61
- fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
62
- evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
61
+ initial_parameters: Parameters | None = None,
62
+ fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
63
+ evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
63
64
  ) -> None:
64
65
  super().__init__(
65
66
  fraction_fit=fraction_fit,
@@ -86,8 +87,8 @@ class FaultTolerantFedAvg(FedAvg):
86
87
  self,
87
88
  server_round: int,
88
89
  results: list[tuple[ClientProxy, FitRes]],
89
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
90
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
90
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
91
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
91
92
  """Aggregate fit results using weighted average."""
92
93
  if not results:
93
94
  return None, {}
@@ -118,8 +119,8 @@ class FaultTolerantFedAvg(FedAvg):
118
119
  self,
119
120
  server_round: int,
120
121
  results: list[tuple[ClientProxy, EvaluateRes]],
121
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
122
- ) -> tuple[Optional[float], dict[str, Scalar]]:
122
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
123
+ ) -> tuple[float | None, dict[str, Scalar]]:
123
124
  """Aggregate evaluation losses using weighted average."""
124
125
  if not results:
125
126
  return None, {}