joonmyung 1.6.1__tar.gz → 1.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {joonmyung-1.6.1 → joonmyung-1.7.0}/PKG-INFO +1 -1
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/analysis/analysis.py +1 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/analysis/dataset.py +30 -11
- joonmyung-1.7.0/joonmyung/compression/__init__.py +2 -0
- joonmyung-1.7.0/joonmyung/compression/compression.py +227 -0
- joonmyung-1.7.0/joonmyung/compression/utils.py +564 -0
- joonmyung-1.7.0/joonmyung/data.py +104 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/draw.py +210 -76
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/log.py +2 -1
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/meta_data/utils.py +4 -14
- joonmyung-1.7.0/joonmyung/metric.py +165 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/status.py +1 -1
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/utils.py +4 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung.egg-info/PKG-INFO +1 -1
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung.egg-info/SOURCES.txt +2 -12
- {joonmyung-1.6.1 → joonmyung-1.7.0}/setup.py +1 -1
- joonmyung-1.6.1/joonmyung/analysis/analysis_bak.py +0 -218
- joonmyung-1.6.1/joonmyung/analysis/analysis_/343/205/240/343/205/217.py +0 -218
- joonmyung-1.6.1/joonmyung/analysis/evaluate.py +0 -39
- joonmyung-1.6.1/joonmyung/analysis/hook.py +0 -0
- joonmyung-1.6.1/joonmyung/analysis/utils.py +0 -14
- joonmyung-1.6.1/joonmyung/compression/__init__.py +0 -0
- joonmyung-1.6.1/joonmyung/compression/compression.py +0 -202
- joonmyung-1.6.1/joonmyung/data.py +0 -47
- joonmyung-1.6.1/joonmyung/dummy.py +0 -4
- joonmyung-1.6.1/joonmyung/metric.py +0 -133
- joonmyung-1.6.1/joonmyung/model/__init__.py +0 -0
- joonmyung-1.6.1/joonmyung/model/compression.py +0 -202
- joonmyung-1.6.1/joonmyung/model.py +0 -0
- joonmyung-1.6.1/joonmyung/models/__init__.py +0 -0
- joonmyung-1.6.1/joonmyung/models/tome.py +0 -386
- {joonmyung-1.6.1 → joonmyung-1.7.0}/LICENSE.txt +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/README.md +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/__init__.py +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/analysis/__init__.py +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/analysis/model.py +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/app.py +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/clip/__init__.py +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/clip/clip.py +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/clip/model.py +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/clip/simple_tokenizer.py +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/compression/apply.py +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/file.py +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/gradcam.py +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/meta_data/__init__.py +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/meta_data/label.py +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/script.py +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung.egg-info/dependency_links.txt +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung.egg-info/not-zip-safe +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung.egg-info/top_level.txt +0 -0
- {joonmyung-1.6.1 → joonmyung-1.7.0}/setup.cfg +0 -0
|
@@ -21,10 +21,19 @@ class JDataset():
|
|
|
21
21
|
"num_classes" : 1000,
|
|
22
22
|
"data_types" : ["val", "train"],
|
|
23
23
|
"label_name" : imnet_label,
|
|
24
|
-
"distributions" : {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]},
|
|
25
|
-
# "distributions" : {"mean": [0.48145466, 0.4578275, 0.40821073], "std": [0.26862954, 0.26130258, 0.27577711]}, # CLIP
|
|
24
|
+
"distributions" : {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]},
|
|
26
25
|
"size": (224, 224)
|
|
27
26
|
},
|
|
27
|
+
|
|
28
|
+
"imagenet_clip": {
|
|
29
|
+
"num_classes": 1000,
|
|
30
|
+
"data_types": ["val", "train"],
|
|
31
|
+
"label_name": imnet_label,
|
|
32
|
+
"distributions" : {"mean": [0.48145466, 0.4578275, 0.40821073], "std": [0.26862954, 0.26130258, 0.27577711]},
|
|
33
|
+
"size": (224, 224)
|
|
34
|
+
},
|
|
35
|
+
|
|
36
|
+
|
|
28
37
|
"cifar100" : {
|
|
29
38
|
"num_classes" : 100,
|
|
30
39
|
"data_types": ["test", "train"],
|
|
@@ -45,8 +54,12 @@ class JDataset():
|
|
|
45
54
|
size = size if size else setting["size"]
|
|
46
55
|
|
|
47
56
|
self.transform = [
|
|
48
|
-
|
|
49
|
-
transforms.Compose([transforms.Resize(256, interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor()]),
|
|
57
|
+
# DEIT
|
|
58
|
+
transforms.Compose([transforms.Resize((256), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize(self.distribution["mean"], self.distribution["std"])]),
|
|
59
|
+
transforms.Compose([transforms.Resize((256), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor()]),
|
|
60
|
+
# CLIP
|
|
61
|
+
transforms.Compose([transforms.Resize((256, 256), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize(self.distribution["mean"], self.distribution["std"])]),
|
|
62
|
+
transforms.Compose([transforms.Resize((256, 256), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor()]),
|
|
50
63
|
transforms.Compose([transforms.Resize(224, interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize(self.distribution["mean"], self.distribution["std"])]),
|
|
51
64
|
transforms.Compose([transforms.Resize(224, interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor()]),
|
|
52
65
|
transforms.Compose([transforms.ToTensor()])
|
|
@@ -59,24 +72,30 @@ class JDataset():
|
|
|
59
72
|
# self.img_paths = [sorted(glob.glob(os.path.join(self.data_path, self.data_type, "*", "*")))]
|
|
60
73
|
# self.img_paths = [[path, idx] for idx, label_path in enumerate(self.label_paths) for path in sorted(glob.glob(os.path.join(self.data_path, self.data_type, label_path, "*")))]
|
|
61
74
|
self.img_paths = [sorted(glob.glob(os.path.join(self.data_path, self.data_type, label_path, "*"))) for label_path in self.label_paths]
|
|
62
|
-
|
|
75
|
+
self.img_len = [len(labels) for labels in self.img_paths]
|
|
76
|
+
self.img_cum_len = torch.Tensor([0] + [sum(self.img_len[:i+1]) for i in range(len(self.img_len))])
|
|
63
77
|
|
|
64
78
|
def __getitem__(self, idx):
|
|
65
|
-
|
|
79
|
+
if type(idx) == int:
|
|
80
|
+
label_num = (self.img_cum_len <= idx).sum().item() - 1
|
|
81
|
+
img_num = idx - int(self.img_cum_len[label_num].item())
|
|
82
|
+
else:
|
|
83
|
+
label_num, img_num = idx
|
|
66
84
|
img_path = self.img_paths[label_num][img_num]
|
|
67
85
|
sample = default_loader(img_path)
|
|
68
86
|
sample = self.transform[self.transform_type](sample)
|
|
69
87
|
|
|
70
|
-
return sample[None].to(self.device), torch.tensor(label_num).to(self.device), self.label_name[int(label_num)]
|
|
88
|
+
return sample[None].to(self.device), torch.tensor(label_num).to(self.device), self.label_name[int(label_num)], img_path
|
|
71
89
|
|
|
72
90
|
def getItems(self, indexs):
|
|
73
|
-
ds, ls, lns = [], [], []
|
|
91
|
+
ds, ls, lns, ips = [], [], [], []
|
|
74
92
|
for index in indexs:
|
|
75
|
-
d, l, ln = self.__getitem__(index)
|
|
93
|
+
d, l, ln, ip = self.__getitem__(index)
|
|
76
94
|
ds.append(d)
|
|
77
95
|
ls.append(l)
|
|
78
96
|
lns.append(ln)
|
|
79
|
-
|
|
97
|
+
ips.append(ip)
|
|
98
|
+
return torch.cat(ds, dim=0), torch.stack(ls, dim=0), lns, ips
|
|
80
99
|
|
|
81
100
|
def getAllItems(self, batch_size=32):
|
|
82
101
|
dataset = create_dataset(
|
|
@@ -99,7 +118,7 @@ class JDataset():
|
|
|
99
118
|
return c_i
|
|
100
119
|
|
|
101
120
|
def __len__(self):
|
|
102
|
-
return
|
|
121
|
+
return self.img_cum_len[-1]
|
|
103
122
|
|
|
104
123
|
|
|
105
124
|
def validation(self, data):
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
from typing import Callable
|
|
10
|
+
import torch
|
|
11
|
+
import math
|
|
12
|
+
|
|
13
|
+
def token_compression(x, info, diffDropScheduler, layer_idx, others = []):
|
|
14
|
+
[x, TD] = [x[None], True] if len(x.shape) == 2 else [x, False]
|
|
15
|
+
B, T, D = x.shape
|
|
16
|
+
if not info["use"] or T == 1:
|
|
17
|
+
return x.squeeze(0) if TD else x, others
|
|
18
|
+
|
|
19
|
+
T_vis = T if info["img_idx"][0] == None else info["img_idx"][1] - info["img_idx"][0]
|
|
20
|
+
if diffDropScheduler.benchmark:
|
|
21
|
+
r_use, r_diff = None, None
|
|
22
|
+
r_throughput = diffDropScheduler.drop_ratio_avg[layer_idx + 1]
|
|
23
|
+
else:
|
|
24
|
+
r_throughput = None
|
|
25
|
+
r_use, r_diff = (info["prune_layer"] == layer_idx and info["prune_r"]), \
|
|
26
|
+
diffDropScheduler(info["difficulty"])
|
|
27
|
+
|
|
28
|
+
if (r_use or r_diff or r_throughput):
|
|
29
|
+
prune_r, prune_thr = None, None
|
|
30
|
+
if r_throughput is not None:
|
|
31
|
+
prune_r = r_throughput
|
|
32
|
+
elif info["r_type"] == 0:
|
|
33
|
+
prune_r = int(T_vis * info["prune_r"]) if r_use else int(T_vis * r_diff)
|
|
34
|
+
else:
|
|
35
|
+
prune_thr = info["prune_r"] if r_use else r_diff
|
|
36
|
+
|
|
37
|
+
scores = info["importance"] if not diffDropScheduler.benchmark else torch.randn(1, T_vis, device=x.device)
|
|
38
|
+
if info["source"] is None: info["source"] = torch.ones((B, (T // info["group_num"]) ), dtype=torch.bool, device=x.device)
|
|
39
|
+
if info["size"] is None: info["size"] = torch.ones_like(x[..., 0, None]) # (B, T, 1)
|
|
40
|
+
|
|
41
|
+
x, info["source"], others = pruning(x,
|
|
42
|
+
prune_r=prune_r,
|
|
43
|
+
prune_thr=prune_thr,
|
|
44
|
+
scores=scores,
|
|
45
|
+
source=info["source"],
|
|
46
|
+
cls=info["cls"],
|
|
47
|
+
group_num=info["group_num"],
|
|
48
|
+
SE = info["img_idx"],
|
|
49
|
+
others = others)
|
|
50
|
+
|
|
51
|
+
return x.squeeze(0) if TD else x, others
|
|
52
|
+
|
|
53
|
+
def merging(
|
|
54
|
+
metric : torch.Tensor,
|
|
55
|
+
r_merge : int,
|
|
56
|
+
scores : torch.Tensor,
|
|
57
|
+
tau_sim : int,
|
|
58
|
+
tau_info: int,
|
|
59
|
+
tau_size: int,
|
|
60
|
+
mass: int,
|
|
61
|
+
size: torch.Tensor):
|
|
62
|
+
|
|
63
|
+
B, T, _ = metric.shape # (4(B), 197(T), 384(4))
|
|
64
|
+
with torch.no_grad():
|
|
65
|
+
metric = metric / metric.norm(dim=-1, keepdim=True) # (12, 197, 64)
|
|
66
|
+
a, b = metric[..., ::2, :], metric[..., 1::2, :] # (12, 99, 64), (12, 98, 64)
|
|
67
|
+
|
|
68
|
+
if tau_sim:
|
|
69
|
+
W_sim = a @ b.transpose(-1, -2)
|
|
70
|
+
W_sim = ((W_sim + 1) / 2) ** (1 / tau_sim)
|
|
71
|
+
else:
|
|
72
|
+
W_sim = torch.ones((a.shape[0], a.shape[1], b.shape[1]), device=a.device)
|
|
73
|
+
|
|
74
|
+
if tau_info > 0 and scores is not None:
|
|
75
|
+
attn_info = scores
|
|
76
|
+
attn_info = 1 / attn_info # (1(B), 1024(T))
|
|
77
|
+
attn_info = attn_info / attn_info.max(1, keepdim=True)[0] # (192(B), 197(T))
|
|
78
|
+
attn_a, attn_b = attn_info[..., ::2, None], attn_info[..., 1::2, None].transpose(1, 2)
|
|
79
|
+
|
|
80
|
+
W_info = (attn_a * attn_b) ** (1 / tau_info)
|
|
81
|
+
else:
|
|
82
|
+
W_info = 1
|
|
83
|
+
|
|
84
|
+
if tau_size and size is not None:
|
|
85
|
+
size_info = 1 / size
|
|
86
|
+
size_info = size_info / size_info.max(1, keepdim=True)[0] # (4(B), 197(T), 1)
|
|
87
|
+
size_a, size_b = size_info[..., ::2, :], size_info[..., 1::2, :].transpose(1, 2)
|
|
88
|
+
|
|
89
|
+
W_size = (size_a * size_b) ** (1 / tau_size)
|
|
90
|
+
else:
|
|
91
|
+
W_size = 1
|
|
92
|
+
|
|
93
|
+
scores = W_sim * W_info * W_size
|
|
94
|
+
|
|
95
|
+
n, t1, t2 = scores.shape
|
|
96
|
+
node_max, node_idx = scores.max(dim=-1) # (12, 99), (12, 99)
|
|
97
|
+
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] # (12, 99, 1)
|
|
98
|
+
unm_idx = edge_idx[..., r_merge:, :] # Unmerged Tokens (12, 83, 1)
|
|
99
|
+
src_idx = edge_idx[..., :r_merge, :] # Merged Tokens (12, 16, 1)
|
|
100
|
+
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) # (12, 16, 1)
|
|
101
|
+
unm_idx = unm_idx.sort(dim=1)[0]
|
|
102
|
+
|
|
103
|
+
if mass:
|
|
104
|
+
src_so, dst_so = scores[..., ::2, :], scores[..., 1::2, :] # (1, 1176, 1)
|
|
105
|
+
src_so = src_so.gather(dim=-2, index=src_idx) # (12, 91, 197)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
|
109
|
+
src, dst = x[..., ::2, :], x[..., 1::2, :] # (12, 99, 197), (12, 98, 197)
|
|
110
|
+
n, mid, c = src.shape[0], src.shape[1:-2], src.shape[-1]
|
|
111
|
+
unm = src.gather(dim=-2, index=unm_idx.expand(n, *mid, t1 - r_merge, c)) # (12, 91, 197)
|
|
112
|
+
src = src.gather(dim=-2, index=src_idx.expand(n, *mid, r_merge, c))
|
|
113
|
+
if mass:
|
|
114
|
+
src = src * src_so
|
|
115
|
+
dst = dst.scatter_reduce(-2, dst_idx.expand(n, *mid, r_merge, c), src, reduce=mode) # (12, 98, 197)
|
|
116
|
+
x = torch.cat([unm, dst], dim=-2) # (12, 1 + 180, 197)
|
|
117
|
+
return x
|
|
118
|
+
|
|
119
|
+
return merge
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def merge_wavg(
|
|
123
|
+
merge: Callable, x: torch.Tensor, size: torch.Tensor = None, scores=None, pooling_type = 0, source = None,
|
|
124
|
+
):
|
|
125
|
+
|
|
126
|
+
size_max = size.amax(dim=-2, keepdim=True)
|
|
127
|
+
if pooling_type:
|
|
128
|
+
norm = merge(scores * size, mode="sum") # (1, 197, 1)
|
|
129
|
+
|
|
130
|
+
x = merge(x * scores * size, mode="sum")
|
|
131
|
+
size = merge(size, mode="sum")
|
|
132
|
+
x = x / norm
|
|
133
|
+
else:
|
|
134
|
+
x = merge(x * (size / size_max), mode="sum")
|
|
135
|
+
size = merge(size, mode="sum")
|
|
136
|
+
x = x / (size / size_max)
|
|
137
|
+
|
|
138
|
+
if source is not None:
|
|
139
|
+
source = merge(source, mode="amax")
|
|
140
|
+
return x, size, source
|
|
141
|
+
|
|
142
|
+
def pruning(
|
|
143
|
+
x: torch.Tensor,
|
|
144
|
+
prune_r : int,
|
|
145
|
+
prune_thr : float,
|
|
146
|
+
scores : torch.Tensor,
|
|
147
|
+
source : torch.Tensor,
|
|
148
|
+
cls : False,
|
|
149
|
+
group_num : int = 1,
|
|
150
|
+
others : [] = None,
|
|
151
|
+
SE : [] = None):
|
|
152
|
+
b, t_full, d = x.shape
|
|
153
|
+
scores_block = scores.reshape(b, -1, group_num).mean(dim=-1) # (B, T)
|
|
154
|
+
scores_block = scores_block / scores_block.mean(dim=-1, keepdim=True)
|
|
155
|
+
t_vis = scores_block.shape[1]
|
|
156
|
+
|
|
157
|
+
if cls: scores_block[:, 0] = math.inf
|
|
158
|
+
|
|
159
|
+
x_block = x.reshape(b, -1, group_num, d)
|
|
160
|
+
if prune_thr: # REMOVE BASED THRESHOLD
|
|
161
|
+
mask_block = (scores_block >= prune_thr)
|
|
162
|
+
else:
|
|
163
|
+
idx_unprune = scores_block.topk(t_vis - int(prune_r // group_num), dim=1, largest=True, sorted=False).indices
|
|
164
|
+
mask_block = torch.zeros_like(scores_block, dtype=torch.bool)
|
|
165
|
+
mask_block = mask_block.scatter(1, idx_unprune, torch.ones_like(idx_unprune, device=idx_unprune.device, dtype=torch.bool))
|
|
166
|
+
|
|
167
|
+
if SE[0] is not None:
|
|
168
|
+
start, end, length = SE
|
|
169
|
+
|
|
170
|
+
mask_F, mask_L = torch.ones((b, start), device=mask_block.device, dtype=torch.bool), torch.ones(b, t_full - end, device=mask_block.device, dtype=torch.bool)
|
|
171
|
+
mask_block = torch.cat([mask_F, mask_block, mask_L], dim =-1)
|
|
172
|
+
t_num = mask_block.sum().item()
|
|
173
|
+
SE[1], SE[2] = t_num - (t_full - end), t_num
|
|
174
|
+
|
|
175
|
+
x_unprune = x_block.masked_select(mask_block.reshape(1, -1, 1, 1)).view(b, -1, d) # (1, 10032(T), 1280) > (1, 4880(T'), 1280)
|
|
176
|
+
|
|
177
|
+
if others is not None:
|
|
178
|
+
T_remain = x_unprune.shape[-2]
|
|
179
|
+
if len(others) == 1: # RET :ENCODER
|
|
180
|
+
cu_lens = others[0]
|
|
181
|
+
if cu_lens is not None: cu_lens[1:] = torch.stack([mask_block[:, :cu_lens[c + 1] // 4].sum() for c in range(len(cu_lens) - 1)]) * group_num
|
|
182
|
+
others = [cu_lens]
|
|
183
|
+
elif len(others) == 2: # QA : ENCODER
|
|
184
|
+
cu_lens, rotary_pos_emb = others
|
|
185
|
+
cu_lens[1:] = torch.stack([mask_block[:, :cu_lens[c + 1] // 4].sum() for c in range(len(cu_lens) - 1)]) * group_num
|
|
186
|
+
|
|
187
|
+
rotary_pos_emb = rotary_pos_emb.reshape(-1, group_num, 40).masked_select(mask_block.reshape(-1, 1, 1)).view(-1, 40)
|
|
188
|
+
others = [cu_lens, rotary_pos_emb]
|
|
189
|
+
elif len(others) == 3: # LLM
|
|
190
|
+
attention_mask, position_ids, cache_position = others
|
|
191
|
+
attention_mask = attention_mask[:, :, :T_remain, :T_remain] if attention_mask is not None else None
|
|
192
|
+
position_ids = position_ids.masked_select(mask_block.reshape(b, 1, -1)).reshape(3, 1, -1)
|
|
193
|
+
cache_position = cache_position.masked_select(mask_block)
|
|
194
|
+
others = [attention_mask, position_ids, cache_position]
|
|
195
|
+
else: # LLM
|
|
196
|
+
attention_mask, position_ids, cache_position, position_embeddings = others
|
|
197
|
+
attention_mask = attention_mask[:, :, :T_remain, :T_remain] if attention_mask is not None else None
|
|
198
|
+
position_ids = position_ids.masked_select(mask_block.reshape(b, 1, -1)).reshape(3, 1, -1)
|
|
199
|
+
cache_position = cache_position.masked_select(mask_block)
|
|
200
|
+
position_embeddings = tuple([v.masked_select(mask_block.reshape(1, 1, -1, 1)).reshape(3, 1, -1, 128) for v in position_embeddings])
|
|
201
|
+
others = [attention_mask, position_ids, cache_position, position_embeddings]
|
|
202
|
+
|
|
203
|
+
if source is not None:
|
|
204
|
+
restored_mask = torch.zeros_like(source, device=source.device)
|
|
205
|
+
restored_mask[source] = mask_block
|
|
206
|
+
source = restored_mask
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
x = x_unprune
|
|
210
|
+
|
|
211
|
+
return x, source, others
|
|
212
|
+
|
|
213
|
+
def needNaive(info, layer_idx):
|
|
214
|
+
if info["compression"]["use"]:
|
|
215
|
+
if info["compression"]["info_type"] in [1, 2, 3, 4]:
|
|
216
|
+
if (info["compression"]["prune_r"] and info["compression"]["prune_layer"] == layer_idx):
|
|
217
|
+
return True
|
|
218
|
+
return False
|
|
219
|
+
|
|
220
|
+
def needAttn(info, layer_idx):
|
|
221
|
+
if info["compression"]["use"]:
|
|
222
|
+
if info["compression"]["info_type"] in [1, 2, 3, 4]:
|
|
223
|
+
if (info["compression"]["prune_r"] and info["compression"]["prune_layer"] == layer_idx):
|
|
224
|
+
return True
|
|
225
|
+
return False
|
|
226
|
+
|
|
227
|
+
|