fusion-bench 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. fusion_bench/__init__.py +152 -42
  2. fusion_bench/dataset/__init__.py +27 -4
  3. fusion_bench/dataset/clip_dataset.py +2 -2
  4. fusion_bench/method/__init__.py +12 -1
  5. fusion_bench/method/classification/__init__.py +27 -2
  6. fusion_bench/method/classification/clip_finetune.py +6 -4
  7. fusion_bench/method/classification/image_classification_finetune.py +214 -0
  8. fusion_bench/method/dop/__init__.py +1 -0
  9. fusion_bench/method/dop/dop.py +366 -0
  10. fusion_bench/method/dop/min_norm_solvers.py +227 -0
  11. fusion_bench/method/dop/utils.py +73 -0
  12. fusion_bench/method/opcm/opcm.py +1 -0
  13. fusion_bench/method/pwe_moe/module.py +0 -2
  14. fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
  15. fusion_bench/mixins/__init__.py +2 -0
  16. fusion_bench/mixins/pyinstrument.py +174 -0
  17. fusion_bench/mixins/simple_profiler.py +106 -23
  18. fusion_bench/modelpool/__init__.py +2 -0
  19. fusion_bench/modelpool/base_pool.py +77 -14
  20. fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
  21. fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
  22. fusion_bench/models/__init__.py +35 -9
  23. fusion_bench/optim/__init__.py +40 -2
  24. fusion_bench/optim/lr_scheduler/__init__.py +27 -1
  25. fusion_bench/optim/muon.py +339 -0
  26. fusion_bench/programs/__init__.py +2 -0
  27. fusion_bench/programs/fabric_fusion_program.py +2 -2
  28. fusion_bench/programs/fusion_program.py +271 -0
  29. fusion_bench/tasks/clip_classification/__init__.py +15 -0
  30. fusion_bench/utils/__init__.py +167 -21
  31. fusion_bench/utils/lazy_imports.py +91 -12
  32. fusion_bench/utils/lazy_state_dict.py +55 -5
  33. fusion_bench/utils/misc.py +104 -13
  34. fusion_bench/utils/packages.py +4 -0
  35. fusion_bench/utils/path.py +7 -0
  36. fusion_bench/utils/pylogger.py +6 -0
  37. fusion_bench/utils/rich_utils.py +1 -0
  38. fusion_bench/utils/state_dict_arithmetic.py +935 -162
  39. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/METADATA +8 -2
  40. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/RECORD +75 -56
  41. fusion_bench_config/method/bitdelta/bitdelta.yaml +3 -0
  42. fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
  43. fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
  44. fusion_bench_config/method/depth_upscaling.yaml +9 -0
  45. fusion_bench_config/method/dop/dop.yaml +30 -0
  46. fusion_bench_config/method/dummy.yaml +6 -0
  47. fusion_bench_config/method/ensemble/max_model_predictor.yaml +6 -0
  48. fusion_bench_config/method/ensemble/simple_ensemble.yaml +8 -1
  49. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +8 -0
  50. fusion_bench_config/method/linear/linear_interpolation.yaml +8 -0
  51. fusion_bench_config/method/linear/weighted_average.yaml +3 -0
  52. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +1 -1
  53. fusion_bench_config/method/model_recombination.yaml +8 -0
  54. fusion_bench_config/method/model_stock/model_stock.yaml +4 -1
  55. fusion_bench_config/method/opcm/opcm.yaml +5 -0
  56. fusion_bench_config/method/opcm/task_arithmetic.yaml +6 -0
  57. fusion_bench_config/method/opcm/ties_merging.yaml +5 -0
  58. fusion_bench_config/method/opcm/weight_average.yaml +5 -0
  59. fusion_bench_config/method/simple_average.yaml +9 -0
  60. fusion_bench_config/method/slerp/slerp.yaml +9 -0
  61. fusion_bench_config/method/slerp/slerp_lm.yaml +5 -0
  62. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +3 -0
  63. fusion_bench_config/method/task_arithmetic.yaml +9 -0
  64. fusion_bench_config/method/ties_merging.yaml +3 -0
  65. fusion_bench_config/model_fusion.yaml +45 -0
  66. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
  67. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
  68. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
  69. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
  70. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
  71. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
  72. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/WHEEL +0 -0
  73. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/entry_points.txt +0 -0
  74. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/licenses/LICENSE +0 -0
  75. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,11 @@
1
+ """
2
+ Utilities for handling model checkpoints and state dictionaries.
3
+
4
+ This module provides classes and functions for lazily loading state dictionaries
5
+ from various checkpoint formats, including PyTorch .bin files, SafeTensors files,
6
+ and sharded checkpoints.
7
+ """
8
+
1
9
  import json
2
10
  import logging
3
11
  import os
@@ -43,6 +51,21 @@ def resolve_checkpoint_path(
43
51
  hf_cache_dir: Optional[str] = None,
44
52
  hf_proxies: Optional[Dict] = None,
45
53
  ):
54
+ """
55
+ Resolve a checkpoint path, downloading from Hugging Face Hub if necessary.
56
+
57
+ Args:
58
+ checkpoint: Path to local checkpoint or Hugging Face model ID.
59
+ hf_revision: Specific revision to download from HF Hub.
60
+ hf_cache_dir: Local cache directory for HF downloads.
61
+ hf_proxies: Proxy settings for HF downloads.
62
+
63
+ Returns:
64
+ Local path to the checkpoint.
65
+
66
+ Raises:
67
+ FileNotFoundError: If the checkpoint cannot be resolved.
68
+ """
46
69
  # If it's a local file or directory, return as is
47
70
  if os.path.exists(checkpoint):
48
71
  return checkpoint
@@ -64,11 +87,11 @@ def resolve_checkpoint_path(
64
87
 
65
88
  class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
66
89
  """
67
- Dictionary-like object that lazily loads a state dict from a checkpoint path.
90
+ A dictionary-like object that lazily loads tensors from model checkpoints.
68
91
  """
69
92
 
70
93
  _local_path: str
71
- """local path to the checkpoint."""
94
+ """Local path to the checkpoint."""
72
95
  _state_dict_cache: Optional[Dict]
73
96
  """Cache for the state dict, if enabled."""
74
97
  _index_filename: Optional[str]
@@ -92,6 +115,8 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
92
115
  hf_proxies: Optional[Dict] = None,
93
116
  ):
94
117
  """
118
+ Initialize LazyStateDict with a checkpoint path.
119
+
95
120
  Args:
96
121
  checkpoint (str): Path to the checkpoint file or directory.
97
122
  meta_module_class (Type[nn.Module], optional): Class of the meta module to instantiate.
@@ -116,6 +141,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
116
141
  self.meta_module_class = import_object(self.meta_module_class)
117
142
  self.meta_module = meta_module
118
143
 
144
+ # Instantiate meta module if class provided
119
145
  if self.meta_module_class is not None:
120
146
  with init_empty_weights():
121
147
  self.meta_module = self.meta_module_class.from_pretrained(
@@ -126,6 +152,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
126
152
  proxies=hf_proxies,
127
153
  )
128
154
 
155
+ # Store original checkpoint path and resolve to local path
129
156
  self._checkpoint = checkpoint
130
157
  self._local_path = resolve_checkpoint_path(
131
158
  checkpoint,
@@ -134,10 +161,12 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
134
161
  hf_proxies=hf_proxies,
135
162
  )
136
163
 
164
+ # Detect checkpoint file type and set up indexing
137
165
  self._index, self._index_filename, self._checkpoint_files = (
138
166
  self._resolve_checkpoint_files(self._local_path)
139
167
  )
140
168
 
169
+ # Set up based on checkpoint type
141
170
  if self._index is not None:
142
171
  # if meta_module is provided, remove the keys that are not in the meta_module
143
172
  if self.meta_module is not None:
@@ -152,7 +181,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
152
181
  elif len(self._checkpoint_files) == 1 and self._checkpoint_files[0].endswith(
153
182
  SAFE_WEIGHTS_NAME
154
183
  ):
155
- # let the keys of self._index be the keys of the state dict, the values are the checkpoint file
184
+ # SafeTensors file: create index mapping all keys to this file
156
185
  with safe_open(
157
186
  self._checkpoint_files[0], framework="pt", device=device
158
187
  ) as f:
@@ -164,6 +193,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
164
193
  elif len(self._checkpoint_files) == 1 and self._checkpoint_files[0].endswith(
165
194
  WEIGHTS_NAME
166
195
  ):
196
+ # PyTorch .bin file: load entire state dict immediately
167
197
  log.info(f"Loading full state dict from {WEIGHTS_NAME}")
168
198
  self._state_dict_cache = torch.load(self._checkpoint_files[0])
169
199
  # if meta_module is provided, remove the keys that are not in the meta_module
@@ -173,6 +203,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
173
203
  if key not in meta_module_state_dict:
174
204
  self._state_dict_cache.pop(key)
175
205
  else:
206
+ # Unsupported checkpoint format
176
207
  raise ValueError(
177
208
  f"Cannot determine the type of checkpoint, please provide a checkpoint path to a file containing a whole state dict with file name {WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME}, or the index of a sharded checkpoint ending with `.index.json`."
178
209
  )
@@ -209,10 +240,19 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
209
240
  return deepcopy(self)
210
241
 
211
242
  def _resolve_checkpoint_files(self, checkpoint: str):
212
- # reference: https://huggingface.co/docs/accelerate/v0.17.1/en/usage_guides/big_modeling
243
+ """
244
+ Detect and resolve checkpoint files based on the checkpoint path.
245
+
246
+ Handles single files, directories with state dict files, and sharded checkpoints.
247
+
248
+ Returns:
249
+ Tuple of (index_dict, index_filename, checkpoint_files)
250
+ """
251
+ # Reference: https://huggingface.co/docs/accelerate/v0.17.1/en/usage_guides/big_modeling
213
252
  checkpoint_files = None
214
253
  index_filename = None
215
254
  if os.path.isfile(checkpoint):
255
+ # Single file: check if it's an index or a state dict
216
256
  if str(checkpoint).endswith(".json"):
217
257
  index_filename = checkpoint
218
258
  else:
@@ -232,7 +272,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
232
272
  os.path.join(checkpoint, potential_state_safetensor[0])
233
273
  ]
234
274
  else:
235
- # otherwise check for sharded checkpoints
275
+ # Check for sharded checkpoints
236
276
  potential_index = [
237
277
  f for f in os.listdir(checkpoint) if f.endswith(".index.json")
238
278
  ]
@@ -247,18 +287,22 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
247
287
  f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones."
248
288
  )
249
289
  else:
290
+ # Invalid checkpoint path
250
291
  raise ValueError(
251
292
  "`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
252
293
  f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}."
253
294
  )
254
295
 
296
+ # Load index file if present
255
297
  if index_filename is not None:
256
298
  checkpoint_folder = os.path.split(index_filename)[0]
257
299
  with open(index_filename) as f:
258
300
  index = json.loads(f.read())
259
301
 
302
+ # Extract weight_map if present (standard format)
260
303
  if "weight_map" in index:
261
304
  index = index["weight_map"]
305
+ # Get list of unique checkpoint files
262
306
  checkpoint_files = sorted(list(set(index.values())))
263
307
  checkpoint_files = [
264
308
  os.path.join(checkpoint_folder, f) for f in checkpoint_files
@@ -270,6 +314,11 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
270
314
  def _load_tensor_from_checkpoint_file(
271
315
  self, checkpoint_file: str, key: str, update_cache: bool = True
272
316
  ) -> torch.Tensor:
317
+ """
318
+ Load a tensor from the checkpoint file.
319
+ For safetensors, loads only the requested tensor.
320
+ For PyTorch files, loads the entire state dict on first access.
321
+ """
273
322
  if checkpoint_file.endswith(".safetensors"):
274
323
  with safe_open(checkpoint_file, framework="pt", device=self._device) as f:
275
324
  tensor = f.get_tensor(key)
@@ -279,6 +328,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
279
328
  self._state_dict_cache[key] = tensor
280
329
  return tensor
281
330
  else:
331
+ # PyTorch .bin file: load entire state dict
282
332
  state_dict = torch.load(checkpoint_file, map_location=self._device)
283
333
  if update_cache:
284
334
  if self._state_dict_cache is not None:
@@ -1,35 +1,126 @@
1
1
  from difflib import get_close_matches
2
- from typing import Any, Iterable, List, Optional
2
+ from typing import Any, Iterable, List, Optional, TypeVar, Union
3
+
4
+ T = TypeVar("T")
3
5
 
4
6
  __all__ = [
5
7
  "first",
6
8
  "has_length",
7
- "join_list",
9
+ "join_lists",
8
10
  "attr_equal",
9
11
  "validate_and_suggest_corrections",
10
12
  ]
11
13
 
12
14
 
13
- def first(iterable: Iterable):
14
- return next(iter(iterable))
15
+ def first(iterable: Iterable[T], default: Optional[T] = None) -> Optional[T]:
16
+ """
17
+ Return the first element of an iterable.
15
18
 
19
+ Args:
20
+ iterable: The iterable to get the first element from.
21
+ default: The value to return if the iterable is empty. If None and
22
+ the iterable is empty, raises StopIteration.
16
23
 
17
- def has_length(dataset):
24
+ Returns:
25
+ The first element of the iterable, or the default value if empty.
26
+
27
+ Raises:
28
+ StopIteration: If the iterable is empty and no default is provided.
29
+ TypeError: If the object is not iterable.
18
30
  """
19
- Checks if the dataset implements __len__() and it doesn't raise an error
31
+ try:
32
+ iterator = iter(iterable)
33
+ return next(iterator)
34
+ except StopIteration:
35
+ if default is not None:
36
+ return default
37
+ raise
38
+ except TypeError as e:
39
+ raise TypeError(
40
+ f"Object of type {type(iterable).__name__} is not iterable"
41
+ ) from e
42
+
43
+
44
+ def has_length(obj: Any) -> bool:
20
45
  """
46
+ Check if an object has a length (implements __len__) and len() works correctly.
47
+
48
+ Args:
49
+ obj: The object to check for length support.
50
+
51
+ Returns:
52
+ bool: True if the object supports len() and doesn't raise an error,
53
+ False otherwise.
54
+ """
55
+ if obj is None:
56
+ return False
57
+
21
58
  try:
22
- return len(dataset) is not None
23
- except TypeError:
59
+ # Check if __len__ method exists
60
+ if not hasattr(obj, "__len__"):
61
+ return False
62
+
63
+ # Try to get the length - this will raise TypeError for unsized objects
64
+ length = len(obj)
65
+
66
+ # Verify the length is a non-negative integer
67
+ return isinstance(length, int) and length >= 0
68
+ except (TypeError, AttributeError):
24
69
  # TypeError: len() of unsized object
70
+ # AttributeError: if __len__ is not callable somehow
25
71
  return False
72
+ except Exception:
73
+ # Any other unexpected error
74
+ return False
75
+
76
+
77
+ def join_lists(list_of_lists: Iterable[Iterable[T]]) -> List[T]:
78
+ """
79
+ Flatten a collection of iterables into a single list.
80
+
81
+ Args:
82
+ list_of_lists: An iterable containing iterables to be flattened.
26
83
 
84
+ Returns:
85
+ List[T]: A new list containing all elements from the input iterables
86
+ in order.
27
87
 
28
- def join_list(list_of_list: List[List]):
29
- ans = []
30
- for item in list_of_list:
31
- ans.extend(item)
32
- return ans
88
+ Raises:
89
+ TypeError: If any item in list_of_lists is not iterable.
90
+
91
+ Examples:
92
+ >>> join_lists([[1, 2], [3, 4], [5]])
93
+ [1, 2, 3, 4, 5]
94
+ >>> join_lists([])
95
+ []
96
+ >>> join_lists([[], [1], [], [2, 3]])
97
+ [1, 2, 3]
98
+ """
99
+ if not list_of_lists:
100
+ return []
101
+
102
+ result = []
103
+ for i, item in enumerate(list_of_lists):
104
+ try:
105
+ # Check if item is iterable (but not string, which is iterable but
106
+ # usually not what we want to flatten character by character)
107
+ if isinstance(item, (str, bytes)):
108
+ raise TypeError(
109
+ f"Item at index {i} is a string/bytes, not a list-like iterable"
110
+ )
111
+
112
+ # Try to extend with the item
113
+ result.extend(item)
114
+ except TypeError as e:
115
+ if "not iterable" in str(e):
116
+ raise TypeError(
117
+ f"Item at index {i} (type: {type(item).__name__}) is not iterable"
118
+ ) from e
119
+ else:
120
+ # Re-raise our custom error or other TypeError
121
+ raise
122
+
123
+ return result
33
124
 
34
125
 
35
126
  def attr_equal(obj, attr: str, value):
@@ -40,6 +40,10 @@ def is_matplotlib_available():
40
40
  return _is_package_available("matplotlib")
41
41
 
42
42
 
43
+ def is_open_clip_available():
44
+ return _is_package_available("open_clip")
45
+
46
+
43
47
  def is_pillow_available():
44
48
  return _is_package_available("PIL")
45
49
 
@@ -2,6 +2,8 @@ import logging
2
2
  import os
3
3
  from typing import List
4
4
 
5
+ from lightning_utilities.core.rank_zero import rank_zero_only
6
+
5
7
  log = logging.getLogger(__name__)
6
8
 
7
9
 
@@ -25,6 +27,7 @@ def listdir_fullpath(dir: str) -> List[str]:
25
27
  return [os.path.join(dir, name) for name in names]
26
28
 
27
29
 
30
+ @rank_zero_only
28
31
  def create_symlink(src_dir: str, dst_dir: str, link_name: str = None):
29
32
  """
30
33
  Creates a symbolic link from src_dir to dst_dir.
@@ -59,6 +62,10 @@ def create_symlink(src_dir: str, dst_dir: str, link_name: str = None):
59
62
  link_name = os.path.basename(src_dir)
60
63
 
61
64
  link_path = os.path.join(dst_dir, link_name)
65
+ # if the link already exists, skip
66
+ if os.path.exists(link_path):
67
+ log.warning(f"Symbolic link already exists, skipping: {link_path}")
68
+ return
62
69
 
63
70
  try:
64
71
  # if the system is windows, use the `mklink` command in "CMD" to create the symlink
@@ -3,6 +3,12 @@ from typing import Mapping, Optional
3
3
 
4
4
  from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
5
5
 
6
+ __all__ = [
7
+ "RankedLogger",
8
+ "RankZeroLogger",
9
+ "get_rankzero_logger",
10
+ ]
11
+
6
12
 
7
13
  class RankedLogger(logging.LoggerAdapter):
8
14
  """A multi-GPU-friendly python command line logger."""
@@ -16,6 +16,7 @@ from rich.panel import Panel
16
16
  from rich.prompt import Prompt
17
17
  from rich.syntax import Syntax
18
18
  from rich.text import Text
19
+ from rich.traceback import install as install_rich_traceback
19
20
 
20
21
  from fusion_bench.utils import pylogger
21
22