fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. fusion_bench/__init__.py +25 -2
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/constants/__init__.py +1 -0
  7. fusion_bench/constants/runtime.py +57 -0
  8. fusion_bench/dataset/gpt2_glue.py +1 -1
  9. fusion_bench/method/__init__.py +12 -4
  10. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  11. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  12. fusion_bench/method/bitdelta/__init__.py +1 -0
  13. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  14. fusion_bench/method/classification/clip_finetune.py +1 -1
  15. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  16. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  17. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  18. fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
  19. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
  20. fusion_bench/method/linear/simple_average_for_llama.py +16 -11
  21. fusion_bench/method/model_stock/__init__.py +1 -0
  22. fusion_bench/method/model_stock/model_stock.py +309 -0
  23. fusion_bench/method/regmean/clip_regmean.py +3 -6
  24. fusion_bench/method/regmean/regmean.py +27 -56
  25. fusion_bench/method/regmean/utils.py +56 -0
  26. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  27. fusion_bench/method/simple_average.py +7 -7
  28. fusion_bench/method/slerp/__init__.py +1 -1
  29. fusion_bench/method/slerp/slerp.py +110 -14
  30. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  31. fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
  32. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  33. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
  34. fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
  35. fusion_bench/method/we_moe/__init__.py +1 -0
  36. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  37. fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
  38. fusion_bench/method/we_moe/utils.py +15 -0
  39. fusion_bench/method/weighted_average/llama.py +1 -1
  40. fusion_bench/mixins/clip_classification.py +37 -48
  41. fusion_bench/mixins/serialization.py +30 -10
  42. fusion_bench/modelpool/base_pool.py +1 -1
  43. fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
  44. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  45. fusion_bench/models/__init__.py +5 -0
  46. fusion_bench/models/hf_utils.py +69 -86
  47. fusion_bench/models/linearized/vision_model.py +6 -6
  48. fusion_bench/models/model_card_templates/default.md +46 -0
  49. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  50. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
  51. fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
  52. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
  53. fusion_bench/models/we_moe.py +8 -8
  54. fusion_bench/programs/fabric_fusion_program.py +29 -60
  55. fusion_bench/scripts/cli.py +34 -1
  56. fusion_bench/taskpool/base_pool.py +99 -17
  57. fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
  58. fusion_bench/taskpool/dummy.py +101 -13
  59. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  60. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  61. fusion_bench/utils/__init__.py +2 -0
  62. fusion_bench/utils/cache_utils.py +101 -1
  63. fusion_bench/utils/data.py +6 -4
  64. fusion_bench/utils/devices.py +7 -4
  65. fusion_bench/utils/dtype.py +3 -2
  66. fusion_bench/utils/fabric.py +2 -2
  67. fusion_bench/utils/lazy_imports.py +23 -0
  68. fusion_bench/utils/lazy_state_dict.py +117 -19
  69. fusion_bench/utils/modelscope.py +3 -3
  70. fusion_bench/utils/packages.py +3 -3
  71. fusion_bench/utils/parameters.py +0 -2
  72. fusion_bench/utils/path.py +56 -0
  73. fusion_bench/utils/pylogger.py +1 -1
  74. fusion_bench/utils/timer.py +92 -10
  75. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
  76. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
  77. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  78. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  79. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  80. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  81. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  82. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  83. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
  84. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  85. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
  86. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
  87. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
  88. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
  89. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py CHANGED
@@ -19,9 +19,32 @@ from . import (
19
19
  tasks,
20
20
  utils,
21
21
  )
22
+ from .constants import RuntimeConstants
22
23
  from .method import BaseAlgorithm, BaseModelFusionAlgorithm
23
24
  from .mixins import auto_register_config
24
25
  from .modelpool import BaseModelPool
25
- from .models import separate_io
26
+ from .models import (
27
+ create_default_model_card,
28
+ load_model_card_template,
29
+ save_pretrained_with_remote_code,
30
+ separate_io,
31
+ )
32
+ from .programs import BaseHydraProgram
26
33
  from .taskpool import BaseTaskPool
27
- from .utils import parse_dtype, print_parameters, timeit_context
34
+ from .utils import (
35
+ BoolStateDictType,
36
+ LazyStateDict,
37
+ StateDictType,
38
+ TorchModelType,
39
+ cache_with_joblib,
40
+ get_rankzero_logger,
41
+ import_object,
42
+ instantiate,
43
+ parse_dtype,
44
+ print_parameters,
45
+ seed_everything_by_time,
46
+ set_default_cache_dir,
47
+ set_print_function_call,
48
+ set_print_function_call_permeanent,
49
+ timeit_context,
50
+ )
@@ -1,4 +1,5 @@
1
1
  import warnings
2
+ from typing import Any, List, Type
2
3
 
3
4
  from omegaconf import DictConfig
4
5
 
@@ -76,7 +77,9 @@ class AlgorithmFactory:
76
77
  return algorithm_cls(method_config)
77
78
 
78
79
  @staticmethod
79
- def register_algorithm(name: str, algorithm_cls):
80
+ def register_algorithm(
81
+ name: str, algorithm_cls: Type[ModelFusionAlgorithm]
82
+ ) -> None:
80
83
  """
81
84
  Register a new algorithm with the factory.
82
85
 
@@ -87,7 +90,7 @@ class AlgorithmFactory:
87
90
  AlgorithmFactory._aglorithms[name] = algorithm_cls
88
91
 
89
92
  @classmethod
90
- def available_algorithms(cls):
93
+ def available_algorithms(cls) -> List[str]:
91
94
  """
92
95
  Get a list of available algorithms.
93
96
 
@@ -1,9 +1,10 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import TYPE_CHECKING, Optional
2
+ from typing import TYPE_CHECKING, Any, Optional
3
3
 
4
4
  from omegaconf import DictConfig
5
5
 
6
6
  if TYPE_CHECKING:
7
+ from fusion_bench import BaseModelPool
7
8
  from fusion_bench.programs.base_program import BaseHydraProgram
8
9
 
9
10
  __all__ = ["ModelFusionAlgorithm"]
@@ -51,7 +52,7 @@ class ModelFusionAlgorithm(ABC):
51
52
  pass
52
53
 
53
54
  @abstractmethod
54
- def run(self, modelpool):
55
+ def run(self, modelpool: "BaseModelPool") -> Any:
55
56
  """
56
57
  Fuse the models in the given model pool.
57
58
 
@@ -42,7 +42,7 @@ class ModelPool(ABC):
42
42
  ), "Duplicate model names found in model pool"
43
43
  self._model_names = model_names
44
44
 
45
- def __len__(self):
45
+ def __len__(self) -> int:
46
46
  """
47
47
  Return the number of models in the model pool, exclude special models such as `_pretrained_`.
48
48
 
@@ -66,7 +66,7 @@ class ModelPool(ABC):
66
66
  return names
67
67
 
68
68
  @property
69
- def has_pretrained(self):
69
+ def has_pretrained(self) -> bool:
70
70
  """
71
71
  Check if the pretrained model is available in the model pool.
72
72
 
@@ -78,7 +78,7 @@ class ModelPool(ABC):
78
78
  return True
79
79
  return False
80
80
 
81
- def get_model_config(self, model_name: str):
81
+ def get_model_config(self, model_name: str) -> Dict:
82
82
  """
83
83
  Retrieves the configuration for a specific model from the model pool.
84
84
 
@@ -169,7 +169,7 @@ class CLIPImageClassificationTaskPool(TaskPool):
169
169
  self._fabric = L.Fabric(devices=1)
170
170
  self._fabric.launch()
171
171
 
172
- # CLIPVisionModel works the same with CLIPVisonTransformer, so we can use it directly
172
+ # CLIPVisionModel works the same with CLIPVisionTransformer, so we can use it directly
173
173
  self.clip_model.vision_model = model
174
174
  report = {}
175
175
  training_params, all_params = count_parameters(model)
@@ -2,6 +2,7 @@
2
2
  import importlib.metadata
3
3
 
4
4
  from .paths import *
5
+ from .runtime import RuntimeConstants
5
6
 
6
7
  # fusionbench version
7
8
  FUSION_BENCH_VERSION = importlib.metadata.version("fusion-bench")
@@ -0,0 +1,57 @@
1
+ import threading
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+
5
+
6
+ class RuntimeConstants:
7
+ """
8
+ This class holds constants related to the runtime environment of the Fusion Bench framework.
9
+ It includes default values for cache directories and other runtime configurations.
10
+
11
+ Implemented as a thread-safe singleton to ensure consistent runtime configuration
12
+ across the entire application.
13
+ """
14
+
15
+ _instance: Optional["RuntimeConstants"] = None
16
+ _lock = threading.Lock()
17
+
18
+ def __new__(cls) -> "RuntimeConstants":
19
+ """Create a new instance using singleton pattern with thread safety."""
20
+ with cls._lock:
21
+ # Double-check locking pattern
22
+ if cls._instance is None:
23
+ cls._instance = super(RuntimeConstants, cls).__new__(cls)
24
+ cls._instance._initialized = False
25
+ return cls._instance
26
+
27
+ def __init__(self):
28
+ """Initialize the singleton instance only once."""
29
+ if not self._initialized:
30
+ # Add your runtime constants here
31
+ self._initialized = True
32
+
33
+ debug = False
34
+
35
+ @property
36
+ def cache_dir(self) -> Path:
37
+ from fusion_bench.utils.cache_utils import DEFAULT_CACHE_DIR
38
+
39
+ return DEFAULT_CACHE_DIR
40
+
41
+ @cache_dir.setter
42
+ def cache_dir(self, path: Union[str, Path]) -> None:
43
+ from fusion_bench.utils.cache_utils import set_default_cache_dir
44
+
45
+ set_default_cache_dir(path)
46
+
47
+ @property
48
+ def print_function_call(self) -> bool:
49
+ from fusion_bench.utils.instantiate_utils import PRINT_FUNCTION_CALL
50
+
51
+ return PRINT_FUNCTION_CALL
52
+
53
+ @print_function_call.setter
54
+ def print_function_call(self, enable: bool) -> None:
55
+ from fusion_bench.utils.instantiate_utils import set_print_function_call
56
+
57
+ set_print_function_call(enable)
@@ -121,7 +121,7 @@ class TokenizedGLUE:
121
121
 
122
122
  def load_dataset(
123
123
  self, name: Literal["mrpc", "mnli", "cola", "sst2", "qnli", "qqp", "rte"]
124
- ):
124
+ ) -> Dataset:
125
125
  """
126
126
  Load and tokenize a GLUE dataset.
127
127
 
@@ -30,7 +30,7 @@ _import_structure = {
30
30
  "TaskArithmeticForLlama",
31
31
  "LinearInterpolationAlgorithm",
32
32
  ],
33
- "slerp": ["SlerpMergeAlgorithm"],
33
+ "slerp": ["SlerpMergeAlgorithm", "SlerpForCausalLM"],
34
34
  "simple_average": ["SimpleAverageAlgorithm"],
35
35
  "weighted_average": ["WeightedAverageAlgorithm", "WeightedAverageForLLama"],
36
36
  "task_arithmetic": ["TaskArithmeticAlgorithm"],
@@ -71,6 +71,7 @@ _import_structure = {
71
71
  ],
72
72
  "fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"],
73
73
  "tall_mask": ["TallMaskTaskArithmeticAlgorithm"],
74
+ "model_stock": ["ModelStock"],
74
75
  # plug-and-play model merging methods
75
76
  "concrete_subspace": [
76
77
  "ConcreteTaskArithmeticAlgorithmForCLIP",
@@ -90,7 +91,10 @@ _import_structure = {
90
91
  "MixtralForCausalLMMergingAlgorithm",
91
92
  ],
92
93
  "dawe": ["DataAdaptiveWeightEnsemblingForCLIP"],
93
- "we_moe": ["CLIPWeightEnsemblingMoEAlgorithm"],
94
+ "we_moe": [
95
+ "CLIPWeightEnsemblingMoEAlgorithm",
96
+ "FlanT5WeightEnsemblingMoEAlgorithm",
97
+ ],
94
98
  "rankone_moe": ["CLIPRankOneMoEAlgorithm", "RankOneMoEAlgorithm"],
95
99
  "sparse_we_moe": [
96
100
  "SparseWeightEnsemblingMoEAlgorithm",
@@ -191,6 +195,7 @@ if TYPE_CHECKING:
191
195
  MixtralUpscalingAlgorithm,
192
196
  )
193
197
  from .model_recombination import ModelRecombinationAlgorithm
198
+ from .model_stock import ModelStock
194
199
  from .opcm import OPCMForCLIP
195
200
  from .pruning import (
196
201
  MagnitudeDiffPruningAlgorithm,
@@ -210,7 +215,7 @@ if TYPE_CHECKING:
210
215
  RegMeanAlgorithmPlusPlus,
211
216
  )
212
217
  from .simple_average import SimpleAverageAlgorithm
213
- from .slerp import SlerpMergeAlgorithm
218
+ from .slerp import SlerpForCausalLM, SlerpMergeAlgorithm
214
219
  from .smile_upscaling import (
215
220
  SingularProjectionMergingAlgorithm,
216
221
  SmileUpscalingAlgorithm,
@@ -228,7 +233,10 @@ if TYPE_CHECKING:
228
233
  from .task_arithmetic import TaskArithmeticAlgorithm
229
234
  from .task_singular_vector import TaskSingularVectorMerging
230
235
  from .ties_merging import TiesMergingAlgorithm
231
- from .we_moe import CLIPWeightEnsemblingMoEAlgorithm
236
+ from .we_moe import (
237
+ CLIPWeightEnsemblingMoEAlgorithm,
238
+ FlanT5WeightEnsemblingMoEAlgorithm,
239
+ )
232
240
  from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
233
241
 
234
242
  else:
@@ -11,7 +11,7 @@ from torch import nn
11
11
  from tqdm.auto import tqdm
12
12
 
13
13
  from fusion_bench.method import BaseAlgorithm
14
- from fusion_bench.mixins import LightningFabricMixin
14
+ from fusion_bench.mixins import LightningFabricMixin, auto_register_config
15
15
  from fusion_bench.modelpool import BaseModelPool
16
16
  from fusion_bench.utils.parameters import (
17
17
  StateDictType,
@@ -23,14 +23,50 @@ from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
23
23
  log = logging.getLogger(__name__)
24
24
 
25
25
 
26
- class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
26
+ @auto_register_config
27
+ class TaskVectorCosSimilarity(
28
+ LightningFabricMixin,
29
+ BaseAlgorithm,
30
+ ):
27
31
  """
28
- This class is similar to the Dummy algorithm,
29
- but it also print (or save) the cosine similarity matrix between the task vectors of the models in the model pool.
32
+ Computes and analyzes cosine similarity between task vectors of models in a model pool.
33
+
34
+ This algorithm extracts task vectors from fine-tuned models by computing the difference
35
+ between their parameters and a pretrained base model. It then calculates the pairwise
36
+ cosine similarity between all task vectors to understand the relationships and overlap
37
+ between different tasks.
38
+
39
+ The task vector for a model is defined as:
40
+ task_vector = finetuned_model_params - pretrained_model_params
41
+
42
+ Args:
43
+ plot_heatmap (bool): Whether to generate and save a heatmap visualization
44
+ trainable_only (bool, optional): If True, only consider trainable parameters
45
+ when computing task vectors. Defaults to True.
46
+ max_points_per_model (int, optional): Maximum number of parameters to sample
47
+ per model for memory efficiency. If None, uses all parameters.
48
+ output_path (str, optional): Directory to save outputs. If None, uses the
49
+ fabric logger directory.
50
+
51
+ Outputs:
52
+ - task_vector_cos_similarity.csv: Pairwise cosine similarity matrix
53
+ - task_vector_cos_similarity.pdf: Heatmap visualization (if plot_heatmap=True)
54
+
55
+ Returns:
56
+ The pretrained model from the model pool.
57
+
58
+ Example:
59
+ ```python
60
+ >>> algorithm = TaskVectorCosSimilarity(
61
+ ... plot_heatmap=True,
62
+ ... trainable_only=True,
63
+ ... output_path="/path/to/outputs"
64
+ ... )
65
+ >>> result = algorithm.run(modelpool)
66
+ ```
30
67
  """
31
68
 
32
69
  _config_mapping = BaseAlgorithm._config_mapping | {
33
- "plot_heatmap": "plot_heatmap",
34
70
  "_output_path": "output_path",
35
71
  }
36
72
 
@@ -42,11 +78,8 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
42
78
  output_path: Optional[str] = None,
43
79
  **kwargs,
44
80
  ):
45
- self.plot_heatmap = plot_heatmap
46
- self.trainable_only = trainable_only
47
- self.max_points_per_model = max_points_per_model
48
- self._output_path = output_path
49
81
  super().__init__(**kwargs)
82
+ self._output_path = output_path
50
83
 
51
84
  @property
52
85
  def output_path(self):
@@ -57,6 +90,22 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
57
90
 
58
91
  @torch.no_grad()
59
92
  def run(self, modelpool: BaseModelPool):
93
+ """
94
+ Execute the task vector cosine similarity analysis.
95
+
96
+ This method:
97
+ 1. Loads the pretrained base model from the model pool
98
+ 2. Computes task vectors for each fine-tuned model
99
+ 3. Calculates pairwise cosine similarities between all task vectors
100
+ 4. Saves the similarity matrix as a CSV file
101
+ 5. Optionally generates and saves a heatmap visualization
102
+
103
+ Args:
104
+ modelpool (BaseModelPool): Pool containing pretrained and fine-tuned models
105
+
106
+ Returns:
107
+ nn.Module: The pretrained model from the model pool
108
+ """
60
109
  pretrained_model = modelpool.load_pretrained_model()
61
110
 
62
111
  task_vectors = []
@@ -103,11 +152,14 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
103
152
 
104
153
  def _plot_heatmap(self, data: pd.DataFrame):
105
154
  """
106
- This function plots a heatmap of the provided data using seaborn.
155
+ Generate and save a heatmap visualization of the cosine similarity matrix.
156
+
157
+ Creates a color-coded heatmap showing pairwise cosine similarities between
158
+ task vectors. The heatmap is saved as a PDF file in the output directory.
107
159
 
108
160
  Args:
109
- data (pd.DataFrame): A pandas DataFrame containing the data to be plotted.
110
- figsize (tuple): A tuple specifying the size of the figure. Default is (4, 3).
161
+ data (pd.DataFrame): Symmetric matrix of cosine similarities between
162
+ task vectors, with model names as both index and columns.
111
163
 
112
164
  Returns:
113
165
  None
@@ -141,6 +193,26 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
141
193
  def get_task_vector(
142
194
  self, pretrained_model: nn.Module, finetuned_model: nn.Module
143
195
  ) -> torch.Tensor:
196
+ """
197
+ Compute the task vector for a fine-tuned model.
198
+
199
+ The task vector represents the parameter changes from pretraining to
200
+ fine-tuning and is computed as:
201
+ task_vector = finetuned_params - pretrained_params
202
+
203
+ Args:
204
+ pretrained_model (nn.Module): The base pretrained model
205
+ finetuned_model (nn.Module): The fine-tuned model for a specific task
206
+
207
+ Returns:
208
+ torch.Tensor: Flattened task vector containing parameter differences.
209
+ If max_points_per_model is set, the vector may be downsampled.
210
+
211
+ Note:
212
+ - Converts parameters to float64 for numerical precision
213
+ - Supports optional downsampling for memory efficiency
214
+ - Uses only trainable parameters if trainable_only=True
215
+ """
144
216
  task_vector = state_dict_sub(
145
217
  self.get_state_dict(finetuned_model),
146
218
  self.get_state_dict(pretrained_model),
@@ -166,6 +238,17 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
166
238
  return task_vector
167
239
 
168
240
  def get_state_dict(self, model: nn.Module):
241
+ """
242
+ Extract the state dictionary from a model.
243
+
244
+ Args:
245
+ model (nn.Module): The model to extract parameters from
246
+
247
+ Returns:
248
+ Dict[str, torch.Tensor]: State dictionary containing model parameters.
249
+ Returns only trainable parameters if trainable_only=True,
250
+ otherwise returns all parameters.
251
+ """
169
252
  if self.trainable_only:
170
253
  return trainable_state_dict(model)
171
254
  else: