flwr 1.23.0__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (292) hide show
  1. flwr/__init__.py +16 -5
  2. flwr/app/error.py +2 -2
  3. flwr/app/exception.py +3 -3
  4. flwr/cli/app.py +19 -0
  5. flwr/cli/app_cmd/__init__.py +23 -0
  6. flwr/cli/app_cmd/publish.py +285 -0
  7. flwr/cli/app_cmd/review.py +252 -0
  8. flwr/cli/auth_plugin/auth_plugin.py +4 -5
  9. flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
  10. flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
  11. flwr/cli/build.py +60 -18
  12. flwr/cli/cli_account_auth_interceptor.py +24 -7
  13. flwr/cli/config_utils.py +101 -13
  14. flwr/cli/federation/__init__.py +24 -0
  15. flwr/cli/federation/ls.py +140 -0
  16. flwr/cli/federation/show.py +317 -0
  17. flwr/cli/install.py +91 -13
  18. flwr/cli/log.py +52 -9
  19. flwr/cli/login/login.py +7 -4
  20. flwr/cli/ls.py +170 -130
  21. flwr/cli/new/new.py +33 -50
  22. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
  23. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  25. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  26. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  28. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  29. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  30. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
  31. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  32. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  33. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
  34. flwr/cli/pull.py +10 -5
  35. flwr/cli/run/run.py +77 -30
  36. flwr/cli/run_utils.py +130 -0
  37. flwr/cli/stop.py +25 -7
  38. flwr/cli/supernode/ls.py +16 -8
  39. flwr/cli/supernode/register.py +9 -4
  40. flwr/cli/supernode/unregister.py +5 -3
  41. flwr/cli/utils.py +376 -16
  42. flwr/client/__init__.py +1 -1
  43. flwr/client/dpfedavg_numpy_client.py +4 -1
  44. flwr/client/grpc_adapter_client/connection.py +6 -7
  45. flwr/client/grpc_rere_client/connection.py +10 -11
  46. flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
  47. flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
  48. flwr/client/message_handler/message_handler.py +2 -2
  49. flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
  50. flwr/client/numpy_client.py +1 -1
  51. flwr/client/rest_client/connection.py +12 -14
  52. flwr/client/run_info_store.py +4 -5
  53. flwr/client/typing.py +1 -1
  54. flwr/clientapp/client_app.py +9 -10
  55. flwr/clientapp/mod/centraldp_mods.py +16 -17
  56. flwr/clientapp/mod/localdp_mod.py +8 -9
  57. flwr/clientapp/typing.py +1 -1
  58. flwr/clientapp/utils.py +3 -3
  59. flwr/common/address.py +1 -2
  60. flwr/common/args.py +3 -4
  61. flwr/common/config.py +13 -16
  62. flwr/common/constant.py +5 -2
  63. flwr/common/differential_privacy.py +3 -4
  64. flwr/common/event_log_plugin/event_log_plugin.py +3 -4
  65. flwr/common/exit/exit.py +15 -2
  66. flwr/common/exit/exit_code.py +19 -0
  67. flwr/common/exit/exit_handler.py +6 -2
  68. flwr/common/exit/signal_handler.py +5 -5
  69. flwr/common/grpc.py +6 -6
  70. flwr/common/inflatable_protobuf_utils.py +1 -1
  71. flwr/common/inflatable_utils.py +38 -21
  72. flwr/common/logger.py +19 -19
  73. flwr/common/message.py +4 -4
  74. flwr/common/object_ref.py +7 -7
  75. flwr/common/record/array.py +3 -3
  76. flwr/common/record/arrayrecord.py +18 -30
  77. flwr/common/record/configrecord.py +3 -3
  78. flwr/common/record/recorddict.py +5 -5
  79. flwr/common/record/typeddict.py +9 -2
  80. flwr/common/recorddict_compat.py +7 -10
  81. flwr/common/retry_invoker.py +20 -20
  82. flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
  83. flwr/common/serde.py +5 -4
  84. flwr/common/serde_utils.py +2 -2
  85. flwr/common/telemetry.py +9 -5
  86. flwr/common/typing.py +52 -37
  87. flwr/compat/client/app.py +38 -37
  88. flwr/compat/client/grpc_client/connection.py +11 -11
  89. flwr/compat/server/app.py +5 -6
  90. flwr/proto/appio_pb2.py +13 -3
  91. flwr/proto/appio_pb2.pyi +134 -65
  92. flwr/proto/appio_pb2_grpc.py +20 -0
  93. flwr/proto/appio_pb2_grpc.pyi +27 -0
  94. flwr/proto/clientappio_pb2.py +17 -7
  95. flwr/proto/clientappio_pb2.pyi +15 -0
  96. flwr/proto/clientappio_pb2_grpc.py +206 -40
  97. flwr/proto/clientappio_pb2_grpc.pyi +168 -53
  98. flwr/proto/control_pb2.py +71 -52
  99. flwr/proto/control_pb2.pyi +277 -111
  100. flwr/proto/control_pb2_grpc.py +249 -40
  101. flwr/proto/control_pb2_grpc.pyi +185 -52
  102. flwr/proto/error_pb2.py +13 -3
  103. flwr/proto/error_pb2.pyi +24 -6
  104. flwr/proto/error_pb2_grpc.py +20 -0
  105. flwr/proto/error_pb2_grpc.pyi +27 -0
  106. flwr/proto/fab_pb2.py +14 -4
  107. flwr/proto/fab_pb2.pyi +59 -31
  108. flwr/proto/fab_pb2_grpc.py +20 -0
  109. flwr/proto/fab_pb2_grpc.pyi +27 -0
  110. flwr/proto/federation_pb2.py +38 -0
  111. flwr/proto/federation_pb2.pyi +56 -0
  112. flwr/proto/federation_pb2_grpc.py +24 -0
  113. flwr/proto/federation_pb2_grpc.pyi +31 -0
  114. flwr/proto/fleet_pb2.py +14 -4
  115. flwr/proto/fleet_pb2.pyi +137 -61
  116. flwr/proto/fleet_pb2_grpc.py +189 -48
  117. flwr/proto/fleet_pb2_grpc.pyi +175 -61
  118. flwr/proto/grpcadapter_pb2.py +14 -4
  119. flwr/proto/grpcadapter_pb2.pyi +38 -16
  120. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  121. flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
  122. flwr/proto/heartbeat_pb2.py +17 -7
  123. flwr/proto/heartbeat_pb2.pyi +51 -22
  124. flwr/proto/heartbeat_pb2_grpc.py +20 -0
  125. flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
  126. flwr/proto/log_pb2.py +13 -3
  127. flwr/proto/log_pb2.pyi +34 -11
  128. flwr/proto/log_pb2_grpc.py +20 -0
  129. flwr/proto/log_pb2_grpc.pyi +27 -0
  130. flwr/proto/message_pb2.py +15 -5
  131. flwr/proto/message_pb2.pyi +154 -86
  132. flwr/proto/message_pb2_grpc.py +20 -0
  133. flwr/proto/message_pb2_grpc.pyi +27 -0
  134. flwr/proto/node_pb2.py +15 -5
  135. flwr/proto/node_pb2.pyi +50 -25
  136. flwr/proto/node_pb2_grpc.py +20 -0
  137. flwr/proto/node_pb2_grpc.pyi +27 -0
  138. flwr/proto/recorddict_pb2.py +13 -3
  139. flwr/proto/recorddict_pb2.pyi +184 -107
  140. flwr/proto/recorddict_pb2_grpc.py +20 -0
  141. flwr/proto/recorddict_pb2_grpc.pyi +27 -0
  142. flwr/proto/run_pb2.py +40 -31
  143. flwr/proto/run_pb2.pyi +149 -84
  144. flwr/proto/run_pb2_grpc.py +20 -0
  145. flwr/proto/run_pb2_grpc.pyi +27 -0
  146. flwr/proto/serverappio_pb2.py +13 -3
  147. flwr/proto/serverappio_pb2.pyi +32 -8
  148. flwr/proto/serverappio_pb2_grpc.py +246 -65
  149. flwr/proto/serverappio_pb2_grpc.pyi +221 -85
  150. flwr/proto/simulationio_pb2.py +16 -8
  151. flwr/proto/simulationio_pb2.pyi +15 -0
  152. flwr/proto/simulationio_pb2_grpc.py +162 -41
  153. flwr/proto/simulationio_pb2_grpc.pyi +149 -55
  154. flwr/proto/transport_pb2.py +20 -10
  155. flwr/proto/transport_pb2.pyi +249 -160
  156. flwr/proto/transport_pb2_grpc.py +35 -4
  157. flwr/proto/transport_pb2_grpc.pyi +38 -8
  158. flwr/server/app.py +38 -17
  159. flwr/server/client_manager.py +4 -5
  160. flwr/server/client_proxy.py +10 -11
  161. flwr/server/compat/app.py +4 -5
  162. flwr/server/compat/app_utils.py +2 -1
  163. flwr/server/compat/grid_client_proxy.py +10 -12
  164. flwr/server/compat/legacy_context.py +3 -4
  165. flwr/server/fleet_event_log_interceptor.py +2 -1
  166. flwr/server/grid/grid.py +2 -3
  167. flwr/server/grid/grpc_grid.py +10 -8
  168. flwr/server/grid/inmemory_grid.py +4 -4
  169. flwr/server/run_serverapp.py +2 -3
  170. flwr/server/server.py +34 -39
  171. flwr/server/server_app.py +7 -8
  172. flwr/server/server_config.py +1 -2
  173. flwr/server/serverapp/app.py +34 -28
  174. flwr/server/serverapp_components.py +4 -5
  175. flwr/server/strategy/aggregate.py +9 -8
  176. flwr/server/strategy/bulyan.py +13 -11
  177. flwr/server/strategy/dp_adaptive_clipping.py +16 -20
  178. flwr/server/strategy/dp_fixed_clipping.py +12 -17
  179. flwr/server/strategy/dpfedavg_adaptive.py +3 -4
  180. flwr/server/strategy/dpfedavg_fixed.py +6 -10
  181. flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
  182. flwr/server/strategy/fedadagrad.py +18 -14
  183. flwr/server/strategy/fedadam.py +16 -14
  184. flwr/server/strategy/fedavg.py +16 -17
  185. flwr/server/strategy/fedavg_android.py +15 -15
  186. flwr/server/strategy/fedavgm.py +21 -18
  187. flwr/server/strategy/fedmedian.py +2 -3
  188. flwr/server/strategy/fedopt.py +11 -10
  189. flwr/server/strategy/fedprox.py +10 -9
  190. flwr/server/strategy/fedtrimmedavg.py +12 -11
  191. flwr/server/strategy/fedxgb_bagging.py +13 -11
  192. flwr/server/strategy/fedxgb_cyclic.py +6 -6
  193. flwr/server/strategy/fedxgb_nn_avg.py +4 -4
  194. flwr/server/strategy/fedyogi.py +16 -14
  195. flwr/server/strategy/krum.py +12 -11
  196. flwr/server/strategy/qfedavg.py +16 -15
  197. flwr/server/strategy/strategy.py +6 -9
  198. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
  199. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
  200. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
  201. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
  202. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
  203. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
  204. flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
  205. flwr/server/superlink/fleet/message_handler/message_handler.py +34 -28
  206. flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
  207. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  208. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  209. flwr/server/superlink/fleet/vce/vce_api.py +15 -9
  210. flwr/server/superlink/linkstate/in_memory_linkstate.py +115 -150
  211. flwr/server/superlink/linkstate/linkstate.py +59 -43
  212. flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
  213. flwr/server/superlink/linkstate/sqlite_linkstate.py +447 -438
  214. flwr/server/superlink/linkstate/utils.py +6 -6
  215. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
  216. flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
  217. flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
  218. flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
  219. flwr/server/superlink/utils.py +4 -6
  220. flwr/server/typing.py +1 -1
  221. flwr/server/utils/tensorboard.py +15 -8
  222. flwr/server/workflow/default_workflows.py +5 -5
  223. flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
  224. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
  225. flwr/serverapp/strategy/bulyan.py +16 -15
  226. flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
  227. flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
  228. flwr/serverapp/strategy/fedadagrad.py +10 -11
  229. flwr/serverapp/strategy/fedadam.py +10 -11
  230. flwr/serverapp/strategy/fedavg.py +9 -10
  231. flwr/serverapp/strategy/fedavgm.py +17 -16
  232. flwr/serverapp/strategy/fedmedian.py +2 -2
  233. flwr/serverapp/strategy/fedopt.py +10 -11
  234. flwr/serverapp/strategy/fedprox.py +7 -8
  235. flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
  236. flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
  237. flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
  238. flwr/serverapp/strategy/fedyogi.py +9 -11
  239. flwr/serverapp/strategy/krum.py +7 -7
  240. flwr/serverapp/strategy/multikrum.py +9 -9
  241. flwr/serverapp/strategy/qfedavg.py +17 -16
  242. flwr/serverapp/strategy/strategy.py +6 -9
  243. flwr/serverapp/strategy/strategy_utils.py +7 -8
  244. flwr/simulation/app.py +46 -42
  245. flwr/simulation/legacy_app.py +12 -12
  246. flwr/simulation/ray_transport/ray_actor.py +10 -11
  247. flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
  248. flwr/simulation/run_simulation.py +43 -43
  249. flwr/simulation/simulationio_connection.py +4 -4
  250. flwr/supercore/cli/flower_superexec.py +3 -4
  251. flwr/supercore/constant.py +31 -1
  252. flwr/supercore/corestate/corestate.py +24 -3
  253. flwr/supercore/corestate/in_memory_corestate.py +138 -0
  254. flwr/supercore/corestate/sqlite_corestate.py +157 -0
  255. flwr/supercore/ffs/disk_ffs.py +1 -2
  256. flwr/supercore/ffs/ffs.py +1 -2
  257. flwr/supercore/ffs/ffs_factory.py +1 -2
  258. flwr/{common → supercore}/heartbeat.py +20 -25
  259. flwr/supercore/object_store/in_memory_object_store.py +1 -2
  260. flwr/supercore/object_store/object_store.py +1 -2
  261. flwr/supercore/object_store/object_store_factory.py +1 -2
  262. flwr/supercore/object_store/sqlite_object_store.py +8 -7
  263. flwr/supercore/primitives/asymmetric.py +1 -1
  264. flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
  265. flwr/supercore/sqlite_mixin.py +37 -34
  266. flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
  267. flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
  268. flwr/supercore/superexec/run_superexec.py +9 -13
  269. flwr/superlink/artifact_provider/artifact_provider.py +1 -2
  270. flwr/superlink/auth_plugin/auth_plugin.py +6 -9
  271. flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
  272. flwr/superlink/federation/__init__.py +24 -0
  273. flwr/superlink/federation/federation_manager.py +64 -0
  274. flwr/superlink/federation/noop_federation_manager.py +71 -0
  275. flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
  276. flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
  277. flwr/superlink/servicer/control/control_grpc.py +5 -6
  278. flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
  279. flwr/superlink/servicer/control/control_servicer.py +102 -18
  280. flwr/supernode/cli/flower_supernode.py +58 -3
  281. flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
  282. flwr/supernode/nodestate/nodestate.py +7 -8
  283. flwr/supernode/nodestate/nodestate_factory.py +7 -4
  284. flwr/supernode/runtime/run_clientapp.py +41 -22
  285. flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
  286. flwr/supernode/start_client_internal.py +158 -42
  287. {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
  288. flwr-1.24.0.dist-info/RECORD +454 -0
  289. flwr/supercore/object_store/utils.py +0 -43
  290. flwr-1.23.0.dist-info/RECORD +0 -439
  291. {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
  292. {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
@@ -20,7 +20,6 @@ Paper (Andrew et al.): https://arxiv.org/abs/1905.03871
20
20
 
21
21
  import math
22
22
  from logging import INFO, WARNING
23
- from typing import Optional, Union
24
23
 
25
24
  import numpy as np
26
25
 
@@ -97,7 +96,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
97
96
  initial_clipping_norm: float = 0.1,
98
97
  target_clipped_quantile: float = 0.5,
99
98
  clip_norm_lr: float = 0.2,
100
- clipped_count_stddev: Optional[float] = None,
99
+ clipped_count_stddev: float | None = None,
101
100
  ) -> None:
102
101
  super().__init__()
103
102
 
@@ -148,9 +147,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
148
147
  rep = "Differential Privacy Strategy Wrapper (Server-Side Adaptive Clipping)"
149
148
  return rep
150
149
 
151
- def initialize_parameters(
152
- self, client_manager: ClientManager
153
- ) -> Optional[Parameters]:
150
+ def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
154
151
  """Initialize global model parameters using given strategy."""
155
152
  return self.strategy.initialize_parameters(client_manager)
156
153
 
@@ -173,8 +170,8 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
173
170
  self,
174
171
  server_round: int,
175
172
  results: list[tuple[ClientProxy, FitRes]],
176
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
177
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
173
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
174
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
178
175
  """Aggregate training results and update clip norms."""
179
176
  if failures:
180
177
  return None, {}
@@ -192,7 +189,8 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
192
189
  param = parameters_to_ndarrays(res.parameters)
193
190
  # Compute and clip update
194
191
  model_update = [
195
- np.subtract(x, y) for (x, y) in zip(param, self.current_round_params)
192
+ np.subtract(x, y)
193
+ for (x, y) in zip(param, self.current_round_params, strict=True)
196
194
  ]
197
195
 
198
196
  norm_bit = adaptive_clip_inputs_inplace(model_update, self.clipping_norm)
@@ -246,14 +244,14 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
246
244
  self,
247
245
  server_round: int,
248
246
  results: list[tuple[ClientProxy, EvaluateRes]],
249
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
250
- ) -> tuple[Optional[float], dict[str, Scalar]]:
247
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
248
+ ) -> tuple[float | None, dict[str, Scalar]]:
251
249
  """Aggregate evaluation losses using the given strategy."""
252
250
  return self.strategy.aggregate_evaluate(server_round, results, failures)
253
251
 
254
252
  def evaluate(
255
253
  self, server_round: int, parameters: Parameters
256
- ) -> Optional[tuple[float, dict[str, Scalar]]]:
254
+ ) -> tuple[float, dict[str, Scalar]] | None:
257
255
  """Evaluate model parameters using an evaluation function from the strategy."""
258
256
  return self.strategy.evaluate(server_round, parameters)
259
257
 
@@ -316,7 +314,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
316
314
  initial_clipping_norm: float = 0.1,
317
315
  target_clipped_quantile: float = 0.5,
318
316
  clip_norm_lr: float = 0.2,
319
- clipped_count_stddev: Optional[float] = None,
317
+ clipped_count_stddev: float | None = None,
320
318
  ) -> None:
321
319
  super().__init__()
322
320
 
@@ -364,9 +362,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
364
362
  rep = "Differential Privacy Strategy Wrapper (Client-Side Adaptive Clipping)"
365
363
  return rep
366
364
 
367
- def initialize_parameters(
368
- self, client_manager: ClientManager
369
- ) -> Optional[Parameters]:
365
+ def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
370
366
  """Initialize global model parameters using given strategy."""
371
367
  return self.strategy.initialize_parameters(client_manager)
372
368
 
@@ -395,8 +391,8 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
395
391
  self,
396
392
  server_round: int,
397
393
  results: list[tuple[ClientProxy, FitRes]],
398
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
399
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
394
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
395
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
400
396
  """Aggregate training results and update clip norms."""
401
397
  if failures:
402
398
  return None, {}
@@ -458,13 +454,13 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
458
454
  self,
459
455
  server_round: int,
460
456
  results: list[tuple[ClientProxy, EvaluateRes]],
461
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
462
- ) -> tuple[Optional[float], dict[str, Scalar]]:
457
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
458
+ ) -> tuple[float | None, dict[str, Scalar]]:
463
459
  """Aggregate evaluation losses using the given strategy."""
464
460
  return self.strategy.aggregate_evaluate(server_round, results, failures)
465
461
 
466
462
  def evaluate(
467
463
  self, server_round: int, parameters: Parameters
468
- ) -> Optional[tuple[float, dict[str, Scalar]]]:
464
+ ) -> tuple[float, dict[str, Scalar]] | None:
469
465
  """Evaluate model parameters using an evaluation function from the strategy."""
470
466
  return self.strategy.evaluate(server_round, parameters)
@@ -19,7 +19,6 @@ Papers: https://arxiv.org/abs/1712.07557, https://arxiv.org/abs/1710.06963
19
19
 
20
20
 
21
21
  from logging import INFO, WARNING
22
- from typing import Optional, Union
23
22
 
24
23
  from flwr.common import (
25
24
  EvaluateIns,
@@ -109,9 +108,7 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
109
108
  rep = "Differential Privacy Strategy Wrapper (Server-Side Fixed Clipping)"
110
109
  return rep
111
110
 
112
- def initialize_parameters(
113
- self, client_manager: ClientManager
114
- ) -> Optional[Parameters]:
111
+ def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
115
112
  """Initialize global model parameters using given strategy."""
116
113
  return self.strategy.initialize_parameters(client_manager)
117
114
 
@@ -134,8 +131,8 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
134
131
  self,
135
132
  server_round: int,
136
133
  results: list[tuple[ClientProxy, FitRes]],
137
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
138
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
134
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
135
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
139
136
  """Compute the updates, clip, and pass them for aggregation.
140
137
 
141
138
  Afterward, add noise to the aggregated parameters.
@@ -192,14 +189,14 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
192
189
  self,
193
190
  server_round: int,
194
191
  results: list[tuple[ClientProxy, EvaluateRes]],
195
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
196
- ) -> tuple[Optional[float], dict[str, Scalar]]:
192
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
193
+ ) -> tuple[float | None, dict[str, Scalar]]:
197
194
  """Aggregate evaluation losses using the given strategy."""
198
195
  return self.strategy.aggregate_evaluate(server_round, results, failures)
199
196
 
200
197
  def evaluate(
201
198
  self, server_round: int, parameters: Parameters
202
- ) -> Optional[tuple[float, dict[str, Scalar]]]:
199
+ ) -> tuple[float, dict[str, Scalar]] | None:
203
200
  """Evaluate model parameters using an evaluation function from the strategy."""
204
201
  return self.strategy.evaluate(server_round, parameters)
205
202
 
@@ -277,9 +274,7 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
277
274
  rep = "Differential Privacy Strategy Wrapper (Client-Side Fixed Clipping)"
278
275
  return rep
279
276
 
280
- def initialize_parameters(
281
- self, client_manager: ClientManager
282
- ) -> Optional[Parameters]:
277
+ def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
283
278
  """Initialize global model parameters using given strategy."""
284
279
  return self.strategy.initialize_parameters(client_manager)
285
280
 
@@ -308,8 +303,8 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
308
303
  self,
309
304
  server_round: int,
310
305
  results: list[tuple[ClientProxy, FitRes]],
311
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
312
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
306
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
307
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
313
308
  """Add noise to the aggregated parameters."""
314
309
  if failures:
315
310
  return None, {}
@@ -349,13 +344,13 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
349
344
  self,
350
345
  server_round: int,
351
346
  results: list[tuple[ClientProxy, EvaluateRes]],
352
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
353
- ) -> tuple[Optional[float], dict[str, Scalar]]:
347
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
348
+ ) -> tuple[float | None, dict[str, Scalar]]:
354
349
  """Aggregate evaluation losses using the given strategy."""
355
350
  return self.strategy.aggregate_evaluate(server_round, results, failures)
356
351
 
357
352
  def evaluate(
358
353
  self, server_round: int, parameters: Parameters
359
- ) -> Optional[tuple[float, dict[str, Scalar]]]:
354
+ ) -> tuple[float, dict[str, Scalar]] | None:
360
355
  """Evaluate model parameters using an evaluation function from the strategy."""
361
356
  return self.strategy.evaluate(server_round, parameters)
@@ -19,7 +19,6 @@ Paper: arxiv.org/pdf/1905.03871.pdf
19
19
 
20
20
 
21
21
  import math
22
- from typing import Optional, Union
23
22
 
24
23
  import numpy as np
25
24
 
@@ -49,7 +48,7 @@ class DPFedAvgAdaptive(DPFedAvgFixed):
49
48
  server_side_noising: bool = True,
50
49
  clip_norm_lr: float = 0.2,
51
50
  clip_norm_target_quantile: float = 0.5,
52
- clip_count_stddev: Optional[float] = None,
51
+ clip_count_stddev: float | None = None,
53
52
  ) -> None:
54
53
  warn_deprecated_feature("`DPFedAvgAdaptive` wrapper")
55
54
  super().__init__(
@@ -119,8 +118,8 @@ class DPFedAvgAdaptive(DPFedAvgFixed):
119
118
  self,
120
119
  server_round: int,
121
120
  results: list[tuple[ClientProxy, FitRes]],
122
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
123
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
121
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
122
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
124
123
  """Aggregate training results as in DPFedAvgFixed and update clip norms."""
125
124
  if failures:
126
125
  return None, {}
@@ -18,8 +18,6 @@ Paper: arxiv.org/pdf/1710.06963.pdf
18
18
  """
19
19
 
20
20
 
21
- from typing import Optional, Union
22
-
23
21
  from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
24
22
  from flwr.common.dp import add_gaussian_noise
25
23
  from flwr.common.logger import warn_deprecated_feature
@@ -72,9 +70,7 @@ class DPFedAvgFixed(Strategy):
72
70
  self.noise_multiplier * self.clip_norm / (self.num_sampled_clients ** (0.5))
73
71
  )
74
72
 
75
- def initialize_parameters(
76
- self, client_manager: ClientManager
77
- ) -> Optional[Parameters]:
73
+ def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
78
74
  """Initialize global model parameters using given strategy."""
79
75
  return self.strategy.initialize_parameters(client_manager)
80
76
 
@@ -149,8 +145,8 @@ class DPFedAvgFixed(Strategy):
149
145
  self,
150
146
  server_round: int,
151
147
  results: list[tuple[ClientProxy, FitRes]],
152
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
153
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
148
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
149
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
154
150
  """Aggregate training results using unweighted aggregation."""
155
151
  if failures:
156
152
  return None, {}
@@ -170,13 +166,13 @@ class DPFedAvgFixed(Strategy):
170
166
  self,
171
167
  server_round: int,
172
168
  results: list[tuple[ClientProxy, EvaluateRes]],
173
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
174
- ) -> tuple[Optional[float], dict[str, Scalar]]:
169
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
170
+ ) -> tuple[float | None, dict[str, Scalar]]:
175
171
  """Aggregate evaluation losses using the given strategy."""
176
172
  return self.strategy.aggregate_evaluate(server_round, results, failures)
177
173
 
178
174
  def evaluate(
179
175
  self, server_round: int, parameters: Parameters
180
- ) -> Optional[tuple[float, dict[str, Scalar]]]:
176
+ ) -> tuple[float, dict[str, Scalar]] | None:
181
177
  """Evaluate model parameters using an evaluation function from the strategy."""
182
178
  return self.strategy.evaluate(server_round, parameters)
@@ -15,8 +15,8 @@
15
15
  """Fault-tolerant variant of FedAvg strategy."""
16
16
 
17
17
 
18
+ from collections.abc import Callable
18
19
  from logging import WARNING
19
- from typing import Callable, Optional, Union
20
20
 
21
21
  from flwr.common import (
22
22
  EvaluateRes,
@@ -47,19 +47,20 @@ class FaultTolerantFedAvg(FedAvg):
47
47
  min_fit_clients: int = 1,
48
48
  min_evaluate_clients: int = 1,
49
49
  min_available_clients: int = 1,
50
- evaluate_fn: Optional[
50
+ evaluate_fn: (
51
51
  Callable[
52
52
  [int, NDArrays, dict[str, Scalar]],
53
- Optional[tuple[float, dict[str, Scalar]]],
53
+ tuple[float, dict[str, Scalar]] | None,
54
54
  ]
55
- ] = None,
56
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
57
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
55
+ | None
56
+ ) = None,
57
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
58
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
58
59
  min_completion_rate_fit: float = 0.5,
59
60
  min_completion_rate_evaluate: float = 0.5,
60
- initial_parameters: Optional[Parameters] = None,
61
- fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
62
- evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
61
+ initial_parameters: Parameters | None = None,
62
+ fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
63
+ evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
63
64
  ) -> None:
64
65
  super().__init__(
65
66
  fraction_fit=fraction_fit,
@@ -86,8 +87,8 @@ class FaultTolerantFedAvg(FedAvg):
86
87
  self,
87
88
  server_round: int,
88
89
  results: list[tuple[ClientProxy, FitRes]],
89
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
90
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
90
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
91
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
91
92
  """Aggregate fit results using weighted average."""
92
93
  if not results:
93
94
  return None, {}
@@ -118,8 +119,8 @@ class FaultTolerantFedAvg(FedAvg):
118
119
  self,
119
120
  server_round: int,
120
121
  results: list[tuple[ClientProxy, EvaluateRes]],
121
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
122
- ) -> tuple[Optional[float], dict[str, Scalar]]:
122
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
123
+ ) -> tuple[float | None, dict[str, Scalar]]:
123
124
  """Aggregate evaluation losses using weighted average."""
124
125
  if not results:
125
126
  return None, {}
@@ -20,7 +20,7 @@ Paper: arxiv.org/abs/2003.00295
20
20
  """
21
21
 
22
22
 
23
- from typing import Callable, Optional, Union
23
+ from collections.abc import Callable
24
24
 
25
25
  import numpy as np
26
26
 
@@ -87,16 +87,17 @@ class FedAdagrad(FedOpt):
87
87
  min_fit_clients: int = 2,
88
88
  min_evaluate_clients: int = 2,
89
89
  min_available_clients: int = 2,
90
- evaluate_fn: Optional[
90
+ evaluate_fn: (
91
91
  Callable[
92
92
  [int, NDArrays, dict[str, Scalar]],
93
- Optional[tuple[float, dict[str, Scalar]]],
93
+ tuple[float, dict[str, Scalar]] | None,
94
94
  ]
95
- ] = None,
96
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
97
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
98
- fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
99
- evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
95
+ | None
96
+ ) = None,
97
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
98
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
99
+ fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
100
+ evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
100
101
  accept_failures: bool = True,
101
102
  initial_parameters: Parameters,
102
103
  eta: float = 1e-1,
@@ -132,8 +133,8 @@ class FedAdagrad(FedOpt):
132
133
  self,
133
134
  server_round: int,
134
135
  results: list[tuple[ClientProxy, FitRes]],
135
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
136
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
136
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
137
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
137
138
  """Aggregate fit results using weighted average."""
138
139
  fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit(
139
140
  server_round=server_round, results=results, failures=failures
@@ -145,7 +146,8 @@ class FedAdagrad(FedOpt):
145
146
 
146
147
  # Adagrad
147
148
  delta_t: NDArrays = [
148
- x - y for x, y in zip(fedavg_weights_aggregate, self.current_weights)
149
+ x - y
150
+ for x, y in zip(fedavg_weights_aggregate, self.current_weights, strict=True)
149
151
  ]
150
152
 
151
153
  # m_t
@@ -153,17 +155,19 @@ class FedAdagrad(FedOpt):
153
155
  self.m_t = [np.zeros_like(x) for x in delta_t]
154
156
  self.m_t = [
155
157
  np.multiply(self.beta_1, x) + (1 - self.beta_1) * y
156
- for x, y in zip(self.m_t, delta_t)
158
+ for x, y in zip(self.m_t, delta_t, strict=True)
157
159
  ]
158
160
 
159
161
  # v_t
160
162
  if not self.v_t:
161
163
  self.v_t = [np.zeros_like(x) for x in delta_t]
162
- self.v_t = [x + np.multiply(y, y) for x, y in zip(self.v_t, delta_t)]
164
+ self.v_t = [
165
+ x + np.multiply(y, y) for x, y in zip(self.v_t, delta_t, strict=True)
166
+ ]
163
167
 
164
168
  new_weights = [
165
169
  x + self.eta * y / (np.sqrt(z) + self.tau)
166
- for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
170
+ for x, y, z in zip(self.current_weights, self.m_t, self.v_t, strict=True)
167
171
  ]
168
172
 
169
173
  self.current_weights = new_weights
@@ -20,7 +20,7 @@ Paper: arxiv.org/abs/2003.00295
20
20
  """
21
21
 
22
22
 
23
- from typing import Callable, Optional, Union
23
+ from collections.abc import Callable
24
24
 
25
25
  import numpy as np
26
26
 
@@ -91,18 +91,19 @@ class FedAdam(FedOpt):
91
91
  min_fit_clients: int = 2,
92
92
  min_evaluate_clients: int = 2,
93
93
  min_available_clients: int = 2,
94
- evaluate_fn: Optional[
94
+ evaluate_fn: (
95
95
  Callable[
96
96
  [int, NDArrays, dict[str, Scalar]],
97
- Optional[tuple[float, dict[str, Scalar]]],
97
+ tuple[float, dict[str, Scalar]] | None,
98
98
  ]
99
- ] = None,
100
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
101
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
99
+ | None
100
+ ) = None,
101
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
102
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
102
103
  accept_failures: bool = True,
103
104
  initial_parameters: Parameters,
104
- fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
105
- evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
105
+ fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
106
+ evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
106
107
  eta: float = 1e-1,
107
108
  eta_l: float = 1e-1,
108
109
  beta_1: float = 0.9,
@@ -138,8 +139,8 @@ class FedAdam(FedOpt):
138
139
  self,
139
140
  server_round: int,
140
141
  results: list[tuple[ClientProxy, FitRes]],
141
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
142
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
142
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
143
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
143
144
  """Aggregate fit results using weighted average."""
144
145
  fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit(
145
146
  server_round=server_round, results=results, failures=failures
@@ -151,7 +152,8 @@ class FedAdam(FedOpt):
151
152
 
152
153
  # Adam
153
154
  delta_t: NDArrays = [
154
- x - y for x, y in zip(fedavg_weights_aggregate, self.current_weights)
155
+ x - y
156
+ for x, y in zip(fedavg_weights_aggregate, self.current_weights, strict=True)
155
157
  ]
156
158
 
157
159
  # m_t
@@ -159,7 +161,7 @@ class FedAdam(FedOpt):
159
161
  self.m_t = [np.zeros_like(x) for x in delta_t]
160
162
  self.m_t = [
161
163
  np.multiply(self.beta_1, x) + (1 - self.beta_1) * y
162
- for x, y in zip(self.m_t, delta_t)
164
+ for x, y in zip(self.m_t, delta_t, strict=True)
163
165
  ]
164
166
 
165
167
  # v_t
@@ -167,7 +169,7 @@ class FedAdam(FedOpt):
167
169
  self.v_t = [np.zeros_like(x) for x in delta_t]
168
170
  self.v_t = [
169
171
  self.beta_2 * x + (1 - self.beta_2) * np.multiply(y, y)
170
- for x, y in zip(self.v_t, delta_t)
172
+ for x, y in zip(self.v_t, delta_t, strict=True)
171
173
  ]
172
174
 
173
175
  # Compute the bias-corrected learning rate, `eta_norm` for improving convergence
@@ -182,7 +184,7 @@ class FedAdam(FedOpt):
182
184
 
183
185
  new_weights = [
184
186
  x + eta_norm * y / (np.sqrt(z) + self.tau)
185
- for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
187
+ for x, y, z in zip(self.current_weights, self.m_t, self.v_t, strict=True)
186
188
  ]
187
189
 
188
190
  self.current_weights = new_weights
@@ -18,8 +18,8 @@ Paper: arxiv.org/abs/1602.05629
18
18
  """
19
19
 
20
20
 
21
+ from collections.abc import Callable
21
22
  from logging import WARNING
22
- from typing import Callable, Optional, Union
23
23
 
24
24
  from flwr.common import (
25
25
  EvaluateIns,
@@ -97,18 +97,19 @@ class FedAvg(Strategy):
97
97
  min_fit_clients: int = 2,
98
98
  min_evaluate_clients: int = 2,
99
99
  min_available_clients: int = 2,
100
- evaluate_fn: Optional[
100
+ evaluate_fn: (
101
101
  Callable[
102
102
  [int, NDArrays, dict[str, Scalar]],
103
- Optional[tuple[float, dict[str, Scalar]]],
103
+ tuple[float, dict[str, Scalar]] | None,
104
104
  ]
105
- ] = None,
106
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
107
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
105
+ | None
106
+ ) = None,
107
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
108
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
108
109
  accept_failures: bool = True,
109
- initial_parameters: Optional[Parameters] = None,
110
- fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
111
- evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
110
+ initial_parameters: Parameters | None = None,
111
+ fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
112
+ evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
112
113
  inplace: bool = True,
113
114
  ) -> None:
114
115
  super().__init__()
@@ -148,9 +149,7 @@ class FedAvg(Strategy):
148
149
  num_clients = int(num_available_clients * self.fraction_evaluate)
149
150
  return max(num_clients, self.min_evaluate_clients), self.min_available_clients
150
151
 
151
- def initialize_parameters(
152
- self, client_manager: ClientManager
153
- ) -> Optional[Parameters]:
152
+ def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
154
153
  """Initialize global model parameters."""
155
154
  initial_parameters = self.initial_parameters
156
155
  self.initial_parameters = None # Don't keep initial parameters in memory
@@ -158,7 +157,7 @@ class FedAvg(Strategy):
158
157
 
159
158
  def evaluate(
160
159
  self, server_round: int, parameters: Parameters
161
- ) -> Optional[tuple[float, dict[str, Scalar]]]:
160
+ ) -> tuple[float, dict[str, Scalar]] | None:
162
161
  """Evaluate model parameters using an evaluation function."""
163
162
  if self.evaluate_fn is None:
164
163
  # No evaluation function provided
@@ -221,8 +220,8 @@ class FedAvg(Strategy):
221
220
  self,
222
221
  server_round: int,
223
222
  results: list[tuple[ClientProxy, FitRes]],
224
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
225
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
223
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
224
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
226
225
  """Aggregate fit results using weighted average."""
227
226
  if not results:
228
227
  return None, {}
@@ -257,8 +256,8 @@ class FedAvg(Strategy):
257
256
  self,
258
257
  server_round: int,
259
258
  results: list[tuple[ClientProxy, EvaluateRes]],
260
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
261
- ) -> tuple[Optional[float], dict[str, Scalar]]:
259
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
260
+ ) -> tuple[float | None, dict[str, Scalar]]:
262
261
  """Aggregate evaluation losses using weighted average."""
263
262
  if not results:
264
263
  return None, {}
@@ -18,7 +18,8 @@ Paper: arxiv.org/abs/1602.05629
18
18
  """
19
19
 
20
20
 
21
- from typing import Callable, Optional, Union, cast
21
+ from collections.abc import Callable
22
+ from typing import cast
22
23
 
23
24
  import numpy as np
24
25
 
@@ -79,16 +80,17 @@ class FedAvgAndroid(Strategy):
79
80
  min_fit_clients: int = 2,
80
81
  min_evaluate_clients: int = 2,
81
82
  min_available_clients: int = 2,
82
- evaluate_fn: Optional[
83
+ evaluate_fn: (
83
84
  Callable[
84
85
  [int, NDArrays, dict[str, Scalar]],
85
- Optional[tuple[float, dict[str, Scalar]]],
86
+ tuple[float, dict[str, Scalar]] | None,
86
87
  ]
87
- ] = None,
88
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
89
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
88
+ | None
89
+ ) = None,
90
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
91
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
90
92
  accept_failures: bool = True,
91
- initial_parameters: Optional[Parameters] = None,
93
+ initial_parameters: Parameters | None = None,
92
94
  ) -> None:
93
95
  super().__init__()
94
96
  self.min_fit_clients = min_fit_clients
@@ -117,9 +119,7 @@ class FedAvgAndroid(Strategy):
117
119
  num_clients = int(num_available_clients * self.fraction_evaluate)
118
120
  return max(num_clients, self.min_evaluate_clients), self.min_available_clients
119
121
 
120
- def initialize_parameters(
121
- self, client_manager: ClientManager
122
- ) -> Optional[Parameters]:
122
+ def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
123
123
  """Initialize global model parameters."""
124
124
  initial_parameters = self.initial_parameters
125
125
  self.initial_parameters = None # Don't keep initial parameters in memory
@@ -127,7 +127,7 @@ class FedAvgAndroid(Strategy):
127
127
 
128
128
  def evaluate(
129
129
  self, server_round: int, parameters: Parameters
130
- ) -> Optional[tuple[float, dict[str, Scalar]]]:
130
+ ) -> tuple[float, dict[str, Scalar]] | None:
131
131
  """Evaluate model parameters using an evaluation function."""
132
132
  if self.evaluate_fn is None:
133
133
  # No evaluation function provided
@@ -190,8 +190,8 @@ class FedAvgAndroid(Strategy):
190
190
  self,
191
191
  server_round: int,
192
192
  results: list[tuple[ClientProxy, FitRes]],
193
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
194
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
193
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
194
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
195
195
  """Aggregate fit results using weighted average."""
196
196
  if not results:
197
197
  return None, {}
@@ -209,8 +209,8 @@ class FedAvgAndroid(Strategy):
209
209
  self,
210
210
  server_round: int,
211
211
  results: list[tuple[ClientProxy, EvaluateRes]],
212
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
213
- ) -> tuple[Optional[float], dict[str, Scalar]]:
212
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
213
+ ) -> tuple[float | None, dict[str, Scalar]]:
214
214
  """Aggregate evaluation losses using weighted average."""
215
215
  if not results:
216
216
  return None, {}