flwr 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (301) hide show
  1. flwr/__init__.py +16 -5
  2. flwr/app/error.py +2 -2
  3. flwr/app/exception.py +3 -3
  4. flwr/cli/app.py +34 -1
  5. flwr/cli/app_cmd/__init__.py +23 -0
  6. flwr/cli/app_cmd/publish.py +285 -0
  7. flwr/cli/app_cmd/review.py +252 -0
  8. flwr/cli/auth_plugin/__init__.py +15 -6
  9. flwr/cli/auth_plugin/auth_plugin.py +94 -0
  10. flwr/cli/auth_plugin/noop_auth_plugin.py +101 -0
  11. flwr/cli/auth_plugin/oidc_cli_plugin.py +46 -32
  12. flwr/cli/build.py +166 -53
  13. flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +29 -11
  14. flwr/cli/config_utils.py +101 -13
  15. flwr/cli/federation/__init__.py +24 -0
  16. flwr/cli/federation/ls.py +140 -0
  17. flwr/cli/federation/show.py +317 -0
  18. flwr/cli/install.py +91 -13
  19. flwr/cli/log.py +54 -11
  20. flwr/cli/login/login.py +41 -27
  21. flwr/cli/ls.py +177 -133
  22. flwr/cli/new/new.py +175 -40
  23. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
  24. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  25. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  26. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  28. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  29. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  30. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  31. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
  32. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  33. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  34. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
  35. flwr/cli/pull.py +12 -7
  36. flwr/cli/run/run.py +82 -31
  37. flwr/cli/run_utils.py +130 -0
  38. flwr/cli/stop.py +27 -9
  39. flwr/cli/supernode/__init__.py +25 -0
  40. flwr/cli/supernode/ls.py +268 -0
  41. flwr/cli/supernode/register.py +190 -0
  42. flwr/cli/supernode/unregister.py +140 -0
  43. flwr/cli/utils.py +464 -81
  44. flwr/client/__init__.py +2 -1
  45. flwr/client/dpfedavg_numpy_client.py +4 -1
  46. flwr/client/grpc_adapter_client/connection.py +12 -15
  47. flwr/client/grpc_rere_client/connection.py +68 -41
  48. flwr/client/grpc_rere_client/grpc_adapter.py +34 -14
  49. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +5 -7
  50. flwr/client/message_handler/message_handler.py +2 -2
  51. flwr/client/mod/secure_aggregation/secaggplus_mod.py +10 -8
  52. flwr/client/numpy_client.py +1 -1
  53. flwr/client/rest_client/connection.py +94 -51
  54. flwr/client/run_info_store.py +4 -5
  55. flwr/client/typing.py +1 -1
  56. flwr/clientapp/__init__.py +1 -2
  57. flwr/{client → clientapp}/client_app.py +9 -10
  58. flwr/clientapp/mod/centraldp_mods.py +16 -17
  59. flwr/clientapp/mod/localdp_mod.py +8 -9
  60. flwr/clientapp/typing.py +1 -1
  61. flwr/{client/clientapp → clientapp}/utils.py +4 -4
  62. flwr/common/address.py +1 -2
  63. flwr/common/args.py +3 -4
  64. flwr/common/config.py +13 -16
  65. flwr/common/constant.py +56 -13
  66. flwr/common/differential_privacy.py +3 -4
  67. flwr/common/event_log_plugin/event_log_plugin.py +3 -4
  68. flwr/common/exit/exit.py +15 -2
  69. flwr/common/exit/exit_code.py +39 -10
  70. flwr/common/exit/exit_handler.py +6 -2
  71. flwr/common/exit/signal_handler.py +5 -5
  72. flwr/common/grpc.py +6 -6
  73. flwr/common/inflatable_protobuf_utils.py +1 -1
  74. flwr/common/inflatable_utils.py +48 -31
  75. flwr/common/logger.py +19 -19
  76. flwr/common/message.py +4 -4
  77. flwr/common/object_ref.py +7 -7
  78. flwr/common/record/array.py +6 -6
  79. flwr/common/record/arrayrecord.py +18 -21
  80. flwr/common/record/configrecord.py +3 -3
  81. flwr/common/record/recorddict.py +5 -5
  82. flwr/common/record/typeddict.py +9 -2
  83. flwr/common/recorddict_compat.py +7 -10
  84. flwr/common/retry_invoker.py +20 -20
  85. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
  86. flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
  87. flwr/common/serde.py +9 -6
  88. flwr/common/serde_utils.py +2 -2
  89. flwr/common/telemetry.py +9 -5
  90. flwr/common/typing.py +59 -43
  91. flwr/compat/client/app.py +39 -38
  92. flwr/compat/client/grpc_client/connection.py +13 -13
  93. flwr/compat/server/app.py +5 -6
  94. flwr/proto/appio_pb2.py +13 -3
  95. flwr/proto/appio_pb2.pyi +134 -65
  96. flwr/proto/appio_pb2_grpc.py +20 -0
  97. flwr/proto/appio_pb2_grpc.pyi +27 -0
  98. flwr/proto/clientappio_pb2.py +17 -7
  99. flwr/proto/clientappio_pb2.pyi +15 -0
  100. flwr/proto/clientappio_pb2_grpc.py +206 -40
  101. flwr/proto/clientappio_pb2_grpc.pyi +168 -53
  102. flwr/proto/control_pb2.py +72 -40
  103. flwr/proto/control_pb2.pyi +319 -87
  104. flwr/proto/control_pb2_grpc.py +339 -28
  105. flwr/proto/control_pb2_grpc.pyi +209 -37
  106. flwr/proto/error_pb2.py +13 -3
  107. flwr/proto/error_pb2.pyi +24 -6
  108. flwr/proto/error_pb2_grpc.py +20 -0
  109. flwr/proto/error_pb2_grpc.pyi +27 -0
  110. flwr/proto/fab_pb2.py +24 -10
  111. flwr/proto/fab_pb2.pyi +68 -20
  112. flwr/proto/fab_pb2_grpc.py +20 -0
  113. flwr/proto/fab_pb2_grpc.pyi +27 -0
  114. flwr/proto/federation_pb2.py +38 -0
  115. flwr/proto/federation_pb2.pyi +56 -0
  116. flwr/proto/federation_pb2_grpc.py +24 -0
  117. flwr/proto/federation_pb2_grpc.pyi +31 -0
  118. flwr/proto/fleet_pb2.py +45 -27
  119. flwr/proto/fleet_pb2.pyi +186 -70
  120. flwr/proto/fleet_pb2_grpc.py +277 -66
  121. flwr/proto/fleet_pb2_grpc.pyi +201 -55
  122. flwr/proto/grpcadapter_pb2.py +14 -4
  123. flwr/proto/grpcadapter_pb2.pyi +38 -16
  124. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  125. flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
  126. flwr/proto/heartbeat_pb2.py +17 -7
  127. flwr/proto/heartbeat_pb2.pyi +51 -22
  128. flwr/proto/heartbeat_pb2_grpc.py +20 -0
  129. flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
  130. flwr/proto/log_pb2.py +13 -3
  131. flwr/proto/log_pb2.pyi +34 -11
  132. flwr/proto/log_pb2_grpc.py +20 -0
  133. flwr/proto/log_pb2_grpc.pyi +27 -0
  134. flwr/proto/message_pb2.py +15 -5
  135. flwr/proto/message_pb2.pyi +154 -86
  136. flwr/proto/message_pb2_grpc.py +20 -0
  137. flwr/proto/message_pb2_grpc.pyi +27 -0
  138. flwr/proto/node_pb2.py +16 -4
  139. flwr/proto/node_pb2.pyi +77 -4
  140. flwr/proto/node_pb2_grpc.py +20 -0
  141. flwr/proto/node_pb2_grpc.pyi +27 -0
  142. flwr/proto/recorddict_pb2.py +13 -3
  143. flwr/proto/recorddict_pb2.pyi +184 -107
  144. flwr/proto/recorddict_pb2_grpc.py +20 -0
  145. flwr/proto/recorddict_pb2_grpc.pyi +27 -0
  146. flwr/proto/run_pb2.py +40 -31
  147. flwr/proto/run_pb2.pyi +149 -84
  148. flwr/proto/run_pb2_grpc.py +20 -0
  149. flwr/proto/run_pb2_grpc.pyi +27 -0
  150. flwr/proto/serverappio_pb2.py +13 -3
  151. flwr/proto/serverappio_pb2.pyi +32 -8
  152. flwr/proto/serverappio_pb2_grpc.py +246 -65
  153. flwr/proto/serverappio_pb2_grpc.pyi +221 -85
  154. flwr/proto/simulationio_pb2.py +16 -8
  155. flwr/proto/simulationio_pb2.pyi +15 -0
  156. flwr/proto/simulationio_pb2_grpc.py +162 -41
  157. flwr/proto/simulationio_pb2_grpc.pyi +149 -55
  158. flwr/proto/transport_pb2.py +20 -10
  159. flwr/proto/transport_pb2.pyi +249 -160
  160. flwr/proto/transport_pb2_grpc.py +35 -4
  161. flwr/proto/transport_pb2_grpc.pyi +38 -8
  162. flwr/server/app.py +173 -127
  163. flwr/server/client_manager.py +4 -5
  164. flwr/server/client_proxy.py +10 -11
  165. flwr/server/compat/app.py +4 -5
  166. flwr/server/compat/app_utils.py +2 -1
  167. flwr/server/compat/grid_client_proxy.py +10 -12
  168. flwr/server/compat/legacy_context.py +3 -4
  169. flwr/server/fleet_event_log_interceptor.py +2 -1
  170. flwr/server/grid/grid.py +2 -3
  171. flwr/server/grid/grpc_grid.py +10 -8
  172. flwr/server/grid/inmemory_grid.py +4 -4
  173. flwr/server/run_serverapp.py +2 -3
  174. flwr/server/server.py +34 -39
  175. flwr/server/server_app.py +7 -8
  176. flwr/server/server_config.py +1 -2
  177. flwr/server/serverapp/app.py +34 -28
  178. flwr/server/serverapp_components.py +4 -5
  179. flwr/server/strategy/aggregate.py +9 -8
  180. flwr/server/strategy/bulyan.py +13 -11
  181. flwr/server/strategy/dp_adaptive_clipping.py +16 -20
  182. flwr/server/strategy/dp_fixed_clipping.py +12 -17
  183. flwr/server/strategy/dpfedavg_adaptive.py +3 -4
  184. flwr/server/strategy/dpfedavg_fixed.py +6 -10
  185. flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
  186. flwr/server/strategy/fedadagrad.py +18 -14
  187. flwr/server/strategy/fedadam.py +16 -14
  188. flwr/server/strategy/fedavg.py +16 -17
  189. flwr/server/strategy/fedavg_android.py +15 -15
  190. flwr/server/strategy/fedavgm.py +21 -18
  191. flwr/server/strategy/fedmedian.py +2 -3
  192. flwr/server/strategy/fedopt.py +11 -10
  193. flwr/server/strategy/fedprox.py +10 -9
  194. flwr/server/strategy/fedtrimmedavg.py +12 -11
  195. flwr/server/strategy/fedxgb_bagging.py +13 -11
  196. flwr/server/strategy/fedxgb_cyclic.py +6 -6
  197. flwr/server/strategy/fedxgb_nn_avg.py +4 -4
  198. flwr/server/strategy/fedyogi.py +16 -14
  199. flwr/server/strategy/krum.py +12 -11
  200. flwr/server/strategy/qfedavg.py +16 -15
  201. flwr/server/strategy/strategy.py +6 -9
  202. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +19 -8
  203. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
  204. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
  205. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
  206. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
  207. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +136 -42
  208. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +28 -51
  209. flwr/server/superlink/fleet/message_handler/message_handler.py +100 -49
  210. flwr/server/superlink/fleet/rest_rere/rest_api.py +54 -33
  211. flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
  212. flwr/server/superlink/fleet/vce/backend/raybackend.py +6 -6
  213. flwr/server/superlink/fleet/vce/vce_api.py +32 -13
  214. flwr/server/superlink/linkstate/in_memory_linkstate.py +266 -207
  215. flwr/server/superlink/linkstate/linkstate.py +161 -62
  216. flwr/server/superlink/linkstate/linkstate_factory.py +24 -6
  217. flwr/server/superlink/linkstate/sqlite_linkstate.py +698 -638
  218. flwr/server/superlink/linkstate/utils.py +9 -60
  219. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
  220. flwr/server/superlink/serverappio/serverappio_servicer.py +28 -23
  221. flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
  222. flwr/server/superlink/simulation/simulationio_servicer.py +19 -14
  223. flwr/server/superlink/utils.py +4 -6
  224. flwr/server/typing.py +1 -1
  225. flwr/server/utils/tensorboard.py +15 -8
  226. flwr/server/utils/validator.py +2 -3
  227. flwr/server/workflow/default_workflows.py +5 -5
  228. flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
  229. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +12 -10
  230. flwr/serverapp/strategy/bulyan.py +16 -15
  231. flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
  232. flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
  233. flwr/serverapp/strategy/fedadagrad.py +10 -11
  234. flwr/serverapp/strategy/fedadam.py +10 -11
  235. flwr/serverapp/strategy/fedavg.py +9 -10
  236. flwr/serverapp/strategy/fedavgm.py +17 -16
  237. flwr/serverapp/strategy/fedmedian.py +2 -2
  238. flwr/serverapp/strategy/fedopt.py +10 -11
  239. flwr/serverapp/strategy/fedprox.py +7 -8
  240. flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
  241. flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
  242. flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
  243. flwr/serverapp/strategy/fedyogi.py +9 -11
  244. flwr/serverapp/strategy/krum.py +7 -7
  245. flwr/serverapp/strategy/multikrum.py +9 -9
  246. flwr/serverapp/strategy/qfedavg.py +17 -16
  247. flwr/serverapp/strategy/strategy.py +6 -9
  248. flwr/serverapp/strategy/strategy_utils.py +7 -8
  249. flwr/simulation/app.py +46 -42
  250. flwr/simulation/legacy_app.py +12 -12
  251. flwr/simulation/ray_transport/ray_actor.py +11 -12
  252. flwr/simulation/ray_transport/ray_client_proxy.py +12 -13
  253. flwr/simulation/run_simulation.py +44 -43
  254. flwr/simulation/simulationio_connection.py +4 -4
  255. flwr/supercore/cli/flower_superexec.py +3 -4
  256. flwr/supercore/constant.py +52 -0
  257. flwr/supercore/corestate/corestate.py +24 -3
  258. flwr/supercore/corestate/in_memory_corestate.py +138 -0
  259. flwr/supercore/corestate/sqlite_corestate.py +157 -0
  260. flwr/supercore/ffs/disk_ffs.py +1 -2
  261. flwr/supercore/ffs/ffs.py +1 -2
  262. flwr/supercore/ffs/ffs_factory.py +1 -2
  263. flwr/{common → supercore}/heartbeat.py +20 -25
  264. flwr/supercore/object_store/in_memory_object_store.py +1 -6
  265. flwr/supercore/object_store/object_store.py +1 -2
  266. flwr/supercore/object_store/object_store_factory.py +27 -8
  267. flwr/supercore/object_store/sqlite_object_store.py +253 -0
  268. flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
  269. flwr/supercore/primitives/asymmetric.py +117 -0
  270. flwr/supercore/primitives/asymmetric_ed25519.py +175 -0
  271. flwr/supercore/sqlite_mixin.py +159 -0
  272. flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
  273. flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
  274. flwr/supercore/superexec/run_superexec.py +9 -13
  275. flwr/supercore/utils.py +20 -0
  276. flwr/superlink/artifact_provider/artifact_provider.py +1 -2
  277. flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
  278. flwr/superlink/auth_plugin/auth_plugin.py +88 -0
  279. flwr/superlink/auth_plugin/noop_auth_plugin.py +84 -0
  280. flwr/superlink/federation/__init__.py +24 -0
  281. flwr/superlink/federation/federation_manager.py +64 -0
  282. flwr/superlink/federation/noop_federation_manager.py +71 -0
  283. flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +41 -32
  284. flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
  285. flwr/superlink/servicer/control/control_grpc.py +18 -17
  286. flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
  287. flwr/superlink/servicer/control/control_servicer.py +239 -63
  288. flwr/supernode/cli/flower_supernode.py +74 -26
  289. flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
  290. flwr/supernode/nodestate/nodestate.py +7 -8
  291. flwr/supernode/nodestate/nodestate_factory.py +7 -4
  292. flwr/supernode/runtime/run_clientapp.py +43 -24
  293. flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
  294. flwr/supernode/start_client_internal.py +175 -51
  295. {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
  296. flwr-1.24.0.dist-info/RECORD +454 -0
  297. flwr/common/auth_plugin/auth_plugin.py +0 -149
  298. flwr/supercore/object_store/utils.py +0 -43
  299. flwr-1.22.0.dist-info/RECORD +0 -428
  300. {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
  301. {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
@@ -20,7 +20,7 @@ Paper: arxiv.org/abs/2003.00295
20
20
  """
21
21
 
22
22
 
23
- from typing import Callable, Optional, Union
23
+ from collections.abc import Callable
24
24
 
25
25
  import numpy as np
26
26
 
@@ -87,16 +87,17 @@ class FedAdagrad(FedOpt):
87
87
  min_fit_clients: int = 2,
88
88
  min_evaluate_clients: int = 2,
89
89
  min_available_clients: int = 2,
90
- evaluate_fn: Optional[
90
+ evaluate_fn: (
91
91
  Callable[
92
92
  [int, NDArrays, dict[str, Scalar]],
93
- Optional[tuple[float, dict[str, Scalar]]],
93
+ tuple[float, dict[str, Scalar]] | None,
94
94
  ]
95
- ] = None,
96
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
97
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
98
- fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
99
- evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
95
+ | None
96
+ ) = None,
97
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
98
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
99
+ fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
100
+ evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
100
101
  accept_failures: bool = True,
101
102
  initial_parameters: Parameters,
102
103
  eta: float = 1e-1,
@@ -132,8 +133,8 @@ class FedAdagrad(FedOpt):
132
133
  self,
133
134
  server_round: int,
134
135
  results: list[tuple[ClientProxy, FitRes]],
135
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
136
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
136
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
137
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
137
138
  """Aggregate fit results using weighted average."""
138
139
  fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit(
139
140
  server_round=server_round, results=results, failures=failures
@@ -145,7 +146,8 @@ class FedAdagrad(FedOpt):
145
146
 
146
147
  # Adagrad
147
148
  delta_t: NDArrays = [
148
- x - y for x, y in zip(fedavg_weights_aggregate, self.current_weights)
149
+ x - y
150
+ for x, y in zip(fedavg_weights_aggregate, self.current_weights, strict=True)
149
151
  ]
150
152
 
151
153
  # m_t
@@ -153,17 +155,19 @@ class FedAdagrad(FedOpt):
153
155
  self.m_t = [np.zeros_like(x) for x in delta_t]
154
156
  self.m_t = [
155
157
  np.multiply(self.beta_1, x) + (1 - self.beta_1) * y
156
- for x, y in zip(self.m_t, delta_t)
158
+ for x, y in zip(self.m_t, delta_t, strict=True)
157
159
  ]
158
160
 
159
161
  # v_t
160
162
  if not self.v_t:
161
163
  self.v_t = [np.zeros_like(x) for x in delta_t]
162
- self.v_t = [x + np.multiply(y, y) for x, y in zip(self.v_t, delta_t)]
164
+ self.v_t = [
165
+ x + np.multiply(y, y) for x, y in zip(self.v_t, delta_t, strict=True)
166
+ ]
163
167
 
164
168
  new_weights = [
165
169
  x + self.eta * y / (np.sqrt(z) + self.tau)
166
- for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
170
+ for x, y, z in zip(self.current_weights, self.m_t, self.v_t, strict=True)
167
171
  ]
168
172
 
169
173
  self.current_weights = new_weights
@@ -20,7 +20,7 @@ Paper: arxiv.org/abs/2003.00295
20
20
  """
21
21
 
22
22
 
23
- from typing import Callable, Optional, Union
23
+ from collections.abc import Callable
24
24
 
25
25
  import numpy as np
26
26
 
@@ -91,18 +91,19 @@ class FedAdam(FedOpt):
91
91
  min_fit_clients: int = 2,
92
92
  min_evaluate_clients: int = 2,
93
93
  min_available_clients: int = 2,
94
- evaluate_fn: Optional[
94
+ evaluate_fn: (
95
95
  Callable[
96
96
  [int, NDArrays, dict[str, Scalar]],
97
- Optional[tuple[float, dict[str, Scalar]]],
97
+ tuple[float, dict[str, Scalar]] | None,
98
98
  ]
99
- ] = None,
100
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
101
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
99
+ | None
100
+ ) = None,
101
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
102
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
102
103
  accept_failures: bool = True,
103
104
  initial_parameters: Parameters,
104
- fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
105
- evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
105
+ fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
106
+ evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
106
107
  eta: float = 1e-1,
107
108
  eta_l: float = 1e-1,
108
109
  beta_1: float = 0.9,
@@ -138,8 +139,8 @@ class FedAdam(FedOpt):
138
139
  self,
139
140
  server_round: int,
140
141
  results: list[tuple[ClientProxy, FitRes]],
141
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
142
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
142
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
143
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
143
144
  """Aggregate fit results using weighted average."""
144
145
  fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit(
145
146
  server_round=server_round, results=results, failures=failures
@@ -151,7 +152,8 @@ class FedAdam(FedOpt):
151
152
 
152
153
  # Adam
153
154
  delta_t: NDArrays = [
154
- x - y for x, y in zip(fedavg_weights_aggregate, self.current_weights)
155
+ x - y
156
+ for x, y in zip(fedavg_weights_aggregate, self.current_weights, strict=True)
155
157
  ]
156
158
 
157
159
  # m_t
@@ -159,7 +161,7 @@ class FedAdam(FedOpt):
159
161
  self.m_t = [np.zeros_like(x) for x in delta_t]
160
162
  self.m_t = [
161
163
  np.multiply(self.beta_1, x) + (1 - self.beta_1) * y
162
- for x, y in zip(self.m_t, delta_t)
164
+ for x, y in zip(self.m_t, delta_t, strict=True)
163
165
  ]
164
166
 
165
167
  # v_t
@@ -167,7 +169,7 @@ class FedAdam(FedOpt):
167
169
  self.v_t = [np.zeros_like(x) for x in delta_t]
168
170
  self.v_t = [
169
171
  self.beta_2 * x + (1 - self.beta_2) * np.multiply(y, y)
170
- for x, y in zip(self.v_t, delta_t)
172
+ for x, y in zip(self.v_t, delta_t, strict=True)
171
173
  ]
172
174
 
173
175
  # Compute the bias-corrected learning rate, `eta_norm` for improving convergence
@@ -182,7 +184,7 @@ class FedAdam(FedOpt):
182
184
 
183
185
  new_weights = [
184
186
  x + eta_norm * y / (np.sqrt(z) + self.tau)
185
- for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
187
+ for x, y, z in zip(self.current_weights, self.m_t, self.v_t, strict=True)
186
188
  ]
187
189
 
188
190
  self.current_weights = new_weights
@@ -18,8 +18,8 @@ Paper: arxiv.org/abs/1602.05629
18
18
  """
19
19
 
20
20
 
21
+ from collections.abc import Callable
21
22
  from logging import WARNING
22
- from typing import Callable, Optional, Union
23
23
 
24
24
  from flwr.common import (
25
25
  EvaluateIns,
@@ -97,18 +97,19 @@ class FedAvg(Strategy):
97
97
  min_fit_clients: int = 2,
98
98
  min_evaluate_clients: int = 2,
99
99
  min_available_clients: int = 2,
100
- evaluate_fn: Optional[
100
+ evaluate_fn: (
101
101
  Callable[
102
102
  [int, NDArrays, dict[str, Scalar]],
103
- Optional[tuple[float, dict[str, Scalar]]],
103
+ tuple[float, dict[str, Scalar]] | None,
104
104
  ]
105
- ] = None,
106
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
107
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
105
+ | None
106
+ ) = None,
107
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
108
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
108
109
  accept_failures: bool = True,
109
- initial_parameters: Optional[Parameters] = None,
110
- fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
111
- evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
110
+ initial_parameters: Parameters | None = None,
111
+ fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
112
+ evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
112
113
  inplace: bool = True,
113
114
  ) -> None:
114
115
  super().__init__()
@@ -148,9 +149,7 @@ class FedAvg(Strategy):
148
149
  num_clients = int(num_available_clients * self.fraction_evaluate)
149
150
  return max(num_clients, self.min_evaluate_clients), self.min_available_clients
150
151
 
151
- def initialize_parameters(
152
- self, client_manager: ClientManager
153
- ) -> Optional[Parameters]:
152
+ def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
154
153
  """Initialize global model parameters."""
155
154
  initial_parameters = self.initial_parameters
156
155
  self.initial_parameters = None # Don't keep initial parameters in memory
@@ -158,7 +157,7 @@ class FedAvg(Strategy):
158
157
 
159
158
  def evaluate(
160
159
  self, server_round: int, parameters: Parameters
161
- ) -> Optional[tuple[float, dict[str, Scalar]]]:
160
+ ) -> tuple[float, dict[str, Scalar]] | None:
162
161
  """Evaluate model parameters using an evaluation function."""
163
162
  if self.evaluate_fn is None:
164
163
  # No evaluation function provided
@@ -221,8 +220,8 @@ class FedAvg(Strategy):
221
220
  self,
222
221
  server_round: int,
223
222
  results: list[tuple[ClientProxy, FitRes]],
224
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
225
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
223
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
224
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
226
225
  """Aggregate fit results using weighted average."""
227
226
  if not results:
228
227
  return None, {}
@@ -257,8 +256,8 @@ class FedAvg(Strategy):
257
256
  self,
258
257
  server_round: int,
259
258
  results: list[tuple[ClientProxy, EvaluateRes]],
260
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
261
- ) -> tuple[Optional[float], dict[str, Scalar]]:
259
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
260
+ ) -> tuple[float | None, dict[str, Scalar]]:
262
261
  """Aggregate evaluation losses using weighted average."""
263
262
  if not results:
264
263
  return None, {}
@@ -18,7 +18,8 @@ Paper: arxiv.org/abs/1602.05629
18
18
  """
19
19
 
20
20
 
21
- from typing import Callable, Optional, Union, cast
21
+ from collections.abc import Callable
22
+ from typing import cast
22
23
 
23
24
  import numpy as np
24
25
 
@@ -79,16 +80,17 @@ class FedAvgAndroid(Strategy):
79
80
  min_fit_clients: int = 2,
80
81
  min_evaluate_clients: int = 2,
81
82
  min_available_clients: int = 2,
82
- evaluate_fn: Optional[
83
+ evaluate_fn: (
83
84
  Callable[
84
85
  [int, NDArrays, dict[str, Scalar]],
85
- Optional[tuple[float, dict[str, Scalar]]],
86
+ tuple[float, dict[str, Scalar]] | None,
86
87
  ]
87
- ] = None,
88
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
89
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
88
+ | None
89
+ ) = None,
90
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
91
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
90
92
  accept_failures: bool = True,
91
- initial_parameters: Optional[Parameters] = None,
93
+ initial_parameters: Parameters | None = None,
92
94
  ) -> None:
93
95
  super().__init__()
94
96
  self.min_fit_clients = min_fit_clients
@@ -117,9 +119,7 @@ class FedAvgAndroid(Strategy):
117
119
  num_clients = int(num_available_clients * self.fraction_evaluate)
118
120
  return max(num_clients, self.min_evaluate_clients), self.min_available_clients
119
121
 
120
- def initialize_parameters(
121
- self, client_manager: ClientManager
122
- ) -> Optional[Parameters]:
122
+ def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
123
123
  """Initialize global model parameters."""
124
124
  initial_parameters = self.initial_parameters
125
125
  self.initial_parameters = None # Don't keep initial parameters in memory
@@ -127,7 +127,7 @@ class FedAvgAndroid(Strategy):
127
127
 
128
128
  def evaluate(
129
129
  self, server_round: int, parameters: Parameters
130
- ) -> Optional[tuple[float, dict[str, Scalar]]]:
130
+ ) -> tuple[float, dict[str, Scalar]] | None:
131
131
  """Evaluate model parameters using an evaluation function."""
132
132
  if self.evaluate_fn is None:
133
133
  # No evaluation function provided
@@ -190,8 +190,8 @@ class FedAvgAndroid(Strategy):
190
190
  self,
191
191
  server_round: int,
192
192
  results: list[tuple[ClientProxy, FitRes]],
193
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
194
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
193
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
194
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
195
195
  """Aggregate fit results using weighted average."""
196
196
  if not results:
197
197
  return None, {}
@@ -209,8 +209,8 @@ class FedAvgAndroid(Strategy):
209
209
  self,
210
210
  server_round: int,
211
211
  results: list[tuple[ClientProxy, EvaluateRes]],
212
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
213
- ) -> tuple[Optional[float], dict[str, Scalar]]:
212
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
213
+ ) -> tuple[float | None, dict[str, Scalar]]:
214
214
  """Aggregate evaluation losses using weighted average."""
215
215
  if not results:
216
216
  return None, {}
@@ -18,8 +18,8 @@ Paper: arxiv.org/pdf/1909.06335.pdf
18
18
  """
19
19
 
20
20
 
21
+ from collections.abc import Callable
21
22
  from logging import WARNING
22
- from typing import Callable, Optional, Union
23
23
 
24
24
  from flwr.common import (
25
25
  FitRes,
@@ -82,18 +82,19 @@ class FedAvgM(FedAvg):
82
82
  min_fit_clients: int = 2,
83
83
  min_evaluate_clients: int = 2,
84
84
  min_available_clients: int = 2,
85
- evaluate_fn: Optional[
85
+ evaluate_fn: (
86
86
  Callable[
87
87
  [int, NDArrays, dict[str, Scalar]],
88
- Optional[tuple[float, dict[str, Scalar]]],
88
+ tuple[float, dict[str, Scalar]] | None,
89
89
  ]
90
- ] = None,
91
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
92
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
90
+ | None
91
+ ) = None,
92
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
93
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
93
94
  accept_failures: bool = True,
94
- initial_parameters: Optional[Parameters] = None,
95
- fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
96
- evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
95
+ initial_parameters: Parameters | None = None,
96
+ fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
97
+ evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
97
98
  server_learning_rate: float = 1.0,
98
99
  server_momentum: float = 0.0,
99
100
  ) -> None:
@@ -116,16 +117,14 @@ class FedAvgM(FedAvg):
116
117
  self.server_opt: bool = (self.server_momentum != 0.0) or (
117
118
  self.server_learning_rate != 1.0
118
119
  )
119
- self.momentum_vector: Optional[NDArrays] = None
120
+ self.momentum_vector: NDArrays | None = None
120
121
 
121
122
  def __repr__(self) -> str:
122
123
  """Compute a string representation of the strategy."""
123
124
  rep = f"FedAvgM(accept_failures={self.accept_failures})"
124
125
  return rep
125
126
 
126
- def initialize_parameters(
127
- self, client_manager: ClientManager
128
- ) -> Optional[Parameters]:
127
+ def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
129
128
  """Initialize global model parameters."""
130
129
  return self.initial_parameters
131
130
 
@@ -133,8 +132,8 @@ class FedAvgM(FedAvg):
133
132
  self,
134
133
  server_round: int,
135
134
  results: list[tuple[ClientProxy, FitRes]],
136
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
137
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
135
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
136
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
138
137
  """Aggregate fit results using weighted average."""
139
138
  if not results:
140
139
  return None, {}
@@ -161,7 +160,9 @@ class FedAvgM(FedAvg):
161
160
  pseudo_gradient: NDArrays = [
162
161
  x - y
163
162
  for x, y in zip(
164
- parameters_to_ndarrays(self.initial_parameters), fedavg_result
163
+ parameters_to_ndarrays(self.initial_parameters),
164
+ fedavg_result,
165
+ strict=True,
165
166
  )
166
167
  ]
167
168
  if self.server_momentum > 0.0:
@@ -171,7 +172,9 @@ class FedAvgM(FedAvg):
171
172
  ), "Momentum should have been created on round 1."
172
173
  self.momentum_vector = [
173
174
  self.server_momentum * x + y
174
- for x, y in zip(self.momentum_vector, pseudo_gradient)
175
+ for x, y in zip(
176
+ self.momentum_vector, pseudo_gradient, strict=True
177
+ )
175
178
  ]
176
179
  else:
177
180
  self.momentum_vector = pseudo_gradient
@@ -182,7 +185,7 @@ class FedAvgM(FedAvg):
182
185
  # SGD
183
186
  fedavg_result = [
184
187
  x - self.server_learning_rate * y
185
- for x, y in zip(initial_weights, pseudo_gradient)
188
+ for x, y in zip(initial_weights, pseudo_gradient, strict=True)
186
189
  ]
187
190
  # Update current weights
188
191
  self.initial_parameters = ndarrays_to_parameters(fedavg_result)
@@ -19,7 +19,6 @@ Paper: arxiv.org/pdf/1803.01498v1.pdf
19
19
 
20
20
 
21
21
  from logging import WARNING
22
- from typing import Optional, Union
23
22
 
24
23
  from flwr.common import (
25
24
  FitRes,
@@ -47,8 +46,8 @@ class FedMedian(FedAvg):
47
46
  self,
48
47
  server_round: int,
49
48
  results: list[tuple[ClientProxy, FitRes]],
50
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
51
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
49
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
50
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
52
51
  """Aggregate fit results using median."""
53
52
  if not results:
54
53
  return None, {}
@@ -18,7 +18,7 @@ Paper: arxiv.org/abs/2003.00295
18
18
  """
19
19
 
20
20
 
21
- from typing import Callable, Optional
21
+ from collections.abc import Callable
22
22
 
23
23
  from flwr.common import (
24
24
  MetricsAggregationFn,
@@ -84,18 +84,19 @@ class FedOpt(FedAvg):
84
84
  min_fit_clients: int = 2,
85
85
  min_evaluate_clients: int = 2,
86
86
  min_available_clients: int = 2,
87
- evaluate_fn: Optional[
87
+ evaluate_fn: (
88
88
  Callable[
89
89
  [int, NDArrays, dict[str, Scalar]],
90
- Optional[tuple[float, dict[str, Scalar]]],
90
+ tuple[float, dict[str, Scalar]] | None,
91
91
  ]
92
- ] = None,
93
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
94
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
92
+ | None
93
+ ) = None,
94
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
95
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
95
96
  accept_failures: bool = True,
96
97
  initial_parameters: Parameters,
97
- fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
98
- evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
98
+ fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
99
+ evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
99
100
  eta: float = 1e-1,
100
101
  eta_l: float = 1e-1,
101
102
  beta_1: float = 0.0,
@@ -122,8 +123,8 @@ class FedOpt(FedAvg):
122
123
  self.tau = tau
123
124
  self.beta_1 = beta_1
124
125
  self.beta_2 = beta_2
125
- self.m_t: Optional[NDArrays] = None
126
- self.v_t: Optional[NDArrays] = None
126
+ self.m_t: NDArrays | None = None
127
+ self.v_t: NDArrays | None = None
127
128
 
128
129
  def __repr__(self) -> str:
129
130
  """Compute a string representation of the strategy."""
@@ -18,7 +18,7 @@ Paper: arxiv.org/abs/1812.06127
18
18
  """
19
19
 
20
20
 
21
- from typing import Callable, Optional
21
+ from collections.abc import Callable
22
22
 
23
23
  from flwr.common import FitIns, MetricsAggregationFn, NDArrays, Parameters, Scalar
24
24
  from flwr.server.client_manager import ClientManager
@@ -111,18 +111,19 @@ class FedProx(FedAvg):
111
111
  min_fit_clients: int = 2,
112
112
  min_evaluate_clients: int = 2,
113
113
  min_available_clients: int = 2,
114
- evaluate_fn: Optional[
114
+ evaluate_fn: (
115
115
  Callable[
116
116
  [int, NDArrays, dict[str, Scalar]],
117
- Optional[tuple[float, dict[str, Scalar]]],
117
+ tuple[float, dict[str, Scalar]] | None,
118
118
  ]
119
- ] = None,
120
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
121
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
119
+ | None
120
+ ) = None,
121
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
122
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
122
123
  accept_failures: bool = True,
123
- initial_parameters: Optional[Parameters] = None,
124
- fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
125
- evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
124
+ initial_parameters: Parameters | None = None,
125
+ fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
126
+ evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
126
127
  proximal_mu: float,
127
128
  ) -> None:
128
129
  super().__init__(
@@ -16,8 +16,8 @@
16
16
 
17
17
  Paper: arxiv.org/abs/1803.01498
18
18
  """
19
+ from collections.abc import Callable
19
20
  from logging import WARNING
20
- from typing import Callable, Optional, Union
21
21
 
22
22
  from flwr.common import (
23
23
  FitRes,
@@ -76,18 +76,19 @@ class FedTrimmedAvg(FedAvg):
76
76
  min_fit_clients: int = 2,
77
77
  min_evaluate_clients: int = 2,
78
78
  min_available_clients: int = 2,
79
- evaluate_fn: Optional[
79
+ evaluate_fn: (
80
80
  Callable[
81
81
  [int, NDArrays, dict[str, Scalar]],
82
- Optional[tuple[float, dict[str, Scalar]]],
82
+ tuple[float, dict[str, Scalar]] | None,
83
83
  ]
84
- ] = None,
85
- on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
86
- on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None,
84
+ | None
85
+ ) = None,
86
+ on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
87
+ on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
87
88
  accept_failures: bool = True,
88
- initial_parameters: Optional[Parameters] = None,
89
- fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
90
- evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
89
+ initial_parameters: Parameters | None = None,
90
+ fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
91
+ evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
91
92
  beta: float = 0.2,
92
93
  ) -> None:
93
94
  super().__init__(
@@ -115,8 +116,8 @@ class FedTrimmedAvg(FedAvg):
115
116
  self,
116
117
  server_round: int,
117
118
  results: list[tuple[ClientProxy, FitRes]],
118
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
119
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
119
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
120
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
120
121
  """Aggregate fit results using trimmed average."""
121
122
  if not results:
122
123
  return None, {}
@@ -16,8 +16,9 @@
16
16
 
17
17
 
18
18
  import json
19
+ from collections.abc import Callable
19
20
  from logging import WARNING
20
- from typing import Any, Callable, Optional, Union, cast
21
+ from typing import Any, cast
21
22
 
22
23
  from flwr.common import EvaluateRes, FitRes, Parameters, Scalar
23
24
  from flwr.common.logger import log
@@ -32,16 +33,17 @@ class FedXgbBagging(FedAvg):
32
33
  # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
33
34
  def __init__(
34
35
  self,
35
- evaluate_function: Optional[
36
+ evaluate_function: (
36
37
  Callable[
37
38
  [int, Parameters, dict[str, Scalar]],
38
- Optional[tuple[float, dict[str, Scalar]]],
39
+ tuple[float, dict[str, Scalar]] | None,
39
40
  ]
40
- ] = None,
41
+ | None
42
+ ) = None,
41
43
  **kwargs: Any,
42
44
  ):
43
45
  self.evaluate_function = evaluate_function
44
- self.global_model: Optional[bytes] = None
46
+ self.global_model: bytes | None = None
45
47
  super().__init__(**kwargs)
46
48
 
47
49
  def __repr__(self) -> str:
@@ -53,8 +55,8 @@ class FedXgbBagging(FedAvg):
53
55
  self,
54
56
  server_round: int,
55
57
  results: list[tuple[ClientProxy, FitRes]],
56
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
57
- ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
58
+ failures: list[tuple[ClientProxy, FitRes] | BaseException],
59
+ ) -> tuple[Parameters | None, dict[str, Scalar]]:
58
60
  """Aggregate fit results using bagging."""
59
61
  if not results:
60
62
  return None, {}
@@ -80,8 +82,8 @@ class FedXgbBagging(FedAvg):
80
82
  self,
81
83
  server_round: int,
82
84
  results: list[tuple[ClientProxy, EvaluateRes]],
83
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
84
- ) -> tuple[Optional[float], dict[str, Scalar]]:
85
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
86
+ ) -> tuple[float | None, dict[str, Scalar]]:
85
87
  """Aggregate evaluation metrics using average."""
86
88
  if not results:
87
89
  return None, {}
@@ -101,7 +103,7 @@ class FedXgbBagging(FedAvg):
101
103
 
102
104
  def evaluate(
103
105
  self, server_round: int, parameters: Parameters
104
- ) -> Optional[tuple[float, dict[str, Scalar]]]:
106
+ ) -> tuple[float, dict[str, Scalar]] | None:
105
107
  """Evaluate model parameters using an evaluation function."""
106
108
  if self.evaluate_function is None:
107
109
  # No evaluation function provided
@@ -114,7 +116,7 @@ class FedXgbBagging(FedAvg):
114
116
 
115
117
 
116
118
  def aggregate(
117
- bst_prev_org: Optional[bytes],
119
+ bst_prev_org: bytes | None,
118
120
  bst_curr_org: bytes,
119
121
  ) -> bytes:
120
122
  """Conduct bagging aggregation for given trees."""