fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. fusion_bench/__init__.py +22 -2
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +6 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/constants/runtime.py +57 -0
  9. fusion_bench/dataset/clip_dataset.py +2 -1
  10. fusion_bench/dataset/gpt2_glue.py +9 -9
  11. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  12. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  13. fusion_bench/dataset/image_dataset.py +1 -1
  14. fusion_bench/dataset/nyuv2.py +2 -2
  15. fusion_bench/method/__init__.py +24 -5
  16. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  17. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  19. fusion_bench/method/base_algorithm.py +195 -12
  20. fusion_bench/method/bitdelta/__init__.py +5 -0
  21. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  25. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  26. fusion_bench/method/classification/clip_finetune.py +1 -1
  27. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  28. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  29. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  30. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  31. fusion_bench/method/ensemble.py +12 -12
  32. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  33. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
  34. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  35. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  36. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  37. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  38. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  39. fusion_bench/method/linear/expo.py +2 -1
  40. fusion_bench/method/linear/linear_interpolation.py +6 -4
  41. fusion_bench/method/linear/simple_average_for_llama.py +17 -13
  42. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  43. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  44. fusion_bench/method/model_recombination.py +2 -5
  45. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  46. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  47. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  48. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  49. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  50. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  51. fusion_bench/method/randes/modelsoup.py +1 -3
  52. fusion_bench/method/regmean/clip_regmean.py +2 -2
  53. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  54. fusion_bench/method/regmean/regmean.py +2 -11
  55. fusion_bench/method/regmean_plusplus/__init__.py +1 -1
  56. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
  57. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
  58. fusion_bench/method/simple_average.py +12 -16
  59. fusion_bench/method/slerp/slerp.py +5 -2
  60. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  61. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  62. fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
  63. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  64. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
  65. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  66. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  67. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  68. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  69. fusion_bench/method/we_moe/__init__.py +1 -0
  70. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  71. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  72. fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
  73. fusion_bench/method/we_moe/utils.py +15 -0
  74. fusion_bench/method/we_moe/we_moe.py +6 -6
  75. fusion_bench/method/weighted_average/llama.py +4 -16
  76. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  77. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  78. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  79. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  80. fusion_bench/mixins/__init__.py +10 -2
  81. fusion_bench/mixins/clip_classification.py +15 -45
  82. fusion_bench/mixins/hydra_config.py +105 -7
  83. fusion_bench/mixins/lightning_fabric.py +2 -0
  84. fusion_bench/mixins/serialization.py +275 -48
  85. fusion_bench/modelpool/__init__.py +2 -2
  86. fusion_bench/modelpool/base_pool.py +29 -9
  87. fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
  88. fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
  89. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  90. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  91. fusion_bench/models/__init__.py +7 -1
  92. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  93. fusion_bench/models/hf_utils.py +160 -0
  94. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  95. fusion_bench/models/linearized/vision_model.py +1 -1
  96. fusion_bench/models/model_card_templates/default.md +46 -0
  97. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  98. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  99. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  100. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  101. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  102. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  103. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  104. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  105. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  106. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
  107. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  108. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  109. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  110. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
  111. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  112. fusion_bench/models/parameter_dict.py +1 -1
  113. fusion_bench/models/sparse_we_moe.py +1 -53
  114. fusion_bench/models/utils.py +26 -0
  115. fusion_bench/models/we_moe.py +1 -53
  116. fusion_bench/models/wrappers/ensemble.py +6 -4
  117. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  118. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  119. fusion_bench/programs/base_program.py +81 -2
  120. fusion_bench/programs/fabric_fusion_program.py +46 -61
  121. fusion_bench/scripts/cli.py +38 -5
  122. fusion_bench/taskpool/base_pool.py +4 -3
  123. fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
  124. fusion_bench/taskpool/dummy.py +1 -1
  125. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  126. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  127. fusion_bench/utils/__init__.py +7 -1
  128. fusion_bench/utils/cache_utils.py +101 -1
  129. fusion_bench/utils/devices.py +14 -4
  130. fusion_bench/utils/fabric.py +2 -2
  131. fusion_bench/utils/instantiate_utils.py +3 -1
  132. fusion_bench/utils/lazy_imports.py +23 -0
  133. fusion_bench/utils/lazy_state_dict.py +38 -3
  134. fusion_bench/utils/modelscope.py +127 -8
  135. fusion_bench/utils/parameters.py +2 -2
  136. fusion_bench/utils/path.py +56 -0
  137. fusion_bench/utils/pylogger.py +1 -1
  138. fusion_bench/utils/rich_utils.py +3 -0
  139. fusion_bench/utils/state_dict_arithmetic.py +25 -23
  140. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
  141. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
  142. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  143. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  144. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  145. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  146. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  147. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  148. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  149. fusion_bench_config/hydra/default.yaml +6 -2
  150. fusion_bench_config/llama_full_finetune.yaml +1 -0
  151. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  152. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  153. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  154. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  155. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  156. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  157. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  158. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  159. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
  160. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
  167. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
  168. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  169. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  170. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  171. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  172. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  173. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  174. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  175. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  176. fusion_bench_config/nyuv2_config.yaml +3 -1
  177. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  178. fusion_bench_config/path/default.yaml +28 -0
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  180. fusion_bench_config/method/adamerging.yaml +0 -23
  181. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  182. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  183. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  184. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
  185. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
  186. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
  187. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
  188. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -1,48 +1,6 @@
1
- # flake8: noqa F401
2
- from typing import TYPE_CHECKING
3
-
4
- from transformers.utils.import_utils import (
5
- OptionalDependencyNotAvailable,
6
- _LazyModule,
7
- is_flax_available,
8
- is_tf_available,
9
- is_torch_available,
1
+ from . import register
2
+ from .configuration_smile_mistral import SmileMistralConfig
3
+ from .modeling_smile_mistral import (
4
+ SmileMistralForCausalLM,
5
+ SmileMistralModel,
10
6
  )
11
-
12
- _import_structure = {
13
- "configuration_smile_mistral": ["SmileMistralConfig"],
14
- }
15
-
16
- try:
17
- if not is_torch_available():
18
- raise OptionalDependencyNotAvailable()
19
- except OptionalDependencyNotAvailable:
20
- pass
21
- else:
22
- _import_structure["modeling_smile_mistral"] = [
23
- "SmileMistralForCausalLM",
24
- "SmileMistralModel",
25
- "SmileMistralPreTrainedModel",
26
- ]
27
-
28
- if TYPE_CHECKING:
29
- from .configuration_smile_mistral import SmileMistralConfig
30
-
31
- try:
32
- if not is_torch_available():
33
- raise OptionalDependencyNotAvailable()
34
- except OptionalDependencyNotAvailable:
35
- pass
36
- else:
37
- from .modeling_smile_mistral import (
38
- SmileMistralForCausalLM,
39
- SmileMistralModel,
40
- SmileMistralPreTrainedModel,
41
- )
42
-
43
- else:
44
- import sys
45
-
46
- sys.modules[__name__] = _LazyModule(
47
- __name__, globals()["__file__"], _import_structure, module_spec=__spec__
48
- )
@@ -5,4 +5,4 @@ from .modeling_smile_qwen2 import (
5
5
  SmileQwen2ForQuestionAnswering,
6
6
  SmileQwen2ForSequenceClassification,
7
7
  SmileQwen2Model,
8
- )
8
+ )
@@ -24,7 +24,6 @@ from transformers.modeling_outputs import (
24
24
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
25
25
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
26
26
  from transformers.models.qwen2.modeling_qwen2 import (
27
- QWEN2_INPUTS_DOCSTRING,
28
27
  Qwen2RMSNorm,
29
28
  Qwen2RotaryEmbedding,
30
29
  apply_rotary_pos_emb,
@@ -32,7 +31,6 @@ from transformers.models.qwen2.modeling_qwen2 import (
32
31
  )
33
32
  from transformers.processing_utils import Unpack
34
33
  from transformers.utils import (
35
- LossKwargs,
36
34
  add_code_sample_docstrings,
37
35
  add_start_docstrings,
38
36
  add_start_docstrings_to_model_forward,
@@ -314,7 +312,6 @@ class SmileQwen2Model(SmileQwen2PreTrainedModel):
314
312
  self.embed_tokens = value
315
313
 
316
314
  @can_return_tuple
317
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
318
315
  def forward(
319
316
  self,
320
317
  input_ids: Optional[torch.LongTensor] = None,
@@ -609,9 +606,6 @@ class SmileQwen2Model(SmileQwen2PreTrainedModel):
609
606
  return causal_mask
610
607
 
611
608
 
612
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
613
-
614
-
615
609
  class SmileQwen2ForCausalLM(SmileQwen2PreTrainedModel, GenerationMixin):
616
610
  _tied_weights_keys = ["lm_head.weight"]
617
611
  _tp_plan = {"lm_head": "colwise_rep"}
@@ -646,7 +640,6 @@ class SmileQwen2ForCausalLM(SmileQwen2PreTrainedModel, GenerationMixin):
646
640
 
647
641
  @can_return_tuple
648
642
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
649
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
650
643
  @replace_return_docstrings(
651
644
  output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
652
645
  )
@@ -663,7 +656,7 @@ class SmileQwen2ForCausalLM(SmileQwen2PreTrainedModel, GenerationMixin):
663
656
  output_hidden_states: Optional[bool] = None,
664
657
  cache_position: Optional[torch.LongTensor] = None,
665
658
  logits_to_keep: Union[int, torch.Tensor] = 0,
666
- **kwargs: Unpack[KwargsForCausalLM],
659
+ **kwargs,
667
660
  ) -> CausalLMOutputWithPast:
668
661
  r"""
669
662
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -752,7 +745,9 @@ class SmileQwen2ForSequenceClassification(SmileQwen2PreTrainedModel):
752
745
  def __init__(self, config):
753
746
  super().__init__(config)
754
747
  self.num_labels = config.num_labels
755
- self.model = SmileQwen2Model(config) #* replace Qwen2Model with SmileQwen2Model
748
+ self.model = SmileQwen2Model(
749
+ config
750
+ ) # * replace Qwen2Model with SmileQwen2Model
756
751
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
757
752
 
758
753
  # Initialize weights and apply final processing
@@ -765,7 +760,6 @@ class SmileQwen2ForSequenceClassification(SmileQwen2PreTrainedModel):
765
760
  self.model.embed_tokens = value
766
761
 
767
762
  @can_return_tuple
768
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
769
763
  def forward(
770
764
  self,
771
765
  input_ids: Optional[torch.LongTensor] = None,
@@ -852,7 +846,9 @@ class SmileQwen2ForQuestionAnswering(SmileQwen2PreTrainedModel):
852
846
 
853
847
  def __init__(self, config):
854
848
  super().__init__(config)
855
- self.transformer = SmileQwen2Model(config) #* replace Qwen2Model with SmileQwen2Model
849
+ self.transformer = SmileQwen2Model(
850
+ config
851
+ ) # * replace Qwen2Model with SmileQwen2Model
856
852
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
857
853
 
858
854
  # Initialize weights and apply final processing
@@ -865,7 +861,6 @@ class SmileQwen2ForQuestionAnswering(SmileQwen2PreTrainedModel):
865
861
  self.transformer.embed_tokens = value
866
862
 
867
863
  @can_return_tuple
868
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
869
864
  def forward(
870
865
  self,
871
866
  input_ids: Optional[torch.LongTensor] = None,
@@ -1,10 +1,7 @@
1
1
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
2
2
 
3
3
  from .configuration_smile_qwen2 import SmileQwen2Config
4
- from .modeling_smile_qwen2 import (
5
- SmileQwen2ForCausalLM,
6
- SmileQwen2Model,
7
- )
4
+ from .modeling_smile_qwen2 import SmileQwen2ForCausalLM, SmileQwen2Model
8
5
 
9
6
  AutoConfig.register("smile_qwen2", SmileQwen2Config)
10
7
  AutoModel.register(SmileQwen2Config, SmileQwen2Model)
@@ -74,7 +74,7 @@ class ParameterDictModel(nn.Module):
74
74
  name.split("."),
75
75
  param,
76
76
  check_parent=True,
77
- parent_builder=self.__class__,
77
+ parent_builder=__class__,
78
78
  )
79
79
 
80
80
  def __repr__(self):
@@ -11,6 +11,7 @@ from torch.func import functional_call
11
11
  from torch.nn import functional as F
12
12
  from tqdm.auto import tqdm
13
13
 
14
+ from fusion_bench.models.utils import del_attr, get_attr, set_attr
14
15
  from fusion_bench.utils.state_dict_arithmetic import (
15
16
  state_dict_sub,
16
17
  state_dict_weighted_sum,
@@ -20,59 +21,6 @@ from fusion_bench.utils.type import StateDictType
20
21
  log = logging.getLogger(__name__)
21
22
 
22
23
 
23
- def join_list(list_of_list: List[List]):
24
- ans = []
25
- for l in list_of_list:
26
- ans.extend(l)
27
- return ans
28
-
29
-
30
- def del_attr(obj, names: List[str]):
31
- """
32
- Deletes an attribute from an object recursively.
33
-
34
- Args:
35
- obj (object): Object to delete attribute from.
36
- names (list): List of attribute names to delete recursively.
37
- """
38
- if len(names) == 1:
39
- delattr(obj, names[0])
40
- else:
41
- del_attr(getattr(obj, names[0]), names[1:])
42
-
43
-
44
- def set_attr(obj, names: List[str], val):
45
- """
46
- Sets an attribute of an object recursively.
47
-
48
- Args:
49
- obj (object): Object to set attribute of.
50
- names (list): List of attribute names to set recursively.
51
- val (object): Value to set the attribute to.
52
- """
53
- if len(names) == 1:
54
- setattr(obj, names[0], val)
55
- else:
56
- set_attr(getattr(obj, names[0]), names[1:], val)
57
-
58
-
59
- def get_attr(obj, names: List[str]):
60
- """
61
- Gets an attribute of an object recursively.
62
-
63
- Args:
64
- obj (object): Object to get attribute of.
65
- names (list): List of attribute names to get recursively.
66
-
67
- Returns:
68
- object: The attribute of the object.
69
- """
70
- if len(names) == 1:
71
- return getattr(obj, names[0])
72
- else:
73
- return get_attr(getattr(obj, names[0]), names[1:])
74
-
75
-
76
24
  class Depth_0_Gate(nn.Module):
77
25
  def __init__(self, num_experts: int):
78
26
  super().__init__()
@@ -3,6 +3,8 @@ from typing import List
3
3
  import torch
4
4
  from torch import nn
5
5
 
6
+ from fusion_bench.utils.type import StateDictType
7
+
6
8
 
7
9
  def del_attr(obj, names: List[str]):
8
10
  """
@@ -50,6 +52,30 @@ def get_attr(obj, names: List[str]):
50
52
  return get_attr(getattr(obj, names[0]), names[1:])
51
53
 
52
54
 
55
+ def check_parameterNamesMatch(checkpoints: List[StateDictType]) -> None:
56
+ """
57
+ Checks that the parameter names of the given checkpoints match.
58
+
59
+ Args:
60
+ checkpoints (List[Dict[str, float]]): A list of checkpoints, where each checkpoint is a dictionary of parameter names and their corresponding values.
61
+
62
+ Raises:
63
+ ValueError: If the number of checkpoints is less than 2 or if the parameter names of any two checkpoints differ.
64
+
65
+ """
66
+ parameter_names = set(checkpoints[0].keys())
67
+
68
+ if len(checkpoints) >= 2:
69
+ # raise ValueError("Number of models is less than 2.")
70
+ for checkpoint in checkpoints[1:]:
71
+ current_parameterNames = set(checkpoint.keys())
72
+ if current_parameterNames != parameter_names:
73
+ raise ValueError(
74
+ "Differing parameter names in models. "
75
+ f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
76
+ )
77
+
78
+
53
79
  def find_layers_with_type(
54
80
  module: nn.Module,
55
81
  layer_types=[nn.Linear],
@@ -8,64 +8,12 @@ from torch import Tensor, nn
8
8
  from torch.func import functional_call
9
9
  from torch.nn import functional as F
10
10
 
11
+ from fusion_bench.models.utils import del_attr, get_attr, set_attr
11
12
  from fusion_bench.utils.type import StateDictType
12
13
 
13
14
  log = logging.getLogger(__name__)
14
15
 
15
16
 
16
- def join_list(list_of_list: List[List]):
17
- ans = []
18
- for l in list_of_list:
19
- ans.extend(l)
20
- return ans
21
-
22
-
23
- def del_attr(obj, names: List[str]):
24
- """
25
- Deletes an attribute from an object recursively.
26
-
27
- Args:
28
- obj (object): Object to delete attribute from.
29
- names (list): List of attribute names to delete recursively.
30
- """
31
- if len(names) == 1:
32
- delattr(obj, names[0])
33
- else:
34
- del_attr(getattr(obj, names[0]), names[1:])
35
-
36
-
37
- def set_attr(obj, names: List[str], val):
38
- """
39
- Sets an attribute of an object recursively.
40
-
41
- Args:
42
- obj (object): Object to set attribute of.
43
- names (list): List of attribute names to set recursively.
44
- val (object): Value to set the attribute to.
45
- """
46
- if len(names) == 1:
47
- setattr(obj, names[0], val)
48
- else:
49
- set_attr(getattr(obj, names[0]), names[1:], val)
50
-
51
-
52
- def get_attr(obj, names: List[str]):
53
- """
54
- Gets an attribute of an object recursively.
55
-
56
- Args:
57
- obj (object): Object to get attribute of.
58
- names (list): List of attribute names to get recursively.
59
-
60
- Returns:
61
- object: The attribute of the object.
62
- """
63
- if len(names) == 1:
64
- return getattr(obj, names[0])
65
- else:
66
- return get_attr(getattr(obj, names[0]), names[1:])
67
-
68
-
69
17
  class Depth_0_Gate(nn.Module):
70
18
  def __init__(self, num_experts: int):
71
19
  super().__init__()
@@ -1,4 +1,4 @@
1
- from typing import Any, Callable, Dict, List, cast
1
+ from typing import Any, Callable, Dict, List, Union, cast
2
2
 
3
3
  import numpy as np
4
4
  import torch
@@ -6,7 +6,9 @@ from omegaconf import ListConfig
6
6
  from torch import Tensor, nn
7
7
 
8
8
 
9
- def aggregate_tensors(outputs: List[Any], aggregate_fn: Callable) -> Tensor:
9
+ def aggregate_tensors(
10
+ outputs: List[Any], aggregate_fn: Callable
11
+ ) -> Union[Tensor, Dict, List, None]:
10
12
  """
11
13
  Aggregates a list of outputs using the provided aggregation function.
12
14
 
@@ -84,7 +86,7 @@ class EnsembleModule(nn.Module):
84
86
  """
85
87
  return torch.stack(outputs).mean(dim=0)
86
88
 
87
- def forward(self, *args, **kwargs):
89
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
88
90
  """
89
91
  Performs a forward pass by averaging the outputs of the models.
90
92
 
@@ -150,7 +152,7 @@ class WeightedEnsembleModule(nn.Module):
150
152
  weights = cast(Tensor, self.weights).view(-1, *([1] * outputs[0].dim()))
151
153
  return (torch.stack(outputs) * weights).sum(dim=0)
152
154
 
153
- def forward(self, *args, **kwargs):
155
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
154
156
  """
155
157
  Performs a forward pass by computing the weighted average of the models' outputs.
156
158
 
@@ -49,7 +49,7 @@ def get_layer_wise_weights(
49
49
  return torch.full((num_models, num_layers), init_values, dtype=dtype)
50
50
 
51
51
 
52
- def _fuse_weights(layer_wise_weight: Tensor, tensors: List[Tensor]):
52
+ def _fuse_weights(layer_wise_weight: Tensor, tensors: List[Tensor]) -> Tensor:
53
53
  """
54
54
  Fuse the layer-wise weights with the given state dictionaries.
55
55