fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/method/base_algorithm.py +7 -2
- fusion_bench/compat/modelpool/__init__.py +3 -2
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/__init__.py +6 -1
- fusion_bench/dataset/arc_agi/arc.py +26 -7
- fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
- fusion_bench/dataset/arc_agi/np_cache.py +0 -1
- fusion_bench/dataset/arc_agi/preprocess.py +51 -9
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +72 -5
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +4 -1
- fusion_bench/method/adamerging/__init__.py +1 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/simple_average.py +1 -1
- fusion_bench/method/surgery/__init__.py +3 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +60 -12
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +11 -2
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +5 -1
- fusion_bench/taskpool/__init__.py +10 -2
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +5 -3
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
|
@@ -3,7 +3,17 @@ import json
|
|
|
3
3
|
import logging
|
|
4
4
|
import os
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
6
|
+
from typing import ( # noqa: F401
|
|
7
|
+
TYPE_CHECKING,
|
|
8
|
+
Any,
|
|
9
|
+
Callable,
|
|
10
|
+
Dict,
|
|
11
|
+
List,
|
|
12
|
+
Optional,
|
|
13
|
+
Tuple,
|
|
14
|
+
Union,
|
|
15
|
+
cast,
|
|
16
|
+
)
|
|
7
17
|
|
|
8
18
|
import torch
|
|
9
19
|
from omegaconf import DictConfig
|
|
@@ -25,6 +35,10 @@ from fusion_bench.tasks.clip_classification import get_classnames_and_templates
|
|
|
25
35
|
from fusion_bench.utils import instantiate
|
|
26
36
|
from fusion_bench.utils.parameters import count_parameters
|
|
27
37
|
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
|
|
40
|
+
|
|
41
|
+
# disable tokenizers parallelism by default to avoid deadlocks
|
|
28
42
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
29
43
|
|
|
30
44
|
log = logging.getLogger(__name__)
|
|
@@ -198,14 +212,16 @@ class CLIPVisionModelTaskPool(
|
|
|
198
212
|
classifier: HFCLIPClassifier,
|
|
199
213
|
test_loader: DataLoader,
|
|
200
214
|
num_classes: int,
|
|
215
|
+
task_name: str = None,
|
|
201
216
|
):
|
|
202
217
|
"""
|
|
203
|
-
Evaluate the classifier on the test dataset.
|
|
218
|
+
Evaluate the classifier on the test dataset (single-task evaluation).
|
|
204
219
|
|
|
205
220
|
Args:
|
|
206
221
|
classifier (HFCLIPClassifier): The classifier to evaluate.
|
|
207
222
|
test_loader (DataLoader): The data loader for the test dataset.
|
|
208
223
|
num_classes (int): The number of classes in the classification task.
|
|
224
|
+
task_name (str): The name of the task.
|
|
209
225
|
|
|
210
226
|
Returns:
|
|
211
227
|
Dict[str, float]: A dictionary containing the accuracy and loss of the classifier on the test dataset.
|
|
@@ -228,7 +244,12 @@ class CLIPVisionModelTaskPool(
|
|
|
228
244
|
)
|
|
229
245
|
):
|
|
230
246
|
inputs, targets = batch
|
|
231
|
-
outputs = classifier(
|
|
247
|
+
outputs = classifier(
|
|
248
|
+
inputs,
|
|
249
|
+
return_image_embeds=True,
|
|
250
|
+
return_dict=True,
|
|
251
|
+
task_name=task_name,
|
|
252
|
+
)
|
|
232
253
|
logits: Tensor = outputs["logits"]
|
|
233
254
|
|
|
234
255
|
loss = F.cross_entropy(logits, targets)
|
|
@@ -246,12 +267,18 @@ class CLIPVisionModelTaskPool(
|
|
|
246
267
|
results = {"accuracy": acc, "loss": loss}
|
|
247
268
|
return results
|
|
248
269
|
|
|
249
|
-
def evaluate(
|
|
270
|
+
def evaluate(
|
|
271
|
+
self,
|
|
272
|
+
model: Union[CLIPVisionModel, CLIPVisionTransformer],
|
|
273
|
+
name=None,
|
|
274
|
+
**kwargs,
|
|
275
|
+
):
|
|
250
276
|
"""
|
|
251
277
|
Evaluate the model on the image classification task.
|
|
252
278
|
|
|
253
279
|
Args:
|
|
254
280
|
model (Union[CLIPVisionModel, CLIPVisionTransformer]): The model to evaluate.
|
|
281
|
+
name (Optional[str]): The name of the model. This will be logged into the report if not None.
|
|
255
282
|
|
|
256
283
|
Returns:
|
|
257
284
|
Dict[str, Any]: A dictionary containing the evaluation results for each task.
|
|
@@ -261,8 +288,17 @@ class CLIPVisionModelTaskPool(
|
|
|
261
288
|
|
|
262
289
|
report = {}
|
|
263
290
|
# CLIPVisionModel works the same with CLIPVisonTransformer, so we can use it directly
|
|
264
|
-
|
|
265
|
-
|
|
291
|
+
if hasattr(model, "is_surgery_model") and model.is_surgery_model:
|
|
292
|
+
log.info("running evaluation on a surgery model.")
|
|
293
|
+
model: "SurgeryModelWrapper" = model
|
|
294
|
+
self.clip_model.vision_model = model
|
|
295
|
+
else:
|
|
296
|
+
# replace the vision encoder with the model
|
|
297
|
+
self.clip_model.vision_model = model
|
|
298
|
+
classifier = HFCLIPClassifier(
|
|
299
|
+
self.clip_model,
|
|
300
|
+
processor=self.processor,
|
|
301
|
+
)
|
|
266
302
|
classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
|
|
267
303
|
# collect basic model information
|
|
268
304
|
training_params, all_params = count_parameters(model)
|
|
@@ -285,6 +321,7 @@ class CLIPVisionModelTaskPool(
|
|
|
285
321
|
classifier,
|
|
286
322
|
test_dataloader,
|
|
287
323
|
num_classes=len(classnames),
|
|
324
|
+
task_name=task_name,
|
|
288
325
|
)
|
|
289
326
|
report[task_name] = result
|
|
290
327
|
self.on_task_evaluation_end()
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""
|
|
2
|
+
The dataset contains the following fields:
|
|
3
|
+
|
|
4
|
+
- chosen_input_ids: The input token ids for the winner.
|
|
5
|
+
- chosen_attention_mask: The attention mask for the winner.
|
|
6
|
+
- rejected_input_ids: The input token ids for the loser.
|
|
7
|
+
- rejected_attention_mask: The attention mask for the loser.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import functools
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
|
12
|
+
|
|
13
|
+
import lightning as L
|
|
14
|
+
import torch
|
|
15
|
+
from omegaconf import DictConfig
|
|
16
|
+
from torch.utils.data import Subset
|
|
17
|
+
import numpy as np
|
|
18
|
+
from tqdm.auto import tqdm
|
|
19
|
+
|
|
20
|
+
from fusion_bench.dataset.llama.collate import bradley_terry_rm_collate
|
|
21
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
22
|
+
from fusion_bench.taskpool import BaseTaskPool
|
|
23
|
+
from fusion_bench.utils import instantiate
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from transformers import LlamaForSequenceClassification
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def evaluate_batch(model: "LlamaForSequenceClassification", batch):
|
|
30
|
+
batch_size = batch["input_ids"].size(0)
|
|
31
|
+
assert batch_size % 2 == 0, "Batch size must be even."
|
|
32
|
+
|
|
33
|
+
outputs = model(
|
|
34
|
+
input_ids=batch["input_ids"],
|
|
35
|
+
attention_mask=batch["attention_mask"],
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
rewards = outputs[0]
|
|
39
|
+
chosen_reward = rewards[: batch_size // 2]
|
|
40
|
+
rejected_rewards = rewards[batch_size // 2 :]
|
|
41
|
+
|
|
42
|
+
loss = -torch.log(torch.sigmoid(chosen_reward - rejected_rewards)).mean()
|
|
43
|
+
correct = (chosen_reward > rejected_rewards).sum().item()
|
|
44
|
+
total = batch_size // 2
|
|
45
|
+
|
|
46
|
+
return {
|
|
47
|
+
"loss": loss.item(),
|
|
48
|
+
"correct": correct,
|
|
49
|
+
"total": total,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def evaluate_dataloader(model: "LlamaForSequenceClassification", dataloader):
|
|
54
|
+
"""
|
|
55
|
+
Compute the accuracy of the reward model on the given dataloader.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
model: The reward model
|
|
59
|
+
dataloader: The dataloader for the dataset
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
float: The accuracy of the reward model on the dataset
|
|
63
|
+
"""
|
|
64
|
+
metrics = {
|
|
65
|
+
"loss": 0.0,
|
|
66
|
+
"correct": 0,
|
|
67
|
+
"total": 0,
|
|
68
|
+
}
|
|
69
|
+
with torch.no_grad():
|
|
70
|
+
for batch in (pbar := tqdm(dataloader)):
|
|
71
|
+
batch_result = evaluate_batch(model, batch)
|
|
72
|
+
new_total = metrics["total"] + batch_result["total"]
|
|
73
|
+
metrics["loss"] = (
|
|
74
|
+
metrics["loss"] * metrics["total"] / new_total
|
|
75
|
+
+ batch_result["loss"] * batch_result["total"] / new_total
|
|
76
|
+
)
|
|
77
|
+
metrics["correct"] += batch_result["correct"]
|
|
78
|
+
metrics["total"] += batch_result["total"]
|
|
79
|
+
pbar.set_postfix(metrics)
|
|
80
|
+
|
|
81
|
+
metrics["accuracy"] = metrics["correct"] / metrics["total"]
|
|
82
|
+
return metrics
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class RewardModelEvaluationTaskPool(
|
|
86
|
+
BaseTaskPool,
|
|
87
|
+
LightningFabricMixin,
|
|
88
|
+
):
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
test_datasets: List[DictConfig],
|
|
92
|
+
dataloader_kwargs: DictConfig,
|
|
93
|
+
tokenizer: Optional[DictConfig],
|
|
94
|
+
max_num_samples: int = -1,
|
|
95
|
+
seed: int = 0,
|
|
96
|
+
**kwargs,
|
|
97
|
+
):
|
|
98
|
+
self.seed = seed
|
|
99
|
+
L.seed_everything(seed)
|
|
100
|
+
self._test_datasets = test_datasets
|
|
101
|
+
self.dataloader_kwargs = dataloader_kwargs
|
|
102
|
+
self._tokenizer = tokenizer
|
|
103
|
+
self.max_num_samples = max_num_samples
|
|
104
|
+
super().__init__(**kwargs)
|
|
105
|
+
|
|
106
|
+
def setup(self):
|
|
107
|
+
if self._tokenizer is None:
|
|
108
|
+
# try to load the tokenizer from the model pool
|
|
109
|
+
tokenizer = self._program.modelpool.load_tokenizer()
|
|
110
|
+
else:
|
|
111
|
+
tokenizer = instantiate(self._tokenizer)
|
|
112
|
+
self.tokenizer = tokenizer
|
|
113
|
+
|
|
114
|
+
test_datasets = {
|
|
115
|
+
dataset_name: instantiate(self._test_datasets[dataset_name])
|
|
116
|
+
for dataset_name in self._test_datasets
|
|
117
|
+
}
|
|
118
|
+
if self.max_num_samples > 0:
|
|
119
|
+
test_datasets = {
|
|
120
|
+
dataset_name: Subset(
|
|
121
|
+
test_dataset,
|
|
122
|
+
np.random.permutation(len(test_dataset))[: self.max_num_samples],
|
|
123
|
+
)
|
|
124
|
+
for dataset_name, test_dataset in test_datasets.items()
|
|
125
|
+
}
|
|
126
|
+
test_dataloaders = {
|
|
127
|
+
dataset_name: torch.utils.data.DataLoader(
|
|
128
|
+
test_dataset,
|
|
129
|
+
collate_fn=functools.partial(
|
|
130
|
+
bradley_terry_rm_collate,
|
|
131
|
+
pad_token_id=tokenizer.pad_token_id,
|
|
132
|
+
),
|
|
133
|
+
**self.dataloader_kwargs,
|
|
134
|
+
)
|
|
135
|
+
for dataset_name, test_dataset in test_datasets.items()
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
self.test_dataloaders = {
|
|
139
|
+
dataset_name: self.fabric.setup_dataloaders(test_dataloader)
|
|
140
|
+
for dataset_name, test_dataloader in test_dataloaders.items()
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
@torch.no_grad()
|
|
144
|
+
def evaluate(self, model: "LlamaForSequenceClassification"):
|
|
145
|
+
self.setup()
|
|
146
|
+
|
|
147
|
+
model = self.fabric.setup_module(model)
|
|
148
|
+
if model.config.pad_token_id is None:
|
|
149
|
+
model.config.pad_token_id = self.tokenizer.pad_token_id
|
|
150
|
+
|
|
151
|
+
model.eval()
|
|
152
|
+
report = {}
|
|
153
|
+
for dataset_name, test_dataloader in self.test_dataloaders.items():
|
|
154
|
+
report[dataset_name] = evaluate_dataloader(model, test_dataloader)
|
|
155
|
+
|
|
156
|
+
print(report)
|
|
157
|
+
return report
|
|
@@ -3,9 +3,10 @@ import os
|
|
|
3
3
|
from typing import Optional
|
|
4
4
|
|
|
5
5
|
from datasets import load_dataset, load_from_disk
|
|
6
|
+
from omegaconf import DictConfig
|
|
6
7
|
|
|
7
8
|
from fusion_bench.utils import instantiate, timeit_context
|
|
8
|
-
|
|
9
|
+
|
|
9
10
|
from .glue_preprocessors import glue_processors
|
|
10
11
|
from .glue_prompt_templates import glue_prompt_templates
|
|
11
12
|
|
|
@@ -4,3 +4,25 @@ import hydra.core.hydra_config
|
|
|
4
4
|
def get_hydra_output_dir():
|
|
5
5
|
hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
|
|
6
6
|
return hydra_cfg.runtime.output_dir
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def config_priority_get(priority_config, general_config, key, default):
|
|
10
|
+
"""
|
|
11
|
+
Retrieve a configuration value with priority.
|
|
12
|
+
|
|
13
|
+
This function retrieves the value associated with `key` from `priority_config` if it exists.
|
|
14
|
+
If the key is not found in `priority_config`, it retrieves the value from `general_config`.
|
|
15
|
+
If the key is not found in either configuration, it returns the provided `default` value.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
priority_config (dict): The configuration dictionary with higher priority.
|
|
19
|
+
general_config (dict): The general configuration dictionary.
|
|
20
|
+
key (str): The key to look up in the configuration dictionaries.
|
|
21
|
+
default: The default value to return if the key is not found in either configuration.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
The value associated with `key` from `priority_config` or `general_config`, or the `default` value if the key is not found.
|
|
25
|
+
"""
|
|
26
|
+
if key in priority_config:
|
|
27
|
+
return priority_config[key]
|
|
28
|
+
return general_config.get(key, default)
|
|
File without changes
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
2
|
+
import numpy as np
|
|
3
|
+
import seaborn as sns
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def visualize_model_inputs(input_ids, attention_mask, labels, tokenizer=None):
|
|
7
|
+
"""
|
|
8
|
+
Visualize model inputs: attention mask, labels and input_ids
|
|
9
|
+
|
|
10
|
+
Parameters:
|
|
11
|
+
-----------
|
|
12
|
+
attention_mask: numpy array or tensor
|
|
13
|
+
The attention mask array
|
|
14
|
+
labels: numpy array or tensor
|
|
15
|
+
The labels array
|
|
16
|
+
input_ids: numpy array or tensor
|
|
17
|
+
The input ids array
|
|
18
|
+
tokenizer: optional
|
|
19
|
+
The tokenizer object to decode input_ids
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
# Convert inputs to numpy if they're tensors
|
|
23
|
+
attention_mask = np.array(attention_mask)
|
|
24
|
+
labels = np.array(labels)
|
|
25
|
+
input_ids = np.array(input_ids)
|
|
26
|
+
|
|
27
|
+
# Create figure with 3 subplots
|
|
28
|
+
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(15, 10))
|
|
29
|
+
|
|
30
|
+
# Plot attention mask
|
|
31
|
+
sns.heatmap(attention_mask.reshape(1, -1), ax=ax1, cmap="Blues", cbar=True)
|
|
32
|
+
ax1.set_title("**Attention Mask**")
|
|
33
|
+
ax1.set_ylabel("Sequence")
|
|
34
|
+
|
|
35
|
+
# Plot labels
|
|
36
|
+
sns.heatmap(labels.reshape(1, -1), ax=ax2, cmap="Reds", cbar=True)
|
|
37
|
+
ax2.set_title("**Labels**")
|
|
38
|
+
ax2.set_ylabel("Sequence")
|
|
39
|
+
|
|
40
|
+
# Plot input_ids
|
|
41
|
+
sns.heatmap(input_ids.reshape(1, -1), ax=ax3, cmap="Greens", cbar=True)
|
|
42
|
+
ax3.set_title("**Input IDs**")
|
|
43
|
+
ax3.set_ylabel("Sequence")
|
|
44
|
+
|
|
45
|
+
# If tokenizer is provided, add decoded tokens as x-axis labels
|
|
46
|
+
if tokenizer:
|
|
47
|
+
decoded_tokens = [tokenizer.decode(token_id) for token_id in input_ids]
|
|
48
|
+
ax3.set_xticks(np.arange(len(decoded_tokens)) + 0.5)
|
|
49
|
+
ax3.set_xticklabels(decoded_tokens, rotation=45, ha="right")
|
|
50
|
+
|
|
51
|
+
plt.tight_layout()
|
|
52
|
+
return fig
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from IPython.display import HTML, display
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def create_color_style():
|
|
6
|
+
return """
|
|
7
|
+
<style>
|
|
8
|
+
.token-container { font-family: monospace; white-space: pre; }
|
|
9
|
+
.attention { background-color: #90EE90; } /* Light green */
|
|
10
|
+
.label { background-color: #FFB6C6; } /* Light red */
|
|
11
|
+
.token { color: #0066cc; } /* Blue */
|
|
12
|
+
.stats { font-weight: bold; }
|
|
13
|
+
</style>
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def escape_special_chars(text):
|
|
18
|
+
"""Convert special characters to their string representation"""
|
|
19
|
+
return (
|
|
20
|
+
text.replace("\n", "\\n")
|
|
21
|
+
.replace("\t", "\\t")
|
|
22
|
+
.replace("\r", "\\r")
|
|
23
|
+
.replace(" ", "␣")
|
|
24
|
+
) # Optional: show spaces with visible character
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def visualize_tokens_html(input_ids, attention_mask, labels, tokenizer):
|
|
28
|
+
"""
|
|
29
|
+
Visualize model inputs using HTML colored text representation for Jupyter Notebook
|
|
30
|
+
with special characters shown as strings
|
|
31
|
+
"""
|
|
32
|
+
# Convert to numpy if tensors
|
|
33
|
+
attention_mask = np.array(attention_mask).flatten()
|
|
34
|
+
labels = np.array(labels).flatten()
|
|
35
|
+
input_ids = np.array(input_ids).flatten()
|
|
36
|
+
|
|
37
|
+
# Decode tokens and escape special characters
|
|
38
|
+
tokens = [escape_special_chars(tokenizer.decode(id_)) for id_ in input_ids]
|
|
39
|
+
|
|
40
|
+
# Create HTML output
|
|
41
|
+
html_output = [create_color_style()]
|
|
42
|
+
|
|
43
|
+
# Header
|
|
44
|
+
html_output.append("<h3>**Token Visualization**</h3>")
|
|
45
|
+
|
|
46
|
+
# Legend
|
|
47
|
+
html_output.append(
|
|
48
|
+
"""
|
|
49
|
+
<div style='margin: 10px 0;'>
|
|
50
|
+
<strong>Legend:</strong><br>
|
|
51
|
+
<span class='attention'> </span> Active Attention<br>
|
|
52
|
+
<span class='label'> </span> Label Present<br>
|
|
53
|
+
<span class='token'>Text</span> Token Text<br>
|
|
54
|
+
Special Characters: \\n (newline), \\t (tab), ␣ (space)
|
|
55
|
+
</div>
|
|
56
|
+
"""
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Token alignment
|
|
60
|
+
html_output.append("<strong>Token Alignment:</strong>")
|
|
61
|
+
html_output.append("<div class='token-container'>")
|
|
62
|
+
|
|
63
|
+
# Calculate maximum token length for better alignment
|
|
64
|
+
max_token_len = max(len(str(token)) for token in tokens)
|
|
65
|
+
|
|
66
|
+
for i, (input_id, token, mask, label) in enumerate(
|
|
67
|
+
zip(input_ids, tokens, attention_mask, labels)
|
|
68
|
+
):
|
|
69
|
+
# Pad token for alignment
|
|
70
|
+
token_text = f"{token:{max_token_len}s}"
|
|
71
|
+
|
|
72
|
+
# Create classes for styling
|
|
73
|
+
classes = []
|
|
74
|
+
if mask == 1:
|
|
75
|
+
classes.append("attention")
|
|
76
|
+
if label != -100 and label != 0:
|
|
77
|
+
classes.append("label")
|
|
78
|
+
|
|
79
|
+
class_str = f"class='{' '.join(classes)}'" if classes else ""
|
|
80
|
+
|
|
81
|
+
# Create the line
|
|
82
|
+
line = f"Position {i:3d}: <span {class_str}><span class='token'>{token_text}</span></span> "
|
|
83
|
+
line += (
|
|
84
|
+
f"(Mask: {int(mask)}, Label: {int(label)}, Inpu_id: {int(input_id)})<br>"
|
|
85
|
+
)
|
|
86
|
+
html_output.append(line)
|
|
87
|
+
|
|
88
|
+
html_output.append("</div>")
|
|
89
|
+
|
|
90
|
+
# Statistics
|
|
91
|
+
html_output.append(
|
|
92
|
+
"""
|
|
93
|
+
<div class='stats' style='margin-top: 10px;'>
|
|
94
|
+
Statistics:<br>
|
|
95
|
+
Total tokens: {}<br>
|
|
96
|
+
Active attention tokens: {}<br>
|
|
97
|
+
Labeled tokens: {}<br>
|
|
98
|
+
</div>
|
|
99
|
+
""".format(
|
|
100
|
+
len(tokens), attention_mask.sum(), sum(labels != -100)
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Display the HTML
|
|
105
|
+
display(HTML("".join(html_output)))
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# Example usage:
|
|
109
|
+
"""
|
|
110
|
+
from transformers import AutoTokenizer
|
|
111
|
+
import torch
|
|
112
|
+
|
|
113
|
+
# Initialize tokenizer
|
|
114
|
+
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
|
115
|
+
|
|
116
|
+
# Sample input with special characters
|
|
117
|
+
text = "Hello,\nhow are\tyou?"
|
|
118
|
+
inputs = tokenizer(text, return_tensors='pt')
|
|
119
|
+
labels = torch.zeros_like(inputs['input_ids']) # dummy labels
|
|
120
|
+
|
|
121
|
+
visualize_tokens_html(
|
|
122
|
+
inputs['attention_mask'][0],
|
|
123
|
+
labels[0],
|
|
124
|
+
inputs['input_ids'][0],
|
|
125
|
+
tokenizer
|
|
126
|
+
)
|
|
127
|
+
"""
|
fusion_bench/utils/type.py
CHANGED
|
@@ -6,18 +6,20 @@ from typing_extensions import TypeAlias
|
|
|
6
6
|
|
|
7
7
|
try:
|
|
8
8
|
import torch
|
|
9
|
-
from torch import Tensor
|
|
9
|
+
from torch import Tensor, nn
|
|
10
10
|
|
|
11
11
|
StateDictType: TypeAlias = Dict[str, Tensor]
|
|
12
|
+
TorchModelType = TypeVar("TorchModelType", bound=nn.Module)
|
|
13
|
+
|
|
12
14
|
except ImportError:
|
|
13
15
|
pass
|
|
14
16
|
|
|
15
17
|
|
|
16
|
-
|
|
18
|
+
PyModuleType = type(sys)
|
|
17
19
|
T = TypeVar("T")
|
|
18
20
|
T1 = TypeVar("T1")
|
|
19
21
|
T2 = TypeVar("T2")
|
|
20
22
|
T3 = TypeVar("T3")
|
|
21
23
|
T4 = TypeVar("T4")
|
|
22
24
|
|
|
23
|
-
__all__ = ["StateDictType", "
|
|
25
|
+
__all__ = ["StateDictType", "PyModuleType", "TorchModelType", "T", "T1", "T2", "T3", "T4"]
|