fusion-bench 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (727) hide show
  1. fusion_bench/__init__.py +20 -0
  2. fusion_bench/__main__.py +4 -0
  3. fusion_bench/compat/__init__.py +0 -0
  4. fusion_bench/compat/method/__init__.py +109 -0
  5. fusion_bench/compat/method/base_algorithm.py +58 -0
  6. fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
  7. fusion_bench/compat/modelpool/__init__.py +116 -0
  8. fusion_bench/compat/modelpool/base_pool.py +328 -0
  9. fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
  10. fusion_bench/compat/taskpool/__init__.py +95 -0
  11. fusion_bench/compat/taskpool/base_pool.py +111 -0
  12. fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
  13. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
  14. fusion_bench/constants/__init__.py +2 -0
  15. fusion_bench/constants/paths.py +18 -0
  16. fusion_bench/dataset/__init__.py +29 -0
  17. fusion_bench/dataset/arc_agi/__init__.py +6 -0
  18. fusion_bench/dataset/arc_agi/arc.py +308 -0
  19. fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
  20. fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
  21. fusion_bench/dataset/arc_agi/messagers.py +1355 -0
  22. fusion_bench/dataset/arc_agi/np_cache.py +168 -0
  23. fusion_bench/dataset/arc_agi/preprocess.py +298 -0
  24. fusion_bench/dataset/arc_agi/representers.py +1019 -0
  25. fusion_bench/dataset/clip_dataset.py +71 -0
  26. fusion_bench/dataset/fer2013.py +12 -0
  27. fusion_bench/dataset/gpt2_glue.py +300 -0
  28. fusion_bench/dataset/gsm8k.py +60 -0
  29. fusion_bench/dataset/image_dataset.py +55 -0
  30. fusion_bench/dataset/imdb.py +11 -0
  31. fusion_bench/dataset/llama/__init__.py +1 -0
  32. fusion_bench/dataset/llama/alpaca.py +232 -0
  33. fusion_bench/dataset/llama/collate.py +120 -0
  34. fusion_bench/dataset/llama/metamathqa.py +50 -0
  35. fusion_bench/dataset/llama/openai.py +160 -0
  36. fusion_bench/dataset/llama/preference_700k.py +70 -0
  37. fusion_bench/dataset/llama/sharegpt.py +141 -0
  38. fusion_bench/dataset/llama/squad.py +125 -0
  39. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  40. fusion_bench/dataset/llama/ultrachat.py +58 -0
  41. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  42. fusion_bench/dataset/llama/wikitext.py +89 -0
  43. fusion_bench/dataset/nyuv2.py +119 -0
  44. fusion_bench/method/__init__.py +177 -0
  45. fusion_bench/method/ada_svd/__init__.py +2 -0
  46. fusion_bench/method/ada_svd/clip_vision.py +319 -0
  47. fusion_bench/method/adamerging/__init__.py +6 -0
  48. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
  49. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
  50. fusion_bench/method/adamerging/entropy_loss.py +25 -0
  51. fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
  52. fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
  53. fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
  54. fusion_bench/method/adamerging/llama_adamerging.py +335 -0
  55. fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
  56. fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
  57. fusion_bench/method/adamerging/utils.py +15 -0
  58. fusion_bench/method/analysis/__init__.py +2 -0
  59. fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
  60. fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
  61. fusion_bench/method/base_algorithm.py +44 -0
  62. fusion_bench/method/classification/__init__.py +3 -0
  63. fusion_bench/method/classification/clip_finetune.py +444 -0
  64. fusion_bench/method/classification/continual_clip_finetune.py +297 -0
  65. fusion_bench/method/concrete_subspace/__init__.py +6 -0
  66. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
  67. fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
  68. fusion_bench/method/dare/__init__.py +4 -0
  69. fusion_bench/method/dare/simple_average.py +31 -0
  70. fusion_bench/method/dare/task_arithmetic.py +82 -0
  71. fusion_bench/method/dare/ties_merging.py +100 -0
  72. fusion_bench/method/dare/utils.py +87 -0
  73. fusion_bench/method/dawe/__init__.py +2 -0
  74. fusion_bench/method/dawe/dawe_for_clip.py +274 -0
  75. fusion_bench/method/dawe/warppers/__init__.py +13 -0
  76. fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
  77. fusion_bench/method/depth_upscaling/__init__.py +3 -0
  78. fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
  79. fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
  80. fusion_bench/method/dummy.py +35 -0
  81. fusion_bench/method/ensemble.py +98 -0
  82. fusion_bench/method/fisher_merging/__init__.py +4 -0
  83. fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
  84. fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
  85. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
  86. fusion_bench/method/linear/__init__.py +6 -0
  87. fusion_bench/method/linear/expo.py +118 -0
  88. fusion_bench/method/linear/linear_interpolation.py +60 -0
  89. fusion_bench/method/linear/llama_expo.py +229 -0
  90. fusion_bench/method/linear/simple_average_for_llama.py +54 -0
  91. fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
  92. fusion_bench/method/lm_finetune/__init__.py +3 -0
  93. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  94. fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
  95. fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
  96. fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
  97. fusion_bench/method/mixture_of_experts/__init__.py +7 -0
  98. fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
  99. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
  100. fusion_bench/method/model_recombination.py +121 -0
  101. fusion_bench/method/opcm/__init__.py +4 -0
  102. fusion_bench/method/opcm/opcm.py +277 -0
  103. fusion_bench/method/opcm/task_arithmetic.py +115 -0
  104. fusion_bench/method/opcm/ties_merging.py +156 -0
  105. fusion_bench/method/opcm/utils.py +73 -0
  106. fusion_bench/method/opcm/weight_average.py +120 -0
  107. fusion_bench/method/pruning/__init__.py +5 -0
  108. fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
  109. fusion_bench/method/pruning/llama_random_prune.py +143 -0
  110. fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
  111. fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
  112. fusion_bench/method/pruning/prune_utils.py +165 -0
  113. fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
  114. fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
  115. fusion_bench/method/pruning/wanda_utils/data.py +135 -0
  116. fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
  117. fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
  118. fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
  119. fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
  120. fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
  121. fusion_bench/method/pwe_moe/__init__.py +5 -0
  122. fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
  123. fusion_bench/method/pwe_moe/module.py +316 -0
  124. fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
  125. fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
  126. fusion_bench/method/pwe_moe/utils.py +43 -0
  127. fusion_bench/method/rankone_moe/__init__.py +3 -0
  128. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  129. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  130. fusion_bench/method/regmean/__init__.py +4 -0
  131. fusion_bench/method/regmean/clip_regmean.py +131 -0
  132. fusion_bench/method/regmean/gpt2_regmean.py +147 -0
  133. fusion_bench/method/regmean/regmean.py +375 -0
  134. fusion_bench/method/simple_average.py +112 -0
  135. fusion_bench/method/slerp/__init__.py +2 -0
  136. fusion_bench/method/slerp/slerp.py +101 -0
  137. fusion_bench/method/slerp/slerp_utils.py +107 -0
  138. fusion_bench/method/smile_upscaling/__init__.py +3 -0
  139. fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
  140. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
  141. fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
  142. fusion_bench/method/sparse_we_moe/__init__.py +2 -0
  143. fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
  144. fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
  145. fusion_bench/method/sparselo/__init__.py +2 -0
  146. fusion_bench/method/sparselo/sparselo.py +955 -0
  147. fusion_bench/method/surgery/__init__.py +1 -0
  148. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  149. fusion_bench/method/tall_mask/__init__.py +0 -0
  150. fusion_bench/method/tall_mask/utils.py +234 -0
  151. fusion_bench/method/task_arithmetic/__init__.py +2 -0
  152. fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
  153. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  154. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  155. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  156. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  157. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
  158. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  159. fusion_bench/method/ties_merging/__init__.py +2 -0
  160. fusion_bench/method/ties_merging/ties_merging.py +117 -0
  161. fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
  162. fusion_bench/method/trust_region/__init__.py +2 -0
  163. fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
  164. fusion_bench/method/trust_region/utils.py +58 -0
  165. fusion_bench/method/we_moe/__init__.py +2 -0
  166. fusion_bench/method/we_moe/clip_we_moe.py +161 -0
  167. fusion_bench/method/we_moe/we_moe.py +247 -0
  168. fusion_bench/method/weighted_average/__init__.py +3 -0
  169. fusion_bench/method/weighted_average/llama.py +113 -0
  170. fusion_bench/method/weighted_average/weighted_average.py +102 -0
  171. fusion_bench/metrics/__init__.py +0 -0
  172. fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
  173. fusion_bench/metrics/nyuv2/__init__.py +11 -0
  174. fusion_bench/metrics/nyuv2/depth.py +45 -0
  175. fusion_bench/metrics/nyuv2/loss.py +31 -0
  176. fusion_bench/metrics/nyuv2/noise.py +16 -0
  177. fusion_bench/metrics/nyuv2/normal.py +48 -0
  178. fusion_bench/metrics/nyuv2/segmentation.py +43 -0
  179. fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
  180. fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
  181. fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
  182. fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
  183. fusion_bench/mixins/__init__.py +28 -0
  184. fusion_bench/mixins/clip_classification.py +252 -0
  185. fusion_bench/mixins/fabric_training.py +320 -0
  186. fusion_bench/mixins/lightning_fabric.py +174 -0
  187. fusion_bench/mixins/optim/__init__.py +0 -0
  188. fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
  189. fusion_bench/mixins/rich_live.py +21 -0
  190. fusion_bench/mixins/serialization.py +132 -0
  191. fusion_bench/mixins/simple_profiler.py +79 -0
  192. fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
  193. fusion_bench/modelpool/__init__.py +42 -0
  194. fusion_bench/modelpool/base_pool.py +268 -0
  195. fusion_bench/modelpool/causal_lm/__init__.py +2 -0
  196. fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
  197. fusion_bench/modelpool/clip_vision/__init__.py +1 -0
  198. fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
  199. fusion_bench/modelpool/huggingface_automodel.py +20 -0
  200. fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
  201. fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
  202. fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
  203. fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
  204. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  205. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  206. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  207. fusion_bench/models/__init__.py +3 -0
  208. fusion_bench/models/chat_templates/__init__.py +1 -0
  209. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  210. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  211. fusion_bench/models/hf_clip.py +199 -0
  212. fusion_bench/models/linearized/__init__.py +0 -0
  213. fusion_bench/models/linearized/linearized_model_utils.py +91 -0
  214. fusion_bench/models/linearized/vision_model.py +122 -0
  215. fusion_bench/models/llama/__init__.py +16 -0
  216. fusion_bench/models/llama/model_utils/__init__.py +0 -0
  217. fusion_bench/models/llama/model_utils/embedding.py +87 -0
  218. fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
  219. fusion_bench/models/llama/model_utils/misc.py +112 -0
  220. fusion_bench/models/llama/model_utils/mod.py +52 -0
  221. fusion_bench/models/llama/model_utils/visual.py +241 -0
  222. fusion_bench/models/llama/patcher.py +78 -0
  223. fusion_bench/models/llama/tokenizer_loader.py +153 -0
  224. fusion_bench/models/masks/__init__.py +2 -0
  225. fusion_bench/models/masks/mask_model.py +160 -0
  226. fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
  227. fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
  228. fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
  229. fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
  230. fusion_bench/models/modeling_losparse_llama/register.py +8 -0
  231. fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
  232. fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
  233. fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
  234. fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
  235. fusion_bench/models/modeling_smile_mistral/register.py +8 -0
  236. fusion_bench/models/nyuv2/__init__.py +0 -0
  237. fusion_bench/models/nyuv2/aspp.py +82 -0
  238. fusion_bench/models/nyuv2/lightning_module.py +176 -0
  239. fusion_bench/models/nyuv2/resnet.py +405 -0
  240. fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
  241. fusion_bench/models/parameter_dict.py +75 -0
  242. fusion_bench/models/rankone_moe.py +410 -0
  243. fusion_bench/models/separate_io.py +105 -0
  244. fusion_bench/models/smile_moe/__init__.py +0 -0
  245. fusion_bench/models/smile_moe/linear.py +256 -0
  246. fusion_bench/models/sparse_we_moe.py +459 -0
  247. fusion_bench/models/surgery/__init__.py +1 -0
  248. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  249. fusion_bench/models/utils.py +80 -0
  250. fusion_bench/models/we_moe.py +247 -0
  251. fusion_bench/models/wrappers/__init__.py +0 -0
  252. fusion_bench/models/wrappers/ensemble.py +183 -0
  253. fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
  254. fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
  255. fusion_bench/optim/__init__.py +2 -0
  256. fusion_bench/optim/exception.py +47 -0
  257. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  258. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  259. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  260. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  261. fusion_bench/optim/mezo.py +118 -0
  262. fusion_bench/programs/__init__.py +20 -0
  263. fusion_bench/programs/base_program.py +9 -0
  264. fusion_bench/programs/fabric_fusion_program.py +299 -0
  265. fusion_bench/scripts/__init__.py +0 -0
  266. fusion_bench/scripts/cli.py +43 -0
  267. fusion_bench/scripts/clip/__init__.py +0 -0
  268. fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
  269. fusion_bench/scripts/imgui.py +218 -0
  270. fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
  271. fusion_bench/scripts/webui.py +405 -0
  272. fusion_bench/taskpool/__init__.py +39 -0
  273. fusion_bench/taskpool/base_pool.py +35 -0
  274. fusion_bench/taskpool/clip_vision/__init__.py +4 -0
  275. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  276. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
  277. fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
  278. fusion_bench/taskpool/dummy.py +58 -0
  279. fusion_bench/taskpool/gpt2_text_classification.py +149 -0
  280. fusion_bench/taskpool/llama/__init__.py +1 -0
  281. fusion_bench/taskpool/llama/reward_model.py +157 -0
  282. fusion_bench/taskpool/llama/test_generation.py +185 -0
  283. fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
  284. fusion_bench/tasks/__init__.py +2 -0
  285. fusion_bench/tasks/base_task.py +18 -0
  286. fusion_bench/tasks/classification.py +75 -0
  287. fusion_bench/tasks/clip_classification/__init__.py +183 -0
  288. fusion_bench/tasks/clip_classification/cifar10.py +33 -0
  289. fusion_bench/tasks/clip_classification/cifar100.py +146 -0
  290. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
  291. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  292. fusion_bench/tasks/clip_classification/dtd.py +60 -0
  293. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  294. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  295. fusion_bench/tasks/clip_classification/eurosat.py +18 -0
  296. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  297. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  298. fusion_bench/tasks/clip_classification/flower102.py +106 -0
  299. fusion_bench/tasks/clip_classification/food101.py +105 -0
  300. fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
  301. fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
  302. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  303. fusion_bench/tasks/clip_classification/mnist.py +5 -0
  304. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  305. fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
  306. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  307. fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
  308. fusion_bench/tasks/clip_classification/resisc45.py +68 -0
  309. fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
  310. fusion_bench/tasks/clip_classification/stl10.py +17 -0
  311. fusion_bench/tasks/clip_classification/sun397.py +404 -0
  312. fusion_bench/tasks/clip_classification/svhn.py +5 -0
  313. fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
  314. fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
  315. fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
  316. fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
  317. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
  318. fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
  319. fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
  320. fusion_bench/utils/__init__.py +14 -0
  321. fusion_bench/utils/auto.py +31 -0
  322. fusion_bench/utils/cache_utils.py +58 -0
  323. fusion_bench/utils/data.py +165 -0
  324. fusion_bench/utils/devices.py +231 -0
  325. fusion_bench/utils/dict.py +43 -0
  326. fusion_bench/utils/dtype.py +146 -0
  327. fusion_bench/utils/expr.py +90 -0
  328. fusion_bench/utils/fabric.py +17 -0
  329. fusion_bench/utils/functools.py +37 -0
  330. fusion_bench/utils/hydra_utils.py +28 -0
  331. fusion_bench/utils/instantiate.py +450 -0
  332. fusion_bench/utils/json.py +93 -0
  333. fusion_bench/utils/lazy_imports.py +74 -0
  334. fusion_bench/utils/misc.py +18 -0
  335. fusion_bench/utils/packages.py +84 -0
  336. fusion_bench/utils/parameters.py +323 -0
  337. fusion_bench/utils/path.py +22 -0
  338. fusion_bench/utils/plot/__init__.py +0 -0
  339. fusion_bench/utils/plot/color_data.py +1726 -0
  340. fusion_bench/utils/plot/token.py +52 -0
  341. fusion_bench/utils/plot/token_notebook.py +127 -0
  342. fusion_bench/utils/pylogger.py +55 -0
  343. fusion_bench/utils/rich_utils.py +201 -0
  344. fusion_bench/utils/set.py +8 -0
  345. fusion_bench/utils/state_dict_arithmetic.py +297 -0
  346. fusion_bench/utils/strenum/__init__.py +326 -0
  347. fusion_bench/utils/strenum/_name_mangler.py +127 -0
  348. fusion_bench/utils/strenum/_version.py +556 -0
  349. fusion_bench/utils/tensorboard.py +51 -0
  350. fusion_bench/utils/timer.py +49 -0
  351. fusion_bench/utils/type.py +34 -0
  352. fusion_bench-0.2.9.dist-info/LICENSE +21 -0
  353. fusion_bench-0.2.9.dist-info/METADATA +258 -0
  354. fusion_bench-0.2.9.dist-info/RECORD +727 -0
  355. fusion_bench-0.2.9.dist-info/WHEEL +5 -0
  356. fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
  357. fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
  358. fusion_bench_config/README.md +12 -0
  359. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
  360. fusion_bench_config/dataset/image_classification/README.md +6 -0
  361. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  362. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  363. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
  364. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
  365. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  366. fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
  367. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  368. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  369. fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
  370. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  371. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  372. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  373. fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
  374. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  375. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  376. fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
  377. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  378. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  379. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  380. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  381. fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
  382. fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
  383. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  384. fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
  385. fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
  386. fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
  387. fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
  388. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  389. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  390. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
  391. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
  392. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  393. fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
  394. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  395. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  396. fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
  397. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  398. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  399. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  400. fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
  401. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  402. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  403. fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
  404. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  405. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  406. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  407. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  408. fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
  409. fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
  410. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  411. fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
  412. fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
  413. fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
  414. fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
  415. fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
  416. fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
  417. fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
  418. fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
  419. fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
  420. fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
  421. fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
  422. fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
  423. fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
  424. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  425. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  426. fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
  427. fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
  428. fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
  429. fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
  430. fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
  431. fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
  432. fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
  433. fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
  434. fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
  435. fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
  436. fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
  437. fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
  438. fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
  439. fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
  440. fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
  441. fusion_bench_config/fabric/auto.yaml +16 -0
  442. fusion_bench_config/fabric/llama_ddp.yaml +18 -0
  443. fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
  444. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  445. fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
  446. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
  447. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  448. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  449. fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
  450. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  451. fusion_bench_config/fabric_model_fusion.yaml +20 -0
  452. fusion_bench_config/hydra/default.yaml +8 -0
  453. fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
  454. fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
  455. fusion_bench_config/llama_full_finetune.yaml +19 -0
  456. fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
  457. fusion_bench_config/llama_model_fusion.yaml +17 -0
  458. fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
  459. fusion_bench_config/method/adamerging/clip.yaml +23 -0
  460. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
  461. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
  462. fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
  463. fusion_bench_config/method/adamerging.yaml +23 -0
  464. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
  465. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
  466. fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
  467. fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
  468. fusion_bench_config/method/clip_finetune.yaml +26 -0
  469. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
  470. fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
  471. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
  472. fusion_bench_config/method/dare/simple_average.yaml +5 -0
  473. fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
  474. fusion_bench_config/method/dare/ties_merging.yaml +15 -0
  475. fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
  476. fusion_bench_config/method/depth_upscaling.yaml +5 -0
  477. fusion_bench_config/method/dummy.yaml +1 -0
  478. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
  479. fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
  480. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
  481. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
  482. fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
  483. fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
  484. fusion_bench_config/method/linear/expo.yaml +8 -0
  485. fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
  486. fusion_bench_config/method/linear/llama_expo.yaml +19 -0
  487. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
  488. fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
  489. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
  490. fusion_bench_config/method/linear/weighted_average.yaml +6 -0
  491. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
  492. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  493. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
  494. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
  495. fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
  496. fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
  497. fusion_bench_config/method/model_recombination.yaml +4 -0
  498. fusion_bench_config/method/opcm/opcm.yaml +12 -0
  499. fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
  500. fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
  501. fusion_bench_config/method/opcm/weight_average.yaml +10 -0
  502. fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
  503. fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
  504. fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
  505. fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
  506. fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
  507. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  508. fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
  509. fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
  510. fusion_bench_config/method/regmean/regmean.yaml +4 -0
  511. fusion_bench_config/method/simple_average.yaml +1 -0
  512. fusion_bench_config/method/slerp/slerp.yaml +6 -0
  513. fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
  514. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
  515. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
  516. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
  517. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
  518. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
  519. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  520. fusion_bench_config/method/task_arithmetic.yaml +2 -0
  521. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  522. fusion_bench_config/method/ties_merging.yaml +8 -0
  523. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
  524. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
  525. fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
  526. fusion_bench_config/model/clip-vit/README.md +38 -0
  527. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
  528. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  529. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  530. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  531. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  532. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
  533. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
  534. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  535. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
  536. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  537. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  538. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  539. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
  540. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  541. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
  542. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  543. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  544. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  545. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  546. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
  547. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
  548. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  549. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
  550. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
  551. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
  552. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  553. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  554. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  555. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  556. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
  557. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
  558. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  559. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
  560. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  561. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  562. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  563. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
  564. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  565. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
  566. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  567. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  568. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  569. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  570. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
  571. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
  572. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  573. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
  574. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
  575. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
  576. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  577. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  578. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  579. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  580. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
  581. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
  582. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  583. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
  584. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  585. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  586. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  587. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
  588. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  589. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
  590. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  591. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  592. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  593. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  594. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
  595. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
  596. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  597. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
  598. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
  599. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  600. fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
  601. fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
  602. fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
  603. fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
  604. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
  605. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
  606. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
  607. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
  608. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
  609. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
  610. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
  611. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
  612. fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
  613. fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
  614. fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
  615. fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
  616. fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
  617. fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
  618. fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
  619. fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
  620. fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
  621. fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
  622. fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
  623. fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
  624. fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
  625. fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
  626. fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
  627. fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
  628. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
  629. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
  630. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
  631. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  632. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  633. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  634. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  635. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  636. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
  637. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
  638. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
  639. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
  640. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
  641. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  642. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  643. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  644. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  645. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
  646. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
  647. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
  648. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
  649. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
  650. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
  651. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
  652. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  653. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
  654. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
  655. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
  656. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
  657. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  658. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  659. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  660. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  661. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
  662. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  663. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  664. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
  665. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  666. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  667. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
  668. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
  669. fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
  670. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
  671. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
  672. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
  673. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
  674. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
  675. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  676. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  677. fusion_bench_config/modelpool/automodelpool.yaml +12 -0
  678. fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
  679. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
  680. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
  681. fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
  682. fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
  683. fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
  684. fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
  685. fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
  686. fusion_bench_config/nyuv2_config.yaml +17 -0
  687. fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
  688. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
  689. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  690. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
  691. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
  692. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
  693. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
  694. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
  695. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  696. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  697. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  698. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  699. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  700. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  701. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  702. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  703. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  704. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  705. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  706. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  707. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  708. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  709. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  710. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  711. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  712. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  713. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  714. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  715. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  716. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  717. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  718. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  719. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
  720. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
  721. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  722. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
  723. fusion_bench_config/taskpool/dummy.yaml +2 -0
  724. fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
  725. fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
  726. fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
  727. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
@@ -0,0 +1,1034 @@
1
+ import logging
2
+ import math
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint
8
+ from torch import Tensor, nn
9
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
+ from transformers.cache_utils import (
11
+ Cache,
12
+ DynamicCache,
13
+ SlidingWindowCache,
14
+ StaticCache,
15
+ )
16
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPast,
19
+ CausalLMOutputWithPast,
20
+ SequenceClassifierOutputWithPast,
21
+ TokenClassifierOutput,
22
+ )
23
+ from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.models.mistral.modeling_mistral import (
25
+ ACT2FN,
26
+ MistralRMSNorm,
27
+ MistralRotaryEmbedding,
28
+ )
29
+
30
+ from .configuration_smile_mistral import SmileMistralConfig
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def rotate_half(x):
36
+ """Rotates half the hidden dims of the input."""
37
+ x1 = x[..., : x.shape[-1] // 2]
38
+ x2 = x[..., x.shape[-1] // 2 :]
39
+ return torch.cat((-x2, x1), dim=-1)
40
+
41
+
42
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
43
+ """Applies Rotary Position Embedding to the query and key tensors.
44
+
45
+ Args:
46
+ q (`torch.Tensor`): The query tensor.
47
+ k (`torch.Tensor`): The key tensor.
48
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
49
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
50
+ position_ids (`torch.Tensor`, *optional*):
51
+ Deprecated and unused.
52
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
53
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
54
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
55
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
56
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
57
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
58
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
59
+ Returns:
60
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
61
+ """
62
+ cos = cos.unsqueeze(unsqueeze_dim)
63
+ sin = sin.unsqueeze(unsqueeze_dim)
64
+ q_embed = (q * cos) + (rotate_half(q) * sin)
65
+ k_embed = (k * cos) + (rotate_half(k) * sin)
66
+ return q_embed, k_embed
67
+
68
+
69
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
70
+ """
71
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
72
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
73
+ """
74
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
75
+ if n_rep == 1:
76
+ return hidden_states
77
+ hidden_states = hidden_states[:, :, None, :, :].expand(
78
+ batch, num_key_value_heads, n_rep, slen, head_dim
79
+ )
80
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
81
+
82
+
83
+ class SmileGate(nn.Module):
84
+ __constants__ = ["in_features", "num_experts", "k"]
85
+ in_features: int
86
+ num_experts: int
87
+ k: int
88
+ weight: Tensor
89
+
90
+ def __init__(
91
+ self,
92
+ in_features: int,
93
+ num_experts: int,
94
+ k: int,
95
+ device=None,
96
+ dtype=None,
97
+ ):
98
+ factory_kwargs = {"device": device, "dtype": dtype}
99
+ super().__init__()
100
+ self.input_features = in_features
101
+ self.num_experts = num_experts
102
+ self.k = k
103
+
104
+ self.weight = nn.Parameter(
105
+ torch.empty(num_experts * k, in_features, **factory_kwargs)
106
+ )
107
+
108
+ def forward(self, x: Tensor):
109
+ batch_size = x.size(0)
110
+ if self.num_experts == 1:
111
+ return torch.ones(batch_size, 1, device=x.device, dtype=x.dtype)
112
+
113
+ routing_weights = F.linear(x, self.weight).view(
114
+ batch_size, self.num_experts, self.k
115
+ )
116
+ routing_weights = routing_weights.norm(p=2, dim=2)
117
+ return routing_weights
118
+
119
+
120
+ class SmileLinearExpert(nn.Module):
121
+ __constants__ = ["in_features", "out_features", "k"]
122
+ in_features: int
123
+ out_features: int
124
+ k: int
125
+
126
+ def __init__(
127
+ self,
128
+ in_features,
129
+ out_features,
130
+ k: int,
131
+ bias: bool,
132
+ device=None,
133
+ dtype=None,
134
+ ):
135
+ factory_kwargs = {"device": device, "dtype": dtype}
136
+ super().__init__()
137
+ self.in_features = in_features
138
+ self.out_features = out_features
139
+ self.k = k
140
+
141
+ self.u = nn.Parameter(torch.empty(out_features, k, **factory_kwargs))
142
+ self.svh = nn.Parameter(torch.empty(k, in_features, **factory_kwargs))
143
+
144
+ if bias:
145
+ self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
146
+ else:
147
+ self.register_parameter("bias", None)
148
+
149
+ def forward(self, x):
150
+ x = F.linear(x, self.svh)
151
+ x = F.linear(x, self.u, self.bias)
152
+ return x
153
+
154
+
155
+ class SmileLinear(nn.Module):
156
+ @torch.no_grad()
157
+ def __init__(
158
+ self,
159
+ config: SmileMistralConfig,
160
+ in_features,
161
+ out_features,
162
+ bias: bool,
163
+ device=None,
164
+ dtype=None,
165
+ ):
166
+ factory_kwargs = {"device": device, "dtype": dtype}
167
+ super().__init__()
168
+ self.num_local_experts = config.num_local_experts
169
+ self.num_experts_per_tok = config.num_experts_per_tok
170
+ self.rank_of_expert = config.rank_of_expert
171
+ self.rank_of_router = config.rank_of_router
172
+ self.in_features = in_features
173
+ self.out_features = out_features
174
+
175
+ # construct the gate network
176
+ self.gate = SmileGate(
177
+ in_features=in_features,
178
+ num_experts=self.num_local_experts,
179
+ k=self.rank_of_router,
180
+ **factory_kwargs,
181
+ )
182
+
183
+ # the shared linear
184
+ self.shared_linear = nn.Linear(
185
+ in_features, out_features, bias=bias, **factory_kwargs
186
+ )
187
+
188
+ # construct experts
189
+ if self.rank_of_expert > 0:
190
+ self.experts = nn.ModuleList(
191
+ [
192
+ SmileLinearExpert(
193
+ in_features=in_features,
194
+ out_features=out_features,
195
+ bias=bias,
196
+ k=self.rank_of_expert,
197
+ **factory_kwargs,
198
+ )
199
+ for _ in range(self.num_local_experts)
200
+ ]
201
+ )
202
+ else:
203
+ self.experts = nn.ModuleList(
204
+ [
205
+ nn.Linear(in_features, out_features, bias=bias, **factory_kwargs)
206
+ for _ in range(self.num_local_experts)
207
+ ]
208
+ )
209
+
210
+ def forward(self, hidden_states: Tensor):
211
+ pretrained_out = self.shared_linear(hidden_states)
212
+
213
+ input_shape = hidden_states.size()
214
+ hidden_states = hidden_states.view(-1, self.in_features)
215
+
216
+ router_logits = self.gate(hidden_states)
217
+ routing_weights = F.softmax(router_logits, dim=1)
218
+ # sample the expert according to the routing weights
219
+ routing_weights, selected_experts = torch.topk(
220
+ routing_weights, self.num_experts_per_tok, dim=-1
221
+ )
222
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
223
+
224
+ final_hidden_states = torch.zeros(
225
+ (hidden_states.size(0), self.out_features),
226
+ dtype=hidden_states.dtype,
227
+ device=hidden_states.device,
228
+ )
229
+
230
+ # One hot encode the selected experts to create an expert mask
231
+ # this will be used to easily index which expert is going to be sollicitated
232
+ expert_mask = torch.nn.functional.one_hot(
233
+ selected_experts, num_classes=self.num_local_experts
234
+ ).permute(2, 1, 0)
235
+
236
+ # Loop over all available experts in the model and perform the computation on each expert
237
+ for expert_idx in range(self.num_local_experts):
238
+ expert_layer = self.experts[expert_idx]
239
+ idx, top_x = torch.where(expert_mask[expert_idx])
240
+
241
+ # Index the correct hidden states and compute the expert hidden state for
242
+ # the current expert. We need to make sure to multiply the output hidden
243
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
244
+ current_state = hidden_states[None, top_x].reshape(-1, self.in_features)
245
+ if current_state.numel() == 0:
246
+ continue
247
+ current_hidden_states = (
248
+ expert_layer(current_state) * routing_weights[top_x, idx, None]
249
+ )
250
+
251
+ # However `index_add_` only support torch tensors for indexing so we'll use
252
+ # the `top_x` tensor here.
253
+ final_hidden_states.index_add_(
254
+ 0, top_x, current_hidden_states.to(hidden_states.dtype)
255
+ )
256
+ final_hidden_states = final_hidden_states.reshape(
257
+ *input_shape[:-1], self.out_features
258
+ )
259
+ final_hidden_states = pretrained_out + final_hidden_states
260
+ return final_hidden_states
261
+
262
+ @property
263
+ def weight(self):
264
+ """
265
+ Mimic linear layer. Bacause in some cases, user might indicate the device (or dtype of parameters) of the linear layer using `linear_layer.weight.device`
266
+ """
267
+ return self.shared_linear.weight
268
+
269
+ @property
270
+ def bias(self):
271
+ return self.shared_linear.bias
272
+
273
+ def __repr__(self):
274
+ return (
275
+ f"SingularMoELinear("
276
+ f"in_features={self.shared_linear.in_features}, "
277
+ f"out_features={self.shared_linear.out_features}, "
278
+ f"num_local_experts={self.num_local_experts}, "
279
+ f"num_experts_per_tok={self.num_experts_per_tok}, "
280
+ f"rank_of_router={self.rank_of_router}, "
281
+ f"rank_of_expert={self.rank_of_expert}"
282
+ f")"
283
+ )
284
+
285
+
286
+ class SmileMistralAttention(nn.Module):
287
+ """
288
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
289
+ and "Generating Long Sequences with Sparse Transformers".
290
+ """
291
+
292
+ def __init__(self, config: SmileMistralConfig, layer_idx: Optional[int] = None):
293
+ super().__init__()
294
+ self.config = config
295
+ self.layer_idx = layer_idx
296
+ if layer_idx is None:
297
+ logger.warning_once(
298
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
299
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
300
+ "when creating this class."
301
+ )
302
+
303
+ self.attention_dropout = config.attention_dropout
304
+ self.hidden_size = config.hidden_size
305
+ self.num_heads = config.num_attention_heads
306
+ self.head_dim = config.head_dim
307
+ self.num_key_value_heads = config.num_key_value_heads
308
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
309
+ self.max_position_embeddings = config.max_position_embeddings
310
+ self.rope_theta = config.rope_theta
311
+ self.is_causal = True
312
+
313
+ self.q_proj = SmileLinear(
314
+ config,
315
+ self.hidden_size,
316
+ self.num_heads * self.head_dim,
317
+ bias=False,
318
+ )
319
+ self.k_proj = SmileLinear(
320
+ config,
321
+ self.hidden_size,
322
+ self.num_key_value_heads * self.head_dim,
323
+ bias=False,
324
+ )
325
+ self.v_proj = SmileLinear(
326
+ config,
327
+ self.hidden_size,
328
+ self.num_key_value_heads * self.head_dim,
329
+ bias=False,
330
+ )
331
+ self.o_proj = SmileLinear(
332
+ config,
333
+ self.num_heads * self.head_dim,
334
+ self.hidden_size,
335
+ bias=False,
336
+ )
337
+
338
+ self.rotary_emb = MistralRotaryEmbedding(
339
+ self.head_dim,
340
+ max_position_embeddings=self.max_position_embeddings,
341
+ base=self.rope_theta,
342
+ )
343
+
344
+ def forward(
345
+ self,
346
+ hidden_states: torch.Tensor,
347
+ attention_mask: Optional[torch.Tensor] = None,
348
+ position_ids: Optional[torch.LongTensor] = None,
349
+ past_key_value: Optional[Cache] = None,
350
+ output_attentions: bool = False,
351
+ use_cache: bool = False,
352
+ cache_position: Optional[torch.LongTensor] = None,
353
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
354
+ bsz, q_len, _ = hidden_states.size()
355
+
356
+ query_states = self.q_proj(hidden_states)
357
+ key_states = self.k_proj(hidden_states)
358
+ value_states = self.v_proj(hidden_states)
359
+
360
+ query_states = query_states.view(
361
+ bsz, q_len, self.num_heads, self.head_dim
362
+ ).transpose(1, 2)
363
+ key_states = key_states.view(
364
+ bsz, q_len, self.num_key_value_heads, self.head_dim
365
+ ).transpose(1, 2)
366
+ value_states = value_states.view(
367
+ bsz, q_len, self.num_key_value_heads, self.head_dim
368
+ ).transpose(1, 2)
369
+
370
+ cos, sin = self.rotary_emb(value_states, position_ids)
371
+ query_states, key_states = apply_rotary_pos_emb(
372
+ query_states, key_states, cos, sin
373
+ )
374
+
375
+ if past_key_value is not None:
376
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
377
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
378
+ key_states, value_states = past_key_value.update(
379
+ key_states, value_states, self.layer_idx, cache_kwargs
380
+ )
381
+
382
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
383
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
384
+
385
+ attn_weights = torch.matmul(
386
+ query_states, key_states.transpose(2, 3)
387
+ ) / math.sqrt(self.head_dim)
388
+
389
+ if attention_mask is not None: # no matter the length, we just slice it
390
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
391
+ attn_weights = attn_weights + causal_mask
392
+
393
+ # upcast attention to fp32
394
+ attn_weights = nn.functional.softmax(
395
+ attn_weights, dim=-1, dtype=torch.float32
396
+ ).to(query_states.dtype)
397
+ attn_weights = nn.functional.dropout(
398
+ attn_weights, p=self.attention_dropout, training=self.training
399
+ )
400
+ attn_output = torch.matmul(attn_weights, value_states)
401
+
402
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
403
+ raise ValueError(
404
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
405
+ f" {attn_output.size()}"
406
+ )
407
+
408
+ attn_output = attn_output.transpose(1, 2).contiguous()
409
+
410
+ attn_output = attn_output.view(bsz, q_len, -1)
411
+ attn_output = self.o_proj(attn_output)
412
+
413
+ if not output_attentions:
414
+ attn_weights = None
415
+
416
+ return attn_output, attn_weights, past_key_value
417
+
418
+
419
+ class SmileMistralMLP(nn.Module):
420
+ def __init__(self, config: SmileMistralConfig):
421
+ super().__init__()
422
+ self.hidden_size = config.hidden_size
423
+ self.intermediate_size = config.intermediate_size
424
+ self.gate_proj = SmileLinear(
425
+ config,
426
+ in_features=self.hidden_size,
427
+ out_features=self.intermediate_size,
428
+ bias=False,
429
+ )
430
+ self.up_proj = SmileLinear(
431
+ config,
432
+ in_features=self.hidden_size,
433
+ out_features=self.intermediate_size,
434
+ bias=False,
435
+ )
436
+ self.down_proj = SmileLinear(
437
+ config,
438
+ in_features=self.intermediate_size,
439
+ out_features=self.hidden_size,
440
+ bias=False,
441
+ )
442
+ self.act_fn = ACT2FN[config.hidden_act]
443
+
444
+ def forward(self, hidden_state):
445
+ return self.down_proj(
446
+ self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)
447
+ )
448
+
449
+
450
+ SMILE_MISTRAL_ATTENTION_CLASSES = {
451
+ "eager": SmileMistralAttention,
452
+ # "flash_attention_2": MistralFlashAttention2,
453
+ # "sdpa": MistralSdpaAttention,
454
+ }
455
+
456
+
457
+ class SmileMistralDecoderLayer(nn.Module):
458
+ def __init__(self, config: SmileMistralConfig, layer_idx: int):
459
+ super().__init__()
460
+ self.hidden_size = config.hidden_size
461
+
462
+ self.self_attn = SMILE_MISTRAL_ATTENTION_CLASSES[config._attn_implementation](
463
+ config=config, layer_idx=layer_idx
464
+ )
465
+
466
+ self.mlp = SmileMistralMLP(config)
467
+ self.input_layernorm = MistralRMSNorm(
468
+ config.hidden_size, eps=config.rms_norm_eps
469
+ )
470
+ self.post_attention_layernorm = MistralRMSNorm(
471
+ config.hidden_size, eps=config.rms_norm_eps
472
+ )
473
+
474
+ def forward(
475
+ self,
476
+ hidden_states: torch.Tensor,
477
+ attention_mask: Optional[torch.Tensor] = None,
478
+ position_ids: Optional[torch.LongTensor] = None,
479
+ past_key_value: Optional[Cache] = None,
480
+ output_attentions: Optional[bool] = False,
481
+ use_cache: Optional[bool] = False,
482
+ cache_position: Optional[torch.LongTensor] = None,
483
+ **kwargs,
484
+ ) -> Tuple[
485
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
486
+ ]:
487
+ """
488
+ Args:
489
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
490
+ attention_mask (`torch.FloatTensor`, *optional*):
491
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
492
+ query_sequence_length, key_sequence_length)` if default attention is used.
493
+ output_attentions (`bool`, *optional*):
494
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
495
+ returned tensors for more detail.
496
+ use_cache (`bool`, *optional*):
497
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
498
+ (see `past_key_values`).
499
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
500
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
501
+ Indices depicting the position of the input sequence tokens in the sequence
502
+ kwargs (`dict`, *optional*):
503
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
504
+ into the model
505
+ """
506
+ residual = hidden_states
507
+
508
+ hidden_states = self.input_layernorm(hidden_states)
509
+
510
+ # Self Attention
511
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
512
+ hidden_states=hidden_states,
513
+ attention_mask=attention_mask,
514
+ position_ids=position_ids,
515
+ past_key_value=past_key_value,
516
+ output_attentions=output_attentions,
517
+ use_cache=use_cache,
518
+ cache_position=cache_position,
519
+ **kwargs,
520
+ )
521
+ hidden_states = residual + hidden_states
522
+
523
+ # Fully Connected
524
+ residual = hidden_states
525
+ hidden_states = self.post_attention_layernorm(hidden_states)
526
+ hidden_states = self.mlp(hidden_states)
527
+ hidden_states = residual + hidden_states
528
+
529
+ outputs = (hidden_states,)
530
+
531
+ if output_attentions:
532
+ outputs += (self_attn_weights,)
533
+
534
+ if use_cache:
535
+ outputs += (present_key_value,)
536
+
537
+ return outputs
538
+
539
+
540
+ class SmileMistralPreTrainedModel(PreTrainedModel):
541
+ config_class = SmileMistralConfig
542
+ base_model_prefix = "model"
543
+ supports_gradient_checkpointing = True
544
+ _no_split_modules = ["SmileMistralDecoderLayer"]
545
+ _skip_keys_device_placement = "past_key_values"
546
+ _supports_flash_attn_2 = False
547
+ _supports_sdpa = False
548
+ _supports_cache_class = True
549
+ _supports_static_cache = True
550
+
551
+ def _init_weights(self, module):
552
+ std = self.config.initializer_range
553
+ if isinstance(module, nn.Linear):
554
+ module.weight.data.normal_(mean=0.0, std=std)
555
+ if module.bias is not None:
556
+ module.bias.data.zero_()
557
+ elif isinstance(module, nn.Embedding):
558
+ module.weight.data.normal_(mean=0.0, std=std)
559
+ if module.padding_idx is not None:
560
+ module.weight.data[module.padding_idx].zero_()
561
+
562
+
563
+ class SmileMistralModel(SmileMistralPreTrainedModel):
564
+ def __init__(self, config: SmileMistralConfig):
565
+ super().__init__(config)
566
+ self.padding_idx = config.pad_token_id
567
+ self.vocab_size = config.vocab_size
568
+
569
+ self.embed_tokens = nn.Embedding(
570
+ config.vocab_size, config.hidden_size, self.padding_idx
571
+ )
572
+ self.layers = nn.ModuleList(
573
+ [
574
+ SmileMistralDecoderLayer(config, layer_idx)
575
+ for layer_idx in range(config.num_hidden_layers)
576
+ ]
577
+ )
578
+ self._attn_implementation = config._attn_implementation
579
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
580
+
581
+ self.gradient_checkpointing = False
582
+ # Initialize weights and apply final processing
583
+ self.post_init()
584
+
585
+ def get_input_embeddings(self):
586
+ return self.embed_tokens
587
+
588
+ def set_input_embeddings(self, value):
589
+ self.embed_tokens = value
590
+
591
+ def forward(
592
+ self,
593
+ input_ids: torch.LongTensor = None,
594
+ attention_mask: Optional[torch.Tensor] = None,
595
+ position_ids: Optional[torch.LongTensor] = None,
596
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
597
+ inputs_embeds: Optional[torch.FloatTensor] = None,
598
+ use_cache: Optional[bool] = None,
599
+ output_attentions: Optional[bool] = None,
600
+ output_hidden_states: Optional[bool] = None,
601
+ return_dict: Optional[bool] = None,
602
+ cache_position: Optional[torch.LongTensor] = None,
603
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
604
+ output_attentions = (
605
+ output_attentions
606
+ if output_attentions is not None
607
+ else self.config.output_attentions
608
+ )
609
+ output_hidden_states = (
610
+ output_hidden_states
611
+ if output_hidden_states is not None
612
+ else self.config.output_hidden_states
613
+ )
614
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
615
+
616
+ return_dict = (
617
+ return_dict if return_dict is not None else self.config.use_return_dict
618
+ )
619
+
620
+ # retrieve input_ids and inputs_embeds
621
+ if (input_ids is None) ^ (inputs_embeds is not None):
622
+ raise ValueError(
623
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
624
+ )
625
+
626
+ if self.gradient_checkpointing and self.training and use_cache:
627
+ logger.warning_once(
628
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
629
+ )
630
+ use_cache = False
631
+
632
+ if inputs_embeds is None:
633
+ inputs_embeds = self.embed_tokens(input_ids)
634
+
635
+ return_legacy_cache = False
636
+ if use_cache and not isinstance(past_key_values, Cache) and not self.training:
637
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
638
+ return_legacy_cache = True
639
+ logger.warning_once(
640
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
641
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
642
+ )
643
+
644
+ if cache_position is None:
645
+ past_seen_tokens = (
646
+ past_key_values.get_seq_length() if past_key_values is not None else 0
647
+ )
648
+ cache_position = torch.arange(
649
+ past_seen_tokens,
650
+ past_seen_tokens + inputs_embeds.shape[1],
651
+ device=inputs_embeds.device,
652
+ )
653
+
654
+ if position_ids is None:
655
+ position_ids = cache_position.unsqueeze(0)
656
+
657
+ causal_mask = self._update_causal_mask(
658
+ attention_mask,
659
+ inputs_embeds,
660
+ cache_position,
661
+ past_key_values,
662
+ use_cache,
663
+ output_attentions,
664
+ )
665
+
666
+ hidden_states = inputs_embeds
667
+
668
+ # decoder layers
669
+ all_hidden_states = () if output_hidden_states else None
670
+ all_self_attns = () if output_attentions else None
671
+ next_decoder_cache = None
672
+
673
+ for decoder_layer in self.layers:
674
+ if output_hidden_states:
675
+ all_hidden_states += (hidden_states,)
676
+
677
+ if self.gradient_checkpointing and self.training:
678
+ layer_outputs = self._gradient_checkpointing_func(
679
+ decoder_layer.__call__,
680
+ hidden_states,
681
+ causal_mask,
682
+ position_ids,
683
+ past_key_values,
684
+ output_attentions,
685
+ use_cache,
686
+ cache_position,
687
+ )
688
+ else:
689
+ layer_outputs = decoder_layer(
690
+ hidden_states,
691
+ attention_mask=causal_mask,
692
+ position_ids=position_ids,
693
+ past_key_value=past_key_values,
694
+ output_attentions=output_attentions,
695
+ use_cache=use_cache,
696
+ cache_position=cache_position,
697
+ )
698
+
699
+ hidden_states = layer_outputs[0]
700
+
701
+ if use_cache:
702
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
703
+
704
+ if output_attentions:
705
+ all_self_attns += (layer_outputs[1],)
706
+
707
+ hidden_states = self.norm(hidden_states)
708
+
709
+ # add hidden states from the last decoder layer
710
+ if output_hidden_states:
711
+ all_hidden_states += (hidden_states,)
712
+
713
+ next_cache = next_decoder_cache if use_cache else None
714
+ if return_legacy_cache:
715
+ next_cache = next_cache.to_legacy_cache()
716
+
717
+ if not return_dict:
718
+ return tuple(
719
+ v
720
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
721
+ if v is not None
722
+ )
723
+ return BaseModelOutputWithPast(
724
+ last_hidden_state=hidden_states,
725
+ past_key_values=next_cache,
726
+ hidden_states=all_hidden_states,
727
+ attentions=all_self_attns,
728
+ )
729
+
730
+ def _update_causal_mask(
731
+ self,
732
+ attention_mask: torch.Tensor,
733
+ input_tensor: torch.Tensor,
734
+ cache_position: torch.Tensor,
735
+ past_key_values: Cache,
736
+ use_cache: bool,
737
+ output_attentions: bool,
738
+ ):
739
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
740
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
741
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
742
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
743
+
744
+ if self._attn_implementation == "flash_attention_2":
745
+ if attention_mask is not None and use_cache:
746
+ is_padding_right = (
747
+ attention_mask[:, -1].sum().item() != input_tensor.size()[0]
748
+ )
749
+ if is_padding_right:
750
+ raise ValueError(
751
+ "You are attempting to perform batched generation with padding_side='right'"
752
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
753
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
754
+ )
755
+ if attention_mask is not None and 0.0 in attention_mask:
756
+ return attention_mask
757
+ return None
758
+
759
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
760
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
761
+ # to infer the attention mask.
762
+
763
+ # cache_position must be valid here no matter which cache we use
764
+ past_seen_tokens = cache_position[0] if past_key_values is not None else 0
765
+ using_static_cache = isinstance(past_key_values, StaticCache)
766
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
767
+
768
+ if (
769
+ self.config._attn_implementation == "sdpa"
770
+ and not (using_static_cache or using_sliding_window_cache)
771
+ and not output_attentions
772
+ ):
773
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
774
+ attention_mask,
775
+ inputs_embeds=input_tensor,
776
+ past_key_values_length=past_seen_tokens,
777
+ sliding_window=self.config.sliding_window,
778
+ is_training=self.training,
779
+ ):
780
+ return None
781
+
782
+ dtype, device = input_tensor.dtype, input_tensor.device
783
+ min_dtype = torch.finfo(dtype).min
784
+ sequence_length = input_tensor.shape[1]
785
+ # SlidingWindowCache
786
+ if using_sliding_window_cache:
787
+ target_length = max(sequence_length, self.config.sliding_window)
788
+ # StaticCache
789
+ elif using_static_cache:
790
+ target_length = past_key_values.get_max_length()
791
+ # DynamicCache or no cache
792
+ else:
793
+ target_length = (
794
+ attention_mask.shape[-1]
795
+ if isinstance(attention_mask, torch.Tensor)
796
+ else past_seen_tokens + sequence_length + 1
797
+ )
798
+
799
+ if attention_mask is not None and attention_mask.dim() == 4:
800
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
801
+ if attention_mask.max() != 0:
802
+ raise ValueError(
803
+ "Custom 4D attention mask should be passed in inverted form with max==0`"
804
+ )
805
+ causal_mask = attention_mask
806
+ else:
807
+ causal_mask = torch.full(
808
+ (sequence_length, target_length),
809
+ fill_value=min_dtype,
810
+ dtype=dtype,
811
+ device=device,
812
+ )
813
+ exclude_mask = torch.arange(
814
+ target_length, device=device
815
+ ) > cache_position.reshape(-1, 1)
816
+ if self.config.sliding_window is not None:
817
+ if (
818
+ not using_sliding_window_cache
819
+ or sequence_length > self.config.sliding_window
820
+ ):
821
+ exclude_mask.bitwise_or_(
822
+ torch.arange(target_length, device=device)
823
+ <= (cache_position.reshape(-1, 1) - self.config.sliding_window)
824
+ )
825
+ causal_mask *= exclude_mask
826
+ causal_mask = causal_mask[None, None, :, :].expand(
827
+ input_tensor.shape[0], 1, -1, -1
828
+ )
829
+ if attention_mask is not None:
830
+ causal_mask = (
831
+ causal_mask.clone()
832
+ ) # copy to contiguous memory for in-place edit
833
+ if attention_mask.dim() == 2:
834
+ mask_length = attention_mask.shape[-1]
835
+ padding_mask = (
836
+ causal_mask[:, :, :, :mask_length]
837
+ + attention_mask[:, None, None, :]
838
+ )
839
+ padding_mask = padding_mask == 0
840
+ causal_mask[:, :, :, :mask_length] = causal_mask[
841
+ :, :, :, :mask_length
842
+ ].masked_fill(padding_mask, min_dtype)
843
+
844
+ if (
845
+ self.config._attn_implementation == "sdpa"
846
+ and attention_mask is not None
847
+ and attention_mask.device.type == "cuda"
848
+ and not output_attentions
849
+ ):
850
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
851
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
852
+ # Details: https://github.com/pytorch/pytorch/issues/110213
853
+ causal_mask = AttentionMaskConverter._unmask_unattended(
854
+ causal_mask, min_dtype
855
+ )
856
+
857
+ return causal_mask
858
+
859
+
860
+ class SmileMistralForCausalLM(SmileMistralPreTrainedModel):
861
+ _tied_weights_keys = ["lm_head.weight"]
862
+
863
+ def __init__(self, config):
864
+ super().__init__(config)
865
+ self.model = SmileMistralModel(config)
866
+ self.vocab_size = config.vocab_size
867
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
868
+
869
+ # Initialize weights and apply final processing
870
+ self.post_init()
871
+
872
+ def get_input_embeddings(self):
873
+ return self.model.embed_tokens
874
+
875
+ def set_input_embeddings(self, value):
876
+ self.model.embed_tokens = value
877
+
878
+ def get_output_embeddings(self):
879
+ return self.lm_head
880
+
881
+ def set_output_embeddings(self, new_embeddings):
882
+ self.lm_head = new_embeddings
883
+
884
+ def set_decoder(self, decoder):
885
+ self.model = decoder
886
+
887
+ def get_decoder(self):
888
+ return self.model
889
+
890
+ def forward(
891
+ self,
892
+ input_ids: torch.LongTensor = None,
893
+ attention_mask: Optional[torch.Tensor] = None,
894
+ position_ids: Optional[torch.LongTensor] = None,
895
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
896
+ inputs_embeds: Optional[torch.FloatTensor] = None,
897
+ labels: Optional[torch.LongTensor] = None,
898
+ use_cache: Optional[bool] = None,
899
+ output_attentions: Optional[bool] = None,
900
+ output_hidden_states: Optional[bool] = None,
901
+ return_dict: Optional[bool] = None,
902
+ cache_position: Optional[torch.LongTensor] = None,
903
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
904
+ r"""
905
+ Args:
906
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
907
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
908
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
909
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
910
+
911
+ Returns:
912
+
913
+ Example:
914
+
915
+ ```python
916
+ >>> from transformers import AutoTokenizer, MistralForCausalLM
917
+
918
+ >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
919
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
920
+
921
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
922
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
923
+
924
+ >>> # Generate
925
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
926
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
927
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
928
+ ```"""
929
+
930
+ output_attentions = (
931
+ output_attentions
932
+ if output_attentions is not None
933
+ else self.config.output_attentions
934
+ )
935
+ output_hidden_states = (
936
+ output_hidden_states
937
+ if output_hidden_states is not None
938
+ else self.config.output_hidden_states
939
+ )
940
+ return_dict = (
941
+ return_dict if return_dict is not None else self.config.use_return_dict
942
+ )
943
+
944
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
945
+ outputs = self.model(
946
+ input_ids=input_ids,
947
+ attention_mask=attention_mask,
948
+ position_ids=position_ids,
949
+ past_key_values=past_key_values,
950
+ inputs_embeds=inputs_embeds,
951
+ use_cache=use_cache,
952
+ output_attentions=output_attentions,
953
+ output_hidden_states=output_hidden_states,
954
+ return_dict=return_dict,
955
+ cache_position=cache_position,
956
+ )
957
+
958
+ hidden_states = outputs[0]
959
+ logits = self.lm_head(hidden_states)
960
+ logits = logits.float()
961
+
962
+ loss = None
963
+ if labels is not None:
964
+ # Shift so that tokens < n predict n
965
+ shift_logits = logits[..., :-1, :].contiguous()
966
+ shift_labels = labels[..., 1:].contiguous()
967
+ # Flatten the tokens
968
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
969
+ shift_labels = shift_labels.view(-1)
970
+ # Ensure tensors are on the same device
971
+ shift_labels = shift_labels.to(shift_logits.device)
972
+ loss_fct = CrossEntropyLoss()
973
+ loss = loss_fct(shift_logits, shift_labels)
974
+
975
+ if not return_dict:
976
+ output = (logits,) + outputs[1:]
977
+ return (loss,) + output if loss is not None else output
978
+
979
+ return CausalLMOutputWithPast(
980
+ loss=loss,
981
+ logits=logits,
982
+ past_key_values=outputs.past_key_values,
983
+ hidden_states=outputs.hidden_states,
984
+ attentions=outputs.attentions,
985
+ )
986
+
987
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
988
+ def prepare_inputs_for_generation(
989
+ self,
990
+ input_ids,
991
+ past_key_values=None,
992
+ attention_mask=None,
993
+ inputs_embeds=None,
994
+ cache_position=None,
995
+ position_ids=None,
996
+ use_cache=True,
997
+ **kwargs,
998
+ ):
999
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1000
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1001
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1002
+ if past_key_values is not None:
1003
+ if inputs_embeds is not None: # Exception 1
1004
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1005
+ elif (
1006
+ input_ids.shape[1] != cache_position.shape[0]
1007
+ ): # Default case (the "else", a no op, is Exception 2)
1008
+ input_ids = input_ids[:, cache_position]
1009
+
1010
+ if attention_mask is not None and position_ids is None:
1011
+ # create position_ids on the fly for batch generation
1012
+ position_ids = attention_mask.long().cumsum(-1) - 1
1013
+ position_ids.masked_fill_(attention_mask == 0, 1)
1014
+ if past_key_values:
1015
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1016
+
1017
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1018
+ if inputs_embeds is not None and cache_position[0] == 0:
1019
+ model_inputs = {"inputs_embeds": inputs_embeds}
1020
+ else:
1021
+ model_inputs = {
1022
+ "input_ids": input_ids.contiguous()
1023
+ } # `contiguous()` needed for compilation use cases
1024
+
1025
+ model_inputs.update(
1026
+ {
1027
+ "position_ids": position_ids,
1028
+ "cache_position": cache_position,
1029
+ "past_key_values": past_key_values,
1030
+ "use_cache": use_cache,
1031
+ "attention_mask": attention_mask,
1032
+ }
1033
+ )
1034
+ return model_inputs