flwr 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (301) hide show
  1. flwr/__init__.py +16 -5
  2. flwr/app/error.py +2 -2
  3. flwr/app/exception.py +3 -3
  4. flwr/cli/app.py +34 -1
  5. flwr/cli/app_cmd/__init__.py +23 -0
  6. flwr/cli/app_cmd/publish.py +285 -0
  7. flwr/cli/app_cmd/review.py +252 -0
  8. flwr/cli/auth_plugin/__init__.py +15 -6
  9. flwr/cli/auth_plugin/auth_plugin.py +94 -0
  10. flwr/cli/auth_plugin/noop_auth_plugin.py +101 -0
  11. flwr/cli/auth_plugin/oidc_cli_plugin.py +46 -32
  12. flwr/cli/build.py +166 -53
  13. flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +29 -11
  14. flwr/cli/config_utils.py +101 -13
  15. flwr/cli/federation/__init__.py +24 -0
  16. flwr/cli/federation/ls.py +140 -0
  17. flwr/cli/federation/show.py +317 -0
  18. flwr/cli/install.py +91 -13
  19. flwr/cli/log.py +54 -11
  20. flwr/cli/login/login.py +41 -27
  21. flwr/cli/ls.py +177 -133
  22. flwr/cli/new/new.py +175 -40
  23. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
  24. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  25. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  26. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  28. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  29. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  30. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  31. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
  32. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  33. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  34. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
  35. flwr/cli/pull.py +12 -7
  36. flwr/cli/run/run.py +82 -31
  37. flwr/cli/run_utils.py +130 -0
  38. flwr/cli/stop.py +27 -9
  39. flwr/cli/supernode/__init__.py +25 -0
  40. flwr/cli/supernode/ls.py +268 -0
  41. flwr/cli/supernode/register.py +190 -0
  42. flwr/cli/supernode/unregister.py +140 -0
  43. flwr/cli/utils.py +464 -81
  44. flwr/client/__init__.py +2 -1
  45. flwr/client/dpfedavg_numpy_client.py +4 -1
  46. flwr/client/grpc_adapter_client/connection.py +12 -15
  47. flwr/client/grpc_rere_client/connection.py +68 -41
  48. flwr/client/grpc_rere_client/grpc_adapter.py +34 -14
  49. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +5 -7
  50. flwr/client/message_handler/message_handler.py +2 -2
  51. flwr/client/mod/secure_aggregation/secaggplus_mod.py +10 -8
  52. flwr/client/numpy_client.py +1 -1
  53. flwr/client/rest_client/connection.py +94 -51
  54. flwr/client/run_info_store.py +4 -5
  55. flwr/client/typing.py +1 -1
  56. flwr/clientapp/__init__.py +1 -2
  57. flwr/{client → clientapp}/client_app.py +9 -10
  58. flwr/clientapp/mod/centraldp_mods.py +16 -17
  59. flwr/clientapp/mod/localdp_mod.py +8 -9
  60. flwr/clientapp/typing.py +1 -1
  61. flwr/{client/clientapp → clientapp}/utils.py +4 -4
  62. flwr/common/address.py +1 -2
  63. flwr/common/args.py +3 -4
  64. flwr/common/config.py +13 -16
  65. flwr/common/constant.py +56 -13
  66. flwr/common/differential_privacy.py +3 -4
  67. flwr/common/event_log_plugin/event_log_plugin.py +3 -4
  68. flwr/common/exit/exit.py +15 -2
  69. flwr/common/exit/exit_code.py +39 -10
  70. flwr/common/exit/exit_handler.py +6 -2
  71. flwr/common/exit/signal_handler.py +5 -5
  72. flwr/common/grpc.py +6 -6
  73. flwr/common/inflatable_protobuf_utils.py +1 -1
  74. flwr/common/inflatable_utils.py +48 -31
  75. flwr/common/logger.py +19 -19
  76. flwr/common/message.py +4 -4
  77. flwr/common/object_ref.py +7 -7
  78. flwr/common/record/array.py +6 -6
  79. flwr/common/record/arrayrecord.py +18 -21
  80. flwr/common/record/configrecord.py +3 -3
  81. flwr/common/record/recorddict.py +5 -5
  82. flwr/common/record/typeddict.py +9 -2
  83. flwr/common/recorddict_compat.py +7 -10
  84. flwr/common/retry_invoker.py +20 -20
  85. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
  86. flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
  87. flwr/common/serde.py +9 -6
  88. flwr/common/serde_utils.py +2 -2
  89. flwr/common/telemetry.py +9 -5
  90. flwr/common/typing.py +59 -43
  91. flwr/compat/client/app.py +39 -38
  92. flwr/compat/client/grpc_client/connection.py +13 -13
  93. flwr/compat/server/app.py +5 -6
  94. flwr/proto/appio_pb2.py +13 -3
  95. flwr/proto/appio_pb2.pyi +134 -65
  96. flwr/proto/appio_pb2_grpc.py +20 -0
  97. flwr/proto/appio_pb2_grpc.pyi +27 -0
  98. flwr/proto/clientappio_pb2.py +17 -7
  99. flwr/proto/clientappio_pb2.pyi +15 -0
  100. flwr/proto/clientappio_pb2_grpc.py +206 -40
  101. flwr/proto/clientappio_pb2_grpc.pyi +168 -53
  102. flwr/proto/control_pb2.py +72 -40
  103. flwr/proto/control_pb2.pyi +319 -87
  104. flwr/proto/control_pb2_grpc.py +339 -28
  105. flwr/proto/control_pb2_grpc.pyi +209 -37
  106. flwr/proto/error_pb2.py +13 -3
  107. flwr/proto/error_pb2.pyi +24 -6
  108. flwr/proto/error_pb2_grpc.py +20 -0
  109. flwr/proto/error_pb2_grpc.pyi +27 -0
  110. flwr/proto/fab_pb2.py +24 -10
  111. flwr/proto/fab_pb2.pyi +68 -20
  112. flwr/proto/fab_pb2_grpc.py +20 -0
  113. flwr/proto/fab_pb2_grpc.pyi +27 -0
  114. flwr/proto/federation_pb2.py +38 -0
  115. flwr/proto/federation_pb2.pyi +56 -0
  116. flwr/proto/federation_pb2_grpc.py +24 -0
  117. flwr/proto/federation_pb2_grpc.pyi +31 -0
  118. flwr/proto/fleet_pb2.py +45 -27
  119. flwr/proto/fleet_pb2.pyi +186 -70
  120. flwr/proto/fleet_pb2_grpc.py +277 -66
  121. flwr/proto/fleet_pb2_grpc.pyi +201 -55
  122. flwr/proto/grpcadapter_pb2.py +14 -4
  123. flwr/proto/grpcadapter_pb2.pyi +38 -16
  124. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  125. flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
  126. flwr/proto/heartbeat_pb2.py +17 -7
  127. flwr/proto/heartbeat_pb2.pyi +51 -22
  128. flwr/proto/heartbeat_pb2_grpc.py +20 -0
  129. flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
  130. flwr/proto/log_pb2.py +13 -3
  131. flwr/proto/log_pb2.pyi +34 -11
  132. flwr/proto/log_pb2_grpc.py +20 -0
  133. flwr/proto/log_pb2_grpc.pyi +27 -0
  134. flwr/proto/message_pb2.py +15 -5
  135. flwr/proto/message_pb2.pyi +154 -86
  136. flwr/proto/message_pb2_grpc.py +20 -0
  137. flwr/proto/message_pb2_grpc.pyi +27 -0
  138. flwr/proto/node_pb2.py +16 -4
  139. flwr/proto/node_pb2.pyi +77 -4
  140. flwr/proto/node_pb2_grpc.py +20 -0
  141. flwr/proto/node_pb2_grpc.pyi +27 -0
  142. flwr/proto/recorddict_pb2.py +13 -3
  143. flwr/proto/recorddict_pb2.pyi +184 -107
  144. flwr/proto/recorddict_pb2_grpc.py +20 -0
  145. flwr/proto/recorddict_pb2_grpc.pyi +27 -0
  146. flwr/proto/run_pb2.py +40 -31
  147. flwr/proto/run_pb2.pyi +149 -84
  148. flwr/proto/run_pb2_grpc.py +20 -0
  149. flwr/proto/run_pb2_grpc.pyi +27 -0
  150. flwr/proto/serverappio_pb2.py +13 -3
  151. flwr/proto/serverappio_pb2.pyi +32 -8
  152. flwr/proto/serverappio_pb2_grpc.py +246 -65
  153. flwr/proto/serverappio_pb2_grpc.pyi +221 -85
  154. flwr/proto/simulationio_pb2.py +16 -8
  155. flwr/proto/simulationio_pb2.pyi +15 -0
  156. flwr/proto/simulationio_pb2_grpc.py +162 -41
  157. flwr/proto/simulationio_pb2_grpc.pyi +149 -55
  158. flwr/proto/transport_pb2.py +20 -10
  159. flwr/proto/transport_pb2.pyi +249 -160
  160. flwr/proto/transport_pb2_grpc.py +35 -4
  161. flwr/proto/transport_pb2_grpc.pyi +38 -8
  162. flwr/server/app.py +173 -127
  163. flwr/server/client_manager.py +4 -5
  164. flwr/server/client_proxy.py +10 -11
  165. flwr/server/compat/app.py +4 -5
  166. flwr/server/compat/app_utils.py +2 -1
  167. flwr/server/compat/grid_client_proxy.py +10 -12
  168. flwr/server/compat/legacy_context.py +3 -4
  169. flwr/server/fleet_event_log_interceptor.py +2 -1
  170. flwr/server/grid/grid.py +2 -3
  171. flwr/server/grid/grpc_grid.py +10 -8
  172. flwr/server/grid/inmemory_grid.py +4 -4
  173. flwr/server/run_serverapp.py +2 -3
  174. flwr/server/server.py +34 -39
  175. flwr/server/server_app.py +7 -8
  176. flwr/server/server_config.py +1 -2
  177. flwr/server/serverapp/app.py +34 -28
  178. flwr/server/serverapp_components.py +4 -5
  179. flwr/server/strategy/aggregate.py +9 -8
  180. flwr/server/strategy/bulyan.py +13 -11
  181. flwr/server/strategy/dp_adaptive_clipping.py +16 -20
  182. flwr/server/strategy/dp_fixed_clipping.py +12 -17
  183. flwr/server/strategy/dpfedavg_adaptive.py +3 -4
  184. flwr/server/strategy/dpfedavg_fixed.py +6 -10
  185. flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
  186. flwr/server/strategy/fedadagrad.py +18 -14
  187. flwr/server/strategy/fedadam.py +16 -14
  188. flwr/server/strategy/fedavg.py +16 -17
  189. flwr/server/strategy/fedavg_android.py +15 -15
  190. flwr/server/strategy/fedavgm.py +21 -18
  191. flwr/server/strategy/fedmedian.py +2 -3
  192. flwr/server/strategy/fedopt.py +11 -10
  193. flwr/server/strategy/fedprox.py +10 -9
  194. flwr/server/strategy/fedtrimmedavg.py +12 -11
  195. flwr/server/strategy/fedxgb_bagging.py +13 -11
  196. flwr/server/strategy/fedxgb_cyclic.py +6 -6
  197. flwr/server/strategy/fedxgb_nn_avg.py +4 -4
  198. flwr/server/strategy/fedyogi.py +16 -14
  199. flwr/server/strategy/krum.py +12 -11
  200. flwr/server/strategy/qfedavg.py +16 -15
  201. flwr/server/strategy/strategy.py +6 -9
  202. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +19 -8
  203. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
  204. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
  205. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
  206. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
  207. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +136 -42
  208. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +28 -51
  209. flwr/server/superlink/fleet/message_handler/message_handler.py +100 -49
  210. flwr/server/superlink/fleet/rest_rere/rest_api.py +54 -33
  211. flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
  212. flwr/server/superlink/fleet/vce/backend/raybackend.py +6 -6
  213. flwr/server/superlink/fleet/vce/vce_api.py +32 -13
  214. flwr/server/superlink/linkstate/in_memory_linkstate.py +266 -207
  215. flwr/server/superlink/linkstate/linkstate.py +161 -62
  216. flwr/server/superlink/linkstate/linkstate_factory.py +24 -6
  217. flwr/server/superlink/linkstate/sqlite_linkstate.py +698 -638
  218. flwr/server/superlink/linkstate/utils.py +9 -60
  219. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
  220. flwr/server/superlink/serverappio/serverappio_servicer.py +28 -23
  221. flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
  222. flwr/server/superlink/simulation/simulationio_servicer.py +19 -14
  223. flwr/server/superlink/utils.py +4 -6
  224. flwr/server/typing.py +1 -1
  225. flwr/server/utils/tensorboard.py +15 -8
  226. flwr/server/utils/validator.py +2 -3
  227. flwr/server/workflow/default_workflows.py +5 -5
  228. flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
  229. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +12 -10
  230. flwr/serverapp/strategy/bulyan.py +16 -15
  231. flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
  232. flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
  233. flwr/serverapp/strategy/fedadagrad.py +10 -11
  234. flwr/serverapp/strategy/fedadam.py +10 -11
  235. flwr/serverapp/strategy/fedavg.py +9 -10
  236. flwr/serverapp/strategy/fedavgm.py +17 -16
  237. flwr/serverapp/strategy/fedmedian.py +2 -2
  238. flwr/serverapp/strategy/fedopt.py +10 -11
  239. flwr/serverapp/strategy/fedprox.py +7 -8
  240. flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
  241. flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
  242. flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
  243. flwr/serverapp/strategy/fedyogi.py +9 -11
  244. flwr/serverapp/strategy/krum.py +7 -7
  245. flwr/serverapp/strategy/multikrum.py +9 -9
  246. flwr/serverapp/strategy/qfedavg.py +17 -16
  247. flwr/serverapp/strategy/strategy.py +6 -9
  248. flwr/serverapp/strategy/strategy_utils.py +7 -8
  249. flwr/simulation/app.py +46 -42
  250. flwr/simulation/legacy_app.py +12 -12
  251. flwr/simulation/ray_transport/ray_actor.py +11 -12
  252. flwr/simulation/ray_transport/ray_client_proxy.py +12 -13
  253. flwr/simulation/run_simulation.py +44 -43
  254. flwr/simulation/simulationio_connection.py +4 -4
  255. flwr/supercore/cli/flower_superexec.py +3 -4
  256. flwr/supercore/constant.py +52 -0
  257. flwr/supercore/corestate/corestate.py +24 -3
  258. flwr/supercore/corestate/in_memory_corestate.py +138 -0
  259. flwr/supercore/corestate/sqlite_corestate.py +157 -0
  260. flwr/supercore/ffs/disk_ffs.py +1 -2
  261. flwr/supercore/ffs/ffs.py +1 -2
  262. flwr/supercore/ffs/ffs_factory.py +1 -2
  263. flwr/{common → supercore}/heartbeat.py +20 -25
  264. flwr/supercore/object_store/in_memory_object_store.py +1 -6
  265. flwr/supercore/object_store/object_store.py +1 -2
  266. flwr/supercore/object_store/object_store_factory.py +27 -8
  267. flwr/supercore/object_store/sqlite_object_store.py +253 -0
  268. flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
  269. flwr/supercore/primitives/asymmetric.py +117 -0
  270. flwr/supercore/primitives/asymmetric_ed25519.py +175 -0
  271. flwr/supercore/sqlite_mixin.py +159 -0
  272. flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
  273. flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
  274. flwr/supercore/superexec/run_superexec.py +9 -13
  275. flwr/supercore/utils.py +20 -0
  276. flwr/superlink/artifact_provider/artifact_provider.py +1 -2
  277. flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
  278. flwr/superlink/auth_plugin/auth_plugin.py +88 -0
  279. flwr/superlink/auth_plugin/noop_auth_plugin.py +84 -0
  280. flwr/superlink/federation/__init__.py +24 -0
  281. flwr/superlink/federation/federation_manager.py +64 -0
  282. flwr/superlink/federation/noop_federation_manager.py +71 -0
  283. flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +41 -32
  284. flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
  285. flwr/superlink/servicer/control/control_grpc.py +18 -17
  286. flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
  287. flwr/superlink/servicer/control/control_servicer.py +239 -63
  288. flwr/supernode/cli/flower_supernode.py +74 -26
  289. flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
  290. flwr/supernode/nodestate/nodestate.py +7 -8
  291. flwr/supernode/nodestate/nodestate_factory.py +7 -4
  292. flwr/supernode/runtime/run_clientapp.py +43 -24
  293. flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
  294. flwr/supernode/start_client_internal.py +175 -51
  295. {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
  296. flwr-1.24.0.dist-info/RECORD +454 -0
  297. flwr/common/auth_plugin/auth_plugin.py +0 -149
  298. flwr/supercore/object_store/utils.py +0 -43
  299. flwr-1.22.0.dist-info/RECORD +0 -428
  300. {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
  301. {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
@@ -18,10 +18,9 @@ Paper: arxiv.org/abs/1802.07927
18
18
  """
19
19
 
20
20
 
21
- from collections import OrderedDict
22
- from collections.abc import Iterable
21
+ from collections.abc import Callable, Iterable
23
22
  from logging import INFO, WARN
24
- from typing import Callable, Optional, cast
23
+ from typing import cast
25
24
 
26
25
  import numpy as np
27
26
 
@@ -104,15 +103,15 @@ class Bulyan(FedAvg):
104
103
  weighted_by_key: str = "num-examples",
105
104
  arrayrecord_key: str = "arrays",
106
105
  configrecord_key: str = "config",
107
- train_metrics_aggr_fn: Optional[
108
- Callable[[list[RecordDict], str], MetricRecord]
109
- ] = None,
110
- evaluate_metrics_aggr_fn: Optional[
111
- Callable[[list[RecordDict], str], MetricRecord]
112
- ] = None,
113
- selection_rule: Optional[
114
- Callable[[list[RecordDict], int, int], list[RecordDict]]
115
- ] = None,
106
+ train_metrics_aggr_fn: (
107
+ Callable[[list[RecordDict], str], MetricRecord] | None
108
+ ) = None,
109
+ evaluate_metrics_aggr_fn: (
110
+ Callable[[list[RecordDict], str], MetricRecord] | None
111
+ ) = None,
112
+ selection_rule: (
113
+ Callable[[list[RecordDict], int, int], list[RecordDict]] | None
114
+ ) = None,
116
115
  ) -> None:
117
116
  super().__init__(
118
117
  fraction_train=fraction_train,
@@ -140,7 +139,7 @@ class Bulyan(FedAvg):
140
139
  self,
141
140
  server_round: int,
142
141
  replies: Iterable[Message],
143
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
142
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
144
143
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
145
144
  valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
146
145
 
@@ -175,7 +174,9 @@ class Bulyan(FedAvg):
175
174
  ]
176
175
 
177
176
  # Compute median
178
- median_ndarrays = [np.median(arr, axis=0) for arr in zip(*selected_ndarrays)]
177
+ median_ndarrays = [
178
+ np.median(arr, axis=0) for arr in zip(*selected_ndarrays, strict=True)
179
+ ]
179
180
 
180
181
  # Aggregate the beta closest weights element-wise
181
182
  aggregated_ndarrays = aggregate_n_closest_weights(
@@ -184,7 +185,7 @@ class Bulyan(FedAvg):
184
185
 
185
186
  # Convert to ArrayRecord
186
187
  arrays = ArrayRecord(
187
- OrderedDict(zip(array_keys, map(Array, aggregated_ndarrays)))
188
+ dict(zip(array_keys, map(Array, aggregated_ndarrays), strict=True))
188
189
  )
189
190
 
190
191
  # Aggregate MetricRecords
@@ -19,10 +19,8 @@ Paper (Andrew et al.): https://arxiv.org/abs/1905.03871
19
19
 
20
20
  import math
21
21
  from abc import ABC
22
- from collections import OrderedDict
23
22
  from collections.abc import Iterable
24
23
  from logging import INFO
25
- from typing import Optional
26
24
 
27
25
  import numpy as np
28
26
 
@@ -53,7 +51,7 @@ class DifferentialPrivacyAdaptiveBase(Strategy, ABC):
53
51
  initial_clipping_norm: float = 0.1,
54
52
  target_clipped_quantile: float = 0.5,
55
53
  clip_norm_lr: float = 0.2,
56
- clipped_count_stddev: Optional[float] = None,
54
+ clipped_count_stddev: float | None = None,
57
55
  ) -> None:
58
56
  super().__init__()
59
57
 
@@ -96,7 +94,7 @@ class DifferentialPrivacyAdaptiveBase(Strategy, ABC):
96
94
  add_gaussian_noise_inplace(nds, stdv)
97
95
  log(INFO, "aggregate_fit: central DP noise with %.4f stdev added", stdv)
98
96
  return ArrayRecord(
99
- OrderedDict({k: Array(v) for k, v in zip(aggregated.keys(), nds)})
97
+ {k: Array(v) for k, v in zip(aggregated.keys(), nds, strict=True)}
100
98
  )
101
99
 
102
100
  def _noisy_fraction(self, count: int, total: int) -> float:
@@ -115,7 +113,7 @@ class DifferentialPrivacyAdaptiveBase(Strategy, ABC):
115
113
 
116
114
  def aggregate_evaluate(
117
115
  self, server_round: int, replies: Iterable[Message]
118
- ) -> Optional[MetricRecord]:
116
+ ) -> MetricRecord | None:
119
117
  """Aggregate MetricRecords in the received Messages."""
120
118
  return self.strategy.aggregate_evaluate(server_round, replies)
121
119
 
@@ -136,7 +134,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
136
134
  initial_clipping_norm: float = 0.1,
137
135
  target_clipped_quantile: float = 0.5,
138
136
  clip_norm_lr: float = 0.2,
139
- clipped_count_stddev: Optional[float] = None,
137
+ clipped_count_stddev: float | None = None,
140
138
  ) -> None:
141
139
  super().__init__(
142
140
  strategy,
@@ -171,7 +169,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
171
169
 
172
170
  def aggregate_train(
173
171
  self, server_round: int, replies: Iterable[Message]
174
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
172
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
175
173
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
176
174
  if not validate_replies(replies, self.num_sampled_clients):
177
175
  return None, None
@@ -184,16 +182,19 @@ class DifferentialPrivacyServerSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
184
182
  for arr_name, record in reply.content.array_records.items():
185
183
  reply_nd = record.to_numpy_ndarrays()
186
184
  model_update = [
187
- np.subtract(x, y) for (x, y) in zip(reply_nd, current_nd)
185
+ np.subtract(x, y)
186
+ for (x, y) in zip(reply_nd, current_nd, strict=True)
188
187
  ]
189
188
  norm_bit = adaptive_clip_inputs_inplace(
190
189
  model_update, self.clipping_norm
191
190
  )
192
191
  clipped_indicator_count += int(norm_bit)
193
192
  # reconstruct array using clipped contribution from current round
194
- restored = [c + u for c, u in zip(current_nd, model_update)]
193
+ restored = [
194
+ c + u for c, u in zip(current_nd, model_update, strict=True)
195
+ ]
195
196
  reply.content[arr_name] = ArrayRecord(
196
- OrderedDict({k: Array(v) for k, v in zip(record.keys(), restored)})
197
+ {k: Array(v) for k, v in zip(record.keys(), restored, strict=True)}
197
198
  )
198
199
  log(
199
200
  INFO,
@@ -287,7 +288,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
287
288
 
288
289
  def aggregate_train(
289
290
  self, server_round: int, replies: Iterable[Message]
290
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
291
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
291
292
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
292
293
  if not validate_replies(replies, self.num_sampled_clients):
293
294
  return None, None
@@ -17,11 +17,10 @@
17
17
  Papers: https://arxiv.org/abs/1712.07557, https://arxiv.org/abs/1710.06963
18
18
  """
19
19
 
20
+
20
21
  from abc import ABC
21
- from collections import OrderedDict
22
22
  from collections.abc import Iterable
23
23
  from logging import INFO, WARNING
24
- from typing import Optional
25
24
 
26
25
  from flwr.common import Array, ArrayRecord, ConfigRecord, Message, MetricRecord, log
27
26
  from flwr.common.differential_privacy import (
@@ -112,12 +111,12 @@ class DifferentialPrivacyFixedClippingBase(Strategy, ABC):
112
111
  )
113
112
 
114
113
  return ArrayRecord(
115
- OrderedDict(
116
- {
117
- k: Array(v)
118
- for k, v in zip(aggregated_arrays.keys(), aggregated_ndarrays)
119
- }
120
- )
114
+ {
115
+ k: Array(v)
116
+ for k, v in zip(
117
+ aggregated_arrays.keys(), aggregated_ndarrays, strict=True
118
+ )
119
+ }
121
120
  )
122
121
 
123
122
  def configure_evaluate(
@@ -130,7 +129,7 @@ class DifferentialPrivacyFixedClippingBase(Strategy, ABC):
130
129
  self,
131
130
  server_round: int,
132
131
  replies: Iterable[Message],
133
- ) -> Optional[MetricRecord]:
132
+ ) -> MetricRecord | None:
134
133
  """Aggregate MetricRecords in the received Messages."""
135
134
  return self.strategy.aggregate_evaluate(server_round, replies)
136
135
 
@@ -199,7 +198,7 @@ class DifferentialPrivacyServerSideFixedClipping(DifferentialPrivacyFixedClippin
199
198
  self,
200
199
  server_round: int,
201
200
  replies: Iterable[Message],
202
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
201
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
203
202
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
204
203
  if not validate_replies(replies, self.num_sampled_clients):
205
204
  return None, None
@@ -217,9 +216,7 @@ class DifferentialPrivacyServerSideFixedClipping(DifferentialPrivacyFixedClippin
217
216
  )
218
217
  # Replace content while preserving keys
219
218
  reply.content[arr_name] = ArrayRecord(
220
- OrderedDict(
221
- {k: Array(v) for k, v in zip(record.keys(), reply_ndarrays)}
222
- )
219
+ dict(zip(record.keys(), map(Array, reply_ndarrays), strict=True))
223
220
  )
224
221
  log(
225
222
  INFO,
@@ -302,7 +299,7 @@ class DifferentialPrivacyClientSideFixedClipping(DifferentialPrivacyFixedClippin
302
299
  self,
303
300
  server_round: int,
304
301
  replies: Iterable[Message],
305
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
302
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
306
303
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
307
304
  if not validate_replies(replies, self.num_sampled_clients):
308
305
  return None, None
@@ -19,9 +19,8 @@ Adaptive Federated Optimization using Adagrad.
19
19
  Paper: arxiv.org/abs/2003.00295
20
20
  """
21
21
 
22
- from collections import OrderedDict
23
- from collections.abc import Iterable
24
- from typing import Callable, Optional
22
+
23
+ from collections.abc import Callable, Iterable
25
24
 
26
25
  import numpy as np
27
26
 
@@ -90,12 +89,12 @@ class FedAdagrad(FedOpt):
90
89
  weighted_by_key: str = "num-examples",
91
90
  arrayrecord_key: str = "arrays",
92
91
  configrecord_key: str = "config",
93
- train_metrics_aggr_fn: Optional[
94
- Callable[[list[RecordDict], str], MetricRecord]
95
- ] = None,
96
- evaluate_metrics_aggr_fn: Optional[
97
- Callable[[list[RecordDict], str], MetricRecord]
98
- ] = None,
92
+ train_metrics_aggr_fn: (
93
+ Callable[[list[RecordDict], str], MetricRecord] | None
94
+ ) = None,
95
+ evaluate_metrics_aggr_fn: (
96
+ Callable[[list[RecordDict], str], MetricRecord] | None
97
+ ) = None,
99
98
  eta: float = 1e-1,
100
99
  eta_l: float = 1e-1,
101
100
  tau: float = 1e-3,
@@ -122,7 +121,7 @@ class FedAdagrad(FedOpt):
122
121
  self,
123
122
  server_round: int,
124
123
  replies: Iterable[Message],
125
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
124
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
126
125
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
127
126
  aggregated_arrayrecord, aggregated_metrics = super().aggregate_train(
128
127
  server_round, replies
@@ -154,6 +153,6 @@ class FedAdagrad(FedOpt):
154
153
  }
155
154
 
156
155
  return (
157
- ArrayRecord(OrderedDict({k: Array(v) for k, v in new_arrays.items()})),
156
+ ArrayRecord({k: Array(v) for k, v in new_arrays.items()}),
158
157
  aggregated_metrics,
159
158
  )
@@ -19,9 +19,8 @@
19
19
  Paper: arxiv.org/abs/2003.00295
20
20
  """
21
21
 
22
- from collections import OrderedDict
23
- from collections.abc import Iterable
24
- from typing import Callable, Optional
22
+
23
+ from collections.abc import Callable, Iterable
25
24
 
26
25
  import numpy as np
27
26
 
@@ -94,12 +93,12 @@ class FedAdam(FedOpt):
94
93
  weighted_by_key: str = "num-examples",
95
94
  arrayrecord_key: str = "arrays",
96
95
  configrecord_key: str = "config",
97
- train_metrics_aggr_fn: Optional[
98
- Callable[[list[RecordDict], str], MetricRecord]
99
- ] = None,
100
- evaluate_metrics_aggr_fn: Optional[
101
- Callable[[list[RecordDict], str], MetricRecord]
102
- ] = None,
96
+ train_metrics_aggr_fn: (
97
+ Callable[[list[RecordDict], str], MetricRecord] | None
98
+ ) = None,
99
+ evaluate_metrics_aggr_fn: (
100
+ Callable[[list[RecordDict], str], MetricRecord] | None
101
+ ) = None,
103
102
  eta: float = 1e-1,
104
103
  eta_l: float = 1e-1,
105
104
  beta_1: float = 0.9,
@@ -128,7 +127,7 @@ class FedAdam(FedOpt):
128
127
  self,
129
128
  server_round: int,
130
129
  replies: Iterable[Message],
131
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
130
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
132
131
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
133
132
  aggregated_arrayrecord, aggregated_metrics = super().aggregate_train(
134
133
  server_round, replies
@@ -173,6 +172,6 @@ class FedAdam(FedOpt):
173
172
  }
174
173
 
175
174
  return (
176
- ArrayRecord(OrderedDict({k: Array(v) for k, v in new_arrays.items()})),
175
+ ArrayRecord({k: Array(v) for k, v in new_arrays.items()}),
177
176
  aggregated_metrics,
178
177
  )
@@ -15,9 +15,8 @@
15
15
  """Flower message-based FedAvg strategy."""
16
16
 
17
17
 
18
- from collections.abc import Iterable
18
+ from collections.abc import Callable, Iterable
19
19
  from logging import INFO, WARNING
20
- from typing import Callable, Optional
21
20
 
22
21
  from flwr.common import (
23
22
  ArrayRecord,
@@ -91,12 +90,12 @@ class FedAvg(Strategy):
91
90
  weighted_by_key: str = "num-examples",
92
91
  arrayrecord_key: str = "arrays",
93
92
  configrecord_key: str = "config",
94
- train_metrics_aggr_fn: Optional[
95
- Callable[[list[RecordDict], str], MetricRecord]
96
- ] = None,
97
- evaluate_metrics_aggr_fn: Optional[
98
- Callable[[list[RecordDict], str], MetricRecord]
99
- ] = None,
93
+ train_metrics_aggr_fn: (
94
+ Callable[[list[RecordDict], str], MetricRecord] | None
95
+ ) = None,
96
+ evaluate_metrics_aggr_fn: (
97
+ Callable[[list[RecordDict], str], MetricRecord] | None
98
+ ) = None,
100
99
  ) -> None:
101
100
  self.fraction_train = fraction_train
102
101
  self.fraction_evaluate = fraction_evaluate
@@ -251,7 +250,7 @@ class FedAvg(Strategy):
251
250
  self,
252
251
  server_round: int,
253
252
  replies: Iterable[Message],
254
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
253
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
255
254
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
256
255
  valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
257
256
 
@@ -304,7 +303,7 @@ class FedAvg(Strategy):
304
303
  self,
305
304
  server_round: int,
306
305
  replies: Iterable[Message],
307
- ) -> Optional[MetricRecord]:
306
+ ) -> MetricRecord | None:
308
307
  """Aggregate MetricRecords in the received Messages."""
309
308
  valid_replies, _ = self._check_and_log_replies(replies, is_train=False)
310
309
 
@@ -18,10 +18,8 @@ Paper: arxiv.org/pdf/1909.06335.pdf
18
18
  """
19
19
 
20
20
 
21
- from collections import OrderedDict
22
- from collections.abc import Iterable
21
+ from collections.abc import Callable, Iterable
23
22
  from logging import INFO
24
- from typing import Callable, Optional
25
23
 
26
24
  from flwr.common import (
27
25
  Array,
@@ -93,12 +91,12 @@ class FedAvgM(FedAvg):
93
91
  weighted_by_key: str = "num-examples",
94
92
  arrayrecord_key: str = "arrays",
95
93
  configrecord_key: str = "config",
96
- train_metrics_aggr_fn: Optional[
97
- Callable[[list[RecordDict], str], MetricRecord]
98
- ] = None,
99
- evaluate_metrics_aggr_fn: Optional[
100
- Callable[[list[RecordDict], str], MetricRecord]
101
- ] = None,
94
+ train_metrics_aggr_fn: (
95
+ Callable[[list[RecordDict], str], MetricRecord] | None
96
+ ) = None,
97
+ evaluate_metrics_aggr_fn: (
98
+ Callable[[list[RecordDict], str], MetricRecord] | None
99
+ ) = None,
102
100
  server_learning_rate: float = 1.0,
103
101
  server_momentum: float = 0.0,
104
102
  ) -> None:
@@ -119,8 +117,8 @@ class FedAvgM(FedAvg):
119
117
  self.server_opt: bool = (self.server_momentum != 0.0) or (
120
118
  self.server_learning_rate != 1.0
121
119
  )
122
- self.current_arrays: Optional[ArrayRecord] = None
123
- self.momentum_vector: Optional[NDArrays] = None
120
+ self.current_arrays: ArrayRecord | None = None
121
+ self.momentum_vector: NDArrays | None = None
124
122
 
125
123
  def summary(self) -> None:
126
124
  """Log summary configuration of the strategy."""
@@ -143,7 +141,7 @@ class FedAvgM(FedAvg):
143
141
  self,
144
142
  server_round: int,
145
143
  replies: Iterable[Message],
146
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
144
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
147
145
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
148
146
  # Call FedAvg aggregate_train to perform validation and aggregation
149
147
  aggregated_arrays, aggregated_metrics = super().aggregate_train(
@@ -168,7 +166,8 @@ class FedAvgM(FedAvg):
168
166
 
169
167
  # Remember that updates are the opposite of gradients
170
168
  pseudo_gradient = [
171
- old - new for new, old in zip(aggregated_ndarrays, ndarrays)
169
+ old - new
170
+ for new, old in zip(aggregated_ndarrays, ndarrays, strict=True)
172
171
  ]
173
172
  if self.server_momentum > 0.0:
174
173
  if self.momentum_vector is None:
@@ -177,7 +176,9 @@ class FedAvgM(FedAvg):
177
176
  else:
178
177
  self.momentum_vector = [
179
178
  self.server_momentum * mv + pg
180
- for mv, pg in zip(self.momentum_vector, pseudo_gradient)
179
+ for mv, pg in zip(
180
+ self.momentum_vector, pseudo_gradient, strict=True
181
+ )
181
182
  ]
182
183
 
183
184
  # No nesterov for now
@@ -186,10 +187,10 @@ class FedAvgM(FedAvg):
186
187
  # SGD and convert back to ArrayRecord
187
188
  updated_array_list = [
188
189
  Array(old - self.server_learning_rate * pg)
189
- for old, pg in zip(ndarrays, pseudo_gradient)
190
+ for old, pg in zip(ndarrays, pseudo_gradient, strict=True)
190
191
  ]
191
192
  aggregated_arrays = ArrayRecord(
192
- OrderedDict(zip(array_keys, updated_array_list))
193
+ dict(zip(array_keys, updated_array_list, strict=True))
193
194
  )
194
195
 
195
196
  # Update current weights
@@ -19,7 +19,7 @@ Paper: arxiv.org/pdf/1803.01498v1.pdf
19
19
 
20
20
 
21
21
  from collections.abc import Iterable
22
- from typing import Optional, cast
22
+ from typing import cast
23
23
 
24
24
  import numpy as np
25
25
 
@@ -72,7 +72,7 @@ class FedMedian(FedAvg):
72
72
  self,
73
73
  server_round: int,
74
74
  replies: Iterable[Message],
75
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
75
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
76
76
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
77
77
  # Call FedAvg aggregate_train to perform validation and aggregation
78
78
  valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
@@ -17,9 +17,8 @@
17
17
  Paper: arxiv.org/abs/2003.00295
18
18
  """
19
19
 
20
- from collections.abc import Iterable
20
+ from collections.abc import Callable, Iterable
21
21
  from logging import INFO
22
- from typing import Callable, Optional
23
22
 
24
23
  import numpy as np
25
24
 
@@ -101,12 +100,12 @@ class FedOpt(FedAvg):
101
100
  weighted_by_key: str = "num-examples",
102
101
  arrayrecord_key: str = "arrays",
103
102
  configrecord_key: str = "config",
104
- train_metrics_aggr_fn: Optional[
105
- Callable[[list[RecordDict], str], MetricRecord]
106
- ] = None,
107
- evaluate_metrics_aggr_fn: Optional[
108
- Callable[[list[RecordDict], str], MetricRecord]
109
- ] = None,
103
+ train_metrics_aggr_fn: (
104
+ Callable[[list[RecordDict], str], MetricRecord] | None
105
+ ) = None,
106
+ evaluate_metrics_aggr_fn: (
107
+ Callable[[list[RecordDict], str], MetricRecord] | None
108
+ ) = None,
110
109
  eta: float = 1e-1,
111
110
  eta_l: float = 1e-1,
112
111
  beta_1: float = 0.0,
@@ -125,14 +124,14 @@ class FedOpt(FedAvg):
125
124
  train_metrics_aggr_fn=train_metrics_aggr_fn,
126
125
  evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
127
126
  )
128
- self.current_arrays: Optional[dict[str, NDArray]] = None
127
+ self.current_arrays: dict[str, NDArray] | None = None
129
128
  self.eta = eta
130
129
  self.eta_l = eta_l
131
130
  self.tau = tau
132
131
  self.beta_1 = beta_1
133
132
  self.beta_2 = beta_2
134
- self.m_t: Optional[dict[str, NDArray]] = None
135
- self.v_t: Optional[dict[str, NDArray]] = None
133
+ self.m_t: dict[str, NDArray] | None = None
134
+ self.v_t: dict[str, NDArray] | None = None
136
135
 
137
136
  def summary(self) -> None:
138
137
  """Log summary configuration of the strategy."""
@@ -18,9 +18,8 @@ Paper: arxiv.org/abs/1812.06127
18
18
  """
19
19
 
20
20
 
21
- from collections.abc import Iterable
21
+ from collections.abc import Callable, Iterable
22
22
  from logging import INFO, WARN
23
- from typing import Callable, Optional
24
23
 
25
24
  from flwr.common import (
26
25
  ArrayRecord,
@@ -130,12 +129,12 @@ class FedProx(FedAvg):
130
129
  weighted_by_key: str = "num-examples",
131
130
  arrayrecord_key: str = "arrays",
132
131
  configrecord_key: str = "config",
133
- train_metrics_aggr_fn: Optional[
134
- Callable[[list[RecordDict], str], MetricRecord]
135
- ] = None,
136
- evaluate_metrics_aggr_fn: Optional[
137
- Callable[[list[RecordDict], str], MetricRecord]
138
- ] = None,
132
+ train_metrics_aggr_fn: (
133
+ Callable[[list[RecordDict], str], MetricRecord] | None
134
+ ) = None,
135
+ evaluate_metrics_aggr_fn: (
136
+ Callable[[list[RecordDict], str], MetricRecord] | None
137
+ ) = None,
139
138
  proximal_mu: float = 0.0,
140
139
  ) -> None:
141
140
  super().__init__(
@@ -18,9 +18,9 @@ Paper: arxiv.org/abs/1803.01498
18
18
  """
19
19
 
20
20
 
21
- from collections.abc import Iterable
21
+ from collections.abc import Callable, Iterable
22
22
  from logging import INFO
23
- from typing import Callable, Optional, cast
23
+ from typing import cast
24
24
 
25
25
  import numpy as np
26
26
 
@@ -83,12 +83,12 @@ class FedTrimmedAvg(FedAvg):
83
83
  weighted_by_key: str = "num-examples",
84
84
  arrayrecord_key: str = "arrays",
85
85
  configrecord_key: str = "config",
86
- train_metrics_aggr_fn: Optional[
87
- Callable[[list[RecordDict], str], MetricRecord]
88
- ] = None,
89
- evaluate_metrics_aggr_fn: Optional[
90
- Callable[[list[RecordDict], str], MetricRecord]
91
- ] = None,
86
+ train_metrics_aggr_fn: (
87
+ Callable[[list[RecordDict], str], MetricRecord] | None
88
+ ) = None,
89
+ evaluate_metrics_aggr_fn: (
90
+ Callable[[list[RecordDict], str], MetricRecord] | None
91
+ ) = None,
92
92
  beta: float = 0.2,
93
93
  ) -> None:
94
94
  super().__init__(
@@ -115,7 +115,7 @@ class FedTrimmedAvg(FedAvg):
115
115
  self,
116
116
  server_round: int,
117
117
  replies: Iterable[Message],
118
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
118
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
119
119
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
120
120
  # Call FedAvg aggregate_train to perform validation and aggregation
121
121
  valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower message-based FedXgbBagging strategy."""
16
16
  from collections.abc import Iterable
17
- from typing import Optional, cast
17
+ from typing import cast
18
18
 
19
19
  import numpy as np
20
20
 
@@ -65,7 +65,7 @@ class FedXgbBagging(FedAvg):
65
65
  average using the provided weight factor key.
66
66
  """
67
67
 
68
- current_bst: Optional[bytes] = None
68
+ current_bst: bytes | None = None
69
69
 
70
70
  def _ensure_single_array(self, arrays: ArrayRecord) -> None:
71
71
  """Check that ensures there's only one Array in the ArrayRecord."""
@@ -89,7 +89,7 @@ class FedXgbBagging(FedAvg):
89
89
  self,
90
90
  server_round: int,
91
91
  replies: Iterable[Message],
92
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
92
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
93
93
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
94
94
  valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
95
95
 
@@ -15,9 +15,9 @@
15
15
  """Flower message-based FedXgbCyclic strategy."""
16
16
 
17
17
 
18
- from collections.abc import Iterable
18
+ from collections.abc import Callable, Iterable
19
19
  from logging import INFO
20
- from typing import Callable, Optional, cast
20
+ from typing import cast
21
21
 
22
22
  from flwr.common import (
23
23
  ArrayRecord,
@@ -78,12 +78,12 @@ class FedXgbCyclic(FedAvg):
78
78
  weighted_by_key: str = "num-examples",
79
79
  arrayrecord_key: str = "arrays",
80
80
  configrecord_key: str = "config",
81
- train_metrics_aggr_fn: Optional[
82
- Callable[[list[RecordDict], str], MetricRecord]
83
- ] = None,
84
- evaluate_metrics_aggr_fn: Optional[
85
- Callable[[list[RecordDict], str], MetricRecord]
86
- ] = None,
81
+ train_metrics_aggr_fn: (
82
+ Callable[[list[RecordDict], str], MetricRecord] | None
83
+ ) = None,
84
+ evaluate_metrics_aggr_fn: (
85
+ Callable[[list[RecordDict], str], MetricRecord] | None
86
+ ) = None,
87
87
  ) -> None:
88
88
  super().__init__(
89
89
  fraction_train=fraction_train,
@@ -184,7 +184,7 @@ class FedXgbCyclic(FedAvg):
184
184
  self,
185
185
  server_round: int,
186
186
  replies: Iterable[Message],
187
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
187
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
188
188
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
189
189
  valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
190
190