flwr 1.23.0__py3-none-any.whl → 1.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (339) hide show
  1. flwr/__init__.py +16 -5
  2. flwr/app/error.py +2 -2
  3. flwr/app/exception.py +3 -3
  4. flwr/cli/app.py +19 -0
  5. flwr/cli/{new/templates → app_cmd}/__init__.py +9 -1
  6. flwr/cli/app_cmd/publish.py +285 -0
  7. flwr/cli/app_cmd/review.py +262 -0
  8. flwr/cli/auth_plugin/auth_plugin.py +4 -5
  9. flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
  10. flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
  11. flwr/cli/build.py +60 -18
  12. flwr/cli/cli_account_auth_interceptor.py +24 -7
  13. flwr/cli/config_utils.py +101 -13
  14. flwr/cli/{new/templates/app/code/flwr_tune → federation}/__init__.py +10 -1
  15. flwr/cli/federation/ls.py +140 -0
  16. flwr/cli/federation/show.py +318 -0
  17. flwr/cli/install.py +91 -13
  18. flwr/cli/log.py +52 -9
  19. flwr/cli/login/login.py +7 -4
  20. flwr/cli/ls.py +211 -130
  21. flwr/cli/new/new.py +123 -331
  22. flwr/cli/pull.py +10 -5
  23. flwr/cli/run/run.py +71 -29
  24. flwr/cli/run_utils.py +148 -0
  25. flwr/cli/stop.py +26 -8
  26. flwr/cli/supernode/ls.py +25 -12
  27. flwr/cli/supernode/register.py +9 -4
  28. flwr/cli/supernode/unregister.py +5 -3
  29. flwr/cli/utils.py +239 -16
  30. flwr/client/__init__.py +1 -1
  31. flwr/client/dpfedavg_numpy_client.py +4 -1
  32. flwr/client/grpc_adapter_client/connection.py +8 -9
  33. flwr/client/grpc_rere_client/connection.py +16 -14
  34. flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
  35. flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
  36. flwr/client/message_handler/message_handler.py +2 -2
  37. flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
  38. flwr/client/numpy_client.py +1 -1
  39. flwr/client/rest_client/connection.py +18 -18
  40. flwr/client/run_info_store.py +4 -5
  41. flwr/client/typing.py +1 -1
  42. flwr/clientapp/client_app.py +9 -10
  43. flwr/clientapp/mod/centraldp_mods.py +16 -17
  44. flwr/clientapp/mod/localdp_mod.py +8 -9
  45. flwr/clientapp/typing.py +1 -1
  46. flwr/clientapp/utils.py +3 -3
  47. flwr/common/address.py +1 -2
  48. flwr/common/args.py +3 -4
  49. flwr/common/config.py +13 -16
  50. flwr/common/constant.py +5 -2
  51. flwr/common/differential_privacy.py +3 -4
  52. flwr/common/event_log_plugin/event_log_plugin.py +3 -4
  53. flwr/common/exit/exit.py +15 -2
  54. flwr/common/exit/exit_code.py +19 -0
  55. flwr/common/exit/exit_handler.py +6 -2
  56. flwr/common/exit/signal_handler.py +5 -5
  57. flwr/common/grpc.py +6 -6
  58. flwr/common/inflatable_protobuf_utils.py +1 -1
  59. flwr/common/inflatable_utils.py +38 -21
  60. flwr/common/logger.py +19 -19
  61. flwr/common/message.py +4 -4
  62. flwr/common/object_ref.py +7 -7
  63. flwr/common/record/array.py +3 -3
  64. flwr/common/record/arrayrecord.py +18 -30
  65. flwr/common/record/configrecord.py +3 -3
  66. flwr/common/record/recorddict.py +5 -5
  67. flwr/common/record/typeddict.py +9 -2
  68. flwr/common/recorddict_compat.py +7 -10
  69. flwr/common/retry_invoker.py +20 -20
  70. flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
  71. flwr/common/serde.py +11 -4
  72. flwr/common/serde_utils.py +2 -2
  73. flwr/common/telemetry.py +9 -5
  74. flwr/common/typing.py +58 -37
  75. flwr/compat/client/app.py +38 -37
  76. flwr/compat/client/grpc_client/connection.py +11 -11
  77. flwr/compat/server/app.py +5 -6
  78. flwr/proto/appio_pb2.py +13 -3
  79. flwr/proto/appio_pb2.pyi +134 -65
  80. flwr/proto/appio_pb2_grpc.py +20 -0
  81. flwr/proto/appio_pb2_grpc.pyi +27 -0
  82. flwr/proto/clientappio_pb2.py +17 -7
  83. flwr/proto/clientappio_pb2.pyi +15 -0
  84. flwr/proto/clientappio_pb2_grpc.py +206 -40
  85. flwr/proto/clientappio_pb2_grpc.pyi +168 -53
  86. flwr/proto/control_pb2.py +71 -52
  87. flwr/proto/control_pb2.pyi +277 -111
  88. flwr/proto/control_pb2_grpc.py +249 -40
  89. flwr/proto/control_pb2_grpc.pyi +185 -52
  90. flwr/proto/error_pb2.py +13 -3
  91. flwr/proto/error_pb2.pyi +24 -6
  92. flwr/proto/error_pb2_grpc.py +20 -0
  93. flwr/proto/error_pb2_grpc.pyi +27 -0
  94. flwr/proto/fab_pb2.py +14 -4
  95. flwr/proto/fab_pb2.pyi +59 -31
  96. flwr/proto/fab_pb2_grpc.py +20 -0
  97. flwr/proto/fab_pb2_grpc.pyi +27 -0
  98. flwr/proto/federation_pb2.py +38 -0
  99. flwr/proto/federation_pb2.pyi +56 -0
  100. flwr/proto/federation_pb2_grpc.py +24 -0
  101. flwr/proto/federation_pb2_grpc.pyi +31 -0
  102. flwr/proto/fleet_pb2.py +24 -14
  103. flwr/proto/fleet_pb2.pyi +141 -61
  104. flwr/proto/fleet_pb2_grpc.py +189 -48
  105. flwr/proto/fleet_pb2_grpc.pyi +175 -61
  106. flwr/proto/grpcadapter_pb2.py +14 -4
  107. flwr/proto/grpcadapter_pb2.pyi +38 -16
  108. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  109. flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
  110. flwr/proto/heartbeat_pb2.py +17 -7
  111. flwr/proto/heartbeat_pb2.pyi +51 -22
  112. flwr/proto/heartbeat_pb2_grpc.py +20 -0
  113. flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
  114. flwr/proto/log_pb2.py +13 -3
  115. flwr/proto/log_pb2.pyi +34 -11
  116. flwr/proto/log_pb2_grpc.py +20 -0
  117. flwr/proto/log_pb2_grpc.pyi +27 -0
  118. flwr/proto/message_pb2.py +15 -5
  119. flwr/proto/message_pb2.pyi +154 -86
  120. flwr/proto/message_pb2_grpc.py +20 -0
  121. flwr/proto/message_pb2_grpc.pyi +27 -0
  122. flwr/proto/node_pb2.py +15 -5
  123. flwr/proto/node_pb2.pyi +50 -25
  124. flwr/proto/node_pb2_grpc.py +20 -0
  125. flwr/proto/node_pb2_grpc.pyi +27 -0
  126. flwr/proto/recorddict_pb2.py +13 -3
  127. flwr/proto/recorddict_pb2.pyi +184 -107
  128. flwr/proto/recorddict_pb2_grpc.py +20 -0
  129. flwr/proto/recorddict_pb2_grpc.pyi +27 -0
  130. flwr/proto/run_pb2.py +40 -31
  131. flwr/proto/run_pb2.pyi +158 -84
  132. flwr/proto/run_pb2_grpc.py +20 -0
  133. flwr/proto/run_pb2_grpc.pyi +27 -0
  134. flwr/proto/serverappio_pb2.py +13 -3
  135. flwr/proto/serverappio_pb2.pyi +32 -8
  136. flwr/proto/serverappio_pb2_grpc.py +246 -65
  137. flwr/proto/serverappio_pb2_grpc.pyi +221 -85
  138. flwr/proto/simulationio_pb2.py +16 -8
  139. flwr/proto/simulationio_pb2.pyi +15 -0
  140. flwr/proto/simulationio_pb2_grpc.py +162 -41
  141. flwr/proto/simulationio_pb2_grpc.pyi +149 -55
  142. flwr/proto/transport_pb2.py +20 -10
  143. flwr/proto/transport_pb2.pyi +249 -160
  144. flwr/proto/transport_pb2_grpc.py +35 -4
  145. flwr/proto/transport_pb2_grpc.pyi +38 -8
  146. flwr/server/app.py +39 -17
  147. flwr/server/client_manager.py +4 -5
  148. flwr/server/client_proxy.py +10 -11
  149. flwr/server/compat/app.py +4 -5
  150. flwr/server/compat/app_utils.py +2 -1
  151. flwr/server/compat/grid_client_proxy.py +10 -12
  152. flwr/server/compat/legacy_context.py +3 -4
  153. flwr/server/fleet_event_log_interceptor.py +2 -1
  154. flwr/server/grid/grid.py +2 -3
  155. flwr/server/grid/grpc_grid.py +10 -8
  156. flwr/server/grid/inmemory_grid.py +4 -4
  157. flwr/server/run_serverapp.py +2 -3
  158. flwr/server/server.py +34 -39
  159. flwr/server/server_app.py +7 -8
  160. flwr/server/server_config.py +1 -2
  161. flwr/server/serverapp/app.py +34 -28
  162. flwr/server/serverapp_components.py +4 -5
  163. flwr/server/strategy/aggregate.py +9 -8
  164. flwr/server/strategy/bulyan.py +13 -11
  165. flwr/server/strategy/dp_adaptive_clipping.py +16 -20
  166. flwr/server/strategy/dp_fixed_clipping.py +12 -17
  167. flwr/server/strategy/dpfedavg_adaptive.py +3 -4
  168. flwr/server/strategy/dpfedavg_fixed.py +6 -10
  169. flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
  170. flwr/server/strategy/fedadagrad.py +18 -14
  171. flwr/server/strategy/fedadam.py +16 -14
  172. flwr/server/strategy/fedavg.py +16 -17
  173. flwr/server/strategy/fedavg_android.py +15 -15
  174. flwr/server/strategy/fedavgm.py +21 -18
  175. flwr/server/strategy/fedmedian.py +2 -3
  176. flwr/server/strategy/fedopt.py +11 -10
  177. flwr/server/strategy/fedprox.py +10 -9
  178. flwr/server/strategy/fedtrimmedavg.py +12 -11
  179. flwr/server/strategy/fedxgb_bagging.py +13 -11
  180. flwr/server/strategy/fedxgb_cyclic.py +6 -6
  181. flwr/server/strategy/fedxgb_nn_avg.py +4 -4
  182. flwr/server/strategy/fedyogi.py +16 -14
  183. flwr/server/strategy/krum.py +12 -11
  184. flwr/server/strategy/qfedavg.py +16 -15
  185. flwr/server/strategy/strategy.py +6 -9
  186. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
  187. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
  190. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
  192. flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
  193. flwr/server/superlink/fleet/message_handler/message_handler.py +75 -30
  194. flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
  195. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  196. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  197. flwr/server/superlink/fleet/vce/vce_api.py +15 -9
  198. flwr/server/superlink/linkstate/in_memory_linkstate.py +148 -149
  199. flwr/server/superlink/linkstate/linkstate.py +91 -43
  200. flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
  201. flwr/server/superlink/linkstate/sqlite_linkstate.py +502 -436
  202. flwr/server/superlink/linkstate/utils.py +6 -6
  203. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
  204. flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
  205. flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
  206. flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
  207. flwr/server/superlink/utils.py +4 -6
  208. flwr/server/typing.py +1 -1
  209. flwr/server/utils/tensorboard.py +15 -8
  210. flwr/server/workflow/default_workflows.py +5 -5
  211. flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
  213. flwr/serverapp/strategy/bulyan.py +16 -15
  214. flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
  215. flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
  216. flwr/serverapp/strategy/fedadagrad.py +10 -11
  217. flwr/serverapp/strategy/fedadam.py +10 -11
  218. flwr/serverapp/strategy/fedavg.py +9 -10
  219. flwr/serverapp/strategy/fedavgm.py +17 -16
  220. flwr/serverapp/strategy/fedmedian.py +2 -2
  221. flwr/serverapp/strategy/fedopt.py +10 -11
  222. flwr/serverapp/strategy/fedprox.py +7 -8
  223. flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
  224. flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
  225. flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
  226. flwr/serverapp/strategy/fedyogi.py +9 -11
  227. flwr/serverapp/strategy/krum.py +7 -7
  228. flwr/serverapp/strategy/multikrum.py +9 -9
  229. flwr/serverapp/strategy/qfedavg.py +17 -16
  230. flwr/serverapp/strategy/strategy.py +6 -9
  231. flwr/serverapp/strategy/strategy_utils.py +7 -8
  232. flwr/simulation/app.py +46 -42
  233. flwr/simulation/legacy_app.py +12 -12
  234. flwr/simulation/ray_transport/ray_actor.py +10 -11
  235. flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
  236. flwr/simulation/run_simulation.py +43 -43
  237. flwr/simulation/simulationio_connection.py +4 -4
  238. flwr/supercore/cli/flower_superexec.py +3 -4
  239. flwr/supercore/constant.py +34 -1
  240. flwr/supercore/corestate/corestate.py +24 -3
  241. flwr/supercore/corestate/in_memory_corestate.py +138 -0
  242. flwr/supercore/corestate/sqlite_corestate.py +157 -0
  243. flwr/supercore/ffs/disk_ffs.py +1 -2
  244. flwr/supercore/ffs/ffs.py +1 -2
  245. flwr/supercore/ffs/ffs_factory.py +1 -2
  246. flwr/{common → supercore}/heartbeat.py +20 -25
  247. flwr/supercore/object_store/in_memory_object_store.py +1 -2
  248. flwr/supercore/object_store/object_store.py +1 -2
  249. flwr/supercore/object_store/object_store_factory.py +1 -2
  250. flwr/supercore/object_store/sqlite_object_store.py +8 -7
  251. flwr/supercore/primitives/asymmetric.py +1 -1
  252. flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
  253. flwr/supercore/sqlite_mixin.py +37 -34
  254. flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
  255. flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
  256. flwr/supercore/superexec/run_superexec.py +9 -13
  257. flwr/supercore/utils.py +190 -0
  258. flwr/superlink/artifact_provider/artifact_provider.py +1 -2
  259. flwr/superlink/auth_plugin/auth_plugin.py +6 -9
  260. flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
  261. flwr/{cli/new/templates/app → superlink/federation}/__init__.py +10 -1
  262. flwr/superlink/federation/federation_manager.py +64 -0
  263. flwr/superlink/federation/noop_federation_manager.py +71 -0
  264. flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
  265. flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
  266. flwr/superlink/servicer/control/control_grpc.py +7 -6
  267. flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
  268. flwr/superlink/servicer/control/control_servicer.py +190 -23
  269. flwr/supernode/cli/flower_supernode.py +58 -3
  270. flwr/supernode/nodestate/in_memory_nodestate.py +121 -49
  271. flwr/supernode/nodestate/nodestate.py +52 -8
  272. flwr/supernode/nodestate/nodestate_factory.py +7 -4
  273. flwr/supernode/runtime/run_clientapp.py +41 -22
  274. flwr/supernode/servicer/clientappio/clientappio_servicer.py +46 -10
  275. flwr/supernode/start_client_internal.py +165 -46
  276. {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/METADATA +9 -11
  277. flwr-1.25.0.dist-info/RECORD +393 -0
  278. flwr/cli/new/templates/app/.gitignore.tpl +0 -163
  279. flwr/cli/new/templates/app/LICENSE.tpl +0 -202
  280. flwr/cli/new/templates/app/README.baseline.md.tpl +0 -127
  281. flwr/cli/new/templates/app/README.flowertune.md.tpl +0 -68
  282. flwr/cli/new/templates/app/README.md.tpl +0 -37
  283. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +0 -1
  284. flwr/cli/new/templates/app/code/__init__.py +0 -15
  285. flwr/cli/new/templates/app/code/__init__.py.tpl +0 -1
  286. flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +0 -1
  287. flwr/cli/new/templates/app/code/client.baseline.py.tpl +0 -75
  288. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +0 -93
  289. flwr/cli/new/templates/app/code/client.jax.py.tpl +0 -71
  290. flwr/cli/new/templates/app/code/client.mlx.py.tpl +0 -102
  291. flwr/cli/new/templates/app/code/client.numpy.py.tpl +0 -46
  292. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +0 -80
  293. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +0 -55
  294. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +0 -108
  295. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -82
  296. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +0 -110
  297. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +0 -36
  298. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +0 -92
  299. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +0 -87
  300. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -56
  301. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +0 -73
  302. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +0 -78
  303. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -66
  304. flwr/cli/new/templates/app/code/server.baseline.py.tpl +0 -43
  305. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +0 -42
  306. flwr/cli/new/templates/app/code/server.jax.py.tpl +0 -39
  307. flwr/cli/new/templates/app/code/server.mlx.py.tpl +0 -41
  308. flwr/cli/new/templates/app/code/server.numpy.py.tpl +0 -38
  309. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +0 -41
  310. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +0 -31
  311. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +0 -44
  312. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +0 -38
  313. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +0 -56
  314. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +0 -1
  315. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +0 -98
  316. flwr/cli/new/templates/app/code/task.jax.py.tpl +0 -57
  317. flwr/cli/new/templates/app/code/task.mlx.py.tpl +0 -102
  318. flwr/cli/new/templates/app/code/task.numpy.py.tpl +0 -7
  319. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +0 -98
  320. flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +0 -111
  321. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +0 -67
  322. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +0 -52
  323. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +0 -67
  324. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +0 -1
  325. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +0 -146
  326. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +0 -80
  327. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +0 -65
  328. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +0 -52
  329. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +0 -56
  330. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +0 -49
  331. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +0 -53
  332. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +0 -53
  333. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +0 -52
  334. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +0 -53
  335. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +0 -61
  336. flwr/supercore/object_store/utils.py +0 -43
  337. flwr-1.23.0.dist-info/RECORD +0 -439
  338. {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/WHEEL +0 -0
  339. {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/entry_points.txt +0 -0
@@ -1,41 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import torch
4
- from flwr.app import ArrayRecord, ConfigRecord, Context
5
- from flwr.serverapp import Grid, ServerApp
6
- from flwr.serverapp.strategy import FedAvg
7
-
8
- from $import_name.task import Net
9
-
10
- # Create ServerApp
11
- app = ServerApp()
12
-
13
-
14
- @app.main()
15
- def main(grid: Grid, context: Context) -> None:
16
- """Main entry point for the ServerApp."""
17
-
18
- # Read run config
19
- fraction_train: float = context.run_config["fraction-train"]
20
- num_rounds: int = context.run_config["num-server-rounds"]
21
- lr: float = context.run_config["lr"]
22
-
23
- # Load global model
24
- global_model = Net()
25
- arrays = ArrayRecord(global_model.state_dict())
26
-
27
- # Initialize FedAvg strategy
28
- strategy = FedAvg(fraction_train=fraction_train)
29
-
30
- # Start strategy, run FedAvg for `num_rounds`
31
- result = strategy.start(
32
- grid=grid,
33
- initial_arrays=arrays,
34
- train_config=ConfigRecord({"lr": lr}),
35
- num_rounds=num_rounds,
36
- )
37
-
38
- # Save final model to disk
39
- print("\nSaving final model to disk...")
40
- state_dict = result.arrays.to_torch_state_dict()
41
- torch.save(state_dict, "final_model.pt")
@@ -1,31 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
6
- from $import_name.task import Net, get_weights
7
-
8
-
9
- def server_fn(context: Context):
10
- # Read from config
11
- num_rounds = context.run_config["num-server-rounds"]
12
- fraction_fit = context.run_config["fraction-fit"]
13
-
14
- # Initialize model parameters
15
- ndarrays = get_weights(Net())
16
- parameters = ndarrays_to_parameters(ndarrays)
17
-
18
- # Define strategy
19
- strategy = FedAvg(
20
- fraction_fit=fraction_fit,
21
- fraction_evaluate=1.0,
22
- min_available_clients=2,
23
- initial_parameters=parameters,
24
- )
25
- config = ServerConfig(num_rounds=num_rounds)
26
-
27
- return ServerAppComponents(strategy=strategy, config=config)
28
-
29
-
30
- # Create ServerApp
31
- app = ServerApp(server_fn=server_fn)
@@ -1,44 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import joblib
4
- from flwr.app import ArrayRecord, Context
5
- from flwr.serverapp import Grid, ServerApp
6
- from flwr.serverapp.strategy import FedAvg
7
-
8
- from $import_name.task import get_model, get_model_params, set_initial_params, set_model_params
9
-
10
- # Create ServerApp
11
- app = ServerApp()
12
-
13
-
14
- @app.main()
15
- def main(grid: Grid, context: Context) -> None:
16
- """Main entry point for the ServerApp."""
17
-
18
- # Read run config
19
- num_rounds: int = context.run_config["num-server-rounds"]
20
-
21
- # Create LogisticRegression Model
22
- penalty = context.run_config["penalty"]
23
- local_epochs = context.run_config["local-epochs"]
24
- model = get_model(penalty, local_epochs)
25
- # Setting initial parameters, akin to model.compile for keras models
26
- set_initial_params(model)
27
- # Construct ArrayRecord representation
28
- arrays = ArrayRecord(get_model_params(model))
29
-
30
- # Initialize FedAvg strategy
31
- strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
32
-
33
- # Start strategy, run FedAvg for `num_rounds`
34
- result = strategy.start(
35
- grid=grid,
36
- initial_arrays=arrays,
37
- num_rounds=num_rounds,
38
- )
39
-
40
- # Save final model parameters
41
- print("\nSaving final model to disk...")
42
- ndarrays = result.arrays.to_numpy_ndarrays()
43
- set_model_params(model, ndarrays)
44
- joblib.dump(model, "logreg_model.pkl")
@@ -1,38 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- from flwr.app import ArrayRecord, Context
4
- from flwr.serverapp import Grid, ServerApp
5
- from flwr.serverapp.strategy import FedAvg
6
-
7
- from $import_name.task import load_model
8
-
9
- # Create ServerApp
10
- app = ServerApp()
11
-
12
-
13
- @app.main()
14
- def main(grid: Grid, context: Context) -> None:
15
- """Main entry point for the ServerApp."""
16
-
17
- # Read run config
18
- num_rounds: int = context.run_config["num-server-rounds"]
19
-
20
- # Load global model
21
- model = load_model()
22
- arrays = ArrayRecord(model.get_weights())
23
-
24
- # Initialize FedAvg strategy
25
- strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
26
-
27
- # Start strategy, run FedAvg for `num_rounds`
28
- result = strategy.start(
29
- grid=grid,
30
- initial_arrays=arrays,
31
- num_rounds=num_rounds,
32
- )
33
-
34
- # Save final model to disk
35
- print("\nSaving final model to disk...")
36
- ndarrays = result.arrays.to_numpy_ndarrays()
37
- model.set_weights(ndarrays)
38
- model.save("final_model.keras")
@@ -1,56 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import numpy as np
4
- import xgboost as xgb
5
- from flwr.app import ArrayRecord, Context
6
- from flwr.common.config import unflatten_dict
7
- from flwr.serverapp import Grid, ServerApp
8
- from flwr.serverapp.strategy import FedXgbBagging
9
-
10
- from $import_name.task import replace_keys
11
-
12
- # Create ServerApp
13
- app = ServerApp()
14
-
15
-
16
- @app.main()
17
- def main(grid: Grid, context: Context) -> None:
18
- # Read run config
19
- num_rounds = context.run_config["num-server-rounds"]
20
- fraction_train = context.run_config["fraction-train"]
21
- fraction_evaluate = context.run_config["fraction-evaluate"]
22
- # Flatted config dict and replace "-" with "_"
23
- cfg = replace_keys(unflatten_dict(context.run_config))
24
- params = cfg["params"]
25
-
26
- # Init global model
27
- # Init with an empty object; the XGBooster will be created
28
- # and trained on the client side.
29
- global_model = b""
30
- # Note: we store the model as the first item in a list into ArrayRecord,
31
- # which can be accessed using index ["0"].
32
- arrays = ArrayRecord([np.frombuffer(global_model, dtype=np.uint8)])
33
-
34
- # Initialize FedXgbBagging strategy
35
- strategy = FedXgbBagging(
36
- fraction_train=fraction_train,
37
- fraction_evaluate=fraction_evaluate,
38
- )
39
-
40
- # Start strategy, run FedXgbBagging for `num_rounds`
41
- result = strategy.start(
42
- grid=grid,
43
- initial_arrays=arrays,
44
- num_rounds=num_rounds,
45
- )
46
-
47
- # Save final model to disk
48
- bst = xgb.Booster(params=params)
49
- global_model = bytearray(result.arrays["0"].numpy().tobytes())
50
-
51
- # Load global model into booster
52
- bst.load_model(global_model)
53
-
54
- # Save model
55
- print("\nSaving final model to disk...")
56
- bst.save_model("final_model.json")
@@ -1 +0,0 @@
1
- """$project_name: A Flower Baseline."""
@@ -1,98 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import warnings
4
-
5
- import torch
6
- import transformers
7
- from datasets.utils.logging import disable_progress_bar
8
- from evaluate import load as load_metric
9
- from flwr_datasets import FederatedDataset
10
- from flwr_datasets.partitioner import IidPartitioner
11
- from torch.optim import AdamW
12
- from torch.utils.data import DataLoader
13
- from transformers import AutoTokenizer, DataCollatorWithPadding
14
-
15
- warnings.filterwarnings("ignore", category=UserWarning)
16
- warnings.filterwarnings("ignore", category=FutureWarning)
17
- disable_progress_bar()
18
- transformers.logging.set_verbosity_error()
19
-
20
-
21
- fds = None # Cache FederatedDataset
22
-
23
-
24
- def load_data(partition_id: int, num_partitions: int, model_name: str):
25
- """Load IMDB data (training and eval)"""
26
- # Only initialize `FederatedDataset` once
27
- global fds
28
- if fds is None:
29
- partitioner = IidPartitioner(num_partitions=num_partitions)
30
- fds = FederatedDataset(
31
- dataset="stanfordnlp/imdb",
32
- partitioners={"train": partitioner},
33
- )
34
- partition = fds.load_partition(partition_id)
35
- # Divide data: 80% train, 20% test
36
- partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
37
-
38
- tokenizer = AutoTokenizer.from_pretrained(model_name)
39
-
40
- def tokenize_function(examples):
41
- return tokenizer(
42
- examples["text"], truncation=True, add_special_tokens=True, max_length=512
43
- )
44
-
45
- partition_train_test = partition_train_test.map(tokenize_function, batched=True)
46
- partition_train_test = partition_train_test.remove_columns("text")
47
- partition_train_test = partition_train_test.rename_column("label", "labels")
48
-
49
- data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
50
- trainloader = DataLoader(
51
- partition_train_test["train"],
52
- shuffle=True,
53
- batch_size=32,
54
- collate_fn=data_collator,
55
- )
56
-
57
- testloader = DataLoader(
58
- partition_train_test["test"], batch_size=32, collate_fn=data_collator
59
- )
60
-
61
- return trainloader, testloader
62
-
63
-
64
- def train(net, trainloader, num_steps, device):
65
- optimizer = AdamW(net.parameters(), lr=5e-5)
66
- net.train()
67
- running_loss = 0.0
68
- step_cnt = 0
69
- for batch in trainloader:
70
- batch = {k: v.to(device) for k, v in batch.items()}
71
- outputs = net(**batch)
72
- loss = outputs.loss
73
- loss.backward()
74
- optimizer.step()
75
- optimizer.zero_grad()
76
- running_loss += loss.item()
77
- step_cnt += 1
78
- if step_cnt >= num_steps:
79
- break
80
- avg_trainloss = running_loss / step_cnt
81
- return avg_trainloss
82
-
83
-
84
- def test(net, testloader, device):
85
- metric = load_metric("accuracy")
86
- loss = 0
87
- net.eval()
88
- for batch in testloader:
89
- batch = {k: v.to(device) for k, v in batch.items()}
90
- with torch.no_grad():
91
- outputs = net(**batch)
92
- logits = outputs.logits
93
- loss += outputs.loss.item()
94
- predictions = torch.argmax(logits, dim=-1)
95
- metric.add_batch(predictions=predictions, references=batch["labels"])
96
- loss /= len(testloader.dataset)
97
- accuracy = metric.compute()["accuracy"]
98
- return loss, accuracy
@@ -1,57 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
- from sklearn.datasets import make_regression
7
- from sklearn.model_selection import train_test_split
8
-
9
- key = jax.random.PRNGKey(0)
10
-
11
-
12
- def load_data():
13
- # Load dataset
14
- X, y = make_regression(n_features=3, random_state=0)
15
- X, X_test, y, y_test = train_test_split(X, y)
16
- return X, y, X_test, y_test
17
-
18
-
19
- def load_model(model_shape):
20
- # Extract model parameters
21
- params = {"b": jax.random.uniform(key), "w": jax.random.uniform(key, model_shape)}
22
- return params
23
-
24
-
25
- def loss_fn(params, X, y):
26
- # Return MSE as loss
27
- err = jnp.dot(X, params["w"]) + params["b"] - y
28
- return jnp.mean(jnp.square(err))
29
-
30
-
31
- def train(params, grad_fn, X, y):
32
- loss = 1_000_000
33
- num_examples = X.shape[0]
34
- for _ in range(50):
35
- grads = grad_fn(params, X, y)
36
- params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
37
- loss = loss_fn(params, X, y)
38
- return params, loss, num_examples
39
-
40
-
41
- def evaluation(params, grad_fn, X_test, y_test):
42
- num_examples = X_test.shape[0]
43
- err_test = loss_fn(params, X_test, y_test)
44
- loss_test = jnp.mean(jnp.square(err_test))
45
- return loss_test, num_examples
46
-
47
-
48
- def get_params(params):
49
- parameters = []
50
- for _, val in params.items():
51
- parameters.append(np.array(val))
52
- return parameters
53
-
54
-
55
- def set_params(local_params, global_params):
56
- for key, value in list(zip(local_params.keys(), global_params)):
57
- local_params[key] = value
@@ -1,102 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import mlx.core as mx
4
- import mlx.nn as nn
5
- import numpy as np
6
- from flwr_datasets import FederatedDataset
7
- from flwr_datasets.partitioner import IidPartitioner
8
-
9
- from datasets.utils.logging import disable_progress_bar
10
-
11
- disable_progress_bar()
12
-
13
-
14
- class MLP(nn.Module):
15
- """A simple MLP."""
16
-
17
- def __init__(
18
- self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
19
- ):
20
- super().__init__()
21
- layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
22
- self.layers = [
23
- nn.Linear(idim, odim)
24
- for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
25
- ]
26
-
27
- def __call__(self, x):
28
- for l in self.layers[:-1]:
29
- x = mx.maximum(l(x), 0.0)
30
- return self.layers[-1](x)
31
-
32
-
33
- def loss_fn(model, X, y):
34
- return mx.mean(nn.losses.cross_entropy(model(X), y))
35
-
36
-
37
- def eval_fn(model, X, y):
38
- return mx.mean(mx.argmax(model(X), axis=1) == y)
39
-
40
-
41
- def batch_iterate(batch_size, X, y):
42
- perm = mx.array(np.random.permutation(y.size))
43
- for s in range(0, y.size, batch_size):
44
- ids = perm[s : s + batch_size]
45
- yield X[ids], y[ids]
46
-
47
-
48
- fds = None # Cache FederatedDataset
49
-
50
-
51
- def load_data(partition_id: int, num_partitions: int):
52
- # Only initialize `FederatedDataset` once
53
- global fds
54
- if fds is None:
55
- partitioner = IidPartitioner(num_partitions=num_partitions)
56
- fds = FederatedDataset(
57
- dataset="ylecun/mnist",
58
- partitioners={"train": partitioner},
59
- trust_remote_code=True,
60
- )
61
- partition = fds.load_partition(partition_id)
62
- partition_splits = partition.train_test_split(test_size=0.2, seed=42)
63
-
64
- partition_splits["train"].set_format("numpy")
65
- partition_splits["test"].set_format("numpy")
66
-
67
- train_partition = partition_splits["train"].map(
68
- lambda img: {
69
- "img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
70
- },
71
- input_columns="image",
72
- )
73
- test_partition = partition_splits["test"].map(
74
- lambda img: {
75
- "img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
76
- },
77
- input_columns="image",
78
- )
79
-
80
- data = (
81
- train_partition["img"],
82
- train_partition["label"].astype(np.uint32),
83
- test_partition["img"],
84
- test_partition["label"].astype(np.uint32),
85
- )
86
-
87
- train_images, train_labels, test_images, test_labels = map(mx.array, data)
88
- return train_images, train_labels, test_images, test_labels
89
-
90
-
91
- def get_params(model):
92
- layers = model.parameters()["layers"]
93
- return [np.array(val) for layer in layers for _, val in layer.items()]
94
-
95
-
96
- def set_params(model, parameters):
97
- new_params = {}
98
- new_params["layers"] = [
99
- {"weight": mx.array(parameters[i]), "bias": mx.array(parameters[i + 1])}
100
- for i in range(0, len(parameters), 2)
101
- ]
102
- model.update(new_params)
@@ -1,7 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import numpy as np
4
-
5
-
6
- def get_dummy_model():
7
- return [np.ones((1, 1))]
@@ -1,98 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from flwr_datasets import FederatedDataset
7
- from flwr_datasets.partitioner import IidPartitioner
8
- from torch.utils.data import DataLoader
9
- from torchvision.transforms import Compose, Normalize, ToTensor
10
-
11
-
12
- class Net(nn.Module):
13
- """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
14
-
15
- def __init__(self):
16
- super(Net, self).__init__()
17
- self.conv1 = nn.Conv2d(3, 6, 5)
18
- self.pool = nn.MaxPool2d(2, 2)
19
- self.conv2 = nn.Conv2d(6, 16, 5)
20
- self.fc1 = nn.Linear(16 * 5 * 5, 120)
21
- self.fc2 = nn.Linear(120, 84)
22
- self.fc3 = nn.Linear(84, 10)
23
-
24
- def forward(self, x):
25
- x = self.pool(F.relu(self.conv1(x)))
26
- x = self.pool(F.relu(self.conv2(x)))
27
- x = x.view(-1, 16 * 5 * 5)
28
- x = F.relu(self.fc1(x))
29
- x = F.relu(self.fc2(x))
30
- return self.fc3(x)
31
-
32
-
33
- fds = None # Cache FederatedDataset
34
-
35
- pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
36
-
37
-
38
- def apply_transforms(batch):
39
- """Apply transforms to the partition from FederatedDataset."""
40
- batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
41
- return batch
42
-
43
-
44
- def load_data(partition_id: int, num_partitions: int):
45
- """Load partition CIFAR10 data."""
46
- # Only initialize `FederatedDataset` once
47
- global fds
48
- if fds is None:
49
- partitioner = IidPartitioner(num_partitions=num_partitions)
50
- fds = FederatedDataset(
51
- dataset="uoft-cs/cifar10",
52
- partitioners={"train": partitioner},
53
- )
54
- partition = fds.load_partition(partition_id)
55
- # Divide data on each node: 80% train, 20% test
56
- partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
57
- # Construct dataloaders
58
- partition_train_test = partition_train_test.with_transform(apply_transforms)
59
- trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
60
- testloader = DataLoader(partition_train_test["test"], batch_size=32)
61
- return trainloader, testloader
62
-
63
-
64
- def train(net, trainloader, epochs, lr, device):
65
- """Train the model on the training set."""
66
- net.to(device) # move model to GPU if available
67
- criterion = torch.nn.CrossEntropyLoss().to(device)
68
- optimizer = torch.optim.Adam(net.parameters(), lr=lr)
69
- net.train()
70
- running_loss = 0.0
71
- for _ in range(epochs):
72
- for batch in trainloader:
73
- images = batch["img"].to(device)
74
- labels = batch["label"].to(device)
75
- optimizer.zero_grad()
76
- loss = criterion(net(images), labels)
77
- loss.backward()
78
- optimizer.step()
79
- running_loss += loss.item()
80
- avg_trainloss = running_loss / len(trainloader)
81
- return avg_trainloss
82
-
83
-
84
- def test(net, testloader, device):
85
- """Validate the model on the test set."""
86
- net.to(device)
87
- criterion = torch.nn.CrossEntropyLoss()
88
- correct, loss = 0, 0.0
89
- with torch.no_grad():
90
- for batch in testloader:
91
- images = batch["img"].to(device)
92
- labels = batch["label"].to(device)
93
- outputs = net(images)
94
- loss += criterion(outputs, labels).item()
95
- correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
96
- accuracy = correct / len(testloader.dataset)
97
- loss = loss / len(testloader)
98
- return loss, accuracy
@@ -1,111 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- from collections import OrderedDict
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from flwr_datasets import FederatedDataset
9
- from flwr_datasets.partitioner import IidPartitioner
10
- from torch.utils.data import DataLoader
11
- from torchvision.transforms import Compose, Normalize, ToTensor
12
-
13
-
14
- class Net(nn.Module):
15
- """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
16
-
17
- def __init__(self):
18
- super(Net, self).__init__()
19
- self.conv1 = nn.Conv2d(3, 6, 5)
20
- self.pool = nn.MaxPool2d(2, 2)
21
- self.conv2 = nn.Conv2d(6, 16, 5)
22
- self.fc1 = nn.Linear(16 * 5 * 5, 120)
23
- self.fc2 = nn.Linear(120, 84)
24
- self.fc3 = nn.Linear(84, 10)
25
-
26
- def forward(self, x):
27
- x = self.pool(F.relu(self.conv1(x)))
28
- x = self.pool(F.relu(self.conv2(x)))
29
- x = x.view(-1, 16 * 5 * 5)
30
- x = F.relu(self.fc1(x))
31
- x = F.relu(self.fc2(x))
32
- return self.fc3(x)
33
-
34
-
35
- fds = None # Cache FederatedDataset
36
-
37
-
38
- def load_data(partition_id: int, num_partitions: int):
39
- """Load partition CIFAR10 data."""
40
- # Only initialize `FederatedDataset` once
41
- global fds
42
- if fds is None:
43
- partitioner = IidPartitioner(num_partitions=num_partitions)
44
- fds = FederatedDataset(
45
- dataset="uoft-cs/cifar10",
46
- partitioners={"train": partitioner},
47
- )
48
- partition = fds.load_partition(partition_id)
49
- # Divide data on each node: 80% train, 20% test
50
- partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
51
- pytorch_transforms = Compose(
52
- [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
53
- )
54
-
55
- def apply_transforms(batch):
56
- """Apply transforms to the partition from FederatedDataset."""
57
- batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
58
- return batch
59
-
60
- partition_train_test = partition_train_test.with_transform(apply_transforms)
61
- trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
62
- testloader = DataLoader(partition_train_test["test"], batch_size=32)
63
- return trainloader, testloader
64
-
65
-
66
- def train(net, trainloader, epochs, device):
67
- """Train the model on the training set."""
68
- net.to(device) # move model to GPU if available
69
- criterion = torch.nn.CrossEntropyLoss().to(device)
70
- optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
71
- net.train()
72
- running_loss = 0.0
73
- for _ in range(epochs):
74
- for batch in trainloader:
75
- images = batch["img"]
76
- labels = batch["label"]
77
- optimizer.zero_grad()
78
- loss = criterion(net(images.to(device)), labels.to(device))
79
- loss.backward()
80
- optimizer.step()
81
- running_loss += loss.item()
82
-
83
- avg_trainloss = running_loss / len(trainloader)
84
- return avg_trainloss
85
-
86
-
87
- def test(net, testloader, device):
88
- """Validate the model on the test set."""
89
- net.to(device)
90
- criterion = torch.nn.CrossEntropyLoss()
91
- correct, loss = 0, 0.0
92
- with torch.no_grad():
93
- for batch in testloader:
94
- images = batch["img"].to(device)
95
- labels = batch["label"].to(device)
96
- outputs = net(images)
97
- loss += criterion(outputs, labels).item()
98
- correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
99
- accuracy = correct / len(testloader.dataset)
100
- loss = loss / len(testloader)
101
- return loss, accuracy
102
-
103
-
104
- def get_weights(net):
105
- return [val.cpu().numpy() for _, val in net.state_dict().items()]
106
-
107
-
108
- def set_weights(net, parameters):
109
- params_dict = zip(net.state_dict().keys(), parameters)
110
- state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
111
- net.load_state_dict(state_dict, strict=True)