fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. fusion_bench/__init__.py +22 -2
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +6 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/constants/runtime.py +57 -0
  9. fusion_bench/dataset/clip_dataset.py +2 -1
  10. fusion_bench/dataset/gpt2_glue.py +9 -9
  11. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  12. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  13. fusion_bench/dataset/image_dataset.py +1 -1
  14. fusion_bench/dataset/nyuv2.py +2 -2
  15. fusion_bench/method/__init__.py +24 -5
  16. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  17. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  19. fusion_bench/method/base_algorithm.py +195 -12
  20. fusion_bench/method/bitdelta/__init__.py +5 -0
  21. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  25. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  26. fusion_bench/method/classification/clip_finetune.py +1 -1
  27. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  28. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  29. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  30. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  31. fusion_bench/method/ensemble.py +12 -12
  32. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  33. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
  34. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  35. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  36. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  37. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  38. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  39. fusion_bench/method/linear/expo.py +2 -1
  40. fusion_bench/method/linear/linear_interpolation.py +6 -4
  41. fusion_bench/method/linear/simple_average_for_llama.py +17 -13
  42. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  43. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  44. fusion_bench/method/model_recombination.py +2 -5
  45. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  46. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  47. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  48. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  49. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  50. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  51. fusion_bench/method/randes/modelsoup.py +1 -3
  52. fusion_bench/method/regmean/clip_regmean.py +2 -2
  53. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  54. fusion_bench/method/regmean/regmean.py +2 -11
  55. fusion_bench/method/regmean_plusplus/__init__.py +1 -1
  56. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
  57. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
  58. fusion_bench/method/simple_average.py +12 -16
  59. fusion_bench/method/slerp/slerp.py +5 -2
  60. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  61. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  62. fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
  63. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  64. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
  65. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  66. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  67. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  68. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  69. fusion_bench/method/we_moe/__init__.py +1 -0
  70. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  71. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  72. fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
  73. fusion_bench/method/we_moe/utils.py +15 -0
  74. fusion_bench/method/we_moe/we_moe.py +6 -6
  75. fusion_bench/method/weighted_average/llama.py +4 -16
  76. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  77. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  78. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  79. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  80. fusion_bench/mixins/__init__.py +10 -2
  81. fusion_bench/mixins/clip_classification.py +15 -45
  82. fusion_bench/mixins/hydra_config.py +105 -7
  83. fusion_bench/mixins/lightning_fabric.py +2 -0
  84. fusion_bench/mixins/serialization.py +275 -48
  85. fusion_bench/modelpool/__init__.py +2 -2
  86. fusion_bench/modelpool/base_pool.py +29 -9
  87. fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
  88. fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
  89. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  90. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  91. fusion_bench/models/__init__.py +7 -1
  92. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  93. fusion_bench/models/hf_utils.py +160 -0
  94. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  95. fusion_bench/models/linearized/vision_model.py +1 -1
  96. fusion_bench/models/model_card_templates/default.md +46 -0
  97. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  98. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  99. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  100. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  101. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  102. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  103. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  104. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  105. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  106. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
  107. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  108. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  109. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  110. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
  111. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  112. fusion_bench/models/parameter_dict.py +1 -1
  113. fusion_bench/models/sparse_we_moe.py +1 -53
  114. fusion_bench/models/utils.py +26 -0
  115. fusion_bench/models/we_moe.py +1 -53
  116. fusion_bench/models/wrappers/ensemble.py +6 -4
  117. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  118. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  119. fusion_bench/programs/base_program.py +81 -2
  120. fusion_bench/programs/fabric_fusion_program.py +46 -61
  121. fusion_bench/scripts/cli.py +38 -5
  122. fusion_bench/taskpool/base_pool.py +4 -3
  123. fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
  124. fusion_bench/taskpool/dummy.py +1 -1
  125. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  126. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  127. fusion_bench/utils/__init__.py +7 -1
  128. fusion_bench/utils/cache_utils.py +101 -1
  129. fusion_bench/utils/devices.py +14 -4
  130. fusion_bench/utils/fabric.py +2 -2
  131. fusion_bench/utils/instantiate_utils.py +3 -1
  132. fusion_bench/utils/lazy_imports.py +23 -0
  133. fusion_bench/utils/lazy_state_dict.py +38 -3
  134. fusion_bench/utils/modelscope.py +127 -8
  135. fusion_bench/utils/parameters.py +2 -2
  136. fusion_bench/utils/path.py +56 -0
  137. fusion_bench/utils/pylogger.py +1 -1
  138. fusion_bench/utils/rich_utils.py +3 -0
  139. fusion_bench/utils/state_dict_arithmetic.py +25 -23
  140. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
  141. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
  142. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  143. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  144. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  145. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  146. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  147. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  148. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  149. fusion_bench_config/hydra/default.yaml +6 -2
  150. fusion_bench_config/llama_full_finetune.yaml +1 -0
  151. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  152. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  153. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  154. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  155. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  156. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  157. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  158. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  159. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
  160. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
  167. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
  168. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  169. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  170. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  171. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  172. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  173. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  174. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  175. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  176. fusion_bench_config/nyuv2_config.yaml +3 -1
  177. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  178. fusion_bench_config/path/default.yaml +28 -0
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  180. fusion_bench_config/method/adamerging.yaml +0 -23
  181. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  182. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  183. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  184. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
  185. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
  186. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
  187. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
  188. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -1,8 +1,20 @@
1
+ """
2
+ Hydra Configuration Mixin for FusionBench.
3
+
4
+ This module provides a mixin class that enables easy instantiation of objects
5
+ from Hydra configuration files. It's designed to work seamlessly with the
6
+ FusionBench configuration system and supports dynamic object creation based
7
+ on YAML configuration files.
8
+
9
+ The mixin integrates with Hydra's configuration management system to provide
10
+ a clean interface for creating objects from structured configurations.
11
+ """
12
+
1
13
  import logging
2
14
  import os
3
15
  from copy import deepcopy
4
16
  from pathlib import Path
5
- from typing import Dict, List, Optional, Union
17
+ from typing import Dict, List, Optional, TypeVar, Union
6
18
 
7
19
  import hydra.core.global_hydra
8
20
  from hydra import compose, initialize
@@ -13,10 +25,39 @@ from fusion_bench.utils.instantiate_utils import set_print_function_call
13
25
 
14
26
  log = logging.getLogger(__name__)
15
27
 
28
+ T = TypeVar("T", bound="HydraConfigMixin")
29
+
16
30
 
17
31
  class HydraConfigMixin:
18
- """
19
- A mixin for classes that need to be instantiated from a config file.
32
+ R"""
33
+ A mixin class that provides configuration-based instantiation capabilities.
34
+
35
+ This mixin enables classes to be instantiated directly from Hydra configuration
36
+ files, supporting both direct instantiation and target-based instantiation patterns.
37
+ It's particularly useful in FusionBench for creating model pools, task pools,
38
+ and fusion algorithms from YAML configurations.
39
+
40
+ The mixin handles:
41
+ - Configuration loading and composition
42
+ - Target class validation
43
+ - Nested configuration group navigation
44
+ - Object instantiation with proper error handling
45
+
46
+ Example:
47
+
48
+ ```python
49
+ class MyAlgorithm(HydraConfigMixin):
50
+ def __init__(self, param1: str, param2: int = 10):
51
+ self.param1 = param1
52
+ self.param2 = param2
53
+
54
+ # Instantiate from config
55
+ algorithm = MyAlgorithm.from_config("algorithms/my_algorithm")
56
+ ```
57
+
58
+ Note:
59
+ This mixin requires Hydra to be properly initialized before use.
60
+ Typically, this is handled by the main FusionBench CLI application.
20
61
  """
21
62
 
22
63
  @classmethod
@@ -24,26 +65,83 @@ class HydraConfigMixin:
24
65
  cls,
25
66
  config_name: Union[str, Path],
26
67
  overrides: Optional[List[str]] = None,
27
- ):
68
+ ) -> T:
69
+ """
70
+ Create an instance of the class from a Hydra configuration.
71
+
72
+ This method loads a Hydra configuration file and instantiates the class
73
+ using the configuration parameters. It supports both direct parameter
74
+ passing and target-based instantiation patterns.
75
+
76
+ Args:
77
+ config_name: The name/path of the configuration file to load.
78
+ Can be a string like "algorithms/simple_average" or
79
+ a Path object. The .yaml extension is optional.
80
+ overrides: Optional list of configuration overrides in the format
81
+ ["key=value", "nested.key=value"]. These allow runtime
82
+ modification of configuration parameters.
83
+
84
+ Returns:
85
+ An instance of the class configured according to the loaded configuration.
86
+
87
+ Raises:
88
+ RuntimeError: If Hydra is not properly initialized.
89
+ ImportError: If a target class specified in the config cannot be imported.
90
+ ValueError: If required configuration parameters are missing.
91
+
92
+ Example:
93
+ ```python
94
+ # Load with basic config
95
+ obj = MyClass.from_config("my_config")
96
+
97
+ # Load with overrides
98
+ obj = MyClass.from_config(
99
+ "my_config",
100
+ overrides=["param1=new_value", "param2=42"]
101
+ )
102
+
103
+ # Load nested config
104
+ obj = MyClass.from_config("category/subcategory/my_config")
105
+ ```
106
+
107
+ Note:
108
+ The method automatically handles nested configuration groups by
109
+ navigating through the configuration hierarchy based on the
110
+ config_name path structure.
111
+ """
112
+ # Verify Hydra initialization
28
113
  if not hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
29
- raise RuntimeError("Hydra is not initialized.")
114
+ raise RuntimeError(
115
+ "Hydra is not initialized. Please ensure Hydra is properly "
116
+ "initialized before calling from_config(). This is typically "
117
+ "handled by the FusionBench CLI application."
118
+ )
30
119
  else:
120
+ # Compose the configuration with any provided overrides
31
121
  cfg = compose(config_name=config_name, overrides=overrides)
32
122
 
123
+ # Navigate through nested configuration groups
124
+ # E.g., "algorithms/simple_average" -> navigate to cfg.algorithms
33
125
  config_groups = config_name.split("/")[:-1]
34
126
  for config_group in config_groups:
35
127
  cfg = cfg[config_group]
36
128
 
129
+ # Handle target-based instantiation
37
130
  if "_target_" in cfg:
38
- # if the config has a _target_ key, check if it is equal to the class name
131
+ # Validate that the target class matches the calling class
39
132
  target_cls = import_object(cfg["_target_"])
40
133
  if target_cls != cls:
41
134
  log.warning(
42
- f"The _target_ key in the config is {cfg['_target_']}, but the class name is {cls.__name__}."
135
+ f"Configuration target mismatch: config specifies "
136
+ f"'{cfg['_target_']}' but called on class '{cls.__name__}'. "
137
+ f"This may indicate a configuration error."
43
138
  )
139
+
140
+ # Instantiate using the target pattern with function call logging disabled
44
141
  with set_print_function_call(False):
45
142
  obj = instantiate(cfg)
46
143
  else:
144
+ # Direct instantiation using configuration as keyword arguments
47
145
  obj = cls(**cfg)
48
146
 
49
147
  return obj
@@ -52,9 +52,11 @@ class LightningFabricMixin:
52
52
  and nodes, with support for custom logging via TensorBoard.
53
53
 
54
54
  Attributes:
55
+
55
56
  - _fabric (L.Fabric): The Lightning Fabric instance used for distributed computing.
56
57
 
57
58
  Note:
59
+
58
60
  This mixin is designed to be used with classes that require distributed computing capabilities and wish to
59
61
  leverage the Lightning Fabric for this purpose. It assumes the presence of a `config` attribute or parameter
60
62
  in the consuming class for configuration.
@@ -1,20 +1,158 @@
1
+ import inspect
1
2
  import logging
3
+ from copy import deepcopy
4
+ from functools import wraps
5
+ from inspect import Parameter, _ParameterKind
2
6
  from pathlib import Path
3
7
  from typing import Dict, Optional, Union
4
8
 
5
- from omegaconf import OmegaConf
9
+ from omegaconf import DictConfig, OmegaConf
6
10
 
11
+ from fusion_bench.constants import FUSION_BENCH_VERSION
7
12
  from fusion_bench.utils import import_object, instantiate
13
+ from fusion_bench.utils.instantiate_utils import set_print_function_call
8
14
 
9
15
  log = logging.getLogger(__name__)
10
16
 
17
+ __all__ = [
18
+ "YAMLSerializationMixin",
19
+ "auto_register_config",
20
+ "BaseYAMLSerializable",
21
+ ]
22
+
23
+
24
+ def auto_register_config(cls):
25
+ """
26
+ Decorator to automatically register __init__ parameters in _config_mapping.
27
+
28
+ This decorator enhances classes that inherit from YAMLSerializationMixin by
29
+ automatically mapping constructor parameters to configuration keys and
30
+ dynamically setting instance attributes based on provided arguments.
31
+
32
+ The decorator performs the following operations:
33
+ 1. Inspects the class's __init__ method signature
34
+ 2. Automatically populates the _config_mapping dictionary with parameter names
35
+ 3. Wraps the __init__ method to handle both positional and keyword arguments
36
+ 4. Sets instance attributes for all constructor parameters
37
+ 5. Applies default values when parameters are not provided
38
+
39
+ Args:
40
+ cls (YAMLSerializationMixin): The class to be decorated. Must inherit from
41
+ YAMLSerializationMixin to ensure proper serialization capabilities.
42
+
43
+ Returns:
44
+ YAMLSerializationMixin: The decorated class with enhanced auto-registration
45
+ functionality and modified __init__ behavior.
46
+
47
+ Behavior:
48
+ - **Parameter Registration**: All non-variadic parameters (excluding *args, **kwargs)
49
+ from the __init__ method are automatically added to _config_mapping
50
+ - **Positional Arguments**: Handled in order and mapped to corresponding parameter names
51
+ - **Keyword Arguments**: Processed after positional arguments, overriding any conflicts
52
+ - **Default Values**: Applied when parameters are not provided via arguments
53
+ - **Attribute Setting**: All parameters become instance attributes accessible via dot notation
54
+
55
+ Example:
56
+ ```python
57
+ @auto_register_config
58
+ class MyAlgorithm(BaseYAMLSerializable):
59
+ def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default"):
60
+ super().__init__()
61
+
62
+ # All instantiation methods work automatically:
63
+ algo1 = MyAlgorithm(0.01, 64) # positional args
64
+ algo2 = MyAlgorithm(learning_rate=0.01, model_name="bert") # keyword args
65
+ algo3 = MyAlgorithm(0.01, batch_size=128, model_name="gpt") # mixed args
66
+
67
+ # Attributes are automatically set and can be serialized:
68
+ print(algo1.learning_rate) # 0.01
69
+ print(algo1.batch_size) # 64
70
+ print(algo1.model_name) # "default" (from default value)
71
+
72
+ config = algo1.config
73
+ # DictConfig({'_target_': 'MyAlgorithm', 'learning_rate': 0.01, 'batch_size': 64, 'model_name': 'default'})
74
+ ```
75
+
76
+ Note:
77
+ - The decorator wraps the original __init__ method while preserving its signature for IDE support
78
+ - Parameters with *args or **kwargs signatures are ignored during registration
79
+ - The attributes are auto-registered, then the original __init__ method is called,
80
+ - Type hints, method name, and other metadata are preserved using functools.wraps
81
+ - This decorator is designed to work seamlessly with the YAML serialization system
82
+
83
+ Raises:
84
+ AttributeError: If the class does not have the required _config_mapping attribute
85
+ infrastructure (should inherit from YAMLSerializationMixin)
86
+ """
87
+ original_init = cls.__init__
88
+ sig = inspect.signature(original_init)
89
+
90
+ # Auto-register parameters in _config_mapping
91
+ if not "_config_mapping" in cls.__dict__:
92
+ cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", {}))
93
+ registered_parameters = tuple(cls._config_mapping.values())
94
+
95
+ for param_name in list(sig.parameters.keys())[1:]: # Skip 'self'
96
+ if (
97
+ sig.parameters[param_name].kind
98
+ not in [
99
+ _ParameterKind.VAR_POSITIONAL,
100
+ _ParameterKind.VAR_KEYWORD,
101
+ ]
102
+ ) and (param_name not in registered_parameters):
103
+ cls._config_mapping[param_name] = param_name
104
+
105
+ def __init__(self, *args, **kwargs):
106
+ nonlocal original_init, registered_parameters
107
+
108
+ # auto-register the attributes based on the signature
109
+ sig = inspect.signature(original_init)
110
+ param_names = list(sig.parameters.keys())[1:] # Skip 'self'
111
+
112
+ # Handle positional arguments
113
+ for i, arg_value in enumerate(args):
114
+ if i < len(param_names):
115
+ param_name = param_names[i]
116
+ if sig.parameters[param_name].kind not in [
117
+ _ParameterKind.VAR_POSITIONAL,
118
+ _ParameterKind.VAR_KEYWORD,
119
+ ]:
120
+ setattr(self, param_name, arg_value)
121
+
122
+ # Handle keyword arguments and defaults
123
+ for param_name in param_names:
124
+ if (
125
+ sig.parameters[param_name].kind
126
+ not in [
127
+ _ParameterKind.VAR_POSITIONAL,
128
+ _ParameterKind.VAR_KEYWORD,
129
+ ]
130
+ ) and (param_name not in registered_parameters):
131
+ # Skip if already set by positional argument
132
+ param_index = param_names.index(param_name)
133
+ if param_index >= 0 and param_index < len(args):
134
+ continue
135
+
136
+ if param_name in kwargs:
137
+ setattr(self, param_name, kwargs[param_name])
138
+ else:
139
+ # Set default value if available and attribute doesn't exist
140
+ default_value = sig.parameters[param_name].default
141
+ if default_value is not Parameter.empty:
142
+ setattr(self, param_name, default_value)
143
+
144
+ # Call the original __init__
145
+ result = original_init(self, *args, **kwargs)
146
+ return result
147
+
148
+ # Replace the original __init__ method while preserving its signature
149
+ cls.__init__ = __init__
150
+ return cls
151
+
11
152
 
12
153
  class YAMLSerializationMixin:
13
- _recursive_: bool = False
14
154
  _config_key: Optional[str] = None
15
- _config_mapping: Dict[str, str] = {
16
- "_recursive_": "_recursive_",
17
- }
155
+ _config_mapping: Dict[str, str] = {}
18
156
  R"""
19
157
  `_config_mapping` is a dictionary mapping the attribute names of the class to the config option names. This is used to convert the class to a DictConfig.
20
158
 
@@ -47,46 +185,50 @@ class YAMLSerializationMixin:
47
185
  By default, the `_target_` key is set to the class name as `type(self).__name__`.
48
186
  """
49
187
 
50
- def __init__(
51
- self,
52
- _recursive_: bool = False,
53
- **kwargs,
54
- ) -> None:
55
- self._recursive_ = _recursive_
188
+ def __init__(self, **kwargs) -> None:
56
189
  for key, value in kwargs.items():
57
190
  log.warning(f"Unused argument: {key}={value}")
58
191
 
59
192
  @property
60
- def config(self):
193
+ def config(self) -> DictConfig:
61
194
  R"""
62
195
  Returns the configuration of the model pool as a DictConfig.
63
196
 
64
- This property calls the `to_config` method to convert the model pool
65
- instance into a dictionary configuration, which can be used for
66
- serialization or other purposes.
197
+ This property converts the model pool instance into a dictionary
198
+ configuration, which can be used for serialization or other purposes.
67
199
 
68
200
  Example:
69
- >>> model = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
70
- >>> config = model.config
71
- >>> print(config)
72
- DictConfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
201
+
202
+ ```python
203
+ model = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
204
+ config = model.config
205
+ print(config)
206
+ # DictConfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
207
+ ```
73
208
 
74
209
  This is useful for serializing the object to a YAML file or for debugging.
75
210
 
76
211
  Returns:
77
212
  DictConfig: The configuration of the model pool.
78
213
  """
79
- return self.to_config()
214
+ config = {"_target_": f"{type(self).__module__}.{type(self).__qualname__}"}
215
+ for attr, key in self._config_mapping.items():
216
+ if hasattr(self, attr):
217
+ config[key] = getattr(self, attr)
218
+
219
+ try:
220
+ return OmegaConf.create(config)
221
+ except Exception as e:
222
+ return OmegaConf.create(config, flags={"allow_objects": True})
80
223
 
81
- def to_yaml(self, path: Union[str, Path]):
224
+ def to_yaml(self, path: Union[str, Path], resolve: bool = True):
82
225
  """
83
226
  Save the model pool to a YAML file.
84
227
 
85
228
  Args:
86
229
  path (Union[str, Path]): The path to save the model pool to.
87
230
  """
88
- config = self.to_config()
89
- OmegaConf.save(config, path, resolve=True)
231
+ OmegaConf.save(self.config, path, resolve=resolve)
90
232
 
91
233
  @classmethod
92
234
  def from_yaml(cls, path: Union[str, Path]):
@@ -108,41 +250,126 @@ class YAMLSerializationMixin:
108
250
  f"The class {target_cls.__name__} is not the same as the class {cls.__name__}. "
109
251
  f"Instantiating the class {target_cls.__name__} instead."
110
252
  )
111
- return instantiate(
112
- config,
113
- _recursive_=(
114
- cls._recursive_
115
- if config.get("_recursive_") is None
116
- else config.get("_recursive_")
117
- ),
118
- )
253
+ with set_print_function_call(False):
254
+ return instantiate(config)
119
255
 
120
- def to_config(self):
256
+ def register_parameter_to_config(
257
+ self,
258
+ attr_name: str,
259
+ param_name: str,
260
+ value,
261
+ ):
121
262
  """
122
- Convert the model pool to a DictConfig.
263
+ Set an attribute value and register its config mapping.
123
264
 
124
- Returns:
125
- Dict: The model pool as a DictConfig.
265
+ This method allows dynamic setting of object attributes while simultaneously
266
+ updating the configuration mapping that defines how the attribute should
267
+ be serialized in the configuration output.
268
+
269
+ Args:
270
+ attr_name (str): The name of the attribute to set on this object.
271
+ arg_name (str): The corresponding parameter name to use in the config
272
+ serialization. This is how the attribute will appear in YAML output.
273
+ value: The value to assign to the attribute.
274
+
275
+ Example:
276
+ ```python
277
+ model = BaseYAMLSerializable()
278
+ model.set_option("learning_rate", "lr", 0.001)
279
+
280
+ # This sets model.learning_rate = 0.001
281
+ # and maps it to "lr" in the config output
282
+ config = model.config
283
+ # config will contain: {"lr": 0.001, ...}
284
+ ```
126
285
  """
127
- config = {"_target_": type(self).__name__}
128
- for attr, key in self._config_mapping.items():
129
- if hasattr(self, attr):
130
- config[key] = getattr(self, attr)
131
- return OmegaConf.create(config)
286
+ setattr(self, attr_name, value)
287
+ self._config_mapping[attr_name] = param_name
288
+
289
+
290
+ @auto_register_config
291
+ class BaseYAMLSerializable(YAMLSerializationMixin):
292
+ """
293
+ A base class for YAML-serializable classes with enhanced metadata support.
294
+
295
+ This class extends `YAMLSerializationMixin` to provide additional metadata
296
+ fields commonly used in FusionBench classes, including usage information
297
+ and version tracking. It serves as a foundation for all serializable
298
+ model components in the framework.
299
+
300
+ The class automatically handles serialization of usage and version metadata
301
+ alongside the standard configuration parameters, making it easier to track
302
+ model provenance and intended usage patterns.
132
303
 
304
+ Attributes:
305
+ _usage_ (Optional[str]): Description of the model's intended usage or purpose.
306
+ _version_ (Optional[str]): Version information for the model or configuration.
133
307
 
134
- class BaseYAMLSerializableModel(YAMLSerializationMixin):
135
- _config_mapping = YAMLSerializationMixin._config_mapping | {
136
- "_usage_": "_usage_",
137
- "_version_": "_version_",
138
- }
308
+ Example:
309
+ ```python
310
+ class MyAlgorithm(BaseYAMLSerializable):
311
+ _config_mapping = BaseYAMLSerializable._config_mapping | {
312
+ "model_name": "model_name",
313
+ "num_layers": "num_layers",
314
+ }
315
+
316
+ def __init__(self, _usage_: str = None, _version_: str = None):
317
+ super().__init__(_usage_=_usage_, _version_=_version_)
318
+
319
+ # Usage with metadata
320
+ model = MyAlgorithm(
321
+ _usage_="Text classification fine-tuning",
322
+ _version_="1.0.0"
323
+ )
324
+
325
+ # Serialization includes metadata
326
+ config = model.config
327
+ # DictConfig({
328
+ # '_target_': 'MyModel',
329
+ # '_usage_': 'Text classification fine-tuning',
330
+ # '_version_': '1.0.0'
331
+ # })
332
+ ```
333
+
334
+ Note:
335
+ The underscore prefix in `_usage_` and `_version_` follows the convention
336
+ for metadata fields that are not core model parameters but provide
337
+ important contextual information for model management and tracking.
338
+ """
139
339
 
140
340
  def __init__(
141
341
  self,
342
+ _recursive_: bool = False,
142
343
  _usage_: Optional[str] = None,
143
- _version_: Optional[str] = None,
344
+ _version_: Optional[str] = FUSION_BENCH_VERSION,
144
345
  **kwargs,
145
346
  ):
347
+ """
348
+ Initialize a base YAML-serializable model with metadata support.
349
+
350
+ Args:
351
+ _usage_ (Optional[str], optional): Description of the model's intended
352
+ usage or purpose. This can include information about the training
353
+ domain, expected input types, or specific use cases. Defaults to None.
354
+ _version_ (Optional[str], optional): Version information for the model
355
+ or configuration. Can be used to track model iterations, dataset
356
+ versions, or compatibility information. Defaults to None.
357
+ **kwargs: Additional keyword arguments passed to the parent class.
358
+ Unused arguments will trigger warnings via the parent's initialization.
359
+
360
+ Example:
361
+ ```python
362
+ model = BaseYAMLSerializable(
363
+ _usage_="Image classification on CIFAR-10",
364
+ _version_="2.1.0"
365
+ )
366
+ ```
367
+ """
146
368
  super().__init__(**kwargs)
147
- self._usage_ = _usage_
148
- self._version_ = _version_
369
+ if _version_ != FUSION_BENCH_VERSION:
370
+ log.warning(
371
+ f"Current fusion-bench version is {FUSION_BENCH_VERSION}, but the serialized version is {_version_}. "
372
+ "Attempting to use current version."
373
+ )
374
+ # override _version_ with current fusion-bench version
375
+ self._version_ = FUSION_BENCH_VERSION
@@ -17,7 +17,7 @@ _import_structure = {
17
17
  "HuggingFaceGPT2ClassificationPool",
18
18
  "GPT2ForSequenceClassificationPool",
19
19
  ],
20
- "seq_classification_lm": ["SeqenceClassificationModelPool"],
20
+ "seq_classification_lm": ["SequenceClassificationModelPool"],
21
21
  }
22
22
 
23
23
 
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
34
34
  from .openclip_vision import OpenCLIPVisionModelPool
35
35
  from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
36
36
  from .seq2seq_lm import Seq2SeqLMPool
37
- from .seq_classification_lm import SeqenceClassificationModelPool
37
+ from .seq_classification_lm import SequenceClassificationModelPool
38
38
 
39
39
  else:
40
40
  sys.modules[__name__] = LazyImporter(
@@ -1,13 +1,13 @@
1
1
  import logging
2
2
  from copy import deepcopy
3
- from typing import Dict, List, Optional, Union
3
+ from typing import Dict, Generator, List, Optional, Tuple, Union
4
4
 
5
5
  import torch
6
6
  from omegaconf import DictConfig
7
7
  from torch import nn
8
8
  from torch.utils.data import Dataset
9
9
 
10
- from fusion_bench.mixins import BaseYAMLSerializableModel, HydraConfigMixin
10
+ from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
11
11
  from fusion_bench.utils import instantiate, timeit_context
12
12
 
13
13
  __all__ = ["BaseModelPool"]
@@ -15,7 +15,10 @@ __all__ = ["BaseModelPool"]
15
15
  log = logging.getLogger(__name__)
16
16
 
17
17
 
18
- class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
18
+ class BaseModelPool(
19
+ HydraConfigMixin,
20
+ BaseYAMLSerializable,
21
+ ):
19
22
  """
20
23
  A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.
21
24
 
@@ -31,7 +34,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
31
34
  _program = None
32
35
  _config_key = "modelpool"
33
36
  _models: Union[DictConfig, Dict[str, nn.Module]]
34
- _config_mapping = BaseYAMLSerializableModel._config_mapping | {
37
+ _config_mapping = BaseYAMLSerializable._config_mapping | {
35
38
  "_models": "models",
36
39
  "_train_datasets": "train_datasets",
37
40
  "_val_datasets": "val_datasets",
@@ -56,7 +59,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
56
59
  super().__init__(**kwargs)
57
60
 
58
61
  @property
59
- def has_pretrained(self):
62
+ def has_pretrained(self) -> bool:
60
63
  """
61
64
  Check if the model pool contains a pretrained model.
62
65
 
@@ -125,7 +128,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
125
128
  return len(self.model_names)
126
129
 
127
130
  @staticmethod
128
- def is_special_model(model_name: str):
131
+ def is_special_model(model_name: str) -> bool:
129
132
  """
130
133
  Determine if a model is special based on its name.
131
134
 
@@ -152,6 +155,23 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
152
155
  model_config = deepcopy(model_config)
153
156
  return model_config
154
157
 
158
+ def get_model_path(self, model_name: str) -> str:
159
+ """
160
+ Get the path for the specified model.
161
+
162
+ Args:
163
+ model_name (str): The name of the model.
164
+
165
+ Returns:
166
+ str: The path for the specified model.
167
+ """
168
+ if isinstance(self._models[model_name], str):
169
+ return self._models[model_name]
170
+ else:
171
+ raise ValueError(
172
+ "Model path is not a string. Try to override this method in derived modelpool class."
173
+ )
174
+
155
175
  def load_model(
156
176
  self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
157
177
  ) -> nn.Module:
@@ -159,7 +179,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
159
179
  Load a model from the pool based on the provided configuration.
160
180
 
161
181
  Args:
162
- model (Union[str, DictConfig]): The model name or configuration.
182
+ model_name_or_config (Union[str, DictConfig]): The model name or configuration.
163
183
 
164
184
  Returns:
165
185
  nn.Module: The instantiated model.
@@ -201,11 +221,11 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
201
221
  model = self.load_model(self.model_names[0], *args, **kwargs)
202
222
  return model
203
223
 
204
- def models(self):
224
+ def models(self) -> Generator[nn.Module, None, None]:
205
225
  for model_name in self.model_names:
206
226
  yield self.load_model(model_name)
207
227
 
208
- def named_models(self):
228
+ def named_models(self) -> Generator[Tuple[str, nn.Module], None, None]:
209
229
  for model_name in self.model_names:
210
230
  yield model_name, self.load_model(model_name)
211
231