fusion-bench 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. fusion_bench/__init__.py +4 -0
  2. fusion_bench/dataset/clip_dataset.py +1 -0
  3. fusion_bench/method/__init__.py +2 -0
  4. fusion_bench/method/adamerging/__init__.py +28 -5
  5. fusion_bench/method/adamerging/resnet_adamerging.py +279 -0
  6. fusion_bench/method/adamerging/task_wise_adamerging.py +2 -14
  7. fusion_bench/method/adamerging/utils.py +58 -0
  8. fusion_bench/method/classification/image_classification_finetune.py +168 -12
  9. fusion_bench/method/dare/simple_average.py +3 -2
  10. fusion_bench/method/dare/task_arithmetic.py +3 -2
  11. fusion_bench/method/simple_average.py +6 -4
  12. fusion_bench/method/task_arithmetic/task_arithmetic.py +4 -1
  13. fusion_bench/mixins/lightning_fabric.py +9 -0
  14. fusion_bench/modelpool/__init__.py +24 -2
  15. fusion_bench/modelpool/base_pool.py +8 -1
  16. fusion_bench/modelpool/causal_lm/causal_lm.py +2 -1
  17. fusion_bench/modelpool/convnext_for_image_classification.py +198 -0
  18. fusion_bench/modelpool/dinov2_for_image_classification.py +197 -0
  19. fusion_bench/modelpool/resnet_for_image_classification.py +289 -5
  20. fusion_bench/models/hf_clip.py +4 -7
  21. fusion_bench/models/hf_utils.py +4 -1
  22. fusion_bench/models/model_card_templates/default.md +1 -1
  23. fusion_bench/taskpool/__init__.py +2 -0
  24. fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
  25. fusion_bench/taskpool/resnet_for_image_classification.py +231 -0
  26. fusion_bench/utils/json.py +49 -8
  27. fusion_bench/utils/state_dict_arithmetic.py +91 -10
  28. {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/METADATA +2 -2
  29. {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/RECORD +124 -62
  30. fusion_bench_config/fabric/auto.yaml +1 -1
  31. fusion_bench_config/fabric/loggers/swandb_logger.yaml +5 -0
  32. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  33. fusion_bench_config/fabric_model_fusion.yaml +1 -0
  34. fusion_bench_config/method/adamerging/resnet.yaml +18 -0
  35. fusion_bench_config/method/classification/clip_finetune.yaml +5 -0
  36. fusion_bench_config/method/classification/image_classification_finetune.yaml +9 -0
  37. fusion_bench_config/method/linear/expo.yaml +5 -0
  38. fusion_bench_config/method/linear/llama_expo.yaml +5 -0
  39. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +3 -0
  40. fusion_bench_config/method/linear/simple_average_for_causallm.yaml +5 -0
  41. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +3 -0
  42. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +5 -0
  43. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +5 -0
  44. fusion_bench_config/method/mixtral_moe_merging.yaml +3 -0
  45. fusion_bench_config/method/mixtral_moe_upscaling.yaml +5 -0
  46. fusion_bench_config/method/regmean/clip_regmean.yaml +3 -0
  47. fusion_bench_config/method/regmean/gpt2_regmean.yaml +3 -0
  48. fusion_bench_config/method/regmean/regmean.yaml +3 -0
  49. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +3 -0
  50. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +6 -0
  51. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  52. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +5 -0
  53. fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +3 -0
  54. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -0
  55. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +5 -0
  56. fusion_bench_config/method/wudi/wudi.yaml +3 -0
  57. fusion_bench_config/model_fusion.yaml +2 -1
  58. fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224.yaml +10 -0
  59. fusion_bench_config/modelpool/Dinov2ForImageClassification/dinov2-base-imagenet1k-1-layer.yaml +10 -0
  60. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/_generate_config.py +138 -0
  61. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar10.yaml +1 -1
  62. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar100.yaml +1 -1
  63. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_dtd.yaml +14 -0
  64. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_emnist_letters.yaml +14 -0
  65. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_eurosat.yaml +14 -0
  66. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fashion_mnist.yaml +14 -0
  67. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fer2013.yaml +14 -0
  68. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_food101.yaml +14 -0
  69. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_gtsrb.yaml +14 -0
  70. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_kmnist.yaml +14 -0
  71. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_mnist.yaml +14 -0
  72. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford-iiit-pet.yaml +14 -0
  73. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford_flowers102.yaml +14 -0
  74. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_pcam.yaml +14 -0
  75. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_rendered-sst2.yaml +14 -0
  76. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_resisc45.yaml +14 -0
  77. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stanford-cars.yaml +14 -0
  78. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stl10.yaml +14 -0
  79. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_sun397.yaml +14 -0
  80. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_svhn.yaml +14 -0
  81. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar10.yaml +1 -1
  82. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar100.yaml +1 -1
  83. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_dtd.yaml +14 -0
  84. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_emnist_letters.yaml +14 -0
  85. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_eurosat.yaml +14 -0
  86. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fashion_mnist.yaml +14 -0
  87. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fer2013.yaml +14 -0
  88. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_food101.yaml +14 -0
  89. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_gtsrb.yaml +14 -0
  90. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_kmnist.yaml +14 -0
  91. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_mnist.yaml +14 -0
  92. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford-iiit-pet.yaml +14 -0
  93. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford_flowers102.yaml +14 -0
  94. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_pcam.yaml +14 -0
  95. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_rendered-sst2.yaml +14 -0
  96. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_resisc45.yaml +14 -0
  97. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stanford-cars.yaml +14 -0
  98. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stl10.yaml +14 -0
  99. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_sun397.yaml +14 -0
  100. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_svhn.yaml +14 -0
  101. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar10.yaml +1 -1
  102. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar100.yaml +1 -1
  103. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_dtd.yaml +14 -0
  104. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_emnist_letters.yaml +14 -0
  105. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_eurosat.yaml +14 -0
  106. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fashion_mnist.yaml +14 -0
  107. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fer2013.yaml +14 -0
  108. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_food101.yaml +14 -0
  109. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_gtsrb.yaml +14 -0
  110. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_kmnist.yaml +14 -0
  111. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_mnist.yaml +14 -0
  112. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford-iiit-pet.yaml +14 -0
  113. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford_flowers102.yaml +14 -0
  114. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_pcam.yaml +14 -0
  115. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_rendered-sst2.yaml +14 -0
  116. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_resisc45.yaml +14 -0
  117. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stanford-cars.yaml +14 -0
  118. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stl10.yaml +14 -0
  119. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_sun397.yaml +14 -0
  120. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_svhn.yaml +14 -0
  121. fusion_bench_config/method/clip_finetune.yaml +0 -26
  122. {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/WHEEL +0 -0
  123. {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/entry_points.txt +0 -0
  124. {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/licenses/LICENSE +0 -0
  125. {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/top_level.txt +0 -0
@@ -1,31 +1,72 @@
1
1
  import json
2
2
  from pathlib import Path
3
- from typing import Any, Union
3
+ from typing import TYPE_CHECKING, Any, Union
4
4
 
5
+ if TYPE_CHECKING:
6
+ from pyarrow.fs import FileSystem
5
7
 
6
- def save_to_json(obj, path: Union[str, Path]):
8
+
9
+ def save_to_json(obj, path: Union[str, Path], filesystem: "FileSystem" = None):
7
10
  """
8
11
  save an object to a json file
9
12
 
10
13
  Args:
11
14
  obj (Any): the object to save
12
15
  path (Union[str, Path]): the path to save the object
16
+ filesystem (FileSystem, optional): PyArrow FileSystem to use for writing.
17
+ If None, uses local filesystem via standard Python open().
18
+ Can also be an s3fs.S3FileSystem or fsspec filesystem.
13
19
  """
14
- with open(path, "w") as f:
15
- json.dump(obj, f)
20
+ if filesystem is not None:
21
+ json_str = json.dumps(obj)
22
+ # Check if it's an fsspec-based filesystem (like s3fs)
23
+ if hasattr(filesystem, "open"):
24
+ # Direct fsspec/s3fs usage - more reliable for some endpoints
25
+ path_str = str(path)
26
+ with filesystem.open(path_str, "w") as f:
27
+ f.write(json_str)
28
+ else:
29
+ # Use PyArrow filesystem
30
+ path_str = str(path)
31
+ with filesystem.open_output_stream(path_str) as f:
32
+ f.write(json_str.encode("utf-8"))
33
+ else:
34
+ # Use standard Python file operations
35
+ with open(path, "w") as f:
36
+ json.dump(obj, f)
16
37
 
17
38
 
18
- def load_from_json(path: Union[str, Path]) -> Union[dict, list]:
39
+ def load_from_json(
40
+ path: Union[str, Path], filesystem: "FileSystem" = None
41
+ ) -> Union[dict, list]:
19
42
  """load an object from a json file
20
43
 
21
44
  Args:
22
45
  path (Union[str, Path]): the path to load the object
46
+ filesystem (FileSystem, optional): PyArrow FileSystem to use for reading.
47
+ If None, uses local filesystem via standard Python open().
48
+ Can also be an s3fs.S3FileSystem or fsspec filesystem.
23
49
 
24
50
  Returns:
25
- dict: the loaded object
51
+ Union[dict, list]: the loaded object
26
52
  """
27
- with open(path, "r") as f:
28
- return json.load(f)
53
+ if filesystem is not None:
54
+ # Check if it's an fsspec-based filesystem (like s3fs)
55
+ if hasattr(filesystem, "open"):
56
+ # Direct fsspec/s3fs usage
57
+ path_str = str(path)
58
+ with filesystem.open(path_str, "r") as f:
59
+ return json.load(f)
60
+ else:
61
+ # Use PyArrow filesystem
62
+ path_str = str(path)
63
+ with filesystem.open_input_stream(path_str) as f:
64
+ json_data = f.read().decode("utf-8")
65
+ return json.loads(json_data)
66
+ else:
67
+ # Use standard Python file operations
68
+ with open(path, "r") as f:
69
+ return json.load(f)
29
70
 
30
71
 
31
72
  def _is_list_of_dict(obj) -> bool:
@@ -6,10 +6,13 @@ import torch
6
6
  from torch import Tensor
7
7
  from tqdm.auto import tqdm
8
8
 
9
+ from fusion_bench.utils.type import TorchModelType
10
+
9
11
  from .type import BoolStateDictType, StateDictType
10
12
 
11
13
  __all__ = [
12
14
  "ArithmeticStateDict",
15
+ "load_state_dict_with_prefix",
13
16
  "state_dicts_check_keys",
14
17
  "state_dict_to_device",
15
18
  "num_params_of_state_dict",
@@ -646,6 +649,48 @@ def _validate_list_lengths_equal(
646
649
  pass
647
650
 
648
651
 
652
+ def load_state_dict_with_prefix(
653
+ model: TorchModelType,
654
+ state_dict: StateDictType,
655
+ strict: bool = True,
656
+ assign: bool = False,
657
+ key_prefix: str = "model.",
658
+ operation: Literal["add", "remove"] = "remove",
659
+ ) -> TorchModelType:
660
+ """
661
+ Load a state dict into a model, adding or removing a prefix from the keys.
662
+
663
+ This is useful when loading state dicts saved with DataParallel, pytorch lightning or similar wrappers.
664
+
665
+ Args:
666
+ model: The model to load the state dict into.
667
+ state_dict: The state dictionary to load.
668
+ key_prefix: The prefix to add or remove from the keys.
669
+ operation: 'add' to add the prefix, 'remove' to remove it.
670
+
671
+ Returns:
672
+ The model with the loaded state dict.
673
+ """
674
+ if operation not in ("add", "remove"):
675
+ raise ValueError("operation must be either 'add' or 'remove'")
676
+
677
+ modified_state_dict = OrderedDict()
678
+ for key, value in state_dict.items():
679
+ if operation == "add":
680
+ new_key = f"{key_prefix}{key}"
681
+ else: # operation == "remove"
682
+ if key.startswith(key_prefix):
683
+ new_key = key[len(key_prefix) :]
684
+ else:
685
+ raise ValueError(
686
+ f"Key '{key}' does not start with prefix '{key_prefix}'"
687
+ )
688
+ modified_state_dict[new_key] = value
689
+
690
+ model.load_state_dict(modified_state_dict, strict=strict, assign=assign)
691
+ return model
692
+
693
+
649
694
  def state_dict_to_device(
650
695
  state_dict: StateDictType,
651
696
  device: Union[torch.device, str],
@@ -851,22 +896,48 @@ def state_dict_add_scalar(state_dict: StateDictType, scalar: Number) -> StateDic
851
896
  return OrderedDict((key, tensor + scalar) for key, tensor in state_dict.items())
852
897
 
853
898
 
854
- def state_dict_mul(state_dict: StateDictType, scalar: float) -> StateDictType:
899
+ def state_dict_mul(
900
+ state_dict: StateDictType,
901
+ scalar: float,
902
+ *,
903
+ keep_dtype_when_zero: bool = True,
904
+ show_pbar: bool = False,
905
+ ) -> StateDictType:
855
906
  """
856
907
  Multiply all parameters in a state dict by a scalar.
857
908
 
858
909
  Args:
859
910
  state_dict: The state dict to multiply.
860
- scalar: The scalar value to multiply each parameter by.
911
+ scalar (float): The scalar value to multiply each parameter by.
912
+ keep_dtype_when_zero (bool): Whether to keep the original data type of the tensors if either the tensor is all zeros or the scalar is zero.
913
+ show_pbar (bool): Whether to show a progress bar during computation.
861
914
 
862
915
  Returns:
863
916
  A new state dict with each parameter multiplied by the scalar.
864
917
  """
865
- return OrderedDict((key, scalar * tensor) for key, tensor in state_dict.items())
918
+ new_state_dict = OrderedDict()
919
+ for key, tensor in (
920
+ state_dict.items()
921
+ if not show_pbar
922
+ else tqdm(state_dict.items(), desc="Multiplying state dict")
923
+ ):
924
+ if (
925
+ keep_dtype_when_zero
926
+ and not tensor.is_floating_point() # when tensor is not floating point, multiplication by 0 keeps dtype
927
+ and (scalar == 0 or torch.all(tensor == 0))
928
+ ):
929
+ new_state_dict[key] = tensor.clone()
930
+ else:
931
+ new_state_dict[key] = scalar * tensor
932
+ return new_state_dict
866
933
 
867
934
 
868
935
  def state_dict_div(
869
- state_dict: StateDictType, scalar: float, show_pbar: bool = False
936
+ state_dict: StateDictType,
937
+ scalar: float,
938
+ *,
939
+ keep_dtype_when_zero: bool = True,
940
+ show_pbar: bool = False,
870
941
  ) -> StateDictType:
871
942
  """
872
943
  Divide all parameters in a state dict by a scalar.
@@ -874,6 +945,7 @@ def state_dict_div(
874
945
  Args:
875
946
  state_dict: The state dict to divide.
876
947
  scalar: The scalar value to divide each parameter by.
948
+ keep_dtype_when_zero: Whether to keep the original data type of the tensors if the tensor is all zeros.
877
949
  show_pbar: Whether to show a progress bar during computation.
878
950
 
879
951
  Returns:
@@ -885,12 +957,21 @@ def state_dict_div(
885
957
  if scalar == 0:
886
958
  raise ZeroDivisionError("Cannot divide state dict by zero")
887
959
 
888
- keys_iter = (
889
- tqdm(state_dict.keys(), desc="Dividing state dict")
890
- if show_pbar
891
- else state_dict.keys()
892
- )
893
- return OrderedDict((key, state_dict[key] / scalar) for key in keys_iter)
960
+ new_state_dict = OrderedDict()
961
+ for key, tensor in (
962
+ state_dict.items()
963
+ if not show_pbar
964
+ else tqdm(state_dict.items(), desc="Dividing state dict")
965
+ ):
966
+ if (
967
+ keep_dtype_when_zero
968
+ and not tensor.is_floating_point() # when tensor is not floating point, division by any scalar keeps dtype
969
+ and torch.all(tensor == 0) # only check tensor for zero
970
+ ):
971
+ new_state_dict[key] = tensor.clone()
972
+ else:
973
+ new_state_dict[key] = tensor / scalar
974
+ return new_state_dict
894
975
 
895
976
 
896
977
  def state_dict_power(state_dict: StateDictType, p: float) -> StateDictType:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
- Name: fusion_bench
3
- Version: 0.2.26
2
+ Name: fusion-bench
3
+ Version: 0.2.28
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  Project-URL: Repository, https://github.com/tanganke/fusion_bench