fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -2
  3. fusion_bench/compat/modelpool/__init__.py +3 -2
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/__init__.py +6 -1
  6. fusion_bench/dataset/arc_agi/arc.py +26 -7
  7. fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
  8. fusion_bench/dataset/arc_agi/np_cache.py +0 -1
  9. fusion_bench/dataset/arc_agi/preprocess.py +51 -9
  10. fusion_bench/dataset/llama/__init__.py +1 -0
  11. fusion_bench/dataset/llama/alpaca.py +93 -3
  12. fusion_bench/dataset/llama/collate.py +72 -5
  13. fusion_bench/dataset/llama/metamathqa.py +50 -0
  14. fusion_bench/dataset/llama/preference_700k.py +70 -0
  15. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  16. fusion_bench/dataset/llama/ultrachat.py +58 -0
  17. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  18. fusion_bench/method/__init__.py +4 -1
  19. fusion_bench/method/adamerging/__init__.py +1 -1
  20. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  21. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  22. fusion_bench/method/linear/expo.py +39 -0
  23. fusion_bench/method/lm_finetune/__init__.py +1 -0
  24. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  25. fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
  26. fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
  27. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  28. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  29. fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
  30. fusion_bench/method/rankone_moe/__init__.py +3 -0
  31. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  32. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  33. fusion_bench/method/simple_average.py +1 -1
  34. fusion_bench/method/surgery/__init__.py +3 -0
  35. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  36. fusion_bench/mixins/__init__.py +2 -0
  37. fusion_bench/mixins/clip_classification.py +60 -12
  38. fusion_bench/mixins/fabric_training.py +320 -0
  39. fusion_bench/mixins/lightning_fabric.py +11 -2
  40. fusion_bench/modelpool/__init__.py +2 -0
  41. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  42. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  43. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  44. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  45. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  46. fusion_bench/models/chat_templates/__init__.py +1 -0
  47. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  48. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  49. fusion_bench/models/hf_clip.py +50 -9
  50. fusion_bench/models/rankone_moe.py +410 -0
  51. fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
  52. fusion_bench/models/utils.py +8 -0
  53. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  54. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  55. fusion_bench/optim/__init__.py +2 -0
  56. fusion_bench/optim/exception.py +47 -0
  57. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  58. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  59. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  60. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  61. fusion_bench/optim/mezo.py +0 -2
  62. fusion_bench/programs/fabric_fusion_program.py +5 -1
  63. fusion_bench/taskpool/__init__.py +10 -2
  64. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  65. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  66. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  67. fusion_bench/taskpool/llama/reward_model.py +157 -0
  68. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  69. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
  70. fusion_bench/utils/hydra_utils.py +22 -0
  71. fusion_bench/utils/plot/__init__.py +0 -0
  72. fusion_bench/utils/plot/token.py +52 -0
  73. fusion_bench/utils/plot/token_notebook.py +127 -0
  74. fusion_bench/utils/type.py +5 -3
  75. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
  76. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
  77. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  78. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  79. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  80. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  81. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  82. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  83. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  84. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  85. fusion_bench_config/llama_full_finetune.yaml +19 -0
  86. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  87. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
  88. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
  89. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  90. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
  91. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  92. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  93. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  94. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  95. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  96. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  97. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  98. fusion_bench_config/nyuv2_config.yaml +5 -1
  99. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  100. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  101. fusion_bench_config/llama_weighted_average.yaml +0 -26
  102. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
  103. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
  104. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
  105. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,17 @@ import json
3
3
  import logging
4
4
  import os
5
5
  from pathlib import Path
6
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast # noqa: F401
6
+ from typing import ( # noqa: F401
7
+ TYPE_CHECKING,
8
+ Any,
9
+ Callable,
10
+ Dict,
11
+ List,
12
+ Optional,
13
+ Tuple,
14
+ Union,
15
+ cast,
16
+ )
7
17
 
8
18
  import torch
9
19
  from omegaconf import DictConfig
@@ -25,6 +35,10 @@ from fusion_bench.tasks.clip_classification import get_classnames_and_templates
25
35
  from fusion_bench.utils import instantiate
26
36
  from fusion_bench.utils.parameters import count_parameters
27
37
 
38
+ if TYPE_CHECKING:
39
+ from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
40
+
41
+ # disable tokenizers parallelism by default to avoid deadlocks
28
42
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
29
43
 
30
44
  log = logging.getLogger(__name__)
@@ -198,14 +212,16 @@ class CLIPVisionModelTaskPool(
198
212
  classifier: HFCLIPClassifier,
199
213
  test_loader: DataLoader,
200
214
  num_classes: int,
215
+ task_name: str = None,
201
216
  ):
202
217
  """
203
- Evaluate the classifier on the test dataset.
218
+ Evaluate the classifier on the test dataset (single-task evaluation).
204
219
 
205
220
  Args:
206
221
  classifier (HFCLIPClassifier): The classifier to evaluate.
207
222
  test_loader (DataLoader): The data loader for the test dataset.
208
223
  num_classes (int): The number of classes in the classification task.
224
+ task_name (str): The name of the task.
209
225
 
210
226
  Returns:
211
227
  Dict[str, float]: A dictionary containing the accuracy and loss of the classifier on the test dataset.
@@ -228,7 +244,12 @@ class CLIPVisionModelTaskPool(
228
244
  )
229
245
  ):
230
246
  inputs, targets = batch
231
- outputs = classifier(inputs, return_image_embeds=True, return_dict=True)
247
+ outputs = classifier(
248
+ inputs,
249
+ return_image_embeds=True,
250
+ return_dict=True,
251
+ task_name=task_name,
252
+ )
232
253
  logits: Tensor = outputs["logits"]
233
254
 
234
255
  loss = F.cross_entropy(logits, targets)
@@ -246,12 +267,18 @@ class CLIPVisionModelTaskPool(
246
267
  results = {"accuracy": acc, "loss": loss}
247
268
  return results
248
269
 
249
- def evaluate(self, model: Union[CLIPVisionModel, CLIPVisionTransformer], name=None):
270
+ def evaluate(
271
+ self,
272
+ model: Union[CLIPVisionModel, CLIPVisionTransformer],
273
+ name=None,
274
+ **kwargs,
275
+ ):
250
276
  """
251
277
  Evaluate the model on the image classification task.
252
278
 
253
279
  Args:
254
280
  model (Union[CLIPVisionModel, CLIPVisionTransformer]): The model to evaluate.
281
+ name (Optional[str]): The name of the model. This will be logged into the report if not None.
255
282
 
256
283
  Returns:
257
284
  Dict[str, Any]: A dictionary containing the evaluation results for each task.
@@ -261,8 +288,17 @@ class CLIPVisionModelTaskPool(
261
288
 
262
289
  report = {}
263
290
  # CLIPVisionModel works the same with CLIPVisonTransformer, so we can use it directly
264
- self.clip_model.vision_model = model
265
- classifier = HFCLIPClassifier(self.clip_model, processor=self.processor)
291
+ if hasattr(model, "is_surgery_model") and model.is_surgery_model:
292
+ log.info("running evaluation on a surgery model.")
293
+ model: "SurgeryModelWrapper" = model
294
+ self.clip_model.vision_model = model
295
+ else:
296
+ # replace the vision encoder with the model
297
+ self.clip_model.vision_model = model
298
+ classifier = HFCLIPClassifier(
299
+ self.clip_model,
300
+ processor=self.processor,
301
+ )
266
302
  classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
267
303
  # collect basic model information
268
304
  training_params, all_params = count_parameters(model)
@@ -285,6 +321,7 @@ class CLIPVisionModelTaskPool(
285
321
  classifier,
286
322
  test_dataloader,
287
323
  num_classes=len(classnames),
324
+ task_name=task_name,
288
325
  )
289
326
  report[task_name] = result
290
327
  self.on_task_evaluation_end()
@@ -0,0 +1,157 @@
1
+ """
2
+ The dataset contains the following fields:
3
+
4
+ - chosen_input_ids: The input token ids for the winner.
5
+ - chosen_attention_mask: The attention mask for the winner.
6
+ - rejected_input_ids: The input token ids for the loser.
7
+ - rejected_attention_mask: The attention mask for the loser.
8
+ """
9
+
10
+ import functools
11
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
12
+
13
+ import lightning as L
14
+ import torch
15
+ from omegaconf import DictConfig
16
+ from torch.utils.data import Subset
17
+ import numpy as np
18
+ from tqdm.auto import tqdm
19
+
20
+ from fusion_bench.dataset.llama.collate import bradley_terry_rm_collate
21
+ from fusion_bench.mixins import LightningFabricMixin
22
+ from fusion_bench.taskpool import BaseTaskPool
23
+ from fusion_bench.utils import instantiate
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers import LlamaForSequenceClassification
27
+
28
+
29
+ def evaluate_batch(model: "LlamaForSequenceClassification", batch):
30
+ batch_size = batch["input_ids"].size(0)
31
+ assert batch_size % 2 == 0, "Batch size must be even."
32
+
33
+ outputs = model(
34
+ input_ids=batch["input_ids"],
35
+ attention_mask=batch["attention_mask"],
36
+ )
37
+
38
+ rewards = outputs[0]
39
+ chosen_reward = rewards[: batch_size // 2]
40
+ rejected_rewards = rewards[batch_size // 2 :]
41
+
42
+ loss = -torch.log(torch.sigmoid(chosen_reward - rejected_rewards)).mean()
43
+ correct = (chosen_reward > rejected_rewards).sum().item()
44
+ total = batch_size // 2
45
+
46
+ return {
47
+ "loss": loss.item(),
48
+ "correct": correct,
49
+ "total": total,
50
+ }
51
+
52
+
53
+ def evaluate_dataloader(model: "LlamaForSequenceClassification", dataloader):
54
+ """
55
+ Compute the accuracy of the reward model on the given dataloader.
56
+
57
+ Args:
58
+ model: The reward model
59
+ dataloader: The dataloader for the dataset
60
+
61
+ Returns:
62
+ float: The accuracy of the reward model on the dataset
63
+ """
64
+ metrics = {
65
+ "loss": 0.0,
66
+ "correct": 0,
67
+ "total": 0,
68
+ }
69
+ with torch.no_grad():
70
+ for batch in (pbar := tqdm(dataloader)):
71
+ batch_result = evaluate_batch(model, batch)
72
+ new_total = metrics["total"] + batch_result["total"]
73
+ metrics["loss"] = (
74
+ metrics["loss"] * metrics["total"] / new_total
75
+ + batch_result["loss"] * batch_result["total"] / new_total
76
+ )
77
+ metrics["correct"] += batch_result["correct"]
78
+ metrics["total"] += batch_result["total"]
79
+ pbar.set_postfix(metrics)
80
+
81
+ metrics["accuracy"] = metrics["correct"] / metrics["total"]
82
+ return metrics
83
+
84
+
85
+ class RewardModelEvaluationTaskPool(
86
+ BaseTaskPool,
87
+ LightningFabricMixin,
88
+ ):
89
+ def __init__(
90
+ self,
91
+ test_datasets: List[DictConfig],
92
+ dataloader_kwargs: DictConfig,
93
+ tokenizer: Optional[DictConfig],
94
+ max_num_samples: int = -1,
95
+ seed: int = 0,
96
+ **kwargs,
97
+ ):
98
+ self.seed = seed
99
+ L.seed_everything(seed)
100
+ self._test_datasets = test_datasets
101
+ self.dataloader_kwargs = dataloader_kwargs
102
+ self._tokenizer = tokenizer
103
+ self.max_num_samples = max_num_samples
104
+ super().__init__(**kwargs)
105
+
106
+ def setup(self):
107
+ if self._tokenizer is None:
108
+ # try to load the tokenizer from the model pool
109
+ tokenizer = self._program.modelpool.load_tokenizer()
110
+ else:
111
+ tokenizer = instantiate(self._tokenizer)
112
+ self.tokenizer = tokenizer
113
+
114
+ test_datasets = {
115
+ dataset_name: instantiate(self._test_datasets[dataset_name])
116
+ for dataset_name in self._test_datasets
117
+ }
118
+ if self.max_num_samples > 0:
119
+ test_datasets = {
120
+ dataset_name: Subset(
121
+ test_dataset,
122
+ np.random.permutation(len(test_dataset))[: self.max_num_samples],
123
+ )
124
+ for dataset_name, test_dataset in test_datasets.items()
125
+ }
126
+ test_dataloaders = {
127
+ dataset_name: torch.utils.data.DataLoader(
128
+ test_dataset,
129
+ collate_fn=functools.partial(
130
+ bradley_terry_rm_collate,
131
+ pad_token_id=tokenizer.pad_token_id,
132
+ ),
133
+ **self.dataloader_kwargs,
134
+ )
135
+ for dataset_name, test_dataset in test_datasets.items()
136
+ }
137
+
138
+ self.test_dataloaders = {
139
+ dataset_name: self.fabric.setup_dataloaders(test_dataloader)
140
+ for dataset_name, test_dataloader in test_dataloaders.items()
141
+ }
142
+
143
+ @torch.no_grad()
144
+ def evaluate(self, model: "LlamaForSequenceClassification"):
145
+ self.setup()
146
+
147
+ model = self.fabric.setup_module(model)
148
+ if model.config.pad_token_id is None:
149
+ model.config.pad_token_id = self.tokenizer.pad_token_id
150
+
151
+ model.eval()
152
+ report = {}
153
+ for dataset_name, test_dataloader in self.test_dataloaders.items():
154
+ report[dataset_name] = evaluate_dataloader(model, test_dataloader)
155
+
156
+ print(report)
157
+ return report
@@ -60,4 +60,6 @@ class NYUv2TaskPool(TaskPool):
60
60
  num_workers=self.config.num_workers,
61
61
  )
62
62
  report = self.trainer.validate(model, val_loader)
63
+ if isinstance(report, list) and len(report) == 1:
64
+ report = report[0]
63
65
  return report
@@ -3,9 +3,10 @@ import os
3
3
  from typing import Optional
4
4
 
5
5
  from datasets import load_dataset, load_from_disk
6
+ from omegaconf import DictConfig
6
7
 
7
8
  from fusion_bench.utils import instantiate, timeit_context
8
- from omegaconf import DictConfig
9
+
9
10
  from .glue_preprocessors import glue_processors
10
11
  from .glue_prompt_templates import glue_prompt_templates
11
12
 
@@ -4,3 +4,25 @@ import hydra.core.hydra_config
4
4
  def get_hydra_output_dir():
5
5
  hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
6
6
  return hydra_cfg.runtime.output_dir
7
+
8
+
9
+ def config_priority_get(priority_config, general_config, key, default):
10
+ """
11
+ Retrieve a configuration value with priority.
12
+
13
+ This function retrieves the value associated with `key` from `priority_config` if it exists.
14
+ If the key is not found in `priority_config`, it retrieves the value from `general_config`.
15
+ If the key is not found in either configuration, it returns the provided `default` value.
16
+
17
+ Args:
18
+ priority_config (dict): The configuration dictionary with higher priority.
19
+ general_config (dict): The general configuration dictionary.
20
+ key (str): The key to look up in the configuration dictionaries.
21
+ default: The default value to return if the key is not found in either configuration.
22
+
23
+ Returns:
24
+ The value associated with `key` from `priority_config` or `general_config`, or the `default` value if the key is not found.
25
+ """
26
+ if key in priority_config:
27
+ return priority_config[key]
28
+ return general_config.get(key, default)
File without changes
@@ -0,0 +1,52 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import seaborn as sns
4
+
5
+
6
+ def visualize_model_inputs(input_ids, attention_mask, labels, tokenizer=None):
7
+ """
8
+ Visualize model inputs: attention mask, labels and input_ids
9
+
10
+ Parameters:
11
+ -----------
12
+ attention_mask: numpy array or tensor
13
+ The attention mask array
14
+ labels: numpy array or tensor
15
+ The labels array
16
+ input_ids: numpy array or tensor
17
+ The input ids array
18
+ tokenizer: optional
19
+ The tokenizer object to decode input_ids
20
+ """
21
+
22
+ # Convert inputs to numpy if they're tensors
23
+ attention_mask = np.array(attention_mask)
24
+ labels = np.array(labels)
25
+ input_ids = np.array(input_ids)
26
+
27
+ # Create figure with 3 subplots
28
+ fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(15, 10))
29
+
30
+ # Plot attention mask
31
+ sns.heatmap(attention_mask.reshape(1, -1), ax=ax1, cmap="Blues", cbar=True)
32
+ ax1.set_title("**Attention Mask**")
33
+ ax1.set_ylabel("Sequence")
34
+
35
+ # Plot labels
36
+ sns.heatmap(labels.reshape(1, -1), ax=ax2, cmap="Reds", cbar=True)
37
+ ax2.set_title("**Labels**")
38
+ ax2.set_ylabel("Sequence")
39
+
40
+ # Plot input_ids
41
+ sns.heatmap(input_ids.reshape(1, -1), ax=ax3, cmap="Greens", cbar=True)
42
+ ax3.set_title("**Input IDs**")
43
+ ax3.set_ylabel("Sequence")
44
+
45
+ # If tokenizer is provided, add decoded tokens as x-axis labels
46
+ if tokenizer:
47
+ decoded_tokens = [tokenizer.decode(token_id) for token_id in input_ids]
48
+ ax3.set_xticks(np.arange(len(decoded_tokens)) + 0.5)
49
+ ax3.set_xticklabels(decoded_tokens, rotation=45, ha="right")
50
+
51
+ plt.tight_layout()
52
+ return fig
@@ -0,0 +1,127 @@
1
+ import numpy as np
2
+ from IPython.display import HTML, display
3
+
4
+
5
+ def create_color_style():
6
+ return """
7
+ <style>
8
+ .token-container { font-family: monospace; white-space: pre; }
9
+ .attention { background-color: #90EE90; } /* Light green */
10
+ .label { background-color: #FFB6C6; } /* Light red */
11
+ .token { color: #0066cc; } /* Blue */
12
+ .stats { font-weight: bold; }
13
+ </style>
14
+ """
15
+
16
+
17
+ def escape_special_chars(text):
18
+ """Convert special characters to their string representation"""
19
+ return (
20
+ text.replace("\n", "\\n")
21
+ .replace("\t", "\\t")
22
+ .replace("\r", "\\r")
23
+ .replace(" ", "␣")
24
+ ) # Optional: show spaces with visible character
25
+
26
+
27
+ def visualize_tokens_html(input_ids, attention_mask, labels, tokenizer):
28
+ """
29
+ Visualize model inputs using HTML colored text representation for Jupyter Notebook
30
+ with special characters shown as strings
31
+ """
32
+ # Convert to numpy if tensors
33
+ attention_mask = np.array(attention_mask).flatten()
34
+ labels = np.array(labels).flatten()
35
+ input_ids = np.array(input_ids).flatten()
36
+
37
+ # Decode tokens and escape special characters
38
+ tokens = [escape_special_chars(tokenizer.decode(id_)) for id_ in input_ids]
39
+
40
+ # Create HTML output
41
+ html_output = [create_color_style()]
42
+
43
+ # Header
44
+ html_output.append("<h3>**Token Visualization**</h3>")
45
+
46
+ # Legend
47
+ html_output.append(
48
+ """
49
+ <div style='margin: 10px 0;'>
50
+ <strong>Legend:</strong><br>
51
+ <span class='attention'>&nbsp;&nbsp;&nbsp;&nbsp;</span> Active Attention<br>
52
+ <span class='label'>&nbsp;&nbsp;&nbsp;&nbsp;</span> Label Present<br>
53
+ <span class='token'>Text</span> Token Text<br>
54
+ Special Characters: \\n (newline), \\t (tab), ␣ (space)
55
+ </div>
56
+ """
57
+ )
58
+
59
+ # Token alignment
60
+ html_output.append("<strong>Token Alignment:</strong>")
61
+ html_output.append("<div class='token-container'>")
62
+
63
+ # Calculate maximum token length for better alignment
64
+ max_token_len = max(len(str(token)) for token in tokens)
65
+
66
+ for i, (input_id, token, mask, label) in enumerate(
67
+ zip(input_ids, tokens, attention_mask, labels)
68
+ ):
69
+ # Pad token for alignment
70
+ token_text = f"{token:{max_token_len}s}"
71
+
72
+ # Create classes for styling
73
+ classes = []
74
+ if mask == 1:
75
+ classes.append("attention")
76
+ if label != -100 and label != 0:
77
+ classes.append("label")
78
+
79
+ class_str = f"class='{' '.join(classes)}'" if classes else ""
80
+
81
+ # Create the line
82
+ line = f"Position {i:3d}: <span {class_str}><span class='token'>{token_text}</span></span> "
83
+ line += (
84
+ f"(Mask: {int(mask)}, Label: {int(label)}, Inpu_id: {int(input_id)})<br>"
85
+ )
86
+ html_output.append(line)
87
+
88
+ html_output.append("</div>")
89
+
90
+ # Statistics
91
+ html_output.append(
92
+ """
93
+ <div class='stats' style='margin-top: 10px;'>
94
+ Statistics:<br>
95
+ Total tokens: {}<br>
96
+ Active attention tokens: {}<br>
97
+ Labeled tokens: {}<br>
98
+ </div>
99
+ """.format(
100
+ len(tokens), attention_mask.sum(), sum(labels != -100)
101
+ )
102
+ )
103
+
104
+ # Display the HTML
105
+ display(HTML("".join(html_output)))
106
+
107
+
108
+ # Example usage:
109
+ """
110
+ from transformers import AutoTokenizer
111
+ import torch
112
+
113
+ # Initialize tokenizer
114
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
115
+
116
+ # Sample input with special characters
117
+ text = "Hello,\nhow are\tyou?"
118
+ inputs = tokenizer(text, return_tensors='pt')
119
+ labels = torch.zeros_like(inputs['input_ids']) # dummy labels
120
+
121
+ visualize_tokens_html(
122
+ inputs['attention_mask'][0],
123
+ labels[0],
124
+ inputs['input_ids'][0],
125
+ tokenizer
126
+ )
127
+ """
@@ -6,18 +6,20 @@ from typing_extensions import TypeAlias
6
6
 
7
7
  try:
8
8
  import torch
9
- from torch import Tensor
9
+ from torch import Tensor, nn
10
10
 
11
11
  StateDictType: TypeAlias = Dict[str, Tensor]
12
+ TorchModelType = TypeVar("TorchModelType", bound=nn.Module)
13
+
12
14
  except ImportError:
13
15
  pass
14
16
 
15
17
 
16
- ModuleType = type(sys)
18
+ PyModuleType = type(sys)
17
19
  T = TypeVar("T")
18
20
  T1 = TypeVar("T1")
19
21
  T2 = TypeVar("T2")
20
22
  T3 = TypeVar("T3")
21
23
  T4 = TypeVar("T4")
22
24
 
23
- __all__ = ["StateDictType", "ModuleType", "T", "T1", "T2", "T3", "T4"]
25
+ __all__ = ["StateDictType", "PyModuleType", "TorchModelType", "T", "T1", "T2", "T3", "T4"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: fusion_bench
3
- Version: 0.2.5
3
+ Version: 0.2.7
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  License: MIT License