fusion-bench 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__main__.py +4 -0
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/method/__init__.py +26 -4
- fusion_bench/method/classification/__init__.py +1 -0
- fusion_bench/method/classification/clip_finetune.py +1 -3
- fusion_bench/method/classification/continual_clip_finetune.py +297 -0
- fusion_bench/method/dare/__init__.py +1 -0
- fusion_bench/method/dare/task_arithmetic.py +14 -7
- fusion_bench/method/dare/ties_merging.py +100 -0
- fusion_bench/method/isotropic_merging/__init__.py +15 -0
- fusion_bench/method/isotropic_merging/iso.py +114 -0
- fusion_bench/method/isotropic_merging/iso_utils.py +176 -0
- fusion_bench/method/opcm/__init__.py +4 -0
- fusion_bench/method/opcm/opcm.py +277 -0
- fusion_bench/method/opcm/task_arithmetic.py +115 -0
- fusion_bench/method/opcm/ties_merging.py +156 -0
- fusion_bench/method/opcm/utils.py +73 -0
- fusion_bench/method/opcm/weight_average.py +120 -0
- fusion_bench/method/slerp/slerp.py +1 -1
- fusion_bench/method/task_singular_vector/TSVM.py +22 -2
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +91 -93
- fusion_bench/method/ties_merging/ties_merging.py +10 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
- fusion_bench/mixins/clip_classification.py +4 -1
- fusion_bench/programs/fabric_fusion_program.py +22 -11
- fusion_bench/scripts/cli.py +1 -0
- fusion_bench/taskpool/base_pool.py +1 -1
- fusion_bench/taskpool/clip_vision/taskpool.py +12 -7
- fusion_bench/utils/__init__.py +2 -1
- fusion_bench/utils/dict.py +43 -0
- fusion_bench/utils/expr.py +90 -0
- fusion_bench/utils/fabric.py +17 -0
- fusion_bench/utils/instantiate.py +7 -1
- fusion_bench/utils/json.py +30 -0
- fusion_bench/utils/parameters.py +27 -7
- fusion_bench/utils/path.py +15 -0
- fusion_bench/utils/plot/color_data.py +1726 -0
- fusion_bench/utils/rich_utils.py +15 -0
- fusion_bench/utils/set.py +8 -0
- fusion_bench/utils/tensorboard.py +51 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/METADATA +17 -18
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/RECORD +58 -29
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/WHEEL +1 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
- fusion_bench_config/method/clip_finetune.yaml +2 -2
- fusion_bench_config/method/dare/ties_merging.yaml +15 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +4 -0
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +5 -0
- fusion_bench_config/method/opcm/opcm.yaml +12 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
- fusion_bench_config/method/opcm/weight_average.yaml +10 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
class pattern_query:
|
|
2
|
+
R"""
|
|
3
|
+
Examples:
|
|
4
|
+
|
|
5
|
+
>>> f = pattern_query(lambda x: x==1)
|
|
6
|
+
>>> f(1)
|
|
7
|
+
True
|
|
8
|
+
>>> f(2)
|
|
9
|
+
False
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, func=None, type=None):
|
|
14
|
+
self.func = func
|
|
15
|
+
self.type = type
|
|
16
|
+
|
|
17
|
+
def __call__(self, expr) -> bool:
|
|
18
|
+
if self.type is not None:
|
|
19
|
+
if not isinstance(expr, self.type):
|
|
20
|
+
return False
|
|
21
|
+
if self.func is not None:
|
|
22
|
+
return self.func(expr)
|
|
23
|
+
return True
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def is_expr_match(pattern, expr):
|
|
27
|
+
R"""match pattern with a python expression expr.
|
|
28
|
+
|
|
29
|
+
Examples:
|
|
30
|
+
|
|
31
|
+
>>> is_expr_match('a', 'a')
|
|
32
|
+
True
|
|
33
|
+
>>> is_expr_match((object, 1), ('s',1))
|
|
34
|
+
True
|
|
35
|
+
>>> is_expr_match((object, 1), ('s',2))
|
|
36
|
+
False
|
|
37
|
+
>>> is_expr_match(((int, (int,)), (int, (int,)), (-1,)),
|
|
38
|
+
((2146, (6,)), (1124, (97,)), (-1,)))
|
|
39
|
+
True
|
|
40
|
+
|
|
41
|
+
match a numpy array whose shape is (1,2)
|
|
42
|
+
|
|
43
|
+
>>> import numpy as np
|
|
44
|
+
>>> is_expr_match(
|
|
45
|
+
pattern_query(lambda arr: arr.shape==(1,2), np.ndarray),
|
|
46
|
+
np.zeros((1,2)))
|
|
47
|
+
True
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
pattern: pattern to match
|
|
51
|
+
expr: python object
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
NotImplementedError: Unsupported type
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
bool
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
if type(pattern) == type: # type
|
|
61
|
+
if not isinstance(expr, pattern):
|
|
62
|
+
return False
|
|
63
|
+
|
|
64
|
+
else: # instance
|
|
65
|
+
if isinstance(pattern, pattern_query):
|
|
66
|
+
return pattern(expr)
|
|
67
|
+
if type(pattern) != type(expr):
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
if isinstance(pattern, (int, float, str)):
|
|
71
|
+
if pattern != expr:
|
|
72
|
+
return False
|
|
73
|
+
elif isinstance(pattern, (tuple, list)):
|
|
74
|
+
if len(pattern) != len(expr):
|
|
75
|
+
return False
|
|
76
|
+
for i in range(len(pattern)):
|
|
77
|
+
if not is_expr_match(pattern[i], expr[i]):
|
|
78
|
+
return False
|
|
79
|
+
elif isinstance(pattern, dict):
|
|
80
|
+
if len(pattern) != len(expr):
|
|
81
|
+
return False
|
|
82
|
+
for k in pattern:
|
|
83
|
+
try:
|
|
84
|
+
if not is_expr_match(pattern[k], expr[k]):
|
|
85
|
+
return False
|
|
86
|
+
except:
|
|
87
|
+
return False
|
|
88
|
+
else:
|
|
89
|
+
raise NotImplementedError("Unsupported type: {}".format(type(pattern)))
|
|
90
|
+
return True
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import time
|
|
2
|
+
|
|
3
|
+
import lightning as L
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def seed_everything_by_time(fabric: L.Fabric):
|
|
7
|
+
"""
|
|
8
|
+
Set seed for all processes by time.
|
|
9
|
+
"""
|
|
10
|
+
# set seed for all processes
|
|
11
|
+
if fabric.is_global_zero:
|
|
12
|
+
seed = int(time.time())
|
|
13
|
+
else:
|
|
14
|
+
seed = None
|
|
15
|
+
fabric.barrier()
|
|
16
|
+
seed = fabric.broadcast(seed, src=0)
|
|
17
|
+
L.seed_everything(seed)
|
|
@@ -10,7 +10,7 @@ from hydra._internal.utils import _locate
|
|
|
10
10
|
from hydra.errors import InstantiationException
|
|
11
11
|
from hydra.types import ConvertMode, TargetConf
|
|
12
12
|
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
13
|
-
from omegaconf import OmegaConf, SCMode
|
|
13
|
+
from omegaconf import DictConfig, OmegaConf, SCMode
|
|
14
14
|
from omegaconf._utils import is_structured_config
|
|
15
15
|
from rich import print
|
|
16
16
|
from rich.panel import Panel
|
|
@@ -30,6 +30,12 @@ Function to be used for printing function calls.
|
|
|
30
30
|
CATCH_EXCEPTION = True
|
|
31
31
|
|
|
32
32
|
|
|
33
|
+
def is_instantiable(config: Union[DictConfig, Any]) -> bool:
|
|
34
|
+
if OmegaConf.is_dict(config):
|
|
35
|
+
return "_target_" in config
|
|
36
|
+
return False
|
|
37
|
+
|
|
38
|
+
|
|
33
39
|
def _resolve_callable_name(f: Callable[..., Any]) -> str:
|
|
34
40
|
# Get the module name
|
|
35
41
|
module_name = f.__module__
|
fusion_bench/utils/json.py
CHANGED
|
@@ -1,3 +1,33 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Union
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def save_to_json(obj, path: Union[str, Path]):
|
|
7
|
+
"""
|
|
8
|
+
save an object to a json file
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
obj (Any): the object to save
|
|
12
|
+
path (Union[str, Path]): the path to save the object
|
|
13
|
+
"""
|
|
14
|
+
with open(path, "w") as f:
|
|
15
|
+
json.dump(obj, f)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def load_from_json(path: Union[str, Path]) -> Union[dict, list]:
|
|
19
|
+
"""load an object from a json file
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
path (Union[str, Path]): the path to load the object
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
dict: the loaded object
|
|
26
|
+
"""
|
|
27
|
+
with open(path, "r") as f:
|
|
28
|
+
return json.load(f)
|
|
29
|
+
|
|
30
|
+
|
|
1
31
|
def _is_list_of_dict(obj) -> bool:
|
|
2
32
|
if not isinstance(obj, list):
|
|
3
33
|
return False
|
fusion_bench/utils/parameters.py
CHANGED
|
@@ -14,6 +14,7 @@ __all__ = [
|
|
|
14
14
|
"get_parameter_statistics",
|
|
15
15
|
"state_dict_to_vector",
|
|
16
16
|
"vector_to_state_dict",
|
|
17
|
+
"trainable_state_dict",
|
|
17
18
|
]
|
|
18
19
|
|
|
19
20
|
# Model conversion utils
|
|
@@ -44,33 +45,43 @@ def trainable_state_dict(
|
|
|
44
45
|
|
|
45
46
|
|
|
46
47
|
def state_dict_to_vector(
|
|
47
|
-
state_dict: StateDictType,
|
|
48
|
+
state_dict: Union[StateDictType, nn.Module],
|
|
48
49
|
remove_keys: Optional[List[str]] = None,
|
|
49
50
|
):
|
|
50
51
|
"""
|
|
51
52
|
Convert a state dictionary to a vector.
|
|
52
53
|
|
|
53
54
|
Args:
|
|
54
|
-
state_dict (dict): The state dictionary to convert.
|
|
55
|
+
state_dict (Union[dict[str, torch.Tensor], nn.Module]): The state dictionary to convert.
|
|
55
56
|
remove_keys (list, optional): List of keys to remove from the state dictionary. Defaults to [].
|
|
56
57
|
|
|
57
58
|
Returns:
|
|
58
59
|
torch.Tensor: The converted vector.
|
|
59
60
|
"""
|
|
60
61
|
remove_keys = remove_keys if remove_keys is not None else []
|
|
61
|
-
|
|
62
|
+
|
|
63
|
+
if isinstance(state_dict, nn.Module):
|
|
64
|
+
shared_state_dict = state_dict.state_dict()
|
|
65
|
+
else:
|
|
66
|
+
shared_state_dict = copy.copy(state_dict)
|
|
67
|
+
|
|
68
|
+
# remove the keys to be removed
|
|
62
69
|
for key in remove_keys:
|
|
63
70
|
if key in shared_state_dict:
|
|
64
71
|
del shared_state_dict[key]
|
|
72
|
+
|
|
73
|
+
# sort the reference dict
|
|
65
74
|
sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
|
|
66
|
-
|
|
75
|
+
|
|
76
|
+
vector = nn.utils.parameters_to_vector(
|
|
67
77
|
[value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
|
|
68
78
|
)
|
|
79
|
+
return vector
|
|
69
80
|
|
|
70
81
|
|
|
71
82
|
def vector_to_state_dict(
|
|
72
83
|
vector: torch.Tensor,
|
|
73
|
-
state_dict: StateDictType,
|
|
84
|
+
state_dict: Union[StateDictType, nn.Module],
|
|
74
85
|
remove_keys: Optional[List[str]] = None,
|
|
75
86
|
):
|
|
76
87
|
"""
|
|
@@ -78,18 +89,27 @@ def vector_to_state_dict(
|
|
|
78
89
|
|
|
79
90
|
Args:
|
|
80
91
|
vector (torch.Tensor): The vector to convert.
|
|
81
|
-
state_dict (dict): The reference state dictionary to define the order of the vector.
|
|
92
|
+
state_dict (Union[dict[str, torch.Tensor], nn.Module]): The reference state dictionary to define the order of the vector.
|
|
82
93
|
remove_keys (list, optional): List of keys to remove from the reference state dictionary. Defaults to [].
|
|
83
94
|
|
|
84
95
|
Returns:
|
|
85
96
|
dict: The converted state dictionary.
|
|
86
97
|
"""
|
|
87
98
|
remove_keys = remove_keys if remove_keys is not None else []
|
|
99
|
+
|
|
88
100
|
# create a reference dict to define the order of the vector
|
|
89
|
-
|
|
101
|
+
if isinstance(state_dict, nn.Module):
|
|
102
|
+
reference_dict = state_dict.state_dict()
|
|
103
|
+
else:
|
|
104
|
+
# shallow copy the state_dict
|
|
105
|
+
reference_dict = copy.copy(state_dict)
|
|
106
|
+
|
|
107
|
+
# remove the keys to be removed
|
|
90
108
|
for key in remove_keys:
|
|
91
109
|
if key in reference_dict:
|
|
92
110
|
del reference_dict[key]
|
|
111
|
+
|
|
112
|
+
# sort the reference dict
|
|
93
113
|
sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
|
|
94
114
|
|
|
95
115
|
# create a shared state dict using the reference dict
|
fusion_bench/utils/path.py
CHANGED
|
@@ -1,7 +1,22 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from typing import List
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
def path_is_dir_and_not_empty(path: str):
|
|
5
6
|
if path is None:
|
|
6
7
|
return False
|
|
7
8
|
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def listdir_fullpath(dir: str) -> List[str]:
|
|
12
|
+
"""list directory `dir`, return fullpaths
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
dir (str): directory name
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
List[str]: a list of fullpaths
|
|
19
|
+
"""
|
|
20
|
+
assert os.path.isdir(dir), "Argument 'dir' must be a Directory"
|
|
21
|
+
names = os.listdir(dir)
|
|
22
|
+
return [os.path.join(dir, name) for name in names]
|