fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. fusion_bench/__init__.py +25 -2
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/constants/__init__.py +1 -0
  7. fusion_bench/constants/runtime.py +57 -0
  8. fusion_bench/dataset/gpt2_glue.py +1 -1
  9. fusion_bench/method/__init__.py +12 -4
  10. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  11. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  12. fusion_bench/method/bitdelta/__init__.py +1 -0
  13. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  14. fusion_bench/method/classification/clip_finetune.py +1 -1
  15. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  16. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  17. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  18. fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
  19. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
  20. fusion_bench/method/linear/simple_average_for_llama.py +16 -11
  21. fusion_bench/method/model_stock/__init__.py +1 -0
  22. fusion_bench/method/model_stock/model_stock.py +309 -0
  23. fusion_bench/method/regmean/clip_regmean.py +3 -6
  24. fusion_bench/method/regmean/regmean.py +27 -56
  25. fusion_bench/method/regmean/utils.py +56 -0
  26. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  27. fusion_bench/method/simple_average.py +7 -7
  28. fusion_bench/method/slerp/__init__.py +1 -1
  29. fusion_bench/method/slerp/slerp.py +110 -14
  30. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  31. fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
  32. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  33. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
  34. fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
  35. fusion_bench/method/we_moe/__init__.py +1 -0
  36. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  37. fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
  38. fusion_bench/method/we_moe/utils.py +15 -0
  39. fusion_bench/method/weighted_average/llama.py +1 -1
  40. fusion_bench/mixins/clip_classification.py +37 -48
  41. fusion_bench/mixins/serialization.py +30 -10
  42. fusion_bench/modelpool/base_pool.py +1 -1
  43. fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
  44. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  45. fusion_bench/models/__init__.py +5 -0
  46. fusion_bench/models/hf_utils.py +69 -86
  47. fusion_bench/models/linearized/vision_model.py +6 -6
  48. fusion_bench/models/model_card_templates/default.md +46 -0
  49. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  50. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
  51. fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
  52. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
  53. fusion_bench/models/we_moe.py +8 -8
  54. fusion_bench/programs/fabric_fusion_program.py +29 -60
  55. fusion_bench/scripts/cli.py +34 -1
  56. fusion_bench/taskpool/base_pool.py +99 -17
  57. fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
  58. fusion_bench/taskpool/dummy.py +101 -13
  59. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  60. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  61. fusion_bench/utils/__init__.py +2 -0
  62. fusion_bench/utils/cache_utils.py +101 -1
  63. fusion_bench/utils/data.py +6 -4
  64. fusion_bench/utils/devices.py +7 -4
  65. fusion_bench/utils/dtype.py +3 -2
  66. fusion_bench/utils/fabric.py +2 -2
  67. fusion_bench/utils/lazy_imports.py +23 -0
  68. fusion_bench/utils/lazy_state_dict.py +117 -19
  69. fusion_bench/utils/modelscope.py +3 -3
  70. fusion_bench/utils/packages.py +3 -3
  71. fusion_bench/utils/parameters.py +0 -2
  72. fusion_bench/utils/path.py +56 -0
  73. fusion_bench/utils/pylogger.py +1 -1
  74. fusion_bench/utils/timer.py +92 -10
  75. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
  76. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
  77. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  78. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  79. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  80. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  81. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  82. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  83. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
  84. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  85. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
  86. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
  87. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
  88. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
  89. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
@@ -5,23 +5,65 @@ This module contains utilities for working with Hugging Face models.
5
5
  import inspect
6
6
  import os
7
7
  import shutil
8
- from typing import Optional, cast
8
+ from typing import List, Optional, cast
9
9
 
10
- from omegaconf import OmegaConf
10
+ from omegaconf import DictConfig, OmegaConf
11
11
  from transformers.modeling_utils import PreTrainedModel
12
12
 
13
- from fusion_bench import BaseAlgorithm, BaseModelPool
14
- from fusion_bench.utils.pylogger import getRankZeroLogger
13
+ from fusion_bench.utils.pylogger import get_rankzero_logger
15
14
 
16
- log = getRankZeroLogger(__name__)
15
+ log = get_rankzero_logger(__name__)
17
16
 
18
17
  __all__ = [
18
+ "load_model_card_template",
19
19
  "save_pretrained_with_remote_code",
20
- "generate_readme_head",
21
- "generate_readme_body",
22
- "generate_complete_readme",
20
+ "create_default_model_card",
23
21
  ]
24
22
 
23
+ MODEL_CARD_TEMPLATE_DIRS = [
24
+ os.path.join(os.path.dirname(__file__), "model_card_templates")
25
+ ]
26
+
27
+
28
+ def load_model_card_template(basename: str) -> str:
29
+ """
30
+ Load a model card template from file.
31
+
32
+ Searches for a template file by name, first checking if the name is a direct file path,
33
+ then searching through predefined template directories.
34
+
35
+ Args:
36
+ name (str): The name of the template file or a direct file path to the template.
37
+
38
+ Returns:
39
+ str: The contents of the template file as a string.
40
+
41
+ Raises:
42
+ FileNotFoundError: If the template file is not found in any of the search locations.
43
+ """
44
+ if os.path.exists(basename):
45
+ with open(basename, "r") as f:
46
+ return f.read()
47
+
48
+ for template_dir in MODEL_CARD_TEMPLATE_DIRS:
49
+ template_path = os.path.join(template_dir, basename)
50
+ if os.path.exists(template_path):
51
+ with open(template_path, "r") as f:
52
+ return f.read()
53
+
54
+ raise FileNotFoundError(f"Model card template '{basename}' not found.")
55
+
56
+
57
+ def try_to_yaml(config):
58
+ if config is None:
59
+ return None
60
+
61
+ try:
62
+ return OmegaConf.to_yaml(config, resolve=True, sort_keys=True)
63
+ except Exception as e:
64
+ log.error(f"Failed to convert config to YAML: {e}. Return `None`.")
65
+ return None
66
+
25
67
 
26
68
  def save_pretrained_with_remote_code(
27
69
  model: PreTrainedModel,
@@ -99,84 +141,25 @@ def save_pretrained_with_remote_code(
99
141
  f.write(f"from .{base_name} import {auto_map[key].__name__}\n")
100
142
 
101
143
 
102
- def generate_readme_head(
103
- models: list[str] | BaseModelPool,
104
- library_name: str = "transformers",
144
+ def create_default_model_card(
145
+ models: list[str],
146
+ *,
147
+ title: str = "Deep Model Fusion",
105
148
  tags: list[str] = ["fusion-bench", "merge"],
149
+ description=None,
150
+ algorithm_config: DictConfig = None,
151
+ modelpool_config: DictConfig = None,
106
152
  ):
107
- text = "---\nbase_model:\n"
108
- for model_name in models:
109
- text += f"- {model_name}\n"
110
- if library_name:
111
- text += f"library_name: {library_name}\n"
112
- text += "tags:\n"
113
- for tag in tags:
114
- text += f"- {tag}\n"
115
- text += "---\n"
116
- return text
117
-
118
-
119
- def generate_readme_body(
120
- algorithm: BaseAlgorithm,
121
- models_or_modelpool: Optional[list[str] | BaseModelPool] = None,
122
- models: list[str] = None,
123
- ):
124
- text = """\
125
- # Merge
126
-
127
- This is a merge of pre-trained language models created using [fusion-bench](https://github.com/tanganke/fusion_bench).
128
-
129
- """
130
-
131
- if models is not None:
132
- text += """
133
- ## Models Merged
134
-
135
- The following models were included in the merge:
136
-
137
- """
138
- for model_name in models:
139
- text += f"- {model_name}\n"
140
- text += "\n"
141
-
142
- try:
143
- text += f"""\
144
- ## Configuration
145
-
146
- The following YAML configuration was used to produce this model:
147
-
148
- ```yaml
149
- {OmegaConf.to_yaml(algorithm.config, resolve=True, sort_keys=True)}
150
- ```
151
- """
152
- except Exception as e:
153
- return (
154
- text # If the algorithm config cannot be converted to YAML, we skip it.
155
- )
156
-
157
- if isinstance(models_or_modelpool, BaseModelPool):
158
- try:
159
- text += f"""
160
- ```yaml
161
- {OmegaConf.to_yaml(models_or_modelpool.config, resolve=True, sort_keys=True)}
162
- ```
163
- """
164
- except Exception as e:
165
- pass # If the model pool config cannot be converted to YAML, we skip it.
166
- return text
167
-
168
-
169
- def generate_complete_readme(
170
- algorithm: BaseAlgorithm, modelpool: BaseModelPool, models: list[str]
171
- ):
172
- # Generate the complete README content
173
- text = generate_readme_head(
174
- [modelpool.get_model_path(m) for m in modelpool.model_names]
175
- )
176
- readme_body = generate_readme_body(
177
- algorithm,
178
- models_or_modelpool=modelpool,
179
- models=[modelpool.get_model_path(m) for m in modelpool.model_names],
153
+ from jinja2 import Template
154
+
155
+ template: Template = Template(load_model_card_template("default.md"))
156
+ card = template.render(
157
+ models=models,
158
+ library_name="transformers",
159
+ title=title,
160
+ tags=tags,
161
+ description=description,
162
+ algorithm_config_str=try_to_yaml(algorithm_config),
163
+ modelpool_config_str=try_to_yaml(modelpool_config),
180
164
  )
181
- complete_readme = text + "\n" + readme_body
182
- return complete_readme
165
+ return card
@@ -45,21 +45,21 @@ def linearize_lora_model_(model):
45
45
 
46
46
 
47
47
  def load_fft_vision_model_hf(
48
- model_name: str, return_vison_model=True
48
+ model_name: str, return_vision_model=True
49
49
  ) -> Union[CLIPVisionTransformer, CLIPVisionModel]:
50
50
  """
51
51
  Load a CLIP vision model from Hugging Face.
52
52
 
53
53
  Args:
54
54
  model_name (str): The name of the CLIP vision model to load from Hugging Face.
55
- return_vison_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
55
+ return_vision_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
56
56
 
57
57
  Returns:
58
58
  Union[CLIPVisionTransformer, CLIPVisionModel]: The vision model.
59
59
  """
60
60
  model = CLIPVisionModel.from_pretrained(model_name)
61
61
 
62
- if return_vison_model:
62
+ if return_vision_model:
63
63
  return CLIPVisionModel.from_pretrained(model_name).vision_model
64
64
  else:
65
65
  return model
@@ -69,7 +69,7 @@ def load_lora_vision_model_hf(
69
69
  base_model_name: str,
70
70
  peft_name: str,
71
71
  merge_and_unload: bool = False,
72
- return_vison_model=True,
72
+ return_vision_model=True,
73
73
  ) -> PeftModel:
74
74
  """
75
75
  Load a LoRA (Low-Rank Adaptation) vision model from Hugging Face.
@@ -80,7 +80,7 @@ def load_lora_vision_model_hf(
80
80
  base_model_name (str): The name of the base vision model to load from Hugging Face.
81
81
  peft_name (str): The name of the LoRA adaptation to apply to the base model.
82
82
  merge_and_unload (bool, optional): If True, the LoRA adaptation is merged into the base model and the LoRA layers are removed. Defaults to False.
83
- return_vison_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
83
+ return_vision_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
84
84
 
85
85
  Returns:
86
86
  PeftModel: The adapted vision model, optionally merged and unloaded.
@@ -97,7 +97,7 @@ def load_lora_vision_model_hf(
97
97
  vision_model = peft_model
98
98
 
99
99
  # Return the vision model
100
- if return_vison_model:
100
+ if return_vision_model:
101
101
  return vision_model
102
102
  else:
103
103
  model.vision_model = vision_model
@@ -0,0 +1,46 @@
1
+ ---
2
+ base_model:
3
+ {%- for model in models %}
4
+ - {{ model }}
5
+ {%- endfor %}
6
+ library_name: {{ library_name }}
7
+ tags:
8
+ {%- for tag in tags %}
9
+ - {{ tag }}
10
+ {%- endfor %}
11
+ ---
12
+ # {{ title }}
13
+
14
+ {% if description is not none %}{{ description }}{% endif %}
15
+
16
+ ## Models Merged
17
+
18
+ This is a merged model created using [fusion-bench](https://github.com/tanganke/fusion_bench).
19
+
20
+ The following models were included in the merge:
21
+ {% for model in models %}
22
+ - {{ model }}
23
+ {%- endfor %}
24
+
25
+ {% if algorithm_config_str is not none or modelpool_config_str is not none %}
26
+ ## Configuration
27
+
28
+ The following YAML configuration was used to produce this model:
29
+
30
+ {% if algorithm_config_str is not none -%}
31
+ ### Algorithm Configuration
32
+
33
+ ```yaml
34
+ {{ algorithm_config_str -}}
35
+ ```
36
+ {%- endif %}
37
+
38
+ {% if modelpool_config_str is not none -%}
39
+ ### Model Pool Configuration
40
+
41
+ ```yaml
42
+ {{ modelpool_config_str -}}
43
+ ```
44
+ {%- endif %}
45
+
46
+ {% endif %}
@@ -0,0 +1,7 @@
1
+ from . import register
2
+ from .configuration_smile_llama import SmileLlamaConfig
3
+ from .modeling_smile_llama import (
4
+ SmileLlamaDecoderLayer,
5
+ SmileLlamaForCausalLM,
6
+ SmileLlamaModel,
7
+ )
@@ -17,7 +17,6 @@ from transformers.modeling_outputs import (
17
17
  )
18
18
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
19
19
  from transformers.models.llama.modeling_llama import (
20
- LLAMA_INPUTS_DOCSTRING,
21
20
  LlamaRMSNorm,
22
21
  LlamaRotaryEmbedding,
23
22
  apply_rotary_pos_emb,
@@ -25,7 +24,6 @@ from transformers.models.llama.modeling_llama import (
25
24
  )
26
25
  from transformers.processing_utils import Unpack
27
26
  from transformers.utils import (
28
- LossKwargs,
29
27
  add_start_docstrings_to_model_forward,
30
28
  can_return_tuple,
31
29
  is_torch_flex_attn_available,
@@ -296,7 +294,6 @@ class SmileLlamaModel(SmileLlamaPreTrainedModel):
296
294
  self.embed_tokens = value
297
295
 
298
296
  @can_return_tuple
299
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
300
297
  def forward(
301
298
  self,
302
299
  input_ids: Optional[torch.LongTensor] = None,
@@ -566,9 +563,6 @@ class SmileLlamaModel(SmileLlamaPreTrainedModel):
566
563
  return causal_mask
567
564
 
568
565
 
569
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
570
-
571
-
572
566
  class SmileLlamaForCausalLM(SmileLlamaPreTrainedModel, GenerationMixin):
573
567
  _tied_weights_keys = ["lm_head.weight"]
574
568
  _tp_plan = {"lm_head": "colwise_rep"}
@@ -603,7 +597,6 @@ class SmileLlamaForCausalLM(SmileLlamaPreTrainedModel, GenerationMixin):
603
597
 
604
598
  @can_return_tuple
605
599
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
606
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
607
600
  @replace_return_docstrings(
608
601
  output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
609
602
  )
@@ -620,7 +613,7 @@ class SmileLlamaForCausalLM(SmileLlamaPreTrainedModel, GenerationMixin):
620
613
  output_hidden_states: Optional[bool] = None,
621
614
  cache_position: Optional[torch.LongTensor] = None,
622
615
  logits_to_keep: Union[int, torch.Tensor] = 0,
623
- **kwargs: Unpack[KwargsForCausalLM],
616
+ **kwargs,
624
617
  ) -> CausalLMOutputWithPast:
625
618
  r"""
626
619
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1,6 +1,7 @@
1
+ from . import register
1
2
  from .configuration_smile_mistral import SmileMistralConfig
2
3
  from .modeling_smile_mistral import (
4
+ SmileMistralDecoderLayer,
3
5
  SmileMistralForCausalLM,
4
6
  SmileMistralModel,
5
7
  )
6
- from . import register
@@ -31,7 +31,6 @@ from transformers.models.qwen2.modeling_qwen2 import (
31
31
  )
32
32
  from transformers.processing_utils import Unpack
33
33
  from transformers.utils import (
34
- LossKwargs,
35
34
  add_code_sample_docstrings,
36
35
  add_start_docstrings,
37
36
  add_start_docstrings_to_model_forward,
@@ -607,9 +606,6 @@ class SmileQwen2Model(SmileQwen2PreTrainedModel):
607
606
  return causal_mask
608
607
 
609
608
 
610
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
611
-
612
-
613
609
  class SmileQwen2ForCausalLM(SmileQwen2PreTrainedModel, GenerationMixin):
614
610
  _tied_weights_keys = ["lm_head.weight"]
615
611
  _tp_plan = {"lm_head": "colwise_rep"}
@@ -660,7 +656,7 @@ class SmileQwen2ForCausalLM(SmileQwen2PreTrainedModel, GenerationMixin):
660
656
  output_hidden_states: Optional[bool] = None,
661
657
  cache_position: Optional[torch.LongTensor] = None,
662
658
  logits_to_keep: Union[int, torch.Tensor] = 0,
663
- **kwargs: Unpack[KwargsForCausalLM],
659
+ **kwargs,
664
660
  ) -> CausalLMOutputWithPast:
665
661
  r"""
666
662
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import logging
3
- from typing import List
3
+ from typing import Generic, List
4
4
 
5
5
  import torch
6
6
  import torch.func
@@ -9,7 +9,7 @@ from torch.func import functional_call
9
9
  from torch.nn import functional as F
10
10
 
11
11
  from fusion_bench.models.utils import del_attr, get_attr, set_attr
12
- from fusion_bench.utils.type import StateDictType
12
+ from fusion_bench.utils.type import StateDictType, TorchModelType
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
@@ -76,15 +76,15 @@ def construct_weight_ensembling_gate(
76
76
  return gate
77
77
 
78
78
 
79
- class WeightEnsemblingMoE(nn.Module):
79
+ class WeightEnsemblingMoE(nn.Module, Generic[TorchModelType]):
80
80
  # variable to store the merged state dict temporarily
81
81
  _merged_state_dict: StateDictType = None
82
82
 
83
83
  def __init__(
84
84
  self,
85
85
  hidden_size: int,
86
- base_model: nn.Module,
87
- expert_models: List[nn.Module],
86
+ base_model: TorchModelType,
87
+ expert_models: List[TorchModelType],
88
88
  init_lambda: float = 0.2,
89
89
  batch_first: bool = False,
90
90
  router_hidden_layers: int = 2,
@@ -101,8 +101,8 @@ class WeightEnsemblingMoE(nn.Module):
101
101
  Args:
102
102
 
103
103
  hidden_size (int): The size of the hidden layer in the models.
104
- base_model (nn.Module): The base model that will be used as a reference for the expert models.
105
- expert_models (List[nn.Module]): A list of expert models that will be combined.
104
+ base_model (TorchModelType): The base model that will be used as a reference for the expert models.
105
+ expert_models (List[TorchModelType]): A list of expert models that will be combined.
106
106
  init_lambda (float, optional): The initial lambda value for the weight ensembling gate. Defaults to 0.2.
107
107
  batch_first (bool, optional): If True, the input tensors are expected to have the batch size as the first dimension. Defaults to False.
108
108
  router_hidden_layers (int, optional): The number of hidden layers in the router. Defaults to 2.
@@ -145,7 +145,7 @@ class WeightEnsemblingMoE(nn.Module):
145
145
  self._merged_state_dict,
146
146
  )
147
147
 
148
- def merge_weights(self, expert_weights):
148
+ def merge_weights(self, expert_weights) -> StateDictType:
149
149
  state_dict = self.base_model.state_dict(keep_vars=True)
150
150
  for weight, task_vector in zip(expert_weights, self.task_vectors):
151
151
  for name, param in task_vector.named_parameters():
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
+ from pathlib import Path
4
5
  from typing import Any, Callable, Dict, Iterable, List, Optional, Union # noqa: F401
5
6
 
6
7
  import lightning as L
@@ -9,19 +10,24 @@ from omegaconf import DictConfig, OmegaConf
9
10
  from torch import nn
10
11
  from tqdm.auto import tqdm
11
12
 
12
- import fusion_bench.utils.instantiate_utils
13
- from fusion_bench.method import BaseAlgorithm
13
+ import fusion_bench
14
+ from fusion_bench import (
15
+ BaseAlgorithm,
16
+ BaseHydraProgram,
17
+ BaseModelPool,
18
+ BaseTaskPool,
19
+ RuntimeConstants,
20
+ import_object,
21
+ instantiate,
22
+ timeit_context,
23
+ )
14
24
  from fusion_bench.mixins import LightningFabricMixin
15
- from fusion_bench.modelpool import BaseModelPool
16
- from fusion_bench.programs import BaseHydraProgram
17
- from fusion_bench.taskpool import BaseTaskPool
18
- from fusion_bench.utils import import_object, instantiate, timeit_context
19
25
  from fusion_bench.utils.hydra_utils import get_hydra_output_dir
20
26
  from fusion_bench.utils.json import print_json
21
- from fusion_bench.utils.pylogger import getRankZeroLogger
27
+ from fusion_bench.utils.path import create_symlink
22
28
  from fusion_bench.utils.rich_utils import print_bordered, print_config_tree
23
29
 
24
- log = getRankZeroLogger(__name__)
30
+ log = fusion_bench.get_rankzero_logger(__name__)
25
31
 
26
32
 
27
33
  class FabricModelFusionProgram(
@@ -60,6 +66,7 @@ class FabricModelFusionProgram(
60
66
  path: DictConfig = None,
61
67
  **kwargs,
62
68
  ):
69
+ super().__init__(**kwargs)
63
70
  self._method = method
64
71
  self._modelpool = modelpool
65
72
  self._taskpool = taskpool
@@ -70,8 +77,10 @@ class FabricModelFusionProgram(
70
77
  self.fast_dev_run = fast_dev_run
71
78
  self.seed = seed
72
79
  self.path = path
73
- fusion_bench.utils.instantiate_utils.PRINT_FUNCTION_CALL = print_function_call
74
- super().__init__(**kwargs)
80
+ RuntimeConstants.debug = fast_dev_run
81
+ RuntimeConstants.print_function_call = print_function_call
82
+ if path is not None:
83
+ RuntimeConstants.cache_dir = path.get("cache_dir", None)
75
84
 
76
85
  if print_config:
77
86
  print_config_tree(
@@ -224,8 +233,16 @@ class FabricModelFusionProgram(
224
233
  fabric = self.fabric
225
234
  if self.seed is not None:
226
235
  L.seed_everything(self.seed)
227
- if fabric.global_rank == 0:
228
- self._link_hydra_output()
236
+
237
+ # create symbol link to hydra output directory
238
+ if (
239
+ self.fabric.is_global_zero
240
+ and self.log_dir is not None
241
+ and os.path.abspath(self.log_dir) != os.path.abspath(get_hydra_output_dir())
242
+ ):
243
+ create_symlink(
244
+ get_hydra_output_dir(), self.log_dir, link_name="hydra_output"
245
+ )
229
246
 
230
247
  log.info("Running the model fusion program.")
231
248
  # setup the modelpool, method, and taskpool
@@ -278,51 +295,3 @@ class FabricModelFusionProgram(
278
295
  json.dump(report, open(self.report_save_path, "w"))
279
296
  else:
280
297
  log.info("No task pool specified. Skipping evaluation.")
281
-
282
- @rank_zero_only
283
- def _link_hydra_output(self):
284
- """
285
- Creates a symbolic link to the Hydra output directory within the specified log directory.
286
-
287
- If `self.log_dir` is not None, this method will:
288
- 1. Retrieve the Hydra output directory using `get_hydra_output_dir()`.
289
- 2. Create the log directory if it does not already exist.
290
- 3. Create a symbolic link named "hydra_output_<basename_of_hydra_output_dir>"
291
- within the log directory, pointing to the Hydra output directory.
292
-
293
- Note:
294
- - The symbolic link is created only if the Hydra output directory is not None.
295
- - The `target_is_directory` parameter is set to True to indicate that the target is a directory.
296
-
297
- Raises:
298
- OSError: If the symbolic link creation fails.
299
- """
300
- if self.log_dir is not None:
301
- # make symlink to the hydra output directory
302
- try:
303
- hydra_output_dir = get_hydra_output_dir()
304
- except Exception as e:
305
- hydra_output_dir = None
306
-
307
- if hydra_output_dir is not None:
308
- if os.path.abspath(hydra_output_dir) == os.path.abspath(self.log_dir):
309
- return
310
-
311
- os.makedirs(self.log_dir, exist_ok=True)
312
- try:
313
- # if the system is windows, use the `mklink` command in "CMD" to create the symlink
314
- if os.name == "nt":
315
- os.system(
316
- f"mklink /J {os.path.abspath(os.path.join(self.log_dir, 'hydra_output_' + os.path.basename(hydra_output_dir)))} {os.path.abspath(hydra_output_dir)}"
317
- )
318
- else:
319
- os.symlink(
320
- hydra_output_dir,
321
- os.path.join(
322
- self.log_dir,
323
- "hydra_output_" + os.path.basename(hydra_output_dir),
324
- ),
325
- target_is_directory=True,
326
- )
327
- except OSError as e:
328
- log.warning(f"Failed to create symbolic link: {e}")
@@ -12,9 +12,9 @@ import os
12
12
  import hydra
13
13
  from omegaconf import DictConfig, OmegaConf
14
14
 
15
+ from fusion_bench.constants import PROJECT_ROOT_PATH
15
16
  from fusion_bench.programs import BaseHydraProgram
16
17
  from fusion_bench.utils import instantiate
17
- from fusion_bench.constants import PROJECT_ROOT_PATH
18
18
 
19
19
  log = logging.getLogger(__name__)
20
20
 
@@ -34,6 +34,39 @@ def _get_default_config_path():
34
34
  version_base=None,
35
35
  )
36
36
  def main(cfg: DictConfig) -> None:
37
+ """
38
+ Main entry point for the FusionBench command-line interface.
39
+
40
+ This function serves as the primary entry point for the `fusion_bench` CLI command.
41
+ It is decorated with Hydra's main decorator to handle configuration management,
42
+ command-line argument parsing, and configuration file loading.
43
+
44
+ The function performs the following operations:
45
+ 1. Resolves any interpolations in the configuration using OmegaConf
46
+ 2. Instantiates the appropriate program class based on the configuration
47
+ 3. Executes the program's run method to perform the fusion task
48
+
49
+ Args:
50
+ cfg (DictConfig): The Hydra configuration object containing all settings
51
+ for the fusion task. This includes method configuration, model pool
52
+ configuration, task pool configuration, and other runtime parameters.
53
+ The configuration is automatically loaded by Hydra from the specified
54
+ config files and command-line overrides.
55
+
56
+ Returns:
57
+ None: This function doesn't return a value but executes the fusion
58
+ program which may save results, log outputs, or perform other
59
+ side effects as configured.
60
+
61
+ Example:
62
+ This function is typically called automatically when running:
63
+ ```bash
64
+ fusion_bench method=... modelpool=... taskpool=...
65
+ ```
66
+
67
+ The Hydra decorator handles parsing these command-line arguments and
68
+ loading the corresponding configuration files to populate the cfg parameter.
69
+ """
37
70
  OmegaConf.resolve(cfg)
38
71
  program: BaseHydraProgram = instantiate(cfg)
39
72
  program.run()