fusion-bench 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. fusion_bench/__init__.py +152 -42
  2. fusion_bench/dataset/__init__.py +27 -4
  3. fusion_bench/dataset/clip_dataset.py +2 -2
  4. fusion_bench/method/__init__.py +18 -1
  5. fusion_bench/method/classification/__init__.py +27 -2
  6. fusion_bench/method/classification/image_classification_finetune.py +214 -0
  7. fusion_bench/method/ensemble.py +17 -2
  8. fusion_bench/method/linear/__init__.py +6 -2
  9. fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
  10. fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
  11. fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
  12. fusion_bench/method/opcm/opcm.py +1 -0
  13. fusion_bench/method/pwe_moe/module.py +0 -2
  14. fusion_bench/method/simple_average.py +2 -2
  15. fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
  16. fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
  17. fusion_bench/method/ties_merging/ties_merging.py +22 -6
  18. fusion_bench/method/wudi/__init__.py +1 -0
  19. fusion_bench/method/wudi/wudi.py +105 -0
  20. fusion_bench/mixins/__init__.py +2 -0
  21. fusion_bench/mixins/lightning_fabric.py +4 -0
  22. fusion_bench/mixins/pyinstrument.py +174 -0
  23. fusion_bench/mixins/serialization.py +25 -78
  24. fusion_bench/mixins/simple_profiler.py +106 -23
  25. fusion_bench/modelpool/__init__.py +2 -0
  26. fusion_bench/modelpool/base_pool.py +77 -14
  27. fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
  28. fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
  29. fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
  30. fusion_bench/models/__init__.py +35 -9
  31. fusion_bench/models/hf_clip.py +4 -0
  32. fusion_bench/models/hf_utils.py +2 -1
  33. fusion_bench/models/model_card_templates/default.md +8 -1
  34. fusion_bench/models/wrappers/ensemble.py +136 -7
  35. fusion_bench/optim/__init__.py +40 -2
  36. fusion_bench/optim/lr_scheduler/__init__.py +27 -1
  37. fusion_bench/optim/muon.py +339 -0
  38. fusion_bench/programs/__init__.py +2 -0
  39. fusion_bench/programs/fabric_fusion_program.py +2 -2
  40. fusion_bench/programs/fusion_program.py +271 -0
  41. fusion_bench/scripts/cli.py +2 -2
  42. fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
  43. fusion_bench/tasks/clip_classification/__init__.py +15 -0
  44. fusion_bench/utils/__init__.py +167 -21
  45. fusion_bench/utils/devices.py +30 -8
  46. fusion_bench/utils/lazy_imports.py +91 -12
  47. fusion_bench/utils/lazy_state_dict.py +58 -5
  48. fusion_bench/utils/misc.py +104 -13
  49. fusion_bench/utils/packages.py +4 -0
  50. fusion_bench/utils/path.py +7 -0
  51. fusion_bench/utils/pylogger.py +6 -0
  52. fusion_bench/utils/rich_utils.py +8 -3
  53. fusion_bench/utils/state_dict_arithmetic.py +935 -162
  54. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +10 -3
  55. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +76 -55
  56. fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
  57. fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
  58. fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
  59. fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
  60. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
  61. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
  62. fusion_bench_config/method/wudi/wudi.yaml +4 -0
  63. fusion_bench_config/model_fusion.yaml +45 -0
  64. fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
  65. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
  66. fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
  67. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
  68. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
  69. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
  70. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
  71. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
  72. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
  73. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
  74. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
  75. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
  76. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
  77. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,11 @@
1
+ """
2
+ Utilities for handling model checkpoints and state dictionaries.
3
+
4
+ This module provides classes and functions for lazily loading state dictionaries
5
+ from various checkpoint formats, including PyTorch .bin files, SafeTensors files,
6
+ and sharded checkpoints.
7
+ """
8
+
1
9
  import json
2
10
  import logging
3
11
  import os
@@ -43,6 +51,21 @@ def resolve_checkpoint_path(
43
51
  hf_cache_dir: Optional[str] = None,
44
52
  hf_proxies: Optional[Dict] = None,
45
53
  ):
54
+ """
55
+ Resolve a checkpoint path, downloading from Hugging Face Hub if necessary.
56
+
57
+ Args:
58
+ checkpoint: Path to local checkpoint or Hugging Face model ID.
59
+ hf_revision: Specific revision to download from HF Hub.
60
+ hf_cache_dir: Local cache directory for HF downloads.
61
+ hf_proxies: Proxy settings for HF downloads.
62
+
63
+ Returns:
64
+ Local path to the checkpoint.
65
+
66
+ Raises:
67
+ FileNotFoundError: If the checkpoint cannot be resolved.
68
+ """
46
69
  # If it's a local file or directory, return as is
47
70
  if os.path.exists(checkpoint):
48
71
  return checkpoint
@@ -64,11 +87,11 @@ def resolve_checkpoint_path(
64
87
 
65
88
  class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
66
89
  """
67
- Dictionary-like object that lazily loads a state dict from a checkpoint path.
90
+ A dictionary-like object that lazily loads tensors from model checkpoints.
68
91
  """
69
92
 
70
93
  _local_path: str
71
- """local path to the checkpoint."""
94
+ """Local path to the checkpoint."""
72
95
  _state_dict_cache: Optional[Dict]
73
96
  """Cache for the state dict, if enabled."""
74
97
  _index_filename: Optional[str]
@@ -76,6 +99,9 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
76
99
  _index: Optional[Dict[str, str]]
77
100
  """Mapping of parameter names to checkpoint files."""
78
101
 
102
+ meta_module: TorchModelType = None
103
+ meta_module_class: Optional[Type[TorchModelType]] = None
104
+
79
105
  def __init__(
80
106
  self,
81
107
  checkpoint: str,
@@ -89,6 +115,8 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
89
115
  hf_proxies: Optional[Dict] = None,
90
116
  ):
91
117
  """
118
+ Initialize LazyStateDict with a checkpoint path.
119
+
92
120
  Args:
93
121
  checkpoint (str): Path to the checkpoint file or directory.
94
122
  meta_module_class (Type[nn.Module], optional): Class of the meta module to instantiate.
@@ -113,6 +141,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
113
141
  self.meta_module_class = import_object(self.meta_module_class)
114
142
  self.meta_module = meta_module
115
143
 
144
+ # Instantiate meta module if class provided
116
145
  if self.meta_module_class is not None:
117
146
  with init_empty_weights():
118
147
  self.meta_module = self.meta_module_class.from_pretrained(
@@ -123,6 +152,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
123
152
  proxies=hf_proxies,
124
153
  )
125
154
 
155
+ # Store original checkpoint path and resolve to local path
126
156
  self._checkpoint = checkpoint
127
157
  self._local_path = resolve_checkpoint_path(
128
158
  checkpoint,
@@ -131,10 +161,12 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
131
161
  hf_proxies=hf_proxies,
132
162
  )
133
163
 
164
+ # Detect checkpoint file type and set up indexing
134
165
  self._index, self._index_filename, self._checkpoint_files = (
135
166
  self._resolve_checkpoint_files(self._local_path)
136
167
  )
137
168
 
169
+ # Set up based on checkpoint type
138
170
  if self._index is not None:
139
171
  # if meta_module is provided, remove the keys that are not in the meta_module
140
172
  if self.meta_module is not None:
@@ -149,7 +181,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
149
181
  elif len(self._checkpoint_files) == 1 and self._checkpoint_files[0].endswith(
150
182
  SAFE_WEIGHTS_NAME
151
183
  ):
152
- # let the keys of self._index be the keys of the state dict, the values are the checkpoint file
184
+ # SafeTensors file: create index mapping all keys to this file
153
185
  with safe_open(
154
186
  self._checkpoint_files[0], framework="pt", device=device
155
187
  ) as f:
@@ -161,6 +193,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
161
193
  elif len(self._checkpoint_files) == 1 and self._checkpoint_files[0].endswith(
162
194
  WEIGHTS_NAME
163
195
  ):
196
+ # PyTorch .bin file: load entire state dict immediately
164
197
  log.info(f"Loading full state dict from {WEIGHTS_NAME}")
165
198
  self._state_dict_cache = torch.load(self._checkpoint_files[0])
166
199
  # if meta_module is provided, remove the keys that are not in the meta_module
@@ -170,6 +203,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
170
203
  if key not in meta_module_state_dict:
171
204
  self._state_dict_cache.pop(key)
172
205
  else:
206
+ # Unsupported checkpoint format
173
207
  raise ValueError(
174
208
  f"Cannot determine the type of checkpoint, please provide a checkpoint path to a file containing a whole state dict with file name {WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME}, or the index of a sharded checkpoint ending with `.index.json`."
175
209
  )
@@ -206,10 +240,19 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
206
240
  return deepcopy(self)
207
241
 
208
242
  def _resolve_checkpoint_files(self, checkpoint: str):
209
- # reference: https://huggingface.co/docs/accelerate/v0.17.1/en/usage_guides/big_modeling
243
+ """
244
+ Detect and resolve checkpoint files based on the checkpoint path.
245
+
246
+ Handles single files, directories with state dict files, and sharded checkpoints.
247
+
248
+ Returns:
249
+ Tuple of (index_dict, index_filename, checkpoint_files)
250
+ """
251
+ # Reference: https://huggingface.co/docs/accelerate/v0.17.1/en/usage_guides/big_modeling
210
252
  checkpoint_files = None
211
253
  index_filename = None
212
254
  if os.path.isfile(checkpoint):
255
+ # Single file: check if it's an index or a state dict
213
256
  if str(checkpoint).endswith(".json"):
214
257
  index_filename = checkpoint
215
258
  else:
@@ -229,7 +272,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
229
272
  os.path.join(checkpoint, potential_state_safetensor[0])
230
273
  ]
231
274
  else:
232
- # otherwise check for sharded checkpoints
275
+ # Check for sharded checkpoints
233
276
  potential_index = [
234
277
  f for f in os.listdir(checkpoint) if f.endswith(".index.json")
235
278
  ]
@@ -244,18 +287,22 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
244
287
  f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones."
245
288
  )
246
289
  else:
290
+ # Invalid checkpoint path
247
291
  raise ValueError(
248
292
  "`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
249
293
  f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}."
250
294
  )
251
295
 
296
+ # Load index file if present
252
297
  if index_filename is not None:
253
298
  checkpoint_folder = os.path.split(index_filename)[0]
254
299
  with open(index_filename) as f:
255
300
  index = json.loads(f.read())
256
301
 
302
+ # Extract weight_map if present (standard format)
257
303
  if "weight_map" in index:
258
304
  index = index["weight_map"]
305
+ # Get list of unique checkpoint files
259
306
  checkpoint_files = sorted(list(set(index.values())))
260
307
  checkpoint_files = [
261
308
  os.path.join(checkpoint_folder, f) for f in checkpoint_files
@@ -267,6 +314,11 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
267
314
  def _load_tensor_from_checkpoint_file(
268
315
  self, checkpoint_file: str, key: str, update_cache: bool = True
269
316
  ) -> torch.Tensor:
317
+ """
318
+ Load a tensor from the checkpoint file.
319
+ For safetensors, loads only the requested tensor.
320
+ For PyTorch files, loads the entire state dict on first access.
321
+ """
270
322
  if checkpoint_file.endswith(".safetensors"):
271
323
  with safe_open(checkpoint_file, framework="pt", device=self._device) as f:
272
324
  tensor = f.get_tensor(key)
@@ -276,6 +328,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
276
328
  self._state_dict_cache[key] = tensor
277
329
  return tensor
278
330
  else:
331
+ # PyTorch .bin file: load entire state dict
279
332
  state_dict = torch.load(checkpoint_file, map_location=self._device)
280
333
  if update_cache:
281
334
  if self._state_dict_cache is not None:
@@ -1,35 +1,126 @@
1
1
  from difflib import get_close_matches
2
- from typing import Any, Iterable, List, Optional
2
+ from typing import Any, Iterable, List, Optional, TypeVar, Union
3
+
4
+ T = TypeVar("T")
3
5
 
4
6
  __all__ = [
5
7
  "first",
6
8
  "has_length",
7
- "join_list",
9
+ "join_lists",
8
10
  "attr_equal",
9
11
  "validate_and_suggest_corrections",
10
12
  ]
11
13
 
12
14
 
13
- def first(iterable: Iterable):
14
- return next(iter(iterable))
15
+ def first(iterable: Iterable[T], default: Optional[T] = None) -> Optional[T]:
16
+ """
17
+ Return the first element of an iterable.
15
18
 
19
+ Args:
20
+ iterable: The iterable to get the first element from.
21
+ default: The value to return if the iterable is empty. If None and
22
+ the iterable is empty, raises StopIteration.
16
23
 
17
- def has_length(dataset):
24
+ Returns:
25
+ The first element of the iterable, or the default value if empty.
26
+
27
+ Raises:
28
+ StopIteration: If the iterable is empty and no default is provided.
29
+ TypeError: If the object is not iterable.
18
30
  """
19
- Checks if the dataset implements __len__() and it doesn't raise an error
31
+ try:
32
+ iterator = iter(iterable)
33
+ return next(iterator)
34
+ except StopIteration:
35
+ if default is not None:
36
+ return default
37
+ raise
38
+ except TypeError as e:
39
+ raise TypeError(
40
+ f"Object of type {type(iterable).__name__} is not iterable"
41
+ ) from e
42
+
43
+
44
+ def has_length(obj: Any) -> bool:
20
45
  """
46
+ Check if an object has a length (implements __len__) and len() works correctly.
47
+
48
+ Args:
49
+ obj: The object to check for length support.
50
+
51
+ Returns:
52
+ bool: True if the object supports len() and doesn't raise an error,
53
+ False otherwise.
54
+ """
55
+ if obj is None:
56
+ return False
57
+
21
58
  try:
22
- return len(dataset) is not None
23
- except TypeError:
59
+ # Check if __len__ method exists
60
+ if not hasattr(obj, "__len__"):
61
+ return False
62
+
63
+ # Try to get the length - this will raise TypeError for unsized objects
64
+ length = len(obj)
65
+
66
+ # Verify the length is a non-negative integer
67
+ return isinstance(length, int) and length >= 0
68
+ except (TypeError, AttributeError):
24
69
  # TypeError: len() of unsized object
70
+ # AttributeError: if __len__ is not callable somehow
25
71
  return False
72
+ except Exception:
73
+ # Any other unexpected error
74
+ return False
75
+
76
+
77
+ def join_lists(list_of_lists: Iterable[Iterable[T]]) -> List[T]:
78
+ """
79
+ Flatten a collection of iterables into a single list.
80
+
81
+ Args:
82
+ list_of_lists: An iterable containing iterables to be flattened.
26
83
 
84
+ Returns:
85
+ List[T]: A new list containing all elements from the input iterables
86
+ in order.
27
87
 
28
- def join_list(list_of_list: List[List]):
29
- ans = []
30
- for item in list_of_list:
31
- ans.extend(item)
32
- return ans
88
+ Raises:
89
+ TypeError: If any item in list_of_lists is not iterable.
90
+
91
+ Examples:
92
+ >>> join_lists([[1, 2], [3, 4], [5]])
93
+ [1, 2, 3, 4, 5]
94
+ >>> join_lists([])
95
+ []
96
+ >>> join_lists([[], [1], [], [2, 3]])
97
+ [1, 2, 3]
98
+ """
99
+ if not list_of_lists:
100
+ return []
101
+
102
+ result = []
103
+ for i, item in enumerate(list_of_lists):
104
+ try:
105
+ # Check if item is iterable (but not string, which is iterable but
106
+ # usually not what we want to flatten character by character)
107
+ if isinstance(item, (str, bytes)):
108
+ raise TypeError(
109
+ f"Item at index {i} is a string/bytes, not a list-like iterable"
110
+ )
111
+
112
+ # Try to extend with the item
113
+ result.extend(item)
114
+ except TypeError as e:
115
+ if "not iterable" in str(e):
116
+ raise TypeError(
117
+ f"Item at index {i} (type: {type(item).__name__}) is not iterable"
118
+ ) from e
119
+ else:
120
+ # Re-raise our custom error or other TypeError
121
+ raise
122
+
123
+ return result
33
124
 
34
125
 
35
126
  def attr_equal(obj, attr: str, value):
@@ -40,6 +40,10 @@ def is_matplotlib_available():
40
40
  return _is_package_available("matplotlib")
41
41
 
42
42
 
43
+ def is_open_clip_available():
44
+ return _is_package_available("open_clip")
45
+
46
+
43
47
  def is_pillow_available():
44
48
  return _is_package_available("PIL")
45
49
 
@@ -2,6 +2,8 @@ import logging
2
2
  import os
3
3
  from typing import List
4
4
 
5
+ from lightning_utilities.core.rank_zero import rank_zero_only
6
+
5
7
  log = logging.getLogger(__name__)
6
8
 
7
9
 
@@ -25,6 +27,7 @@ def listdir_fullpath(dir: str) -> List[str]:
25
27
  return [os.path.join(dir, name) for name in names]
26
28
 
27
29
 
30
+ @rank_zero_only
28
31
  def create_symlink(src_dir: str, dst_dir: str, link_name: str = None):
29
32
  """
30
33
  Creates a symbolic link from src_dir to dst_dir.
@@ -59,6 +62,10 @@ def create_symlink(src_dir: str, dst_dir: str, link_name: str = None):
59
62
  link_name = os.path.basename(src_dir)
60
63
 
61
64
  link_path = os.path.join(dst_dir, link_name)
65
+ # if the link already exists, skip
66
+ if os.path.exists(link_path):
67
+ log.warning(f"Symbolic link already exists, skipping: {link_path}")
68
+ return
62
69
 
63
70
  try:
64
71
  # if the system is windows, use the `mklink` command in "CMD" to create the symlink
@@ -3,6 +3,12 @@ from typing import Mapping, Optional
3
3
 
4
4
  from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
5
5
 
6
+ __all__ = [
7
+ "RankedLogger",
8
+ "RankZeroLogger",
9
+ "get_rankzero_logger",
10
+ ]
11
+
6
12
 
7
13
  class RankedLogger(logging.LoggerAdapter):
8
14
  """A multi-GPU-friendly python command line logger."""
@@ -16,6 +16,7 @@ from rich.panel import Panel
16
16
  from rich.prompt import Prompt
17
17
  from rich.syntax import Syntax
18
18
  from rich.text import Text
19
+ from rich.traceback import install as install_rich_traceback
19
20
 
20
21
  from fusion_bench.utils import pylogger
21
22
 
@@ -188,17 +189,21 @@ if __name__ == "__main__":
188
189
  display_available_styles()
189
190
 
190
191
 
191
- def setup_colorlogging(force=False, **config_kwargs):
192
+ def setup_colorlogging(
193
+ force=False,
194
+ level=logging.INFO,
195
+ **kwargs,
196
+ ):
192
197
  """
193
198
  Sets up color logging for the application.
194
199
  """
195
200
  FORMAT = "%(message)s"
196
201
 
197
202
  logging.basicConfig(
198
- level=logging.INFO,
203
+ level=level,
199
204
  format=FORMAT,
200
205
  datefmt="[%X]",
201
206
  handlers=[RichHandler()],
202
207
  force=force,
203
- **config_kwargs,
208
+ **kwargs,
204
209
  )