fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. fusion_bench/__init__.py +1 -0
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +5 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/dataset/clip_dataset.py +2 -1
  9. fusion_bench/dataset/gpt2_glue.py +9 -9
  10. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  11. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  12. fusion_bench/dataset/image_dataset.py +1 -1
  13. fusion_bench/dataset/nyuv2.py +2 -2
  14. fusion_bench/method/__init__.py +16 -1
  15. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  16. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  17. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  18. fusion_bench/method/base_algorithm.py +195 -12
  19. fusion_bench/method/bitdelta/__init__.py +4 -0
  20. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  21. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  25. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  26. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  27. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  28. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  29. fusion_bench/method/ensemble.py +12 -12
  30. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  31. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
  32. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  33. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  34. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  35. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  36. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  37. fusion_bench/method/linear/expo.py +2 -1
  38. fusion_bench/method/linear/linear_interpolation.py +6 -4
  39. fusion_bench/method/linear/simple_average_for_llama.py +16 -6
  40. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  41. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  42. fusion_bench/method/model_recombination.py +2 -5
  43. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  44. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  45. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  46. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  47. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  48. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  49. fusion_bench/method/randes/modelsoup.py +1 -3
  50. fusion_bench/method/regmean/clip_regmean.py +2 -2
  51. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  52. fusion_bench/method/regmean/regmean.py +2 -11
  53. fusion_bench/method/regmean_plusplus/__init__.py +3 -0
  54. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
  55. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
  56. fusion_bench/method/simple_average.py +16 -4
  57. fusion_bench/method/slerp/slerp.py +5 -2
  58. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  59. fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
  60. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
  61. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  62. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  63. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  64. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  65. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  66. fusion_bench/method/we_moe/we_moe.py +6 -6
  67. fusion_bench/method/weighted_average/llama.py +4 -16
  68. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  69. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  70. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  71. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  72. fusion_bench/mixins/__init__.py +10 -2
  73. fusion_bench/mixins/clip_classification.py +4 -3
  74. fusion_bench/mixins/hydra_config.py +105 -7
  75. fusion_bench/mixins/lightning_fabric.py +2 -0
  76. fusion_bench/mixins/serialization.py +265 -48
  77. fusion_bench/modelpool/__init__.py +2 -2
  78. fusion_bench/modelpool/base_pool.py +29 -9
  79. fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
  80. fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
  81. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  82. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  83. fusion_bench/models/__init__.py +2 -1
  84. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  85. fusion_bench/models/hf_utils.py +182 -0
  86. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  87. fusion_bench/models/linearized/vision_model.py +1 -1
  88. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  89. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  90. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  91. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  92. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  93. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  94. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  95. fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
  96. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  97. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
  98. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  99. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  100. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  101. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
  102. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  103. fusion_bench/models/parameter_dict.py +1 -1
  104. fusion_bench/models/sparse_we_moe.py +1 -53
  105. fusion_bench/models/utils.py +26 -0
  106. fusion_bench/models/we_moe.py +1 -53
  107. fusion_bench/models/wrappers/ensemble.py +6 -4
  108. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  109. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  110. fusion_bench/programs/base_program.py +81 -2
  111. fusion_bench/programs/fabric_fusion_program.py +24 -8
  112. fusion_bench/scripts/cli.py +6 -6
  113. fusion_bench/taskpool/base_pool.py +4 -3
  114. fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
  115. fusion_bench/taskpool/dummy.py +1 -1
  116. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  117. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  118. fusion_bench/utils/__init__.py +6 -1
  119. fusion_bench/utils/devices.py +14 -4
  120. fusion_bench/utils/instantiate_utils.py +3 -1
  121. fusion_bench/utils/misc.py +48 -2
  122. fusion_bench/utils/modelscope.py +265 -0
  123. fusion_bench/utils/parameters.py +2 -2
  124. fusion_bench/utils/rich_utils.py +3 -0
  125. fusion_bench/utils/state_dict_arithmetic.py +34 -27
  126. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
  127. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
  128. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  129. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  130. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  131. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  132. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  133. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  134. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  135. fusion_bench_config/hydra/default.yaml +6 -2
  136. fusion_bench_config/llama_full_finetune.yaml +1 -0
  137. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  138. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  139. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  140. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
  141. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
  142. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  143. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  144. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
  145. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
  168. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
  169. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
  170. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
  171. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
  172. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
  173. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  174. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  175. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  176. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  177. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  178. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  179. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  180. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  181. fusion_bench_config/nyuv2_config.yaml +3 -1
  182. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  183. fusion_bench_config/path/default.yaml +28 -0
  184. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  185. fusion_bench_config/method/adamerging.yaml +0 -23
  186. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  187. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  188. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  189. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
  190. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
  191. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
  192. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
  193. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -1,20 +1,148 @@
1
+ import inspect
1
2
  import logging
3
+ from copy import deepcopy
4
+ from functools import wraps
5
+ from inspect import Parameter, _ParameterKind
2
6
  from pathlib import Path
3
7
  from typing import Dict, Optional, Union
4
8
 
5
- from omegaconf import OmegaConf
9
+ from omegaconf import DictConfig, OmegaConf
6
10
 
11
+ from fusion_bench.constants import FUSION_BENCH_VERSION
7
12
  from fusion_bench.utils import import_object, instantiate
13
+ from fusion_bench.utils.instantiate_utils import set_print_function_call
8
14
 
9
15
  log = logging.getLogger(__name__)
10
16
 
17
+ __all__ = [
18
+ "YAMLSerializationMixin",
19
+ "auto_register_config",
20
+ "BaseYAMLSerializable",
21
+ ]
22
+
23
+
24
+ def auto_register_config(cls):
25
+ """
26
+ Decorator to automatically register __init__ parameters in _config_mapping.
27
+
28
+ This decorator enhances classes that inherit from YAMLSerializationMixin by
29
+ automatically mapping constructor parameters to configuration keys and
30
+ dynamically setting instance attributes based on provided arguments.
31
+
32
+ The decorator performs the following operations:
33
+ 1. Inspects the class's __init__ method signature
34
+ 2. Automatically populates the _config_mapping dictionary with parameter names
35
+ 3. Wraps the __init__ method to handle both positional and keyword arguments
36
+ 4. Sets instance attributes for all constructor parameters
37
+ 5. Applies default values when parameters are not provided
38
+
39
+ Args:
40
+ cls (YAMLSerializationMixin): The class to be decorated. Must inherit from
41
+ YAMLSerializationMixin to ensure proper serialization capabilities.
42
+
43
+ Returns:
44
+ YAMLSerializationMixin: The decorated class with enhanced auto-registration
45
+ functionality and modified __init__ behavior.
46
+
47
+ Behavior:
48
+ - **Parameter Registration**: All non-variadic parameters (excluding *args, **kwargs)
49
+ from the __init__ method are automatically added to _config_mapping
50
+ - **Positional Arguments**: Handled in order and mapped to corresponding parameter names
51
+ - **Keyword Arguments**: Processed after positional arguments, overriding any conflicts
52
+ - **Default Values**: Applied when parameters are not provided via arguments
53
+ - **Attribute Setting**: All parameters become instance attributes accessible via dot notation
54
+
55
+ Example:
56
+ ```python
57
+ @auto_register_config
58
+ class MyAlgorithm(BaseYAMLSerializable):
59
+ def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default"):
60
+ super().__init__()
61
+
62
+ # All instantiation methods work automatically:
63
+ algo1 = MyAlgorithm(0.01, 64) # positional args
64
+ algo2 = MyAlgorithm(learning_rate=0.01, model_name="bert") # keyword args
65
+ algo3 = MyAlgorithm(0.01, batch_size=128, model_name="gpt") # mixed args
66
+
67
+ # Attributes are automatically set and can be serialized:
68
+ print(algo1.learning_rate) # 0.01
69
+ print(algo1.batch_size) # 64
70
+ print(algo1.model_name) # "default" (from default value)
71
+
72
+ config = algo1.config
73
+ # DictConfig({'_target_': 'MyAlgorithm', 'learning_rate': 0.01, 'batch_size': 64, 'model_name': 'default'})
74
+ ```
75
+
76
+ Note:
77
+ - The decorator wraps the original __init__ method while preserving its signature for IDE support
78
+ - Parameters with *args or **kwargs signatures are ignored during registration
79
+ - The attributes are auto-registered, then the original __init__ method is called,
80
+ - Type hints, method name, and other metadata are preserved using functools.wraps
81
+ - This decorator is designed to work seamlessly with the YAML serialization system
82
+
83
+ Raises:
84
+ AttributeError: If the class does not have the required _config_mapping attribute
85
+ infrastructure (should inherit from YAMLSerializationMixin)
86
+ """
87
+ original_init = cls.__init__
88
+ sig = inspect.signature(original_init)
89
+
90
+ # Auto-register parameters in _config_mapping
91
+ if not "_config_mapping" in cls.__dict__:
92
+ cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", {}))
93
+ for param_name in list(sig.parameters.keys())[1:]: # Skip 'self'
94
+ if sig.parameters[param_name].kind not in [
95
+ _ParameterKind.VAR_POSITIONAL,
96
+ _ParameterKind.VAR_KEYWORD,
97
+ ]:
98
+ cls._config_mapping[param_name] = param_name
99
+
100
+ def __init__(self, *args, **kwargs):
101
+ # auto-register the attributes based on the signature
102
+ sig = inspect.signature(original_init)
103
+ param_names = list(sig.parameters.keys())[1:] # Skip 'self'
104
+
105
+ # Handle positional arguments
106
+ for i, arg_value in enumerate(args):
107
+ if i < len(param_names):
108
+ param_name = param_names[i]
109
+ if sig.parameters[param_name].kind not in [
110
+ _ParameterKind.VAR_POSITIONAL,
111
+ _ParameterKind.VAR_KEYWORD,
112
+ ]:
113
+ setattr(self, param_name, arg_value)
114
+
115
+ # Handle keyword arguments and defaults
116
+ for param_name in param_names:
117
+ if sig.parameters[param_name].kind not in [
118
+ _ParameterKind.VAR_POSITIONAL,
119
+ _ParameterKind.VAR_KEYWORD,
120
+ ]:
121
+ # Skip if already set by positional argument
122
+ param_index = param_names.index(param_name)
123
+ if param_index >= 0 and param_index < len(args):
124
+ continue
125
+
126
+ if param_name in kwargs:
127
+ setattr(self, param_name, kwargs[param_name])
128
+ else:
129
+ # Set default value if available and attribute doesn't exist
130
+ default_value = sig.parameters[param_name].default
131
+ if default_value is not Parameter.empty:
132
+ setattr(self, param_name, default_value)
133
+
134
+ # Call the original __init__
135
+ result = original_init(self, *args, **kwargs)
136
+ return result
137
+
138
+ # Replace the original __init__ method while preserving its signature
139
+ cls.__init__ = __init__
140
+ return cls
141
+
11
142
 
12
143
  class YAMLSerializationMixin:
13
- _recursive_: bool = False
14
144
  _config_key: Optional[str] = None
15
- _config_mapping: Dict[str, str] = {
16
- "_recursive_": "_recursive_",
17
- }
145
+ _config_mapping: Dict[str, str] = {}
18
146
  R"""
19
147
  `_config_mapping` is a dictionary mapping the attribute names of the class to the config option names. This is used to convert the class to a DictConfig.
20
148
 
@@ -47,46 +175,50 @@ class YAMLSerializationMixin:
47
175
  By default, the `_target_` key is set to the class name as `type(self).__name__`.
48
176
  """
49
177
 
50
- def __init__(
51
- self,
52
- _recursive_: bool = False,
53
- **kwargs,
54
- ) -> None:
55
- self._recursive_ = _recursive_
178
+ def __init__(self, **kwargs) -> None:
56
179
  for key, value in kwargs.items():
57
180
  log.warning(f"Unused argument: {key}={value}")
58
181
 
59
182
  @property
60
- def config(self):
183
+ def config(self) -> DictConfig:
61
184
  R"""
62
185
  Returns the configuration of the model pool as a DictConfig.
63
186
 
64
- This property calls the `to_config` method to convert the model pool
65
- instance into a dictionary configuration, which can be used for
66
- serialization or other purposes.
187
+ This property converts the model pool instance into a dictionary
188
+ configuration, which can be used for serialization or other purposes.
67
189
 
68
190
  Example:
69
- >>> model = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
70
- >>> config = model.config
71
- >>> print(config)
72
- DictConfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
191
+
192
+ ```python
193
+ model = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
194
+ config = model.config
195
+ print(config)
196
+ # DictConfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
197
+ ```
73
198
 
74
199
  This is useful for serializing the object to a YAML file or for debugging.
75
200
 
76
201
  Returns:
77
202
  DictConfig: The configuration of the model pool.
78
203
  """
79
- return self.to_config()
204
+ config = {"_target_": f"{type(self).__module__}.{type(self).__qualname__}"}
205
+ for attr, key in self._config_mapping.items():
206
+ if hasattr(self, attr):
207
+ config[key] = getattr(self, attr)
80
208
 
81
- def to_yaml(self, path: Union[str, Path]):
209
+ try:
210
+ return OmegaConf.create(config)
211
+ except Exception as e:
212
+ return OmegaConf.create(config, flags={"allow_objects": True})
213
+
214
+ def to_yaml(self, path: Union[str, Path], resolve: bool = True):
82
215
  """
83
216
  Save the model pool to a YAML file.
84
217
 
85
218
  Args:
86
219
  path (Union[str, Path]): The path to save the model pool to.
87
220
  """
88
- config = self.to_config()
89
- OmegaConf.save(config, path, resolve=True)
221
+ OmegaConf.save(self.config, path, resolve=resolve)
90
222
 
91
223
  @classmethod
92
224
  def from_yaml(cls, path: Union[str, Path]):
@@ -108,41 +240,126 @@ class YAMLSerializationMixin:
108
240
  f"The class {target_cls.__name__} is not the same as the class {cls.__name__}. "
109
241
  f"Instantiating the class {target_cls.__name__} instead."
110
242
  )
111
- return instantiate(
112
- config,
113
- _recursive_=(
114
- cls._recursive_
115
- if config.get("_recursive_") is None
116
- else config.get("_recursive_")
117
- ),
118
- )
243
+ with set_print_function_call(False):
244
+ return instantiate(config)
119
245
 
120
- def to_config(self):
246
+ def register_parameter_to_config(
247
+ self,
248
+ attr_name: str,
249
+ param_name: str,
250
+ value,
251
+ ):
121
252
  """
122
- Convert the model pool to a DictConfig.
253
+ Set an attribute value and register its config mapping.
123
254
 
124
- Returns:
125
- Dict: The model pool as a DictConfig.
255
+ This method allows dynamic setting of object attributes while simultaneously
256
+ updating the configuration mapping that defines how the attribute should
257
+ be serialized in the configuration output.
258
+
259
+ Args:
260
+ attr_name (str): The name of the attribute to set on this object.
261
+ arg_name (str): The corresponding parameter name to use in the config
262
+ serialization. This is how the attribute will appear in YAML output.
263
+ value: The value to assign to the attribute.
264
+
265
+ Example:
266
+ ```python
267
+ model = BaseYAMLSerializable()
268
+ model.set_option("learning_rate", "lr", 0.001)
269
+
270
+ # This sets model.learning_rate = 0.001
271
+ # and maps it to "lr" in the config output
272
+ config = model.config
273
+ # config will contain: {"lr": 0.001, ...}
274
+ ```
126
275
  """
127
- config = {"_target_": type(self).__name__}
128
- for attr, key in self._config_mapping.items():
129
- if hasattr(self, attr):
130
- config[key] = getattr(self, attr)
131
- return OmegaConf.create(config)
276
+ setattr(self, attr_name, value)
277
+ self._config_mapping[attr_name] = param_name
278
+
279
+
280
+ @auto_register_config
281
+ class BaseYAMLSerializable(YAMLSerializationMixin):
282
+ """
283
+ A base class for YAML-serializable classes with enhanced metadata support.
284
+
285
+ This class extends `YAMLSerializationMixin` to provide additional metadata
286
+ fields commonly used in FusionBench classes, including usage information
287
+ and version tracking. It serves as a foundation for all serializable
288
+ model components in the framework.
289
+
290
+ The class automatically handles serialization of usage and version metadata
291
+ alongside the standard configuration parameters, making it easier to track
292
+ model provenance and intended usage patterns.
132
293
 
294
+ Attributes:
295
+ _usage_ (Optional[str]): Description of the model's intended usage or purpose.
296
+ _version_ (Optional[str]): Version information for the model or configuration.
133
297
 
134
- class BaseYAMLSerializableModel(YAMLSerializationMixin):
135
- _config_mapping = YAMLSerializationMixin._config_mapping | {
136
- "_usage_": "_usage_",
137
- "_version_": "_version_",
138
- }
298
+ Example:
299
+ ```python
300
+ class MyAlgorithm(BaseYAMLSerializable):
301
+ _config_mapping = BaseYAMLSerializable._config_mapping | {
302
+ "model_name": "model_name",
303
+ "num_layers": "num_layers",
304
+ }
305
+
306
+ def __init__(self, _usage_: str = None, _version_: str = None):
307
+ super().__init__(_usage_=_usage_, _version_=_version_)
308
+
309
+ # Usage with metadata
310
+ model = MyAlgorithm(
311
+ _usage_="Text classification fine-tuning",
312
+ _version_="1.0.0"
313
+ )
314
+
315
+ # Serialization includes metadata
316
+ config = model.config
317
+ # DictConfig({
318
+ # '_target_': 'MyModel',
319
+ # '_usage_': 'Text classification fine-tuning',
320
+ # '_version_': '1.0.0'
321
+ # })
322
+ ```
323
+
324
+ Note:
325
+ The underscore prefix in `_usage_` and `_version_` follows the convention
326
+ for metadata fields that are not core model parameters but provide
327
+ important contextual information for model management and tracking.
328
+ """
139
329
 
140
330
  def __init__(
141
331
  self,
332
+ _recursive_: bool = False,
142
333
  _usage_: Optional[str] = None,
143
- _version_: Optional[str] = None,
334
+ _version_: Optional[str] = FUSION_BENCH_VERSION,
144
335
  **kwargs,
145
336
  ):
337
+ """
338
+ Initialize a base YAML-serializable model with metadata support.
339
+
340
+ Args:
341
+ _usage_ (Optional[str], optional): Description of the model's intended
342
+ usage or purpose. This can include information about the training
343
+ domain, expected input types, or specific use cases. Defaults to None.
344
+ _version_ (Optional[str], optional): Version information for the model
345
+ or configuration. Can be used to track model iterations, dataset
346
+ versions, or compatibility information. Defaults to None.
347
+ **kwargs: Additional keyword arguments passed to the parent class.
348
+ Unused arguments will trigger warnings via the parent's initialization.
349
+
350
+ Example:
351
+ ```python
352
+ model = BaseYAMLSerializable(
353
+ _usage_="Image classification on CIFAR-10",
354
+ _version_="2.1.0"
355
+ )
356
+ ```
357
+ """
146
358
  super().__init__(**kwargs)
147
- self._usage_ = _usage_
148
- self._version_ = _version_
359
+ if _version_ != FUSION_BENCH_VERSION:
360
+ log.warning(
361
+ f"Current fusion-bench version is {FUSION_BENCH_VERSION}, but the serialized version is {_version_}. "
362
+ "Attempting to use current version."
363
+ )
364
+ # override _version_ with current fusion-bench version
365
+ self._version_ = FUSION_BENCH_VERSION
@@ -17,7 +17,7 @@ _import_structure = {
17
17
  "HuggingFaceGPT2ClassificationPool",
18
18
  "GPT2ForSequenceClassificationPool",
19
19
  ],
20
- "seq_classification_lm": ["SeqenceClassificationModelPool"],
20
+ "seq_classification_lm": ["SequenceClassificationModelPool"],
21
21
  }
22
22
 
23
23
 
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
34
34
  from .openclip_vision import OpenCLIPVisionModelPool
35
35
  from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
36
36
  from .seq2seq_lm import Seq2SeqLMPool
37
- from .seq_classification_lm import SeqenceClassificationModelPool
37
+ from .seq_classification_lm import SequenceClassificationModelPool
38
38
 
39
39
  else:
40
40
  sys.modules[__name__] = LazyImporter(
@@ -1,13 +1,13 @@
1
1
  import logging
2
2
  from copy import deepcopy
3
- from typing import Dict, List, Optional, Union
3
+ from typing import Dict, Generator, List, Optional, Tuple, Union
4
4
 
5
5
  import torch
6
6
  from omegaconf import DictConfig
7
7
  from torch import nn
8
8
  from torch.utils.data import Dataset
9
9
 
10
- from fusion_bench.mixins import BaseYAMLSerializableModel, HydraConfigMixin
10
+ from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
11
11
  from fusion_bench.utils import instantiate, timeit_context
12
12
 
13
13
  __all__ = ["BaseModelPool"]
@@ -15,7 +15,10 @@ __all__ = ["BaseModelPool"]
15
15
  log = logging.getLogger(__name__)
16
16
 
17
17
 
18
- class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
18
+ class BaseModelPool(
19
+ HydraConfigMixin,
20
+ BaseYAMLSerializable,
21
+ ):
19
22
  """
20
23
  A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.
21
24
 
@@ -31,7 +34,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
31
34
  _program = None
32
35
  _config_key = "modelpool"
33
36
  _models: Union[DictConfig, Dict[str, nn.Module]]
34
- _config_mapping = BaseYAMLSerializableModel._config_mapping | {
37
+ _config_mapping = BaseYAMLSerializable._config_mapping | {
35
38
  "_models": "models",
36
39
  "_train_datasets": "train_datasets",
37
40
  "_val_datasets": "val_datasets",
@@ -56,7 +59,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
56
59
  super().__init__(**kwargs)
57
60
 
58
61
  @property
59
- def has_pretrained(self):
62
+ def has_pretrained(self) -> bool:
60
63
  """
61
64
  Check if the model pool contains a pretrained model.
62
65
 
@@ -125,7 +128,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
125
128
  return len(self.model_names)
126
129
 
127
130
  @staticmethod
128
- def is_special_model(model_name: str):
131
+ def is_special_model(model_name: str) -> bool:
129
132
  """
130
133
  Determine if a model is special based on its name.
131
134
 
@@ -152,6 +155,23 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
152
155
  model_config = deepcopy(model_config)
153
156
  return model_config
154
157
 
158
+ def get_model_path(self, model_name: str) -> str:
159
+ """
160
+ Get the path for the specified model.
161
+
162
+ Args:
163
+ model_name (str): The name of the model.
164
+
165
+ Returns:
166
+ str: The path for the specified model.
167
+ """
168
+ if isinstance(self._models[model_name], str):
169
+ return self._models[model_name]
170
+ else:
171
+ raise ValueError(
172
+ "Model path is not a string. Try to override this method in derived modelpool class."
173
+ )
174
+
155
175
  def load_model(
156
176
  self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
157
177
  ) -> nn.Module:
@@ -159,7 +179,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
159
179
  Load a model from the pool based on the provided configuration.
160
180
 
161
181
  Args:
162
- model (Union[str, DictConfig]): The model name or configuration.
182
+ model_name_or_config (Union[str, DictConfig]): The model name or configuration.
163
183
 
164
184
  Returns:
165
185
  nn.Module: The instantiated model.
@@ -201,11 +221,11 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
201
221
  model = self.load_model(self.model_names[0], *args, **kwargs)
202
222
  return model
203
223
 
204
- def models(self):
224
+ def models(self) -> Generator[nn.Module, None, None]:
205
225
  for model_name in self.model_names:
206
226
  yield self.load_model(model_name)
207
227
 
208
- def named_models(self):
228
+ def named_models(self) -> Generator[Tuple[str, nn.Module], None, None]:
209
229
  for model_name in self.model_names:
210
230
  yield model_name, self.load_model(model_name)
211
231
 
@@ -57,6 +57,15 @@ class CausalLMPool(BaseModelPool):
57
57
  )
58
58
  self.load_lazy = load_lazy
59
59
 
60
+ def get_model_path(self, model_name: str):
61
+ model_name_or_config = self._models[model_name]
62
+ if isinstance(model_name_or_config, str):
63
+ return model_name_or_config
64
+ elif isinstance(model_name_or_config, (DictConfig, dict)):
65
+ return model_name_or_config.get("pretrained_model_name_or_path")
66
+ else:
67
+ raise RuntimeError("Invalid model configuration")
68
+
60
69
  @override
61
70
  def load_model(
62
71
  self,
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  from copy import deepcopy
3
- from typing import Optional, Union
3
+ from typing import Literal, Optional, Union
4
4
 
5
5
  from datasets import load_dataset
6
6
  from lightning.fabric.utilities import rank_zero_only
@@ -11,6 +11,7 @@ from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
11
11
  from typing_extensions import override
12
12
 
13
13
  from fusion_bench.utils import instantiate, timeit_context
14
+ from fusion_bench.utils.modelscope import resolve_repo_path
14
15
 
15
16
  from ..base_pool import BaseModelPool
16
17
 
@@ -25,25 +26,32 @@ class CLIPVisionModelPool(BaseModelPool):
25
26
  the specifics of the CLIP Vision models provided by the Hugging Face Transformers library.
26
27
  """
27
28
 
28
- _config_mapping = BaseModelPool._config_mapping | {"_processor": "processor"}
29
+ _config_mapping = BaseModelPool._config_mapping | {
30
+ "_processor": "processor",
31
+ "_platform": "hf",
32
+ }
29
33
 
30
34
  def __init__(
31
35
  self,
32
36
  models: DictConfig,
33
37
  *,
34
38
  processor: Optional[DictConfig] = None,
39
+ platform: Literal["hf", "huggingface", "modelscope"] = "hf",
35
40
  **kwargs,
36
41
  ):
37
42
  super().__init__(models, **kwargs)
38
-
39
43
  self._processor = processor
44
+ self._platform = platform
40
45
 
41
46
  def load_processor(self, *args, **kwargs) -> CLIPProcessor:
42
47
  assert self._processor is not None, "Processor is not defined in the config"
43
48
  if isinstance(self._processor, str):
44
49
  if rank_zero_only.rank == 0:
45
50
  log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
46
- processor = CLIPProcessor.from_pretrained(self._processor)
51
+ repo_path = resolve_repo_path(
52
+ repo_id=self._processor, repo_type="model", platform=self._platform
53
+ )
54
+ processor = CLIPProcessor.from_pretrained(repo_path, *args, **kwargs)
47
55
  else:
48
56
  processor = instantiate(self._processor, *args, **kwargs)
49
57
  return processor
@@ -54,7 +62,10 @@ class CLIPVisionModelPool(BaseModelPool):
54
62
  if isinstance(model_config, str):
55
63
  if rank_zero_only.rank == 0:
56
64
  log.info(f"Loading `transformers.CLIPModel`: {model_config}")
57
- clip_model = CLIPModel.from_pretrained(model_config, *args, **kwargs)
65
+ repo_path = resolve_repo_path(
66
+ repo_id=model_config, repo_type="model", platform=self._platform
67
+ )
68
+ clip_model = CLIPModel.from_pretrained(repo_path, *args, **kwargs)
58
69
  return clip_model
59
70
  else:
60
71
  assert isinstance(
@@ -107,14 +118,17 @@ class CLIPVisionModelPool(BaseModelPool):
107
118
  if isinstance(model, str):
108
119
  if rank_zero_only.rank == 0:
109
120
  log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
110
- return CLIPVisionModel.from_pretrained(model, *args, **kwargs)
121
+ repo_path = resolve_repo_path(
122
+ model, repo_type="model", platform=self._platform
123
+ )
124
+ return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)
111
125
  if isinstance(model, nn.Module):
112
126
  if rank_zero_only.rank == 0:
113
127
  log.info(f"Returning existing model: {model}")
114
128
  return model
115
-
116
- # If the model is not a string, we use the default load_model method
117
- return super().load_model(model_name_or_config, *args, **kwargs)
129
+ else:
130
+ # If the model is not a string, we use the default load_model method
131
+ return super().load_model(model_name_or_config, *args, **kwargs)
118
132
 
119
133
  def load_train_dataset(self, dataset_name: str, *args, **kwargs):
120
134
  dataset_config = self._train_datasets[dataset_name]
@@ -123,7 +137,7 @@ class CLIPVisionModelPool(BaseModelPool):
123
137
  log.info(
124
138
  f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
125
139
  )
126
- dataset = load_dataset(dataset_config, split="train")
140
+ dataset = self._load_dataset(dataset_config, split="train")
127
141
  else:
128
142
  dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
129
143
  return dataset
@@ -135,7 +149,7 @@ class CLIPVisionModelPool(BaseModelPool):
135
149
  log.info(
136
150
  f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
137
151
  )
138
- dataset = load_dataset(dataset_config, split="validation")
152
+ dataset = self._load_dataset(dataset_config, split="validation")
139
153
  else:
140
154
  dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
141
155
  return dataset
@@ -147,7 +161,24 @@ class CLIPVisionModelPool(BaseModelPool):
147
161
  log.info(
148
162
  f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
149
163
  )
150
- dataset = load_dataset(dataset_config, split="test")
164
+ dataset = self._load_dataset(dataset_config, split="test")
151
165
  else:
152
166
  dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
153
167
  return dataset
168
+
169
+ def _load_dataset(self, name: str, split: str):
170
+ """
171
+ Load a dataset by its name and split.
172
+
173
+ Args:
174
+ dataset_name (str): The name of the dataset.
175
+ split (str): The split of the dataset to load (e.g., "train", "validation", "test").
176
+
177
+ Returns:
178
+ Dataset: The loaded dataset.
179
+ """
180
+ datset_dir = resolve_repo_path(
181
+ name, repo_type="dataset", platform=self._platform
182
+ )
183
+ dataset = load_dataset(datset_dir, split=split)
184
+ return dataset
@@ -1,2 +1,2 @@
1
1
  from .reward_model import create_reward_model_from_pretrained
2
- from .seq_classification_lm import SeqenceClassificationModelPool
2
+ from .seq_classification_lm import SequenceClassificationModelPool
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
17
17
  log = logging.getLogger(__name__)
18
18
 
19
19
 
20
- class SeqenceClassificationModelPool(BaseModelPool):
20
+ class SequenceClassificationModelPool(BaseModelPool):
21
21
 
22
22
  def __init__(
23
23
  self,
@@ -1,4 +1,5 @@
1
1
  # flake8: noqa F401
2
+ from fusion_bench.utils import LazyStateDict
3
+
2
4
  from . import separate_io, utils
3
5
  from .parameter_dict import ParameterDictModel
4
- from fusion_bench.utils import LazyStateDict
@@ -10,6 +10,6 @@ Reference:
10
10
  """
11
11
 
12
12
  from .wrapper import (
13
- PrunableMixtralSparseMoeBlockWrapper,
14
13
  DynamicSkippingMixtralSparseMoeBlockWrapper,
14
+ PrunableMixtralSparseMoeBlockWrapper,
15
15
  )