fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (264) hide show
  1. fusion_bench/compat/method/__init__.py +1 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -1
  3. fusion_bench/compat/modelpool/__init__.py +1 -1
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/arc.py +5 -0
  6. fusion_bench/dataset/arc_agi/preprocess.py +1 -1
  7. fusion_bench/dataset/clip_dataset.py +3 -0
  8. fusion_bench/dataset/fer2013.py +12 -0
  9. fusion_bench/dataset/llama/__init__.py +1 -0
  10. fusion_bench/dataset/llama/alpaca.py +93 -3
  11. fusion_bench/dataset/llama/collate.py +62 -2
  12. fusion_bench/dataset/llama/metamathqa.py +50 -0
  13. fusion_bench/dataset/llama/preference_700k.py +70 -0
  14. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  15. fusion_bench/dataset/llama/ultrachat.py +58 -0
  16. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  17. fusion_bench/method/__init__.py +3 -1
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  19. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  20. fusion_bench/method/classification/clip_finetune.py +10 -13
  21. fusion_bench/method/linear/expo.py +39 -0
  22. fusion_bench/method/lm_finetune/__init__.py +1 -0
  23. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  24. fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
  25. fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
  26. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  27. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  28. fusion_bench/method/surgery/__init__.py +1 -0
  29. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  30. fusion_bench/method/tall_mask/__init__.py +0 -0
  31. fusion_bench/method/tall_mask/utils.py +234 -0
  32. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  33. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  34. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  35. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  36. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
  37. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  38. fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
  39. fusion_bench/mixins/__init__.py +2 -0
  40. fusion_bench/mixins/clip_classification.py +64 -11
  41. fusion_bench/mixins/fabric_training.py +320 -0
  42. fusion_bench/mixins/lightning_fabric.py +12 -1
  43. fusion_bench/modelpool/__init__.py +2 -0
  44. fusion_bench/modelpool/base_pool.py +0 -1
  45. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  46. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  47. fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
  48. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  49. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  50. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  51. fusion_bench/models/chat_templates/__init__.py +1 -0
  52. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  53. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  54. fusion_bench/models/hf_clip.py +50 -9
  55. fusion_bench/models/surgery/__init__.py +1 -0
  56. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  57. fusion_bench/models/utils.py +8 -0
  58. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  59. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  60. fusion_bench/optim/__init__.py +2 -0
  61. fusion_bench/optim/exception.py +47 -0
  62. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  63. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  64. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  65. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  66. fusion_bench/optim/mezo.py +0 -2
  67. fusion_bench/programs/fabric_fusion_program.py +12 -5
  68. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  69. fusion_bench/taskpool/llama/reward_model.py +157 -0
  70. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  71. fusion_bench/tasks/clip_classification/__init__.py +13 -45
  72. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
  73. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  74. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  75. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  76. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  77. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  78. fusion_bench/tasks/clip_classification/food101.py +105 -0
  79. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  80. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  81. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  82. fusion_bench/utils/hydra_utils.py +22 -0
  83. fusion_bench/utils/parameters.py +12 -3
  84. fusion_bench/utils/plot/__init__.py +0 -0
  85. fusion_bench/utils/plot/token.py +52 -0
  86. fusion_bench/utils/plot/token_notebook.py +127 -0
  87. fusion_bench/utils/type.py +14 -3
  88. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
  89. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
  90. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  91. fusion_bench_config/dataset/image_classification/README.md +6 -0
  92. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  93. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  94. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
  95. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
  96. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  97. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  98. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  99. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  100. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  101. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  102. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  103. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  104. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  105. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  106. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  107. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  108. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  109. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  110. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  111. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
  112. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
  113. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  114. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  115. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  116. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  117. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  118. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  119. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  120. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  121. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  122. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  123. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  124. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  125. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  126. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  127. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  128. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  129. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  130. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  131. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  132. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  133. fusion_bench_config/llama_full_finetune.yaml +19 -0
  134. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  135. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
  136. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
  137. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  139. fusion_bench_config/model/clip-vit/README.md +38 -0
  140. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
  141. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  143. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  144. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  145. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
  146. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  147. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
  148. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  149. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  150. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  151. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
  152. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  153. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
  154. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  155. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  156. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  157. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  158. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
  159. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
  160. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  161. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
  162. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
  163. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
  164. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  165. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  166. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  167. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  168. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
  169. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
  170. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  171. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
  172. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  173. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  174. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  175. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
  176. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  177. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
  178. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  179. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  180. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  181. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  182. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
  183. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
  184. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  185. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
  186. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
  187. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
  188. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  189. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  190. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  191. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  192. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
  193. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  194. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
  195. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  196. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  197. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  198. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
  199. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  200. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
  201. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  202. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  203. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  204. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  205. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
  206. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
  207. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  208. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
  209. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
  210. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  211. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  212. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  213. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  214. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  215. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  216. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
  217. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  218. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  219. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  220. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  221. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
  222. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  223. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  224. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  225. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  226. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  227. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
  228. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  229. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  230. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  231. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  232. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  233. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  234. fusion_bench_config/nyuv2_config.yaml +5 -1
  235. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  236. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  237. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  238. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  239. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  240. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  241. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  242. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  243. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  244. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  245. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  246. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  247. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  248. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  249. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  250. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  251. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  252. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  253. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  254. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  255. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  256. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  257. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  258. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  259. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  260. fusion_bench_config/llama_weighted_average.yaml +0 -26
  261. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
  262. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
  263. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
  264. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
- from typing import Literal, Optional, Union # noqa: F401
1
+ from typing import Dict, Literal, Optional, Union # noqa: F401
2
2
 
3
3
  import torch
4
- from torch import Dict, nn
4
+ from torch import nn
5
5
  from tqdm.auto import tqdm
6
6
  from transformers import LlamaForCausalLM, LlamaModel
7
7
 
@@ -0,0 +1 @@
1
+ from .clip_layer_wise_adamerging_surgery import CLIPLayerWiseAdaMergingSurgeryAlgorithm
@@ -0,0 +1,157 @@
1
+ """
2
+ Implementation of the Layer-Wise AdaMerging+Surgery Algorithm.
3
+
4
+ For more details, please refer to:
5
+
6
+ - (ICLR 2024) Yang, et.al. AdaMerging: Adaptive Model Merging for Multi-Task Learning. http://arxiv.org/abs/2310.02575
7
+ - (ICML 2024) Yang, et.al. Representation Surgery for Multi-Task Model Merging. https://arxiv.org/abs/2402.02705
8
+
9
+ Basic Example:
10
+
11
+ ```shell
12
+ fusion_bench \
13
+ method=surgery/adamerging_surgery \
14
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
15
+ taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
16
+ ```
17
+ """
18
+
19
+ import copy
20
+ import functools
21
+ import gc
22
+ import logging
23
+ from typing import TYPE_CHECKING, cast
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from torch.utils.data import DataLoader
28
+ from tqdm import tqdm
29
+ from transformers import CLIPVisionModel
30
+
31
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
32
+ from fusion_bench.method.adamerging.layer_wise_adamerging import (
33
+ LayerWiseAdaMergingAlgorithm,
34
+ )
35
+ from fusion_bench.method.adamerging.utils import get_memory_usage
36
+ from fusion_bench.mixins import CLIPClassificationMixin
37
+ from fusion_bench.modelpool import CLIPVisionModelPool
38
+ from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
39
+ from fusion_bench.models.wrappers.layer_wise_fusion import LayerWiseMergedModel
40
+
41
+ log = logging.getLogger(__name__)
42
+
43
+
44
+ class CLIPLayerWiseAdaMergingSurgeryAlgorithm(
45
+ CLIPClassificationMixin,
46
+ LayerWiseAdaMergingAlgorithm,
47
+ ):
48
+
49
+ def on_test_time_adaptation_start(self):
50
+ """
51
+ Here we load the CLIP processor and construct the zero-shot classification head for each task.
52
+ """
53
+ self.setup_zero_shot_classification_head()
54
+
55
+ @functools.cache
56
+ def get_shuffled_test_loader_iter(self, task: str):
57
+ return super().get_shuffled_test_loader_iter(
58
+ task,
59
+ batch_size=self.config.batch_size,
60
+ num_workers=self.config.num_workers,
61
+ )
62
+
63
+ def run(self, modelpool: CLIPVisionModelPool, **kwargs):
64
+ """
65
+ Run the Layer-Wise AdaMerging+Surgery Algorithm.
66
+
67
+ This method constructs the wrapped model and performs test-time adaptation if necessary. Then, it will perform surgery.
68
+
69
+ Args:
70
+ modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.
71
+
72
+ Returns:
73
+ LayerWiseMergedModel: The merged model after test-time adaptation.
74
+ """
75
+ log.info("Fusing models using layer-wise adaptive merging.")
76
+ self.modelpool = modelpool
77
+ self.log_hyperparams(self.config)
78
+
79
+ # === Start of the AdaMerging Algorithm ===
80
+ with self.profile("construct the wrapped model"):
81
+ module = cast(
82
+ LayerWiseMergedModel[CLIPVisionModel],
83
+ self.construct_layer_wise_merged_model(modelpool),
84
+ )
85
+
86
+ if self.config.weights is not None:
87
+ # skip the test-time adaptation
88
+ merged_model = copy.deepcopy(module.merge_and_unload())
89
+ else:
90
+ with self.profile("test-time adaptation"):
91
+ module = self.test_time_adaptation(module)
92
+ if self.config.get("save_merging_weights", False):
93
+ self.save_merging_weights(
94
+ self.config.save_merging_weights, module.merge_weight
95
+ )
96
+ merged_model = copy.deepcopy(module.merge_and_unload())
97
+
98
+ # free memory
99
+ del module
100
+ gc.collect()
101
+ torch.cuda.empty_cache()
102
+
103
+ # === Start of the Surgery Algorithm ===
104
+ log.info("start performing Surgery")
105
+ alpha_model = SurgeryModelWrapper(
106
+ merged_model,
107
+ modelpool.model_names,
108
+ projection_dim=merged_model.config.projection_dim,
109
+ )
110
+ alpha_model = self.fabric.setup(alpha_model)
111
+ log.info(get_memory_usage("after freeing memory, the memory usage of GPU is:"))
112
+
113
+ optimizer = torch.optim.Adam(
114
+ alpha_model.collect_trainable_params(),
115
+ lr=1e-3,
116
+ betas=(0.9, 0.999),
117
+ weight_decay=0.0,
118
+ )
119
+
120
+ finetuned_models = {
121
+ model_name: modelpool.load_model(model_name)
122
+ for model_name in modelpool.model_names
123
+ }
124
+ for name, model in finetuned_models.items():
125
+ model.requires_grad_(False)
126
+ model = self.fabric.to_device(model)
127
+ model.eval()
128
+
129
+ for iteration in tqdm(
130
+ range(self.config.surgery_steps),
131
+ "surgery",
132
+ dynamic_ncols=True,
133
+ ):
134
+ for dataset_name in modelpool.model_names:
135
+ batch = next(self.get_shuffled_test_loader_iter(dataset_name))
136
+ finetuned_feature = self.compute_features(
137
+ finetuned_models[dataset_name], batch[0]
138
+ )
139
+ features, _, _ = alpha_model.compute_surgery_features(
140
+ lambda model: self.compute_features(model, batch[0]),
141
+ dataset_name,
142
+ )
143
+
144
+ loss = F.l1_loss(features, finetuned_feature)
145
+
146
+ optimizer.zero_grad()
147
+ loss.backward()
148
+ optimizer.step()
149
+
150
+ if ((iteration + 1) % self.config.eval_iterations) == 0:
151
+ # print(list(alpha_model.collect_trainable_params()))
152
+ # Evaluate try to use the test module in fusion bench
153
+ log.info(f"iteration: {iteration+1}")
154
+ self._program.evaluate_merged_model(self._program.taskpool, alpha_model)
155
+
156
+ log.info("test the result of Adamerging")
157
+ return {"adamerging": merged_model, "surgery": alpha_model}
File without changes
@@ -0,0 +1,234 @@
1
+ import copy
2
+ import os
3
+ from typing import List, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from fusion_bench.utils import state_dict_to_vector, vector_to_state_dict
9
+
10
+
11
+ def generate_task_masks(
12
+ tv_flat_checks: torch.Tensor,
13
+ flat_ft: torch.Tensor,
14
+ flat_ptm: torch.Tensor,
15
+ tv: Optional[torch.Tensor] = None,
16
+ tall_mask_lambda: float = 1.0,
17
+ ) -> torch.Tensor:
18
+ """
19
+ Generate task-specific TALL masks
20
+ TALL masks are generated as: mask_t = |theta_0 - theta_t| > |theta_mt - theta_t| * lambda
21
+
22
+ Args:
23
+ tv_flat_checks: individual task vectors
24
+ flat_ft: individual theta_t (fine-tuned weights)
25
+ flat_ptm: theta_0 (pre-trained weight)
26
+ tv: multi-task vector
27
+ tall_mask_lambda: hyper-parameter lambda for generating TALL masks
28
+ Returns:
29
+ final_mask: generated TALL masks with the given lambda, in shape (n_task, n_parameter)
30
+ """
31
+
32
+ print(f"Generating TALL masks.")
33
+
34
+ if tv is None:
35
+ tv = tv_flat_checks.sum(0)
36
+
37
+ flat_multi = flat_ptm + tv
38
+
39
+ original_shape = flat_ft.shape
40
+
41
+ # generate masks by comparing the l1 distance between |theta_0 - theta_t| and |theta_mt - theta_t|
42
+ diff_pt_ft = (flat_ptm - flat_ft).abs()
43
+ diff_multi_ft = (flat_multi - flat_ft).abs()
44
+ # compare the l1 distance, scaled with hyper-parameter lambda
45
+ mask = diff_pt_ft > diff_multi_ft * tall_mask_lambda
46
+
47
+ final_mask = (
48
+ mask.squeeze() if original_shape == tv_flat_checks.squeeze().shape else mask
49
+ )
50
+
51
+ print(
52
+ f"Average sparsity for the mask with tall_mask_lambda of {tall_mask_lambda}: {final_mask.float().mean():.4f}"
53
+ )
54
+ return final_mask
55
+
56
+
57
+ def construct_tall_mask(
58
+ tv_flat_checks: torch.Tensor,
59
+ flat_ft: torch.Tensor,
60
+ flat_ptm: torch.Tensor,
61
+ merged_tv: torch.Tensor,
62
+ ptm_check: torch.Tensor,
63
+ remove_keys: List[str],
64
+ config,
65
+ ):
66
+ """
67
+ Construct TALL masks for all tasks for each lambda, and store in dictionary
68
+
69
+ Args:
70
+ tv_flat_checks: individual task vectors
71
+ flat_ft: individual theta_t (fine-tuned weights)
72
+ flat_ptm: theta_0 (pre-trained weight)
73
+ merged_tv: multi-task vector
74
+ ptm_check: pre-trained weight as state dictionary
75
+ remove_keys: the keys to be removed when converting between dictionary and vector
76
+ Returns:
77
+ tall_masks: constructed TALL masks in dictionary format of {lambda: {task: mask}}
78
+ """
79
+ tall_masks = {}
80
+ for tall_mask_lambda in [0.2, 0.3, 0.4, 0.5, 0.6]:
81
+ # generate tall masks for each lambda
82
+ masks_at_scale = generate_task_masks(
83
+ tv_flat_checks,
84
+ flat_ft,
85
+ flat_ptm,
86
+ tall_mask_lambda=tall_mask_lambda,
87
+ tv=merged_tv,
88
+ )
89
+ # convert vectors to dictionary
90
+ masks_at_scale = [
91
+ vector_to_state_dict(mask, ptm_check, remove_keys=remove_keys)
92
+ for mask in masks_at_scale
93
+ ]
94
+ # store the masks with {dataset: mask}
95
+ tall_masks[tall_mask_lambda] = {
96
+ key: value for key, value in zip(config.DATASETS, masks_at_scale)
97
+ }
98
+ return tall_masks
99
+
100
+
101
+ def find_optimal_mask(val_metrics, eval_masks, args, save_masks=True):
102
+ """
103
+ Respectively finds the optimal mask for each data task based on the validation accuracy
104
+
105
+ Args:
106
+ val_metrics: validation metrics for each lambda
107
+ eval_masks: all generated masks
108
+
109
+ Returns:
110
+ best_masks_for_test: the best masks for each task, selected based on validation accuracy from each task
111
+ best_val_metrics: best validation metrics for each task
112
+ """
113
+ # transpose the dict from lambda-task to task-lambda
114
+ transposed_dict = {}
115
+ for key, inner_dict in val_metrics.items():
116
+ for inner_key, value in inner_dict.items():
117
+ if inner_key not in transposed_dict:
118
+ transposed_dict[inner_key] = {}
119
+ transposed_dict[inner_key][key] = value
120
+
121
+ # for each task, find the best lambda
122
+ max_subkeys = {
123
+ key: max(inner_dict, key=inner_dict.get)
124
+ for key, inner_dict in transposed_dict.items()
125
+ }
126
+
127
+ # select the best mask for each task, which will be used for testing later
128
+ best_masks_for_test = {}
129
+ best_masks_for_test_vector = {} # the selected masks as vectors
130
+ best_val_metrics = {}
131
+ # respectively for each task:
132
+ for ds in args.DATASETS:
133
+ # select the lambda which achieves the best valdiation accuracy
134
+ best_lambda = float(max_subkeys[ds + "Val:top1"])
135
+ # select the mask based on the selected lambda, save as dictionaries
136
+ best_masks_for_test[ds] = eval_masks[best_lambda][ds]
137
+ # select the mask based on the selected lambda, save as vectors
138
+ best_masks_for_test_vector[ds] = state_dict_to_vector(
139
+ eval_masks[best_lambda][ds], remove_keys=[]
140
+ )
141
+ print(f"Best lambda for {ds} is {best_lambda}")
142
+ # save the best validation metric based on the selected lambda
143
+ best_val_metrics[ds + "Val:top1"] = val_metrics[best_lambda][ds + "Val:top1"]
144
+
145
+ # save the best masks in disk
146
+ if save_masks and not args.method.load_mask:
147
+ # convert to numpy to save with np.packbits for saving storage
148
+ best_masks_for_test_vector = {
149
+ k: np.packbits(v) for k, v in best_masks_for_test_vector.items()
150
+ }
151
+ mask_save_dir = args.model_location.replace("checkpoints", "tall_masks")
152
+ mask_name = (
153
+ f"TALL_mask_{args.num_tasks}task.npy"
154
+ if not args.method.use_ties
155
+ else f"TALL_mask_{args.num_tasks}task_use_ties_{args.method.ties_agg}.npy"
156
+ )
157
+ np.save(
158
+ os.path.join(mask_save_dir, args.model, mask_name),
159
+ best_masks_for_test_vector,
160
+ )
161
+ del best_masks_for_test_vector
162
+
163
+ return best_masks_for_test, best_val_metrics
164
+
165
+
166
+ def load_tall_mask(remove_keys, ptm_check, config):
167
+ """Loads TALL masks from disk, unpack and transform to state dictionaries."""
168
+ mask_location = config.model_location.replace("checkpoints", "tall_masks")
169
+ try:
170
+ if config.method.use_ties:
171
+ print("==== Loading TALL Masks built with TIES ====")
172
+ tall_masks = torch.load(
173
+ os.path.join(
174
+ mask_location,
175
+ config.model,
176
+ f"TALL_mask_{config.num_tasks}task_use_ties.npy",
177
+ )
178
+ )
179
+ else:
180
+ print("==== Loading TALL Masks built with Task Arithmetic ====")
181
+ tall_masks = torch.load(
182
+ os.path.join(
183
+ mask_location, config.model, f"TALL_mask_{config.num_tasks}task.npy"
184
+ )
185
+ )
186
+ except:
187
+ raise Exception("TALL Masks are not constructed yet.")
188
+
189
+ # unpack masks and convert back to torch tensors
190
+ tall_masks = {k: torch.from_numpy(np.unpackbits(v)) for k, v in tall_masks.items()}
191
+
192
+ # convert vectors to dictionaries
193
+ tall_masks = {
194
+ dataset: vector_to_state_dict(mask, ptm_check, remove_keys=remove_keys)
195
+ for dataset, mask in tall_masks.items()
196
+ }
197
+
198
+ return tall_masks
199
+
200
+
201
+ def construct_consensus_mask(ptm_check, prun_thre_k, config, remove_keys=[]):
202
+ """
203
+ Generate consensus mask by filtering out least-used parameters
204
+
205
+ Args:
206
+ ptm_check: pretrained_checkpoint as state dictionary
207
+ prun_thre_k: weight-pruning threhold, stands for the least number of activated tasks for a parameter to be preserved from pruning
208
+ if prun_thre_k is set to 2: remove both catastrophic and selfish weights;
209
+ if prun_thre_k is set to 1: remove only catastrophic weights;
210
+ if prun_thre_k is set to 0: remove no weights -> reduce to TA or TIES
211
+ if prun_thre_k is set to > num_tasks: remove all weights -> reduce to zero-shot
212
+ Returns:
213
+ consensus_mask_vector: constructed consensus mask as vector (boolean in shape (n_parameter, ))
214
+ """
215
+
216
+ print("==== Generating Consensus Mask ====")
217
+ # load TALL masks (in shape (n_task, n_parameter))
218
+ tall_masks = load_tall_mask(remove_keys, ptm_check, config)
219
+ tall_masks = list(tall_masks.values())
220
+
221
+ # generate consensus masks
222
+ consensus_mask = copy.deepcopy(tall_masks[0])
223
+ for key, value in consensus_mask.items():
224
+ consensus_mask[key] = torch.zeros_like(value)
225
+ # count for each parameter, the tasks it has been activated for
226
+ for mask in tall_masks:
227
+ consensus_mask[key] = consensus_mask[key] + mask[key].float()
228
+ # filter out the least-activated parameters based on given threshold
229
+ consensus_mask[key] = consensus_mask[key].float() >= prun_thre_k
230
+ consensus_mask_vector = state_dict_to_vector(
231
+ consensus_mask, remove_keys=remove_keys
232
+ )
233
+
234
+ return consensus_mask_vector
@@ -0,0 +1,16 @@
1
+ import torch
2
+ from torch import Tensor, nn
3
+
4
+ from fusion_bench import BaseAlgorithm
5
+
6
+ from .utils import TSVC_utils, check_parameterNamesMatch
7
+
8
+
9
+ class TaskSingularVectorCompression(BaseAlgorithm):
10
+ def __init__(self, **kwargs):
11
+ super().__init__(**kwargs)
12
+
13
+ def run(self, modelpool):
14
+ raise NotImplementedError(
15
+ "Task Singular Vector Compression is not implemented yet."
16
+ )
@@ -0,0 +1,63 @@
1
+ """
2
+ Example:
3
+
4
+ ```bash
5
+ fusion_bench \
6
+ method=task_singular_vector/TaskSingularVectorMerging \
7
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only \
8
+ taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TALL20
9
+ ```
10
+ """
11
+
12
+ from typing import List, Optional
13
+
14
+ import torch
15
+ from torch import Tensor, nn
16
+
17
+ from fusion_bench import BaseAlgorithm
18
+ from fusion_bench.mixins import LightningFabricMixin
19
+ from fusion_bench.utils import timeit_context
20
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_sub, state_dict_add
21
+ from fusion_bench.utils.type import StateDictType
22
+
23
+ from .utils import (
24
+ TSVM_utils,
25
+ check_parameterNamesMatch,
26
+ check_state_dicts_equal,
27
+ state_dict_to_vector,
28
+ vector_to_state_dict,
29
+ )
30
+
31
+
32
+ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
33
+
34
+ def __init__(
35
+ self,
36
+ remove_keys: Optional[List[str]] = None,
37
+ **kwargs,
38
+ ):
39
+ self.remove_keys = remove_keys if remove_keys is not None else []
40
+ super().__init__(**kwargs)
41
+
42
+ def run(self, modelpool):
43
+ # Load the pre-trained model and the fine-tuned models
44
+ pretrained_model = modelpool.load_pretrained_model()
45
+ finetuned_models = list(modelpool.models())
46
+
47
+ ptm_check = pretrained_model.state_dict()
48
+ ft_checks = [model.state_dict() for model in finetuned_models]
49
+ check_parameterNamesMatch(ft_checks + [ptm_check])
50
+
51
+ with timeit_context("Flattening out Checkpoints"):
52
+ task_vectors = [state_dict_sub(check, ptm_check) for check in ft_checks]
53
+
54
+ new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
55
+ task_vectors,
56
+ exclude_keys=self.remove_keys,
57
+ accelerator=self.fabric.device,
58
+ )
59
+
60
+ pretrained_model.load_state_dict(
61
+ state_dict_add(new_merged_tv, pretrained_model.state_dict())
62
+ )
63
+ return pretrained_model
@@ -0,0 +1,9 @@
1
+ """
2
+ This module is modified from the original code of the paper:
3
+
4
+ - Gargiulo, et.al. Task Singular Vectors: Reducing Task Interference in Model Merging
5
+ - http://arxiv.org/abs/2412.00081
6
+ - https://github.com/AntoAndGar/task_singular_vectors/
7
+ """
8
+
9
+ from .TSVM import TaskSingularVectorMerging
@@ -0,0 +1,50 @@
1
+ import torch
2
+
3
+
4
+ def compute_svd_and_compress(key, matrix, sv_reduction):
5
+ """
6
+ Computes the Singular Value Decomposition (SVD) of a given matrix and compresses it by reducing the number of singular values.
7
+
8
+ Args:
9
+ key (Any): An identifier for the matrix.
10
+ matrix (torch.Tensor): The input matrix to decompose.
11
+ sv_reduction (float): The fraction of singular values to retain (0 < sv_reduction <= 1).
12
+
13
+ Returns:
14
+ tuple: A tuple containing:
15
+ - key (Any): The original identifier for the matrix.
16
+ - u (torch.Tensor): The left singular vectors of the reduced SVD.
17
+ - s (torch.Tensor): The reduced singular values.
18
+ - v (torch.Tensor): The right singular vectors of the reduced SVD.
19
+ """
20
+ u, s, v = torch.linalg.svd(matrix, full_matrices=False)
21
+ reduced_index_s = int(s.shape[0] * sv_reduction)
22
+ return key, u[:, :reduced_index_s], s[:reduced_index_s], v[:reduced_index_s, :]
23
+
24
+
25
+ def compress_tv(task_vectors, sv_reduction):
26
+ """
27
+ Compress task vectors using Singular Value Decomposition (SVD).
28
+
29
+ Args:
30
+ task_vectors (dict): A dictionary where keys are dataset names and values are task vectors.
31
+ Each task vector is expected to have a 'vector' attribute which is a dictionary
32
+ with keys as layer names and values as layer matrices.
33
+ sv_reduction (int): The fraction of singular values to keep for compression.
34
+
35
+ Returns:
36
+ dict: A dictionary with the same structure as `task_vectors`, but with each layer matrix
37
+ replaced by its compressed SVD components (u, s, v) if the layer is 2-dimensional.
38
+ If the layer is not 2-dimensional, it is stored as is under the key "dim1".
39
+ """
40
+ with torch.no_grad():
41
+ svd_dict = {}
42
+ for dataset, task_vector in task_vectors.items():
43
+ svd_dict[dataset] = {}
44
+ for key, layer in task_vector.vector.items():
45
+ if len(layer.shape) == 2: # and "text_projection" not in key:
46
+ _, u, s, v = compute_svd_and_compress(key, layer, sv_reduction)
47
+ svd_dict[dataset][key] = {"u": u, "s": s, "v": v}
48
+ else:
49
+ svd_dict[dataset][key] = {"dim1": layer}
50
+ return svd_dict