fusion-bench 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. fusion_bench/__main__.py +4 -0
  2. fusion_bench/dataset/fer2013.py +1 -0
  3. fusion_bench/method/__init__.py +26 -4
  4. fusion_bench/method/classification/__init__.py +1 -0
  5. fusion_bench/method/classification/clip_finetune.py +1 -3
  6. fusion_bench/method/classification/continual_clip_finetune.py +297 -0
  7. fusion_bench/method/dare/__init__.py +1 -0
  8. fusion_bench/method/dare/task_arithmetic.py +14 -7
  9. fusion_bench/method/dare/ties_merging.py +100 -0
  10. fusion_bench/method/isotropic_merging/__init__.py +15 -0
  11. fusion_bench/method/isotropic_merging/iso.py +114 -0
  12. fusion_bench/method/isotropic_merging/iso_utils.py +176 -0
  13. fusion_bench/method/opcm/__init__.py +4 -0
  14. fusion_bench/method/opcm/opcm.py +277 -0
  15. fusion_bench/method/opcm/task_arithmetic.py +115 -0
  16. fusion_bench/method/opcm/ties_merging.py +156 -0
  17. fusion_bench/method/opcm/utils.py +73 -0
  18. fusion_bench/method/opcm/weight_average.py +120 -0
  19. fusion_bench/method/slerp/slerp.py +1 -1
  20. fusion_bench/method/task_singular_vector/TSVM.py +22 -2
  21. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +91 -93
  22. fusion_bench/method/ties_merging/ties_merging.py +10 -0
  23. fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
  24. fusion_bench/mixins/clip_classification.py +4 -1
  25. fusion_bench/programs/fabric_fusion_program.py +22 -11
  26. fusion_bench/scripts/cli.py +1 -0
  27. fusion_bench/taskpool/base_pool.py +1 -1
  28. fusion_bench/taskpool/clip_vision/taskpool.py +12 -7
  29. fusion_bench/utils/__init__.py +2 -1
  30. fusion_bench/utils/dict.py +43 -0
  31. fusion_bench/utils/expr.py +90 -0
  32. fusion_bench/utils/fabric.py +17 -0
  33. fusion_bench/utils/instantiate.py +7 -1
  34. fusion_bench/utils/json.py +30 -0
  35. fusion_bench/utils/parameters.py +27 -7
  36. fusion_bench/utils/path.py +15 -0
  37. fusion_bench/utils/plot/color_data.py +1726 -0
  38. fusion_bench/utils/rich_utils.py +15 -0
  39. fusion_bench/utils/set.py +8 -0
  40. fusion_bench/utils/tensorboard.py +51 -0
  41. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/METADATA +17 -18
  42. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/RECORD +58 -29
  43. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/WHEEL +1 -1
  44. fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
  45. fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
  46. fusion_bench_config/method/clip_finetune.yaml +2 -2
  47. fusion_bench_config/method/dare/ties_merging.yaml +15 -0
  48. fusion_bench_config/method/isotropic_merging/iso_c.yaml +4 -0
  49. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +5 -0
  50. fusion_bench_config/method/opcm/opcm.yaml +12 -0
  51. fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
  52. fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
  53. fusion_bench_config/method/opcm/weight_average.yaml +10 -0
  54. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +6 -0
  55. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
  56. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/LICENSE +0 -0
  57. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/entry_points.txt +0 -0
  58. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,90 @@
1
+ class pattern_query:
2
+ R"""
3
+ Examples:
4
+
5
+ >>> f = pattern_query(lambda x: x==1)
6
+ >>> f(1)
7
+ True
8
+ >>> f(2)
9
+ False
10
+
11
+ """
12
+
13
+ def __init__(self, func=None, type=None):
14
+ self.func = func
15
+ self.type = type
16
+
17
+ def __call__(self, expr) -> bool:
18
+ if self.type is not None:
19
+ if not isinstance(expr, self.type):
20
+ return False
21
+ if self.func is not None:
22
+ return self.func(expr)
23
+ return True
24
+
25
+
26
+ def is_expr_match(pattern, expr):
27
+ R"""match pattern with a python expression expr.
28
+
29
+ Examples:
30
+
31
+ >>> is_expr_match('a', 'a')
32
+ True
33
+ >>> is_expr_match((object, 1), ('s',1))
34
+ True
35
+ >>> is_expr_match((object, 1), ('s',2))
36
+ False
37
+ >>> is_expr_match(((int, (int,)), (int, (int,)), (-1,)),
38
+ ((2146, (6,)), (1124, (97,)), (-1,)))
39
+ True
40
+
41
+ match a numpy array whose shape is (1,2)
42
+
43
+ >>> import numpy as np
44
+ >>> is_expr_match(
45
+ pattern_query(lambda arr: arr.shape==(1,2), np.ndarray),
46
+ np.zeros((1,2)))
47
+ True
48
+
49
+ Args:
50
+ pattern: pattern to match
51
+ expr: python object
52
+
53
+ Raises:
54
+ NotImplementedError: Unsupported type
55
+
56
+ Returns:
57
+ bool
58
+ """
59
+
60
+ if type(pattern) == type: # type
61
+ if not isinstance(expr, pattern):
62
+ return False
63
+
64
+ else: # instance
65
+ if isinstance(pattern, pattern_query):
66
+ return pattern(expr)
67
+ if type(pattern) != type(expr):
68
+ return False
69
+
70
+ if isinstance(pattern, (int, float, str)):
71
+ if pattern != expr:
72
+ return False
73
+ elif isinstance(pattern, (tuple, list)):
74
+ if len(pattern) != len(expr):
75
+ return False
76
+ for i in range(len(pattern)):
77
+ if not is_expr_match(pattern[i], expr[i]):
78
+ return False
79
+ elif isinstance(pattern, dict):
80
+ if len(pattern) != len(expr):
81
+ return False
82
+ for k in pattern:
83
+ try:
84
+ if not is_expr_match(pattern[k], expr[k]):
85
+ return False
86
+ except:
87
+ return False
88
+ else:
89
+ raise NotImplementedError("Unsupported type: {}".format(type(pattern)))
90
+ return True
@@ -0,0 +1,17 @@
1
+ import time
2
+
3
+ import lightning as L
4
+
5
+
6
+ def seed_everything_by_time(fabric: L.Fabric):
7
+ """
8
+ Set seed for all processes by time.
9
+ """
10
+ # set seed for all processes
11
+ if fabric.is_global_zero:
12
+ seed = int(time.time())
13
+ else:
14
+ seed = None
15
+ fabric.barrier()
16
+ seed = fabric.broadcast(seed, src=0)
17
+ L.seed_everything(seed)
@@ -10,7 +10,7 @@ from hydra._internal.utils import _locate
10
10
  from hydra.errors import InstantiationException
11
11
  from hydra.types import ConvertMode, TargetConf
12
12
  from lightning_utilities.core.rank_zero import rank_zero_only
13
- from omegaconf import OmegaConf, SCMode
13
+ from omegaconf import DictConfig, OmegaConf, SCMode
14
14
  from omegaconf._utils import is_structured_config
15
15
  from rich import print
16
16
  from rich.panel import Panel
@@ -30,6 +30,12 @@ Function to be used for printing function calls.
30
30
  CATCH_EXCEPTION = True
31
31
 
32
32
 
33
+ def is_instantiable(config: Union[DictConfig, Any]) -> bool:
34
+ if OmegaConf.is_dict(config):
35
+ return "_target_" in config
36
+ return False
37
+
38
+
33
39
  def _resolve_callable_name(f: Callable[..., Any]) -> str:
34
40
  # Get the module name
35
41
  module_name = f.__module__
@@ -1,3 +1,33 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any, Union
4
+
5
+
6
+ def save_to_json(obj, path: Union[str, Path]):
7
+ """
8
+ save an object to a json file
9
+
10
+ Args:
11
+ obj (Any): the object to save
12
+ path (Union[str, Path]): the path to save the object
13
+ """
14
+ with open(path, "w") as f:
15
+ json.dump(obj, f)
16
+
17
+
18
+ def load_from_json(path: Union[str, Path]) -> Union[dict, list]:
19
+ """load an object from a json file
20
+
21
+ Args:
22
+ path (Union[str, Path]): the path to load the object
23
+
24
+ Returns:
25
+ dict: the loaded object
26
+ """
27
+ with open(path, "r") as f:
28
+ return json.load(f)
29
+
30
+
1
31
  def _is_list_of_dict(obj) -> bool:
2
32
  if not isinstance(obj, list):
3
33
  return False
@@ -14,6 +14,7 @@ __all__ = [
14
14
  "get_parameter_statistics",
15
15
  "state_dict_to_vector",
16
16
  "vector_to_state_dict",
17
+ "trainable_state_dict",
17
18
  ]
18
19
 
19
20
  # Model conversion utils
@@ -44,33 +45,43 @@ def trainable_state_dict(
44
45
 
45
46
 
46
47
  def state_dict_to_vector(
47
- state_dict: StateDictType,
48
+ state_dict: Union[StateDictType, nn.Module],
48
49
  remove_keys: Optional[List[str]] = None,
49
50
  ):
50
51
  """
51
52
  Convert a state dictionary to a vector.
52
53
 
53
54
  Args:
54
- state_dict (dict): The state dictionary to convert.
55
+ state_dict (Union[dict[str, torch.Tensor], nn.Module]): The state dictionary to convert.
55
56
  remove_keys (list, optional): List of keys to remove from the state dictionary. Defaults to [].
56
57
 
57
58
  Returns:
58
59
  torch.Tensor: The converted vector.
59
60
  """
60
61
  remove_keys = remove_keys if remove_keys is not None else []
61
- shared_state_dict = copy.deepcopy(state_dict)
62
+
63
+ if isinstance(state_dict, nn.Module):
64
+ shared_state_dict = state_dict.state_dict()
65
+ else:
66
+ shared_state_dict = copy.copy(state_dict)
67
+
68
+ # remove the keys to be removed
62
69
  for key in remove_keys:
63
70
  if key in shared_state_dict:
64
71
  del shared_state_dict[key]
72
+
73
+ # sort the reference dict
65
74
  sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
66
- return nn.utils.parameters_to_vector(
75
+
76
+ vector = nn.utils.parameters_to_vector(
67
77
  [value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
68
78
  )
79
+ return vector
69
80
 
70
81
 
71
82
  def vector_to_state_dict(
72
83
  vector: torch.Tensor,
73
- state_dict: StateDictType,
84
+ state_dict: Union[StateDictType, nn.Module],
74
85
  remove_keys: Optional[List[str]] = None,
75
86
  ):
76
87
  """
@@ -78,18 +89,27 @@ def vector_to_state_dict(
78
89
 
79
90
  Args:
80
91
  vector (torch.Tensor): The vector to convert.
81
- state_dict (dict): The reference state dictionary to define the order of the vector.
92
+ state_dict (Union[dict[str, torch.Tensor], nn.Module]): The reference state dictionary to define the order of the vector.
82
93
  remove_keys (list, optional): List of keys to remove from the reference state dictionary. Defaults to [].
83
94
 
84
95
  Returns:
85
96
  dict: The converted state dictionary.
86
97
  """
87
98
  remove_keys = remove_keys if remove_keys is not None else []
99
+
88
100
  # create a reference dict to define the order of the vector
89
- reference_dict = copy.deepcopy(state_dict)
101
+ if isinstance(state_dict, nn.Module):
102
+ reference_dict = state_dict.state_dict()
103
+ else:
104
+ # shallow copy the state_dict
105
+ reference_dict = copy.copy(state_dict)
106
+
107
+ # remove the keys to be removed
90
108
  for key in remove_keys:
91
109
  if key in reference_dict:
92
110
  del reference_dict[key]
111
+
112
+ # sort the reference dict
93
113
  sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
94
114
 
95
115
  # create a shared state dict using the reference dict
@@ -1,7 +1,22 @@
1
1
  import os
2
+ from typing import List
2
3
 
3
4
 
4
5
  def path_is_dir_and_not_empty(path: str):
5
6
  if path is None:
6
7
  return False
7
8
  return os.path.isdir(path) and len(os.listdir(path)) > 0
9
+
10
+
11
+ def listdir_fullpath(dir: str) -> List[str]:
12
+ """list directory `dir`, return fullpaths
13
+
14
+ Args:
15
+ dir (str): directory name
16
+
17
+ Returns:
18
+ List[str]: a list of fullpaths
19
+ """
20
+ assert os.path.isdir(dir), "Argument 'dir' must be a Directory"
21
+ names = os.listdir(dir)
22
+ return [os.path.join(dir, name) for name in names]