fusion-bench 0.2.24__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +152 -42
- fusion_bench/dataset/__init__.py +27 -4
- fusion_bench/dataset/clip_dataset.py +2 -2
- fusion_bench/method/__init__.py +10 -1
- fusion_bench/method/classification/__init__.py +27 -2
- fusion_bench/method/classification/image_classification_finetune.py +214 -0
- fusion_bench/method/opcm/opcm.py +1 -0
- fusion_bench/method/pwe_moe/module.py +0 -2
- fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/pyinstrument.py +174 -0
- fusion_bench/mixins/simple_profiler.py +106 -23
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +77 -14
- fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
- fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
- fusion_bench/models/__init__.py +35 -9
- fusion_bench/optim/__init__.py +40 -2
- fusion_bench/optim/lr_scheduler/__init__.py +27 -1
- fusion_bench/optim/muon.py +339 -0
- fusion_bench/programs/__init__.py +2 -0
- fusion_bench/programs/fabric_fusion_program.py +2 -2
- fusion_bench/programs/fusion_program.py +271 -0
- fusion_bench/tasks/clip_classification/__init__.py +15 -0
- fusion_bench/utils/__init__.py +167 -21
- fusion_bench/utils/lazy_imports.py +91 -12
- fusion_bench/utils/lazy_state_dict.py +55 -5
- fusion_bench/utils/misc.py +104 -13
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/path.py +7 -0
- fusion_bench/utils/pylogger.py +6 -0
- fusion_bench/utils/rich_utils.py +1 -0
- fusion_bench/utils/state_dict_arithmetic.py +935 -162
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +48 -34
- fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
- fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
- fusion_bench_config/model_fusion.yaml +45 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,11 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for handling model checkpoints and state dictionaries.
|
|
3
|
+
|
|
4
|
+
This module provides classes and functions for lazily loading state dictionaries
|
|
5
|
+
from various checkpoint formats, including PyTorch .bin files, SafeTensors files,
|
|
6
|
+
and sharded checkpoints.
|
|
7
|
+
"""
|
|
8
|
+
|
|
1
9
|
import json
|
|
2
10
|
import logging
|
|
3
11
|
import os
|
|
@@ -43,6 +51,21 @@ def resolve_checkpoint_path(
|
|
|
43
51
|
hf_cache_dir: Optional[str] = None,
|
|
44
52
|
hf_proxies: Optional[Dict] = None,
|
|
45
53
|
):
|
|
54
|
+
"""
|
|
55
|
+
Resolve a checkpoint path, downloading from Hugging Face Hub if necessary.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
checkpoint: Path to local checkpoint or Hugging Face model ID.
|
|
59
|
+
hf_revision: Specific revision to download from HF Hub.
|
|
60
|
+
hf_cache_dir: Local cache directory for HF downloads.
|
|
61
|
+
hf_proxies: Proxy settings for HF downloads.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Local path to the checkpoint.
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
FileNotFoundError: If the checkpoint cannot be resolved.
|
|
68
|
+
"""
|
|
46
69
|
# If it's a local file or directory, return as is
|
|
47
70
|
if os.path.exists(checkpoint):
|
|
48
71
|
return checkpoint
|
|
@@ -64,11 +87,11 @@ def resolve_checkpoint_path(
|
|
|
64
87
|
|
|
65
88
|
class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
66
89
|
"""
|
|
67
|
-
|
|
90
|
+
A dictionary-like object that lazily loads tensors from model checkpoints.
|
|
68
91
|
"""
|
|
69
92
|
|
|
70
93
|
_local_path: str
|
|
71
|
-
"""
|
|
94
|
+
"""Local path to the checkpoint."""
|
|
72
95
|
_state_dict_cache: Optional[Dict]
|
|
73
96
|
"""Cache for the state dict, if enabled."""
|
|
74
97
|
_index_filename: Optional[str]
|
|
@@ -92,6 +115,8 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
92
115
|
hf_proxies: Optional[Dict] = None,
|
|
93
116
|
):
|
|
94
117
|
"""
|
|
118
|
+
Initialize LazyStateDict with a checkpoint path.
|
|
119
|
+
|
|
95
120
|
Args:
|
|
96
121
|
checkpoint (str): Path to the checkpoint file or directory.
|
|
97
122
|
meta_module_class (Type[nn.Module], optional): Class of the meta module to instantiate.
|
|
@@ -116,6 +141,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
116
141
|
self.meta_module_class = import_object(self.meta_module_class)
|
|
117
142
|
self.meta_module = meta_module
|
|
118
143
|
|
|
144
|
+
# Instantiate meta module if class provided
|
|
119
145
|
if self.meta_module_class is not None:
|
|
120
146
|
with init_empty_weights():
|
|
121
147
|
self.meta_module = self.meta_module_class.from_pretrained(
|
|
@@ -126,6 +152,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
126
152
|
proxies=hf_proxies,
|
|
127
153
|
)
|
|
128
154
|
|
|
155
|
+
# Store original checkpoint path and resolve to local path
|
|
129
156
|
self._checkpoint = checkpoint
|
|
130
157
|
self._local_path = resolve_checkpoint_path(
|
|
131
158
|
checkpoint,
|
|
@@ -134,10 +161,12 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
134
161
|
hf_proxies=hf_proxies,
|
|
135
162
|
)
|
|
136
163
|
|
|
164
|
+
# Detect checkpoint file type and set up indexing
|
|
137
165
|
self._index, self._index_filename, self._checkpoint_files = (
|
|
138
166
|
self._resolve_checkpoint_files(self._local_path)
|
|
139
167
|
)
|
|
140
168
|
|
|
169
|
+
# Set up based on checkpoint type
|
|
141
170
|
if self._index is not None:
|
|
142
171
|
# if meta_module is provided, remove the keys that are not in the meta_module
|
|
143
172
|
if self.meta_module is not None:
|
|
@@ -152,7 +181,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
152
181
|
elif len(self._checkpoint_files) == 1 and self._checkpoint_files[0].endswith(
|
|
153
182
|
SAFE_WEIGHTS_NAME
|
|
154
183
|
):
|
|
155
|
-
#
|
|
184
|
+
# SafeTensors file: create index mapping all keys to this file
|
|
156
185
|
with safe_open(
|
|
157
186
|
self._checkpoint_files[0], framework="pt", device=device
|
|
158
187
|
) as f:
|
|
@@ -164,6 +193,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
164
193
|
elif len(self._checkpoint_files) == 1 and self._checkpoint_files[0].endswith(
|
|
165
194
|
WEIGHTS_NAME
|
|
166
195
|
):
|
|
196
|
+
# PyTorch .bin file: load entire state dict immediately
|
|
167
197
|
log.info(f"Loading full state dict from {WEIGHTS_NAME}")
|
|
168
198
|
self._state_dict_cache = torch.load(self._checkpoint_files[0])
|
|
169
199
|
# if meta_module is provided, remove the keys that are not in the meta_module
|
|
@@ -173,6 +203,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
173
203
|
if key not in meta_module_state_dict:
|
|
174
204
|
self._state_dict_cache.pop(key)
|
|
175
205
|
else:
|
|
206
|
+
# Unsupported checkpoint format
|
|
176
207
|
raise ValueError(
|
|
177
208
|
f"Cannot determine the type of checkpoint, please provide a checkpoint path to a file containing a whole state dict with file name {WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME}, or the index of a sharded checkpoint ending with `.index.json`."
|
|
178
209
|
)
|
|
@@ -209,10 +240,19 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
209
240
|
return deepcopy(self)
|
|
210
241
|
|
|
211
242
|
def _resolve_checkpoint_files(self, checkpoint: str):
|
|
212
|
-
|
|
243
|
+
"""
|
|
244
|
+
Detect and resolve checkpoint files based on the checkpoint path.
|
|
245
|
+
|
|
246
|
+
Handles single files, directories with state dict files, and sharded checkpoints.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Tuple of (index_dict, index_filename, checkpoint_files)
|
|
250
|
+
"""
|
|
251
|
+
# Reference: https://huggingface.co/docs/accelerate/v0.17.1/en/usage_guides/big_modeling
|
|
213
252
|
checkpoint_files = None
|
|
214
253
|
index_filename = None
|
|
215
254
|
if os.path.isfile(checkpoint):
|
|
255
|
+
# Single file: check if it's an index or a state dict
|
|
216
256
|
if str(checkpoint).endswith(".json"):
|
|
217
257
|
index_filename = checkpoint
|
|
218
258
|
else:
|
|
@@ -232,7 +272,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
232
272
|
os.path.join(checkpoint, potential_state_safetensor[0])
|
|
233
273
|
]
|
|
234
274
|
else:
|
|
235
|
-
#
|
|
275
|
+
# Check for sharded checkpoints
|
|
236
276
|
potential_index = [
|
|
237
277
|
f for f in os.listdir(checkpoint) if f.endswith(".index.json")
|
|
238
278
|
]
|
|
@@ -247,18 +287,22 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
247
287
|
f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones."
|
|
248
288
|
)
|
|
249
289
|
else:
|
|
290
|
+
# Invalid checkpoint path
|
|
250
291
|
raise ValueError(
|
|
251
292
|
"`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
|
|
252
293
|
f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}."
|
|
253
294
|
)
|
|
254
295
|
|
|
296
|
+
# Load index file if present
|
|
255
297
|
if index_filename is not None:
|
|
256
298
|
checkpoint_folder = os.path.split(index_filename)[0]
|
|
257
299
|
with open(index_filename) as f:
|
|
258
300
|
index = json.loads(f.read())
|
|
259
301
|
|
|
302
|
+
# Extract weight_map if present (standard format)
|
|
260
303
|
if "weight_map" in index:
|
|
261
304
|
index = index["weight_map"]
|
|
305
|
+
# Get list of unique checkpoint files
|
|
262
306
|
checkpoint_files = sorted(list(set(index.values())))
|
|
263
307
|
checkpoint_files = [
|
|
264
308
|
os.path.join(checkpoint_folder, f) for f in checkpoint_files
|
|
@@ -270,6 +314,11 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
270
314
|
def _load_tensor_from_checkpoint_file(
|
|
271
315
|
self, checkpoint_file: str, key: str, update_cache: bool = True
|
|
272
316
|
) -> torch.Tensor:
|
|
317
|
+
"""
|
|
318
|
+
Load a tensor from the checkpoint file.
|
|
319
|
+
For safetensors, loads only the requested tensor.
|
|
320
|
+
For PyTorch files, loads the entire state dict on first access.
|
|
321
|
+
"""
|
|
273
322
|
if checkpoint_file.endswith(".safetensors"):
|
|
274
323
|
with safe_open(checkpoint_file, framework="pt", device=self._device) as f:
|
|
275
324
|
tensor = f.get_tensor(key)
|
|
@@ -279,6 +328,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
279
328
|
self._state_dict_cache[key] = tensor
|
|
280
329
|
return tensor
|
|
281
330
|
else:
|
|
331
|
+
# PyTorch .bin file: load entire state dict
|
|
282
332
|
state_dict = torch.load(checkpoint_file, map_location=self._device)
|
|
283
333
|
if update_cache:
|
|
284
334
|
if self._state_dict_cache is not None:
|
fusion_bench/utils/misc.py
CHANGED
|
@@ -1,35 +1,126 @@
|
|
|
1
1
|
from difflib import get_close_matches
|
|
2
|
-
from typing import Any, Iterable, List, Optional
|
|
2
|
+
from typing import Any, Iterable, List, Optional, TypeVar, Union
|
|
3
|
+
|
|
4
|
+
T = TypeVar("T")
|
|
3
5
|
|
|
4
6
|
__all__ = [
|
|
5
7
|
"first",
|
|
6
8
|
"has_length",
|
|
7
|
-
"
|
|
9
|
+
"join_lists",
|
|
8
10
|
"attr_equal",
|
|
9
11
|
"validate_and_suggest_corrections",
|
|
10
12
|
]
|
|
11
13
|
|
|
12
14
|
|
|
13
|
-
def first(iterable: Iterable):
|
|
14
|
-
|
|
15
|
+
def first(iterable: Iterable[T], default: Optional[T] = None) -> Optional[T]:
|
|
16
|
+
"""
|
|
17
|
+
Return the first element of an iterable.
|
|
15
18
|
|
|
19
|
+
Args:
|
|
20
|
+
iterable: The iterable to get the first element from.
|
|
21
|
+
default: The value to return if the iterable is empty. If None and
|
|
22
|
+
the iterable is empty, raises StopIteration.
|
|
16
23
|
|
|
17
|
-
|
|
24
|
+
Returns:
|
|
25
|
+
The first element of the iterable, or the default value if empty.
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
StopIteration: If the iterable is empty and no default is provided.
|
|
29
|
+
TypeError: If the object is not iterable.
|
|
18
30
|
"""
|
|
19
|
-
|
|
31
|
+
try:
|
|
32
|
+
iterator = iter(iterable)
|
|
33
|
+
return next(iterator)
|
|
34
|
+
except StopIteration:
|
|
35
|
+
if default is not None:
|
|
36
|
+
return default
|
|
37
|
+
raise
|
|
38
|
+
except TypeError as e:
|
|
39
|
+
raise TypeError(
|
|
40
|
+
f"Object of type {type(iterable).__name__} is not iterable"
|
|
41
|
+
) from e
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def has_length(obj: Any) -> bool:
|
|
20
45
|
"""
|
|
46
|
+
Check if an object has a length (implements __len__) and len() works correctly.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
obj: The object to check for length support.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
bool: True if the object supports len() and doesn't raise an error,
|
|
53
|
+
False otherwise.
|
|
54
|
+
"""
|
|
55
|
+
if obj is None:
|
|
56
|
+
return False
|
|
57
|
+
|
|
21
58
|
try:
|
|
22
|
-
|
|
23
|
-
|
|
59
|
+
# Check if __len__ method exists
|
|
60
|
+
if not hasattr(obj, "__len__"):
|
|
61
|
+
return False
|
|
62
|
+
|
|
63
|
+
# Try to get the length - this will raise TypeError for unsized objects
|
|
64
|
+
length = len(obj)
|
|
65
|
+
|
|
66
|
+
# Verify the length is a non-negative integer
|
|
67
|
+
return isinstance(length, int) and length >= 0
|
|
68
|
+
except (TypeError, AttributeError):
|
|
24
69
|
# TypeError: len() of unsized object
|
|
70
|
+
# AttributeError: if __len__ is not callable somehow
|
|
25
71
|
return False
|
|
72
|
+
except Exception:
|
|
73
|
+
# Any other unexpected error
|
|
74
|
+
return False
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def join_lists(list_of_lists: Iterable[Iterable[T]]) -> List[T]:
|
|
78
|
+
"""
|
|
79
|
+
Flatten a collection of iterables into a single list.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
list_of_lists: An iterable containing iterables to be flattened.
|
|
26
83
|
|
|
84
|
+
Returns:
|
|
85
|
+
List[T]: A new list containing all elements from the input iterables
|
|
86
|
+
in order.
|
|
27
87
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
88
|
+
Raises:
|
|
89
|
+
TypeError: If any item in list_of_lists is not iterable.
|
|
90
|
+
|
|
91
|
+
Examples:
|
|
92
|
+
>>> join_lists([[1, 2], [3, 4], [5]])
|
|
93
|
+
[1, 2, 3, 4, 5]
|
|
94
|
+
>>> join_lists([])
|
|
95
|
+
[]
|
|
96
|
+
>>> join_lists([[], [1], [], [2, 3]])
|
|
97
|
+
[1, 2, 3]
|
|
98
|
+
"""
|
|
99
|
+
if not list_of_lists:
|
|
100
|
+
return []
|
|
101
|
+
|
|
102
|
+
result = []
|
|
103
|
+
for i, item in enumerate(list_of_lists):
|
|
104
|
+
try:
|
|
105
|
+
# Check if item is iterable (but not string, which is iterable but
|
|
106
|
+
# usually not what we want to flatten character by character)
|
|
107
|
+
if isinstance(item, (str, bytes)):
|
|
108
|
+
raise TypeError(
|
|
109
|
+
f"Item at index {i} is a string/bytes, not a list-like iterable"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Try to extend with the item
|
|
113
|
+
result.extend(item)
|
|
114
|
+
except TypeError as e:
|
|
115
|
+
if "not iterable" in str(e):
|
|
116
|
+
raise TypeError(
|
|
117
|
+
f"Item at index {i} (type: {type(item).__name__}) is not iterable"
|
|
118
|
+
) from e
|
|
119
|
+
else:
|
|
120
|
+
# Re-raise our custom error or other TypeError
|
|
121
|
+
raise
|
|
122
|
+
|
|
123
|
+
return result
|
|
33
124
|
|
|
34
125
|
|
|
35
126
|
def attr_equal(obj, attr: str, value):
|
fusion_bench/utils/packages.py
CHANGED
fusion_bench/utils/path.py
CHANGED
|
@@ -2,6 +2,8 @@ import logging
|
|
|
2
2
|
import os
|
|
3
3
|
from typing import List
|
|
4
4
|
|
|
5
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
6
|
+
|
|
5
7
|
log = logging.getLogger(__name__)
|
|
6
8
|
|
|
7
9
|
|
|
@@ -25,6 +27,7 @@ def listdir_fullpath(dir: str) -> List[str]:
|
|
|
25
27
|
return [os.path.join(dir, name) for name in names]
|
|
26
28
|
|
|
27
29
|
|
|
30
|
+
@rank_zero_only
|
|
28
31
|
def create_symlink(src_dir: str, dst_dir: str, link_name: str = None):
|
|
29
32
|
"""
|
|
30
33
|
Creates a symbolic link from src_dir to dst_dir.
|
|
@@ -59,6 +62,10 @@ def create_symlink(src_dir: str, dst_dir: str, link_name: str = None):
|
|
|
59
62
|
link_name = os.path.basename(src_dir)
|
|
60
63
|
|
|
61
64
|
link_path = os.path.join(dst_dir, link_name)
|
|
65
|
+
# if the link already exists, skip
|
|
66
|
+
if os.path.exists(link_path):
|
|
67
|
+
log.warning(f"Symbolic link already exists, skipping: {link_path}")
|
|
68
|
+
return
|
|
62
69
|
|
|
63
70
|
try:
|
|
64
71
|
# if the system is windows, use the `mklink` command in "CMD" to create the symlink
|
fusion_bench/utils/pylogger.py
CHANGED
|
@@ -3,6 +3,12 @@ from typing import Mapping, Optional
|
|
|
3
3
|
|
|
4
4
|
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
|
|
5
5
|
|
|
6
|
+
__all__ = [
|
|
7
|
+
"RankedLogger",
|
|
8
|
+
"RankZeroLogger",
|
|
9
|
+
"get_rankzero_logger",
|
|
10
|
+
]
|
|
11
|
+
|
|
6
12
|
|
|
7
13
|
class RankedLogger(logging.LoggerAdapter):
|
|
8
14
|
"""A multi-GPU-friendly python command line logger."""
|
fusion_bench/utils/rich_utils.py
CHANGED