fusion-bench 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. fusion_bench/method/__init__.py +4 -0
  2. fusion_bench/method/fw_merging/__init__.py +2 -0
  3. fusion_bench/method/fw_merging/fw_hard.py +448 -0
  4. fusion_bench/method/fw_merging/fw_soft.py +519 -0
  5. fusion_bench/method/fw_merging/utils.py +331 -0
  6. fusion_bench/method/moe_pruner/__init__.py +7 -0
  7. fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
  8. fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
  9. fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
  10. fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
  11. fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
  12. fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
  13. fusion_bench/method/moe_pruner/utils/data.py +154 -0
  14. fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
  15. fusion_bench/method/moe_pruner/utils/prune.py +313 -0
  16. fusion_bench/method/moe_pruner/utils/score.py +41 -0
  17. fusion_bench/method/pruning/__init__.py +1 -0
  18. fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
  19. fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
  20. fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
  21. fusion_bench/method/pruning/wanda_utils/data.py +33 -14
  22. fusion_bench/method/randes/__init__.py +15 -0
  23. fusion_bench/method/randes/base_algorithm.py +1013 -0
  24. fusion_bench/method/randes/modelsoup.py +126 -0
  25. fusion_bench/method/randes/task_arithmetic.py +318 -0
  26. fusion_bench/method/sparselo/sparselo.py +20 -2
  27. fusion_bench/method/tall_mask/__init__.py +1 -0
  28. fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
  29. fusion_bench/modelpool/causal_lm/causal_lm.py +73 -10
  30. fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
  31. fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
  32. fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
  33. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
  34. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
  35. fusion_bench/programs/fabric_fusion_program.py +5 -0
  36. fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
  37. fusion_bench/utils/__init__.py +1 -0
  38. fusion_bench/utils/data.py +1 -1
  39. fusion_bench/utils/lazy_state_dict.py +268 -0
  40. fusion_bench/utils/parameters.py +33 -0
  41. fusion_bench/utils/state_dict_arithmetic.py +74 -2
  42. fusion_bench/utils/type.py +1 -0
  43. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +10 -3
  44. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +86 -22
  45. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
  46. fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
  47. fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
  48. fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
  49. fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
  50. fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
  51. fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
  52. fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
  53. fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
  54. fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
  55. fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
  56. fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
  57. fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
  58. fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
  59. fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
  60. fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
  61. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
  62. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  63. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  64. fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
  65. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
  66. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
  67. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
  68. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
  69. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
  70. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
  71. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
  72. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
  73. fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
  74. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B-Instruct.yaml +11 -0
  75. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B.yaml +11 -0
  76. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B-Instruct.yaml +11 -0
  77. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B.yaml +11 -0
  78. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b-it.yaml +11 -0
  79. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b.yaml +11 -0
  80. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b-it.yaml +11 -0
  81. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b.yaml +11 -0
  82. fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
  83. fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
  84. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
  85. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
  86. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,128 @@
1
+ import math
2
+ import time
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import transformers
7
+
8
+ torch.backends.cuda.matmul.allow_tf32 = False
9
+ torch.backends.cudnn.allow_tf32 = False
10
+
11
+
12
+ ## SparseGPT: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa
13
+ class SparseGPT:
14
+
15
+ def __init__(self, layer):
16
+ self.layer = layer
17
+ self.dev = self.layer.weight.device
18
+ W = layer.weight.data.clone()
19
+ if isinstance(self.layer, nn.Conv2d):
20
+ W = W.flatten(1)
21
+ if isinstance(self.layer, transformers.Conv1D):
22
+ W = W.t()
23
+ self.rows = W.shape[0]
24
+ self.columns = W.shape[1]
25
+ self.H = torch.zeros((self.columns, self.columns), device=self.dev)
26
+ self.nsamples = 0
27
+
28
+ def add_batch(self, inp, out):
29
+ if len(inp.shape) == 2:
30
+ inp = inp.unsqueeze(0)
31
+ tmp = inp.shape[0]
32
+ if isinstance(self.layer, nn.Linear) or isinstance(
33
+ self.layer, transformers.Conv1D
34
+ ):
35
+ if len(inp.shape) == 3:
36
+ inp = inp.reshape((-1, inp.shape[-1]))
37
+ inp = inp.t()
38
+ self.H *= self.nsamples / (self.nsamples + tmp)
39
+ self.nsamples += tmp
40
+ inp = math.sqrt(2 / self.nsamples) * inp.float()
41
+ self.H += inp.matmul(inp.t())
42
+
43
+ def fasterprune(self, sparsity, prune_n=0, prune_m=0, blocksize=128, percdamp=0.01):
44
+ W = self.layer.weight.data.clone()
45
+ if isinstance(self.layer, nn.Conv2d):
46
+ W = W.flatten(1)
47
+ if isinstance(self.layer, transformers.Conv1D):
48
+ W = W.t()
49
+ W = W.float()
50
+
51
+ tick = time.time()
52
+
53
+ H = self.H
54
+ del self.H
55
+ dead = torch.diag(H) == 0
56
+ H[dead, dead] = 1
57
+ W[:, dead] = 0
58
+
59
+ Losses = torch.zeros(self.rows, device=self.dev)
60
+
61
+ damp = percdamp * torch.mean(torch.diag(H))
62
+ diag = torch.arange(self.columns, device=self.dev)
63
+ H[diag, diag] += damp
64
+ H = torch.linalg.cholesky(H)
65
+ H = torch.cholesky_inverse(H)
66
+ H = torch.linalg.cholesky(H, upper=True)
67
+ Hinv = H
68
+
69
+ mask = None
70
+
71
+ for i1 in range(0, self.columns, blocksize):
72
+ i2 = min(i1 + blocksize, self.columns)
73
+ count = i2 - i1
74
+
75
+ W1 = W[:, i1:i2].clone()
76
+ Q1 = torch.zeros_like(W1)
77
+ Err1 = torch.zeros_like(W1)
78
+ Losses1 = torch.zeros_like(W1)
79
+ Hinv1 = Hinv[i1:i2, i1:i2]
80
+
81
+ if prune_n == 0:
82
+ if mask is not None:
83
+ mask1 = mask[:, i1:i2]
84
+ else:
85
+ tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
86
+ thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
87
+ mask1 = tmp <= thresh
88
+ else:
89
+ mask1 = torch.zeros_like(W1) == 1
90
+
91
+ for i in range(count):
92
+ w = W1[:, i]
93
+ d = Hinv1[i, i]
94
+
95
+ if prune_n != 0 and i % prune_m == 0:
96
+ tmp = (
97
+ W1[:, i : (i + prune_m)] ** 2
98
+ / (torch.diag(Hinv1)[i : (i + prune_m)].reshape((1, -1))) ** 2
99
+ )
100
+ mask1.scatter_(
101
+ 1, i + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True
102
+ )
103
+
104
+ q = w.clone()
105
+ q[mask1[:, i]] = 0
106
+
107
+ Q1[:, i] = q
108
+ Losses1[:, i] = (w - q) ** 2 / d**2
109
+
110
+ err1 = (w - q) / d
111
+ W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
112
+ Err1[:, i] = err1
113
+
114
+ W[:, i1:i2] = Q1
115
+ Losses += torch.sum(Losses1, 1) / 2
116
+
117
+ W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
118
+
119
+ torch.cuda.synchronize()
120
+ if isinstance(self.layer, transformers.Conv1D):
121
+ W = W.t()
122
+ self.layer.weight.data = W.reshape(self.layer.weight.shape).to(
123
+ self.layer.weight.data.dtype
124
+ )
125
+
126
+ def free(self):
127
+ self.H = None
128
+ torch.cuda.empty_cache()
@@ -1,13 +1,15 @@
1
1
  # Code adapted from https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py
2
2
 
3
+ import os
3
4
  import random
4
5
  from typing import List, Optional, Tuple, cast # noqa: F401
5
6
 
6
- from datasets import load_dataset
7
7
  from torch import Tensor
8
8
  from tqdm.auto import tqdm
9
9
  from transformers import PreTrainedTokenizer
10
10
 
11
+ from datasets import load_dataset
12
+
11
13
 
12
14
  # Wrapper for tokenized input IDs
13
15
  class TokenizerWrapper:
@@ -61,6 +63,7 @@ def get_c4(
61
63
  seqlen: int,
62
64
  tokenizer,
63
65
  data_path: str = "allenai/c4",
66
+ cache_dir: str = ".cache/allenai--c4",
64
67
  ) -> Tuple[List[Tuple[Tensor, Tensor]], TokenizerWrapper]:
65
68
  """
66
69
  Load and process the c4 dataset.
@@ -76,19 +79,35 @@ def get_c4(
76
79
  tuple (Tuple[List[Tuple[Tensor, Tensor]], TokenizerWrapper]): Tuple containing the training samples and the validation dataset.
77
80
  """
78
81
  # Load train and validation datasets
79
- traindata = load_dataset(
80
- data_path,
81
- # "allenai--c4", # https://github.com/huggingface/datasets/issues/6559
82
- data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
83
- split="train",
84
- )
85
- valdata = load_dataset(
86
- data_path,
87
- # "allenai--c4",
88
- data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
89
- split="validation",
90
- )
91
-
82
+ if os.path.exists(f"{cache_dir}/en/c4-train.00000-of-01024.json.gz"):
83
+ traindata = load_dataset(
84
+ "json",
85
+ data_files={"train": f"{cache_dir}/en/c4-train.00000-of-01024.json.gz"},
86
+ split="train",
87
+ )
88
+ else:
89
+ traindata = load_dataset(
90
+ data_path,
91
+ # "allenai--c4", # https://github.com/huggingface/datasets/issues/6559
92
+ data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
93
+ split="train",
94
+ )
95
+
96
+ if os.path.exists(f"{cache_dir}/en/c4-validation.00000-of-00008.json.gz"):
97
+ valdata = load_dataset(
98
+ "json",
99
+ data_files={
100
+ "validation": f"{cache_dir}/en/c4-validation.00000-of-00008.json.gz",
101
+ },
102
+ split="validation",
103
+ )
104
+ else:
105
+ valdata = load_dataset(
106
+ data_path,
107
+ # "allenai--c4",
108
+ data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
109
+ split="validation",
110
+ )
92
111
  # Generate samples from training set
93
112
  if seed is not None:
94
113
  random.seed(seed)
@@ -0,0 +1,15 @@
1
+ R"""
2
+ RanDeS: Randomized Delta Superposition
3
+
4
+ Implementation of "RanDeS: Randomized Delta Superposition for Multi-Model Compression"
5
+ paper link: http://arxiv.org/abs/2505.11204
6
+
7
+ Modified from https://github.com/Zhou-Hangyu/randes
8
+ """
9
+
10
+ from .base_algorithm import SuperposedAlgorithmBase
11
+ from .modelsoup import SuperposedModelSoupAlgorithm
12
+ from .task_arithmetic import (
13
+ SuperposedTaskArithmeticAlgorithm,
14
+ SuperposedTaskArithmeticLoRAAlgorithm,
15
+ )