fusion-bench 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. fusion_bench/__init__.py +152 -42
  2. fusion_bench/dataset/__init__.py +27 -4
  3. fusion_bench/dataset/clip_dataset.py +2 -2
  4. fusion_bench/method/__init__.py +12 -1
  5. fusion_bench/method/classification/__init__.py +27 -2
  6. fusion_bench/method/classification/clip_finetune.py +6 -4
  7. fusion_bench/method/classification/image_classification_finetune.py +214 -0
  8. fusion_bench/method/dop/__init__.py +1 -0
  9. fusion_bench/method/dop/dop.py +366 -0
  10. fusion_bench/method/dop/min_norm_solvers.py +227 -0
  11. fusion_bench/method/dop/utils.py +73 -0
  12. fusion_bench/method/opcm/opcm.py +1 -0
  13. fusion_bench/method/pwe_moe/module.py +0 -2
  14. fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
  15. fusion_bench/mixins/__init__.py +2 -0
  16. fusion_bench/mixins/pyinstrument.py +174 -0
  17. fusion_bench/mixins/simple_profiler.py +106 -23
  18. fusion_bench/modelpool/__init__.py +2 -0
  19. fusion_bench/modelpool/base_pool.py +77 -14
  20. fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
  21. fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
  22. fusion_bench/models/__init__.py +35 -9
  23. fusion_bench/optim/__init__.py +40 -2
  24. fusion_bench/optim/lr_scheduler/__init__.py +27 -1
  25. fusion_bench/optim/muon.py +339 -0
  26. fusion_bench/programs/__init__.py +2 -0
  27. fusion_bench/programs/fabric_fusion_program.py +2 -2
  28. fusion_bench/programs/fusion_program.py +271 -0
  29. fusion_bench/tasks/clip_classification/__init__.py +15 -0
  30. fusion_bench/utils/__init__.py +167 -21
  31. fusion_bench/utils/lazy_imports.py +91 -12
  32. fusion_bench/utils/lazy_state_dict.py +55 -5
  33. fusion_bench/utils/misc.py +104 -13
  34. fusion_bench/utils/packages.py +4 -0
  35. fusion_bench/utils/path.py +7 -0
  36. fusion_bench/utils/pylogger.py +6 -0
  37. fusion_bench/utils/rich_utils.py +1 -0
  38. fusion_bench/utils/state_dict_arithmetic.py +935 -162
  39. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/METADATA +8 -2
  40. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/RECORD +75 -56
  41. fusion_bench_config/method/bitdelta/bitdelta.yaml +3 -0
  42. fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
  43. fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
  44. fusion_bench_config/method/depth_upscaling.yaml +9 -0
  45. fusion_bench_config/method/dop/dop.yaml +30 -0
  46. fusion_bench_config/method/dummy.yaml +6 -0
  47. fusion_bench_config/method/ensemble/max_model_predictor.yaml +6 -0
  48. fusion_bench_config/method/ensemble/simple_ensemble.yaml +8 -1
  49. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +8 -0
  50. fusion_bench_config/method/linear/linear_interpolation.yaml +8 -0
  51. fusion_bench_config/method/linear/weighted_average.yaml +3 -0
  52. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +1 -1
  53. fusion_bench_config/method/model_recombination.yaml +8 -0
  54. fusion_bench_config/method/model_stock/model_stock.yaml +4 -1
  55. fusion_bench_config/method/opcm/opcm.yaml +5 -0
  56. fusion_bench_config/method/opcm/task_arithmetic.yaml +6 -0
  57. fusion_bench_config/method/opcm/ties_merging.yaml +5 -0
  58. fusion_bench_config/method/opcm/weight_average.yaml +5 -0
  59. fusion_bench_config/method/simple_average.yaml +9 -0
  60. fusion_bench_config/method/slerp/slerp.yaml +9 -0
  61. fusion_bench_config/method/slerp/slerp_lm.yaml +5 -0
  62. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +3 -0
  63. fusion_bench_config/method/task_arithmetic.yaml +9 -0
  64. fusion_bench_config/method/ties_merging.yaml +3 -0
  65. fusion_bench_config/model_fusion.yaml +45 -0
  66. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
  67. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
  68. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
  69. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
  70. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
  71. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
  72. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/WHEEL +0 -0
  73. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/entry_points.txt +0 -0
  74. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/licenses/LICENSE +0 -0
  75. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,7 @@ _import_structure = {
11
11
  "hydra_config": ["HydraConfigMixin"],
12
12
  "lightning_fabric": ["LightningFabricMixin"],
13
13
  "openclip_classification": ["OpenCLIPClassificationMixin"],
14
+ "pyinstrument": ["PyinstrumentProfilerMixin"],
14
15
  "serialization": [
15
16
  "BaseYAMLSerializable",
16
17
  "YAMLSerializationMixin",
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
25
26
  from .hydra_config import HydraConfigMixin
26
27
  from .lightning_fabric import LightningFabricMixin
27
28
  from .openclip_classification import OpenCLIPClassificationMixin
29
+ from .pyinstrument import PyinstrumentProfilerMixin
28
30
  from .serialization import (
29
31
  BaseYAMLSerializable,
30
32
  YAMLSerializationMixin,
@@ -0,0 +1,174 @@
1
+ from contextlib import contextmanager
2
+ from pathlib import Path
3
+ from typing import Generator, Optional, Union
4
+
5
+ from lightning.fabric.utilities.rank_zero import rank_zero_only
6
+ from pyinstrument import Profiler
7
+
8
+ __all__ = ["PyinstrumentProfilerMixin"]
9
+
10
+
11
+ class PyinstrumentProfilerMixin:
12
+ """
13
+ A mixin class that provides statistical profiling capabilities using pyinstrument.
14
+
15
+ This mixin allows for easy profiling of code blocks using a context manager.
16
+ It provides methods to start and stop profiling actions, save profiling results
17
+ to files, and print profiling summaries.
18
+
19
+ Note:
20
+ This mixin requires the `pyinstrument` package to be installed.
21
+ If not available, an ImportError will be raised when importing this module.
22
+
23
+ Examples:
24
+
25
+ ```python
26
+ class MyClass(PyinstrumentProfilerMixin):
27
+ def do_something(self):
28
+ with self.profile("work"):
29
+ # do some work here
30
+ ...
31
+
32
+ # save the profiling results
33
+ self.save_profile_report("profile_report.html")
34
+
35
+ # or print the summary
36
+ self.print_profile_summary()
37
+ ```
38
+
39
+ Attributes:
40
+ _profiler (Profiler): An instance of the pyinstrument Profiler class.
41
+ """
42
+
43
+ _profiler: Optional[Profiler] = None
44
+ _is_profiling: bool = False
45
+
46
+ @property
47
+ def profiler(self) -> Optional[Profiler]:
48
+ """Get the profiler instance, creating it if necessary."""
49
+ if self._profiler is None:
50
+ self._profiler = Profiler()
51
+ return self._profiler
52
+
53
+ @contextmanager
54
+ def profile(self, action_name: Optional[str] = None) -> Generator:
55
+ """
56
+ Context manager for profiling a code block.
57
+
58
+ Args:
59
+ action_name: Optional name for the profiling action (for logging purposes).
60
+
61
+ Example:
62
+
63
+ ```python
64
+ with self.profile("expensive_operation"):
65
+ # do some expensive work here
66
+ expensive_function()
67
+ ```
68
+ """
69
+ try:
70
+ self.start_profile(action_name)
71
+ yield action_name
72
+ finally:
73
+ self.stop_profile(action_name)
74
+
75
+ def start_profile(self, action_name: Optional[str] = None):
76
+ """
77
+ Start profiling.
78
+
79
+ Args:
80
+ action_name: Optional name for the profiling action.
81
+ """
82
+ if self._is_profiling:
83
+ return
84
+
85
+ self.profiler.start()
86
+ self._is_profiling = True
87
+ if action_name:
88
+ print(f"Started profiling: {action_name}")
89
+
90
+ def stop_profile(self, action_name: Optional[str] = None):
91
+ """
92
+ Stop profiling.
93
+
94
+ Args:
95
+ action_name: Optional name for the profiling action.
96
+ """
97
+ if not self._is_profiling:
98
+ return
99
+
100
+ self.profiler.stop()
101
+ self._is_profiling = False
102
+ if action_name:
103
+ print(f"Stopped profiling: {action_name}")
104
+
105
+ @rank_zero_only
106
+ def print_profile_summary(
107
+ self, title: Optional[str] = None, unicode: bool = True, color: bool = True
108
+ ):
109
+ """
110
+ Print a summary of the profiling results.
111
+
112
+ Args:
113
+ title: Optional title to print before the summary.
114
+ unicode: Whether to use unicode characters in the output.
115
+ color: Whether to use color in the output.
116
+ """
117
+ if self.profiler is None:
118
+ print("No profiling data available.")
119
+ return
120
+
121
+ if title is not None:
122
+ print(title)
123
+
124
+ print(self.profiler.output_text(unicode=unicode, color=color))
125
+
126
+ @rank_zero_only
127
+ def save_profile_report(
128
+ self,
129
+ output_path: Union[str, Path] = "profile_report.html",
130
+ format: str = "html",
131
+ title: Optional[str] = None,
132
+ ):
133
+ """
134
+ Save the profiling results to a file.
135
+
136
+ Args:
137
+ output_path: Path where to save the profiling report.
138
+ format: Output format ('html', or 'text').
139
+ title: Optional title for the report.
140
+ """
141
+ if self.profiler is None:
142
+ print("No profiling data available.")
143
+ return
144
+
145
+ output_path = Path(output_path)
146
+ output_path.parent.mkdir(parents=True, exist_ok=True)
147
+
148
+ if format.lower() == "html":
149
+ content = self.profiler.output_html()
150
+ elif format.lower() == "text":
151
+ content = self.profiler.output_text(unicode=True, color=False)
152
+ else:
153
+ raise ValueError(f"Unsupported format: {format}. Use 'html', or 'text'.")
154
+
155
+ with open(output_path, "w", encoding="utf-8") as f:
156
+ f.write(content)
157
+
158
+ print(f"Profile report saved to: {output_path}")
159
+
160
+ def reset_profile(self):
161
+ """Reset the profiler to start fresh."""
162
+ if self._is_profiling:
163
+ self.stop_profile()
164
+
165
+ self._profiler = None
166
+
167
+ def __del__(self):
168
+ """Cleanup when the object is destroyed."""
169
+ if self._is_profiling:
170
+ self.stop_profile()
171
+
172
+ if self._profiler is not None:
173
+ del self._profiler
174
+ self._profiler = None
@@ -9,27 +9,33 @@ __all__ = ["SimpleProfilerMixin"]
9
9
 
10
10
  class SimpleProfilerMixin:
11
11
  """
12
- A mixin class that provides simple profiling capabilities.
12
+ A mixin class that provides simple profiling capabilities using Lightning's SimpleProfiler.
13
13
 
14
- This mixin allows for easy profiling of code blocks using a context manager.
15
- It also provides methods to start and stop profiling actions, and to print
16
- a summary of the profiling results.
14
+ This mixin allows for easy profiling of code blocks using a context manager or manual
15
+ start/stop methods. It measures the execution time of named actions and provides
16
+ a summary of the profiling results. Unlike statistical profilers, this provides
17
+ precise timing measurements for specific code blocks.
17
18
 
18
- Examples:
19
-
20
- ```python
21
- class MyClass(SimpleProfilerMixin):
22
- def do_something(self):
23
- with self.profile("work"):
24
- # do some work here
25
- ...
26
- with self.profile("more work"):
27
- # do more work here
28
- ...
19
+ Note:
20
+ This mixin uses Lightning's SimpleProfiler which measures wall-clock time
21
+ for named actions. It's suitable for timing discrete operations rather than
22
+ detailed function-level profiling.
29
23
 
30
- # print the profiling summary
31
- self.print_profile_summary()
32
- ```
24
+ Examples:
25
+ ```python
26
+ class MyClass(SimpleProfilerMixin):
27
+ def do_something(self):
28
+ with self.profile("data_loading"):
29
+ # Load data here
30
+ data = load_data()
31
+
32
+ with self.profile("model_training"):
33
+ # Train model here
34
+ model.train(data)
35
+
36
+ # Print the profiling summary
37
+ self.print_profile_summary("Training Profile")
38
+ ```
33
39
 
34
40
  Attributes:
35
41
  _profiler (SimpleProfiler): An instance of the SimpleProfiler class used for profiling.
@@ -38,7 +44,13 @@ class SimpleProfilerMixin:
38
44
  _profiler: SimpleProfiler = None
39
45
 
40
46
  @property
41
- def profiler(self):
47
+ def profiler(self) -> SimpleProfiler:
48
+ """
49
+ Get the SimpleProfiler instance, creating it if necessary.
50
+
51
+ Returns:
52
+ SimpleProfiler: The profiler instance used for timing measurements.
53
+ """
42
54
  # Lazy initialization of the profiler instance
43
55
  if self._profiler is None:
44
56
  self._profiler = SimpleProfiler()
@@ -47,14 +59,24 @@ class SimpleProfilerMixin:
47
59
  @contextmanager
48
60
  def profile(self, action_name: str) -> Generator:
49
61
  """
50
- Context manager for profiling a code block
62
+ Context manager for profiling a code block.
63
+
64
+ This context manager automatically starts profiling when entering the block
65
+ and stops profiling when exiting the block (even if an exception occurs).
66
+
67
+ Args:
68
+ action_name: A descriptive name for the action being profiled.
69
+ This name will appear in the profiling summary.
70
+
71
+ Yields:
72
+ str: The action name that was provided.
51
73
 
52
74
  Example:
53
75
 
54
76
  ```python
55
- with self.profile("work"):
56
- # do some work here
57
- ...
77
+ with self.profile("data_processing"):
78
+ # Process data here
79
+ result = process_large_dataset()
58
80
  ```
59
81
  """
60
82
  try:
@@ -64,18 +86,79 @@ class SimpleProfilerMixin:
64
86
  self.stop_profile(action_name)
65
87
 
66
88
  def start_profile(self, action_name: str):
89
+ """
90
+ Start profiling for a named action.
91
+
92
+ This method begins timing for the specified action. You must call
93
+ stop_profile() with the same action name to complete the measurement.
94
+
95
+ Args:
96
+ action_name: A descriptive name for the action being profiled.
97
+ This name will appear in the profiling summary.
98
+
99
+ Example:
100
+ ```python
101
+ self.start_profile("model_inference")
102
+ result = model.predict(data)
103
+ self.stop_profile("model_inference")
104
+ ```
105
+ """
67
106
  self.profiler.start(action_name)
68
107
 
69
108
  def stop_profile(self, action_name: str):
109
+ """
110
+ Stop profiling for a named action.
111
+
112
+ This method ends timing for the specified action that was previously
113
+ started with start_profile().
114
+
115
+ Args:
116
+ action_name: The name of the action to stop profiling.
117
+ Must match the name used in start_profile().
118
+
119
+ Example:
120
+ ```python
121
+ self.start_profile("data_loading")
122
+ data = load_data()
123
+ self.stop_profile("data_loading")
124
+ ```
125
+ """
70
126
  self.profiler.stop(action_name)
71
127
 
72
128
  @rank_zero_only
73
129
  def print_profile_summary(self, title: Optional[str] = None):
130
+ """
131
+ Print a summary of all profiled actions.
132
+
133
+ This method outputs a formatted summary showing the timing information
134
+ for all actions that have been profiled. The output includes action names
135
+ and their execution times.
136
+
137
+ Args:
138
+ title: Optional title to print before the profiling summary.
139
+ If provided, this will be printed as a header.
140
+
141
+ Note:
142
+ This method is decorated with @rank_zero_only, meaning it will only
143
+ execute on the main process in distributed training scenarios.
144
+
145
+ Example:
146
+ ```python
147
+ # After profiling some actions
148
+ self.print_profile_summary("Training Performance Summary")
149
+ ```
150
+ """
74
151
  if title is not None:
75
152
  print(title)
76
153
  print(self.profiler.summary())
77
154
 
78
155
  def __del__(self):
156
+ """
157
+ Cleanup when the object is destroyed.
158
+
159
+ Ensures that the profiler instance is properly cleaned up to prevent
160
+ memory leaks when the mixin instance is garbage collected.
161
+ """
79
162
  if self._profiler is not None:
80
163
  del self._profiler
81
164
  self._profiler = None
@@ -18,6 +18,7 @@ _import_structure = {
18
18
  "GPT2ForSequenceClassificationPool",
19
19
  ],
20
20
  "seq_classification_lm": ["SequenceClassificationModelPool"],
21
+ "resnet_for_image_classification": ["ResNetForImageClassificationPool"],
21
22
  }
22
23
 
23
24
 
@@ -33,6 +34,7 @@ if TYPE_CHECKING:
33
34
  from .nyuv2_modelpool import NYUv2ModelPool
34
35
  from .openclip_vision import OpenCLIPVisionModelPool
35
36
  from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
37
+ from .resnet_for_image_classification import ResNetForImageClassificationPool
36
38
  from .seq2seq_lm import Seq2SeqLMPool
37
39
  from .seq_classification_lm import SequenceClassificationModelPool
38
40
 
@@ -180,26 +180,59 @@ class BaseModelPool(
180
180
 
181
181
  Args:
182
182
  model_name_or_config (Union[str, DictConfig]): The model name or configuration.
183
+ - If str: should be a key in self._models
184
+ - If DictConfig: should be a configuration dict for instantiation
185
+ *args: Additional positional arguments passed to model instantiation.
186
+ **kwargs: Additional keyword arguments passed to model instantiation.
183
187
 
184
188
  Returns:
185
- nn.Module: The instantiated model.
189
+ nn.Module: The instantiated or retrieved model.
186
190
  """
187
191
  log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
188
- if isinstance(self._models, DictConfig):
189
- model_config = (
190
- self._models[model_name_or_config]
191
- if isinstance(model_name_or_config, str)
192
- else model_name_or_config
193
- )
194
- model = instantiate(model_config, *args, **kwargs)
195
- elif isinstance(self._models, Dict) and isinstance(model_name_or_config, str):
196
- model = self._models[model_name_or_config]
192
+
193
+ if isinstance(model_name_or_config, str):
194
+ model_name = model_name_or_config
195
+ # Handle string model names - lookup in the model pool
196
+ if model_name not in self._models:
197
+ raise KeyError(
198
+ f"Model '{model_name}' not found in model pool. "
199
+ f"Available models: {list(self._models.keys())}"
200
+ )
201
+ model_config = self._models[model_name]
202
+
203
+ # Handle different types of model configurations
204
+ match model_config:
205
+ case dict() | DictConfig() as config:
206
+ # Configuration that needs instantiation
207
+ log.debug(f"Instantiating model '{model_name}' from configuration")
208
+ return instantiate(config, *args, **kwargs)
209
+
210
+ case nn.Module() as model:
211
+ # Pre-instantiated model - return directly
212
+ log.debug(
213
+ f"Returning pre-instantiated model '{model_name}' of type {type(model)}"
214
+ )
215
+ return model
216
+
217
+ case _:
218
+ # Unsupported model configuration type
219
+ raise ValueError(
220
+ f"Unsupported model configuration type for '{model_name}': {type(model_config)}. "
221
+ f"Expected nn.Module, dict, or DictConfig."
222
+ )
223
+
224
+ elif isinstance(model_name_or_config, (dict, DictConfig)):
225
+ # Direct configuration - instantiate directly
226
+ log.debug("Instantiating model from direct DictConfig")
227
+ model_config = model_name_or_config
228
+ return instantiate(model_config, *args, **kwargs)
229
+
197
230
  else:
198
- raise ValueError(
199
- "The model pool configuration is not in the expected format."
200
- f"We expected a DictConfig or Dict, but got {type(self._models)}."
231
+ # Unsupported input type
232
+ raise TypeError(
233
+ f"Unsupported input type: {type(model_name_or_config)}. "
234
+ f"Expected str or DictConfig."
201
235
  )
202
- return model
203
236
 
204
237
  def load_pretrained_model(self, *args, **kwargs):
205
238
  assert (
@@ -229,6 +262,36 @@ class BaseModelPool(
229
262
  for model_name in self.model_names:
230
263
  yield model_name, self.load_model(model_name)
231
264
 
265
+ @property
266
+ def has_train_dataset(self) -> bool:
267
+ """
268
+ Check if the model pool contains training datasets.
269
+
270
+ Returns:
271
+ bool: True if training datasets are available, False otherwise.
272
+ """
273
+ return self._train_datasets is not None and len(self._train_datasets) > 0
274
+
275
+ @property
276
+ def has_val_dataset(self) -> bool:
277
+ """
278
+ Check if the model pool contains validation datasets.
279
+
280
+ Returns:
281
+ bool: True if validation datasets are available, False otherwise.
282
+ """
283
+ return self._val_datasets is not None and len(self._val_datasets) > 0
284
+
285
+ @property
286
+ def has_test_dataset(self) -> bool:
287
+ """
288
+ Check if the model pool contains testing datasets.
289
+
290
+ Returns:
291
+ bool: True if testing datasets are available, False otherwise.
292
+ """
293
+ return self._test_datasets is not None and len(self._test_datasets) > 0
294
+
232
295
  def load_train_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
233
296
  """
234
297
  Load the training dataset for the specified model.
@@ -93,42 +93,79 @@ class CLIPVisionModelPool(BaseModelPool):
93
93
  self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
94
94
  ) -> CLIPVisionModel:
95
95
  """
96
- This method is used to load a CLIPVisionModel from the model pool.
96
+ Load a CLIPVisionModel from the model pool with support for various configuration formats.
97
97
 
98
- Example configuration could be:
98
+ This method provides flexible model loading capabilities, handling different types of model
99
+ configurations including string paths, pre-instantiated models, and complex configurations.
99
100
 
101
+ Supported configuration formats:
102
+ 1. String model paths (e.g., Hugging Face model IDs)
103
+ 2. Pre-instantiated nn.Module objects
104
+ 3. DictConfig objects for complex configurations
105
+
106
+ Example configuration:
100
107
  ```yaml
101
108
  models:
109
+ # Simple string paths to Hugging Face models
102
110
  cifar10: tanganke/clip-vit-base-patch32_cifar10
103
111
  sun397: tanganke/clip-vit-base-patch32_sun397
104
112
  stanford-cars: tanganke/clip-vit-base-patch32_stanford-cars
113
+
114
+ # Complex configuration with additional parameters
115
+ custom_model:
116
+ _target_: transformers.CLIPVisionModel.from_pretrained
117
+ pretrained_model_name_or_path: openai/clip-vit-base-patch32
118
+ torch_dtype: float16
105
119
  ```
106
120
 
107
121
  Args:
108
- model_name_or_config (Union[str, DictConfig]): The name of the model or the model configuration.
122
+ model_name_or_config (Union[str, DictConfig]): Either a model name from the pool
123
+ or a configuration dictionary for instantiating the model.
124
+ *args: Additional positional arguments passed to model loading/instantiation.
125
+ **kwargs: Additional keyword arguments passed to model loading/instantiation.
109
126
 
110
127
  Returns:
111
- CLIPVisionModel: The loaded CLIPVisionModel.
128
+ CLIPVisionModel: The loaded CLIPVisionModel instance.
112
129
  """
130
+ # Check if we have a string model name that exists in our model pool
113
131
  if (
114
132
  isinstance(model_name_or_config, str)
115
133
  and model_name_or_config in self._models
116
134
  ):
117
- model = self._models[model_name_or_config]
118
- if isinstance(model, str):
119
- if rank_zero_only.rank == 0:
120
- log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
121
- repo_path = resolve_repo_path(
122
- model, repo_type="model", platform=self._platform
123
- )
124
- return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)
125
- if isinstance(model, nn.Module):
126
- if rank_zero_only.rank == 0:
127
- log.info(f"Returning existing model: {model}")
128
- return model
129
- else:
130
- # If the model is not a string, we use the default load_model method
131
- return super().load_model(model_name_or_config, *args, **kwargs)
135
+ model_name = model_name_or_config
136
+
137
+ # handle different model configuration types
138
+ match self._models[model_name_or_config]:
139
+ case str() as model_path:
140
+ # Handle string model paths (e.g., Hugging Face model IDs)
141
+ if rank_zero_only.rank == 0:
142
+ log.info(
143
+ f"Loading model `{model_name}` of type `transformers.CLIPVisionModel` from {model_path}"
144
+ )
145
+ # Resolve the repository path (supports both HuggingFace and ModelScope)
146
+ repo_path = resolve_repo_path(
147
+ model_path, repo_type="model", platform=self._platform
148
+ )
149
+ # Load and return the CLIPVisionModel from the resolved path
150
+ return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)
151
+
152
+ case nn.Module() as model:
153
+ # Handle pre-instantiated model objects
154
+ if rank_zero_only.rank == 0:
155
+ log.info(
156
+ f"Returning existing model `{model_name}` of type {type(model)}"
157
+ )
158
+ return model
159
+
160
+ case _:
161
+ # Handle other configuration types (e.g., DictConfig) via parent class
162
+ # This fallback prevents returning None when the model config doesn't
163
+ # match the expected string or nn.Module patterns
164
+ return super().load_model(model_name_or_config, *args, **kwargs)
165
+
166
+ # If model_name_or_config is not a string in our pool, delegate to parent class
167
+ # This handles cases where model_name_or_config is a DictConfig directly
168
+ return super().load_model(model_name_or_config, *args, **kwargs)
132
169
 
133
170
  def load_train_dataset(self, dataset_name: str, *args, **kwargs):
134
171
  dataset_config = self._train_datasets[dataset_name]