fusion-bench 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. fusion_bench/compat/method/__init__.py +1 -0
  2. fusion_bench/compat/method/base_algorithm.py +0 -1
  3. fusion_bench/compat/modelpool/__init__.py +2 -1
  4. fusion_bench/dataset/arc_agi/__init__.py +6 -1
  5. fusion_bench/dataset/arc_agi/arc.py +21 -7
  6. fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
  7. fusion_bench/dataset/arc_agi/np_cache.py +0 -1
  8. fusion_bench/dataset/arc_agi/preprocess.py +50 -8
  9. fusion_bench/dataset/llama/collate.py +10 -3
  10. fusion_bench/method/__init__.py +3 -0
  11. fusion_bench/method/adamerging/__init__.py +1 -1
  12. fusion_bench/method/lm_finetune/fullfinetune_sft.py +47 -5
  13. fusion_bench/method/lm_finetune/peftfinetune_sft.py +58 -23
  14. fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
  15. fusion_bench/method/rankone_moe/__init__.py +3 -0
  16. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  17. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  18. fusion_bench/method/simple_average.py +1 -1
  19. fusion_bench/mixins/clip_classification.py +2 -7
  20. fusion_bench/mixins/lightning_fabric.py +2 -2
  21. fusion_bench/models/rankone_moe.py +410 -0
  22. fusion_bench/taskpool/__init__.py +10 -2
  23. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  24. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  25. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
  26. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/METADATA +1 -1
  27. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/RECORD +36 -29
  28. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +4 -4
  29. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +13 -7
  30. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  31. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
  32. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  33. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/LICENSE +0 -0
  34. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/WHEEL +0 -0
  35. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/entry_points.txt +0 -0
  36. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
1
  fusion_bench/__init__.py,sha256=68dF-zPvb8E2MgYnmgIJsxIHJBy1MApKeOrRZvQEVlg,421
2
2
  fusion_bench/compat/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- fusion_bench/compat/method/__init__.py,sha256=OmhrEUm6lNGJFoqock0N0YvxipIRXhBh09pqkxWKc9A,4743
4
- fusion_bench/compat/method/base_algorithm.py,sha256=ebkl8wkmjYpEDIm0pu2SeS17JMd_6b9G9FHp7ngfPHY,1768
3
+ fusion_bench/compat/method/__init__.py,sha256=yY-ILbwNVTCbor4Z7SOp0wRDbB8FqlaXo4sgF12EhQM,4823
4
+ fusion_bench/compat/method/base_algorithm.py,sha256=Vsc9k04o6FAhu509xGYc1vZWkmegQOjqoqT7IJ8p7CA,1741
5
5
  fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py,sha256=m68BRGy4P-P9lLB10oXOBI-p58a-0FOPcrJ4r4MU32k,1100
6
- fusion_bench/compat/modelpool/__init__.py,sha256=eTglIrZdLgUkaiKFP8Pcf3nfOXFfIXBmJzR036fCq68,4664
6
+ fusion_bench/compat/modelpool/__init__.py,sha256=C0CFrqaIKRiAvhT0PT3vM98fZwmpxL34wfb4FbeKcdo,4665
7
7
  fusion_bench/compat/modelpool/base_pool.py,sha256=1gxQENvdcOSdHmUbw-x7-X-aXtoSa1Gsys_on1ys8FM,10639
8
8
  fusion_bench/compat/modelpool/huggingface_clip_vision.py,sha256=LyIPgepNOK0qrk_EnBdlTC0ZnEkEZvPUy45cO60TiPU,6918
9
9
  fusion_bench/compat/taskpool/__init__.py,sha256=fTHd7_7EwSM2K06gUCQZ1jxxhl8T_kP0ouv70wBLhpI,3630
@@ -19,30 +19,30 @@ fusion_bench/dataset/gsm8k.py,sha256=CmANZ0A89PfPwVu_myKhXk1D9IwypOpjH3iqDo1KxcQ
19
19
  fusion_bench/dataset/image_dataset.py,sha256=MSZE_UESyRRQDwnkm2KpyIARUg9SWcwqnH4fDNstzS4,1870
20
20
  fusion_bench/dataset/imdb.py,sha256=YRzeq5z-Fl0aYcC2QtwEBWFkvucvpNo975jwjL5SZvs,260
21
21
  fusion_bench/dataset/nyuv2.py,sha256=2OdIEaY1ywFYMLUxCTpFcIctcBMFTq4nnoOkudSo-jI,3750
22
- fusion_bench/dataset/arc_agi/__init__.py,sha256=JmMAKCk56GKbOOBnfMJtbrkkCVVFPYRhWqW2XvyiHf0,52
23
- fusion_bench/dataset/arc_agi/arc.py,sha256=jhEDJWSbKRjp1vVbKDwIaSBRzXcz6Ir_CiP4mjtiPpA,9221
24
- fusion_bench/dataset/arc_agi/arc_agi.py,sha256=BEYPktvG77zaLnPK7672jck7NiFGpKaRepJyIPm3bYM,7095
22
+ fusion_bench/dataset/arc_agi/__init__.py,sha256=xj8BMG296qPMiL4NYs-ZwqcLJ6yT2wwbubyCbWPe91w,149
23
+ fusion_bench/dataset/arc_agi/arc.py,sha256=AfRivFvuyumYKjlJq3LSbAzFAdHB0lY4NS8KlxhWqjU,9396
24
+ fusion_bench/dataset/arc_agi/arc_agi.py,sha256=SFOjp0yZrsoln4cQgWU2b-WfI39od6IE1Wof8Ee0888,11768
25
25
  fusion_bench/dataset/arc_agi/augmenters.py,sha256=yhTqyRk0_zamXRQ5Ev10xYc8Dc9D71BTSOkt856x33I,30890
26
26
  fusion_bench/dataset/arc_agi/messagers.py,sha256=E6BqF1iL68ge1m9wOJMSb2Pz6_5i9CR0HxBb7i73plE,53076
27
- fusion_bench/dataset/arc_agi/np_cache.py,sha256=1OoqMEdu9MwiaO086HZPOwfoYmwojFJfSRx9ApP8WgU,5440
28
- fusion_bench/dataset/arc_agi/preprocess.py,sha256=NSuM9ECucPamYd-Ost0voIxR19rBrD0JXLbxzXZr898,6741
27
+ fusion_bench/dataset/arc_agi/np_cache.py,sha256=Ec1DQFtlBdMy-f4dvGEhSr4jyVnBLQELwvX1ztxJKBs,5439
28
+ fusion_bench/dataset/arc_agi/preprocess.py,sha256=SLmkhq76RJ8zTto5JHNFORYEr2GkbrhP81pKz1A8_BE,8523
29
29
  fusion_bench/dataset/arc_agi/representers.py,sha256=-2eTYl-UcFW4zULDjkUrOQYv9P31nttMjc9eTJsaN0g,35852
30
30
  fusion_bench/dataset/llama/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
31
  fusion_bench/dataset/llama/alpaca.py,sha256=sITFsghX2w0KzLwQ71KRz6rfsI2WLjuuKwt8OetvmCQ,4778
32
- fusion_bench/dataset/llama/collate.py,sha256=exaJAi0EbPTwMf69rgzeaD2IhY4YwtgoWgRfTQtZha0,1717
32
+ fusion_bench/dataset/llama/collate.py,sha256=wcnt9Y2G4Isbdof3HAfe-xTbUThGo7IM0AZsn0FTmBs,1932
33
33
  fusion_bench/dataset/llama/openai.py,sha256=_QXz6ciUTN8u4ILDowZPT3SQTes7ngkFZe1MRLFtVQ8,5012
34
34
  fusion_bench/dataset/llama/sharegpt.py,sha256=8hdh_5BcxIyK0ByZoVLdhd_I06kpHffxQdaC6ezzHkM,5249
35
35
  fusion_bench/dataset/llama/squad.py,sha256=H0L0BHFzVTtkw7jfgTA8gzvZDhzsqfIALq1ip_BVwaM,4810
36
36
  fusion_bench/dataset/llama/wikitext.py,sha256=9ZHR-nMfXRumd3o-PIj3n7B83YlVeqpGkZ2zJs2B-9Y,2883
37
- fusion_bench/method/__init__.py,sha256=z_Bx0533GxTdgLd-x4hxdDoF5gND199XFiPthr86Yhw,5585
37
+ fusion_bench/method/__init__.py,sha256=NSBIKPSjcZbZDVuwr8srDDfntfz3jQilozRCqHPYj_w,5751
38
38
  fusion_bench/method/base_algorithm.py,sha256=5dutGZfPqNhO8F8FOlo3UFR91TZu2Xj7O0pTB40JvWo,1135
39
39
  fusion_bench/method/dummy.py,sha256=hb1y6LR_geRZ5eRgGwt5zJUcHYorCeIbs5i76CvurUc,1031
40
40
  fusion_bench/method/ensemble.py,sha256=rGxvJTeorfcBuE_e0XO-0-MAc9un7ZCC46ikKGuAcN4,3077
41
41
  fusion_bench/method/model_recombination.py,sha256=2tviqmYSPOL0_Ktv8_gt_YzQ4tyCANHxXquUot_3Cgo,5360
42
- fusion_bench/method/simple_average.py,sha256=yzQg-qzldJtfPG0uYLFLQTSpeXc8Q4H88pkztzIdXds,4481
42
+ fusion_bench/method/simple_average.py,sha256=2ghcL1E-eLbIYDCHYCoR9WtiYSb1GvFAH163OTTTEEI,4481
43
43
  fusion_bench/method/ada_svd/__init__.py,sha256=4XzQbbvE9HI3NtEmEFvo8iC3ds_85vJXe7P7qJfL7kk,77
44
44
  fusion_bench/method/ada_svd/clip_vision.py,sha256=QrT6cSwgVEGxXEpVhkvKQVQaoRW5P9V52Y3_8NX0f-o,12556
45
- fusion_bench/method/adamerging/__init__.py,sha256=tRQYeKUUIejmKn6YtQ0rtSlok5UQBnmQqbNokgjlNk4,376
45
+ fusion_bench/method/adamerging/__init__.py,sha256=nt0saBT_3bqghk-pINQ-XCWm9UWwSZllu4R1sDuAJAA,376
46
46
  fusion_bench/method/adamerging/clip_layer_wise_adamerging.py,sha256=YdQ4trHohW6QzWC2enYvXA44WHxvzmoH_6sMrPn6z60,1305
47
47
  fusion_bench/method/adamerging/clip_task_wise_adamerging.py,sha256=Tys9pDJzz5YNUCO43pO44fGAnizfSaeAwgH4-vVxRN4,6948
48
48
  fusion_bench/method/adamerging/entropy_loss.py,sha256=ZeVe0Hq1PaMfppLqDbB0MOscZUZRNh4CALrvt8pmQC0,736
@@ -84,8 +84,8 @@ fusion_bench/method/linear/simple_average_for_llama.py,sha256=7JlVrmTMmrePvNGnZN
84
84
  fusion_bench/method/linear/task_arithmetic_for_llama.py,sha256=4SZpiTD7OzhWUXtcdK3PYdXbBGyDqiZd7oZOQ0lraN0,1963
85
85
  fusion_bench/method/lm_finetune/__init__.py,sha256=rIkKoxrqKEYkA7XIR6jyhwvUK_ebX2k6Fm1d7K1kU5g,92
86
86
  fusion_bench/method/lm_finetune/causal_lm_pretrain.py,sha256=4CL9KGFsUzrt-edMfTooo4G4apzTH_57rso3DGGvKL0,219
87
- fusion_bench/method/lm_finetune/fullfinetune_sft.py,sha256=J0yhnmM1TXRUC-Mte0a26BAbnHNDLVb7JEq8zrJzvLs,16931
88
- fusion_bench/method/lm_finetune/peftfinetune_sft.py,sha256=sLWlJRwFtSw9vXbbDF_dKW2AbgHK1BNzihK9hPho7ck,17835
87
+ fusion_bench/method/lm_finetune/fullfinetune_sft.py,sha256=WoVOzFhg1PRUm8iPMYJ1g98-km3wux6nrUqnWXm27Pg,18364
88
+ fusion_bench/method/lm_finetune/peftfinetune_sft.py,sha256=EwJJITxYBFtjsjunOlpSdo70dWeoHUYI-qIyelgW4n4,19834
89
89
  fusion_bench/method/mixture_of_experts/__init__.py,sha256=r95iu1-3tgIUP7sWuAbLuqV7xexNYMYPZkM4_8egfp8,198
90
90
  fusion_bench/method/mixture_of_experts/mixtral_merging.py,sha256=-n1CLP1o08VyMSfaTq42kRutbw-cFDSCWHTu0iNh6ok,4237
91
91
  fusion_bench/method/mixture_of_experts/mixtral_upcycling.py,sha256=tQYAeS8MLFEfH3zDFfNZrML7lRnpGLN-HquQvjPtHNw,11208
@@ -93,7 +93,7 @@ fusion_bench/method/pruning/__init__.py,sha256=3gtmay2bkdIAEGjpAhbY2ztMZOZLKhiJc
93
93
  fusion_bench/method/pruning/llama_magnitude_prune.py,sha256=ihHa8SNe0WGPuZqRKI_6S6gmH4ooTmeTRARGkJHcsos,6300
94
94
  fusion_bench/method/pruning/llama_random_prune.py,sha256=c-qV1iFSKZK1dES6gYsgWna1BUn58dtO0NjV1eIfJrg,4566
95
95
  fusion_bench/method/pruning/llama_wanda_prune.py,sha256=8pcg3X1yn8vfhV0lEg1fHP3oTzAc_-ixLmsZRdH5uPo,12070
96
- fusion_bench/method/pruning/magnitude_diff_pruning.py,sha256=vMyhZF_dWLkgB9A1RGpuYugJ6B-estwTvICj5WC904g,6450
96
+ fusion_bench/method/pruning/magnitude_diff_pruning.py,sha256=nXRHW87_Nwiash-udnwR9iOaJMBDo7fPTmAwmSqsAaI,6451
97
97
  fusion_bench/method/pruning/prune_utils.py,sha256=ITWO8WtrhcOYXTcjc_fAAw7cyjvqFa6axawPr3uTT68,5882
98
98
  fusion_bench/method/pruning/wanda_utils/__init__.py,sha256=ujOZ9GUTwzqfVjXUL0e6y_gAEfTQU85rBq2MZ5om7oQ,320
99
99
  fusion_bench/method/pruning/wanda_utils/ablate.py,sha256=TUKsbInQD3UmS8FpuFeco6FeTMaJLZXho9ASWRPcurc,6459
@@ -109,6 +109,9 @@ fusion_bench/method/pwe_moe/module.py,sha256=D4HDx7iDfKX_vJ3vkzi6_atKKlzJT6nH0sr
109
109
  fusion_bench/method/pwe_moe/utils.py,sha256=K9BeVMrhYv7GNlJO76eoQbkI1dOO7XF18yK06WUh9ZA,1336
110
110
  fusion_bench/method/pwe_moe/phn/__init__.py,sha256=PXX-hb_bd7GdtLHcAcnGGsW_Wbg8g2YlRZMTCk3axUw,78
111
111
  fusion_bench/method/pwe_moe/phn/solvers.py,sha256=OO-ImNwsWIQ3eXPxzj1V-kNgXrJc4FKcK-RwaOl_np0,6156
112
+ fusion_bench/method/rankone_moe/__init__.py,sha256=hvYxnloCrzim9s7HUaNA3dcuThEcfrFL5EMw34YNHeE,119
113
+ fusion_bench/method/rankone_moe/clip_rankone_moe.py,sha256=2wnzyHHZSQagZenu9viJ-68MmRG0ppOLR5JHZuT1FKE,5457
114
+ fusion_bench/method/rankone_moe/rankone_moe.py,sha256=YPWneidBJjms2SrYgH5tAim4KBl3Rrcmeq9Xf5QwU58,8489
112
115
  fusion_bench/method/regmean/__init__.py,sha256=VVqAkdHkb005Sc2XmeiedQYzb3q5aQNI8xzEJnE4thg,158
113
116
  fusion_bench/method/regmean/clip_regmean.py,sha256=xhT7dYSCg9sPLL5ZUCCtcA-Ypw4PBHsOivrnz-3fDso,4931
114
117
  fusion_bench/method/regmean/gpt2_regmean.py,sha256=p2D3E8YAZsltsI6GM474UWNqPZfBqihLZ93ZLUpOJ_c,5565
@@ -151,8 +154,8 @@ fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py,sha256=-ZaD84E
151
154
  fusion_bench/metrics/text_to_image_generation/compressibility.py,sha256=x4dNTFnAN4naChBDZBO-jUghnHAyobRVOupctKYRg1w,1656
152
155
  fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py,sha256=aSWzl8k7z80Cirg5qdfkPsp3sMFEv_PjA1NJv3PPWXY,3115
153
156
  fusion_bench/mixins/__init__.py,sha256=hMxt39JDb_uIvNDtp6ZJEDmaQFwx8GId2VK2Wajw9Rg,791
154
- fusion_bench/mixins/clip_classification.py,sha256=rFF90BPrtkVWF8H1n1du9F2o0i2da9PfC3m0ipGsdus,8201
155
- fusion_bench/mixins/lightning_fabric.py,sha256=LPiBkOpUVltzFXBI6BkROMtYswITJyoALLboZrBItu8,6163
157
+ fusion_bench/mixins/clip_classification.py,sha256=devw9zTpyJsCfGCR_iKuuT9iPp1XWUqqRHRdliK6riM,8030
158
+ fusion_bench/mixins/lightning_fabric.py,sha256=S81Bf9IDktaz2RM5T69TgiwPewUJfliLy6kd-dq3kdc,6163
156
159
  fusion_bench/mixins/rich_live.py,sha256=j7wNgrgwfdpl6nCXZGF_2DLtNq2aqCb_52Qhe9QSltc,495
157
160
  fusion_bench/mixins/serialization.py,sha256=9W50JUcM6wgFlaE9H29mATLLVobYniSDxg94FfY25w0,4049
158
161
  fusion_bench/mixins/simple_profiler.py,sha256=UDPB8QAA3rtsSdnVgL9KMthDLBY1Rh4h8mtiquiCPp4,2106
@@ -173,6 +176,7 @@ fusion_bench/modelpool/seq2seq_lm/modelpool.py,sha256=IjLHi8qycWOA4Ul9jnqR48evgV
173
176
  fusion_bench/models/__init__.py,sha256=TNOEH_2yAQP51m9mdWepNEf9VGUZgDthtgXbs4rhb4M,100
174
177
  fusion_bench/models/hf_clip.py,sha256=yOQ6UKMymQ3GcfpPm26QiToPztij-cXukNMMKXTmUrw,5745
175
178
  fusion_bench/models/parameter_dict.py,sha256=hRie26WIeXU-wvY6JeGaP8LvpMqbuZA6Ia_1vOFMuu4,2294
179
+ fusion_bench/models/rankone_moe.py,sha256=uwpAqk1cwxxprQ0hxuAwRuPvHDxxBKBDahd9vcaafXs,14248
176
180
  fusion_bench/models/separate_io.py,sha256=5AJlCxkHdVVffITnIRlF3ZIaKLRWDhJESVQN1lX-ZhU,3835
177
181
  fusion_bench/models/sparse_we_moe.py,sha256=b-yIeCsl2rz0i7BP9g_fqCEam7KUNjNX_J8oyZV6MJ8,16509
178
182
  fusion_bench/models/utils.py,sha256=7HKXRiWHeoNWp8LyDemG2irnMPkT9qg2ExvxjE5mUck,1858
@@ -224,12 +228,13 @@ fusion_bench/scripts/nyuv2_mtl_train.py,sha256=hB_P_4DIT83CGOXoyyaBnh9fYnxTJtvAP
224
228
  fusion_bench/scripts/webui.py,sha256=ryA-2leSnHcYA88tTAYzJGDhiljbi0vl1Fibejzndlw,14398
225
229
  fusion_bench/scripts/clip/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
226
230
  fusion_bench/scripts/clip/convert_checkpoint.py,sha256=zncgRAhInFpJDSHIm3GO4F6BzgsdAQVj3LLmV7g-JiQ,1221
227
- fusion_bench/taskpool/__init__.py,sha256=YgWy1iMYBmy2jvejjxHAE6-idaz9NS9qfqE5OFLaC9g,954
231
+ fusion_bench/taskpool/__init__.py,sha256=_qaYgzYnvrJDrZ2DjKXMvOFbelsLrujCKa_gP3UQBBg,1094
228
232
  fusion_bench/taskpool/base_pool.py,sha256=FaP0nndeSsrwbdd9mKa_CedbX9T5AHJmxk7Lc0NEVNY,835
229
233
  fusion_bench/taskpool/dummy.py,sha256=Di9JZO3XyDYn6wAGukrJMTnkS_NaxGTeQYo_3j1JD3Y,1675
230
234
  fusion_bench/taskpool/gpt2_text_classification.py,sha256=S4YyrcJhD4JOgvHF-AVG-gENgVGl-wpQZr1SbiThM04,4886
231
235
  fusion_bench/taskpool/nyuv2_taskpool.py,sha256=lnaR1oVm0pO2CA9EVV4uk3fiWYHD-F0GzPrUUARD75I,1970
232
- fusion_bench/taskpool/clip_vision/__init__.py,sha256=V_xu4npg1XJV8PV82I4QqLTlNoOTJVqUHTwYt5FS6BE,141
236
+ fusion_bench/taskpool/clip_vision/__init__.py,sha256=4xGO7rRbRpXF-I34A3WEMU4vydgfdtvXQ57ThaFcpmE,214
237
+ fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py,sha256=JKbRrGaRYztgZ-P0U767HISe40UpDVQ7fn6Tf2rrug0,4891
233
238
  fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py,sha256=hVDTtg-oXqRFmAE2wZPFpk_kvtdk_wS-2-ev2ujEJBs,5390
234
239
  fusion_bench/taskpool/clip_vision/taskpool.py,sha256=NRFXsp2N8PMQzZgFHy2yfJMjoYbDaxQpPTZ4-4EHPBY,13942
235
240
  fusion_bench/taskpool/llama/__init__.py,sha256=iB4ESMgnsl0m-z0YtRdPZiwGGv96-86R8pbSnkdet8Q,57
@@ -258,7 +263,7 @@ fusion_bench/tasks/clip_classification/tiny_imagenet.py,sha256=Ar9uQOqUcgGl7MQX9
258
263
  fusion_bench/tasks/flan_t5_text_generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
259
264
  fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py,sha256=zo5S73jm7YDTMN_FxcPNM2dxQkqv2K2siw2xELARPwk,2448
260
265
  fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py,sha256=-B1wqVGp3wZVs0NB4fqoW0u2TvxOpLYzZF1RzppJ5sc,4357
261
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py,sha256=B9hLhJBDVilitvwdLkc3bpmIcUuhKlDY6AaQQsZz2R8,1832
266
+ fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py,sha256=sVihXHbqwi8IlDpiIxzvmDv-Ob7WKvi23GIRYbBUKOc,1833
262
267
  fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py,sha256=GhRmGmcJGF4oVgZQarsBtx8GNKrNEZUkrillNz3iBuY,13183
263
268
  fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py,sha256=mKMTXIr5o-BqS_Hvv1bbMvvjQLLeKNVw7BKS9qgQ8Dw,1890
264
269
  fusion_bench/utils/__init__.py,sha256=yFhiBlrdcsJqZe-C5wdlZZ3wpmSN8Tipfpa2-R7CFbc,337
@@ -387,13 +392,14 @@ fusion_bench_config/method/linear/simple_average_for_llama.yaml,sha256=QJR5qx9z4
387
392
  fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml,sha256=N7cyHm6a2QwNsV9uaJp-eZmdbs9kmdRrkxtO58QQQgM,116
388
393
  fusion_bench_config/method/linear/weighted_average.yaml,sha256=SmELszTsJU63e8KwIrPmSqKmOmH-rz42zeumQZHoVDY,187
389
394
  fusion_bench_config/method/linear/weighted_average_for_llama.yaml,sha256=r8BlNqzRfn--_gDSff6KI8FO-elWFIszZDRV7G_nvHw,499
390
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml,sha256=YgiRBeTCQKeMjkxRhABw3teEvGc6X74w43_QVVumcVg,1189
391
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml,sha256=UDwjd4vlQ-LgthHeOzyd3c1HeoY8lD_5F7kefpMXhNI,1471
395
+ fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml,sha256=iJgRZiT-fic7jJOMSmq-4vslQXBIoE7IdrxPC4GQ9Cs,1157
396
+ fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml,sha256=_LIlnNoGLJfJpchB9AYvZMRby8oG_PU3p7mdA24Eq0k,1556
392
397
  fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml,sha256=Px8LU_UtDz-YHDFfqQ7scEPOproiFOaudKVshrhCTgc,483
393
398
  fusion_bench_config/method/pruning/llama_random_pruning.yaml,sha256=0RiZS8d42PXZzwncPG8zcbnyYJ9vtfr2sOSqS8oDyT4,325
394
399
  fusion_bench_config/method/pruning/llama_wanda_pruning.yaml,sha256=qKe5yIRsmK2KUyYENENWlw1qlGet9TpDhR-E_uO7vAw,501
395
400
  fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml,sha256=GsxsQ2L3kfsdD7A8o7UAHfiSbAGh53zVXdlYuEIEWR0,130
396
- fusion_bench_config/method/regmean/clip_regmean.yaml,sha256=svZqwicYpbEx1vZL2IISfQulBNAmTm8X_mAP6JrLCDU,402
401
+ fusion_bench_config/method/rankone_moe/rankone_moe.yaml,sha256=RWf94HqYBinZxH-jhi3h8UOLXxv1P5doy0YcTQM-plw,866
402
+ fusion_bench_config/method/regmean/clip_regmean.yaml,sha256=dxSJMRam6YMks7zYx4ACgvrLP5cndxzraVO93SGhyYo,425
397
403
  fusion_bench_config/method/regmean/gpt2_regmean.yaml,sha256=CL6f3GKQTSiLonrak8uEFoFn6MrzQ-ZJp4zXCwCllSk,423
398
404
  fusion_bench_config/method/regmean/regmean.yaml,sha256=ZgVVLx-lHwVgjtjTl4VZUlthh8yyua87QvoJfmNHud4,101
399
405
  fusion_bench_config/method/slerp/slerp.yaml,sha256=DIsS8xS2CnKLyF5OHz_RWG87A48iElevDbVTUHYobDg,118
@@ -514,10 +520,11 @@ fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8
514
520
  fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml,sha256=UYOSR9RJhup6pSC0N7UvvnlpXTkiCdD4tzsx-HyQ_GA,269
515
521
  fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml,sha256=_hqQweyZdCztqvjtuYrhCx4Hdqe959FFCdL7_IspR2w,261
516
522
  fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml,sha256=9hbvC3k5x6NpA9tRDYeORhrjEyd2VH5ztMdLU67Adjk,249
523
+ fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml,sha256=iQMj2VpDTe_D8OfCo94w5Ud2MON-EGa0DzVr6UmphrA,436
517
524
  fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml,sha256=i5Bn8bLl2cgqvrgtIGmoovUfSMehk_m-6C2wwcx5JMU,435
518
- fusion_bench-0.2.5.dist-info/LICENSE,sha256=nhnOJlw4CPuPVE0qvkGmxfFgHmKi-6nzXvTu8t0NUdg,1066
519
- fusion_bench-0.2.5.dist-info/METADATA,sha256=Kv69uDo6ROZOarhCQ81ldxjtsp_9oF9nrMMzY1WE4C4,13528
520
- fusion_bench-0.2.5.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
521
- fusion_bench-0.2.5.dist-info/entry_points.txt,sha256=iUQ8MCJvda7HP4vYh2n1Teoapb4G9PBVYZkAfcc5SHU,116
522
- fusion_bench-0.2.5.dist-info/top_level.txt,sha256=BuO4TL6iHL_2yPBUX9-LlIrHRczA_BNMIFwweK0PQEI,13
523
- fusion_bench-0.2.5.dist-info/RECORD,,
525
+ fusion_bench-0.2.6.dist-info/LICENSE,sha256=nhnOJlw4CPuPVE0qvkGmxfFgHmKi-6nzXvTu8t0NUdg,1066
526
+ fusion_bench-0.2.6.dist-info/METADATA,sha256=eExQgyXjCnwYCSMfJ3h9yH4vWaviRwNogM0OMJktUDU,13528
527
+ fusion_bench-0.2.6.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
528
+ fusion_bench-0.2.6.dist-info/entry_points.txt,sha256=iUQ8MCJvda7HP4vYh2n1Teoapb4G9PBVYZkAfcc5SHU,116
529
+ fusion_bench-0.2.6.dist-info/top_level.txt,sha256=BuO4TL6iHL_2yPBUX9-LlIrHRczA_BNMIFwweK0PQEI,13
530
+ fusion_bench-0.2.6.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
- _target_: ttt.method.FullFinetuneSFT
1
+ _target_: fusion_bench.method.FullFinetuneSFT
2
2
  _recursive_: False
3
3
 
4
4
  optimizer:
@@ -8,9 +8,9 @@ optimizer:
8
8
  lr: 5e-5
9
9
 
10
10
  lr_scheduler:
11
- _target_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
12
- num_warmup_steps: 5
13
- num_training_steps: _T_max_ # this will be replaced by the expected number of training steps
11
+ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
12
+ T_max: _T_max_ # this will be replaced by the expected number of training steps
13
+ eta_min: 1e-6
14
14
 
15
15
  dataloader_kwargs:
16
16
  # per-gpu batch size
@@ -1,4 +1,4 @@
1
- _target_: ttt.method.FullFinetuneSFT
1
+ _target_: fusion_bench.method.PeftFinetuneSFT
2
2
  _recursive_: False
3
3
 
4
4
  optimizer:
@@ -8,9 +8,9 @@ optimizer:
8
8
  lr: 5e-5
9
9
 
10
10
  lr_scheduler:
11
- _target_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
12
- num_warmup_steps: 5
13
- num_training_steps: _T_max_ # this will be replaced by the expected number of training steps
11
+ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
12
+ T_max: _T_max_ # this will be replaced by the expected number of training steps
13
+ eta_min: 1e-6
14
14
 
15
15
  dataloader_kwargs:
16
16
  # per-gpu batch size
@@ -22,9 +22,14 @@ peft_config:
22
22
  _target_: peft.LoraConfig
23
23
  task_type: peft.TaskType.CAUSAL_LM
24
24
  target_modules:
25
- - query
26
- - value
27
- r: 16
25
+ # lora attention modules
26
+ - q_proj
27
+ - v_proj
28
+ # lora mlp modules
29
+ - gate_proj
30
+ - down_proj
31
+ - up_proj
32
+ r: 64
28
33
  lora_alpha: 16
29
34
  lora_dropout: 0
30
35
  bias: none
@@ -53,3 +58,4 @@ save_optimizer_state: false
53
58
  save_full_model: false
54
59
  # Path to checkpoint to load from, used for resuming training
55
60
  ckpt_path: null
61
+ max_length: 4096
@@ -0,0 +1,26 @@
1
+ name: ??? # this can be
2
+ # the path for loading the model weights, if specified, skip the test-time adaptation training
3
+ checkpoint: False
4
+ # the path for saving the model weights.
5
+ save_checkpoint: False
6
+ router_hidden_layers: 1
7
+ init_lambda: 0.3
8
+ batch_reduce: true
9
+
10
+ # device to compute svd
11
+ svd_accelerator: cuda
12
+ rank_k: 32 # How many experts are added to the pool per task?
13
+ select_k: -1 # How many experts are selected from the pool to merge? Range is (1, rank_k*task_num). In particular -1: All the experts in the pool
14
+
15
+ # learning rate
16
+ lr: 1e-4
17
+ optimizer: adam
18
+ # this is overrided by `fabric.devices` if launched from the `fusion_bench` CLI.
19
+ devices: 1
20
+ batch_size: 16
21
+ num_workers: 16
22
+ max_steps: 1000 # default: 1000
23
+ # if true, we will use the gradient accumulation across tasks to save memory
24
+ use_grad_accumulate: true
25
+ cache_dir: outputs
26
+ fast_dev_run: ${fast_dev_run}
@@ -3,6 +3,7 @@ _target_: fusion_bench.method.RegMeanAlgorithmForCLIP
3
3
  exclude_param_names_regex: []
4
4
  # numbers of examples to compute regmean weights
5
5
  num_regmean_examples: 256
6
+ weight_transpose: true
6
7
  # float, reduce non-diagonal elements in regmean weights by multiplying this scalar
7
8
  reduce_non_diagonal_ratio: 0.6
8
9
  dataloader_kwargs:
@@ -0,0 +1,18 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets:
4
+ - sun397
5
+ - stanford-cars
6
+ - resisc45
7
+ - eurosat
8
+ - svhn
9
+ - gtsrb
10
+ - mnist
11
+ - dtd
12
+ - _self_
13
+
14
+ _target_: fusion_bench.taskpool.RankoneWEMoECLIPVisionModelTaskPool
15
+
16
+ # === layer-wise routing weights saving ===
17
+ layer_wise_routing_weights_save_path: null
18
+ layer_wise_routing_weights_max_num: 1000