flwr 1.23.0__py3-none-any.whl → 1.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (339) hide show
  1. flwr/__init__.py +16 -5
  2. flwr/app/error.py +2 -2
  3. flwr/app/exception.py +3 -3
  4. flwr/cli/app.py +19 -0
  5. flwr/cli/{new/templates → app_cmd}/__init__.py +9 -1
  6. flwr/cli/app_cmd/publish.py +285 -0
  7. flwr/cli/app_cmd/review.py +262 -0
  8. flwr/cli/auth_plugin/auth_plugin.py +4 -5
  9. flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
  10. flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
  11. flwr/cli/build.py +60 -18
  12. flwr/cli/cli_account_auth_interceptor.py +24 -7
  13. flwr/cli/config_utils.py +101 -13
  14. flwr/cli/{new/templates/app/code/flwr_tune → federation}/__init__.py +10 -1
  15. flwr/cli/federation/ls.py +140 -0
  16. flwr/cli/federation/show.py +318 -0
  17. flwr/cli/install.py +91 -13
  18. flwr/cli/log.py +52 -9
  19. flwr/cli/login/login.py +7 -4
  20. flwr/cli/ls.py +211 -130
  21. flwr/cli/new/new.py +123 -331
  22. flwr/cli/pull.py +10 -5
  23. flwr/cli/run/run.py +71 -29
  24. flwr/cli/run_utils.py +148 -0
  25. flwr/cli/stop.py +26 -8
  26. flwr/cli/supernode/ls.py +25 -12
  27. flwr/cli/supernode/register.py +9 -4
  28. flwr/cli/supernode/unregister.py +5 -3
  29. flwr/cli/utils.py +239 -16
  30. flwr/client/__init__.py +1 -1
  31. flwr/client/dpfedavg_numpy_client.py +4 -1
  32. flwr/client/grpc_adapter_client/connection.py +8 -9
  33. flwr/client/grpc_rere_client/connection.py +16 -14
  34. flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
  35. flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
  36. flwr/client/message_handler/message_handler.py +2 -2
  37. flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
  38. flwr/client/numpy_client.py +1 -1
  39. flwr/client/rest_client/connection.py +18 -18
  40. flwr/client/run_info_store.py +4 -5
  41. flwr/client/typing.py +1 -1
  42. flwr/clientapp/client_app.py +9 -10
  43. flwr/clientapp/mod/centraldp_mods.py +16 -17
  44. flwr/clientapp/mod/localdp_mod.py +8 -9
  45. flwr/clientapp/typing.py +1 -1
  46. flwr/clientapp/utils.py +3 -3
  47. flwr/common/address.py +1 -2
  48. flwr/common/args.py +3 -4
  49. flwr/common/config.py +13 -16
  50. flwr/common/constant.py +5 -2
  51. flwr/common/differential_privacy.py +3 -4
  52. flwr/common/event_log_plugin/event_log_plugin.py +3 -4
  53. flwr/common/exit/exit.py +15 -2
  54. flwr/common/exit/exit_code.py +19 -0
  55. flwr/common/exit/exit_handler.py +6 -2
  56. flwr/common/exit/signal_handler.py +5 -5
  57. flwr/common/grpc.py +6 -6
  58. flwr/common/inflatable_protobuf_utils.py +1 -1
  59. flwr/common/inflatable_utils.py +38 -21
  60. flwr/common/logger.py +19 -19
  61. flwr/common/message.py +4 -4
  62. flwr/common/object_ref.py +7 -7
  63. flwr/common/record/array.py +3 -3
  64. flwr/common/record/arrayrecord.py +18 -30
  65. flwr/common/record/configrecord.py +3 -3
  66. flwr/common/record/recorddict.py +5 -5
  67. flwr/common/record/typeddict.py +9 -2
  68. flwr/common/recorddict_compat.py +7 -10
  69. flwr/common/retry_invoker.py +20 -20
  70. flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
  71. flwr/common/serde.py +11 -4
  72. flwr/common/serde_utils.py +2 -2
  73. flwr/common/telemetry.py +9 -5
  74. flwr/common/typing.py +58 -37
  75. flwr/compat/client/app.py +38 -37
  76. flwr/compat/client/grpc_client/connection.py +11 -11
  77. flwr/compat/server/app.py +5 -6
  78. flwr/proto/appio_pb2.py +13 -3
  79. flwr/proto/appio_pb2.pyi +134 -65
  80. flwr/proto/appio_pb2_grpc.py +20 -0
  81. flwr/proto/appio_pb2_grpc.pyi +27 -0
  82. flwr/proto/clientappio_pb2.py +17 -7
  83. flwr/proto/clientappio_pb2.pyi +15 -0
  84. flwr/proto/clientappio_pb2_grpc.py +206 -40
  85. flwr/proto/clientappio_pb2_grpc.pyi +168 -53
  86. flwr/proto/control_pb2.py +71 -52
  87. flwr/proto/control_pb2.pyi +277 -111
  88. flwr/proto/control_pb2_grpc.py +249 -40
  89. flwr/proto/control_pb2_grpc.pyi +185 -52
  90. flwr/proto/error_pb2.py +13 -3
  91. flwr/proto/error_pb2.pyi +24 -6
  92. flwr/proto/error_pb2_grpc.py +20 -0
  93. flwr/proto/error_pb2_grpc.pyi +27 -0
  94. flwr/proto/fab_pb2.py +14 -4
  95. flwr/proto/fab_pb2.pyi +59 -31
  96. flwr/proto/fab_pb2_grpc.py +20 -0
  97. flwr/proto/fab_pb2_grpc.pyi +27 -0
  98. flwr/proto/federation_pb2.py +38 -0
  99. flwr/proto/federation_pb2.pyi +56 -0
  100. flwr/proto/federation_pb2_grpc.py +24 -0
  101. flwr/proto/federation_pb2_grpc.pyi +31 -0
  102. flwr/proto/fleet_pb2.py +24 -14
  103. flwr/proto/fleet_pb2.pyi +141 -61
  104. flwr/proto/fleet_pb2_grpc.py +189 -48
  105. flwr/proto/fleet_pb2_grpc.pyi +175 -61
  106. flwr/proto/grpcadapter_pb2.py +14 -4
  107. flwr/proto/grpcadapter_pb2.pyi +38 -16
  108. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  109. flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
  110. flwr/proto/heartbeat_pb2.py +17 -7
  111. flwr/proto/heartbeat_pb2.pyi +51 -22
  112. flwr/proto/heartbeat_pb2_grpc.py +20 -0
  113. flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
  114. flwr/proto/log_pb2.py +13 -3
  115. flwr/proto/log_pb2.pyi +34 -11
  116. flwr/proto/log_pb2_grpc.py +20 -0
  117. flwr/proto/log_pb2_grpc.pyi +27 -0
  118. flwr/proto/message_pb2.py +15 -5
  119. flwr/proto/message_pb2.pyi +154 -86
  120. flwr/proto/message_pb2_grpc.py +20 -0
  121. flwr/proto/message_pb2_grpc.pyi +27 -0
  122. flwr/proto/node_pb2.py +15 -5
  123. flwr/proto/node_pb2.pyi +50 -25
  124. flwr/proto/node_pb2_grpc.py +20 -0
  125. flwr/proto/node_pb2_grpc.pyi +27 -0
  126. flwr/proto/recorddict_pb2.py +13 -3
  127. flwr/proto/recorddict_pb2.pyi +184 -107
  128. flwr/proto/recorddict_pb2_grpc.py +20 -0
  129. flwr/proto/recorddict_pb2_grpc.pyi +27 -0
  130. flwr/proto/run_pb2.py +40 -31
  131. flwr/proto/run_pb2.pyi +158 -84
  132. flwr/proto/run_pb2_grpc.py +20 -0
  133. flwr/proto/run_pb2_grpc.pyi +27 -0
  134. flwr/proto/serverappio_pb2.py +13 -3
  135. flwr/proto/serverappio_pb2.pyi +32 -8
  136. flwr/proto/serverappio_pb2_grpc.py +246 -65
  137. flwr/proto/serverappio_pb2_grpc.pyi +221 -85
  138. flwr/proto/simulationio_pb2.py +16 -8
  139. flwr/proto/simulationio_pb2.pyi +15 -0
  140. flwr/proto/simulationio_pb2_grpc.py +162 -41
  141. flwr/proto/simulationio_pb2_grpc.pyi +149 -55
  142. flwr/proto/transport_pb2.py +20 -10
  143. flwr/proto/transport_pb2.pyi +249 -160
  144. flwr/proto/transport_pb2_grpc.py +35 -4
  145. flwr/proto/transport_pb2_grpc.pyi +38 -8
  146. flwr/server/app.py +39 -17
  147. flwr/server/client_manager.py +4 -5
  148. flwr/server/client_proxy.py +10 -11
  149. flwr/server/compat/app.py +4 -5
  150. flwr/server/compat/app_utils.py +2 -1
  151. flwr/server/compat/grid_client_proxy.py +10 -12
  152. flwr/server/compat/legacy_context.py +3 -4
  153. flwr/server/fleet_event_log_interceptor.py +2 -1
  154. flwr/server/grid/grid.py +2 -3
  155. flwr/server/grid/grpc_grid.py +10 -8
  156. flwr/server/grid/inmemory_grid.py +4 -4
  157. flwr/server/run_serverapp.py +2 -3
  158. flwr/server/server.py +34 -39
  159. flwr/server/server_app.py +7 -8
  160. flwr/server/server_config.py +1 -2
  161. flwr/server/serverapp/app.py +34 -28
  162. flwr/server/serverapp_components.py +4 -5
  163. flwr/server/strategy/aggregate.py +9 -8
  164. flwr/server/strategy/bulyan.py +13 -11
  165. flwr/server/strategy/dp_adaptive_clipping.py +16 -20
  166. flwr/server/strategy/dp_fixed_clipping.py +12 -17
  167. flwr/server/strategy/dpfedavg_adaptive.py +3 -4
  168. flwr/server/strategy/dpfedavg_fixed.py +6 -10
  169. flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
  170. flwr/server/strategy/fedadagrad.py +18 -14
  171. flwr/server/strategy/fedadam.py +16 -14
  172. flwr/server/strategy/fedavg.py +16 -17
  173. flwr/server/strategy/fedavg_android.py +15 -15
  174. flwr/server/strategy/fedavgm.py +21 -18
  175. flwr/server/strategy/fedmedian.py +2 -3
  176. flwr/server/strategy/fedopt.py +11 -10
  177. flwr/server/strategy/fedprox.py +10 -9
  178. flwr/server/strategy/fedtrimmedavg.py +12 -11
  179. flwr/server/strategy/fedxgb_bagging.py +13 -11
  180. flwr/server/strategy/fedxgb_cyclic.py +6 -6
  181. flwr/server/strategy/fedxgb_nn_avg.py +4 -4
  182. flwr/server/strategy/fedyogi.py +16 -14
  183. flwr/server/strategy/krum.py +12 -11
  184. flwr/server/strategy/qfedavg.py +16 -15
  185. flwr/server/strategy/strategy.py +6 -9
  186. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
  187. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
  190. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
  192. flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
  193. flwr/server/superlink/fleet/message_handler/message_handler.py +75 -30
  194. flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
  195. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  196. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  197. flwr/server/superlink/fleet/vce/vce_api.py +15 -9
  198. flwr/server/superlink/linkstate/in_memory_linkstate.py +148 -149
  199. flwr/server/superlink/linkstate/linkstate.py +91 -43
  200. flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
  201. flwr/server/superlink/linkstate/sqlite_linkstate.py +502 -436
  202. flwr/server/superlink/linkstate/utils.py +6 -6
  203. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
  204. flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
  205. flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
  206. flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
  207. flwr/server/superlink/utils.py +4 -6
  208. flwr/server/typing.py +1 -1
  209. flwr/server/utils/tensorboard.py +15 -8
  210. flwr/server/workflow/default_workflows.py +5 -5
  211. flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
  213. flwr/serverapp/strategy/bulyan.py +16 -15
  214. flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
  215. flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
  216. flwr/serverapp/strategy/fedadagrad.py +10 -11
  217. flwr/serverapp/strategy/fedadam.py +10 -11
  218. flwr/serverapp/strategy/fedavg.py +9 -10
  219. flwr/serverapp/strategy/fedavgm.py +17 -16
  220. flwr/serverapp/strategy/fedmedian.py +2 -2
  221. flwr/serverapp/strategy/fedopt.py +10 -11
  222. flwr/serverapp/strategy/fedprox.py +7 -8
  223. flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
  224. flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
  225. flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
  226. flwr/serverapp/strategy/fedyogi.py +9 -11
  227. flwr/serverapp/strategy/krum.py +7 -7
  228. flwr/serverapp/strategy/multikrum.py +9 -9
  229. flwr/serverapp/strategy/qfedavg.py +17 -16
  230. flwr/serverapp/strategy/strategy.py +6 -9
  231. flwr/serverapp/strategy/strategy_utils.py +7 -8
  232. flwr/simulation/app.py +46 -42
  233. flwr/simulation/legacy_app.py +12 -12
  234. flwr/simulation/ray_transport/ray_actor.py +10 -11
  235. flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
  236. flwr/simulation/run_simulation.py +43 -43
  237. flwr/simulation/simulationio_connection.py +4 -4
  238. flwr/supercore/cli/flower_superexec.py +3 -4
  239. flwr/supercore/constant.py +34 -1
  240. flwr/supercore/corestate/corestate.py +24 -3
  241. flwr/supercore/corestate/in_memory_corestate.py +138 -0
  242. flwr/supercore/corestate/sqlite_corestate.py +157 -0
  243. flwr/supercore/ffs/disk_ffs.py +1 -2
  244. flwr/supercore/ffs/ffs.py +1 -2
  245. flwr/supercore/ffs/ffs_factory.py +1 -2
  246. flwr/{common → supercore}/heartbeat.py +20 -25
  247. flwr/supercore/object_store/in_memory_object_store.py +1 -2
  248. flwr/supercore/object_store/object_store.py +1 -2
  249. flwr/supercore/object_store/object_store_factory.py +1 -2
  250. flwr/supercore/object_store/sqlite_object_store.py +8 -7
  251. flwr/supercore/primitives/asymmetric.py +1 -1
  252. flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
  253. flwr/supercore/sqlite_mixin.py +37 -34
  254. flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
  255. flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
  256. flwr/supercore/superexec/run_superexec.py +9 -13
  257. flwr/supercore/utils.py +190 -0
  258. flwr/superlink/artifact_provider/artifact_provider.py +1 -2
  259. flwr/superlink/auth_plugin/auth_plugin.py +6 -9
  260. flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
  261. flwr/{cli/new/templates/app → superlink/federation}/__init__.py +10 -1
  262. flwr/superlink/federation/federation_manager.py +64 -0
  263. flwr/superlink/federation/noop_federation_manager.py +71 -0
  264. flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
  265. flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
  266. flwr/superlink/servicer/control/control_grpc.py +7 -6
  267. flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
  268. flwr/superlink/servicer/control/control_servicer.py +190 -23
  269. flwr/supernode/cli/flower_supernode.py +58 -3
  270. flwr/supernode/nodestate/in_memory_nodestate.py +121 -49
  271. flwr/supernode/nodestate/nodestate.py +52 -8
  272. flwr/supernode/nodestate/nodestate_factory.py +7 -4
  273. flwr/supernode/runtime/run_clientapp.py +41 -22
  274. flwr/supernode/servicer/clientappio/clientappio_servicer.py +46 -10
  275. flwr/supernode/start_client_internal.py +165 -46
  276. {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/METADATA +9 -11
  277. flwr-1.25.0.dist-info/RECORD +393 -0
  278. flwr/cli/new/templates/app/.gitignore.tpl +0 -163
  279. flwr/cli/new/templates/app/LICENSE.tpl +0 -202
  280. flwr/cli/new/templates/app/README.baseline.md.tpl +0 -127
  281. flwr/cli/new/templates/app/README.flowertune.md.tpl +0 -68
  282. flwr/cli/new/templates/app/README.md.tpl +0 -37
  283. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +0 -1
  284. flwr/cli/new/templates/app/code/__init__.py +0 -15
  285. flwr/cli/new/templates/app/code/__init__.py.tpl +0 -1
  286. flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +0 -1
  287. flwr/cli/new/templates/app/code/client.baseline.py.tpl +0 -75
  288. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +0 -93
  289. flwr/cli/new/templates/app/code/client.jax.py.tpl +0 -71
  290. flwr/cli/new/templates/app/code/client.mlx.py.tpl +0 -102
  291. flwr/cli/new/templates/app/code/client.numpy.py.tpl +0 -46
  292. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +0 -80
  293. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +0 -55
  294. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +0 -108
  295. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -82
  296. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +0 -110
  297. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +0 -36
  298. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +0 -92
  299. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +0 -87
  300. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -56
  301. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +0 -73
  302. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +0 -78
  303. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -66
  304. flwr/cli/new/templates/app/code/server.baseline.py.tpl +0 -43
  305. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +0 -42
  306. flwr/cli/new/templates/app/code/server.jax.py.tpl +0 -39
  307. flwr/cli/new/templates/app/code/server.mlx.py.tpl +0 -41
  308. flwr/cli/new/templates/app/code/server.numpy.py.tpl +0 -38
  309. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +0 -41
  310. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +0 -31
  311. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +0 -44
  312. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +0 -38
  313. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +0 -56
  314. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +0 -1
  315. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +0 -98
  316. flwr/cli/new/templates/app/code/task.jax.py.tpl +0 -57
  317. flwr/cli/new/templates/app/code/task.mlx.py.tpl +0 -102
  318. flwr/cli/new/templates/app/code/task.numpy.py.tpl +0 -7
  319. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +0 -98
  320. flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +0 -111
  321. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +0 -67
  322. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +0 -52
  323. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +0 -67
  324. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +0 -1
  325. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +0 -146
  326. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +0 -80
  327. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +0 -65
  328. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +0 -52
  329. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +0 -56
  330. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +0 -49
  331. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +0 -53
  332. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +0 -53
  333. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +0 -52
  334. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +0 -53
  335. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +0 -61
  336. flwr/supercore/object_store/utils.py +0 -43
  337. flwr-1.23.0.dist-info/RECORD +0 -439
  338. {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/WHEEL +0 -0
  339. {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/entry_points.txt +0 -0
@@ -38,12 +38,14 @@ class ExitCode:
38
38
  SERVERAPP_STRATEGY_PRECONDITION_UNMET = 200
39
39
  SERVERAPP_EXCEPTION = 201
40
40
  SERVERAPP_STRATEGY_AGGREGATION_ERROR = 202
41
+ SERVERAPP_RUN_START_REJECTED = 203
41
42
 
42
43
  # SuperNode-specific exit codes (300-399)
43
44
  SUPERNODE_REST_ADDRESS_INVALID = 300
44
45
  # SUPERNODE_NODE_AUTH_KEYS_REQUIRED = 301 --- DELETED ---
45
46
  SUPERNODE_NODE_AUTH_KEY_INVALID = 302
46
47
  SUPERNODE_STARTED_WITHOUT_TLS_BUT_NODE_AUTH_ENABLED = 303
48
+ SUPERNODE_INVALID_TRUSTED_ENTITIES = 304
47
49
 
48
50
  # SuperExec-specific exit codes (400-499)
49
51
  SUPEREXEC_INVALID_PLUGIN_CONFIG = 400
@@ -56,6 +58,9 @@ class ExitCode:
56
58
  COMMON_MISSING_EXTRA_REST = 601
57
59
  COMMON_TLS_NOT_SUPPORTED = 602
58
60
 
61
+ # Simulation exit codes (700-799)
62
+ SIMULATION_EXCEPTION = 700
63
+
59
64
  def __new__(cls) -> ExitCode:
60
65
  """Prevent instantiation."""
61
66
  raise TypeError(f"{cls.__name__} cannot be instantiated.")
@@ -101,6 +106,11 @@ EXIT_CODE_HELP = {
101
106
  "The strategy encountered an error during aggregation. Please check the logs "
102
107
  "for more details."
103
108
  ),
109
+ ExitCode.SERVERAPP_RUN_START_REJECTED: (
110
+ "The SuperLink rejected the request to start the run. This may occur if the "
111
+ "run has been stopped, the run ID or FAB is invalid, or the run failed to "
112
+ "start within the allowed time."
113
+ ),
104
114
  # SuperNode-specific exit codes (300-399)
105
115
  ExitCode.SUPERNODE_REST_ADDRESS_INVALID: (
106
116
  "When using the REST API, please provide `https://` or "
@@ -115,6 +125,11 @@ EXIT_CODE_HELP = {
115
125
  "The private key for SuperNode authentication was provided, but TLS is not "
116
126
  "enabled. Node authentication can only be used when TLS is enabled."
117
127
  ),
128
+ ExitCode.SUPERNODE_INVALID_TRUSTED_ENTITIES: (
129
+ "Failed to read the trusted entities YAML file. "
130
+ "Please ensure that a valid file is provided using "
131
+ "the `--trusted-entities` option."
132
+ ),
118
133
  # SuperExec-specific exit codes (400-499)
119
134
  ExitCode.SUPEREXEC_INVALID_PLUGIN_CONFIG: (
120
135
  "The YAML configuration for the SuperExec plugin is invalid."
@@ -138,4 +153,8 @@ To use the REST API, install `flwr` with the `rest` extra:
138
153
  `pip install "flwr[rest]"`.
139
154
  """,
140
155
  ExitCode.COMMON_TLS_NOT_SUPPORTED: "Please use the '--insecure' flag.",
156
+ # Simulation exit codes (700-799)
157
+ ExitCode.SIMULATION_EXCEPTION: (
158
+ "An unhandled exception occurred when running the simulation."
159
+ ),
141
160
  }
@@ -17,7 +17,7 @@
17
17
 
18
18
  import signal
19
19
  import threading
20
- from typing import Callable
20
+ from collections.abc import Callable
21
21
 
22
22
  from .exit_code import ExitCode
23
23
 
@@ -58,5 +58,9 @@ def trigger_exit_handlers() -> None:
58
58
  """Trigger all registered exit handlers in LIFO order."""
59
59
  with _lock_handlers:
60
60
  for handler in reversed(registered_exit_handlers):
61
- handler()
61
+ try:
62
+ handler()
63
+ except Exception: # pylint: disable=broad-exception-caught
64
+ # Ignore exceptions in exit handlers
65
+ pass
62
66
  registered_exit_handlers.clear()
@@ -16,9 +16,9 @@
16
16
 
17
17
 
18
18
  import signal
19
+ from collections.abc import Callable
19
20
  from threading import Thread
20
21
  from types import FrameType
21
- from typing import Callable, Optional
22
22
 
23
23
  from grpc import Server
24
24
 
@@ -40,10 +40,10 @@ if hasattr(signal, "SIGQUIT"):
40
40
 
41
41
  def register_signal_handlers(
42
42
  event_type: EventType,
43
- exit_message: Optional[str] = None,
44
- grpc_servers: Optional[list[Server]] = None,
45
- bckg_threads: Optional[list[Thread]] = None,
46
- exit_handlers: Optional[list[Callable[[], None]]] = None,
43
+ exit_message: str | None = None,
44
+ grpc_servers: list[Server] | None = None,
45
+ bckg_threads: list[Thread] | None = None,
46
+ exit_handlers: list[Callable[[], None]] | None = None,
47
47
  ) -> None:
48
48
  """Register exit handlers for `SIGINT`, `SIGTERM` and `SIGQUIT` signals.
49
49
 
flwr/common/grpc.py CHANGED
@@ -18,9 +18,9 @@
18
18
  import concurrent.futures
19
19
  import os
20
20
  import sys
21
- from collections.abc import Sequence
21
+ from collections.abc import Callable, Sequence
22
22
  from logging import DEBUG, ERROR
23
- from typing import Any, Callable, Optional
23
+ from typing import Any
24
24
 
25
25
  import grpc
26
26
 
@@ -46,9 +46,9 @@ if "GRPC_VERBOSITY" not in os.environ:
46
46
  def create_channel(
47
47
  server_address: str,
48
48
  insecure: bool,
49
- root_certificates: Optional[bytes] = None,
49
+ root_certificates: bytes | None = None,
50
50
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
51
- interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None,
51
+ interceptors: Sequence[grpc.UnaryUnaryClientInterceptor] | None = None,
52
52
  ) -> grpc.Channel:
53
53
  """Create a gRPC channel, either secure or insecure."""
54
54
  # Check for conflicting parameters
@@ -104,8 +104,8 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments, R0914, R0
104
104
  max_concurrent_workers: int = 1000,
105
105
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
106
106
  keepalive_time_ms: int = 210000,
107
- certificates: Optional[tuple[bytes, bytes, bytes]] = None,
108
- interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
107
+ certificates: tuple[bytes, bytes, bytes] | None = None,
108
+ interceptors: Sequence[grpc.ServerInterceptor] | None = None,
109
109
  ) -> grpc.Server:
110
110
  """Create a gRPC server with a single servicer.
111
111
 
@@ -15,7 +15,7 @@
15
15
  """InflatableObject gRPC utils."""
16
16
 
17
17
 
18
- from typing import Callable
18
+ from collections.abc import Callable
19
19
 
20
20
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
21
21
  ConfirmMessageReceivedRequest,
@@ -14,13 +14,15 @@
14
14
  # ==============================================================================
15
15
  """InflatableObject utilities."""
16
16
 
17
+
17
18
  import concurrent.futures
18
19
  import os
19
20
  import random
20
21
  import threading
21
22
  import time
22
- from collections.abc import Iterable, Iterator
23
- from typing import Callable, Optional, TypeVar
23
+ from collections.abc import Callable, Iterable, Iterator
24
+ from queue import Queue
25
+ from typing import TypeVar
24
26
 
25
27
  from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
26
28
 
@@ -116,7 +118,7 @@ def push_objects(
116
118
  objects: dict[str, InflatableObject],
117
119
  push_object_fn: Callable[[str, bytes], None],
118
120
  *,
119
- object_ids_to_push: Optional[set[str]] = None,
121
+ object_ids_to_push: set[str] | None = None,
120
122
  keep_objects: bool = False,
121
123
  max_concurrent_pushes: int = FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES,
122
124
  ) -> None:
@@ -184,12 +186,21 @@ def push_object_contents_from_iterable(
184
186
  max_concurrent_pushes : int (default: FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES)
185
187
  The maximum number of concurrent pushes to perform.
186
188
  """
189
+ error_event = threading.Event()
190
+ err_queue: Queue[Exception] = Queue()
187
191
 
188
192
  def push(args: tuple[str, bytes]) -> None:
189
193
  """Push a single object."""
194
+ if error_event.is_set():
195
+ return
190
196
  obj_id, obj_content = args
191
197
  # Push the object using the provided function
192
- push_object_fn(obj_id, obj_content)
198
+ try:
199
+ push_object_fn(obj_id, obj_content)
200
+ except Exception as err: # pylint: disable=broad-except
201
+ # Unexpected error during pushing
202
+ error_event.set()
203
+ err_queue.put(err)
193
204
 
194
205
  # Push all object contents concurrently
195
206
  num_workers = get_num_workers(max_concurrent_pushes)
@@ -205,14 +216,18 @@ def push_object_contents_from_iterable(
205
216
  # Remove the executor from the list of tracked executors
206
217
  _untrack_executor(executor)
207
218
 
219
+ # If an error occurred during pushing, raise it
220
+ if not err_queue.empty():
221
+ raise err_queue.get()
222
+
208
223
 
209
224
  def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
210
225
  object_ids: list[str],
211
226
  pull_object_fn: Callable[[str], bytes],
212
227
  *,
213
228
  max_concurrent_pulls: int = FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS,
214
- max_time: Optional[float] = PULL_MAX_TIME,
215
- max_tries_per_object: Optional[int] = PULL_MAX_TRIES_PER_OBJECT,
229
+ max_time: float | None = PULL_MAX_TIME,
230
+ max_tries_per_object: int | None = PULL_MAX_TRIES_PER_OBJECT,
216
231
  initial_backoff: float = PULL_INITIAL_BACKOFF,
217
232
  backoff_cap: float = PULL_BACKOFF_CAP,
218
233
  ) -> dict[str, bytes]:
@@ -254,13 +269,16 @@ def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
254
269
 
255
270
  results: dict[str, bytes] = {}
256
271
  results_lock = threading.Lock()
257
- err_to_raise: Optional[Exception] = None
272
+ err_queue: Queue[Exception] = Queue()
258
273
  early_stop = threading.Event()
259
274
  start = time.monotonic()
260
275
 
276
+ def stop_on_error(err: Exception) -> None:
277
+ early_stop.set()
278
+ err_queue.put(err)
279
+
261
280
  def pull_with_retries(object_id: str) -> None:
262
281
  """Attempt to pull a single object with retry and backoff."""
263
- nonlocal err_to_raise
264
282
  tries = 0
265
283
  delay = initial_backoff
266
284
 
@@ -278,10 +296,7 @@ def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
278
296
  or time.monotonic() - start >= max_time
279
297
  ):
280
298
  # Stop all work if one object exhausts retries
281
- early_stop.set()
282
- with results_lock:
283
- if err_to_raise is None:
284
- err_to_raise = err
299
+ stop_on_error(err)
285
300
  return
286
301
 
287
302
  # Apply exponential backoff with ±20% jitter
@@ -291,10 +306,12 @@ def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
291
306
 
292
307
  except ObjectIdNotPreregisteredError as err:
293
308
  # Permanent failure: object ID is invalid
294
- early_stop.set()
295
- with results_lock:
296
- if err_to_raise is None:
297
- err_to_raise = err
309
+ stop_on_error(err)
310
+ return
311
+
312
+ except Exception as err: # pylint: disable=broad-except
313
+ # Permanent failure: unexpected error
314
+ stop_on_error(err)
298
315
  return
299
316
 
300
317
  # Submit all pull tasks concurrently
@@ -312,8 +329,8 @@ def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
312
329
  _untrack_executor(executor)
313
330
 
314
331
  # If an error occurred during pulling, raise it
315
- if err_to_raise is not None:
316
- raise err_to_raise
332
+ if not err_queue.empty():
333
+ raise err_queue.get()
317
334
 
318
335
  return results
319
336
 
@@ -323,7 +340,7 @@ def inflate_object_from_contents(
323
340
  object_contents: dict[str, bytes],
324
341
  *,
325
342
  keep_object_contents: bool = False,
326
- objects: Optional[dict[str, InflatableObject]] = None,
343
+ objects: dict[str, InflatableObject] | None = None,
327
344
  ) -> InflatableObject:
328
345
  """Inflate an object from object contents.
329
346
 
@@ -443,8 +460,8 @@ def pull_and_inflate_object_from_tree( # pylint: disable=R0913
443
460
  *,
444
461
  return_type: type[T] = InflatableObject, # type: ignore
445
462
  max_concurrent_pulls: int = FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS,
446
- max_time: Optional[float] = PULL_MAX_TIME,
447
- max_tries_per_object: Optional[int] = PULL_MAX_TRIES_PER_OBJECT,
463
+ max_time: float | None = PULL_MAX_TIME,
464
+ max_tries_per_object: int | None = PULL_MAX_TRIES_PER_OBJECT,
448
465
  initial_backoff: float = PULL_INITIAL_BACKOFF,
449
466
  backoff_cap: float = PULL_BACKOFF_CAP,
450
467
  ) -> T:
flwr/common/logger.py CHANGED
@@ -26,7 +26,7 @@ from io import StringIO
26
26
  from logging import ERROR, WARN, LogRecord
27
27
  from logging.handlers import HTTPHandler
28
28
  from queue import Empty, Queue
29
- from typing import TYPE_CHECKING, Any, Optional, TextIO, Union
29
+ from typing import TYPE_CHECKING, Any, TextIO
30
30
 
31
31
  import grpc
32
32
  import typer
@@ -68,7 +68,7 @@ class ConsoleHandler(StreamHandler):
68
68
  timestamps: bool = False,
69
69
  json: bool = False,
70
70
  colored: bool = True,
71
- stream: Optional[TextIO] = None,
71
+ stream: TextIO | None = None,
72
72
  ) -> None:
73
73
  super().__init__(stream)
74
74
  self.timestamps = timestamps
@@ -103,9 +103,9 @@ class ConsoleHandler(StreamHandler):
103
103
 
104
104
 
105
105
  def update_console_handler(
106
- level: Optional[Union[int, str]] = None,
107
- timestamps: Optional[bool] = None,
108
- colored: Optional[bool] = None,
106
+ level: int | str | None = None,
107
+ timestamps: bool | None = None,
108
+ colored: bool | None = None,
109
109
  ) -> None:
110
110
  """Update the logging handler."""
111
111
  for handler in logging.getLogger(LOGGER_NAME).handlers:
@@ -160,7 +160,7 @@ class CustomHTTPHandler(HTTPHandler):
160
160
  url: str,
161
161
  method: str = "GET",
162
162
  secure: bool = False,
163
- credentials: Optional[tuple[str, str]] = None,
163
+ credentials: tuple[str, str] | None = None,
164
164
  ) -> None:
165
165
  super().__init__(host, url, method, secure, credentials)
166
166
  self.identifier = identifier
@@ -180,7 +180,7 @@ class CustomHTTPHandler(HTTPHandler):
180
180
 
181
181
 
182
182
  def configure(
183
- identifier: str, filename: Optional[str] = None, host: Optional[str] = None
183
+ identifier: str, filename: str | None = None, host: str | None = None
184
184
  ) -> None:
185
185
  """Configure logging to file and/or remote log server."""
186
186
  # Create formatter
@@ -298,7 +298,7 @@ def set_logger_propagation(
298
298
  return child_logger
299
299
 
300
300
 
301
- def mirror_output_to_queue(log_queue: Queue[Optional[str]]) -> None:
301
+ def mirror_output_to_queue(log_queue: Queue[str | None]) -> None:
302
302
  """Mirror stdout and stderr output to the provided queue."""
303
303
 
304
304
  def get_write_fn(stream: TextIO) -> Any:
@@ -335,7 +335,7 @@ def redirect_output(output_buffer: StringIO) -> None:
335
335
 
336
336
 
337
337
  def _log_uploader(
338
- log_queue: Queue[Optional[str]], node_id: int, run_id: int, stub: ServerAppIoStub
338
+ log_queue: Queue[str | None], node_id: int, run_id: int, stub: ServerAppIoStub
339
339
  ) -> None:
340
340
  """Upload logs to the SuperLink."""
341
341
  exit_flag = False
@@ -378,10 +378,10 @@ def _log_uploader(
378
378
 
379
379
 
380
380
  def start_log_uploader(
381
- log_queue: Queue[Optional[str]],
381
+ log_queue: Queue[str | None],
382
382
  node_id: int,
383
383
  run_id: int,
384
- stub: Union[ServerAppIoStub, SimulationIoStub],
384
+ stub: ServerAppIoStub | SimulationIoStub,
385
385
  ) -> threading.Thread:
386
386
  """Start the log uploader thread and return it."""
387
387
  thread = threading.Thread(
@@ -392,7 +392,7 @@ def start_log_uploader(
392
392
 
393
393
 
394
394
  def stop_log_uploader(
395
- log_queue: Queue[Optional[str]], log_uploader: threading.Thread
395
+ log_queue: Queue[str | None], log_uploader: threading.Thread
396
396
  ) -> None:
397
397
  """Stop the log uploader thread."""
398
398
  log_queue.put(None)
@@ -403,19 +403,19 @@ def _remove_emojis(text: str) -> str:
403
403
  """Remove emojis from the provided text."""
404
404
  emoji_pattern = re.compile(
405
405
  "["
406
- "\U0001F600-\U0001F64F" # Emoticons
407
- "\U0001F300-\U0001F5FF" # Symbols & Pictographs
408
- "\U0001F680-\U0001F6FF" # Transport & Map Symbols
409
- "\U0001F1E0-\U0001F1FF" # Flags
410
- "\U00002702-\U000027B0" # Dingbats
411
- "\U000024C2-\U0001F251"
406
+ "\U0001f600-\U0001f64f" # Emoticons
407
+ "\U0001f300-\U0001f5ff" # Symbols & Pictographs
408
+ "\U0001f680-\U0001f6ff" # Transport & Map Symbols
409
+ "\U0001f1e0-\U0001f1ff" # Flags
410
+ "\U00002702-\U000027b0" # Dingbats
411
+ "\U000024c2-\U0001f251"
412
412
  "]+",
413
413
  flags=re.UNICODE,
414
414
  )
415
415
  return emoji_pattern.sub(r"", text)
416
416
 
417
417
 
418
- def print_json_error(msg: str, e: Union[typer.Exit, Exception]) -> None:
418
+ def print_json_error(msg: str, e: typer.Exit | Exception) -> None:
419
419
  """Print error message as JSON."""
420
420
  Console().print_json(
421
421
  _json.dumps(
flwr/common/message.py CHANGED
@@ -105,7 +105,7 @@ class Message(InflatableObject):
105
105
  """
106
106
 
107
107
  @overload
108
- def __init__( # pylint: disable=too-many-arguments # noqa: E704
108
+ def __init__( # pylint: disable=too-many-arguments
109
109
  self,
110
110
  content: RecordDict,
111
111
  dst_node_id: int,
@@ -116,12 +116,12 @@ class Message(InflatableObject):
116
116
  ) -> None: ...
117
117
 
118
118
  @overload
119
- def __init__( # noqa: E704
119
+ def __init__(
120
120
  self, content: RecordDict, *, reply_to: Message, ttl: float | None = None
121
121
  ) -> None: ...
122
122
 
123
123
  @overload
124
- def __init__( # noqa: E704
124
+ def __init__(
125
125
  self, error: Error, *, reply_to: Message, ttl: float | None = None
126
126
  ) -> None: ...
127
127
 
@@ -511,7 +511,7 @@ def _check_arg_types( # pylint: disable=too-many-arguments, R0917
511
511
  and (message_type is None or isinstance(message_type, str))
512
512
  and (content is None or isinstance(content, RecordDict))
513
513
  and (error is None or isinstance(error, Error))
514
- and (ttl is None or isinstance(ttl, (int, float)))
514
+ and (ttl is None or isinstance(ttl, (int | float)))
515
515
  and (group_id is None or isinstance(group_id, str))
516
516
  and (reply_to is None or isinstance(reply_to, Message))
517
517
  and (metadata is None or isinstance(metadata, Metadata))
flwr/common/object_ref.py CHANGED
@@ -21,7 +21,7 @@ import sys
21
21
  from importlib.util import find_spec
22
22
  from pathlib import Path
23
23
  from threading import Lock
24
- from typing import Any, Optional, Union
24
+ from typing import Any
25
25
 
26
26
  OBJECT_REF_HELP_STR = """
27
27
  \n\nThe object reference string should have the form <module>:<attribute>. Valid
@@ -31,15 +31,15 @@ attribute.
31
31
  """
32
32
 
33
33
 
34
- _current_sys_path: Optional[str] = None
34
+ _current_sys_path: str | None = None
35
35
  _import_lock = Lock()
36
36
 
37
37
 
38
38
  def validate(
39
39
  module_attribute_str: str,
40
40
  check_module: bool = True,
41
- project_dir: Optional[Union[str, Path]] = None,
42
- ) -> tuple[bool, Optional[str]]:
41
+ project_dir: str | Path | None = None,
42
+ ) -> tuple[bool, str | None]:
43
43
  """Validate object reference.
44
44
 
45
45
  Parameters
@@ -114,7 +114,7 @@ def validate(
114
114
  def load_app( # pylint: disable= too-many-branches
115
115
  module_attribute_str: str,
116
116
  error_type: type[Exception],
117
- project_dir: Optional[Union[str, Path]] = None,
117
+ project_dir: str | Path | None = None,
118
118
  ) -> Any:
119
119
  """Return the object specified in a module attribute string.
120
120
 
@@ -194,12 +194,12 @@ def _unload_modules(project_dir: Path) -> None:
194
194
  """Unload modules from the project directory."""
195
195
  dir_str = str(project_dir.absolute())
196
196
  for name, m in list(sys.modules.items()):
197
- path: Optional[str] = getattr(m, "__file__", None)
197
+ path: str | None = getattr(m, "__file__", None)
198
198
  if path is not None and path.startswith(dir_str):
199
199
  del sys.modules[name]
200
200
 
201
201
 
202
- def _set_sys_path(directory: Optional[Union[str, Path]]) -> None:
202
+ def _set_sys_path(directory: str | Path | None) -> None:
203
203
  """Set the system path."""
204
204
  if directory is None:
205
205
  directory = Path.cwd()
@@ -117,15 +117,15 @@ class Array(InflatableObject):
117
117
  data: bytes
118
118
 
119
119
  @overload
120
- def __init__( # noqa: E704
120
+ def __init__(
121
121
  self, dtype: str, shape: tuple[int, ...], stype: str, data: bytes
122
122
  ) -> None: ...
123
123
 
124
124
  @overload
125
- def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
125
+ def __init__(self, ndarray: NDArray) -> None: ...
126
126
 
127
127
  @overload
128
- def __init__(self, torch_tensor: torch.Tensor) -> None: ... # noqa: E704
128
+ def __init__(self, torch_tensor: torch.Tensor) -> None: ...
129
129
 
130
130
  def __init__( # pylint: disable=too-many-arguments, too-many-locals
131
131
  self,
@@ -63,7 +63,7 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
63
63
 
64
64
  A typed dictionary (``str`` to :class:`Array`) that can store named arrays,
65
65
  including model parameters, gradients, embeddings or non-parameter arrays.
66
- Internally, this behaves similarly to an ``OrderedDict[str, Array]``.
66
+ Internally, this behaves similarly to an ``dict[str, Array]``.
67
67
  An ``ArrayRecord`` can be viewed as an equivalent to PyTorch's ``state_dict``,
68
68
  but it holds arrays in a serialized form.
69
69
 
@@ -80,13 +80,13 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
80
80
 
81
81
  Parameters
82
82
  ----------
83
- array_dict : Optional[OrderedDict[str, Array]] (default: None)
83
+ array_dict : Optional[dict[str, Array]] (default: None)
84
84
  An existing dictionary containing named :class:`Array` instances. If
85
85
  provided, these entries will be used directly to populate the record.
86
86
  numpy_ndarrays : Optional[list[NDArray]] (default: None)
87
87
  A list of NumPy arrays. Each array will be automatically converted
88
88
  into an :class:`Array` and stored in this record with generated keys.
89
- torch_state_dict : Optional[OrderedDict[str, torch.Tensor]] (default: None)
89
+ torch_state_dict : Optional[dict[str, torch.Tensor]] (default: None)
90
90
  A PyTorch ``state_dict`` (``str`` keys to ``torch.Tensor`` values). Each
91
91
  tensor will be converted into an :class:`Array` and stored in this record.
92
92
  keep_input : bool (default: True)
@@ -127,31 +127,23 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
127
127
  """
128
128
 
129
129
  @overload
130
- def __init__(self) -> None: ... # noqa: E704
130
+ def __init__(self) -> None: ...
131
131
 
132
132
  @overload
133
- def __init__( # noqa: E704
134
- self, array_dict: OrderedDict[str, Array], *, keep_input: bool = True
133
+ def __init__(
134
+ self, array_dict: dict[str, Array], *, keep_input: bool = True
135
135
  ) -> None: ...
136
136
 
137
137
  @overload
138
- def __init__( # noqa: E704
138
+ def __init__(
139
139
  self, numpy_ndarrays: list[NDArray], *, keep_input: bool = True
140
140
  ) -> None: ...
141
141
 
142
142
  @overload
143
- def __init__( # noqa: E704
143
+ def __init__(
144
144
  self,
145
- torch_state_dict: OrderedDict[str, torch.Tensor],
146
- *,
147
- keep_input: bool = True,
148
- ) -> None: ...
149
-
150
- # This is also required for PyTorch state dict because they are not strongly typed
151
- @overload
152
- def __init__( # noqa: E704
153
- self,
154
- torch_state_dict: dict[str, Any],
145
+ # `Any` is required for PyTorch state dict because they are not strongly typed
146
+ torch_state_dict: dict[str, torch.Tensor] | dict[str, Any],
155
147
  *,
156
148
  keep_input: bool = True,
157
149
  ) -> None: ...
@@ -160,15 +152,15 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
160
152
  self,
161
153
  *args: Any,
162
154
  numpy_ndarrays: list[NDArray] | None = None,
163
- torch_state_dict: OrderedDict[str, torch.Tensor] | dict[str, Any] | None = None,
164
- array_dict: OrderedDict[str, Array] | None = None,
155
+ torch_state_dict: dict[str, torch.Tensor] | dict[str, Any] | None = None,
156
+ array_dict: dict[str, Array] | None = None,
165
157
  keep_input: bool = True,
166
158
  ) -> None:
167
159
  super().__init__(_check_key, _check_value)
168
160
 
169
161
  # Determine the initialization method and validates input arguments.
170
162
  # Support the following initialization formats:
171
- # 1. cls(array_dict: OrderedDict[str, Array], keep_input: bool)
163
+ # 1. cls(array_dict: dict[str, Array], keep_input: bool)
172
164
  # 2. cls(numpy_ndarrays: list[NDArray], keep_input: bool)
173
165
  # 3. cls(torch_state_dict: dict[str, torch.Tensor], keep_input: bool)
174
166
 
@@ -213,7 +205,7 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
213
205
  and all(isinstance(k, str) for k in arg.keys())
214
206
  and all(isinstance(v, Array) for v in arg.values())
215
207
  ):
216
- array_dict = cast(OrderedDict[str, Array], arg)
208
+ array_dict = cast(dict[str, Array], arg)
217
209
  converted = self.from_array_dict(array_dict, keep_input=keep_input)
218
210
  self.__dict__.update(converted.__dict__)
219
211
  return
@@ -239,9 +231,7 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
239
231
  and all(isinstance(k, str) for k in arg.keys())
240
232
  and all(isinstance(v, torch.Tensor) for v in arg.values())
241
233
  ):
242
- torch_state_dict = cast(
243
- OrderedDict[str, torch.Tensor], arg # type: ignore
244
- )
234
+ torch_state_dict = cast(dict[str, torch.Tensor], arg) # type: ignore
245
235
  converted = self.from_torch_state_dict(
246
236
  torch_state_dict, keep_input=keep_input
247
237
  )
@@ -253,7 +243,7 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
253
243
  @classmethod
254
244
  def from_array_dict(
255
245
  cls,
256
- array_dict: OrderedDict[str, Array],
246
+ array_dict: dict[str, Array],
257
247
  *,
258
248
  keep_input: bool = True,
259
249
  ) -> ArrayRecord:
@@ -300,7 +290,7 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
300
290
  @classmethod
301
291
  def from_torch_state_dict(
302
292
  cls,
303
- state_dict: OrderedDict[str, torch.Tensor],
293
+ state_dict: dict[str, torch.Tensor],
304
294
  *,
305
295
  keep_input: bool = True,
306
296
  ) -> ArrayRecord:
@@ -433,9 +423,7 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
433
423
 
434
424
  # Instantiate new ArrayRecord
435
425
  return ArrayRecord(
436
- OrderedDict(
437
- {name: children[object_id] for name, object_id in array_refs.items()}
438
- )
426
+ {name: children[object_id] for name, object_id in array_refs.items()}
439
427
  )
440
428
 
441
429
  @property
@@ -142,11 +142,11 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues], InflatableObject):
142
142
  var_bytes = 0
143
143
  if isinstance(value, bool):
144
144
  var_bytes = 1
145
- elif isinstance(value, (int, float)):
145
+ elif isinstance(value, (int | float)):
146
146
  var_bytes = (
147
147
  8 # the profobufing represents int/floats in ConfigRecords as 64bit
148
148
  )
149
- if isinstance(value, (str, bytes)):
149
+ if isinstance(value, (str | bytes)):
150
150
  var_bytes = len(value)
151
151
  if var_bytes == 0:
152
152
  raise ValueError(
@@ -159,7 +159,7 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues], InflatableObject):
159
159
 
160
160
  for k, v in self.items():
161
161
  if isinstance(v, list):
162
- if isinstance(v[0], (bytes, str)):
162
+ if isinstance(v[0], (bytes | str)):
163
163
  # not all str are of equal length necessarily
164
164
  # for both the footprint of each element is 1 Byte
165
165
  num_bytes += int(sum(len(s) for s in v)) # type: ignore