flwr 1.23.0__py3-none-any.whl → 1.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (339) hide show
  1. flwr/__init__.py +16 -5
  2. flwr/app/error.py +2 -2
  3. flwr/app/exception.py +3 -3
  4. flwr/cli/app.py +19 -0
  5. flwr/cli/{new/templates → app_cmd}/__init__.py +9 -1
  6. flwr/cli/app_cmd/publish.py +285 -0
  7. flwr/cli/app_cmd/review.py +262 -0
  8. flwr/cli/auth_plugin/auth_plugin.py +4 -5
  9. flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
  10. flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
  11. flwr/cli/build.py +60 -18
  12. flwr/cli/cli_account_auth_interceptor.py +24 -7
  13. flwr/cli/config_utils.py +101 -13
  14. flwr/cli/{new/templates/app/code/flwr_tune → federation}/__init__.py +10 -1
  15. flwr/cli/federation/ls.py +140 -0
  16. flwr/cli/federation/show.py +318 -0
  17. flwr/cli/install.py +91 -13
  18. flwr/cli/log.py +52 -9
  19. flwr/cli/login/login.py +7 -4
  20. flwr/cli/ls.py +211 -130
  21. flwr/cli/new/new.py +123 -331
  22. flwr/cli/pull.py +10 -5
  23. flwr/cli/run/run.py +71 -29
  24. flwr/cli/run_utils.py +148 -0
  25. flwr/cli/stop.py +26 -8
  26. flwr/cli/supernode/ls.py +25 -12
  27. flwr/cli/supernode/register.py +9 -4
  28. flwr/cli/supernode/unregister.py +5 -3
  29. flwr/cli/utils.py +239 -16
  30. flwr/client/__init__.py +1 -1
  31. flwr/client/dpfedavg_numpy_client.py +4 -1
  32. flwr/client/grpc_adapter_client/connection.py +8 -9
  33. flwr/client/grpc_rere_client/connection.py +16 -14
  34. flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
  35. flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
  36. flwr/client/message_handler/message_handler.py +2 -2
  37. flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
  38. flwr/client/numpy_client.py +1 -1
  39. flwr/client/rest_client/connection.py +18 -18
  40. flwr/client/run_info_store.py +4 -5
  41. flwr/client/typing.py +1 -1
  42. flwr/clientapp/client_app.py +9 -10
  43. flwr/clientapp/mod/centraldp_mods.py +16 -17
  44. flwr/clientapp/mod/localdp_mod.py +8 -9
  45. flwr/clientapp/typing.py +1 -1
  46. flwr/clientapp/utils.py +3 -3
  47. flwr/common/address.py +1 -2
  48. flwr/common/args.py +3 -4
  49. flwr/common/config.py +13 -16
  50. flwr/common/constant.py +5 -2
  51. flwr/common/differential_privacy.py +3 -4
  52. flwr/common/event_log_plugin/event_log_plugin.py +3 -4
  53. flwr/common/exit/exit.py +15 -2
  54. flwr/common/exit/exit_code.py +19 -0
  55. flwr/common/exit/exit_handler.py +6 -2
  56. flwr/common/exit/signal_handler.py +5 -5
  57. flwr/common/grpc.py +6 -6
  58. flwr/common/inflatable_protobuf_utils.py +1 -1
  59. flwr/common/inflatable_utils.py +38 -21
  60. flwr/common/logger.py +19 -19
  61. flwr/common/message.py +4 -4
  62. flwr/common/object_ref.py +7 -7
  63. flwr/common/record/array.py +3 -3
  64. flwr/common/record/arrayrecord.py +18 -30
  65. flwr/common/record/configrecord.py +3 -3
  66. flwr/common/record/recorddict.py +5 -5
  67. flwr/common/record/typeddict.py +9 -2
  68. flwr/common/recorddict_compat.py +7 -10
  69. flwr/common/retry_invoker.py +20 -20
  70. flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
  71. flwr/common/serde.py +11 -4
  72. flwr/common/serde_utils.py +2 -2
  73. flwr/common/telemetry.py +9 -5
  74. flwr/common/typing.py +58 -37
  75. flwr/compat/client/app.py +38 -37
  76. flwr/compat/client/grpc_client/connection.py +11 -11
  77. flwr/compat/server/app.py +5 -6
  78. flwr/proto/appio_pb2.py +13 -3
  79. flwr/proto/appio_pb2.pyi +134 -65
  80. flwr/proto/appio_pb2_grpc.py +20 -0
  81. flwr/proto/appio_pb2_grpc.pyi +27 -0
  82. flwr/proto/clientappio_pb2.py +17 -7
  83. flwr/proto/clientappio_pb2.pyi +15 -0
  84. flwr/proto/clientappio_pb2_grpc.py +206 -40
  85. flwr/proto/clientappio_pb2_grpc.pyi +168 -53
  86. flwr/proto/control_pb2.py +71 -52
  87. flwr/proto/control_pb2.pyi +277 -111
  88. flwr/proto/control_pb2_grpc.py +249 -40
  89. flwr/proto/control_pb2_grpc.pyi +185 -52
  90. flwr/proto/error_pb2.py +13 -3
  91. flwr/proto/error_pb2.pyi +24 -6
  92. flwr/proto/error_pb2_grpc.py +20 -0
  93. flwr/proto/error_pb2_grpc.pyi +27 -0
  94. flwr/proto/fab_pb2.py +14 -4
  95. flwr/proto/fab_pb2.pyi +59 -31
  96. flwr/proto/fab_pb2_grpc.py +20 -0
  97. flwr/proto/fab_pb2_grpc.pyi +27 -0
  98. flwr/proto/federation_pb2.py +38 -0
  99. flwr/proto/federation_pb2.pyi +56 -0
  100. flwr/proto/federation_pb2_grpc.py +24 -0
  101. flwr/proto/federation_pb2_grpc.pyi +31 -0
  102. flwr/proto/fleet_pb2.py +24 -14
  103. flwr/proto/fleet_pb2.pyi +141 -61
  104. flwr/proto/fleet_pb2_grpc.py +189 -48
  105. flwr/proto/fleet_pb2_grpc.pyi +175 -61
  106. flwr/proto/grpcadapter_pb2.py +14 -4
  107. flwr/proto/grpcadapter_pb2.pyi +38 -16
  108. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  109. flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
  110. flwr/proto/heartbeat_pb2.py +17 -7
  111. flwr/proto/heartbeat_pb2.pyi +51 -22
  112. flwr/proto/heartbeat_pb2_grpc.py +20 -0
  113. flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
  114. flwr/proto/log_pb2.py +13 -3
  115. flwr/proto/log_pb2.pyi +34 -11
  116. flwr/proto/log_pb2_grpc.py +20 -0
  117. flwr/proto/log_pb2_grpc.pyi +27 -0
  118. flwr/proto/message_pb2.py +15 -5
  119. flwr/proto/message_pb2.pyi +154 -86
  120. flwr/proto/message_pb2_grpc.py +20 -0
  121. flwr/proto/message_pb2_grpc.pyi +27 -0
  122. flwr/proto/node_pb2.py +15 -5
  123. flwr/proto/node_pb2.pyi +50 -25
  124. flwr/proto/node_pb2_grpc.py +20 -0
  125. flwr/proto/node_pb2_grpc.pyi +27 -0
  126. flwr/proto/recorddict_pb2.py +13 -3
  127. flwr/proto/recorddict_pb2.pyi +184 -107
  128. flwr/proto/recorddict_pb2_grpc.py +20 -0
  129. flwr/proto/recorddict_pb2_grpc.pyi +27 -0
  130. flwr/proto/run_pb2.py +40 -31
  131. flwr/proto/run_pb2.pyi +158 -84
  132. flwr/proto/run_pb2_grpc.py +20 -0
  133. flwr/proto/run_pb2_grpc.pyi +27 -0
  134. flwr/proto/serverappio_pb2.py +13 -3
  135. flwr/proto/serverappio_pb2.pyi +32 -8
  136. flwr/proto/serverappio_pb2_grpc.py +246 -65
  137. flwr/proto/serverappio_pb2_grpc.pyi +221 -85
  138. flwr/proto/simulationio_pb2.py +16 -8
  139. flwr/proto/simulationio_pb2.pyi +15 -0
  140. flwr/proto/simulationio_pb2_grpc.py +162 -41
  141. flwr/proto/simulationio_pb2_grpc.pyi +149 -55
  142. flwr/proto/transport_pb2.py +20 -10
  143. flwr/proto/transport_pb2.pyi +249 -160
  144. flwr/proto/transport_pb2_grpc.py +35 -4
  145. flwr/proto/transport_pb2_grpc.pyi +38 -8
  146. flwr/server/app.py +39 -17
  147. flwr/server/client_manager.py +4 -5
  148. flwr/server/client_proxy.py +10 -11
  149. flwr/server/compat/app.py +4 -5
  150. flwr/server/compat/app_utils.py +2 -1
  151. flwr/server/compat/grid_client_proxy.py +10 -12
  152. flwr/server/compat/legacy_context.py +3 -4
  153. flwr/server/fleet_event_log_interceptor.py +2 -1
  154. flwr/server/grid/grid.py +2 -3
  155. flwr/server/grid/grpc_grid.py +10 -8
  156. flwr/server/grid/inmemory_grid.py +4 -4
  157. flwr/server/run_serverapp.py +2 -3
  158. flwr/server/server.py +34 -39
  159. flwr/server/server_app.py +7 -8
  160. flwr/server/server_config.py +1 -2
  161. flwr/server/serverapp/app.py +34 -28
  162. flwr/server/serverapp_components.py +4 -5
  163. flwr/server/strategy/aggregate.py +9 -8
  164. flwr/server/strategy/bulyan.py +13 -11
  165. flwr/server/strategy/dp_adaptive_clipping.py +16 -20
  166. flwr/server/strategy/dp_fixed_clipping.py +12 -17
  167. flwr/server/strategy/dpfedavg_adaptive.py +3 -4
  168. flwr/server/strategy/dpfedavg_fixed.py +6 -10
  169. flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
  170. flwr/server/strategy/fedadagrad.py +18 -14
  171. flwr/server/strategy/fedadam.py +16 -14
  172. flwr/server/strategy/fedavg.py +16 -17
  173. flwr/server/strategy/fedavg_android.py +15 -15
  174. flwr/server/strategy/fedavgm.py +21 -18
  175. flwr/server/strategy/fedmedian.py +2 -3
  176. flwr/server/strategy/fedopt.py +11 -10
  177. flwr/server/strategy/fedprox.py +10 -9
  178. flwr/server/strategy/fedtrimmedavg.py +12 -11
  179. flwr/server/strategy/fedxgb_bagging.py +13 -11
  180. flwr/server/strategy/fedxgb_cyclic.py +6 -6
  181. flwr/server/strategy/fedxgb_nn_avg.py +4 -4
  182. flwr/server/strategy/fedyogi.py +16 -14
  183. flwr/server/strategy/krum.py +12 -11
  184. flwr/server/strategy/qfedavg.py +16 -15
  185. flwr/server/strategy/strategy.py +6 -9
  186. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
  187. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
  190. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
  192. flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
  193. flwr/server/superlink/fleet/message_handler/message_handler.py +75 -30
  194. flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
  195. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  196. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  197. flwr/server/superlink/fleet/vce/vce_api.py +15 -9
  198. flwr/server/superlink/linkstate/in_memory_linkstate.py +148 -149
  199. flwr/server/superlink/linkstate/linkstate.py +91 -43
  200. flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
  201. flwr/server/superlink/linkstate/sqlite_linkstate.py +502 -436
  202. flwr/server/superlink/linkstate/utils.py +6 -6
  203. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
  204. flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
  205. flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
  206. flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
  207. flwr/server/superlink/utils.py +4 -6
  208. flwr/server/typing.py +1 -1
  209. flwr/server/utils/tensorboard.py +15 -8
  210. flwr/server/workflow/default_workflows.py +5 -5
  211. flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
  213. flwr/serverapp/strategy/bulyan.py +16 -15
  214. flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
  215. flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
  216. flwr/serverapp/strategy/fedadagrad.py +10 -11
  217. flwr/serverapp/strategy/fedadam.py +10 -11
  218. flwr/serverapp/strategy/fedavg.py +9 -10
  219. flwr/serverapp/strategy/fedavgm.py +17 -16
  220. flwr/serverapp/strategy/fedmedian.py +2 -2
  221. flwr/serverapp/strategy/fedopt.py +10 -11
  222. flwr/serverapp/strategy/fedprox.py +7 -8
  223. flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
  224. flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
  225. flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
  226. flwr/serverapp/strategy/fedyogi.py +9 -11
  227. flwr/serverapp/strategy/krum.py +7 -7
  228. flwr/serverapp/strategy/multikrum.py +9 -9
  229. flwr/serverapp/strategy/qfedavg.py +17 -16
  230. flwr/serverapp/strategy/strategy.py +6 -9
  231. flwr/serverapp/strategy/strategy_utils.py +7 -8
  232. flwr/simulation/app.py +46 -42
  233. flwr/simulation/legacy_app.py +12 -12
  234. flwr/simulation/ray_transport/ray_actor.py +10 -11
  235. flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
  236. flwr/simulation/run_simulation.py +43 -43
  237. flwr/simulation/simulationio_connection.py +4 -4
  238. flwr/supercore/cli/flower_superexec.py +3 -4
  239. flwr/supercore/constant.py +34 -1
  240. flwr/supercore/corestate/corestate.py +24 -3
  241. flwr/supercore/corestate/in_memory_corestate.py +138 -0
  242. flwr/supercore/corestate/sqlite_corestate.py +157 -0
  243. flwr/supercore/ffs/disk_ffs.py +1 -2
  244. flwr/supercore/ffs/ffs.py +1 -2
  245. flwr/supercore/ffs/ffs_factory.py +1 -2
  246. flwr/{common → supercore}/heartbeat.py +20 -25
  247. flwr/supercore/object_store/in_memory_object_store.py +1 -2
  248. flwr/supercore/object_store/object_store.py +1 -2
  249. flwr/supercore/object_store/object_store_factory.py +1 -2
  250. flwr/supercore/object_store/sqlite_object_store.py +8 -7
  251. flwr/supercore/primitives/asymmetric.py +1 -1
  252. flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
  253. flwr/supercore/sqlite_mixin.py +37 -34
  254. flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
  255. flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
  256. flwr/supercore/superexec/run_superexec.py +9 -13
  257. flwr/supercore/utils.py +190 -0
  258. flwr/superlink/artifact_provider/artifact_provider.py +1 -2
  259. flwr/superlink/auth_plugin/auth_plugin.py +6 -9
  260. flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
  261. flwr/{cli/new/templates/app → superlink/federation}/__init__.py +10 -1
  262. flwr/superlink/federation/federation_manager.py +64 -0
  263. flwr/superlink/federation/noop_federation_manager.py +71 -0
  264. flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
  265. flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
  266. flwr/superlink/servicer/control/control_grpc.py +7 -6
  267. flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
  268. flwr/superlink/servicer/control/control_servicer.py +190 -23
  269. flwr/supernode/cli/flower_supernode.py +58 -3
  270. flwr/supernode/nodestate/in_memory_nodestate.py +121 -49
  271. flwr/supernode/nodestate/nodestate.py +52 -8
  272. flwr/supernode/nodestate/nodestate_factory.py +7 -4
  273. flwr/supernode/runtime/run_clientapp.py +41 -22
  274. flwr/supernode/servicer/clientappio/clientappio_servicer.py +46 -10
  275. flwr/supernode/start_client_internal.py +165 -46
  276. {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/METADATA +9 -11
  277. flwr-1.25.0.dist-info/RECORD +393 -0
  278. flwr/cli/new/templates/app/.gitignore.tpl +0 -163
  279. flwr/cli/new/templates/app/LICENSE.tpl +0 -202
  280. flwr/cli/new/templates/app/README.baseline.md.tpl +0 -127
  281. flwr/cli/new/templates/app/README.flowertune.md.tpl +0 -68
  282. flwr/cli/new/templates/app/README.md.tpl +0 -37
  283. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +0 -1
  284. flwr/cli/new/templates/app/code/__init__.py +0 -15
  285. flwr/cli/new/templates/app/code/__init__.py.tpl +0 -1
  286. flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +0 -1
  287. flwr/cli/new/templates/app/code/client.baseline.py.tpl +0 -75
  288. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +0 -93
  289. flwr/cli/new/templates/app/code/client.jax.py.tpl +0 -71
  290. flwr/cli/new/templates/app/code/client.mlx.py.tpl +0 -102
  291. flwr/cli/new/templates/app/code/client.numpy.py.tpl +0 -46
  292. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +0 -80
  293. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +0 -55
  294. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +0 -108
  295. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -82
  296. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +0 -110
  297. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +0 -36
  298. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +0 -92
  299. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +0 -87
  300. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -56
  301. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +0 -73
  302. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +0 -78
  303. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -66
  304. flwr/cli/new/templates/app/code/server.baseline.py.tpl +0 -43
  305. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +0 -42
  306. flwr/cli/new/templates/app/code/server.jax.py.tpl +0 -39
  307. flwr/cli/new/templates/app/code/server.mlx.py.tpl +0 -41
  308. flwr/cli/new/templates/app/code/server.numpy.py.tpl +0 -38
  309. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +0 -41
  310. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +0 -31
  311. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +0 -44
  312. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +0 -38
  313. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +0 -56
  314. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +0 -1
  315. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +0 -98
  316. flwr/cli/new/templates/app/code/task.jax.py.tpl +0 -57
  317. flwr/cli/new/templates/app/code/task.mlx.py.tpl +0 -102
  318. flwr/cli/new/templates/app/code/task.numpy.py.tpl +0 -7
  319. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +0 -98
  320. flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +0 -111
  321. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +0 -67
  322. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +0 -52
  323. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +0 -67
  324. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +0 -1
  325. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +0 -146
  326. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +0 -80
  327. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +0 -65
  328. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +0 -52
  329. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +0 -56
  330. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +0 -49
  331. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +0 -53
  332. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +0 -53
  333. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +0 -52
  334. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +0 -53
  335. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +0 -61
  336. flwr/supercore/object_store/utils.py +0 -43
  337. flwr-1.23.0.dist-info/RECORD +0 -439
  338. {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/WHEEL +0 -0
  339. {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/entry_points.txt +0 -0
@@ -1,71 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import jax
4
- from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
- from flwr.clientapp import ClientApp
6
-
7
- from $import_name.task import evaluation as evaluation_fn
8
- from $import_name.task import get_params, load_data, load_model, loss_fn, set_params
9
- from $import_name.task import train as train_fn
10
-
11
- # Flower ClientApp
12
- app = ClientApp()
13
-
14
-
15
- @app.train()
16
- def train(msg: Message, context: Context):
17
- """Train the model on local data."""
18
-
19
- # Read from config
20
- input_dim = context.run_config["input-dim"]
21
-
22
- # Load data and model
23
- train_x, train_y, _, _ = load_data()
24
- model = load_model((input_dim,))
25
- grad_fn = jax.grad(loss_fn)
26
-
27
- # Set model parameters
28
- ndarrays = msg.content["arrays"].to_numpy_ndarrays()
29
- set_params(model, ndarrays)
30
-
31
- # Train the model on local data
32
- model, loss, num_examples = train_fn(model, grad_fn, train_x, train_y)
33
-
34
- # Construct and return reply Message
35
- model_record = ArrayRecord(get_params(model))
36
- metrics = {
37
- "train_loss": float(loss),
38
- "num-examples": num_examples,
39
- }
40
- metric_record = MetricRecord(metrics)
41
- content = RecordDict({"arrays": model_record, "metrics": metric_record})
42
- return Message(content=content, reply_to=msg)
43
-
44
-
45
- @app.evaluate()
46
- def evaluate(msg: Message, context: Context):
47
- """Evaluate the model on local data."""
48
-
49
- # Read from config
50
- input_dim = context.run_config["input-dim"]
51
-
52
- # Load data and model
53
- _, _, test_x, test_y = load_data()
54
- model = load_model((input_dim,))
55
- grad_fn = jax.grad(loss_fn)
56
-
57
- # Set model parameters
58
- ndarrays = msg.content["arrays"].to_numpy_ndarrays()
59
- set_params(model, ndarrays)
60
-
61
- # Evaluate the model on local data
62
- loss, num_examples = evaluation_fn(model, grad_fn, test_x, test_y)
63
-
64
- # Construct and return reply Message
65
- metrics = {
66
- "test_loss": float(loss),
67
- "num-examples": num_examples,
68
- }
69
- metric_record = MetricRecord(metrics)
70
- content = RecordDict({"metrics": metric_record})
71
- return Message(content=content, reply_to=msg)
@@ -1,102 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import mlx.core as mx
4
- import mlx.nn as nn
5
- import mlx.optimizers as optim
6
- from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
7
- from flwr.clientapp import ClientApp
8
-
9
- from $import_name.task import (
10
- MLP,
11
- batch_iterate,
12
- eval_fn,
13
- get_params,
14
- load_data,
15
- loss_fn,
16
- set_params,
17
- )
18
-
19
- # Flower ClientApp
20
- app = ClientApp()
21
-
22
-
23
- @app.train()
24
- def train(msg: Message, context: Context):
25
- """Train the model on local data."""
26
-
27
- # Read config
28
- num_layers = context.run_config["num-layers"]
29
- input_dim = context.run_config["input-dim"]
30
- hidden_dim = context.run_config["hidden-dim"]
31
- batch_size = context.run_config["batch-size"]
32
- learning_rate = context.run_config["lr"]
33
- num_epochs = context.run_config["local-epochs"]
34
-
35
- # Instantiate model and apply global parameters
36
- model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
37
- ndarrays = msg.content["arrays"].to_numpy_ndarrays()
38
- set_params(model, ndarrays)
39
-
40
- # Define optimizer and loss function
41
- optimizer = optim.SGD(learning_rate=learning_rate)
42
- loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
43
-
44
- # Load data
45
- partition_id = context.node_config["partition-id"]
46
- num_partitions = context.node_config["num-partitions"]
47
- train_images, train_labels, _, _ = load_data(partition_id, num_partitions)
48
-
49
- # Train the model on local data
50
- for _ in range(num_epochs):
51
- for X, y in batch_iterate(batch_size, train_images, train_labels):
52
- _, grads = loss_and_grad_fn(model, X, y)
53
- optimizer.update(model, grads)
54
- mx.eval(model.parameters(), optimizer.state)
55
-
56
- # Compute train accuracy and loss
57
- accuracy = eval_fn(model, train_images, train_labels)
58
- loss = loss_fn(model, train_images, train_labels)
59
- # Construct and return reply Message
60
- model_record = ArrayRecord(get_params(model))
61
- metrics = {
62
- "num-examples": len(train_images),
63
- "accuracy": float(accuracy.item()),
64
- "loss": float(loss.item()),
65
- }
66
- metric_record = MetricRecord(metrics)
67
- content = RecordDict({"arrays": model_record, "metrics": metric_record})
68
- return Message(content=content, reply_to=msg)
69
-
70
-
71
- @app.evaluate()
72
- def evaluate(msg: Message, context: Context):
73
- """Evaluate the model on local data."""
74
-
75
- # Read config
76
- num_layers = context.run_config["num-layers"]
77
- input_dim = context.run_config["input-dim"]
78
- hidden_dim = context.run_config["hidden-dim"]
79
-
80
- # Instantiate model and apply global parameters
81
- model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
82
- ndarrays = msg.content["arrays"].to_numpy_ndarrays()
83
- set_params(model, ndarrays)
84
-
85
- # Load data
86
- partition_id = context.node_config["partition-id"]
87
- num_partitions = context.node_config["num-partitions"]
88
- _, _, test_images, test_labels = load_data(partition_id, num_partitions)
89
-
90
- # Evaluate the model on local data
91
- accuracy = eval_fn(model, test_images, test_labels)
92
- loss = loss_fn(model, test_images, test_labels)
93
-
94
- # Construct and return reply Message
95
- metrics = {
96
- "num-examples": len(test_images),
97
- "accuracy": float(accuracy.item()),
98
- "loss": float(loss.item()),
99
- }
100
- metric_record = MetricRecord(metrics)
101
- content = RecordDict({"metrics": metric_record})
102
- return Message(content=content, reply_to=msg)
@@ -1,46 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import numpy as np
4
- from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
- from flwr.clientapp import ClientApp
6
-
7
- # Flower ClientApp
8
- app = ClientApp()
9
-
10
-
11
- @app.train()
12
- def train(msg: Message, context: Context):
13
- """Train the model on local data."""
14
-
15
- # The model is the global arrays
16
- ndarrays = msg.content["arrays"].to_numpy_ndarrays()
17
-
18
- # Simulate local training (here we just add random noise to model parameters)
19
- model = [m + np.random.rand(*m.shape) for m in ndarrays]
20
-
21
- # Construct and return reply Message
22
- model_record = ArrayRecord(model)
23
- metrics = {
24
- "random_metric": np.random.rand(),
25
- "num-examples": 1,
26
- }
27
- metric_record = MetricRecord(metrics)
28
- content = RecordDict({"arrays": model_record, "metrics": metric_record})
29
- return Message(content=content, reply_to=msg)
30
-
31
-
32
- @app.evaluate()
33
- def evaluate(msg: Message, context: Context):
34
- """Evaluate the model on local data."""
35
-
36
- # The model is the global arrays
37
- ndarrays = msg.content["arrays"].to_numpy_ndarrays()
38
-
39
- # Return reply Message
40
- metrics = {
41
- "random_metric": np.random.rand(3).tolist(),
42
- "num-examples": 1,
43
- }
44
- metric_record = MetricRecord(metrics)
45
- content = RecordDict({"metrics": metric_record})
46
- return Message(content=content, reply_to=msg)
@@ -1,80 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import torch
4
- from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
- from flwr.clientapp import ClientApp
6
-
7
- from $import_name.task import Net, load_data
8
- from $import_name.task import test as test_fn
9
- from $import_name.task import train as train_fn
10
-
11
- # Flower ClientApp
12
- app = ClientApp()
13
-
14
-
15
- @app.train()
16
- def train(msg: Message, context: Context):
17
- """Train the model on local data."""
18
-
19
- # Load the model and initialize it with the received weights
20
- model = Net()
21
- model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
22
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
- model.to(device)
24
-
25
- # Load the data
26
- partition_id = context.node_config["partition-id"]
27
- num_partitions = context.node_config["num-partitions"]
28
- trainloader, _ = load_data(partition_id, num_partitions)
29
-
30
- # Call the training function
31
- train_loss = train_fn(
32
- model,
33
- trainloader,
34
- context.run_config["local-epochs"],
35
- msg.content["config"]["lr"],
36
- device,
37
- )
38
-
39
- # Construct and return reply Message
40
- model_record = ArrayRecord(model.state_dict())
41
- metrics = {
42
- "train_loss": train_loss,
43
- "num-examples": len(trainloader.dataset),
44
- }
45
- metric_record = MetricRecord(metrics)
46
- content = RecordDict({"arrays": model_record, "metrics": metric_record})
47
- return Message(content=content, reply_to=msg)
48
-
49
-
50
- @app.evaluate()
51
- def evaluate(msg: Message, context: Context):
52
- """Evaluate the model on local data."""
53
-
54
- # Load the model and initialize it with the received weights
55
- model = Net()
56
- model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
57
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
- model.to(device)
59
-
60
- # Load the data
61
- partition_id = context.node_config["partition-id"]
62
- num_partitions = context.node_config["num-partitions"]
63
- _, valloader = load_data(partition_id, num_partitions)
64
-
65
- # Call the evaluation function
66
- eval_loss, eval_acc = test_fn(
67
- model,
68
- valloader,
69
- device,
70
- )
71
-
72
- # Construct and return reply Message
73
- metrics = {
74
- "eval_loss": eval_loss,
75
- "eval_acc": eval_acc,
76
- "num-examples": len(valloader.dataset),
77
- }
78
- metric_record = MetricRecord(metrics)
79
- content = RecordDict({"metrics": metric_record})
80
- return Message(content=content, reply_to=msg)
@@ -1,55 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import torch
4
-
5
- from flwr.client import ClientApp, NumPyClient
6
- from flwr.common import Context
7
- from $import_name.task import Net, get_weights, load_data, set_weights, test, train
8
-
9
-
10
- # Define Flower Client and client_fn
11
- class FlowerClient(NumPyClient):
12
- def __init__(self, net, trainloader, valloader, local_epochs):
13
- self.net = net
14
- self.trainloader = trainloader
15
- self.valloader = valloader
16
- self.local_epochs = local_epochs
17
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
- self.net.to(self.device)
19
-
20
- def fit(self, parameters, config):
21
- set_weights(self.net, parameters)
22
- train_loss = train(
23
- self.net,
24
- self.trainloader,
25
- self.local_epochs,
26
- self.device,
27
- )
28
- return (
29
- get_weights(self.net),
30
- len(self.trainloader.dataset),
31
- {"train_loss": train_loss},
32
- )
33
-
34
- def evaluate(self, parameters, config):
35
- set_weights(self.net, parameters)
36
- loss, accuracy = test(self.net, self.valloader, self.device)
37
- return loss, len(self.valloader.dataset), {"accuracy": accuracy}
38
-
39
-
40
- def client_fn(context: Context):
41
- # Load model and data
42
- net = Net()
43
- partition_id = context.node_config["partition-id"]
44
- num_partitions = context.node_config["num-partitions"]
45
- trainloader, valloader = load_data(partition_id, num_partitions)
46
- local_epochs = context.run_config["local-epochs"]
47
-
48
- # Return Client instance
49
- return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
50
-
51
-
52
- # Flower ClientApp
53
- app = ClientApp(
54
- client_fn,
55
- )
@@ -1,108 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import warnings
4
-
5
- from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
6
- from flwr.clientapp import ClientApp
7
- from sklearn.metrics import (
8
- accuracy_score,
9
- f1_score,
10
- log_loss,
11
- precision_score,
12
- recall_score,
13
- )
14
-
15
- from $import_name.task import (
16
- get_model,
17
- get_model_params,
18
- load_data,
19
- set_initial_params,
20
- set_model_params,
21
- )
22
-
23
- # Flower ClientApp
24
- app = ClientApp()
25
-
26
-
27
- @app.train()
28
- def train(msg: Message, context: Context):
29
- """Train the model on local data."""
30
-
31
- # Create LogisticRegression Model
32
- penalty = context.run_config["penalty"]
33
- local_epochs = context.run_config["local-epochs"]
34
- model = get_model(penalty, local_epochs)
35
- # Setting initial parameters, akin to model.compile for keras models
36
- set_initial_params(model)
37
-
38
- # Apply received pararameters
39
- ndarrays = msg.content["arrays"].to_numpy_ndarrays()
40
- set_model_params(model, ndarrays)
41
-
42
- # Load the data
43
- partition_id = context.node_config["partition-id"]
44
- num_partitions = context.node_config["num-partitions"]
45
- X_train, _, y_train, _ = load_data(partition_id, num_partitions)
46
-
47
- # Ignore convergence failure due to low local epochs
48
- with warnings.catch_warnings():
49
- warnings.simplefilter("ignore")
50
- # Train the model on local data
51
- model.fit(X_train, y_train)
52
-
53
- # Let's compute train loss
54
- y_train_pred_proba = model.predict_proba(X_train)
55
- train_logloss = log_loss(y_train, y_train_pred_proba)
56
-
57
- # Construct and return reply Message
58
- ndarrays = get_model_params(model)
59
- model_record = ArrayRecord(ndarrays)
60
- metrics = {"num-examples": len(X_train), "train_logloss": train_logloss}
61
- metric_record = MetricRecord(metrics)
62
- content = RecordDict({"arrays": model_record, "metrics": metric_record})
63
- return Message(content=content, reply_to=msg)
64
-
65
-
66
- @app.evaluate()
67
- def evaluate(msg: Message, context: Context):
68
- """Evaluate the model on test data."""
69
-
70
- # Create LogisticRegression Model
71
- penalty = context.run_config["penalty"]
72
- local_epochs = context.run_config["local-epochs"]
73
- model = get_model(penalty, local_epochs)
74
-
75
- # Setting initial parameters, akin to model.compile for keras models
76
- set_initial_params(model)
77
-
78
- # Apply received pararameters
79
- ndarrays = msg.content["arrays"].to_numpy_ndarrays()
80
- set_model_params(model, ndarrays)
81
-
82
- # Load the data
83
- partition_id = context.node_config["partition-id"]
84
- num_partitions = context.node_config["num-partitions"]
85
- _, X_test, _, y_test = load_data(partition_id, num_partitions)
86
-
87
- # Evaluate the model on local data
88
- y_train_pred = model.predict(X_test)
89
- y_train_pred_proba = model.predict_proba(X_test)
90
-
91
- accuracy = accuracy_score(y_test, y_train_pred)
92
- loss = log_loss(y_test, y_train_pred_proba)
93
- precision = precision_score(y_test, y_train_pred, average="macro", zero_division=0)
94
- recall = recall_score(y_test, y_train_pred, average="macro", zero_division=0)
95
- f1 = f1_score(y_test, y_train_pred, average="macro", zero_division=0)
96
-
97
- # Construct and return reply Message
98
- metrics = {
99
- "num-examples": len(X_test),
100
- "test_logloss": loss,
101
- "accuracy": accuracy,
102
- "precision": precision,
103
- "recall": recall,
104
- "f1": f1,
105
- }
106
- metric_record = MetricRecord(metrics)
107
- content = RecordDict({"metrics": metric_record})
108
- return Message(content=content, reply_to=msg)
@@ -1,82 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
4
- from flwr.clientapp import ClientApp
5
-
6
- from $import_name.task import load_data, load_model
7
-
8
- # Flower ClientApp
9
- app = ClientApp()
10
-
11
-
12
- @app.train()
13
- def train(msg: Message, context: Context):
14
- """Train the model on local data."""
15
-
16
- # Load the model and initialize it with the received weights
17
- model = load_model()
18
- ndarrays = msg.content["arrays"].to_numpy_ndarrays()
19
- model.set_weights(ndarrays)
20
-
21
- # Read from config
22
- epochs = context.run_config["local-epochs"]
23
- batch_size = context.run_config["batch-size"]
24
- verbose = context.run_config.get("verbose")
25
-
26
- # Load the data
27
- partition_id = context.node_config["partition-id"]
28
- num_partitions = context.node_config["num-partitions"]
29
- x_train, y_train, _, _ = load_data(partition_id, num_partitions)
30
-
31
- # Train the model on local data
32
- history = model.fit(
33
- x_train,
34
- y_train,
35
- epochs=epochs,
36
- batch_size=batch_size,
37
- verbose=verbose,
38
- )
39
-
40
- # Get final training loss and accuracy
41
- train_loss = history.history["loss"][-1] if "loss" in history.history else None
42
- train_acc = history.history.get("accuracy")
43
- train_acc = train_acc[-1] if train_acc is not None else None
44
-
45
- # Construct and return reply Message
46
- model_record = ArrayRecord(model.get_weights())
47
- metrics = {"num-examples": len(x_train)}
48
- if train_loss is not None:
49
- metrics["train_loss"] = train_loss
50
- if train_acc is not None:
51
- metrics["train_acc"] = train_acc
52
- metric_record = MetricRecord(metrics)
53
- content = RecordDict({"arrays": model_record, "metrics": metric_record})
54
- return Message(content=content, reply_to=msg)
55
-
56
-
57
- @app.evaluate()
58
- def evaluate(msg: Message, context: Context):
59
- """Evaluate the model on local data."""
60
-
61
- # Load the model and initialize it with the received weights
62
- model = load_model()
63
- ndarrays = msg.content["arrays"].to_numpy_ndarrays()
64
- model.set_weights(ndarrays)
65
-
66
- # Load the data
67
- partition_id = context.node_config["partition-id"]
68
- num_partitions = context.node_config["num-partitions"]
69
- _, _, x_test, y_test = load_data(partition_id, num_partitions)
70
-
71
- # Evaluate the model on local data
72
- loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
73
-
74
- # Construct and return reply Message
75
- metrics = {
76
- "eval_loss": loss,
77
- "eval_acc": accuracy,
78
- "num-examples": len(x_test),
79
- }
80
- metric_record = MetricRecord(metrics)
81
- content = RecordDict({"metrics": metric_record})
82
- return Message(content=content, reply_to=msg)
@@ -1,110 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import warnings
4
-
5
- import numpy as np
6
- import xgboost as xgb
7
- from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
8
- from flwr.clientapp import ClientApp
9
- from flwr.common.config import unflatten_dict
10
-
11
- from $import_name.task import load_data, replace_keys
12
-
13
- warnings.filterwarnings("ignore", category=UserWarning)
14
-
15
-
16
- # Flower ClientApp
17
- app = ClientApp()
18
-
19
-
20
- def _local_boost(bst_input, num_local_round, train_dmatrix):
21
- # Update trees based on local training data.
22
- for i in range(num_local_round):
23
- bst_input.update(train_dmatrix, bst_input.num_boosted_rounds())
24
-
25
- # Bagging: extract the last N=num_local_round trees for sever aggregation
26
- bst = bst_input[
27
- bst_input.num_boosted_rounds()
28
- - num_local_round : bst_input.num_boosted_rounds()
29
- ]
30
- return bst
31
-
32
-
33
- @app.train()
34
- def train(msg: Message, context: Context) -> Message:
35
- # Load model and data
36
- partition_id = context.node_config["partition-id"]
37
- num_partitions = context.node_config["num-partitions"]
38
- train_dmatrix, _, num_train, _ = load_data(partition_id, num_partitions)
39
-
40
- # Read from run config
41
- num_local_round = context.run_config["local-epochs"]
42
- # Flatted config dict and replace "-" with "_"
43
- cfg = replace_keys(unflatten_dict(context.run_config))
44
- params = cfg["params"]
45
-
46
- global_round = msg.content["config"]["server-round"]
47
- if global_round == 1:
48
- # First round local training
49
- bst = xgb.train(
50
- params,
51
- train_dmatrix,
52
- num_boost_round=num_local_round,
53
- )
54
- else:
55
- bst = xgb.Booster(params=params)
56
- global_model = bytearray(msg.content["arrays"]["0"].numpy().tobytes())
57
-
58
- # Load global model into booster
59
- bst.load_model(global_model)
60
-
61
- # Local training
62
- bst = _local_boost(bst, num_local_round, train_dmatrix)
63
-
64
- # Save model
65
- local_model = bst.save_raw("json")
66
- model_np = np.frombuffer(local_model, dtype=np.uint8)
67
-
68
- # Construct reply message
69
- # Note: we store the model as the first item in a list into ArrayRecord,
70
- # which can be accessed using index ["0"].
71
- model_record = ArrayRecord([model_np])
72
- metrics = {
73
- "num-examples": num_train,
74
- }
75
- metric_record = MetricRecord(metrics)
76
- content = RecordDict({"arrays": model_record, "metrics": metric_record})
77
- return Message(content=content, reply_to=msg)
78
-
79
-
80
- @app.evaluate()
81
- def evaluate(msg: Message, context: Context) -> Message:
82
- # Load model and data
83
- partition_id = context.node_config["partition-id"]
84
- num_partitions = context.node_config["num-partitions"]
85
- _, valid_dmatrix, _, num_val = load_data(partition_id, num_partitions)
86
-
87
- # Load config
88
- cfg = replace_keys(unflatten_dict(context.run_config))
89
- params = cfg["params"]
90
-
91
- # Load global model
92
- bst = xgb.Booster(params=params)
93
- global_model = bytearray(msg.content["arrays"]["0"].numpy().tobytes())
94
- bst.load_model(global_model)
95
-
96
- # Run evaluation
97
- eval_results = bst.eval_set(
98
- evals=[(valid_dmatrix, "valid")],
99
- iteration=bst.num_boosted_rounds() - 1,
100
- )
101
- auc = float(eval_results.split("\t")[1].split(":")[1])
102
-
103
- # Construct and return reply Message
104
- metrics = {
105
- "auc": auc,
106
- "num-examples": num_val,
107
- }
108
- metric_record = MetricRecord(metrics)
109
- content = RecordDict({"metrics": metric_record})
110
- return Message(content=content, reply_to=msg)
@@ -1,36 +0,0 @@
1
- """$project_name: A Flower Baseline."""
2
-
3
- from flwr_datasets import FederatedDataset
4
- from flwr_datasets.partitioner import IidPartitioner
5
- from torch.utils.data import DataLoader
6
- from torchvision.transforms import Compose, Normalize, ToTensor
7
-
8
- FDS = None # Cache FederatedDataset
9
-
10
-
11
- def load_data(partition_id: int, num_partitions: int):
12
- """Load partition CIFAR10 data."""
13
- # Only initialize `FederatedDataset` once
14
- global FDS # pylint: disable=global-statement
15
- if FDS is None:
16
- partitioner = IidPartitioner(num_partitions=num_partitions)
17
- FDS = FederatedDataset(
18
- dataset="uoft-cs/cifar10",
19
- partitioners={"train": partitioner},
20
- )
21
- partition = FDS.load_partition(partition_id)
22
- # Divide data on each node: 80% train, 20% test
23
- partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
24
- pytorch_transforms = Compose(
25
- [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
26
- )
27
-
28
- def apply_transforms(batch):
29
- """Apply transforms to the partition from FederatedDataset."""
30
- batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
31
- return batch
32
-
33
- partition_train_test = partition_train_test.with_transform(apply_transforms)
34
- trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
35
- testloader = DataLoader(partition_train_test["test"], batch_size=32)
36
- return trainloader, testloader