flwr 1.23.0__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (292) hide show
  1. flwr/__init__.py +16 -5
  2. flwr/app/error.py +2 -2
  3. flwr/app/exception.py +3 -3
  4. flwr/cli/app.py +19 -0
  5. flwr/cli/app_cmd/__init__.py +23 -0
  6. flwr/cli/app_cmd/publish.py +285 -0
  7. flwr/cli/app_cmd/review.py +252 -0
  8. flwr/cli/auth_plugin/auth_plugin.py +4 -5
  9. flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
  10. flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
  11. flwr/cli/build.py +60 -18
  12. flwr/cli/cli_account_auth_interceptor.py +24 -7
  13. flwr/cli/config_utils.py +101 -13
  14. flwr/cli/federation/__init__.py +24 -0
  15. flwr/cli/federation/ls.py +140 -0
  16. flwr/cli/federation/show.py +317 -0
  17. flwr/cli/install.py +91 -13
  18. flwr/cli/log.py +52 -9
  19. flwr/cli/login/login.py +7 -4
  20. flwr/cli/ls.py +170 -130
  21. flwr/cli/new/new.py +33 -50
  22. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
  23. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  25. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  26. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  28. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  29. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  30. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
  31. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  32. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  33. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
  34. flwr/cli/pull.py +10 -5
  35. flwr/cli/run/run.py +77 -30
  36. flwr/cli/run_utils.py +130 -0
  37. flwr/cli/stop.py +25 -7
  38. flwr/cli/supernode/ls.py +16 -8
  39. flwr/cli/supernode/register.py +9 -4
  40. flwr/cli/supernode/unregister.py +5 -3
  41. flwr/cli/utils.py +376 -16
  42. flwr/client/__init__.py +1 -1
  43. flwr/client/dpfedavg_numpy_client.py +4 -1
  44. flwr/client/grpc_adapter_client/connection.py +6 -7
  45. flwr/client/grpc_rere_client/connection.py +10 -11
  46. flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
  47. flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
  48. flwr/client/message_handler/message_handler.py +2 -2
  49. flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
  50. flwr/client/numpy_client.py +1 -1
  51. flwr/client/rest_client/connection.py +12 -14
  52. flwr/client/run_info_store.py +4 -5
  53. flwr/client/typing.py +1 -1
  54. flwr/clientapp/client_app.py +9 -10
  55. flwr/clientapp/mod/centraldp_mods.py +16 -17
  56. flwr/clientapp/mod/localdp_mod.py +8 -9
  57. flwr/clientapp/typing.py +1 -1
  58. flwr/clientapp/utils.py +3 -3
  59. flwr/common/address.py +1 -2
  60. flwr/common/args.py +3 -4
  61. flwr/common/config.py +13 -16
  62. flwr/common/constant.py +5 -2
  63. flwr/common/differential_privacy.py +3 -4
  64. flwr/common/event_log_plugin/event_log_plugin.py +3 -4
  65. flwr/common/exit/exit.py +15 -2
  66. flwr/common/exit/exit_code.py +19 -0
  67. flwr/common/exit/exit_handler.py +6 -2
  68. flwr/common/exit/signal_handler.py +5 -5
  69. flwr/common/grpc.py +6 -6
  70. flwr/common/inflatable_protobuf_utils.py +1 -1
  71. flwr/common/inflatable_utils.py +38 -21
  72. flwr/common/logger.py +19 -19
  73. flwr/common/message.py +4 -4
  74. flwr/common/object_ref.py +7 -7
  75. flwr/common/record/array.py +3 -3
  76. flwr/common/record/arrayrecord.py +18 -30
  77. flwr/common/record/configrecord.py +3 -3
  78. flwr/common/record/recorddict.py +5 -5
  79. flwr/common/record/typeddict.py +9 -2
  80. flwr/common/recorddict_compat.py +7 -10
  81. flwr/common/retry_invoker.py +20 -20
  82. flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
  83. flwr/common/serde.py +5 -4
  84. flwr/common/serde_utils.py +2 -2
  85. flwr/common/telemetry.py +9 -5
  86. flwr/common/typing.py +52 -37
  87. flwr/compat/client/app.py +38 -37
  88. flwr/compat/client/grpc_client/connection.py +11 -11
  89. flwr/compat/server/app.py +5 -6
  90. flwr/proto/appio_pb2.py +13 -3
  91. flwr/proto/appio_pb2.pyi +134 -65
  92. flwr/proto/appio_pb2_grpc.py +20 -0
  93. flwr/proto/appio_pb2_grpc.pyi +27 -0
  94. flwr/proto/clientappio_pb2.py +17 -7
  95. flwr/proto/clientappio_pb2.pyi +15 -0
  96. flwr/proto/clientappio_pb2_grpc.py +206 -40
  97. flwr/proto/clientappio_pb2_grpc.pyi +168 -53
  98. flwr/proto/control_pb2.py +71 -52
  99. flwr/proto/control_pb2.pyi +277 -111
  100. flwr/proto/control_pb2_grpc.py +249 -40
  101. flwr/proto/control_pb2_grpc.pyi +185 -52
  102. flwr/proto/error_pb2.py +13 -3
  103. flwr/proto/error_pb2.pyi +24 -6
  104. flwr/proto/error_pb2_grpc.py +20 -0
  105. flwr/proto/error_pb2_grpc.pyi +27 -0
  106. flwr/proto/fab_pb2.py +14 -4
  107. flwr/proto/fab_pb2.pyi +59 -31
  108. flwr/proto/fab_pb2_grpc.py +20 -0
  109. flwr/proto/fab_pb2_grpc.pyi +27 -0
  110. flwr/proto/federation_pb2.py +38 -0
  111. flwr/proto/federation_pb2.pyi +56 -0
  112. flwr/proto/federation_pb2_grpc.py +24 -0
  113. flwr/proto/federation_pb2_grpc.pyi +31 -0
  114. flwr/proto/fleet_pb2.py +14 -4
  115. flwr/proto/fleet_pb2.pyi +137 -61
  116. flwr/proto/fleet_pb2_grpc.py +189 -48
  117. flwr/proto/fleet_pb2_grpc.pyi +175 -61
  118. flwr/proto/grpcadapter_pb2.py +14 -4
  119. flwr/proto/grpcadapter_pb2.pyi +38 -16
  120. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  121. flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
  122. flwr/proto/heartbeat_pb2.py +17 -7
  123. flwr/proto/heartbeat_pb2.pyi +51 -22
  124. flwr/proto/heartbeat_pb2_grpc.py +20 -0
  125. flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
  126. flwr/proto/log_pb2.py +13 -3
  127. flwr/proto/log_pb2.pyi +34 -11
  128. flwr/proto/log_pb2_grpc.py +20 -0
  129. flwr/proto/log_pb2_grpc.pyi +27 -0
  130. flwr/proto/message_pb2.py +15 -5
  131. flwr/proto/message_pb2.pyi +154 -86
  132. flwr/proto/message_pb2_grpc.py +20 -0
  133. flwr/proto/message_pb2_grpc.pyi +27 -0
  134. flwr/proto/node_pb2.py +15 -5
  135. flwr/proto/node_pb2.pyi +50 -25
  136. flwr/proto/node_pb2_grpc.py +20 -0
  137. flwr/proto/node_pb2_grpc.pyi +27 -0
  138. flwr/proto/recorddict_pb2.py +13 -3
  139. flwr/proto/recorddict_pb2.pyi +184 -107
  140. flwr/proto/recorddict_pb2_grpc.py +20 -0
  141. flwr/proto/recorddict_pb2_grpc.pyi +27 -0
  142. flwr/proto/run_pb2.py +40 -31
  143. flwr/proto/run_pb2.pyi +149 -84
  144. flwr/proto/run_pb2_grpc.py +20 -0
  145. flwr/proto/run_pb2_grpc.pyi +27 -0
  146. flwr/proto/serverappio_pb2.py +13 -3
  147. flwr/proto/serverappio_pb2.pyi +32 -8
  148. flwr/proto/serverappio_pb2_grpc.py +246 -65
  149. flwr/proto/serverappio_pb2_grpc.pyi +221 -85
  150. flwr/proto/simulationio_pb2.py +16 -8
  151. flwr/proto/simulationio_pb2.pyi +15 -0
  152. flwr/proto/simulationio_pb2_grpc.py +162 -41
  153. flwr/proto/simulationio_pb2_grpc.pyi +149 -55
  154. flwr/proto/transport_pb2.py +20 -10
  155. flwr/proto/transport_pb2.pyi +249 -160
  156. flwr/proto/transport_pb2_grpc.py +35 -4
  157. flwr/proto/transport_pb2_grpc.pyi +38 -8
  158. flwr/server/app.py +38 -17
  159. flwr/server/client_manager.py +4 -5
  160. flwr/server/client_proxy.py +10 -11
  161. flwr/server/compat/app.py +4 -5
  162. flwr/server/compat/app_utils.py +2 -1
  163. flwr/server/compat/grid_client_proxy.py +10 -12
  164. flwr/server/compat/legacy_context.py +3 -4
  165. flwr/server/fleet_event_log_interceptor.py +2 -1
  166. flwr/server/grid/grid.py +2 -3
  167. flwr/server/grid/grpc_grid.py +10 -8
  168. flwr/server/grid/inmemory_grid.py +4 -4
  169. flwr/server/run_serverapp.py +2 -3
  170. flwr/server/server.py +34 -39
  171. flwr/server/server_app.py +7 -8
  172. flwr/server/server_config.py +1 -2
  173. flwr/server/serverapp/app.py +34 -28
  174. flwr/server/serverapp_components.py +4 -5
  175. flwr/server/strategy/aggregate.py +9 -8
  176. flwr/server/strategy/bulyan.py +13 -11
  177. flwr/server/strategy/dp_adaptive_clipping.py +16 -20
  178. flwr/server/strategy/dp_fixed_clipping.py +12 -17
  179. flwr/server/strategy/dpfedavg_adaptive.py +3 -4
  180. flwr/server/strategy/dpfedavg_fixed.py +6 -10
  181. flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
  182. flwr/server/strategy/fedadagrad.py +18 -14
  183. flwr/server/strategy/fedadam.py +16 -14
  184. flwr/server/strategy/fedavg.py +16 -17
  185. flwr/server/strategy/fedavg_android.py +15 -15
  186. flwr/server/strategy/fedavgm.py +21 -18
  187. flwr/server/strategy/fedmedian.py +2 -3
  188. flwr/server/strategy/fedopt.py +11 -10
  189. flwr/server/strategy/fedprox.py +10 -9
  190. flwr/server/strategy/fedtrimmedavg.py +12 -11
  191. flwr/server/strategy/fedxgb_bagging.py +13 -11
  192. flwr/server/strategy/fedxgb_cyclic.py +6 -6
  193. flwr/server/strategy/fedxgb_nn_avg.py +4 -4
  194. flwr/server/strategy/fedyogi.py +16 -14
  195. flwr/server/strategy/krum.py +12 -11
  196. flwr/server/strategy/qfedavg.py +16 -15
  197. flwr/server/strategy/strategy.py +6 -9
  198. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
  199. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
  200. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
  201. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
  202. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
  203. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
  204. flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
  205. flwr/server/superlink/fleet/message_handler/message_handler.py +34 -28
  206. flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
  207. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  208. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  209. flwr/server/superlink/fleet/vce/vce_api.py +15 -9
  210. flwr/server/superlink/linkstate/in_memory_linkstate.py +115 -150
  211. flwr/server/superlink/linkstate/linkstate.py +59 -43
  212. flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
  213. flwr/server/superlink/linkstate/sqlite_linkstate.py +447 -438
  214. flwr/server/superlink/linkstate/utils.py +6 -6
  215. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
  216. flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
  217. flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
  218. flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
  219. flwr/server/superlink/utils.py +4 -6
  220. flwr/server/typing.py +1 -1
  221. flwr/server/utils/tensorboard.py +15 -8
  222. flwr/server/workflow/default_workflows.py +5 -5
  223. flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
  224. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
  225. flwr/serverapp/strategy/bulyan.py +16 -15
  226. flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
  227. flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
  228. flwr/serverapp/strategy/fedadagrad.py +10 -11
  229. flwr/serverapp/strategy/fedadam.py +10 -11
  230. flwr/serverapp/strategy/fedavg.py +9 -10
  231. flwr/serverapp/strategy/fedavgm.py +17 -16
  232. flwr/serverapp/strategy/fedmedian.py +2 -2
  233. flwr/serverapp/strategy/fedopt.py +10 -11
  234. flwr/serverapp/strategy/fedprox.py +7 -8
  235. flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
  236. flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
  237. flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
  238. flwr/serverapp/strategy/fedyogi.py +9 -11
  239. flwr/serverapp/strategy/krum.py +7 -7
  240. flwr/serverapp/strategy/multikrum.py +9 -9
  241. flwr/serverapp/strategy/qfedavg.py +17 -16
  242. flwr/serverapp/strategy/strategy.py +6 -9
  243. flwr/serverapp/strategy/strategy_utils.py +7 -8
  244. flwr/simulation/app.py +46 -42
  245. flwr/simulation/legacy_app.py +12 -12
  246. flwr/simulation/ray_transport/ray_actor.py +10 -11
  247. flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
  248. flwr/simulation/run_simulation.py +43 -43
  249. flwr/simulation/simulationio_connection.py +4 -4
  250. flwr/supercore/cli/flower_superexec.py +3 -4
  251. flwr/supercore/constant.py +31 -1
  252. flwr/supercore/corestate/corestate.py +24 -3
  253. flwr/supercore/corestate/in_memory_corestate.py +138 -0
  254. flwr/supercore/corestate/sqlite_corestate.py +157 -0
  255. flwr/supercore/ffs/disk_ffs.py +1 -2
  256. flwr/supercore/ffs/ffs.py +1 -2
  257. flwr/supercore/ffs/ffs_factory.py +1 -2
  258. flwr/{common → supercore}/heartbeat.py +20 -25
  259. flwr/supercore/object_store/in_memory_object_store.py +1 -2
  260. flwr/supercore/object_store/object_store.py +1 -2
  261. flwr/supercore/object_store/object_store_factory.py +1 -2
  262. flwr/supercore/object_store/sqlite_object_store.py +8 -7
  263. flwr/supercore/primitives/asymmetric.py +1 -1
  264. flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
  265. flwr/supercore/sqlite_mixin.py +37 -34
  266. flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
  267. flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
  268. flwr/supercore/superexec/run_superexec.py +9 -13
  269. flwr/superlink/artifact_provider/artifact_provider.py +1 -2
  270. flwr/superlink/auth_plugin/auth_plugin.py +6 -9
  271. flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
  272. flwr/superlink/federation/__init__.py +24 -0
  273. flwr/superlink/federation/federation_manager.py +64 -0
  274. flwr/superlink/federation/noop_federation_manager.py +71 -0
  275. flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
  276. flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
  277. flwr/superlink/servicer/control/control_grpc.py +5 -6
  278. flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
  279. flwr/superlink/servicer/control/control_servicer.py +102 -18
  280. flwr/supernode/cli/flower_supernode.py +58 -3
  281. flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
  282. flwr/supernode/nodestate/nodestate.py +7 -8
  283. flwr/supernode/nodestate/nodestate_factory.py +7 -4
  284. flwr/supernode/runtime/run_clientapp.py +41 -22
  285. flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
  286. flwr/supernode/start_client_internal.py +158 -42
  287. {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
  288. flwr-1.24.0.dist-info/RECORD +454 -0
  289. flwr/supercore/object_store/utils.py +0 -43
  290. flwr-1.23.0.dist-info/RECORD +0 -439
  291. {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
  292. {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  from os import urandom
19
- from typing import Optional
20
19
 
21
20
  from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
22
21
  from flwr.common.constant import (
@@ -51,7 +50,8 @@ VALID_RUN_SUB_STATUSES = {
51
50
  }
52
51
  MESSAGE_UNAVAILABLE_ERROR_REASON = (
53
52
  "Error: Message Unavailable - The requested message could not be found in the "
54
- "database. It may have expired due to its TTL or never existed."
53
+ "database. It may have expired due to its TTL, been deleted because the "
54
+ "destination SuperNode was removed from the federation, or never existed."
55
55
  )
56
56
  REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
57
57
  "Error: Reply Message Unavailable - The reply message has expired."
@@ -63,7 +63,7 @@ NODE_UNAVAILABLE_ERROR_REASON = (
63
63
 
64
64
 
65
65
  def generate_rand_int_from_bytes(
66
- num_bytes: int, exclude: Optional[list[int]] = None
66
+ num_bytes: int, exclude: list[int] | None = None
67
67
  ) -> int:
68
68
  """Generate a random unsigned integer from `num_bytes` bytes.
69
69
 
@@ -257,7 +257,7 @@ def message_ttl_has_expired(message_metadata: Metadata, current_time: float) ->
257
257
  def verify_message_ids(
258
258
  inquired_message_ids: set[str],
259
259
  found_message_ins_dict: dict[str, Message],
260
- current_time: Optional[float] = None,
260
+ current_time: float | None = None,
261
261
  update_set: bool = True,
262
262
  ) -> dict[str, Message]:
263
263
  """Verify found Messages and generate error Messages for invalid ones.
@@ -300,7 +300,7 @@ def verify_found_message_replies(
300
300
  inquired_message_ids: set[str],
301
301
  found_message_ins_dict: dict[str, Message],
302
302
  found_message_res_list: list[Message],
303
- current_time: Optional[float] = None,
303
+ current_time: float | None = None,
304
304
  update_set: bool = True,
305
305
  ) -> dict[str, Message]:
306
306
  """Verify found Message replies and generate error Message for invalid ones.
@@ -345,7 +345,7 @@ def check_node_availability_for_in_message(
345
345
  inquired_in_message_ids: set[str],
346
346
  found_in_message_dict: dict[str, Message],
347
347
  node_id_to_online_until: dict[int, float],
348
- current_time: Optional[float] = None,
348
+ current_time: float | None = None,
349
349
  update_set: bool = True,
350
350
  ) -> dict[str, Message]:
351
351
  """Check node availability for given Message and generate error reply Message if
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  from logging import INFO
19
- from typing import Optional
20
19
 
21
20
  import grpc
22
21
 
@@ -38,7 +37,7 @@ def run_serverappio_api_grpc(
38
37
  state_factory: LinkStateFactory,
39
38
  ffs_factory: FfsFactory,
40
39
  objectstore_factory: ObjectStoreFactory,
41
- certificates: Optional[tuple[bytes, bytes, bytes]],
40
+ certificates: tuple[bytes, bytes, bytes] | None,
42
41
  ) -> grpc.Server:
43
42
  """Run ServerAppIo API (gRPC, request-response)."""
44
43
  # Create ServerAppIo API gRPC server
@@ -17,7 +17,6 @@
17
17
 
18
18
  import threading
19
19
  from logging import DEBUG, ERROR, INFO
20
- from typing import Optional
21
20
 
22
21
  import grpc
23
22
 
@@ -91,7 +90,6 @@ from flwr.server.superlink.utils import abort_if
91
90
  from flwr.server.utils.validator import validate_message
92
91
  from flwr.supercore.ffs import Ffs, FfsFactory
93
92
  from flwr.supercore.object_store import NoObjectInStoreError, ObjectStoreFactory
94
- from flwr.supercore.object_store.utils import store_mapping_and_register_objects
95
93
 
96
94
 
97
95
  class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
@@ -141,6 +139,13 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
141
139
  # Attempt to create a token for the provided run ID
142
140
  token = state.create_token(request.run_id)
143
141
 
142
+ # Transition the run to STARTING if token creation was successful
143
+ if token:
144
+ state.update_run_status(
145
+ run_id=request.run_id,
146
+ new_status=RunStatus(Status.STARTING, "", ""),
147
+ )
148
+
144
149
  # Return the token
145
150
  return RequestTokenResponse(token=token or "")
146
151
 
@@ -192,8 +197,11 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
192
197
  request_name="PushMessages",
193
198
  detail="`messages_list` must not be empty",
194
199
  )
195
- message_ids: list[Optional[str]] = []
196
- for message_proto in request.messages_list:
200
+ message_ids: list[str | None] = []
201
+ objects_to_push: set[str] = set()
202
+ for message_proto, object_tree in zip(
203
+ request.messages_list, request.message_object_trees, strict=True
204
+ ):
197
205
  message = message_from_proto(message_proto=message_proto)
198
206
  validation_errors = validate_message(message, is_reply_message=False)
199
207
  _raise_if(
@@ -206,13 +214,12 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
206
214
  request_name="PushMessages",
207
215
  detail="`Message.metadata` has mismatched `run_id`",
208
216
  )
209
- # Store
210
- message_id: Optional[str] = state.store_message_ins(message=message)
217
+ # Store objects
218
+ objects_to_push |= set(store.preregister(request.run_id, object_tree))
219
+ # Store message
220
+ message_id: str | None = state.store_message_ins(message=message)
211
221
  message_ids.append(message_id)
212
222
 
213
- # Store Message object to descendants mapping and preregister objects
214
- objects_to_push = store_mapping_and_register_objects(store, request=request)
215
-
216
223
  return PushAppMessagesResponse(
217
224
  message_ids=[
218
225
  str(message_id) if message_id else "" for message_id in message_ids
@@ -345,8 +352,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
345
352
  if result := ffs.get(run.fab_hash):
346
353
  fab = Fab(run.fab_hash, result[0], result[1])
347
354
  if run and fab and serverapp_ctxt:
348
- # Update run status to STARTING
349
- if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
355
+ # Update run status to RUNNING
356
+ if state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")):
350
357
  log(INFO, "Starting run %d", run_id)
351
358
  return PullAppInputsResponse(
352
359
  context=context_to_proto(serverapp_ctxt),
@@ -355,8 +362,12 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
355
362
  )
356
363
 
357
364
  # Raise an exception if the Run or Fab is not found,
358
- # or if the status cannot be updated to STARTING
359
- raise RuntimeError(f"Failed to start run {run_id}")
365
+ # or if the status cannot be updated to RUNNING
366
+ context.abort(
367
+ grpc.StatusCode.FAILED_PRECONDITION,
368
+ f"Failed to start run {run_id}",
369
+ )
370
+ raise RuntimeError("Unreachable code") # for mypy
360
371
 
361
372
  def PushAppOutputs(
362
373
  self, request: PushAppOutputsRequest, context: grpc.ServicerContext
@@ -441,20 +452,14 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
441
452
  def SendAppHeartbeat(
442
453
  self, request: SendAppHeartbeatRequest, context: grpc.ServicerContext
443
454
  ) -> SendAppHeartbeatResponse:
444
- """Handle a heartbeat from the ServerApp."""
455
+ """Handle a heartbeat from an app process."""
445
456
  log(DEBUG, "ServerAppIoServicer.SendAppHeartbeat")
446
457
 
447
458
  # Init state
448
459
  state = self.state_factory.state()
449
460
 
450
461
  # Acknowledge the heartbeat
451
- # The app heartbeat can only be acknowledged if the run is in
452
- # starting or running status.
453
- success = state.acknowledge_app_heartbeat(
454
- run_id=request.run_id,
455
- heartbeat_interval=request.heartbeat_interval,
456
- )
457
-
462
+ success = state.acknowledge_app_heartbeat(request.token)
458
463
  return SendAppHeartbeatResponse(success=success)
459
464
 
460
465
  def PushObject(
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  from logging import INFO
19
- from typing import Optional
20
19
 
21
20
  import grpc
22
21
 
@@ -36,7 +35,7 @@ def run_simulationio_api_grpc(
36
35
  address: str,
37
36
  state_factory: LinkStateFactory,
38
37
  ffs_factory: FfsFactory,
39
- certificates: Optional[tuple[bytes, bytes, bytes]],
38
+ certificates: tuple[bytes, bytes, bytes] | None,
40
39
  ) -> grpc.Server:
41
40
  """Run SimulationIo API (gRPC, request-response)."""
42
41
  # Create SimulationIo API gRPC server
@@ -110,6 +110,13 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
110
110
  # Attempt to create a token for the provided run ID
111
111
  token = state.create_token(request.run_id)
112
112
 
113
+ # Transition the run to STARTING if token creation was successful
114
+ if token:
115
+ state.update_run_status(
116
+ run_id=request.run_id,
117
+ new_status=RunStatus(Status.STARTING, "", ""),
118
+ )
119
+
113
120
  # Return the token
114
121
  return RequestTokenResponse(token=token or "")
115
122
 
@@ -152,8 +159,8 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
152
159
  if result := ffs.get(run.fab_hash):
153
160
  fab = Fab(run.fab_hash, result[0], result[1])
154
161
  if run and fab and serverapp_ctxt:
155
- # Update run status to STARTING
156
- if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
162
+ # Update run status to RUNNING
163
+ if state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")):
157
164
  log(INFO, "Starting run %d", run_id)
158
165
  return PullAppInputsResponse(
159
166
  context=context_to_proto(serverapp_ctxt),
@@ -162,8 +169,12 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
162
169
  )
163
170
 
164
171
  # Raise an exception if the Run or Fab is not found,
165
- # or if the status cannot be updated to STARTING
166
- raise RuntimeError(f"Failed to start run {run_id}")
172
+ # or if the status cannot be updated to RUNNING
173
+ context.abort(
174
+ grpc.StatusCode.FAILED_PRECONDITION,
175
+ f"Failed to start run {run_id}",
176
+ )
177
+ raise RuntimeError("Unreachable code") # for mypy
167
178
 
168
179
  def PushAppOutputs(
169
180
  self, request: PushAppOutputsRequest, context: ServicerContext
@@ -257,20 +268,14 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
257
268
  def SendAppHeartbeat(
258
269
  self, request: SendAppHeartbeatRequest, context: grpc.ServicerContext
259
270
  ) -> SendAppHeartbeatResponse:
260
- """Handle a heartbeat from the ServerApp in simulation."""
261
- log(DEBUG, "SimultionIoServicer.SendAppHeartbeat")
271
+ """Handle a heartbeat from an app process."""
272
+ log(DEBUG, "SimulationIoServicer.SendAppHeartbeat")
262
273
 
263
274
  # Init state
264
275
  state = self.state_factory.state()
265
276
 
266
277
  # Acknowledge the heartbeat
267
- # The app heartbeat can only be acknowledged if the run is in
268
- # starting or running status.
269
- success = state.acknowledge_app_heartbeat(
270
- run_id=request.run_id,
271
- heartbeat_interval=request.heartbeat_interval,
272
- )
273
-
278
+ success = state.acknowledge_app_heartbeat(request.token)
274
279
  return SendAppHeartbeatResponse(success=success)
275
280
 
276
281
  def _verify_token(self, token: str, context: grpc.ServicerContext) -> int:
@@ -15,8 +15,6 @@
15
15
  """SuperLink utilities."""
16
16
 
17
17
 
18
- from typing import Optional, Union
19
-
20
18
  import grpc
21
19
 
22
20
  from flwr.common.constant import Status, SubStatus
@@ -36,8 +34,8 @@ def check_abort(
36
34
  run_id: int,
37
35
  abort_status_list: list[str],
38
36
  state: LinkState,
39
- store: Optional[ObjectStore] = None,
40
- ) -> Union[str, None]:
37
+ store: ObjectStore | None = None,
38
+ ) -> str | None:
41
39
  """Check if the status of the provided `run_id` is in `abort_status_list`."""
42
40
  run_status: RunStatus = state.get_run_status({run_id})[run_id]
43
41
 
@@ -54,7 +52,7 @@ def check_abort(
54
52
  return None
55
53
 
56
54
 
57
- def abort_grpc_context(msg: Union[str, None], context: grpc.ServicerContext) -> None:
55
+ def abort_grpc_context(msg: str | None, context: grpc.ServicerContext) -> None:
58
56
  """Abort context with statuscode PERMISSION_DENIED if `msg` is not None."""
59
57
  if msg is not None:
60
58
  context.abort(grpc.StatusCode.PERMISSION_DENIED, msg)
@@ -64,7 +62,7 @@ def abort_if(
64
62
  run_id: int,
65
63
  abort_status_list: list[str],
66
64
  state: LinkState,
67
- store: Optional[ObjectStore],
65
+ store: ObjectStore | None,
68
66
  context: grpc.ServicerContext,
69
67
  ) -> None:
70
68
  """Abort context if status of the provided `run_id` is in `abort_status_list`."""
flwr/server/typing.py CHANGED
@@ -15,7 +15,7 @@
15
15
  """Custom types for Flower servers."""
16
16
 
17
17
 
18
- from typing import Callable
18
+ from collections.abc import Callable
19
19
 
20
20
  from flwr.common import Context
21
21
 
@@ -16,20 +16,16 @@
16
16
 
17
17
 
18
18
  import os
19
+ from collections.abc import Callable
19
20
  from datetime import datetime
20
21
  from logging import WARN
21
- from typing import Callable, Optional, Union, cast
22
+ from typing import cast
22
23
 
23
24
  from flwr.common import EvaluateRes, Scalar
24
25
  from flwr.common.logger import log
25
26
  from flwr.server.client_proxy import ClientProxy
26
27
  from flwr.server.strategy import Strategy
27
28
 
28
- try:
29
- import tensorflow as TF
30
- except ModuleNotFoundError:
31
- TF = None
32
-
33
29
  MISSING_EXTRA_TF = """
34
30
  Extra dependency required for using tensorboard are missing.
35
31
  The program will continue without tensorboard.
@@ -59,6 +55,17 @@ def tensorboard(logdir: str) -> Callable[[Strategy], Strategy]:
59
55
  # Variant 2
60
56
  strategy = tensorboard(logdir=LOGDIR)(FedAvg)()
61
57
  """
58
+ log(
59
+ WARN,
60
+ "The `tensorboard` function is deprecated and will be removed "
61
+ "in a future release.",
62
+ )
63
+ # Lazy import of TensorFlow to avoid slow import times
64
+ try:
65
+ import tensorflow as TF # pylint: disable=import-outside-toplevel
66
+ except ModuleNotFoundError:
67
+ TF = None # pylint: disable=invalid-name
68
+
62
69
  print(
63
70
  "\n\t\033[32mStart TensorBoard with the following parameters"
64
71
  f"\n\t$ tensorboard --logdir {logdir}\033[39m\n"
@@ -93,8 +100,8 @@ def tensorboard(logdir: str) -> Callable[[Strategy], Strategy]:
93
100
  self,
94
101
  server_round: int,
95
102
  results: list[tuple[ClientProxy, EvaluateRes]],
96
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
97
- ) -> tuple[Optional[float], dict[str, Scalar]]:
103
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
104
+ ) -> tuple[float | None, dict[str, Scalar]]:
98
105
  """Hooks into aggregate_evaluate for TensorBoard logging purpose."""
99
106
  # Execute decorated function and extract results for logging
100
107
  # They will be returned at the end of this function but also
@@ -18,7 +18,7 @@
18
18
  import io
19
19
  import timeit
20
20
  from logging import INFO, WARN
21
- from typing import Optional, Union, cast
21
+ from typing import cast
22
22
 
23
23
  import flwr.common.recorddict_compat as compat
24
24
  from flwr.common import (
@@ -47,8 +47,8 @@ class DefaultWorkflow:
47
47
 
48
48
  def __init__(
49
49
  self,
50
- fit_workflow: Optional[Workflow] = None,
51
- evaluate_workflow: Optional[Workflow] = None,
50
+ fit_workflow: Workflow | None = None,
51
+ evaluate_workflow: Workflow | None = None,
52
52
  ) -> None:
53
53
  if fit_workflow is None:
54
54
  fit_workflow = default_fit_workflow
@@ -275,7 +275,7 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
275
275
 
276
276
  # Aggregate training results
277
277
  results: list[tuple[ClientProxy, FitRes]] = []
278
- failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = []
278
+ failures: list[tuple[ClientProxy, FitRes] | BaseException] = []
279
279
  for msg in messages:
280
280
  if msg.has_content():
281
281
  proxy = node_id_to_proxy[msg.metadata.src_node_id]
@@ -357,7 +357,7 @@ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
357
357
 
358
358
  # Aggregate the evaluation results
359
359
  results: list[tuple[ClientProxy, EvaluateRes]] = []
360
- failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = []
360
+ failures: list[tuple[ClientProxy, EvaluateRes] | BaseException] = []
361
361
  for msg in messages:
362
362
  if msg.has_content():
363
363
  proxy = node_id_to_proxy[msg.metadata.src_node_id]
@@ -15,8 +15,6 @@
15
15
  """Workflow for the SecAgg protocol."""
16
16
 
17
17
 
18
- from typing import Optional, Union
19
-
20
18
  from .secaggplus_workflow import SecAggPlusWorkflow
21
19
 
22
20
 
@@ -94,13 +92,13 @@ class SecAggWorkflow(SecAggPlusWorkflow):
94
92
 
95
93
  def __init__( # pylint: disable=R0913
96
94
  self,
97
- reconstruction_threshold: Union[int, float],
95
+ reconstruction_threshold: int | float,
98
96
  *,
99
97
  max_weight: float = 1000.0,
100
98
  clipping_range: float = 8.0,
101
99
  quantization_range: int = 4194304,
102
100
  modulus_range: int = 4294967296,
103
- timeout: Optional[float] = None,
101
+ timeout: float | None = None,
104
102
  ) -> None:
105
103
  super().__init__(
106
104
  num_shares=1.0,
@@ -18,7 +18,7 @@
18
18
  import random
19
19
  from dataclasses import dataclass, field
20
20
  from logging import DEBUG, ERROR, INFO, WARN
21
- from typing import Optional, Union, cast
21
+ from typing import cast
22
22
 
23
23
  import flwr.common.recorddict_compat as compat
24
24
  from flwr.common import (
@@ -169,14 +169,14 @@ class SecAggPlusWorkflow:
169
169
 
170
170
  def __init__( # pylint: disable=R0913
171
171
  self,
172
- num_shares: Union[int, float],
173
- reconstruction_threshold: Union[int, float],
172
+ num_shares: int | float,
173
+ reconstruction_threshold: int | float,
174
174
  *,
175
175
  max_weight: float = 1000.0,
176
176
  clipping_range: float = 8.0,
177
177
  quantization_range: int = 4194304,
178
178
  modulus_range: int = 4294967296,
179
- timeout: Optional[float] = None,
179
+ timeout: float | None = None,
180
180
  ) -> None:
181
181
  self.num_shares = num_shares
182
182
  self.reconstruction_threshold = reconstruction_threshold
@@ -211,7 +211,7 @@ class SecAggPlusWorkflow:
211
211
 
212
212
  def _check_init_params(self) -> None: # pylint: disable=R0912
213
213
  # Check `num_shares`
214
- if not isinstance(self.num_shares, (int, float)):
214
+ if not isinstance(self.num_shares, (int | float)):
215
215
  raise TypeError("`num_shares` must be of type int or float.")
216
216
  if isinstance(self.num_shares, int):
217
217
  if self.num_shares == 1:
@@ -229,7 +229,7 @@ class SecAggPlusWorkflow:
229
229
  raise ValueError("`num_shares` as a float must be greater than 0.")
230
230
 
231
231
  # Check `reconstruction_threshold`
232
- if not isinstance(self.reconstruction_threshold, (int, float)):
232
+ if not isinstance(self.reconstruction_threshold, (int | float)):
233
233
  raise TypeError("`reconstruction_threshold` must be of type int or float.")
234
234
  if isinstance(self.reconstruction_threshold, int):
235
235
  if self.reconstruction_threshold == 1:
@@ -467,7 +467,7 @@ class SecAggPlusWorkflow:
467
467
  dsts += dst_lst
468
468
  ciphertexts += ctxt_lst
469
469
 
470
- for src, dst, ciphertext in zip(srcs, dsts, ciphertexts):
470
+ for src, dst, ciphertext in zip(srcs, dsts, ciphertexts, strict=True):
471
471
  if dst in fwd_ciphertexts:
472
472
  fwd_ciphertexts[dst].append(ciphertext)
473
473
  fwd_srcs[dst].append(src)
@@ -604,7 +604,7 @@ class SecAggPlusWorkflow:
604
604
  res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
605
605
  nids = cast(list[int], res_dict[Key.NODE_ID_LIST])
606
606
  shares = cast(list[bytes], res_dict[Key.SHARE_LIST])
607
- for owner_nid, share in zip(nids, shares):
607
+ for owner_nid, share in zip(nids, shares, strict=True):
608
608
  collected_shares_dict[owner_nid].append(share)
609
609
 
610
610
  # Remove masks for every active client after collect_masked_vectors stage
@@ -18,10 +18,9 @@ Paper: arxiv.org/abs/1802.07927
18
18
  """
19
19
 
20
20
 
21
- from collections import OrderedDict
22
- from collections.abc import Iterable
21
+ from collections.abc import Callable, Iterable
23
22
  from logging import INFO, WARN
24
- from typing import Callable, Optional, cast
23
+ from typing import cast
25
24
 
26
25
  import numpy as np
27
26
 
@@ -104,15 +103,15 @@ class Bulyan(FedAvg):
104
103
  weighted_by_key: str = "num-examples",
105
104
  arrayrecord_key: str = "arrays",
106
105
  configrecord_key: str = "config",
107
- train_metrics_aggr_fn: Optional[
108
- Callable[[list[RecordDict], str], MetricRecord]
109
- ] = None,
110
- evaluate_metrics_aggr_fn: Optional[
111
- Callable[[list[RecordDict], str], MetricRecord]
112
- ] = None,
113
- selection_rule: Optional[
114
- Callable[[list[RecordDict], int, int], list[RecordDict]]
115
- ] = None,
106
+ train_metrics_aggr_fn: (
107
+ Callable[[list[RecordDict], str], MetricRecord] | None
108
+ ) = None,
109
+ evaluate_metrics_aggr_fn: (
110
+ Callable[[list[RecordDict], str], MetricRecord] | None
111
+ ) = None,
112
+ selection_rule: (
113
+ Callable[[list[RecordDict], int, int], list[RecordDict]] | None
114
+ ) = None,
116
115
  ) -> None:
117
116
  super().__init__(
118
117
  fraction_train=fraction_train,
@@ -140,7 +139,7 @@ class Bulyan(FedAvg):
140
139
  self,
141
140
  server_round: int,
142
141
  replies: Iterable[Message],
143
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
142
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
144
143
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
145
144
  valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
146
145
 
@@ -175,7 +174,9 @@ class Bulyan(FedAvg):
175
174
  ]
176
175
 
177
176
  # Compute median
178
- median_ndarrays = [np.median(arr, axis=0) for arr in zip(*selected_ndarrays)]
177
+ median_ndarrays = [
178
+ np.median(arr, axis=0) for arr in zip(*selected_ndarrays, strict=True)
179
+ ]
179
180
 
180
181
  # Aggregate the beta closest weights element-wise
181
182
  aggregated_ndarrays = aggregate_n_closest_weights(
@@ -184,7 +185,7 @@ class Bulyan(FedAvg):
184
185
 
185
186
  # Convert to ArrayRecord
186
187
  arrays = ArrayRecord(
187
- OrderedDict(zip(array_keys, map(Array, aggregated_ndarrays)))
188
+ dict(zip(array_keys, map(Array, aggregated_ndarrays), strict=True))
188
189
  )
189
190
 
190
191
  # Aggregate MetricRecords
@@ -19,10 +19,8 @@ Paper (Andrew et al.): https://arxiv.org/abs/1905.03871
19
19
 
20
20
  import math
21
21
  from abc import ABC
22
- from collections import OrderedDict
23
22
  from collections.abc import Iterable
24
23
  from logging import INFO
25
- from typing import Optional
26
24
 
27
25
  import numpy as np
28
26
 
@@ -53,7 +51,7 @@ class DifferentialPrivacyAdaptiveBase(Strategy, ABC):
53
51
  initial_clipping_norm: float = 0.1,
54
52
  target_clipped_quantile: float = 0.5,
55
53
  clip_norm_lr: float = 0.2,
56
- clipped_count_stddev: Optional[float] = None,
54
+ clipped_count_stddev: float | None = None,
57
55
  ) -> None:
58
56
  super().__init__()
59
57
 
@@ -96,7 +94,7 @@ class DifferentialPrivacyAdaptiveBase(Strategy, ABC):
96
94
  add_gaussian_noise_inplace(nds, stdv)
97
95
  log(INFO, "aggregate_fit: central DP noise with %.4f stdev added", stdv)
98
96
  return ArrayRecord(
99
- OrderedDict({k: Array(v) for k, v in zip(aggregated.keys(), nds)})
97
+ {k: Array(v) for k, v in zip(aggregated.keys(), nds, strict=True)}
100
98
  )
101
99
 
102
100
  def _noisy_fraction(self, count: int, total: int) -> float:
@@ -115,7 +113,7 @@ class DifferentialPrivacyAdaptiveBase(Strategy, ABC):
115
113
 
116
114
  def aggregate_evaluate(
117
115
  self, server_round: int, replies: Iterable[Message]
118
- ) -> Optional[MetricRecord]:
116
+ ) -> MetricRecord | None:
119
117
  """Aggregate MetricRecords in the received Messages."""
120
118
  return self.strategy.aggregate_evaluate(server_round, replies)
121
119
 
@@ -136,7 +134,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
136
134
  initial_clipping_norm: float = 0.1,
137
135
  target_clipped_quantile: float = 0.5,
138
136
  clip_norm_lr: float = 0.2,
139
- clipped_count_stddev: Optional[float] = None,
137
+ clipped_count_stddev: float | None = None,
140
138
  ) -> None:
141
139
  super().__init__(
142
140
  strategy,
@@ -171,7 +169,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
171
169
 
172
170
  def aggregate_train(
173
171
  self, server_round: int, replies: Iterable[Message]
174
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
172
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
175
173
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
176
174
  if not validate_replies(replies, self.num_sampled_clients):
177
175
  return None, None
@@ -184,16 +182,19 @@ class DifferentialPrivacyServerSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
184
182
  for arr_name, record in reply.content.array_records.items():
185
183
  reply_nd = record.to_numpy_ndarrays()
186
184
  model_update = [
187
- np.subtract(x, y) for (x, y) in zip(reply_nd, current_nd)
185
+ np.subtract(x, y)
186
+ for (x, y) in zip(reply_nd, current_nd, strict=True)
188
187
  ]
189
188
  norm_bit = adaptive_clip_inputs_inplace(
190
189
  model_update, self.clipping_norm
191
190
  )
192
191
  clipped_indicator_count += int(norm_bit)
193
192
  # reconstruct array using clipped contribution from current round
194
- restored = [c + u for c, u in zip(current_nd, model_update)]
193
+ restored = [
194
+ c + u for c, u in zip(current_nd, model_update, strict=True)
195
+ ]
195
196
  reply.content[arr_name] = ArrayRecord(
196
- OrderedDict({k: Array(v) for k, v in zip(record.keys(), restored)})
197
+ {k: Array(v) for k, v in zip(record.keys(), restored, strict=True)}
197
198
  )
198
199
  log(
199
200
  INFO,
@@ -287,7 +288,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
287
288
 
288
289
  def aggregate_train(
289
290
  self, server_round: int, replies: Iterable[Message]
290
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
291
+ ) -> tuple[ArrayRecord | None, MetricRecord | None]:
291
292
  """Aggregate ArrayRecords and MetricRecords in the received Messages."""
292
293
  if not validate_replies(replies, self.num_sampled_clients):
293
294
  return None, None