fusion-bench 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +4 -0
- fusion_bench/dataset/clip_dataset.py +1 -0
- fusion_bench/method/__init__.py +2 -0
- fusion_bench/method/adamerging/__init__.py +28 -5
- fusion_bench/method/adamerging/resnet_adamerging.py +279 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +2 -14
- fusion_bench/method/adamerging/utils.py +58 -0
- fusion_bench/method/classification/image_classification_finetune.py +168 -12
- fusion_bench/method/dare/simple_average.py +3 -2
- fusion_bench/method/dare/task_arithmetic.py +3 -2
- fusion_bench/method/simple_average.py +6 -4
- fusion_bench/method/task_arithmetic/task_arithmetic.py +4 -1
- fusion_bench/mixins/lightning_fabric.py +9 -0
- fusion_bench/modelpool/__init__.py +24 -2
- fusion_bench/modelpool/base_pool.py +8 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +2 -1
- fusion_bench/modelpool/convnext_for_image_classification.py +198 -0
- fusion_bench/modelpool/dinov2_for_image_classification.py +197 -0
- fusion_bench/modelpool/resnet_for_image_classification.py +289 -5
- fusion_bench/models/hf_clip.py +4 -7
- fusion_bench/models/hf_utils.py +4 -1
- fusion_bench/models/model_card_templates/default.md +1 -1
- fusion_bench/taskpool/__init__.py +2 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
- fusion_bench/taskpool/resnet_for_image_classification.py +231 -0
- fusion_bench/utils/json.py +49 -8
- fusion_bench/utils/state_dict_arithmetic.py +91 -10
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/METADATA +2 -2
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/RECORD +124 -62
- fusion_bench_config/fabric/auto.yaml +1 -1
- fusion_bench_config/fabric/loggers/swandb_logger.yaml +5 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -0
- fusion_bench_config/method/adamerging/resnet.yaml +18 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +5 -0
- fusion_bench_config/method/classification/image_classification_finetune.yaml +9 -0
- fusion_bench_config/method/linear/expo.yaml +5 -0
- fusion_bench_config/method/linear/llama_expo.yaml +5 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +3 -0
- fusion_bench_config/method/linear/simple_average_for_causallm.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +3 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +5 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +5 -0
- fusion_bench_config/method/mixtral_moe_merging.yaml +3 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +5 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +3 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +3 -0
- fusion_bench_config/method/regmean/regmean.yaml +3 -0
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +3 -0
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +3 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +5 -0
- fusion_bench_config/method/wudi/wudi.yaml +3 -0
- fusion_bench_config/model_fusion.yaml +2 -1
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224.yaml +10 -0
- fusion_bench_config/modelpool/Dinov2ForImageClassification/dinov2-base-imagenet1k-1-layer.yaml +10 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/_generate_config.py +138 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_svhn.yaml +14 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_svhn.yaml +14 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_svhn.yaml +14 -0
- fusion_bench_config/method/clip_finetune.yaml +0 -26
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/top_level.txt +0 -0
fusion_bench/utils/json.py
CHANGED
|
@@ -1,31 +1,72 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Any, Union
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
4
4
|
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from pyarrow.fs import FileSystem
|
|
5
7
|
|
|
6
|
-
|
|
8
|
+
|
|
9
|
+
def save_to_json(obj, path: Union[str, Path], filesystem: "FileSystem" = None):
|
|
7
10
|
"""
|
|
8
11
|
save an object to a json file
|
|
9
12
|
|
|
10
13
|
Args:
|
|
11
14
|
obj (Any): the object to save
|
|
12
15
|
path (Union[str, Path]): the path to save the object
|
|
16
|
+
filesystem (FileSystem, optional): PyArrow FileSystem to use for writing.
|
|
17
|
+
If None, uses local filesystem via standard Python open().
|
|
18
|
+
Can also be an s3fs.S3FileSystem or fsspec filesystem.
|
|
13
19
|
"""
|
|
14
|
-
|
|
15
|
-
json.
|
|
20
|
+
if filesystem is not None:
|
|
21
|
+
json_str = json.dumps(obj)
|
|
22
|
+
# Check if it's an fsspec-based filesystem (like s3fs)
|
|
23
|
+
if hasattr(filesystem, "open"):
|
|
24
|
+
# Direct fsspec/s3fs usage - more reliable for some endpoints
|
|
25
|
+
path_str = str(path)
|
|
26
|
+
with filesystem.open(path_str, "w") as f:
|
|
27
|
+
f.write(json_str)
|
|
28
|
+
else:
|
|
29
|
+
# Use PyArrow filesystem
|
|
30
|
+
path_str = str(path)
|
|
31
|
+
with filesystem.open_output_stream(path_str) as f:
|
|
32
|
+
f.write(json_str.encode("utf-8"))
|
|
33
|
+
else:
|
|
34
|
+
# Use standard Python file operations
|
|
35
|
+
with open(path, "w") as f:
|
|
36
|
+
json.dump(obj, f)
|
|
16
37
|
|
|
17
38
|
|
|
18
|
-
def load_from_json(
|
|
39
|
+
def load_from_json(
|
|
40
|
+
path: Union[str, Path], filesystem: "FileSystem" = None
|
|
41
|
+
) -> Union[dict, list]:
|
|
19
42
|
"""load an object from a json file
|
|
20
43
|
|
|
21
44
|
Args:
|
|
22
45
|
path (Union[str, Path]): the path to load the object
|
|
46
|
+
filesystem (FileSystem, optional): PyArrow FileSystem to use for reading.
|
|
47
|
+
If None, uses local filesystem via standard Python open().
|
|
48
|
+
Can also be an s3fs.S3FileSystem or fsspec filesystem.
|
|
23
49
|
|
|
24
50
|
Returns:
|
|
25
|
-
dict: the loaded object
|
|
51
|
+
Union[dict, list]: the loaded object
|
|
26
52
|
"""
|
|
27
|
-
|
|
28
|
-
|
|
53
|
+
if filesystem is not None:
|
|
54
|
+
# Check if it's an fsspec-based filesystem (like s3fs)
|
|
55
|
+
if hasattr(filesystem, "open"):
|
|
56
|
+
# Direct fsspec/s3fs usage
|
|
57
|
+
path_str = str(path)
|
|
58
|
+
with filesystem.open(path_str, "r") as f:
|
|
59
|
+
return json.load(f)
|
|
60
|
+
else:
|
|
61
|
+
# Use PyArrow filesystem
|
|
62
|
+
path_str = str(path)
|
|
63
|
+
with filesystem.open_input_stream(path_str) as f:
|
|
64
|
+
json_data = f.read().decode("utf-8")
|
|
65
|
+
return json.loads(json_data)
|
|
66
|
+
else:
|
|
67
|
+
# Use standard Python file operations
|
|
68
|
+
with open(path, "r") as f:
|
|
69
|
+
return json.load(f)
|
|
29
70
|
|
|
30
71
|
|
|
31
72
|
def _is_list_of_dict(obj) -> bool:
|
|
@@ -6,10 +6,13 @@ import torch
|
|
|
6
6
|
from torch import Tensor
|
|
7
7
|
from tqdm.auto import tqdm
|
|
8
8
|
|
|
9
|
+
from fusion_bench.utils.type import TorchModelType
|
|
10
|
+
|
|
9
11
|
from .type import BoolStateDictType, StateDictType
|
|
10
12
|
|
|
11
13
|
__all__ = [
|
|
12
14
|
"ArithmeticStateDict",
|
|
15
|
+
"load_state_dict_with_prefix",
|
|
13
16
|
"state_dicts_check_keys",
|
|
14
17
|
"state_dict_to_device",
|
|
15
18
|
"num_params_of_state_dict",
|
|
@@ -646,6 +649,48 @@ def _validate_list_lengths_equal(
|
|
|
646
649
|
pass
|
|
647
650
|
|
|
648
651
|
|
|
652
|
+
def load_state_dict_with_prefix(
|
|
653
|
+
model: TorchModelType,
|
|
654
|
+
state_dict: StateDictType,
|
|
655
|
+
strict: bool = True,
|
|
656
|
+
assign: bool = False,
|
|
657
|
+
key_prefix: str = "model.",
|
|
658
|
+
operation: Literal["add", "remove"] = "remove",
|
|
659
|
+
) -> TorchModelType:
|
|
660
|
+
"""
|
|
661
|
+
Load a state dict into a model, adding or removing a prefix from the keys.
|
|
662
|
+
|
|
663
|
+
This is useful when loading state dicts saved with DataParallel, pytorch lightning or similar wrappers.
|
|
664
|
+
|
|
665
|
+
Args:
|
|
666
|
+
model: The model to load the state dict into.
|
|
667
|
+
state_dict: The state dictionary to load.
|
|
668
|
+
key_prefix: The prefix to add or remove from the keys.
|
|
669
|
+
operation: 'add' to add the prefix, 'remove' to remove it.
|
|
670
|
+
|
|
671
|
+
Returns:
|
|
672
|
+
The model with the loaded state dict.
|
|
673
|
+
"""
|
|
674
|
+
if operation not in ("add", "remove"):
|
|
675
|
+
raise ValueError("operation must be either 'add' or 'remove'")
|
|
676
|
+
|
|
677
|
+
modified_state_dict = OrderedDict()
|
|
678
|
+
for key, value in state_dict.items():
|
|
679
|
+
if operation == "add":
|
|
680
|
+
new_key = f"{key_prefix}{key}"
|
|
681
|
+
else: # operation == "remove"
|
|
682
|
+
if key.startswith(key_prefix):
|
|
683
|
+
new_key = key[len(key_prefix) :]
|
|
684
|
+
else:
|
|
685
|
+
raise ValueError(
|
|
686
|
+
f"Key '{key}' does not start with prefix '{key_prefix}'"
|
|
687
|
+
)
|
|
688
|
+
modified_state_dict[new_key] = value
|
|
689
|
+
|
|
690
|
+
model.load_state_dict(modified_state_dict, strict=strict, assign=assign)
|
|
691
|
+
return model
|
|
692
|
+
|
|
693
|
+
|
|
649
694
|
def state_dict_to_device(
|
|
650
695
|
state_dict: StateDictType,
|
|
651
696
|
device: Union[torch.device, str],
|
|
@@ -851,22 +896,48 @@ def state_dict_add_scalar(state_dict: StateDictType, scalar: Number) -> StateDic
|
|
|
851
896
|
return OrderedDict((key, tensor + scalar) for key, tensor in state_dict.items())
|
|
852
897
|
|
|
853
898
|
|
|
854
|
-
def state_dict_mul(
|
|
899
|
+
def state_dict_mul(
|
|
900
|
+
state_dict: StateDictType,
|
|
901
|
+
scalar: float,
|
|
902
|
+
*,
|
|
903
|
+
keep_dtype_when_zero: bool = True,
|
|
904
|
+
show_pbar: bool = False,
|
|
905
|
+
) -> StateDictType:
|
|
855
906
|
"""
|
|
856
907
|
Multiply all parameters in a state dict by a scalar.
|
|
857
908
|
|
|
858
909
|
Args:
|
|
859
910
|
state_dict: The state dict to multiply.
|
|
860
|
-
scalar: The scalar value to multiply each parameter by.
|
|
911
|
+
scalar (float): The scalar value to multiply each parameter by.
|
|
912
|
+
keep_dtype_when_zero (bool): Whether to keep the original data type of the tensors if either the tensor is all zeros or the scalar is zero.
|
|
913
|
+
show_pbar (bool): Whether to show a progress bar during computation.
|
|
861
914
|
|
|
862
915
|
Returns:
|
|
863
916
|
A new state dict with each parameter multiplied by the scalar.
|
|
864
917
|
"""
|
|
865
|
-
|
|
918
|
+
new_state_dict = OrderedDict()
|
|
919
|
+
for key, tensor in (
|
|
920
|
+
state_dict.items()
|
|
921
|
+
if not show_pbar
|
|
922
|
+
else tqdm(state_dict.items(), desc="Multiplying state dict")
|
|
923
|
+
):
|
|
924
|
+
if (
|
|
925
|
+
keep_dtype_when_zero
|
|
926
|
+
and not tensor.is_floating_point() # when tensor is not floating point, multiplication by 0 keeps dtype
|
|
927
|
+
and (scalar == 0 or torch.all(tensor == 0))
|
|
928
|
+
):
|
|
929
|
+
new_state_dict[key] = tensor.clone()
|
|
930
|
+
else:
|
|
931
|
+
new_state_dict[key] = scalar * tensor
|
|
932
|
+
return new_state_dict
|
|
866
933
|
|
|
867
934
|
|
|
868
935
|
def state_dict_div(
|
|
869
|
-
state_dict: StateDictType,
|
|
936
|
+
state_dict: StateDictType,
|
|
937
|
+
scalar: float,
|
|
938
|
+
*,
|
|
939
|
+
keep_dtype_when_zero: bool = True,
|
|
940
|
+
show_pbar: bool = False,
|
|
870
941
|
) -> StateDictType:
|
|
871
942
|
"""
|
|
872
943
|
Divide all parameters in a state dict by a scalar.
|
|
@@ -874,6 +945,7 @@ def state_dict_div(
|
|
|
874
945
|
Args:
|
|
875
946
|
state_dict: The state dict to divide.
|
|
876
947
|
scalar: The scalar value to divide each parameter by.
|
|
948
|
+
keep_dtype_when_zero: Whether to keep the original data type of the tensors if the tensor is all zeros.
|
|
877
949
|
show_pbar: Whether to show a progress bar during computation.
|
|
878
950
|
|
|
879
951
|
Returns:
|
|
@@ -885,12 +957,21 @@ def state_dict_div(
|
|
|
885
957
|
if scalar == 0:
|
|
886
958
|
raise ZeroDivisionError("Cannot divide state dict by zero")
|
|
887
959
|
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
960
|
+
new_state_dict = OrderedDict()
|
|
961
|
+
for key, tensor in (
|
|
962
|
+
state_dict.items()
|
|
963
|
+
if not show_pbar
|
|
964
|
+
else tqdm(state_dict.items(), desc="Dividing state dict")
|
|
965
|
+
):
|
|
966
|
+
if (
|
|
967
|
+
keep_dtype_when_zero
|
|
968
|
+
and not tensor.is_floating_point() # when tensor is not floating point, division by any scalar keeps dtype
|
|
969
|
+
and torch.all(tensor == 0) # only check tensor for zero
|
|
970
|
+
):
|
|
971
|
+
new_state_dict[key] = tensor.clone()
|
|
972
|
+
else:
|
|
973
|
+
new_state_dict[key] = tensor / scalar
|
|
974
|
+
return new_state_dict
|
|
894
975
|
|
|
895
976
|
|
|
896
977
|
def state_dict_power(state_dict: StateDictType, p: float) -> StateDictType:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
|
-
Name:
|
|
3
|
-
Version: 0.2.
|
|
2
|
+
Name: fusion-bench
|
|
3
|
+
Version: 0.2.28
|
|
4
4
|
Summary: A Comprehensive Benchmark of Deep Model Fusion
|
|
5
5
|
Author-email: Anke Tang <tang.anke@foxmail.com>
|
|
6
6
|
Project-URL: Repository, https://github.com/tanganke/fusion_bench
|