fusion-bench 0.2.14__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,19 @@
1
+ """
2
+ Online documentation for this module: https://tanganke.github.io/fusion_bench/modelpool/causal_lm
3
+ """
4
+
1
5
  import logging
2
6
  import os
3
7
  from copy import deepcopy
4
- from typing import Any, Optional, TypeAlias, Union, cast # noqa: F401
8
+ from typing import Any, Dict, Optional, TypeAlias, Union, cast # noqa: F401
5
9
 
6
10
  import peft
7
11
  from omegaconf import DictConfig, flag_override
8
12
  from torch import nn
9
13
  from torch.nn.modules import Module
10
14
  from transformers import (
11
- LlamaForCausalLM,
12
- MistralForCausalLM,
15
+ AutoModelForCausalLM,
16
+ AutoTokenizer,
13
17
  PreTrainedModel,
14
18
  PreTrainedTokenizer,
15
19
  )
@@ -21,8 +25,6 @@ from fusion_bench.utils.dtype import parse_dtype
21
25
 
22
26
  log = logging.getLogger(__name__)
23
27
 
24
- CausalLM: TypeAlias = Union[LlamaForCausalLM, MistralForCausalLM, Any]
25
-
26
28
 
27
29
  class CausalLMPool(BaseModelPool):
28
30
  _config_mapping = BaseModelPool._config_mapping | {
@@ -56,17 +58,78 @@ class CausalLMPool(BaseModelPool):
56
58
  model_name_or_config: str | DictConfig,
57
59
  *args,
58
60
  **kwargs,
59
- ) -> LlamaForCausalLM | MistralForCausalLM | nn.Module:
61
+ ) -> PreTrainedModel:
62
+ """
63
+ Example of YAML config:
64
+
65
+ ```yaml
66
+ models:
67
+ _pretrained_: path_to_pretrained_model # if a plain string, it will be passed to AutoModelForCausalLM.from_pretrained
68
+ model_a: path_to_model_a
69
+ model_b: path_to_model_b
70
+ ```
71
+
72
+ or equivalently,
73
+
74
+ ```yaml
75
+ models:
76
+ _pretrained_:
77
+ _target_: transformers.AutoModelForCausalLM # any callable that returns a model
78
+ pretrained_model_name_or_path: path_to_pretrained_model
79
+ model_a:
80
+ _target_: transformers.AutoModelForCausalLM
81
+ pretrained_model_name_or_path: path_to_model_a
82
+ model_b:
83
+ _target_: transformers.AutoModelForCausalLM
84
+ pretrained_model_name_or_path: path_to_model_b
85
+ ```
86
+ """
60
87
  model_kwargs = deepcopy(self._model_kwargs)
61
88
  model_kwargs.update(kwargs)
89
+
62
90
  if isinstance(model_name_or_config, str):
63
91
  log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
64
- return super().load_model(model_name_or_config, *args, **model_kwargs)
92
+ if model_name_or_config in self._models.keys():
93
+ model_config = self._models[model_name_or_config]
94
+ if isinstance(model_config, str):
95
+ # model_config is a string
96
+ model = AutoModelForCausalLM.from_pretrained(
97
+ model_config,
98
+ *args,
99
+ **model_kwargs,
100
+ )
101
+ return model
102
+ elif isinstance(model_name_or_config, (DictConfig, Dict)):
103
+ model_config = model_name_or_config
104
+
105
+ model = instantiate(model_config, *args, **model_kwargs)
106
+ return model
65
107
 
66
108
  def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
109
+ """
110
+ Example of YAML config:
111
+
112
+ ```yaml
113
+ tokenizer: google/gemma-2-2b-it # if a plain string, it will be passed to AutoTokenizer.from_pretrained
114
+ ```
115
+
116
+ or equivalently,
117
+
118
+ ```yaml
119
+ tokenizer:
120
+ _target_: transformers.AutoTokenizer # any callable that returns a tokenizer
121
+ pretrained_model_name_or_path: google/gemma-2-2b-it
122
+ ```
123
+
124
+ Returns:
125
+ PreTrainedTokenizer: The tokenizer.
126
+ """
67
127
  assert self._tokenizer is not None, "Tokenizer is not defined in the config"
68
128
  log.info("Loading tokenizer.", stacklevel=2)
69
- tokenizer = instantiate(self._tokenizer, *args, **kwargs)
129
+ if isinstance(self._tokenizer, str):
130
+ tokenizer = AutoTokenizer.from_pretrained(self._tokenizer, *args, **kwargs)
131
+ else:
132
+ tokenizer = instantiate(self._tokenizer, *args, **kwargs)
70
133
  return tokenizer
71
134
 
72
135
  @override
@@ -113,7 +176,7 @@ class CausalLMBackbonePool(CausalLMPool):
113
176
  def load_model(
114
177
  self, model_name_or_config: str | DictConfig, *args, **kwargs
115
178
  ) -> Module:
116
- model: Union[MistralForCausalLM, LlamaForCausalLM, Any] = super().load_model(
179
+ model: AutoModelForCausalLM = super().load_model(
117
180
  model_name_or_config, *args, **kwargs
118
181
  )
119
182
  return model.model.layers
@@ -126,7 +189,7 @@ def load_peft_causal_lm(
126
189
  is_trainable: bool = True,
127
190
  merge_and_unload: bool = False,
128
191
  ):
129
- base_model = LlamaForCausalLM.from_pretrained(
192
+ base_model = AutoModelForCausalLM.from_pretrained(
130
193
  base_model_path, torch_dtype=torch_dtype
131
194
  )
132
195
  model = peft.PeftModel.from_pretrained(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fusion_bench
3
- Version: 0.2.14
3
+ Version: 0.2.15
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  License: MIT License
@@ -63,7 +63,7 @@ Dynamic: license-file
63
63
 
64
64
  </div>
65
65
 
66
- > [!TIP]
66
+ > [!TIP]
67
67
  > Documentation is available at [tanganke.github.io/fusion_bench/](https://tanganke.github.io/fusion_bench/).
68
68
 
69
69
  ## Overview
@@ -157,6 +157,9 @@ pip install -e ".[lm-eval-harness]"
157
157
  This will install the latest version of fusion-bench and the dependencies required for LM-Eval Harness.
158
158
  Documentation for using LM-Eval Harness within FusionBench framework can be found at [this online documentation](https://tanganke.github.io/fusion_bench/taskpool/lm_eval_harness) or in the [`docs/taskpool/lm_eval_harness.md`](docs/taskpool/lm_eval_harness.md) markdown file.
159
159
 
160
+ > [!TIP]
161
+ > Documentation for merging large language models using FusionBench can be found at [this online documentation](https://tanganke.github.io/fusion_bench/modelpool/causal_lm) or in the [`docs/modelpool/causal_lm.md`](docs/modelpool/causal_lm.md) markdown file.
162
+
160
163
  ## Introduction to Deep Model Fusion
161
164
 
162
165
  Deep model fusion is a technique that merges, ensemble, or fuse multiple deep neural networks to obtain a unified model.
@@ -219,7 +219,7 @@ fusion_bench/modelpool/huggingface_automodel.py,sha256=OJ6EyYyjNv1_Bhjn-zli-e__B
219
219
  fusion_bench/modelpool/huggingface_gpt2_classification.py,sha256=j8nicVwtoLXY4RPE2dcepeEB3agBKkkH-xA3yMj1czw,2014
220
220
  fusion_bench/modelpool/nyuv2_modelpool.py,sha256=btuXmYxwfjI6MnGakhoOf53Iyb9fxYH20CavGTrTcnA,1375
221
221
  fusion_bench/modelpool/causal_lm/__init__.py,sha256=F432-aDIgAbUITj4GNZS9dgUKKhaDMCbTeHB-9MecaQ,99
222
- fusion_bench/modelpool/causal_lm/causal_lm.py,sha256=k0eOOcFbswVgBYhM9CEXvdCRU9zVC8Gw78QaiMWzeWo,4487
222
+ fusion_bench/modelpool/causal_lm/causal_lm.py,sha256=fO8lF8YWwoe43sVVOqHW9Ike7x-924-I6QQgZqx9EgA,6505
223
223
  fusion_bench/modelpool/clip_vision/__init__.py,sha256=3b9gN2bWUsoA1EmpitnIMnIlX7nklxbkn4WJ0QJtS2c,43
224
224
  fusion_bench/modelpool/clip_vision/modelpool.py,sha256=JH1wLdWefvE242SYpXTnoSLkKX-YcadnidWd2bo8tWQ,5486
225
225
  fusion_bench/modelpool/openclip_vision/__init__.py,sha256=QDmAitKqUwRygN9QncdS_kGWZdfTKL4uUifC8xh9c10,47
@@ -393,7 +393,7 @@ fusion_bench/utils/plot/token_notebook.py,sha256=bsntXf46Zz_RavTxNiB9c3-KvHw7LFw
393
393
  fusion_bench/utils/strenum/__init__.py,sha256=id9ORi1uXrDxhbmVxitJ1KDwLS4H3AAwFpaK5h1cQzw,8531
394
394
  fusion_bench/utils/strenum/_name_mangler.py,sha256=o11M5-bURW2RBvRTYXFQIPNeqLzburdoWLIqk8X3ydw,3397
395
395
  fusion_bench/utils/strenum/_version.py,sha256=6JQRo9LcvODbCOeVFYQb9HNJ_J9XiG_Zbn8ws2A3BV8,18466
396
- fusion_bench-0.2.14.dist-info/licenses/LICENSE,sha256=nhnOJlw4CPuPVE0qvkGmxfFgHmKi-6nzXvTu8t0NUdg,1066
396
+ fusion_bench-0.2.15.dist-info/licenses/LICENSE,sha256=nhnOJlw4CPuPVE0qvkGmxfFgHmKi-6nzXvTu8t0NUdg,1066
397
397
  fusion_bench_config/README.md,sha256=Lc8YSBJ5oxf9KV5kKDivJ9LRyGuraGQPmBbgbdVA-j4,703
398
398
  fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml,sha256=7IxLQoLRz-sRWyV8Vqc5kQcmYE_9YQz2_77pmvAkum8,1207
399
399
  fusion_bench_config/fabric_model_fusion.yaml,sha256=5iPgaM8UOhuvBW2Hap_csst-eqlYRwb_lru8ngjrZ_g,948
@@ -729,6 +729,14 @@ fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml,sha256=MpgshGtmM
729
729
  fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml,sha256=Kbpam1Hds5URMP35dXGdVibH-vTmYPh3xHMkhj6Mgtg,648
730
730
  fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml,sha256=FynhZ1PRvyzsyzrHIuMpGgQGRMlu_xI7earm-CeIVeY,824
731
731
  fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml,sha256=zQWfp7mYm6jQ8g41Eeh2d9vAbocZJ5btPX1ft9QpEZU,546
732
+ fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B-Instruct.yaml,sha256=NDq_prH-b9Vw7lRjsyJIcbeF4MXVVdszxK1FPJxIJYs,453
733
+ fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B.yaml,sha256=Mg_z2vnw7IkNPoMvhl_Ja6gT9tX942sqaNfjXQRzBvg,390
734
+ fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B-Instruct.yaml,sha256=SfPEji6mWx9Dw48rE0B8MDrYv2NVLC-S98DK5xaU6So,453
735
+ fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B.yaml,sha256=2vpOp9t8SUP2rkBw21mqwRYApkqXQiaYXcZm2oxLox4,390
736
+ fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b-it.yaml,sha256=8up_cqEhabGeK6l6tMha9DJzsPoEIFN8bS_Kwv7LmCc,389
737
+ fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b.yaml,sha256=SODG0kcnAP6yC0_J_SpSVMRV-v5qGV22gcWdiBaZo1I,368
738
+ fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b-it.yaml,sha256=zwInWJS8yrhch4vOL1ypRKNWWpJKlhQsyY0Ln14CC-M,389
739
+ fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b.yaml,sha256=ufmu4b3lyxn2XLDMVYxP-bKwYaGTjB5-JoYXLG8v8tY,368
732
740
  fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md,sha256=DC0HF-isCHshipHTC0Rof6GvjTUa0i2DVQZKrklQQlU,2416
733
741
  fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml,sha256=jbJqqciORJQknpSzh2zKiFm6VKDOsmaSk9XfPCVmHGg,1220
734
742
  fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml,sha256=q2_E2R1wIOdxd-AF-wjXkPO64gJgD27YXsZ8FFLWUIo,1607
@@ -790,8 +798,8 @@ fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml,sha256=45kSz44pc
790
798
  fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml,sha256=GjpiiRownrBCpl-TNwWRW2PYePbF-Cl99jlLNPrK5T4,1017
791
799
  fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml,sha256=WwiYMQKehtJixDPnu5o3vcWe4yJksXTWRqOzm3uVWXQ,1017
792
800
  fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml,sha256=xGRt0J9joXTzWUew6DvoYprAWlPXhaVFw5AX4im5VQw,1017
793
- fusion_bench-0.2.14.dist-info/METADATA,sha256=X13MPJ_FA0D5Gc5T-CvbcYOK03QtTiyIHnDNbI7_aOo,20904
794
- fusion_bench-0.2.14.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
795
- fusion_bench-0.2.14.dist-info/entry_points.txt,sha256=iUQ8MCJvda7HP4vYh2n1Teoapb4G9PBVYZkAfcc5SHU,116
796
- fusion_bench-0.2.14.dist-info/top_level.txt,sha256=BuO4TL6iHL_2yPBUX9-LlIrHRczA_BNMIFwweK0PQEI,13
797
- fusion_bench-0.2.14.dist-info/RECORD,,
801
+ fusion_bench-0.2.15.dist-info/METADATA,sha256=abOyRl-ejl7CvLRCaRP20vn7rdb5OF92GxS_S9qTK3Q,21171
802
+ fusion_bench-0.2.15.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
803
+ fusion_bench-0.2.15.dist-info/entry_points.txt,sha256=iUQ8MCJvda7HP4vYh2n1Teoapb4G9PBVYZkAfcc5SHU,116
804
+ fusion_bench-0.2.15.dist-info/top_level.txt,sha256=BuO4TL6iHL_2yPBUX9-LlIrHRczA_BNMIFwweK0PQEI,13
805
+ fusion_bench-0.2.15.dist-info/RECORD,,
@@ -0,0 +1,11 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+ models:
3
+ _pretrained_: meta-llama/Llama-3.1-8B-Instruct
4
+ instruction: MergeBench/Llama-3.1-8B-Instruct_instruction
5
+ math: MergeBench/Llama-3.1-8B-Instruct_math
6
+ coding: MergeBench/Llama-3.1-8B-Instruct_coding
7
+ multilingual: MergeBench/Llama-3.1-8B-Instruct_multilingual
8
+ safety: MergeBench/Llama-3.1-8B-Instruct_safety
9
+ model_kwargs:
10
+ torch_dtype: bfloat16
11
+ tokenizer: meta-llama/Llama-3.1-8B-Instruct
@@ -0,0 +1,11 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+ models:
3
+ _pretrained_: meta-llama/Llama-3.1-8B
4
+ instruction: MergeBench/Llama-3.1-8B_instruction
5
+ math: MergeBench/Llama-3.1-8B_math
6
+ coding: MergeBench/Llama-3.1-8B_coding
7
+ multilingual: MergeBench/Llama-3.1-8B_multilingual
8
+ safety: MergeBench/Llama-3.1-8B_safety
9
+ model_kwargs:
10
+ torch_dtype: bfloat16
11
+ tokenizer: meta-llama/Llama-3.1-8B
@@ -0,0 +1,11 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+ models:
3
+ _pretrained_: meta-llama/Llama-3.2-3B-Instruct
4
+ instruction: MergeBench/Llama-3.2-3B-Instruct_instruction
5
+ math: MergeBench/Llama-3.2-3B-Instruct_math
6
+ coding: MergeBench/Llama-3.2-3B-Instruct_coding
7
+ multilingual: MergeBench/Llama-3.2-3B-Instruct_multilingual
8
+ safety: MergeBench/Llama-3.2-3B-Instruct_safety
9
+ model_kwargs:
10
+ torch_dtype: bfloat16
11
+ tokenizer: meta-llama/Llama-3.2-3B-Instruct
@@ -0,0 +1,11 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+ models:
3
+ _pretrained_: meta-llama/Llama-3.2-3B
4
+ instruction: MergeBench/Llama-3.2-3B_instruction
5
+ math: MergeBench/Llama-3.2-3B_math
6
+ coding: MergeBench/Llama-3.2-3B_coding
7
+ multilingual: MergeBench/Llama-3.2-3B_multilingual
8
+ safety: MergeBench/Llama-3.2-3B_safety
9
+ model_kwargs:
10
+ torch_dtype: bfloat16
11
+ tokenizer: meta-llama/Llama-3.2-3B
@@ -0,0 +1,11 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+ models:
3
+ _pretrained_: google/gemma-2-2b-it
4
+ instruction: MergeBench/gemma-2-2b-it_instruction
5
+ math: MergeBench/gemma-2-2b-it_math
6
+ coding: MergeBench/gemma-2-2b-it_coding
7
+ multilingual: MergeBench/gemma-2-2b-it_multilingual
8
+ safety: MergeBench/gemma-2-2b-it_safety
9
+ model_kwargs:
10
+ torch_dtype: bfloat16
11
+ tokenizer: google/gemma-2-2b-it
@@ -0,0 +1,11 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+ models:
3
+ _pretrained_: google/gemma-2-2b
4
+ instruction: MergeBench/gemma-2-2b_instruction
5
+ math: MergeBench/gemma-2-2b_math
6
+ coding: MergeBench/gemma-2-2b_coding
7
+ multilingual: MergeBench/gemma-2-2b_multilingual
8
+ safety: MergeBench/gemma-2-2b_safety
9
+ model_kwargs:
10
+ torch_dtype: bfloat16
11
+ tokenizer: google/gemma-2-2b
@@ -0,0 +1,11 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+ models:
3
+ _pretrained_: google/gemma-2-9b-it
4
+ instruction: MergeBench/gemma-2-9b-it_instruction
5
+ math: MergeBench/gemma-2-9b-it_math
6
+ coding: MergeBench/gemma-2-9b-it_coding
7
+ multilingual: MergeBench/gemma-2-9b-it_multilingual
8
+ safety: MergeBench/gemma-2-9b-it_safety
9
+ model_kwargs:
10
+ torch_dtype: bfloat16
11
+ tokenizer: google/gemma-2-9b-it
@@ -0,0 +1,11 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+ models:
3
+ _pretrained_: google/gemma-2-9b
4
+ instruction: MergeBench/gemma-2-9b_instruction
5
+ math: MergeBench/gemma-2-9b_math
6
+ coding: MergeBench/gemma-2-9b_coding
7
+ multilingual: MergeBench/gemma-2-9b_multilingual
8
+ safety: MergeBench/gemma-2-9b_safety
9
+ model_kwargs:
10
+ torch_dtype: bfloat16
11
+ tokenizer: google/gemma-2-9b