fusion-bench 0.2.24__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +152 -42
- fusion_bench/dataset/__init__.py +27 -4
- fusion_bench/dataset/clip_dataset.py +2 -2
- fusion_bench/method/__init__.py +10 -1
- fusion_bench/method/classification/__init__.py +27 -2
- fusion_bench/method/classification/image_classification_finetune.py +214 -0
- fusion_bench/method/opcm/opcm.py +1 -0
- fusion_bench/method/pwe_moe/module.py +0 -2
- fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/pyinstrument.py +174 -0
- fusion_bench/mixins/simple_profiler.py +106 -23
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +77 -14
- fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
- fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
- fusion_bench/models/__init__.py +35 -9
- fusion_bench/optim/__init__.py +40 -2
- fusion_bench/optim/lr_scheduler/__init__.py +27 -1
- fusion_bench/optim/muon.py +339 -0
- fusion_bench/programs/__init__.py +2 -0
- fusion_bench/programs/fabric_fusion_program.py +2 -2
- fusion_bench/programs/fusion_program.py +271 -0
- fusion_bench/tasks/clip_classification/__init__.py +15 -0
- fusion_bench/utils/__init__.py +167 -21
- fusion_bench/utils/lazy_imports.py +91 -12
- fusion_bench/utils/lazy_state_dict.py +55 -5
- fusion_bench/utils/misc.py +104 -13
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/path.py +7 -0
- fusion_bench/utils/pylogger.py +6 -0
- fusion_bench/utils/rich_utils.py +1 -0
- fusion_bench/utils/state_dict_arithmetic.py +935 -162
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +48 -34
- fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
- fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
- fusion_bench_config/model_fusion.yaml +45 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Generator, Optional, Union
|
|
4
|
+
|
|
5
|
+
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
6
|
+
from pyinstrument import Profiler
|
|
7
|
+
|
|
8
|
+
__all__ = ["PyinstrumentProfilerMixin"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PyinstrumentProfilerMixin:
|
|
12
|
+
"""
|
|
13
|
+
A mixin class that provides statistical profiling capabilities using pyinstrument.
|
|
14
|
+
|
|
15
|
+
This mixin allows for easy profiling of code blocks using a context manager.
|
|
16
|
+
It provides methods to start and stop profiling actions, save profiling results
|
|
17
|
+
to files, and print profiling summaries.
|
|
18
|
+
|
|
19
|
+
Note:
|
|
20
|
+
This mixin requires the `pyinstrument` package to be installed.
|
|
21
|
+
If not available, an ImportError will be raised when importing this module.
|
|
22
|
+
|
|
23
|
+
Examples:
|
|
24
|
+
|
|
25
|
+
```python
|
|
26
|
+
class MyClass(PyinstrumentProfilerMixin):
|
|
27
|
+
def do_something(self):
|
|
28
|
+
with self.profile("work"):
|
|
29
|
+
# do some work here
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
# save the profiling results
|
|
33
|
+
self.save_profile_report("profile_report.html")
|
|
34
|
+
|
|
35
|
+
# or print the summary
|
|
36
|
+
self.print_profile_summary()
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
_profiler (Profiler): An instance of the pyinstrument Profiler class.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
_profiler: Optional[Profiler] = None
|
|
44
|
+
_is_profiling: bool = False
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def profiler(self) -> Optional[Profiler]:
|
|
48
|
+
"""Get the profiler instance, creating it if necessary."""
|
|
49
|
+
if self._profiler is None:
|
|
50
|
+
self._profiler = Profiler()
|
|
51
|
+
return self._profiler
|
|
52
|
+
|
|
53
|
+
@contextmanager
|
|
54
|
+
def profile(self, action_name: Optional[str] = None) -> Generator:
|
|
55
|
+
"""
|
|
56
|
+
Context manager for profiling a code block.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
action_name: Optional name for the profiling action (for logging purposes).
|
|
60
|
+
|
|
61
|
+
Example:
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
with self.profile("expensive_operation"):
|
|
65
|
+
# do some expensive work here
|
|
66
|
+
expensive_function()
|
|
67
|
+
```
|
|
68
|
+
"""
|
|
69
|
+
try:
|
|
70
|
+
self.start_profile(action_name)
|
|
71
|
+
yield action_name
|
|
72
|
+
finally:
|
|
73
|
+
self.stop_profile(action_name)
|
|
74
|
+
|
|
75
|
+
def start_profile(self, action_name: Optional[str] = None):
|
|
76
|
+
"""
|
|
77
|
+
Start profiling.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
action_name: Optional name for the profiling action.
|
|
81
|
+
"""
|
|
82
|
+
if self._is_profiling:
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
self.profiler.start()
|
|
86
|
+
self._is_profiling = True
|
|
87
|
+
if action_name:
|
|
88
|
+
print(f"Started profiling: {action_name}")
|
|
89
|
+
|
|
90
|
+
def stop_profile(self, action_name: Optional[str] = None):
|
|
91
|
+
"""
|
|
92
|
+
Stop profiling.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
action_name: Optional name for the profiling action.
|
|
96
|
+
"""
|
|
97
|
+
if not self._is_profiling:
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
self.profiler.stop()
|
|
101
|
+
self._is_profiling = False
|
|
102
|
+
if action_name:
|
|
103
|
+
print(f"Stopped profiling: {action_name}")
|
|
104
|
+
|
|
105
|
+
@rank_zero_only
|
|
106
|
+
def print_profile_summary(
|
|
107
|
+
self, title: Optional[str] = None, unicode: bool = True, color: bool = True
|
|
108
|
+
):
|
|
109
|
+
"""
|
|
110
|
+
Print a summary of the profiling results.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
title: Optional title to print before the summary.
|
|
114
|
+
unicode: Whether to use unicode characters in the output.
|
|
115
|
+
color: Whether to use color in the output.
|
|
116
|
+
"""
|
|
117
|
+
if self.profiler is None:
|
|
118
|
+
print("No profiling data available.")
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
if title is not None:
|
|
122
|
+
print(title)
|
|
123
|
+
|
|
124
|
+
print(self.profiler.output_text(unicode=unicode, color=color))
|
|
125
|
+
|
|
126
|
+
@rank_zero_only
|
|
127
|
+
def save_profile_report(
|
|
128
|
+
self,
|
|
129
|
+
output_path: Union[str, Path] = "profile_report.html",
|
|
130
|
+
format: str = "html",
|
|
131
|
+
title: Optional[str] = None,
|
|
132
|
+
):
|
|
133
|
+
"""
|
|
134
|
+
Save the profiling results to a file.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
output_path: Path where to save the profiling report.
|
|
138
|
+
format: Output format ('html', or 'text').
|
|
139
|
+
title: Optional title for the report.
|
|
140
|
+
"""
|
|
141
|
+
if self.profiler is None:
|
|
142
|
+
print("No profiling data available.")
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
output_path = Path(output_path)
|
|
146
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
147
|
+
|
|
148
|
+
if format.lower() == "html":
|
|
149
|
+
content = self.profiler.output_html()
|
|
150
|
+
elif format.lower() == "text":
|
|
151
|
+
content = self.profiler.output_text(unicode=True, color=False)
|
|
152
|
+
else:
|
|
153
|
+
raise ValueError(f"Unsupported format: {format}. Use 'html', or 'text'.")
|
|
154
|
+
|
|
155
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
156
|
+
f.write(content)
|
|
157
|
+
|
|
158
|
+
print(f"Profile report saved to: {output_path}")
|
|
159
|
+
|
|
160
|
+
def reset_profile(self):
|
|
161
|
+
"""Reset the profiler to start fresh."""
|
|
162
|
+
if self._is_profiling:
|
|
163
|
+
self.stop_profile()
|
|
164
|
+
|
|
165
|
+
self._profiler = None
|
|
166
|
+
|
|
167
|
+
def __del__(self):
|
|
168
|
+
"""Cleanup when the object is destroyed."""
|
|
169
|
+
if self._is_profiling:
|
|
170
|
+
self.stop_profile()
|
|
171
|
+
|
|
172
|
+
if self._profiler is not None:
|
|
173
|
+
del self._profiler
|
|
174
|
+
self._profiler = None
|
|
@@ -9,27 +9,33 @@ __all__ = ["SimpleProfilerMixin"]
|
|
|
9
9
|
|
|
10
10
|
class SimpleProfilerMixin:
|
|
11
11
|
"""
|
|
12
|
-
A mixin class that provides simple profiling capabilities.
|
|
12
|
+
A mixin class that provides simple profiling capabilities using Lightning's SimpleProfiler.
|
|
13
13
|
|
|
14
|
-
This mixin allows for easy profiling of code blocks using a context manager
|
|
15
|
-
|
|
16
|
-
a summary of the profiling results.
|
|
14
|
+
This mixin allows for easy profiling of code blocks using a context manager or manual
|
|
15
|
+
start/stop methods. It measures the execution time of named actions and provides
|
|
16
|
+
a summary of the profiling results. Unlike statistical profilers, this provides
|
|
17
|
+
precise timing measurements for specific code blocks.
|
|
17
18
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def do_something(self):
|
|
23
|
-
with self.profile("work"):
|
|
24
|
-
# do some work here
|
|
25
|
-
...
|
|
26
|
-
with self.profile("more work"):
|
|
27
|
-
# do more work here
|
|
28
|
-
...
|
|
19
|
+
Note:
|
|
20
|
+
This mixin uses Lightning's SimpleProfiler which measures wall-clock time
|
|
21
|
+
for named actions. It's suitable for timing discrete operations rather than
|
|
22
|
+
detailed function-level profiling.
|
|
29
23
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
24
|
+
Examples:
|
|
25
|
+
```python
|
|
26
|
+
class MyClass(SimpleProfilerMixin):
|
|
27
|
+
def do_something(self):
|
|
28
|
+
with self.profile("data_loading"):
|
|
29
|
+
# Load data here
|
|
30
|
+
data = load_data()
|
|
31
|
+
|
|
32
|
+
with self.profile("model_training"):
|
|
33
|
+
# Train model here
|
|
34
|
+
model.train(data)
|
|
35
|
+
|
|
36
|
+
# Print the profiling summary
|
|
37
|
+
self.print_profile_summary("Training Profile")
|
|
38
|
+
```
|
|
33
39
|
|
|
34
40
|
Attributes:
|
|
35
41
|
_profiler (SimpleProfiler): An instance of the SimpleProfiler class used for profiling.
|
|
@@ -38,7 +44,13 @@ class SimpleProfilerMixin:
|
|
|
38
44
|
_profiler: SimpleProfiler = None
|
|
39
45
|
|
|
40
46
|
@property
|
|
41
|
-
def profiler(self):
|
|
47
|
+
def profiler(self) -> SimpleProfiler:
|
|
48
|
+
"""
|
|
49
|
+
Get the SimpleProfiler instance, creating it if necessary.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
SimpleProfiler: The profiler instance used for timing measurements.
|
|
53
|
+
"""
|
|
42
54
|
# Lazy initialization of the profiler instance
|
|
43
55
|
if self._profiler is None:
|
|
44
56
|
self._profiler = SimpleProfiler()
|
|
@@ -47,14 +59,24 @@ class SimpleProfilerMixin:
|
|
|
47
59
|
@contextmanager
|
|
48
60
|
def profile(self, action_name: str) -> Generator:
|
|
49
61
|
"""
|
|
50
|
-
Context manager for profiling a code block
|
|
62
|
+
Context manager for profiling a code block.
|
|
63
|
+
|
|
64
|
+
This context manager automatically starts profiling when entering the block
|
|
65
|
+
and stops profiling when exiting the block (even if an exception occurs).
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
action_name: A descriptive name for the action being profiled.
|
|
69
|
+
This name will appear in the profiling summary.
|
|
70
|
+
|
|
71
|
+
Yields:
|
|
72
|
+
str: The action name that was provided.
|
|
51
73
|
|
|
52
74
|
Example:
|
|
53
75
|
|
|
54
76
|
```python
|
|
55
|
-
with self.profile("
|
|
56
|
-
#
|
|
57
|
-
|
|
77
|
+
with self.profile("data_processing"):
|
|
78
|
+
# Process data here
|
|
79
|
+
result = process_large_dataset()
|
|
58
80
|
```
|
|
59
81
|
"""
|
|
60
82
|
try:
|
|
@@ -64,18 +86,79 @@ class SimpleProfilerMixin:
|
|
|
64
86
|
self.stop_profile(action_name)
|
|
65
87
|
|
|
66
88
|
def start_profile(self, action_name: str):
|
|
89
|
+
"""
|
|
90
|
+
Start profiling for a named action.
|
|
91
|
+
|
|
92
|
+
This method begins timing for the specified action. You must call
|
|
93
|
+
stop_profile() with the same action name to complete the measurement.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
action_name: A descriptive name for the action being profiled.
|
|
97
|
+
This name will appear in the profiling summary.
|
|
98
|
+
|
|
99
|
+
Example:
|
|
100
|
+
```python
|
|
101
|
+
self.start_profile("model_inference")
|
|
102
|
+
result = model.predict(data)
|
|
103
|
+
self.stop_profile("model_inference")
|
|
104
|
+
```
|
|
105
|
+
"""
|
|
67
106
|
self.profiler.start(action_name)
|
|
68
107
|
|
|
69
108
|
def stop_profile(self, action_name: str):
|
|
109
|
+
"""
|
|
110
|
+
Stop profiling for a named action.
|
|
111
|
+
|
|
112
|
+
This method ends timing for the specified action that was previously
|
|
113
|
+
started with start_profile().
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
action_name: The name of the action to stop profiling.
|
|
117
|
+
Must match the name used in start_profile().
|
|
118
|
+
|
|
119
|
+
Example:
|
|
120
|
+
```python
|
|
121
|
+
self.start_profile("data_loading")
|
|
122
|
+
data = load_data()
|
|
123
|
+
self.stop_profile("data_loading")
|
|
124
|
+
```
|
|
125
|
+
"""
|
|
70
126
|
self.profiler.stop(action_name)
|
|
71
127
|
|
|
72
128
|
@rank_zero_only
|
|
73
129
|
def print_profile_summary(self, title: Optional[str] = None):
|
|
130
|
+
"""
|
|
131
|
+
Print a summary of all profiled actions.
|
|
132
|
+
|
|
133
|
+
This method outputs a formatted summary showing the timing information
|
|
134
|
+
for all actions that have been profiled. The output includes action names
|
|
135
|
+
and their execution times.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
title: Optional title to print before the profiling summary.
|
|
139
|
+
If provided, this will be printed as a header.
|
|
140
|
+
|
|
141
|
+
Note:
|
|
142
|
+
This method is decorated with @rank_zero_only, meaning it will only
|
|
143
|
+
execute on the main process in distributed training scenarios.
|
|
144
|
+
|
|
145
|
+
Example:
|
|
146
|
+
```python
|
|
147
|
+
# After profiling some actions
|
|
148
|
+
self.print_profile_summary("Training Performance Summary")
|
|
149
|
+
```
|
|
150
|
+
"""
|
|
74
151
|
if title is not None:
|
|
75
152
|
print(title)
|
|
76
153
|
print(self.profiler.summary())
|
|
77
154
|
|
|
78
155
|
def __del__(self):
|
|
156
|
+
"""
|
|
157
|
+
Cleanup when the object is destroyed.
|
|
158
|
+
|
|
159
|
+
Ensures that the profiler instance is properly cleaned up to prevent
|
|
160
|
+
memory leaks when the mixin instance is garbage collected.
|
|
161
|
+
"""
|
|
79
162
|
if self._profiler is not None:
|
|
80
163
|
del self._profiler
|
|
81
164
|
self._profiler = None
|
|
@@ -18,6 +18,7 @@ _import_structure = {
|
|
|
18
18
|
"GPT2ForSequenceClassificationPool",
|
|
19
19
|
],
|
|
20
20
|
"seq_classification_lm": ["SequenceClassificationModelPool"],
|
|
21
|
+
"resnet_for_image_classification": ["ResNetForImageClassificationPool"],
|
|
21
22
|
}
|
|
22
23
|
|
|
23
24
|
|
|
@@ -33,6 +34,7 @@ if TYPE_CHECKING:
|
|
|
33
34
|
from .nyuv2_modelpool import NYUv2ModelPool
|
|
34
35
|
from .openclip_vision import OpenCLIPVisionModelPool
|
|
35
36
|
from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
|
|
37
|
+
from .resnet_for_image_classification import ResNetForImageClassificationPool
|
|
36
38
|
from .seq2seq_lm import Seq2SeqLMPool
|
|
37
39
|
from .seq_classification_lm import SequenceClassificationModelPool
|
|
38
40
|
|
|
@@ -180,26 +180,59 @@ class BaseModelPool(
|
|
|
180
180
|
|
|
181
181
|
Args:
|
|
182
182
|
model_name_or_config (Union[str, DictConfig]): The model name or configuration.
|
|
183
|
+
- If str: should be a key in self._models
|
|
184
|
+
- If DictConfig: should be a configuration dict for instantiation
|
|
185
|
+
*args: Additional positional arguments passed to model instantiation.
|
|
186
|
+
**kwargs: Additional keyword arguments passed to model instantiation.
|
|
183
187
|
|
|
184
188
|
Returns:
|
|
185
|
-
nn.Module: The instantiated model.
|
|
189
|
+
nn.Module: The instantiated or retrieved model.
|
|
186
190
|
"""
|
|
187
191
|
log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
192
|
+
|
|
193
|
+
if isinstance(model_name_or_config, str):
|
|
194
|
+
model_name = model_name_or_config
|
|
195
|
+
# Handle string model names - lookup in the model pool
|
|
196
|
+
if model_name not in self._models:
|
|
197
|
+
raise KeyError(
|
|
198
|
+
f"Model '{model_name}' not found in model pool. "
|
|
199
|
+
f"Available models: {list(self._models.keys())}"
|
|
200
|
+
)
|
|
201
|
+
model_config = self._models[model_name]
|
|
202
|
+
|
|
203
|
+
# Handle different types of model configurations
|
|
204
|
+
match model_config:
|
|
205
|
+
case dict() | DictConfig() as config:
|
|
206
|
+
# Configuration that needs instantiation
|
|
207
|
+
log.debug(f"Instantiating model '{model_name}' from configuration")
|
|
208
|
+
return instantiate(config, *args, **kwargs)
|
|
209
|
+
|
|
210
|
+
case nn.Module() as model:
|
|
211
|
+
# Pre-instantiated model - return directly
|
|
212
|
+
log.debug(
|
|
213
|
+
f"Returning pre-instantiated model '{model_name}' of type {type(model)}"
|
|
214
|
+
)
|
|
215
|
+
return model
|
|
216
|
+
|
|
217
|
+
case _:
|
|
218
|
+
# Unsupported model configuration type
|
|
219
|
+
raise ValueError(
|
|
220
|
+
f"Unsupported model configuration type for '{model_name}': {type(model_config)}. "
|
|
221
|
+
f"Expected nn.Module, dict, or DictConfig."
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
elif isinstance(model_name_or_config, (dict, DictConfig)):
|
|
225
|
+
# Direct configuration - instantiate directly
|
|
226
|
+
log.debug("Instantiating model from direct DictConfig")
|
|
227
|
+
model_config = model_name_or_config
|
|
228
|
+
return instantiate(model_config, *args, **kwargs)
|
|
229
|
+
|
|
197
230
|
else:
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
f"
|
|
231
|
+
# Unsupported input type
|
|
232
|
+
raise TypeError(
|
|
233
|
+
f"Unsupported input type: {type(model_name_or_config)}. "
|
|
234
|
+
f"Expected str or DictConfig."
|
|
201
235
|
)
|
|
202
|
-
return model
|
|
203
236
|
|
|
204
237
|
def load_pretrained_model(self, *args, **kwargs):
|
|
205
238
|
assert (
|
|
@@ -229,6 +262,36 @@ class BaseModelPool(
|
|
|
229
262
|
for model_name in self.model_names:
|
|
230
263
|
yield model_name, self.load_model(model_name)
|
|
231
264
|
|
|
265
|
+
@property
|
|
266
|
+
def has_train_dataset(self) -> bool:
|
|
267
|
+
"""
|
|
268
|
+
Check if the model pool contains training datasets.
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
bool: True if training datasets are available, False otherwise.
|
|
272
|
+
"""
|
|
273
|
+
return self._train_datasets is not None and len(self._train_datasets) > 0
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def has_val_dataset(self) -> bool:
|
|
277
|
+
"""
|
|
278
|
+
Check if the model pool contains validation datasets.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
bool: True if validation datasets are available, False otherwise.
|
|
282
|
+
"""
|
|
283
|
+
return self._val_datasets is not None and len(self._val_datasets) > 0
|
|
284
|
+
|
|
285
|
+
@property
|
|
286
|
+
def has_test_dataset(self) -> bool:
|
|
287
|
+
"""
|
|
288
|
+
Check if the model pool contains testing datasets.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
bool: True if testing datasets are available, False otherwise.
|
|
292
|
+
"""
|
|
293
|
+
return self._test_datasets is not None and len(self._test_datasets) > 0
|
|
294
|
+
|
|
232
295
|
def load_train_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
|
|
233
296
|
"""
|
|
234
297
|
Load the training dataset for the specified model.
|
|
@@ -93,42 +93,79 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
93
93
|
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
|
|
94
94
|
) -> CLIPVisionModel:
|
|
95
95
|
"""
|
|
96
|
-
|
|
96
|
+
Load a CLIPVisionModel from the model pool with support for various configuration formats.
|
|
97
97
|
|
|
98
|
-
|
|
98
|
+
This method provides flexible model loading capabilities, handling different types of model
|
|
99
|
+
configurations including string paths, pre-instantiated models, and complex configurations.
|
|
99
100
|
|
|
101
|
+
Supported configuration formats:
|
|
102
|
+
1. String model paths (e.g., Hugging Face model IDs)
|
|
103
|
+
2. Pre-instantiated nn.Module objects
|
|
104
|
+
3. DictConfig objects for complex configurations
|
|
105
|
+
|
|
106
|
+
Example configuration:
|
|
100
107
|
```yaml
|
|
101
108
|
models:
|
|
109
|
+
# Simple string paths to Hugging Face models
|
|
102
110
|
cifar10: tanganke/clip-vit-base-patch32_cifar10
|
|
103
111
|
sun397: tanganke/clip-vit-base-patch32_sun397
|
|
104
112
|
stanford-cars: tanganke/clip-vit-base-patch32_stanford-cars
|
|
113
|
+
|
|
114
|
+
# Complex configuration with additional parameters
|
|
115
|
+
custom_model:
|
|
116
|
+
_target_: transformers.CLIPVisionModel.from_pretrained
|
|
117
|
+
pretrained_model_name_or_path: openai/clip-vit-base-patch32
|
|
118
|
+
torch_dtype: float16
|
|
105
119
|
```
|
|
106
120
|
|
|
107
121
|
Args:
|
|
108
|
-
model_name_or_config (Union[str, DictConfig]):
|
|
122
|
+
model_name_or_config (Union[str, DictConfig]): Either a model name from the pool
|
|
123
|
+
or a configuration dictionary for instantiating the model.
|
|
124
|
+
*args: Additional positional arguments passed to model loading/instantiation.
|
|
125
|
+
**kwargs: Additional keyword arguments passed to model loading/instantiation.
|
|
109
126
|
|
|
110
127
|
Returns:
|
|
111
|
-
CLIPVisionModel: The loaded CLIPVisionModel.
|
|
128
|
+
CLIPVisionModel: The loaded CLIPVisionModel instance.
|
|
112
129
|
"""
|
|
130
|
+
# Check if we have a string model name that exists in our model pool
|
|
113
131
|
if (
|
|
114
132
|
isinstance(model_name_or_config, str)
|
|
115
133
|
and model_name_or_config in self._models
|
|
116
134
|
):
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
model
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
135
|
+
model_name = model_name_or_config
|
|
136
|
+
|
|
137
|
+
# handle different model configuration types
|
|
138
|
+
match self._models[model_name_or_config]:
|
|
139
|
+
case str() as model_path:
|
|
140
|
+
# Handle string model paths (e.g., Hugging Face model IDs)
|
|
141
|
+
if rank_zero_only.rank == 0:
|
|
142
|
+
log.info(
|
|
143
|
+
f"Loading model `{model_name}` of type `transformers.CLIPVisionModel` from {model_path}"
|
|
144
|
+
)
|
|
145
|
+
# Resolve the repository path (supports both HuggingFace and ModelScope)
|
|
146
|
+
repo_path = resolve_repo_path(
|
|
147
|
+
model_path, repo_type="model", platform=self._platform
|
|
148
|
+
)
|
|
149
|
+
# Load and return the CLIPVisionModel from the resolved path
|
|
150
|
+
return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)
|
|
151
|
+
|
|
152
|
+
case nn.Module() as model:
|
|
153
|
+
# Handle pre-instantiated model objects
|
|
154
|
+
if rank_zero_only.rank == 0:
|
|
155
|
+
log.info(
|
|
156
|
+
f"Returning existing model `{model_name}` of type {type(model)}"
|
|
157
|
+
)
|
|
158
|
+
return model
|
|
159
|
+
|
|
160
|
+
case _:
|
|
161
|
+
# Handle other configuration types (e.g., DictConfig) via parent class
|
|
162
|
+
# This fallback prevents returning None when the model config doesn't
|
|
163
|
+
# match the expected string or nn.Module patterns
|
|
164
|
+
return super().load_model(model_name_or_config, *args, **kwargs)
|
|
165
|
+
|
|
166
|
+
# If model_name_or_config is not a string in our pool, delegate to parent class
|
|
167
|
+
# This handles cases where model_name_or_config is a DictConfig directly
|
|
168
|
+
return super().load_model(model_name_or_config, *args, **kwargs)
|
|
132
169
|
|
|
133
170
|
def load_train_dataset(self, dataset_name: str, *args, **kwargs):
|
|
134
171
|
dataset_config = self._train_datasets[dataset_name]
|