fusion-bench 0.2.21__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fusion_bench/__init__.py +21 -2
  2. fusion_bench/constants/__init__.py +1 -0
  3. fusion_bench/constants/runtime.py +57 -0
  4. fusion_bench/method/__init__.py +8 -2
  5. fusion_bench/method/bitdelta/__init__.py +1 -0
  6. fusion_bench/method/classification/clip_finetune.py +1 -1
  7. fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
  8. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
  9. fusion_bench/method/linear/simple_average_for_llama.py +16 -11
  10. fusion_bench/method/simple_average.py +7 -7
  11. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  12. fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
  13. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  14. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
  15. fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
  16. fusion_bench/method/we_moe/__init__.py +1 -0
  17. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  18. fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
  19. fusion_bench/method/we_moe/utils.py +15 -0
  20. fusion_bench/method/weighted_average/llama.py +1 -1
  21. fusion_bench/mixins/clip_classification.py +11 -42
  22. fusion_bench/mixins/serialization.py +18 -8
  23. fusion_bench/modelpool/causal_lm/causal_lm.py +32 -33
  24. fusion_bench/models/__init__.py +5 -0
  25. fusion_bench/models/hf_utils.py +65 -87
  26. fusion_bench/models/model_card_templates/default.md +46 -0
  27. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  28. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
  29. fusion_bench/models/modeling_smile_mistral/__init__.py +1 -1
  30. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
  31. fusion_bench/programs/fabric_fusion_program.py +29 -60
  32. fusion_bench/scripts/cli.py +34 -1
  33. fusion_bench/taskpool/clip_vision/taskpool.py +9 -4
  34. fusion_bench/utils/__init__.py +1 -0
  35. fusion_bench/utils/cache_utils.py +101 -1
  36. fusion_bench/utils/fabric.py +2 -2
  37. fusion_bench/utils/lazy_imports.py +23 -0
  38. fusion_bench/utils/lazy_state_dict.py +38 -3
  39. fusion_bench/utils/modelscope.py +3 -3
  40. fusion_bench/utils/path.py +56 -0
  41. fusion_bench/utils/pylogger.py +1 -1
  42. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +1 -23
  43. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +53 -45
  44. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  45. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  46. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  47. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
  48. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  49. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
  50. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
  51. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
  52. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
  53. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
@@ -90,14 +90,21 @@ def auto_register_config(cls):
90
90
  # Auto-register parameters in _config_mapping
91
91
  if not "_config_mapping" in cls.__dict__:
92
92
  cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", {}))
93
+ registered_parameters = tuple(cls._config_mapping.values())
94
+
93
95
  for param_name in list(sig.parameters.keys())[1:]: # Skip 'self'
94
- if sig.parameters[param_name].kind not in [
95
- _ParameterKind.VAR_POSITIONAL,
96
- _ParameterKind.VAR_KEYWORD,
97
- ]:
96
+ if (
97
+ sig.parameters[param_name].kind
98
+ not in [
99
+ _ParameterKind.VAR_POSITIONAL,
100
+ _ParameterKind.VAR_KEYWORD,
101
+ ]
102
+ ) and (param_name not in registered_parameters):
98
103
  cls._config_mapping[param_name] = param_name
99
104
 
100
105
  def __init__(self, *args, **kwargs):
106
+ nonlocal original_init, registered_parameters
107
+
101
108
  # auto-register the attributes based on the signature
102
109
  sig = inspect.signature(original_init)
103
110
  param_names = list(sig.parameters.keys())[1:] # Skip 'self'
@@ -114,10 +121,13 @@ def auto_register_config(cls):
114
121
 
115
122
  # Handle keyword arguments and defaults
116
123
  for param_name in param_names:
117
- if sig.parameters[param_name].kind not in [
118
- _ParameterKind.VAR_POSITIONAL,
119
- _ParameterKind.VAR_KEYWORD,
120
- ]:
124
+ if (
125
+ sig.parameters[param_name].kind
126
+ not in [
127
+ _ParameterKind.VAR_POSITIONAL,
128
+ _ParameterKind.VAR_KEYWORD,
129
+ ]
130
+ ) and (param_name not in registered_parameters):
121
131
  # Skip if already set by positional argument
122
132
  param_index = param_names.index(param_name)
123
133
  if param_index >= 0 and param_index < len(args):
@@ -8,7 +8,7 @@ from copy import deepcopy
8
8
  from typing import Any, Dict, Optional, TypeAlias, Union, cast # noqa: F401
9
9
 
10
10
  import peft
11
- from omegaconf import DictConfig, flag_override
11
+ from omegaconf import DictConfig, OmegaConf, flag_override
12
12
  from torch import nn
13
13
  from torch.nn.modules import Module
14
14
  from transformers import (
@@ -19,43 +19,32 @@ from transformers import (
19
19
  )
20
20
  from typing_extensions import override
21
21
 
22
- from fusion_bench.modelpool import BaseModelPool
23
- from fusion_bench.utils import instantiate
24
- from fusion_bench.utils.dtype import parse_dtype
22
+ from fusion_bench import (
23
+ BaseModelPool,
24
+ auto_register_config,
25
+ import_object,
26
+ instantiate,
27
+ parse_dtype,
28
+ )
25
29
  from fusion_bench.utils.lazy_state_dict import LazyStateDict
26
- from fusion_bench.utils.packages import import_object
27
30
 
28
31
  log = logging.getLogger(__name__)
29
32
 
30
33
 
34
+ @auto_register_config
31
35
  class CausalLMPool(BaseModelPool):
32
- _config_mapping = BaseModelPool._config_mapping | {
33
- "_tokenizer": "tokenizer",
34
- "_model_kwargs": "model_kwargs",
35
- "load_lazy": "load_lazy",
36
- }
37
-
38
36
  def __init__(
39
37
  self,
40
38
  models,
41
39
  *,
42
- tokenizer: Optional[DictConfig],
40
+ tokenizer: Optional[DictConfig | str],
43
41
  model_kwargs: Optional[DictConfig] = None,
44
- load_lazy: bool = False,
42
+ enable_lazy_loading: bool = False,
45
43
  **kwargs,
46
44
  ):
47
45
  super().__init__(models, **kwargs)
48
- # process `model_kwargs`
49
- self._tokenizer = tokenizer
50
- self._model_kwargs = model_kwargs
51
- if self._model_kwargs is None:
52
- self._model_kwargs = DictConfig({})
53
- with flag_override(self._model_kwargs, "allow_objects", True):
54
- if hasattr(self._model_kwargs, "torch_dtype"):
55
- self._model_kwargs.torch_dtype = parse_dtype(
56
- self._model_kwargs.torch_dtype
57
- )
58
- self.load_lazy = load_lazy
46
+ if model_kwargs is None:
47
+ self.model_kwargs = DictConfig({})
59
48
 
60
49
  def get_model_path(self, model_name: str):
61
50
  model_name_or_config = self._models[model_name]
@@ -66,6 +55,16 @@ class CausalLMPool(BaseModelPool):
66
55
  else:
67
56
  raise RuntimeError("Invalid model configuration")
68
57
 
58
+ def get_model_kwargs(self):
59
+ model_kwargs = (
60
+ OmegaConf.to_container(self.model_kwargs, resolve=True)
61
+ if isinstance(self.model_kwargs, DictConfig)
62
+ else self.model_kwargs
63
+ )
64
+ if "torch_dtype" in model_kwargs:
65
+ model_kwargs["torch_dtype"] = parse_dtype(model_kwargs["torch_dtype"])
66
+ return model_kwargs
67
+
69
68
  @override
70
69
  def load_model(
71
70
  self,
@@ -98,7 +97,7 @@ class CausalLMPool(BaseModelPool):
98
97
  pretrained_model_name_or_path: path_to_model_b
99
98
  ```
100
99
  """
101
- model_kwargs = deepcopy(self._model_kwargs)
100
+ model_kwargs = self.get_model_kwargs()
102
101
  model_kwargs.update(kwargs)
103
102
 
104
103
  if isinstance(model_name_or_config, str):
@@ -108,7 +107,7 @@ class CausalLMPool(BaseModelPool):
108
107
  model_config = self._models[model_name_or_config]
109
108
  if isinstance(model_config, str):
110
109
  # model_config is a string
111
- if not self.load_lazy:
110
+ if not self.enable_lazy_loading:
112
111
  model = AutoModelForCausalLM.from_pretrained(
113
112
  model_config,
114
113
  *args,
@@ -126,7 +125,7 @@ class CausalLMPool(BaseModelPool):
126
125
  elif isinstance(model_name_or_config, (DictConfig, Dict)):
127
126
  model_config = model_name_or_config
128
127
 
129
- if not self.load_lazy:
128
+ if not self.enable_lazy_loading:
130
129
  model = instantiate(model_config, *args, **model_kwargs)
131
130
  else:
132
131
  meta_module_class = model_config.pop("_target_")
@@ -158,12 +157,12 @@ class CausalLMPool(BaseModelPool):
158
157
  Returns:
159
158
  PreTrainedTokenizer: The tokenizer.
160
159
  """
161
- assert self._tokenizer is not None, "Tokenizer is not defined in the config"
160
+ assert self.tokenizer is not None, "Tokenizer is not defined in the config"
162
161
  log.info("Loading tokenizer.", stacklevel=2)
163
- if isinstance(self._tokenizer, str):
164
- tokenizer = AutoTokenizer.from_pretrained(self._tokenizer, *args, **kwargs)
162
+ if isinstance(self.tokenizer, str):
163
+ tokenizer = AutoTokenizer.from_pretrained(self.tokenizer, *args, **kwargs)
165
164
  else:
166
- tokenizer = instantiate(self._tokenizer, *args, **kwargs)
165
+ tokenizer = instantiate(self.tokenizer, *args, **kwargs)
167
166
  return tokenizer
168
167
 
169
168
  @override
@@ -213,12 +212,12 @@ class CausalLMBackbonePool(CausalLMPool):
213
212
  def load_model(
214
213
  self, model_name_or_config: str | DictConfig, *args, **kwargs
215
214
  ) -> Module:
216
- if self.load_lazy:
215
+ if self.enable_lazy_loading:
217
216
  log.warning(
218
217
  "CausalLMBackbonePool does not support lazy loading. "
219
218
  "Falling back to normal loading."
220
219
  )
221
- self.load_lazy = False
220
+ self.enable_lazy_loading = False
222
221
  model: AutoModelForCausalLM = super().load_model(
223
222
  model_name_or_config, *args, **kwargs
224
223
  )
@@ -2,4 +2,9 @@
2
2
  from fusion_bench.utils import LazyStateDict
3
3
 
4
4
  from . import separate_io, utils
5
+ from .hf_utils import (
6
+ create_default_model_card,
7
+ load_model_card_template,
8
+ save_pretrained_with_remote_code,
9
+ )
5
10
  from .parameter_dict import ParameterDictModel
@@ -5,23 +5,63 @@ This module contains utilities for working with Hugging Face models.
5
5
  import inspect
6
6
  import os
7
7
  import shutil
8
- from typing import Optional, cast
8
+ from typing import List, Optional, cast
9
9
 
10
- from omegaconf import OmegaConf
10
+ from omegaconf import DictConfig, OmegaConf
11
11
  from transformers.modeling_utils import PreTrainedModel
12
12
 
13
- from fusion_bench import BaseAlgorithm, BaseModelPool
14
- from fusion_bench.utils.pylogger import getRankZeroLogger
13
+ from fusion_bench.utils.pylogger import get_rankzero_logger
15
14
 
16
- log = getRankZeroLogger(__name__)
15
+ log = get_rankzero_logger(__name__)
17
16
 
18
17
  __all__ = [
18
+ "load_model_card_template",
19
19
  "save_pretrained_with_remote_code",
20
- "generate_readme_head",
21
- "generate_readme_body",
22
- "generate_complete_readme",
20
+ "create_default_model_card",
23
21
  ]
24
22
 
23
+ MODEL_CARD_TEMPLATE_DIRS = [
24
+ os.path.join(os.path.dirname(__file__), "model_card_templates")
25
+ ]
26
+
27
+
28
+ def load_model_card_template(basename: str) -> str:
29
+ """
30
+ Load a model card template from file.
31
+
32
+ Searches for a template file by name, first checking if the name is a direct file path,
33
+ then searching through predefined template directories.
34
+
35
+ Args:
36
+ name (str): The name of the template file or a direct file path to the template.
37
+
38
+ Returns:
39
+ str: The contents of the template file as a string.
40
+
41
+ Raises:
42
+ FileNotFoundError: If the template file is not found in any of the search locations.
43
+ """
44
+ if os.path.exists(basename):
45
+ return open(basename).read()
46
+
47
+ for template_dir in MODEL_CARD_TEMPLATE_DIRS:
48
+ template_path = os.path.join(template_dir, basename)
49
+ if os.path.exists(template_path):
50
+ return open(template_path).read()
51
+
52
+ raise FileNotFoundError(f"Model card template '{basename}' not found.")
53
+
54
+
55
+ def try_to_yaml(config):
56
+ if config is None:
57
+ return None
58
+
59
+ try:
60
+ return OmegaConf.to_yaml(config, resolve=True, sort_keys=True)
61
+ except Exception as e:
62
+ log.error(f"Failed to convert config to YAML: {e}. Return `None`.")
63
+ return None
64
+
25
65
 
26
66
  def save_pretrained_with_remote_code(
27
67
  model: PreTrainedModel,
@@ -99,84 +139,22 @@ def save_pretrained_with_remote_code(
99
139
  f.write(f"from .{base_name} import {auto_map[key].__name__}\n")
100
140
 
101
141
 
102
- def generate_readme_head(
103
- models: list[str] | BaseModelPool,
104
- library_name: str = "transformers",
105
- tags: list[str] = ["fusion-bench", "merge"],
106
- ):
107
- text = "---\nbase_model:\n"
108
- for model_name in models:
109
- text += f"- {model_name}\n"
110
- if library_name:
111
- text += f"library_name: {library_name}\n"
112
- text += "tags:\n"
113
- for tag in tags:
114
- text += f"- {tag}\n"
115
- text += "---\n"
116
- return text
117
-
118
-
119
- def generate_readme_body(
120
- algorithm: BaseAlgorithm,
121
- models_or_modelpool: Optional[list[str] | BaseModelPool] = None,
122
- models: list[str] = None,
142
+ def create_default_model_card(
143
+ models: list[str],
144
+ description=None,
145
+ algorithm_config: DictConfig = None,
146
+ modelpool_config: DictConfig = None,
123
147
  ):
124
- text = """\
125
- # Merge
126
-
127
- This is a merge of pre-trained language models created using [fusion-bench](https://github.com/tanganke/fusion_bench).
128
-
129
- """
130
-
131
- if models is not None:
132
- text += """
133
- ## Models Merged
134
-
135
- The following models were included in the merge:
136
-
137
- """
138
- for model_name in models:
139
- text += f"- {model_name}\n"
140
- text += "\n"
141
-
142
- try:
143
- text += f"""\
144
- ## Configuration
145
-
146
- The following YAML configuration was used to produce this model:
147
-
148
- ```yaml
149
- {OmegaConf.to_yaml(algorithm.config, resolve=True, sort_keys=True)}
150
- ```
151
- """
152
- except Exception as e:
153
- return (
154
- text # If the algorithm config cannot be converted to YAML, we skip it.
155
- )
156
-
157
- if isinstance(models_or_modelpool, BaseModelPool):
158
- try:
159
- text += f"""
160
- ```yaml
161
- {OmegaConf.to_yaml(models_or_modelpool.config, resolve=True, sort_keys=True)}
162
- ```
163
- """
164
- except Exception as e:
165
- pass # If the model pool config cannot be converted to YAML, we skip it.
166
- return text
167
-
168
-
169
- def generate_complete_readme(
170
- algorithm: BaseAlgorithm, modelpool: BaseModelPool, models: list[str]
171
- ):
172
- # Generate the complete README content
173
- text = generate_readme_head(
174
- [modelpool.get_model_path(m) for m in modelpool.model_names]
175
- )
176
- readme_body = generate_readme_body(
177
- algorithm,
178
- models_or_modelpool=modelpool,
179
- models=[modelpool.get_model_path(m) for m in modelpool.model_names],
148
+ from jinja2 import Template
149
+
150
+ template: Template = Template(load_model_card_template("default.md"))
151
+ card = template.render(
152
+ models=models,
153
+ library_name="transformers",
154
+ tags=["fusion-bench", "merge"],
155
+ title="Deep Model Fusion",
156
+ description=description,
157
+ algorithm_config_str=try_to_yaml(algorithm_config),
158
+ modelpool_config_str=try_to_yaml(modelpool_config),
180
159
  )
181
- complete_readme = text + "\n" + readme_body
182
- return complete_readme
160
+ return card
@@ -0,0 +1,46 @@
1
+ ---
2
+ base_model:
3
+ {%- for model in models %}
4
+ - {{ model }}
5
+ {%- endfor %}
6
+ library_name: {{ library_name }}
7
+ tags:
8
+ {%- for tag in tags %}
9
+ - {{ tag }}
10
+ {%- endfor %}
11
+ ---
12
+ # {{ title }}
13
+
14
+ {% if description is not none %}{{ description }}{% endif %}
15
+
16
+ ## Models Merged
17
+
18
+ This is a merged model created using [fusion-bench](https://github.com/tanganke/fusion_bench).
19
+
20
+ The following models were included in the merge:
21
+ {% for model in models %}
22
+ - {{ model }}
23
+ {%- endfor %}
24
+
25
+ {% if algorithm_config_str is not none or modelpool_config_str is not none %}
26
+ ## Configuration
27
+
28
+ The following YAML configuration was used to produce this model:
29
+
30
+ {% if algorithm_config_str is not none -%}
31
+ ### Algorithm Configuration
32
+
33
+ ```yaml
34
+ {{ algorithm_config_str -}}
35
+ ```
36
+ {%- endif %}
37
+
38
+ {% if modelpool_config_str is not none -%}
39
+ ### Model Pool Configuration
40
+
41
+ ```yaml
42
+ {{ modelpool_config_str -}}
43
+ ```
44
+ {%- endif %}
45
+
46
+ {% endif %}
@@ -0,0 +1,7 @@
1
+ from . import register
2
+ from .configuration_smile_llama import SmileLlamaConfig
3
+ from .modeling_smile_llama import (
4
+ SmileLlamaDecoderLayer,
5
+ SmileLlamaForCausalLM,
6
+ SmileLlamaModel,
7
+ )
@@ -17,7 +17,6 @@ from transformers.modeling_outputs import (
17
17
  )
18
18
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
19
19
  from transformers.models.llama.modeling_llama import (
20
- LLAMA_INPUTS_DOCSTRING,
21
20
  LlamaRMSNorm,
22
21
  LlamaRotaryEmbedding,
23
22
  apply_rotary_pos_emb,
@@ -25,7 +24,6 @@ from transformers.models.llama.modeling_llama import (
25
24
  )
26
25
  from transformers.processing_utils import Unpack
27
26
  from transformers.utils import (
28
- LossKwargs,
29
27
  add_start_docstrings_to_model_forward,
30
28
  can_return_tuple,
31
29
  is_torch_flex_attn_available,
@@ -296,7 +294,6 @@ class SmileLlamaModel(SmileLlamaPreTrainedModel):
296
294
  self.embed_tokens = value
297
295
 
298
296
  @can_return_tuple
299
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
300
297
  def forward(
301
298
  self,
302
299
  input_ids: Optional[torch.LongTensor] = None,
@@ -566,9 +563,6 @@ class SmileLlamaModel(SmileLlamaPreTrainedModel):
566
563
  return causal_mask
567
564
 
568
565
 
569
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
570
-
571
-
572
566
  class SmileLlamaForCausalLM(SmileLlamaPreTrainedModel, GenerationMixin):
573
567
  _tied_weights_keys = ["lm_head.weight"]
574
568
  _tp_plan = {"lm_head": "colwise_rep"}
@@ -603,7 +597,6 @@ class SmileLlamaForCausalLM(SmileLlamaPreTrainedModel, GenerationMixin):
603
597
 
604
598
  @can_return_tuple
605
599
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
606
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
607
600
  @replace_return_docstrings(
608
601
  output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
609
602
  )
@@ -620,7 +613,7 @@ class SmileLlamaForCausalLM(SmileLlamaPreTrainedModel, GenerationMixin):
620
613
  output_hidden_states: Optional[bool] = None,
621
614
  cache_position: Optional[torch.LongTensor] = None,
622
615
  logits_to_keep: Union[int, torch.Tensor] = 0,
623
- **kwargs: Unpack[KwargsForCausalLM],
616
+ **kwargs,
624
617
  ) -> CausalLMOutputWithPast:
625
618
  r"""
626
619
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1,6 +1,6 @@
1
+ from . import register
1
2
  from .configuration_smile_mistral import SmileMistralConfig
2
3
  from .modeling_smile_mistral import (
3
4
  SmileMistralForCausalLM,
4
5
  SmileMistralModel,
5
6
  )
6
- from . import register
@@ -31,7 +31,6 @@ from transformers.models.qwen2.modeling_qwen2 import (
31
31
  )
32
32
  from transformers.processing_utils import Unpack
33
33
  from transformers.utils import (
34
- LossKwargs,
35
34
  add_code_sample_docstrings,
36
35
  add_start_docstrings,
37
36
  add_start_docstrings_to_model_forward,
@@ -607,9 +606,6 @@ class SmileQwen2Model(SmileQwen2PreTrainedModel):
607
606
  return causal_mask
608
607
 
609
608
 
610
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
611
-
612
-
613
609
  class SmileQwen2ForCausalLM(SmileQwen2PreTrainedModel, GenerationMixin):
614
610
  _tied_weights_keys = ["lm_head.weight"]
615
611
  _tp_plan = {"lm_head": "colwise_rep"}
@@ -660,7 +656,7 @@ class SmileQwen2ForCausalLM(SmileQwen2PreTrainedModel, GenerationMixin):
660
656
  output_hidden_states: Optional[bool] = None,
661
657
  cache_position: Optional[torch.LongTensor] = None,
662
658
  logits_to_keep: Union[int, torch.Tensor] = 0,
663
- **kwargs: Unpack[KwargsForCausalLM],
659
+ **kwargs,
664
660
  ) -> CausalLMOutputWithPast:
665
661
  r"""
666
662
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
+ from pathlib import Path
4
5
  from typing import Any, Callable, Dict, Iterable, List, Optional, Union # noqa: F401
5
6
 
6
7
  import lightning as L
@@ -9,19 +10,24 @@ from omegaconf import DictConfig, OmegaConf
9
10
  from torch import nn
10
11
  from tqdm.auto import tqdm
11
12
 
12
- import fusion_bench.utils.instantiate_utils
13
- from fusion_bench.method import BaseAlgorithm
13
+ import fusion_bench
14
+ from fusion_bench import (
15
+ BaseAlgorithm,
16
+ BaseHydraProgram,
17
+ BaseModelPool,
18
+ BaseTaskPool,
19
+ RuntimeConstants,
20
+ import_object,
21
+ instantiate,
22
+ timeit_context,
23
+ )
14
24
  from fusion_bench.mixins import LightningFabricMixin
15
- from fusion_bench.modelpool import BaseModelPool
16
- from fusion_bench.programs import BaseHydraProgram
17
- from fusion_bench.taskpool import BaseTaskPool
18
- from fusion_bench.utils import import_object, instantiate, timeit_context
19
25
  from fusion_bench.utils.hydra_utils import get_hydra_output_dir
20
26
  from fusion_bench.utils.json import print_json
21
- from fusion_bench.utils.pylogger import getRankZeroLogger
27
+ from fusion_bench.utils.path import create_symlink
22
28
  from fusion_bench.utils.rich_utils import print_bordered, print_config_tree
23
29
 
24
- log = getRankZeroLogger(__name__)
30
+ log = fusion_bench.get_rankzero_logger(__name__)
25
31
 
26
32
 
27
33
  class FabricModelFusionProgram(
@@ -60,6 +66,7 @@ class FabricModelFusionProgram(
60
66
  path: DictConfig = None,
61
67
  **kwargs,
62
68
  ):
69
+ super().__init__(**kwargs)
63
70
  self._method = method
64
71
  self._modelpool = modelpool
65
72
  self._taskpool = taskpool
@@ -70,8 +77,10 @@ class FabricModelFusionProgram(
70
77
  self.fast_dev_run = fast_dev_run
71
78
  self.seed = seed
72
79
  self.path = path
73
- fusion_bench.utils.instantiate_utils.PRINT_FUNCTION_CALL = print_function_call
74
- super().__init__(**kwargs)
80
+ RuntimeConstants.debug = fast_dev_run
81
+ RuntimeConstants.print_function_call = print_function_call
82
+ if path is not None:
83
+ RuntimeConstants.cache_dir = path.get("cache_dir", None)
75
84
 
76
85
  if print_config:
77
86
  print_config_tree(
@@ -224,8 +233,16 @@ class FabricModelFusionProgram(
224
233
  fabric = self.fabric
225
234
  if self.seed is not None:
226
235
  L.seed_everything(self.seed)
227
- if fabric.global_rank == 0:
228
- self._link_hydra_output()
236
+
237
+ # create symbol link to hydra output directory
238
+ if (
239
+ self.fabric.is_global_zero
240
+ and self.log_dir is not None
241
+ and os.path.abspath(self.log_dir) != os.path.abspath(get_hydra_output_dir())
242
+ ):
243
+ create_symlink(
244
+ get_hydra_output_dir(), self.log_dir, link_name="hydra_output"
245
+ )
229
246
 
230
247
  log.info("Running the model fusion program.")
231
248
  # setup the modelpool, method, and taskpool
@@ -278,51 +295,3 @@ class FabricModelFusionProgram(
278
295
  json.dump(report, open(self.report_save_path, "w"))
279
296
  else:
280
297
  log.info("No task pool specified. Skipping evaluation.")
281
-
282
- @rank_zero_only
283
- def _link_hydra_output(self):
284
- """
285
- Creates a symbolic link to the Hydra output directory within the specified log directory.
286
-
287
- If `self.log_dir` is not None, this method will:
288
- 1. Retrieve the Hydra output directory using `get_hydra_output_dir()`.
289
- 2. Create the log directory if it does not already exist.
290
- 3. Create a symbolic link named "hydra_output_<basename_of_hydra_output_dir>"
291
- within the log directory, pointing to the Hydra output directory.
292
-
293
- Note:
294
- - The symbolic link is created only if the Hydra output directory is not None.
295
- - The `target_is_directory` parameter is set to True to indicate that the target is a directory.
296
-
297
- Raises:
298
- OSError: If the symbolic link creation fails.
299
- """
300
- if self.log_dir is not None:
301
- # make symlink to the hydra output directory
302
- try:
303
- hydra_output_dir = get_hydra_output_dir()
304
- except Exception as e:
305
- hydra_output_dir = None
306
-
307
- if hydra_output_dir is not None:
308
- if os.path.abspath(hydra_output_dir) == os.path.abspath(self.log_dir):
309
- return
310
-
311
- os.makedirs(self.log_dir, exist_ok=True)
312
- try:
313
- # if the system is windows, use the `mklink` command in "CMD" to create the symlink
314
- if os.name == "nt":
315
- os.system(
316
- f"mklink /J {os.path.abspath(os.path.join(self.log_dir, 'hydra_output_' + os.path.basename(hydra_output_dir)))} {os.path.abspath(hydra_output_dir)}"
317
- )
318
- else:
319
- os.symlink(
320
- hydra_output_dir,
321
- os.path.join(
322
- self.log_dir,
323
- "hydra_output_" + os.path.basename(hydra_output_dir),
324
- ),
325
- target_is_directory=True,
326
- )
327
- except OSError as e:
328
- log.warning(f"Failed to create symbolic link: {e}")