fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -2
  3. fusion_bench/compat/modelpool/__init__.py +3 -2
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/__init__.py +6 -1
  6. fusion_bench/dataset/arc_agi/arc.py +26 -7
  7. fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
  8. fusion_bench/dataset/arc_agi/np_cache.py +0 -1
  9. fusion_bench/dataset/arc_agi/preprocess.py +51 -9
  10. fusion_bench/dataset/llama/__init__.py +1 -0
  11. fusion_bench/dataset/llama/alpaca.py +93 -3
  12. fusion_bench/dataset/llama/collate.py +72 -5
  13. fusion_bench/dataset/llama/metamathqa.py +50 -0
  14. fusion_bench/dataset/llama/preference_700k.py +70 -0
  15. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  16. fusion_bench/dataset/llama/ultrachat.py +58 -0
  17. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  18. fusion_bench/method/__init__.py +4 -1
  19. fusion_bench/method/adamerging/__init__.py +1 -1
  20. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  21. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  22. fusion_bench/method/linear/expo.py +39 -0
  23. fusion_bench/method/lm_finetune/__init__.py +1 -0
  24. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  25. fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
  26. fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
  27. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  28. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  29. fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
  30. fusion_bench/method/rankone_moe/__init__.py +3 -0
  31. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  32. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  33. fusion_bench/method/simple_average.py +1 -1
  34. fusion_bench/method/surgery/__init__.py +3 -0
  35. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  36. fusion_bench/mixins/__init__.py +2 -0
  37. fusion_bench/mixins/clip_classification.py +60 -12
  38. fusion_bench/mixins/fabric_training.py +320 -0
  39. fusion_bench/mixins/lightning_fabric.py +11 -2
  40. fusion_bench/modelpool/__init__.py +2 -0
  41. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  42. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  43. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  44. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  45. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  46. fusion_bench/models/chat_templates/__init__.py +1 -0
  47. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  48. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  49. fusion_bench/models/hf_clip.py +50 -9
  50. fusion_bench/models/rankone_moe.py +410 -0
  51. fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
  52. fusion_bench/models/utils.py +8 -0
  53. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  54. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  55. fusion_bench/optim/__init__.py +2 -0
  56. fusion_bench/optim/exception.py +47 -0
  57. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  58. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  59. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  60. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  61. fusion_bench/optim/mezo.py +0 -2
  62. fusion_bench/programs/fabric_fusion_program.py +5 -1
  63. fusion_bench/taskpool/__init__.py +10 -2
  64. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  65. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  66. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  67. fusion_bench/taskpool/llama/reward_model.py +157 -0
  68. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  69. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
  70. fusion_bench/utils/hydra_utils.py +22 -0
  71. fusion_bench/utils/plot/__init__.py +0 -0
  72. fusion_bench/utils/plot/token.py +52 -0
  73. fusion_bench/utils/plot/token_notebook.py +127 -0
  74. fusion_bench/utils/type.py +5 -3
  75. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
  76. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
  77. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  78. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  79. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  80. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  81. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  82. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  83. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  84. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  85. fusion_bench_config/llama_full_finetune.yaml +19 -0
  86. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  87. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
  88. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
  89. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  90. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
  91. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  92. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  93. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  94. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  95. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  96. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  97. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  98. fusion_bench_config/nyuv2_config.yaml +5 -1
  99. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  100. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  101. fusion_bench_config/llama_weighted_average.yaml +0 -26
  102. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
  103. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
  104. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
  105. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,100 @@
1
1
  import logging
2
2
  import os
3
+ import warnings
3
4
  from typing import Any, Dict, List, Optional
4
5
 
5
6
  from datasets import Dataset, load_dataset, load_from_disk
7
+ from lightning.fabric.utilities import rank_zero_only
8
+ from tqdm.auto import tqdm
6
9
  from transformers import PreTrainedTokenizer
7
10
 
8
11
  import fusion_bench
12
+ from fusion_bench.utils import timeit_context
9
13
 
10
14
  log = logging.getLogger(__name__)
11
15
 
12
16
 
13
- def tokenize_alpaca_dataset(
17
+ def convert_alpaca_to_conversation(alpaca_data: List[Dict[str, str]]):
18
+ """
19
+ Convert Alpaca format data to conversation format.
20
+
21
+ Args:
22
+ alpaca_data (list): List of dictionaries in Alpaca format with
23
+ 'instruction', 'input', and 'output' keys
24
+
25
+ Returns:
26
+ list: List of conversations in ChatML format
27
+ """
28
+ conversations = []
29
+
30
+ for item in tqdm(
31
+ alpaca_data,
32
+ "Converting Alpaca to conversations",
33
+ disable=not rank_zero_only.rank == 0,
34
+ ):
35
+ # Skip if required fields are missing
36
+ if not item.get("instruction") or not item.get("output"):
37
+ continue
38
+
39
+ conversation = []
40
+
41
+ # Create user message
42
+ user_content = item["instruction"]
43
+ if item.get("input") and item["input"].strip():
44
+ user_content += f"\n\n{item['input']}"
45
+
46
+ conversation.append({"role": "user", "content": user_content})
47
+
48
+ # Create assistant message
49
+ conversation.append({"role": "assistant", "content": item["output"]})
50
+
51
+ conversations.append(conversation)
52
+
53
+ return conversations
54
+
55
+
56
+ def load_tokenized_alpaca_dataset(
57
+ tokenizer: PreTrainedTokenizer,
58
+ path: str = "yahma/alpaca-cleaned",
59
+ split: str = "train",
60
+ cache_path: Optional[str] = None,
61
+ ):
62
+ """
63
+ Load and tokenized Alpaca dataset and Alpaca-like dataset.
64
+
65
+ Args:
66
+ tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenizing the dataset.
67
+ path (str, optional): The path to the Alpaca dataset. Defaults to "yahma/alpaca-cleaned".
68
+ split (str, optional): The dataset split to load (e.g., "train", "test"). Defaults to "train".
69
+ cache_path (Optional[str], optional): The path to cache the tokenized dataset. If provided and the cache exists,
70
+ the dataset will be loaded from the cache. Defaults to None.
71
+
72
+ Returns:
73
+ Dataset: The tokenized dataset.
74
+ """
75
+ if cache_path is not None and os.path.exists(cache_path):
76
+ dataset = load_from_disk(cache_path)
77
+ if split is not None and split in dataset:
78
+ return dataset[split]
79
+ else:
80
+ return dataset
81
+
82
+ dataset = load_dataset(path, split=split)
83
+
84
+ alpaca_data = dataset.to_list()
85
+ conversations = convert_alpaca_to_conversation(alpaca_data)
86
+ with timeit_context("Tokenizing dataset"):
87
+ tokenized_dataset = tokenizer.apply_chat_template(
88
+ conversations, return_dict=True
89
+ )
90
+ tokenized_dataset = Dataset.from_dict(tokenized_dataset)
91
+
92
+ if cache_path is not None and rank_zero_only.rank == 0:
93
+ tokenized_dataset.save_to_disk(cache_path)
94
+ return tokenized_dataset
95
+
96
+
97
+ def _tokenize_alpaca_dataset_with_template(
14
98
  dataset: Dataset,
15
99
  tokenizer: PreTrainedTokenizer,
16
100
  max_length: int = 2048,
@@ -32,6 +116,10 @@ def tokenize_alpaca_dataset(
32
116
  Returns:
33
117
  Tokenized dataset
34
118
  """
119
+ warnings.warn(
120
+ "This function is deprecated. Use `apply_chat_template` from `transformers` instead.",
121
+ DeprecationWarning,
122
+ )
35
123
 
36
124
  def prepare_samples(samples: Dict[str, List[str]]) -> Dict[str, List[List[int]]]:
37
125
  # Format prompts based on whether input field exists
@@ -115,7 +203,7 @@ def tokenize_alpaca_dataset(
115
203
  return tokenized_dataset
116
204
 
117
205
 
118
- def load_tokenized_alpaca_dataset_from_json(
206
+ def load_tokenized_alpaca_dataset_from_json_with_prompt(
119
207
  data_files: str,
120
208
  tokenizer: PreTrainedTokenizer,
121
209
  max_length: int,
@@ -138,5 +226,7 @@ def load_tokenized_alpaca_dataset_from_json(
138
226
  dataset = load_dataset("json", data_files=data_files)
139
227
  if split is not None:
140
228
  dataset = dataset[split]
141
- dataset = tokenize_alpaca_dataset(dataset, tokenizer, max_length=max_length)
229
+ dataset = _tokenize_alpaca_dataset_with_template(
230
+ dataset, tokenizer, max_length=max_length
231
+ )
142
232
  return dataset
@@ -7,13 +7,14 @@ from torch.nn.utils.rnn import pad_sequence
7
7
 
8
8
  def padded_collate_sft(
9
9
  batch: List[Dict[str, List[int]]],
10
- padding_idx: int = 0,
10
+ pad_token_id: int = 0,
11
11
  input_ids_key: str = "input_ids",
12
12
  attention_mask_key: Optional[str] = "attention_mask",
13
13
  labels_key: Optional[str] = "labels",
14
14
  ignore_idx: int = -100,
15
15
  ) -> Dict[str, torch.Tensor]:
16
- """Pad a batch of sequences to the longest sequence length in the batch, and
16
+ """
17
+ Pad (right) a batch of sequences to the longest sequence length in the batch, and
17
18
  convert integer lists to tensors.
18
19
 
19
20
  Args:
@@ -27,7 +28,7 @@ def padded_collate_sft(
27
28
  input_ids = pad_sequence(
28
29
  [torch.tensor(x[input_ids_key]) for x in batch],
29
30
  batch_first=True,
30
- padding_value=padding_idx,
31
+ padding_value=pad_token_id,
31
32
  )
32
33
  if attention_mask_key is not None and attention_mask_key in batch[0]:
33
34
  attention_mask = pad_sequence(
@@ -37,6 +38,12 @@ def padded_collate_sft(
37
38
  )
38
39
  else:
39
40
  attention_mask = None
41
+
42
+ for i, item in enumerate(batch):
43
+ # if labels_key not in item, copy input_ids to labels_key
44
+ if labels_key not in item:
45
+ item[labels_key] = item[input_ids_key]
46
+
40
47
  labels = pad_sequence(
41
48
  [torch.tensor(x[labels_key]) for x in batch],
42
49
  batch_first=True,
@@ -44,10 +51,70 @@ def padded_collate_sft(
44
51
  )
45
52
 
46
53
  if attention_mask is not None:
47
- return {
54
+ collated_batch = {
48
55
  input_ids_key: input_ids,
49
56
  attention_mask_key: attention_mask,
50
57
  labels_key: labels,
51
58
  }
52
59
  else:
53
- return {input_ids_key: input_ids, labels_key: labels}
60
+ collated_batch = {input_ids_key: input_ids, labels_key: labels}
61
+
62
+ for key in batch[0]:
63
+ if key not in [input_ids_key, attention_mask_key, labels_key]:
64
+ collated_batch[key] = [x[key] for x in batch]
65
+
66
+ return collated_batch
67
+
68
+
69
+ def bradley_terry_rm_collate(
70
+ batch: List[Dict[str, List[int]]],
71
+ pad_token_id: int = 0,
72
+ padding_side="right",
73
+ ):
74
+ """
75
+ Collate function for Bradley-Terry reward modeling.
76
+
77
+ Args:
78
+ batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs.
79
+ pad_token_id (int): Padding index for input ids. Defaults to 0.
80
+
81
+ Returns:
82
+ Dict[str, torch.Tensor]: Collated input and label tensors. The first half of the batch is the winner, and the second half is the loser.
83
+ """
84
+ converted_batch = []
85
+ for item in batch:
86
+ new_item = {
87
+ "input_ids": item["chosen_input_ids"],
88
+ "attention_mask": item["chosen_attention_mask"],
89
+ }
90
+ converted_batch.append(new_item)
91
+ for item in batch:
92
+ new_item = {
93
+ "input_ids": item["rejected_input_ids"],
94
+ "attention_mask": item["rejected_attention_mask"],
95
+ }
96
+ converted_batch.append(new_item)
97
+
98
+ input_ids = pad_sequence(
99
+ [torch.tensor(x["input_ids"]) for x in converted_batch],
100
+ batch_first=True,
101
+ padding_value=pad_token_id,
102
+ padding_side=padding_side,
103
+ )
104
+ attention_mask = pad_sequence(
105
+ [torch.tensor(x["attention_mask"]) for x in converted_batch],
106
+ batch_first=True,
107
+ padding_value=0,
108
+ padding_side=padding_side,
109
+ )
110
+
111
+ collated_batch = {"input_ids": input_ids, "attention_mask": attention_mask}
112
+ for key in batch[0]:
113
+ if key not in [
114
+ "chosen_input_ids",
115
+ "chosen_attention_mask",
116
+ "rejected_input_ids",
117
+ "rejected_attention_mask",
118
+ ]:
119
+ collated_batch[key] = [x[key] for x in batch]
120
+ return collated_batch
@@ -0,0 +1,50 @@
1
+ import os
2
+ from typing import TYPE_CHECKING, Optional
3
+
4
+ from datasets import Dataset, load_dataset, load_from_disk
5
+ from lightning.fabric.utilities import rank_zero_only
6
+ from tqdm.auto import tqdm
7
+
8
+ from fusion_bench.utils import timeit_context
9
+
10
+ from .alpaca import convert_alpaca_to_conversation
11
+
12
+ if TYPE_CHECKING:
13
+ from transformers import PreTrainedTokenizer
14
+
15
+
16
+ def load_tokenized_metamathqa(
17
+ tokenizer: "PreTrainedTokenizer",
18
+ path: str = "meta-math/MetaMathQA",
19
+ split: str = "train",
20
+ cache_path: Optional[str] = None,
21
+ ):
22
+ if cache_path is not None and os.path.exists(cache_path):
23
+ dataset = load_from_disk(cache_path)
24
+ if split is not None and split in dataset:
25
+ return dataset[split]
26
+ else:
27
+ return dataset
28
+
29
+ dataset = load_dataset(path, split=split)
30
+
31
+ # convert dataset to alpaca format and save to ../data/MetaMathQA.json
32
+ alpaca_dataset = []
33
+ for example in tqdm(dataset, disable=not rank_zero_only.rank == 0):
34
+ alpaca_example = {
35
+ "instruction": example["query"],
36
+ "input": "",
37
+ "output": example["response"],
38
+ }
39
+ alpaca_dataset.append(alpaca_example)
40
+
41
+ conversations = convert_alpaca_to_conversation(alpaca_dataset)
42
+ with timeit_context("Tokenizing dataset"):
43
+ tokenized_dataset = tokenizer.apply_chat_template(
44
+ conversations, return_dict=True
45
+ )
46
+ tokenized_dataset = Dataset.from_dict(tokenized_dataset)
47
+
48
+ if cache_path is not None and rank_zero_only.rank == 0:
49
+ tokenized_dataset.save_to_disk(cache_path)
50
+ return tokenized_dataset
@@ -0,0 +1,70 @@
1
+ import os
2
+ from copy import deepcopy
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ from datasets import Dataset, load_dataset, load_from_disk
6
+ from lightning.fabric.utilities import rank_zero_only
7
+ from tqdm.auto import tqdm
8
+
9
+ from fusion_bench.utils import timeit_context
10
+ import logging
11
+
12
+ if TYPE_CHECKING:
13
+ from transformers import PreTrainedTokenizer
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ def load_tokenized_preference_700k_for_rlhf(
19
+ tokenizer: "PreTrainedTokenizer",
20
+ path: str = "hendrydong/preference_700K",
21
+ split: str = "train",
22
+ num_proc: int = 8,
23
+ cache_path: Optional[str] = None,
24
+ ):
25
+ R"""
26
+ Load and tokenized Preference 700k dataset for Bradley-Terry ranking model.
27
+
28
+ The returned dataset contains the following fields:
29
+
30
+ - chosen_input_ids: The input token ids for the winner.
31
+ - chosen_attention_mask: The attention mask for the winner.
32
+ - rejected_input_ids: The input token ids for the loser.
33
+ - rejected_attention_mask: The attention mask for the loser.
34
+ """
35
+ if cache_path is not None and os.path.exists(cache_path):
36
+ dataset = load_from_disk(cache_path)
37
+ return dataset
38
+
39
+ dataset = load_dataset(path, split=split)
40
+
41
+ def tokenize(sample):
42
+ sample["chosen_chat"] = tokenizer.apply_chat_template(
43
+ sample["chosen"], tokenize=False, add_generation_prompt=False
44
+ )
45
+ sample["rejected_chat"] = tokenizer.apply_chat_template(
46
+ sample["rejected"], tokenize=False, add_generation_prompt=False
47
+ )
48
+
49
+ tokenized_pos = tokenizer(sample["chosen_chat"], truncation=True)
50
+ tokenized_neg = tokenizer(sample["rejected_chat"], truncation=True)
51
+
52
+ # Ensure that the chosen response does not contain an PAD token
53
+ sample["chosen_input_ids"] = tokenized_pos["input_ids"]
54
+ sample["chosen_attention_mask"] = tokenized_pos["attention_mask"]
55
+ if tokenizer.pad_token_id in tokenized_pos["input_ids"]:
56
+ log.warning(f"Prompt contains PAD token: {sample['chosen_chat']}")
57
+
58
+ sample["rejected_input_ids"] = tokenized_neg["input_ids"]
59
+ sample["rejected_attention_mask"] = tokenized_neg["attention_mask"]
60
+ # Ensure that the rejected response does not contain an PAD token
61
+ if tokenizer.pad_token_id in tokenized_neg["input_ids"]:
62
+ log.warning(f"Prompt contains PAD token: {sample['rejected_chat']}")
63
+
64
+ return sample
65
+
66
+ dataset = dataset.map(tokenize, num_proc=num_proc)
67
+
68
+ if cache_path is not None and rank_zero_only.rank == 0:
69
+ dataset.save_to_disk(cache_path)
70
+ return dataset
@@ -0,0 +1,90 @@
1
+ import os
2
+ from copy import deepcopy
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ from datasets import Dataset, load_dataset, load_from_disk
6
+ from lightning.fabric.utilities import rank_zero_only
7
+ from tqdm.auto import tqdm
8
+
9
+ from fusion_bench.utils import timeit_context
10
+
11
+ if TYPE_CHECKING:
12
+ from transformers import PreTrainedTokenizer
13
+
14
+
15
+ def load_tokenized_stanford_shp_for_rlhf(
16
+ tokenizer: "PreTrainedTokenizer",
17
+ path: str = "stanfordnlp/SHP",
18
+ split: str = "train",
19
+ num_proc: int = 8,
20
+ cache_path: Optional[str] = None,
21
+ ):
22
+ if cache_path is not None and os.path.isdir(cache_path):
23
+ dataset = load_from_disk(cache_path)
24
+ return dataset
25
+
26
+ dataset = load_dataset(path, split=split)
27
+
28
+ def tokenize(sample):
29
+ """
30
+ - history: the post title concatented to the post body (string)
31
+ - human_ref_A: text of comment A (string)
32
+ - human_ref_B: text of comment B (string)
33
+ - labels: the preference label -- it is 1 if A is preferred to B; 0 if B is preferred to A. This was randomized such that the label distribution is roughly 50/50. (integer)
34
+ """
35
+ # Create a conversation with the post title and body, followed by comments
36
+ conversation = [{"role": "user", "content": sample["history"]}]
37
+ if sample["labels"] == 0:
38
+ sample["chosen"] = deepcopy(conversation).append(
39
+ {"role": "assistant", "content": sample["human_ref_B"]}
40
+ )
41
+ sample["rejected"] = deepcopy(conversation).append(
42
+ {"role": "assistant", "content": sample["human_ref_A"]}
43
+ )
44
+ else:
45
+ sample["chosen"] = deepcopy(conversation).append(
46
+ {"role": "assistant", "content": sample["human_ref_A"]}
47
+ )
48
+ sample["rejected"] = deepcopy(conversation).append(
49
+ {"role": "assistant", "content": sample["human_ref_B"]}
50
+ )
51
+
52
+ # apply chat template
53
+ sample["chosen_chat"] = tokenizer.apply_chat_template(
54
+ sample["chosen"], tokenize=False, add_generation_prompt=False
55
+ )
56
+ sample["rejected_chat"] = tokenizer.apply_chat_template(
57
+ sample["rejected"], tokenize=False, add_generation_prompt=False
58
+ )
59
+
60
+ # tokenize the conversation
61
+ tokenized_pos = tokenizer(sample["chosen_chat"], truncation=True)
62
+ tokenized_neg = tokenizer(sample["rejected_chat"], truncation=True)
63
+
64
+ # Ensure that the chosen response does not contain an EOS token
65
+ sample["chosen_input_ids"] = tokenized_pos["input_ids"]
66
+ sample["chosen_attention_mask"] = tokenized_pos["attention_mask"]
67
+ assert (
68
+ tokenizer.eos_token_id not in tokenized_pos["input_ids"][:-1]
69
+ ), f"Prompt contains EOS token: {sample['positive']}"
70
+ if sample["chosen_input_ids"][-1] != tokenizer.eos_token_id:
71
+ sample["chosen_input_ids"].append(tokenizer.eos_token_id)
72
+ sample["chosen_attention_mask"].append(1)
73
+
74
+ sample["rejected_input_ids"] = tokenized_neg["input_ids"]
75
+ sample["rejected_attention_mask"] = tokenized_neg["attention_mask"]
76
+ # Ensure that the rejected response does not contain an EOS token
77
+ assert (
78
+ tokenizer.eos_token_id not in tokenized_neg["input_ids"][:-1]
79
+ ), f"Prompt contains EOS token: {sample['rejected']}"
80
+ if sample["rejected_input_ids"][-1] != tokenizer.eos_token_id:
81
+ sample["rejected_input_ids"].append(tokenizer.eos_token_id)
82
+ sample["rejected_attention_mask"].append(1)
83
+
84
+ return sample
85
+
86
+ dataset = dataset.map(tokenize, num_proc=num_proc)
87
+
88
+ if cache_path is not None and rank_zero_only.rank == 0:
89
+ dataset.save_to_disk(cache_path)
90
+ return dataset
@@ -0,0 +1,58 @@
1
+ import os
2
+ from typing import TYPE_CHECKING, Optional
3
+
4
+ from datasets import Dataset, load_dataset, load_from_disk
5
+ from lightning.fabric.utilities import rank_zero_only
6
+ from tqdm.auto import tqdm
7
+
8
+ from fusion_bench.utils import timeit_context
9
+
10
+ if TYPE_CHECKING:
11
+ from transformers import PreTrainedTokenizer
12
+
13
+
14
+ def load_tokenized_ultrachat_200k(
15
+ tokenizer: "PreTrainedTokenizer",
16
+ path: str = "HuggingFaceH4/ultrachat_200k",
17
+ split: str = "train_sft",
18
+ num_proc: int = 8,
19
+ cache_path: Optional[str] = None,
20
+ ):
21
+ R"""
22
+ Load and tokenized Ultrachat 200k dataset for Bradley-Terry ranking model.
23
+
24
+ The returned dataset contains the following fields:
25
+
26
+ - input_ids: The input token ids for the winner.
27
+ - attention_mask: The attention mask for the winner.
28
+ """
29
+ if cache_path is not None and os.path.exists(cache_path):
30
+ dataset = load_from_disk(cache_path)
31
+ return dataset
32
+
33
+ dataset = load_dataset(path, split=split)
34
+
35
+ def tokenize(sample):
36
+
37
+ # ? is it necessary to `.replace(tokenizer.bos_token, "")`?
38
+ sample["input_ids"] = tokenizer.apply_chat_template(
39
+ sample["messages"], tokenize=True, add_generation_prompt=False
40
+ )
41
+ sample["attention_mask"] = [1] * len(sample["input_ids"])
42
+
43
+ return sample
44
+
45
+ dataset = dataset.map(tokenize, num_proc=num_proc)
46
+
47
+ if cache_path is not None and rank_zero_only.rank == 0:
48
+ dataset.save_to_disk(cache_path)
49
+ return dataset
50
+
51
+
52
+ if __name__ == "__main__":
53
+ # Example usage and testing
54
+ from transformers import AutoTokenizer
55
+
56
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
57
+ dataset = load_tokenized_ultrachat_200k(tokenizer)
58
+ print(dataset)
File without changes
@@ -5,11 +5,12 @@ from typing import TYPE_CHECKING
5
5
  from fusion_bench.utils.lazy_imports import LazyImporter
6
6
 
7
7
  _import_structure = {
8
+ # --------------
8
9
  "base_algorithm": ["BaseModelFusionAlgorithm", "BaseAlgorithm"],
9
10
  "dummy": ["DummyAlgorithm"],
10
11
  # single task learning (fine-tuning)
11
12
  "classification": ["ImageClassificationFineTuningForCLIP"],
12
- "lm_finetune": ["FullFinetuneSFT", "PeftFinetuneSFT"],
13
+ "lm_finetune": ["FullFinetuneSFT", "PeftFinetuneSFT", "BradleyTerryRewardModeling"],
13
14
  # analysis
14
15
  "analysis": ["TaskVectorCosSimilarity", "TaskVectorViolinPlot"],
15
16
  # model ensemble methods
@@ -64,6 +65,7 @@ _import_structure = {
64
65
  ],
65
66
  "dawe": ["DataAdaptiveWeightEnsemblingForCLIP"],
66
67
  "we_moe": ["CLIPWeightEnsemblingMoEAlgorithm"],
68
+ "rankone_moe": ["CLIPRankOneMoEAlgorithm", "RankOneMoEAlgorithm"],
67
69
  "sparse_we_moe": [
68
70
  "SparseWeightEnsemblingMoEAlgorithm",
69
71
  "SparseCLIPWeightEnsemblingMoEAlgorithm",
@@ -134,6 +136,7 @@ if TYPE_CHECKING:
134
136
  PWEMoELinearScalarizationForCLIP,
135
137
  PWEMoExactParetoOptimalForCLIP,
136
138
  )
139
+ from .rankone_moe import CLIPRankOneMoEAlgorithm, RankOneMoEAlgorithm
137
140
  from .regmean import RegMeanAlgorithmForCLIP, RegMeanAlgorithmForGPT2
138
141
  from .simple_average import SimpleAverageAlgorithm
139
142
  from .smile_upscaling import (
@@ -1,6 +1,6 @@
1
1
  # flake8: noqa F401
2
2
  from .clip_layer_wise_adamerging import CLIPLayerWiseAdaMergingAlgorithm
3
3
  from .clip_task_wise_adamerging import CLIPTaskWiseAdaMergingAlgorithm
4
+ from .flan_t5_layer_wise_adamerging import FlanT5LayerWiseAdaMergingAlgorithm
4
5
  from .gpt2_layer_wise_adamerging import GPT2LayerWiseAdaMergingAlgorithm
5
6
  from .llama_adamerging import LayerWiseAdaMergingForLlamaSFT
6
- from .flan_t5_layer_wise_adamerging import FlanT5LayerWiseAdaMergingAlgorithm
@@ -1,12 +1,12 @@
1
1
  import logging
2
2
  import os
3
3
  from abc import abstractmethod
4
- from typing import Any, List, Mapping, Union, cast # noqa: F401
4
+ from typing import TYPE_CHECKING, Any, List, Mapping, TypeVar, Union, cast # noqa: F401
5
5
 
6
6
  import torch
7
7
  from lightning.fabric.utilities.rank_zero import rank_zero_only
8
8
  from omegaconf import DictConfig
9
- from torch import Tensor
9
+ from torch import Tensor, nn
10
10
  from torch.utils.data import DataLoader
11
11
  from tqdm.autonotebook import tqdm
12
12
 
@@ -19,10 +19,14 @@ from fusion_bench.models.wrappers.layer_wise_fusion import (
19
19
  get_layer_wise_weights,
20
20
  )
21
21
  from fusion_bench.utils.data import load_tensor_from_file
22
+ from fusion_bench.utils.type import TorchModelType
22
23
 
23
24
  from .entropy_loss import entropy_loss
24
25
  from .utils import get_memory_usage
25
26
 
27
+ if TYPE_CHECKING:
28
+ from fusion_bench.programs.fabric_fusion_program import FabricModelFusionProgram
29
+
26
30
  log = logging.getLogger(__name__)
27
31
 
28
32
 
@@ -31,6 +35,9 @@ class LayerWiseAdaMergingAlgorithm(
31
35
  LightningFabricMixin,
32
36
  SimpleProfilerMixin,
33
37
  ):
38
+ _program: "FabricModelFusionProgram"
39
+ """The program that this algorithm is running on."""
40
+
34
41
  """
35
42
  Implements the Layer-Wise AdaMerging Algorithm.
36
43
 
@@ -48,7 +55,7 @@ class LayerWiseAdaMergingAlgorithm(
48
55
  super().__init__(algorithm_config)
49
56
 
50
57
  @torch.no_grad()
51
- def construct_layer_wise_merged_model(self, modelpool: ModelPool):
58
+ def construct_layer_wise_merged_model(self, modelpool: "ModelPool"):
52
59
  """
53
60
  Constructs a wrapped layer-wise merged model from model pool.
54
61
 
@@ -183,7 +190,7 @@ class LayerWiseAdaMergingAlgorithm(
183
190
  """
184
191
  pass
185
192
 
186
- def test_time_adaptation(self, module: LayerWiseMergedModel):
193
+ def test_time_adaptation(self, module: "LayerWiseMergedModel[TorchModelType]"):
187
194
  """
188
195
  Perform test-time adaptation on the merged model.
189
196
 
@@ -49,7 +49,7 @@ class MinNormSolver:
49
49
  return gamma, cost
50
50
 
51
51
  def _min_norm_2d(vecs, dps):
52
- """
52
+ R"""
53
53
  Find the minimum norm solution as combination of two points
54
54
  This is correct only in 2D
55
55
  ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
@@ -85,7 +85,7 @@ class MinNormSolver:
85
85
  return sol, dps
86
86
 
87
87
  def _projection2simplex(y):
88
- """
88
+ R"""
89
89
  Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
90
90
  """
91
91
  m = len(y)
@@ -117,7 +117,7 @@ class MinNormSolver:
117
117
  return next_point
118
118
 
119
119
  def find_min_norm_element(vecs):
120
- """
120
+ R"""
121
121
  Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
122
122
  as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
123
123
  It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
@@ -163,7 +163,7 @@ class MinNormSolver:
163
163
  sol_vec = new_sol_vec
164
164
 
165
165
  def find_min_norm_element_FW(vecs):
166
- """
166
+ R"""
167
167
  Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
168
168
  as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
169
169
  It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})