fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. fusion_bench/__init__.py +1 -0
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +5 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/dataset/clip_dataset.py +2 -1
  9. fusion_bench/dataset/gpt2_glue.py +9 -9
  10. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  11. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  12. fusion_bench/dataset/image_dataset.py +1 -1
  13. fusion_bench/dataset/nyuv2.py +2 -2
  14. fusion_bench/method/__init__.py +16 -1
  15. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  16. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  17. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  18. fusion_bench/method/base_algorithm.py +195 -12
  19. fusion_bench/method/bitdelta/__init__.py +4 -0
  20. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  21. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  25. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  26. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  27. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  28. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  29. fusion_bench/method/ensemble.py +12 -12
  30. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  31. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
  32. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  33. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  34. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  35. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  36. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  37. fusion_bench/method/linear/expo.py +2 -1
  38. fusion_bench/method/linear/linear_interpolation.py +6 -4
  39. fusion_bench/method/linear/simple_average_for_llama.py +16 -6
  40. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  41. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  42. fusion_bench/method/model_recombination.py +2 -5
  43. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  44. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  45. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  46. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  47. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  48. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  49. fusion_bench/method/randes/modelsoup.py +1 -3
  50. fusion_bench/method/regmean/clip_regmean.py +2 -2
  51. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  52. fusion_bench/method/regmean/regmean.py +2 -11
  53. fusion_bench/method/regmean_plusplus/__init__.py +3 -0
  54. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
  55. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
  56. fusion_bench/method/simple_average.py +16 -4
  57. fusion_bench/method/slerp/slerp.py +5 -2
  58. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  59. fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
  60. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
  61. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  62. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  63. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  64. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  65. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  66. fusion_bench/method/we_moe/we_moe.py +6 -6
  67. fusion_bench/method/weighted_average/llama.py +4 -16
  68. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  69. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  70. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  71. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  72. fusion_bench/mixins/__init__.py +10 -2
  73. fusion_bench/mixins/clip_classification.py +4 -3
  74. fusion_bench/mixins/hydra_config.py +105 -7
  75. fusion_bench/mixins/lightning_fabric.py +2 -0
  76. fusion_bench/mixins/serialization.py +265 -48
  77. fusion_bench/modelpool/__init__.py +2 -2
  78. fusion_bench/modelpool/base_pool.py +29 -9
  79. fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
  80. fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
  81. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  82. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  83. fusion_bench/models/__init__.py +2 -1
  84. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  85. fusion_bench/models/hf_utils.py +182 -0
  86. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  87. fusion_bench/models/linearized/vision_model.py +1 -1
  88. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  89. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  90. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  91. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  92. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  93. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  94. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  95. fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
  96. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  97. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
  98. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  99. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  100. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  101. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
  102. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  103. fusion_bench/models/parameter_dict.py +1 -1
  104. fusion_bench/models/sparse_we_moe.py +1 -53
  105. fusion_bench/models/utils.py +26 -0
  106. fusion_bench/models/we_moe.py +1 -53
  107. fusion_bench/models/wrappers/ensemble.py +6 -4
  108. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  109. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  110. fusion_bench/programs/base_program.py +81 -2
  111. fusion_bench/programs/fabric_fusion_program.py +24 -8
  112. fusion_bench/scripts/cli.py +6 -6
  113. fusion_bench/taskpool/base_pool.py +4 -3
  114. fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
  115. fusion_bench/taskpool/dummy.py +1 -1
  116. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  117. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  118. fusion_bench/utils/__init__.py +6 -1
  119. fusion_bench/utils/devices.py +14 -4
  120. fusion_bench/utils/instantiate_utils.py +3 -1
  121. fusion_bench/utils/misc.py +48 -2
  122. fusion_bench/utils/modelscope.py +265 -0
  123. fusion_bench/utils/parameters.py +2 -2
  124. fusion_bench/utils/rich_utils.py +3 -0
  125. fusion_bench/utils/state_dict_arithmetic.py +34 -27
  126. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
  127. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
  128. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  129. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  130. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  131. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  132. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  133. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  134. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  135. fusion_bench_config/hydra/default.yaml +6 -2
  136. fusion_bench_config/llama_full_finetune.yaml +1 -0
  137. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  138. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  139. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  140. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
  141. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
  142. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  143. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  144. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
  145. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
  168. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
  169. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
  170. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
  171. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
  172. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
  173. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  174. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  175. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  176. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  177. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  178. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  179. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  180. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  181. fusion_bench_config/nyuv2_config.yaml +3 -1
  182. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  183. fusion_bench_config/path/default.yaml +28 -0
  184. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  185. fusion_bench_config/method/adamerging.yaml +0 -23
  186. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  187. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  188. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  189. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
  190. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
  191. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
  192. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
  193. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -0,0 +1,462 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ def pack(x, n_bits=32):
7
+ """
8
+ pack n_bits of x into a single integer
9
+
10
+ x: bool tensor (*, K, N)
11
+ return: int tensor (*, K // n_bits, N)
12
+ """
13
+ assert x.shape[-2] % n_bits == 0, "K must be divisible by n_bits"
14
+
15
+ shift = torch.arange(n_bits, device=x.device)
16
+ shape = x.shape[:-2]
17
+ x = x.view(-1, x.shape[-2] // n_bits, n_bits, x.shape[-1])
18
+ x = x << shift[None, None, :, None]
19
+ x = x.sum(-2)
20
+ x = x.view(*shape, *x.shape[-2:])
21
+
22
+ # determine dtype
23
+ if n_bits == 8:
24
+ dtype = torch.uint8
25
+ elif n_bits == 16:
26
+ dtype = torch.int16
27
+ elif n_bits == 32:
28
+ dtype = torch.int32
29
+ elif n_bits == 64:
30
+ dtype = torch.int64
31
+
32
+ return x.to(dtype)
33
+
34
+
35
+ def unpack(x, n_bits=32):
36
+ """
37
+ unpack n_bits of x into a single integer
38
+
39
+ x: int tensor (*, K // n_bits, N)
40
+ return: bool tensor (*, K, N)
41
+ """
42
+ shift = torch.arange(n_bits, device=x.device)
43
+ shape = x.shape[:-2]
44
+ x = x.view(-1, x.shape[-2], 1, x.shape[-1])
45
+ x = (x >> shift[None, None, :, None]) & 0x1
46
+ x = x.view(*shape, -1, x.shape[-1])
47
+ return x.bool()
48
+
49
+
50
+ @triton.autotune(
51
+ configs=[
52
+ triton.Config(
53
+ {
54
+ "BLOCK_SIZE_M": 16,
55
+ "BLOCK_SIZE_N": 256,
56
+ "BLOCK_SIZE_K": 64,
57
+ "GROUP_SIZE_M": 8,
58
+ },
59
+ num_stages=3,
60
+ num_warps=8,
61
+ ),
62
+ triton.Config(
63
+ {
64
+ "BLOCK_SIZE_M": 16,
65
+ "BLOCK_SIZE_N": 128,
66
+ "BLOCK_SIZE_K": 32,
67
+ "GROUP_SIZE_M": 8,
68
+ },
69
+ num_stages=4,
70
+ num_warps=4,
71
+ ),
72
+ triton.Config(
73
+ {
74
+ "BLOCK_SIZE_M": 16,
75
+ "BLOCK_SIZE_N": 64,
76
+ "BLOCK_SIZE_K": 32,
77
+ "GROUP_SIZE_M": 8,
78
+ },
79
+ num_stages=4,
80
+ num_warps=4,
81
+ ),
82
+ triton.Config(
83
+ {
84
+ "BLOCK_SIZE_M": 16,
85
+ "BLOCK_SIZE_N": 32,
86
+ "BLOCK_SIZE_K": 32,
87
+ "GROUP_SIZE_M": 8,
88
+ },
89
+ num_stages=4,
90
+ num_warps=4,
91
+ ),
92
+ ],
93
+ key=["M", "N", "K"],
94
+ )
95
+ @triton.jit
96
+ def binary_matmul_kernel(
97
+ # Pointers to matrices
98
+ a_ptr,
99
+ b_ptr,
100
+ c_ptr,
101
+ # Matrix dimensions
102
+ M,
103
+ N,
104
+ K,
105
+ n_bits,
106
+ # The stride variables represent how much to increase the ptr by when moving by 1
107
+ # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
108
+ # by to get the element one row down (A has M rows).
109
+ stride_am,
110
+ stride_ak,
111
+ stride_bk,
112
+ stride_bn,
113
+ stride_cm,
114
+ stride_cn,
115
+ # Meta-parameters
116
+ BLOCK_SIZE_M: tl.constexpr,
117
+ BLOCK_SIZE_N: tl.constexpr,
118
+ BLOCK_SIZE_K: tl.constexpr,
119
+ GROUP_SIZE_M: tl.constexpr,
120
+ ACTIVATION: tl.constexpr,
121
+ ):
122
+ """Kernel for computing the matmul C = A x B.
123
+ A has shape (M, K), float
124
+ B has shape (K//n_bits, N), int, packed boolean
125
+ C has shape (M, N),
126
+ """
127
+ # -----------------------------------------------------------
128
+ # Map program ids `pid` to the block of C it should compute.
129
+ # This is done in a grouped ordering to promote L2 data reuse.
130
+ # See above `L2 Cache Optimizations` section for details.
131
+ pid = tl.program_id(axis=0)
132
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
133
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
134
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
135
+ group_id = pid // num_pid_in_group
136
+ first_pid_m = group_id * GROUP_SIZE_M
137
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
138
+ pid_m = first_pid_m + (pid % group_size_m)
139
+ pid_n = (pid % num_pid_in_group) // group_size_m
140
+
141
+ # ----------------------------------------------------------
142
+ # Create pointers for the first blocks of A and B.
143
+ # We will advance this pointer as we move in the K direction
144
+ # and accumulate
145
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
146
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
147
+ # See above `Pointer Arithmetics` section for details
148
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
149
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
150
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
151
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
152
+ # b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
153
+
154
+ # Adapted from GPTQ-Triton (https://github.com/fpgaminer/GPTQ-triton)
155
+ # b_ptrs is set up such that it repeats elements along the K axis n_bits times
156
+ b_ptrs = b_ptr + (
157
+ (offs_k[:, None] // n_bits) * stride_bk + offs_bn[None, :] * stride_bn
158
+ ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
159
+ # shifter is used to extract each bit of each element in the int matrix
160
+ shifter = (offs_k % n_bits)[:, None]
161
+
162
+ # -----------------------------------------------------------
163
+ # Iterate to compute a block of the C matrix.
164
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
165
+ # of fp32 values for higher accuracy.
166
+ # `accumulator` will be converted back to fp16 after the loop.
167
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
168
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
169
+ # Load the next block of A and B, generate a mask by checking the K dimension.
170
+ # If it is out of bounds, set it to 0.
171
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
172
+ # b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0)
173
+ b = tl.load(b_ptrs)
174
+
175
+ # Convert B from int to a.dtype, for each bit in B, 0 becomes -1.0, 1 becomes 1.0
176
+ # b: (BLOCK_SIZE_K, BLOCK_SIZE_N)
177
+ b = (b >> shifter) & 0x1
178
+ b = b.to(a.dtype) * 2 - 1
179
+
180
+ # Simply convert to a.dtype
181
+ # b = b.to(a.dtype)
182
+ # We accumulate along the K dimension.
183
+ accumulator += tl.dot(a, b)
184
+ # Advance the ptrs to the next K block.
185
+ a_ptrs += BLOCK_SIZE_K * stride_ak
186
+ # b_ptrs += BLOCK_SIZE_K * stride_bk
187
+ b_ptrs += (BLOCK_SIZE_K // n_bits) * stride_bk
188
+ # You can fuse arbitrary activation functions here
189
+ # while the accumulator is still in FP32!
190
+ # if ACTIVATION == "leaky_relu":
191
+ # accumulator = leaky_relu(accumulator)
192
+ c = accumulator.to(tl.float16)
193
+
194
+ # -----------------------------------------------------------
195
+ # Write back the block of the output matrix C with masks.
196
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
197
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
198
+ c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
199
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
200
+ tl.store(c_ptrs, c, mask=c_mask)
201
+
202
+
203
+ def binary_matmul(a, b, n_bits=32, activation=""):
204
+ """
205
+ a: float tensor (M, K)
206
+ b: int tensor (K, N)
207
+ n_bits: int, number of bits that each element in b represents
208
+ """
209
+ # Check constraints.
210
+ assert a.shape[1] == b.shape[0] * n_bits, "Incompatible dimensions"
211
+ assert a.is_contiguous(), "Matrix A must be contiguous"
212
+ assert b.is_contiguous(), "Matrix B must be contiguous"
213
+ M, K = a.shape
214
+ _, N = b.shape
215
+
216
+ # Allocates output.
217
+ c = torch.empty((M, N), device=a.device, dtype=a.dtype)
218
+ # 1D launch kernel where each block gets its own program.
219
+ grid = lambda META: (
220
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
221
+ )
222
+
223
+ # print(f"Launching kernel with M = {M}, N = {N}, K = {K}, n_bits = {n_bits}, activation = {activation}")
224
+
225
+ binary_matmul_kernel[grid](
226
+ a,
227
+ b,
228
+ c,
229
+ M,
230
+ N,
231
+ K,
232
+ n_bits,
233
+ a.stride(0),
234
+ a.stride(1),
235
+ b.stride(0),
236
+ b.stride(1),
237
+ c.stride(0),
238
+ c.stride(1),
239
+ ACTIVATION=activation,
240
+ )
241
+ return c
242
+
243
+
244
+ @triton.autotune(
245
+ configs=[
246
+ triton.Config(
247
+ {
248
+ "BLOCK_SIZE_M": 16,
249
+ "BLOCK_SIZE_N": 256,
250
+ "BLOCK_SIZE_K": 64,
251
+ "GROUP_SIZE_M": 8,
252
+ },
253
+ num_stages=3,
254
+ num_warps=4,
255
+ ),
256
+ triton.Config(
257
+ {
258
+ "BLOCK_SIZE_M": 16,
259
+ "BLOCK_SIZE_N": 128,
260
+ "BLOCK_SIZE_K": 32,
261
+ "GROUP_SIZE_M": 8,
262
+ },
263
+ num_stages=4,
264
+ num_warps=2,
265
+ ),
266
+ triton.Config(
267
+ {
268
+ "BLOCK_SIZE_M": 16,
269
+ "BLOCK_SIZE_N": 64,
270
+ "BLOCK_SIZE_K": 32,
271
+ "GROUP_SIZE_M": 8,
272
+ },
273
+ num_stages=4,
274
+ num_warps=2,
275
+ ),
276
+ triton.Config(
277
+ {
278
+ "BLOCK_SIZE_M": 16,
279
+ "BLOCK_SIZE_N": 32,
280
+ "BLOCK_SIZE_K": 32,
281
+ "GROUP_SIZE_M": 8,
282
+ },
283
+ num_stages=4,
284
+ num_warps=2,
285
+ ),
286
+ ],
287
+ key=["M", "N", "K"],
288
+ )
289
+ @triton.jit
290
+ def binary_bmm_kernel(
291
+ # Pointers to matrices
292
+ a_ptr,
293
+ b_ptr,
294
+ c_ptr,
295
+ # Matrix dimensions
296
+ M,
297
+ N,
298
+ K,
299
+ n_bits,
300
+ # The stride variables represent how much to increase the ptr by when moving by 1
301
+ # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
302
+ # by to get the element one row down (A has M rows).
303
+ stride_am,
304
+ stride_ak,
305
+ stride_bk,
306
+ stride_bn,
307
+ stride_cm,
308
+ stride_cn,
309
+ stride_batch_a,
310
+ stride_batch_b,
311
+ stride_batch_c,
312
+ # Meta-parameters
313
+ BLOCK_SIZE_M: tl.constexpr,
314
+ BLOCK_SIZE_N: tl.constexpr,
315
+ BLOCK_SIZE_K: tl.constexpr,
316
+ GROUP_SIZE_M: tl.constexpr,
317
+ ACTIVATION: tl.constexpr,
318
+ ):
319
+ """Kernel for computing the matmul C = A x B.
320
+ A has shape (B, M, K), float
321
+ B has shape (B, K//n_bits, N), int, packed boolean
322
+ C has shape (B, M, N),
323
+ """
324
+ # -----------------------------------------------------------
325
+ # Map program ids `pid` to the block of C it should compute.
326
+ # This is done in a grouped ordering to promote L2 data reuse.
327
+ # See above `L2 Cache Optimizations` section for details.
328
+ pid = tl.program_id(axis=0)
329
+ pid_batch = tl.program_id(axis=1)
330
+
331
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
332
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
333
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
334
+ group_id = pid // num_pid_in_group
335
+ first_pid_m = group_id * GROUP_SIZE_M
336
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
337
+ pid_m = first_pid_m + (pid % group_size_m)
338
+ pid_n = (pid % num_pid_in_group) // group_size_m
339
+
340
+ # ----------------------------------------------------------
341
+ # Create pointers for the first blocks of A and B.
342
+ # We will advance this pointer as we move in the K direction
343
+ # and accumulate
344
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
345
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
346
+ # See above `Pointer Arithmetics` section for details
347
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
348
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
349
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
350
+ a_ptrs = (
351
+ a_ptr
352
+ + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
353
+ + pid_batch * stride_batch_a
354
+ )
355
+ # b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
356
+
357
+ # Adapted from GPTQ-Triton (https://github.com/fpgaminer/GPTQ-triton)
358
+ # b_ptrs is set up such that it repeats elements along the K axis n_bits times
359
+ b_ptrs = (
360
+ b_ptr
361
+ + ((offs_k[:, None] // n_bits) * stride_bk + offs_bn[None, :] * stride_bn)
362
+ + pid_batch * stride_batch_b
363
+ )
364
+ # (BLOCK_SIZE_K, BLOCK_SIZE_N)
365
+ # shifter is used to extract each bit of each element in the int matrix
366
+ shifter = (offs_k % n_bits)[:, None]
367
+
368
+ # -----------------------------------------------------------
369
+ # Iterate to compute a block of the C matrix.
370
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
371
+ # of fp32 values for higher accuracy.
372
+ # `accumulator` will be converted back to fp16 after the loop.
373
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
374
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
375
+ # Load the next block of A and B, generate a mask by checking the K dimension.
376
+ # If it is out of bounds, set it to 0.
377
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
378
+ # b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0)
379
+ b = tl.load(b_ptrs)
380
+
381
+ # Convert B from int to a.dtype, for each bit in B, 0 becomes -1.0, 1 becomes 1.0
382
+ # b: (BLOCK_SIZE_K, BLOCK_SIZE_N)
383
+ b = (b >> shifter) & 0x1
384
+ # b = b.to(a.dtype) * 2 - 1
385
+ b = (2 * b - 1).to(a.dtype)
386
+
387
+ # Simply convert to a.dtype
388
+ # b = b.to(a.dtype)
389
+ # We accumulate along the K dimension.
390
+ accumulator += tl.dot(a, b)
391
+ # Advance the ptrs to the next K block.
392
+ a_ptrs += BLOCK_SIZE_K * stride_ak
393
+ # b_ptrs += BLOCK_SIZE_K * stride_bk
394
+ b_ptrs += (BLOCK_SIZE_K // n_bits) * stride_bk
395
+ # You can fuse arbitrary activation functions here
396
+ # while the accumulator is still in FP32!
397
+ # if ACTIVATION == "leaky_relu":
398
+ # accumulator = leaky_relu(accumulator)
399
+ c = accumulator.to(tl.float16)
400
+
401
+ # -----------------------------------------------------------
402
+ # Write back the block of the output matrix C with masks.
403
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
404
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
405
+ c_ptrs = (
406
+ c_ptr
407
+ + stride_cm * offs_cm[:, None]
408
+ + stride_cn * offs_cn[None, :]
409
+ + pid_batch * stride_batch_c
410
+ )
411
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
412
+ tl.store(c_ptrs, c, mask=c_mask)
413
+
414
+
415
+ def binary_bmm(a, b, n_bits=32, activation=""):
416
+ """
417
+ a: float tensor (B, M, K)
418
+ b: int tensor (B, K, N)
419
+ n_bits: int, number of bits that each element in b represents
420
+ """
421
+ assert a.dim() == 3, "Matrix A must be 3D"
422
+ assert b.dim() == 3, "Matrix B must be 3D"
423
+ assert a.shape[2] == b.shape[1] * n_bits, "Incompatible dimensions"
424
+ assert a.shape[0] == b.shape[0], "Incompatible batch dimensions"
425
+ assert a.is_contiguous(), "Matrix A must be contiguous"
426
+ assert b.is_contiguous(), "Matrix B must be contiguous"
427
+ assert a.device == b.device, "A and B must be on the same device"
428
+ B, M, K = a.shape
429
+ _, _, N = b.shape
430
+
431
+ # Allocates output.
432
+ c = torch.empty((B, M, N), device=a.device, dtype=a.dtype)
433
+ # 1D launch kernel where each block gets its own program.
434
+ grid = lambda META: (
435
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
436
+ B,
437
+ )
438
+
439
+ # print(f"Launching kernel with M = {M}, N = {N}, K = {K}, n_bits = {n_bits}, activation = {activation}")
440
+
441
+ # wrap this, otherwise triton tries to launch from cuda:0
442
+ with torch.cuda.device(a.device):
443
+ binary_bmm_kernel[grid](
444
+ a,
445
+ b,
446
+ c,
447
+ M,
448
+ N,
449
+ K,
450
+ n_bits,
451
+ a.stride(1),
452
+ a.stride(2),
453
+ b.stride(1),
454
+ b.stride(2),
455
+ c.stride(1),
456
+ c.stride(2),
457
+ a.stride(0),
458
+ b.stride(0),
459
+ c.stride(0),
460
+ ACTIVATION=activation,
461
+ )
462
+ return c
@@ -0,0 +1,35 @@
1
+ import torch
2
+ from datasets import load_dataset
3
+ from transformers import default_data_collator
4
+
5
+
6
+ def _preprocess(tokenizer, examples, max_length=128):
7
+ return tokenizer(
8
+ examples["text"], padding="max_length", truncation=True, max_length=max_length
9
+ )
10
+
11
+
12
+ def get_dataset(dataset_name: str, subset: str, split: str, size: int = None):
13
+ if size is None:
14
+ dataset = load_dataset(dataset_name, subset)[split]
15
+ else:
16
+ dataset = load_dataset(dataset_name, subset, streaming=True)[split]
17
+ dataset = dataset.take(size)
18
+
19
+ return dataset
20
+
21
+
22
+ def get_dataloader(dataset, tokenizer, batch_size, num_workers=4, max_length=128):
23
+ dataset = dataset.map(
24
+ lambda examples: _preprocess(tokenizer, examples, max_length),
25
+ batched=True,
26
+ batch_size=batch_size,
27
+ remove_columns=["text", "timestamp", "url"],
28
+ )
29
+ dataloader = torch.utils.data.DataLoader(
30
+ dataset,
31
+ batch_size=batch_size,
32
+ num_workers=0,
33
+ collate_fn=default_data_collator,
34
+ )
35
+ return dataloader
@@ -0,0 +1,129 @@
1
+ import gc
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from .binary_gemm_kernel import binary_bmm, pack, unpack
7
+
8
+
9
+ class BinaryDiff(nn.Module):
10
+ def __init__(self, base: Tensor, finetune: Tensor):
11
+ super().__init__()
12
+ diff = finetune - base
13
+ quantile = diff.float().abs().mean()
14
+
15
+ mask = torch.ones_like(diff)
16
+ mask[diff < 0] = 0
17
+ mask = pack(mask.bool().T)
18
+
19
+ self.register_buffer("mask", mask)
20
+ self.register_buffer("base", base.T)
21
+ self.register_parameter(
22
+ "coeff",
23
+ nn.Parameter(
24
+ torch.tensor(
25
+ quantile,
26
+ dtype=torch.float32,
27
+ requires_grad=True,
28
+ device=base.device,
29
+ )
30
+ ),
31
+ )
32
+ del base, finetune, diff
33
+
34
+ def forward(self, x):
35
+ # print(x.shape, self.base.shape, self.coeff.shape, self.mask.shape)
36
+ # [B, seq, in] @ [in, out] + [B, seq, in] @ [B, in/32, out]
37
+
38
+ # TODO: This can be faster
39
+ repeated_mask = self.mask.unsqueeze(0).repeat(x.size(0), 1, 1)
40
+ return x @ self.base + self.coeff * binary_bmm(x, repeated_mask)
41
+
42
+
43
+ def compress_diff(base_model, finetuned_model, finetuned_compressed_model):
44
+ def compress_submodule(name, subname, module, submodule):
45
+ target_device = submodule.weight.device
46
+
47
+ base_weight = (
48
+ base_model.get_submodule(f"{name}.{subname}")
49
+ .weight.detach()
50
+ .to(target_device)
51
+ )
52
+ finetuned_weight = (
53
+ finetuned_model.get_submodule(f"{name}.{subname}")
54
+ .weight.detach()
55
+ .to(target_device)
56
+ )
57
+
58
+ compressed = BinaryDiff(
59
+ base=base_weight,
60
+ finetune=finetuned_weight,
61
+ ).to(target_device)
62
+
63
+ del submodule, base_weight
64
+ setattr(module, subname, None)
65
+ gc.collect()
66
+ torch.cuda.empty_cache()
67
+ setattr(module, subname, compressed)
68
+
69
+ # TODO: this can be parallelized
70
+ for name, module in finetuned_compressed_model.named_modules():
71
+ if "mlp" in name or "self_attn" in name:
72
+ for subname, submodule in module.named_children():
73
+ if "proj" in subname:
74
+ compress_submodule(name, subname, module, submodule)
75
+
76
+
77
+ def save_diff(finetuned_compressed_model, save_dir):
78
+ diff_dict = {}
79
+
80
+ for name, module in finetuned_compressed_model.named_modules():
81
+ if isinstance(module, BinaryDiff):
82
+ # diff_dict[name + ".mask"] = (module.mask == 1).bool().cpu()
83
+ diff_dict[name + ".mask"] = module.mask.cpu()
84
+ diff_dict[name + ".coeff"] = module.coeff.cpu()
85
+
86
+ for name, param in finetuned_compressed_model.named_parameters():
87
+ if param.requires_grad:
88
+ diff_dict[name] = param.cpu()
89
+
90
+ torch.save(diff_dict, save_dir)
91
+
92
+
93
+ @torch.no_grad()
94
+ def load_diff(model, diff_dir):
95
+ device = model.device
96
+ diff_dict = torch.load(diff_dir)
97
+
98
+ for name, module in model.named_modules():
99
+ if name + ".mask" in diff_dict:
100
+ coeff = diff_dict[name + ".coeff"].to(device)
101
+ mask = diff_dict[name + ".mask"].to(device)
102
+
103
+ # setattr(module, "mask", mask)
104
+ # setattr(module, "coeff", coeff)
105
+ weight = (unpack(mask) * 2 - 1) * coeff
106
+
107
+ module.weight.add_(weight.T.to(module.weight.dtype))
108
+ elif name + ".weight" in diff_dict:
109
+ module.weight = nn.Parameter(
110
+ diff_dict[name + ".weight"].to(device).to(module.weight.dtype)
111
+ )
112
+
113
+ elif name + ".A" in diff_dict:
114
+ A = diff_dict[name + ".A"].to(device)
115
+ B = diff_dict[name + ".B"].to(device)
116
+
117
+ mask = (A @ B).T
118
+ module.weight.add_(mask.to(module.weight.dtype))
119
+
120
+ model.config.vocab_size = model.lm_head.weight.size(0)
121
+
122
+
123
+ def save_full_model(base_model, tokenizer, diff_dir, save_dir):
124
+ load_diff(base_model, diff_dir)
125
+
126
+ base_model.save_pretrained(save_dir)
127
+ tokenizer.save_pretrained(save_dir)
128
+
129
+ del base_model
@@ -372,7 +372,6 @@ class ConcreteLayerWiseAdaMergingForCLIP(
372
372
  clamp_weights=self.config.clamp_weights,
373
373
  tie_weights=self.config.tie_weights,
374
374
  strict=self.config.strict,
375
- layer_vector_dtype=self.merge_dtype,
376
375
  )
377
376
  return module, mask_model
378
377
 
@@ -1,17 +1,17 @@
1
1
  import logging
2
2
  from copy import deepcopy
3
- from typing import List, Mapping, Union # noqa: F401
3
+ from typing import Any, List, Mapping, Union # noqa: F401
4
4
 
5
5
  import torch
6
6
  from torch import nn
7
7
  from tqdm.autonotebook import tqdm
8
8
 
9
- from fusion_bench.method import BaseAlgorithm
10
- from fusion_bench.modelpool import BaseModelPool
9
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
11
10
 
12
11
  log = logging.getLogger(__name__)
13
12
 
14
13
 
14
+ @auto_register_config
15
15
  class DepthUpscalingAlgorithm(BaseAlgorithm):
16
16
  R"""
17
17
  Implements the Depth Upscaling Algorithm.
@@ -26,12 +26,7 @@ class DepthUpscalingAlgorithm(BaseAlgorithm):
26
26
  **kwargs: Additional keyword arguments.
27
27
  """
28
28
 
29
- _config_mapping = BaseAlgorithm._config_mapping | {
30
- "layer_indices": "layer_indices",
31
- }
32
-
33
- def __init__(self, layer_indices: Union[str, List[int]], **kwargs):
34
- self.layer_indices = layer_indices
29
+ def __init__(self, layer_indices: Union[str, List[int]], **kwargs: Any):
35
30
  super().__init__(**kwargs)
36
31
 
37
32
  @torch.no_grad()
@@ -3,13 +3,12 @@ Example Usage:
3
3
 
4
4
  ```bash
5
5
  fusion_bench \
6
- method=adamerging \
6
+ path.log_dir=outputs/ViT-B-32/layer_wise_adamerging \
7
+ method=adamerging/clip \
7
8
  method.name=clip_layer_wise_adamerging \
8
9
  method.save_merging_weights=merging_weights.pt \
9
- modelpool=clip-vit-base-patch32_TA8 \
10
- taskpool=clip-vit-classification_TA8 \
11
- fabric.loggers.root_dir=outputs/logs/ViT-B-32 \
12
- fabric.loggers.name=clip_layer_wise_adamerging_adamerging
10
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
11
+ taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8 \
13
12
  ```
14
13
  """
15
14
 
@@ -12,7 +12,7 @@ fusion_bench \
12
12
  taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
13
13
 
14
14
  fusion_bench \
15
- method=adamerging \
15
+ method=adamerging/clip \
16
16
  method.name=clip_layer_wise_adamerging_doge_ta \
17
17
  modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
18
18
  taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8