joonmyung 1.6.1__tar.gz → 1.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {joonmyung-1.6.1 → joonmyung-1.7.0}/PKG-INFO +1 -1
  2. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/analysis/analysis.py +1 -0
  3. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/analysis/dataset.py +30 -11
  4. joonmyung-1.7.0/joonmyung/compression/__init__.py +2 -0
  5. joonmyung-1.7.0/joonmyung/compression/compression.py +227 -0
  6. joonmyung-1.7.0/joonmyung/compression/utils.py +564 -0
  7. joonmyung-1.7.0/joonmyung/data.py +104 -0
  8. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/draw.py +210 -76
  9. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/log.py +2 -1
  10. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/meta_data/utils.py +4 -14
  11. joonmyung-1.7.0/joonmyung/metric.py +165 -0
  12. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/status.py +1 -1
  13. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/utils.py +4 -0
  14. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung.egg-info/PKG-INFO +1 -1
  15. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung.egg-info/SOURCES.txt +2 -12
  16. {joonmyung-1.6.1 → joonmyung-1.7.0}/setup.py +1 -1
  17. joonmyung-1.6.1/joonmyung/analysis/analysis_bak.py +0 -218
  18. joonmyung-1.6.1/joonmyung/analysis/analysis_/343/205/240/343/205/217.py +0 -218
  19. joonmyung-1.6.1/joonmyung/analysis/evaluate.py +0 -39
  20. joonmyung-1.6.1/joonmyung/analysis/hook.py +0 -0
  21. joonmyung-1.6.1/joonmyung/analysis/utils.py +0 -14
  22. joonmyung-1.6.1/joonmyung/compression/__init__.py +0 -0
  23. joonmyung-1.6.1/joonmyung/compression/compression.py +0 -202
  24. joonmyung-1.6.1/joonmyung/data.py +0 -47
  25. joonmyung-1.6.1/joonmyung/dummy.py +0 -4
  26. joonmyung-1.6.1/joonmyung/metric.py +0 -133
  27. joonmyung-1.6.1/joonmyung/model/__init__.py +0 -0
  28. joonmyung-1.6.1/joonmyung/model/compression.py +0 -202
  29. joonmyung-1.6.1/joonmyung/model.py +0 -0
  30. joonmyung-1.6.1/joonmyung/models/__init__.py +0 -0
  31. joonmyung-1.6.1/joonmyung/models/tome.py +0 -386
  32. {joonmyung-1.6.1 → joonmyung-1.7.0}/LICENSE.txt +0 -0
  33. {joonmyung-1.6.1 → joonmyung-1.7.0}/README.md +0 -0
  34. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/__init__.py +0 -0
  35. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/analysis/__init__.py +0 -0
  36. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/analysis/model.py +0 -0
  37. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/app.py +0 -0
  38. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/clip/__init__.py +0 -0
  39. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/clip/clip.py +0 -0
  40. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/clip/model.py +0 -0
  41. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/clip/simple_tokenizer.py +0 -0
  42. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/compression/apply.py +0 -0
  43. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/file.py +0 -0
  44. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/gradcam.py +0 -0
  45. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/meta_data/__init__.py +0 -0
  46. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/meta_data/label.py +0 -0
  47. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung/script.py +0 -0
  48. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung.egg-info/dependency_links.txt +0 -0
  49. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung.egg-info/not-zip-safe +0 -0
  50. {joonmyung-1.6.1 → joonmyung-1.7.0}/joonmyung.egg-info/top_level.txt +0 -0
  51. {joonmyung-1.6.1 → joonmyung-1.7.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: joonmyung
3
- Version: 1.6.1
3
+ Version: 1.7.0
4
4
  Summary: JoonMyung's Library
5
5
  Home-page: https://github.com/pizard/JoonMyung.git
6
6
  Author: JoonMyung Choi
@@ -43,6 +43,7 @@ def anaModel(transformer_class):
43
43
  def forward(self, *args, **kwdargs):
44
44
  self.resetInfo()
45
45
  return super().forward(*args, **kwdargs)
46
+
46
47
  def encode_image(self, *args, **kwdargs):
47
48
  self.resetInfo()
48
49
  return super().encode_image(*args, **kwdargs)
@@ -21,10 +21,19 @@ class JDataset():
21
21
  "num_classes" : 1000,
22
22
  "data_types" : ["val", "train"],
23
23
  "label_name" : imnet_label,
24
- "distributions" : {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}, # DEIT PAPER
25
- # "distributions" : {"mean": [0.48145466, 0.4578275, 0.40821073], "std": [0.26862954, 0.26130258, 0.27577711]}, # CLIP
24
+ "distributions" : {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]},
26
25
  "size": (224, 224)
27
26
  },
27
+
28
+ "imagenet_clip": {
29
+ "num_classes": 1000,
30
+ "data_types": ["val", "train"],
31
+ "label_name": imnet_label,
32
+ "distributions" : {"mean": [0.48145466, 0.4578275, 0.40821073], "std": [0.26862954, 0.26130258, 0.27577711]},
33
+ "size": (224, 224)
34
+ },
35
+
36
+
28
37
  "cifar100" : {
29
38
  "num_classes" : 100,
30
39
  "data_types": ["test", "train"],
@@ -45,8 +54,12 @@ class JDataset():
45
54
  size = size if size else setting["size"]
46
55
 
47
56
  self.transform = [
48
- transforms.Compose([transforms.Resize(256, interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize(self.distribution["mean"], self.distribution["std"])]), # DEIT PAPER
49
- transforms.Compose([transforms.Resize(256, interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor()]),
57
+ # DEIT
58
+ transforms.Compose([transforms.Resize((256), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize(self.distribution["mean"], self.distribution["std"])]),
59
+ transforms.Compose([transforms.Resize((256), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor()]),
60
+ # CLIP
61
+ transforms.Compose([transforms.Resize((256, 256), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize(self.distribution["mean"], self.distribution["std"])]),
62
+ transforms.Compose([transforms.Resize((256, 256), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor()]),
50
63
  transforms.Compose([transforms.Resize(224, interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize(self.distribution["mean"], self.distribution["std"])]),
51
64
  transforms.Compose([transforms.Resize(224, interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor()]),
52
65
  transforms.Compose([transforms.ToTensor()])
@@ -59,24 +72,30 @@ class JDataset():
59
72
  # self.img_paths = [sorted(glob.glob(os.path.join(self.data_path, self.data_type, "*", "*")))]
60
73
  # self.img_paths = [[path, idx] for idx, label_path in enumerate(self.label_paths) for path in sorted(glob.glob(os.path.join(self.data_path, self.data_type, label_path, "*")))]
61
74
  self.img_paths = [sorted(glob.glob(os.path.join(self.data_path, self.data_type, label_path, "*"))) for label_path in self.label_paths]
62
-
75
+ self.img_len = [len(labels) for labels in self.img_paths]
76
+ self.img_cum_len = torch.Tensor([0] + [sum(self.img_len[:i+1]) for i in range(len(self.img_len))])
63
77
 
64
78
  def __getitem__(self, idx):
65
- label_num, img_num = idx
79
+ if type(idx) == int:
80
+ label_num = (self.img_cum_len <= idx).sum().item() - 1
81
+ img_num = idx - int(self.img_cum_len[label_num].item())
82
+ else:
83
+ label_num, img_num = idx
66
84
  img_path = self.img_paths[label_num][img_num]
67
85
  sample = default_loader(img_path)
68
86
  sample = self.transform[self.transform_type](sample)
69
87
 
70
- return sample[None].to(self.device), torch.tensor(label_num).to(self.device), self.label_name[int(label_num)]
88
+ return sample[None].to(self.device), torch.tensor(label_num).to(self.device), self.label_name[int(label_num)], img_path
71
89
 
72
90
  def getItems(self, indexs):
73
- ds, ls, lns = [], [], []
91
+ ds, ls, lns, ips = [], [], [], []
74
92
  for index in indexs:
75
- d, l, ln = self.__getitem__(index)
93
+ d, l, ln, ip = self.__getitem__(index)
76
94
  ds.append(d)
77
95
  ls.append(l)
78
96
  lns.append(ln)
79
- return torch.cat(ds, dim=0), torch.stack(ls, dim=0), lns
97
+ ips.append(ip)
98
+ return torch.cat(ds, dim=0), torch.stack(ls, dim=0), lns, ips
80
99
 
81
100
  def getAllItems(self, batch_size=32):
82
101
  dataset = create_dataset(
@@ -99,7 +118,7 @@ class JDataset():
99
118
  return c_i
100
119
 
101
120
  def __len__(self):
102
- return
121
+ return self.img_cum_len[-1]
103
122
 
104
123
 
105
124
  def validation(self, data):
@@ -0,0 +1,2 @@
1
+ from .compression import token_compression
2
+ from .utils import getAnalysis, resetInfo, DiffDropScheduler
@@ -0,0 +1,227 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+
8
+
9
+ from typing import Callable
10
+ import torch
11
+ import math
12
+
13
+ def token_compression(x, info, diffDropScheduler, layer_idx, others = []):
14
+ [x, TD] = [x[None], True] if len(x.shape) == 2 else [x, False]
15
+ B, T, D = x.shape
16
+ if not info["use"] or T == 1:
17
+ return x.squeeze(0) if TD else x, others
18
+
19
+ T_vis = T if info["img_idx"][0] == None else info["img_idx"][1] - info["img_idx"][0]
20
+ if diffDropScheduler.benchmark:
21
+ r_use, r_diff = None, None
22
+ r_throughput = diffDropScheduler.drop_ratio_avg[layer_idx + 1]
23
+ else:
24
+ r_throughput = None
25
+ r_use, r_diff = (info["prune_layer"] == layer_idx and info["prune_r"]), \
26
+ diffDropScheduler(info["difficulty"])
27
+
28
+ if (r_use or r_diff or r_throughput):
29
+ prune_r, prune_thr = None, None
30
+ if r_throughput is not None:
31
+ prune_r = r_throughput
32
+ elif info["r_type"] == 0:
33
+ prune_r = int(T_vis * info["prune_r"]) if r_use else int(T_vis * r_diff)
34
+ else:
35
+ prune_thr = info["prune_r"] if r_use else r_diff
36
+
37
+ scores = info["importance"] if not diffDropScheduler.benchmark else torch.randn(1, T_vis, device=x.device)
38
+ if info["source"] is None: info["source"] = torch.ones((B, (T // info["group_num"]) ), dtype=torch.bool, device=x.device)
39
+ if info["size"] is None: info["size"] = torch.ones_like(x[..., 0, None]) # (B, T, 1)
40
+
41
+ x, info["source"], others = pruning(x,
42
+ prune_r=prune_r,
43
+ prune_thr=prune_thr,
44
+ scores=scores,
45
+ source=info["source"],
46
+ cls=info["cls"],
47
+ group_num=info["group_num"],
48
+ SE = info["img_idx"],
49
+ others = others)
50
+
51
+ return x.squeeze(0) if TD else x, others
52
+
53
+ def merging(
54
+ metric : torch.Tensor,
55
+ r_merge : int,
56
+ scores : torch.Tensor,
57
+ tau_sim : int,
58
+ tau_info: int,
59
+ tau_size: int,
60
+ mass: int,
61
+ size: torch.Tensor):
62
+
63
+ B, T, _ = metric.shape # (4(B), 197(T), 384(4))
64
+ with torch.no_grad():
65
+ metric = metric / metric.norm(dim=-1, keepdim=True) # (12, 197, 64)
66
+ a, b = metric[..., ::2, :], metric[..., 1::2, :] # (12, 99, 64), (12, 98, 64)
67
+
68
+ if tau_sim:
69
+ W_sim = a @ b.transpose(-1, -2)
70
+ W_sim = ((W_sim + 1) / 2) ** (1 / tau_sim)
71
+ else:
72
+ W_sim = torch.ones((a.shape[0], a.shape[1], b.shape[1]), device=a.device)
73
+
74
+ if tau_info > 0 and scores is not None:
75
+ attn_info = scores
76
+ attn_info = 1 / attn_info # (1(B), 1024(T))
77
+ attn_info = attn_info / attn_info.max(1, keepdim=True)[0] # (192(B), 197(T))
78
+ attn_a, attn_b = attn_info[..., ::2, None], attn_info[..., 1::2, None].transpose(1, 2)
79
+
80
+ W_info = (attn_a * attn_b) ** (1 / tau_info)
81
+ else:
82
+ W_info = 1
83
+
84
+ if tau_size and size is not None:
85
+ size_info = 1 / size
86
+ size_info = size_info / size_info.max(1, keepdim=True)[0] # (4(B), 197(T), 1)
87
+ size_a, size_b = size_info[..., ::2, :], size_info[..., 1::2, :].transpose(1, 2)
88
+
89
+ W_size = (size_a * size_b) ** (1 / tau_size)
90
+ else:
91
+ W_size = 1
92
+
93
+ scores = W_sim * W_info * W_size
94
+
95
+ n, t1, t2 = scores.shape
96
+ node_max, node_idx = scores.max(dim=-1) # (12, 99), (12, 99)
97
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] # (12, 99, 1)
98
+ unm_idx = edge_idx[..., r_merge:, :] # Unmerged Tokens (12, 83, 1)
99
+ src_idx = edge_idx[..., :r_merge, :] # Merged Tokens (12, 16, 1)
100
+ dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) # (12, 16, 1)
101
+ unm_idx = unm_idx.sort(dim=1)[0]
102
+
103
+ if mass:
104
+ src_so, dst_so = scores[..., ::2, :], scores[..., 1::2, :] # (1, 1176, 1)
105
+ src_so = src_so.gather(dim=-2, index=src_idx) # (12, 91, 197)
106
+
107
+
108
+ def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
109
+ src, dst = x[..., ::2, :], x[..., 1::2, :] # (12, 99, 197), (12, 98, 197)
110
+ n, mid, c = src.shape[0], src.shape[1:-2], src.shape[-1]
111
+ unm = src.gather(dim=-2, index=unm_idx.expand(n, *mid, t1 - r_merge, c)) # (12, 91, 197)
112
+ src = src.gather(dim=-2, index=src_idx.expand(n, *mid, r_merge, c))
113
+ if mass:
114
+ src = src * src_so
115
+ dst = dst.scatter_reduce(-2, dst_idx.expand(n, *mid, r_merge, c), src, reduce=mode) # (12, 98, 197)
116
+ x = torch.cat([unm, dst], dim=-2) # (12, 1 + 180, 197)
117
+ return x
118
+
119
+ return merge
120
+
121
+
122
+ def merge_wavg(
123
+ merge: Callable, x: torch.Tensor, size: torch.Tensor = None, scores=None, pooling_type = 0, source = None,
124
+ ):
125
+
126
+ size_max = size.amax(dim=-2, keepdim=True)
127
+ if pooling_type:
128
+ norm = merge(scores * size, mode="sum") # (1, 197, 1)
129
+
130
+ x = merge(x * scores * size, mode="sum")
131
+ size = merge(size, mode="sum")
132
+ x = x / norm
133
+ else:
134
+ x = merge(x * (size / size_max), mode="sum")
135
+ size = merge(size, mode="sum")
136
+ x = x / (size / size_max)
137
+
138
+ if source is not None:
139
+ source = merge(source, mode="amax")
140
+ return x, size, source
141
+
142
+ def pruning(
143
+ x: torch.Tensor,
144
+ prune_r : int,
145
+ prune_thr : float,
146
+ scores : torch.Tensor,
147
+ source : torch.Tensor,
148
+ cls : False,
149
+ group_num : int = 1,
150
+ others : [] = None,
151
+ SE : [] = None):
152
+ b, t_full, d = x.shape
153
+ scores_block = scores.reshape(b, -1, group_num).mean(dim=-1) # (B, T)
154
+ scores_block = scores_block / scores_block.mean(dim=-1, keepdim=True)
155
+ t_vis = scores_block.shape[1]
156
+
157
+ if cls: scores_block[:, 0] = math.inf
158
+
159
+ x_block = x.reshape(b, -1, group_num, d)
160
+ if prune_thr: # REMOVE BASED THRESHOLD
161
+ mask_block = (scores_block >= prune_thr)
162
+ else:
163
+ idx_unprune = scores_block.topk(t_vis - int(prune_r // group_num), dim=1, largest=True, sorted=False).indices
164
+ mask_block = torch.zeros_like(scores_block, dtype=torch.bool)
165
+ mask_block = mask_block.scatter(1, idx_unprune, torch.ones_like(idx_unprune, device=idx_unprune.device, dtype=torch.bool))
166
+
167
+ if SE[0] is not None:
168
+ start, end, length = SE
169
+
170
+ mask_F, mask_L = torch.ones((b, start), device=mask_block.device, dtype=torch.bool), torch.ones(b, t_full - end, device=mask_block.device, dtype=torch.bool)
171
+ mask_block = torch.cat([mask_F, mask_block, mask_L], dim =-1)
172
+ t_num = mask_block.sum().item()
173
+ SE[1], SE[2] = t_num - (t_full - end), t_num
174
+
175
+ x_unprune = x_block.masked_select(mask_block.reshape(1, -1, 1, 1)).view(b, -1, d) # (1, 10032(T), 1280) > (1, 4880(T'), 1280)
176
+
177
+ if others is not None:
178
+ T_remain = x_unprune.shape[-2]
179
+ if len(others) == 1: # RET :ENCODER
180
+ cu_lens = others[0]
181
+ if cu_lens is not None: cu_lens[1:] = torch.stack([mask_block[:, :cu_lens[c + 1] // 4].sum() for c in range(len(cu_lens) - 1)]) * group_num
182
+ others = [cu_lens]
183
+ elif len(others) == 2: # QA : ENCODER
184
+ cu_lens, rotary_pos_emb = others
185
+ cu_lens[1:] = torch.stack([mask_block[:, :cu_lens[c + 1] // 4].sum() for c in range(len(cu_lens) - 1)]) * group_num
186
+
187
+ rotary_pos_emb = rotary_pos_emb.reshape(-1, group_num, 40).masked_select(mask_block.reshape(-1, 1, 1)).view(-1, 40)
188
+ others = [cu_lens, rotary_pos_emb]
189
+ elif len(others) == 3: # LLM
190
+ attention_mask, position_ids, cache_position = others
191
+ attention_mask = attention_mask[:, :, :T_remain, :T_remain] if attention_mask is not None else None
192
+ position_ids = position_ids.masked_select(mask_block.reshape(b, 1, -1)).reshape(3, 1, -1)
193
+ cache_position = cache_position.masked_select(mask_block)
194
+ others = [attention_mask, position_ids, cache_position]
195
+ else: # LLM
196
+ attention_mask, position_ids, cache_position, position_embeddings = others
197
+ attention_mask = attention_mask[:, :, :T_remain, :T_remain] if attention_mask is not None else None
198
+ position_ids = position_ids.masked_select(mask_block.reshape(b, 1, -1)).reshape(3, 1, -1)
199
+ cache_position = cache_position.masked_select(mask_block)
200
+ position_embeddings = tuple([v.masked_select(mask_block.reshape(1, 1, -1, 1)).reshape(3, 1, -1, 128) for v in position_embeddings])
201
+ others = [attention_mask, position_ids, cache_position, position_embeddings]
202
+
203
+ if source is not None:
204
+ restored_mask = torch.zeros_like(source, device=source.device)
205
+ restored_mask[source] = mask_block
206
+ source = restored_mask
207
+
208
+
209
+ x = x_unprune
210
+
211
+ return x, source, others
212
+
213
+ def needNaive(info, layer_idx):
214
+ if info["compression"]["use"]:
215
+ if info["compression"]["info_type"] in [1, 2, 3, 4]:
216
+ if (info["compression"]["prune_r"] and info["compression"]["prune_layer"] == layer_idx):
217
+ return True
218
+ return False
219
+
220
+ def needAttn(info, layer_idx):
221
+ if info["compression"]["use"]:
222
+ if info["compression"]["info_type"] in [1, 2, 3, 4]:
223
+ if (info["compression"]["prune_r"] and info["compression"]["prune_layer"] == layer_idx):
224
+ return True
225
+ return False
226
+
227
+