fusion-bench 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. fusion_bench/compat/method/base_algorithm.py +1 -1
  2. fusion_bench/dataset/clip_dataset.py +3 -0
  3. fusion_bench/dataset/fer2013.py +12 -0
  4. fusion_bench/dataset/llama/preference_700k.py +1 -1
  5. fusion_bench/method/__init__.py +2 -0
  6. fusion_bench/method/classification/clip_finetune.py +10 -13
  7. fusion_bench/method/surgery/__init__.py +1 -3
  8. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +1 -1
  9. fusion_bench/method/tall_mask/__init__.py +0 -0
  10. fusion_bench/method/tall_mask/utils.py +234 -0
  11. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  12. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  13. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  14. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  15. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
  16. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  17. fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
  18. fusion_bench/mixins/clip_classification.py +6 -6
  19. fusion_bench/mixins/lightning_fabric.py +3 -1
  20. fusion_bench/modelpool/base_pool.py +0 -1
  21. fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
  22. fusion_bench/models/surgery/__init__.py +1 -0
  23. fusion_bench/models/surgery/surgerymodelwrapper.py +2 -1
  24. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  25. fusion_bench/models/wrappers/task_wise_fusion.py +1 -1
  26. fusion_bench/programs/fabric_fusion_program.py +7 -4
  27. fusion_bench/taskpool/llama/reward_model.py +1 -1
  28. fusion_bench/tasks/clip_classification/__init__.py +13 -45
  29. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
  30. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  31. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  32. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  33. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  34. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  35. fusion_bench/tasks/clip_classification/food101.py +105 -0
  36. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  37. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  38. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  39. fusion_bench/utils/parameters.py +12 -3
  40. fusion_bench/utils/type.py +10 -1
  41. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
  42. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +195 -62
  43. fusion_bench_config/dataset/image_classification/README.md +6 -0
  44. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  45. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  46. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
  47. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
  48. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  49. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  50. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  51. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  52. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  53. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  54. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  55. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  56. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  57. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  58. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  59. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  60. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  61. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  62. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  63. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
  64. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
  65. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  66. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  67. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  68. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  69. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  70. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  71. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  72. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  73. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  74. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  75. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  76. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  77. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  78. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  79. fusion_bench_config/model/clip-vit/README.md +38 -0
  80. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
  81. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  82. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  83. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  84. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  85. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
  86. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  87. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
  88. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  89. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  90. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  91. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
  92. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  93. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
  94. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  95. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  96. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  97. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  98. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
  99. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
  100. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  101. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
  102. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
  103. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
  104. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  105. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  106. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  107. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  108. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
  109. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
  110. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  111. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
  112. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  113. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  114. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  115. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
  116. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  117. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
  118. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  119. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  120. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  121. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  122. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
  123. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
  124. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  125. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
  126. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
  127. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
  128. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  129. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  130. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  131. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  132. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
  133. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  134. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
  135. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  136. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  137. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  138. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
  139. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  140. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
  141. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  142. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
  146. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
  147. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  148. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
  149. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
  150. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
  168. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  169. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  170. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  171. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  172. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  173. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  174. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  175. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  176. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  177. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  178. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  180. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  181. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  182. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  183. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  184. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  185. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  186. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  187. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  188. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  189. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  190. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  191. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  192. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
  193. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
  194. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
  195. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,642 @@
1
+ import math
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+
6
+ from fusion_bench.utils.type import StateDictType
7
+
8
+
9
+ def compute_svd_dict(task_vectors, config):
10
+ """
11
+ Computes the Singular Value Decomposition (SVD) for each task vector in the provided datasets and stores the results in a dictionary.
12
+
13
+ Args:
14
+ task_vectors (list): A list of task vector objects, where each task vector contains a dictionary of matrices to be decomposed.
15
+ config (object): Configuration object containing the list of datasets under the attribute `DATASETS`.
16
+
17
+ Returns:
18
+ dict: A dictionary where each key is a dataset name and the value is another dictionary containing the SVD components ('u', 's', 'v') for each matrix in the task vector.
19
+ If a matrix is not 2-dimensional or contains 'text_projection' in its key, it is stored under the key 'dim1' without decomposition.
20
+ """
21
+ sv_reduction = 1 / len(config.DATASETS)
22
+ with torch.no_grad():
23
+ svd_dict = {}
24
+ for i, (task_vector, dataset) in enumerate(zip(task_vectors, config.DATASETS)):
25
+ svd_dict[dataset] = {}
26
+ print(f"Computing SVD for {dataset}...")
27
+ for key in task_vector.vector:
28
+ svd_dict[dataset][key] = {}
29
+ if (
30
+ len(task_vector.vector[key].shape) == 2
31
+ and "text_projection" not in key
32
+ ):
33
+ u, s, v = torch.linalg.svd(
34
+ task_vector.vector[key], full_matrices=False
35
+ )
36
+ reduced_index_s = int(s.shape[0] * sv_reduction)
37
+
38
+ temp_u = torch.zeros_like(u)
39
+ # select only the first reduced_index_s columns of u and place them
40
+ temp_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
41
+ :, :reduced_index_s
42
+ ]
43
+ svd_dict[dataset][key]["u"] = temp_u
44
+
45
+ temp_s = torch.zeros_like(s)
46
+ temp_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
47
+ :reduced_index_s
48
+ ]
49
+
50
+ svd_dict[dataset][key]["s"] = temp_s # s_reduced
51
+
52
+ # select only the first reduced_index_s rows of v and place them
53
+ temp_v = torch.zeros_like(v)
54
+ temp_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
55
+ :reduced_index_s, :
56
+ ]
57
+
58
+ svd_dict[dataset][key]["v"] = temp_v
59
+
60
+ # temp_mat = temp_u @ torch.diag_embed(temp_s) @ temp_v
61
+ else:
62
+ svd_dict[dataset][key]["dim1"] = task_vector.vector[key]
63
+ return svd_dict
64
+
65
+
66
+ def sum_svd_dict(svd_dict, config):
67
+ """
68
+ Sums the Singular Value Decomposition (SVD) components from multiple datasets and computes a new vector.
69
+
70
+ Args:
71
+ svd_dict (dict): A dictionary containing SVD components for multiple datasets. The structure of the dictionary is expected to be:
72
+ {
73
+ dataset_name: {
74
+ key: {
75
+ "u": tensor,
76
+ "s": tensor,
77
+ "v": tensor,
78
+ "dim1": tensor (optional)
79
+ }
80
+ }
81
+ }
82
+ config (object): A configuration object that contains a list of dataset names under the attribute `DATASETS`.
83
+
84
+ Returns:
85
+ dict: A dictionary containing the merged SVD components or averaged "dim1" values for each key.
86
+ """
87
+ print("Summing SVD...")
88
+ new_vector = {}
89
+ for key in svd_dict[config.DATASETS[0]]:
90
+ if "u" in svd_dict[config.DATASETS[0]][key].keys():
91
+ sum_u = sum([svd_dict[dataset][key]["u"] for dataset in config.DATASETS])
92
+ sum_s = sum([svd_dict[dataset][key]["s"] for dataset in config.DATASETS])
93
+ sum_v = sum([svd_dict[dataset][key]["v"] for dataset in config.DATASETS])
94
+
95
+ u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
96
+ u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
97
+ new_vector[key] = torch.linalg.multi_dot(
98
+ (
99
+ u_u,
100
+ v_u,
101
+ torch.diag(sum_s),
102
+ u_v,
103
+ v_v,
104
+ )
105
+ )
106
+ else:
107
+ for i, dataset in enumerate(config.DATASETS, start=1):
108
+ if i == 1:
109
+ new_vector[key] = svd_dict[dataset][key]["dim1"]
110
+ else:
111
+ new_vector[key] += (
112
+ svd_dict[dataset][key]["dim1"] - new_vector[key]
113
+ ) / i
114
+ return new_vector
115
+
116
+
117
+ ###############
118
+ ##### LOSSLESS Orthogonalization
119
+ def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config):
120
+ """
121
+ Computes the Singular Value Decomposition (SVD) for each task vector and merge the results.
122
+
123
+ This function performs the following steps:
124
+ 1. Iterates over each layer in the task vectors.
125
+ 2. For each layer, it computes the SVD of the corresponding matrix if it is a 2D tensor excluding "text_projection".
126
+ 3. Concatenate the U_i, S_i, and V_i matrices from the SVD across all tasks.
127
+ 4. If the vector is not a 2D tensor or is "text_projection", it computes the mean of the vectors.
128
+ 5. After concatenating the SVD components, recomputes the SVD of the summed U and V matrices and constructs the merged layer.
129
+
130
+ Args:
131
+ task_vectors (list): A list of task vectors, where each task vector is a dictionary containing the vectors for each task.
132
+ config (object): A configuration object containing the device and dataset information.
133
+
134
+ Returns:
135
+ dict: A dictionary containing the new vectors after summing the SVD components.
136
+ """
137
+ # becareful wit vit-l on 20 task it does not fit in GPU or in 64 GB RAM (try without last layer)
138
+ device = config.device
139
+ print("Computing SVD...")
140
+ with torch.no_grad():
141
+ new_vector = {}
142
+ for key in task_vectors[0].vector:
143
+ new_vector[key] = {}
144
+ for i, (task_vector, dataset) in enumerate(
145
+ zip(task_vectors, config.DATASETS)
146
+ ):
147
+ vec = task_vector.vector[key].to(device)
148
+
149
+ if (
150
+ len(task_vector.vector[key].shape) == 2
151
+ and "text_projection" not in key
152
+ ):
153
+
154
+ u, s, v = torch.linalg.svd(vec, full_matrices=False)
155
+
156
+ if i == 0:
157
+ print(f"Computed SVD for {key}...")
158
+ sum_u = torch.zeros(
159
+ u.shape[0], u.shape[1] * config.num_tasks, device=device
160
+ )
161
+ sum_s = torch.zeros(
162
+ s.shape[0] * config.num_tasks, device=device
163
+ )
164
+ sum_v = torch.zeros(
165
+ v.shape[0] * config.num_tasks, v.shape[1], device=device
166
+ )
167
+ reduced_index_s = s.shape[0]
168
+
169
+ # select only the first reduced_index_s columns of u and place them
170
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
171
+ :, :reduced_index_s
172
+ ]
173
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
174
+ :reduced_index_s
175
+ ]
176
+ # select only the first reduced_index_s rows of v and place them
177
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
178
+ :reduced_index_s, :
179
+ ]
180
+
181
+ else:
182
+ if i == 0:
183
+ new_vector[key] = vec.clone()
184
+ else:
185
+ new_vector[key] += (vec - new_vector[key]) / (i + 1)
186
+
187
+ if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
188
+ u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
189
+ u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
190
+
191
+ new_vector[key] = torch.linalg.multi_dot(
192
+ (
193
+ u_u,
194
+ v_u,
195
+ torch.diag(sum_s),
196
+ u_v,
197
+ v_v,
198
+ )
199
+ )
200
+
201
+ return new_vector
202
+
203
+
204
+ ###############
205
+ ##### LOSSLESS EIGENDECOMP
206
+ def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config):
207
+ """
208
+ Computes the Singular Value Decomposition (SVD) for each task vector and merge the results.
209
+
210
+ This function performs the following steps:
211
+ 1. Iterates over each layer in the task vectors.
212
+ 2. For each layer, it computes the SVD of the corresponding matrix if it is a 2D tensor excluding "text_projection".
213
+ 3. Concatenate the U_i, S_i, and V_i matrices from the SVD across all tasks.
214
+ 4. If the vector is not a 2D tensor or is "text_projection", it computes the mean of the vectors.
215
+ 5. After concatenating the SVD components, recomputes the eigendecomposition of the summed U and V matrices and constructs the merged layer.
216
+
217
+ Args:
218
+ task_vectors (list): A list of task vectors, where each task vector is a dictionary containing the vectors for each task.
219
+ config (object): A configuration object containing the device and dataset information.
220
+
221
+ Returns:
222
+ dict: A dictionary containing the new vectors after merging the SVD components.
223
+ """
224
+ # becareful wit vit-l on 20 task it does not fit in GPU or in 64 GB RAM (try without last layer)
225
+ device = config.device
226
+ print("Computing SVD...")
227
+ with torch.no_grad():
228
+ new_vector = {}
229
+ for key in task_vectors[0].vector:
230
+ new_vector[key] = {}
231
+ for i, (task_vector, dataset) in enumerate(
232
+ zip(task_vectors, config.DATASETS)
233
+ ):
234
+ vec = task_vector.vector[key].to(device)
235
+
236
+ if (
237
+ len(task_vector.vector[key].shape) == 2
238
+ and "text_projection" not in key
239
+ ):
240
+
241
+ u, s, v = torch.linalg.svd(vec, full_matrices=False)
242
+
243
+ if i == 0:
244
+ print(f"Computed SVD for {key}...")
245
+ sum_u = torch.zeros(
246
+ u.shape[0], u.shape[1] * config.num_tasks, device=device
247
+ )
248
+ sum_s = torch.zeros(
249
+ s.shape[0] * config.num_tasks, device=device
250
+ )
251
+ sum_v = torch.zeros(
252
+ v.shape[0] * config.num_tasks, v.shape[1], device=device
253
+ )
254
+ reduced_index_s = s.shape[0]
255
+
256
+ # select only the first reduced_index_s columns of u and place them
257
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
258
+ :, :reduced_index_s
259
+ ]
260
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
261
+ :reduced_index_s
262
+ ]
263
+ # select only the first reduced_index_s rows of v and place them
264
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
265
+ :reduced_index_s, :
266
+ ]
267
+
268
+ else:
269
+ if i == 0:
270
+ new_vector[key] = vec.clone()
271
+ else:
272
+ new_vector[key] += (vec - new_vector[key]) / (i + 1)
273
+
274
+ if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
275
+ sum_s, indices = torch.sort(sum_s, stable=True)
276
+
277
+ sum_u = torch.index_select(sum_u, 1, indices)
278
+ l_u, q_u = torch.linalg.eigh(sum_u.mT @ sum_u)
279
+ u_orth = (
280
+ q_u
281
+ @ torch.diag(1.0 / (torch.sqrt(torch.abs(l_u)) + 1e-12))
282
+ @ q_u.mT
283
+ )
284
+
285
+ sum_v = torch.index_select(sum_v, 0, indices)
286
+
287
+ l_v, q_v = torch.linalg.eigh(sum_v @ sum_v.mT)
288
+ v_orth = (
289
+ q_v
290
+ @ torch.diag(1.0 / (torch.sqrt(torch.abs(l_v)) + 1e-12))
291
+ @ q_v.mT
292
+ )
293
+
294
+ new_vector[key] = torch.linalg.multi_dot( # bool_mask *
295
+ (
296
+ u_orth,
297
+ torch.diag(sum_s),
298
+ v_orth,
299
+ )
300
+ )
301
+
302
+ return new_vector
303
+
304
+
305
+ ###############
306
+ #### TSV Merge Orthogonalization
307
+ @torch.no_grad()
308
+ def compute_and_sum_svd_mem_reduction(
309
+ task_vectors: List[StateDictType],
310
+ exclude_keys: Optional[List[str]] = None,
311
+ accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
312
+ ) -> StateDictType:
313
+ """
314
+ Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors,
315
+ reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate
316
+ the low-rank matrices. If the vector is not a 2D tensor or is "text_projection", it computes the mean of the vectors.
317
+ Computation of the SVD is performed also for the second operation.
318
+
319
+ Args:
320
+ task_vectors (list): A list of task vector objects, where each object contains a
321
+ dictionary of vectors.
322
+ exclude_keys (list): A list of keys to exclude from the TSVM.
323
+ accelerator (torch.device): The device to use for the computation.
324
+
325
+ Returns:
326
+ dict: A dictionary containing the new vectors after SVD computation and merging.
327
+ """
328
+ if exclude_keys is None:
329
+ exclude_keys = []
330
+ sv_reduction = 1 / len(task_vectors)
331
+
332
+ new_vector = {}
333
+ for key in task_vectors[0]:
334
+ original_device = task_vectors[0][key].device
335
+ original_dtype = task_vectors[0][key].dtype
336
+
337
+ new_vector[key] = {}
338
+ for i, task_vector in enumerate(task_vectors):
339
+ vec = task_vector[key].to(accelerator)
340
+
341
+ if len(task_vector[key].shape) == 2 and key not in exclude_keys:
342
+ # at current, the SVD is not supported for half precision, so we need to convert to float32
343
+ if not (
344
+ original_dtype == torch.float32 or original_dtype == torch.float64
345
+ ):
346
+ vec = vec.to(dtype=torch.float32)
347
+
348
+ u, s, v = torch.linalg.svd(vec, full_matrices=False)
349
+
350
+ if i == 0:
351
+ print(f"Computed SVD for {key}...")
352
+ sum_u = torch.zeros_like(u, device=accelerator)
353
+ sum_s = torch.zeros_like(s, device=accelerator)
354
+ sum_v = torch.zeros_like(v, device=accelerator)
355
+ reduced_index_s = int(s.shape[0] * sv_reduction)
356
+
357
+ # select only the first reduced_index_s columns of u and place them
358
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
359
+ :, :reduced_index_s
360
+ ]
361
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
362
+ :reduced_index_s
363
+ ]
364
+ # select only the first reduced_index_s rows of v and place them
365
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
366
+ :reduced_index_s, :
367
+ ]
368
+
369
+ else:
370
+ # if the vector is not a 2D tensor or is in exclude_keys, compute the mean
371
+ if i == 0:
372
+ new_vector[key] = vec.clone()
373
+ else:
374
+ new_vector[key] += (vec - new_vector[key]) / (i + 1)
375
+
376
+ if len(task_vector[key].shape) == 2 and key not in exclude_keys:
377
+ u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
378
+ u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
379
+
380
+ new_vector[key] = torch.linalg.multi_dot(
381
+ (
382
+ u_u,
383
+ v_u,
384
+ torch.diag(sum_s),
385
+ u_v,
386
+ v_v,
387
+ )
388
+ )
389
+ new_vector[key] = new_vector[key].to(
390
+ device=original_device, dtype=original_dtype, non_blocking=True
391
+ )
392
+ return new_vector
393
+
394
+
395
+ ###############
396
+ #### TSV Merge Eigendecomp
397
+ def compute_and_sum_svd_mem_reduction_2(task_vectors, config):
398
+ """
399
+ Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors,
400
+ reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate
401
+ the low-rank matrices. If the vector is not a 2D tensor or is "text_projection", it computes the mean of the vectors.
402
+ Computation of the eigendecomposition is performed instead of the SVD for the second operation.
403
+
404
+ Args:
405
+ task_vectors (list): A list of task vector objects, where each object contains a
406
+ dictionary of vectors.
407
+ config (object): Configuration object containing the following attributes:
408
+ - DATASETS (list): List of datasets.
409
+ - device (torch.device): The device to perform computations on.
410
+
411
+ Returns:
412
+ dict: A dictionary containing the new vectors after SVD computation and merging.
413
+ """
414
+ sv_reduction = 1 / len(config.DATASETS)
415
+ device = config.device
416
+ print("Computing SVD...")
417
+ with torch.no_grad():
418
+ new_vector = {}
419
+ for key in task_vectors[0].vector:
420
+ new_vector[key] = {}
421
+ for i, (task_vector, dataset) in enumerate(
422
+ zip(task_vectors, config.DATASETS)
423
+ ):
424
+ vec = task_vector.vector[key].to(device)
425
+
426
+ if (
427
+ len(task_vector.vector[key].shape) == 2
428
+ and "text_projection" not in key
429
+ ):
430
+ u, s, v = torch.linalg.svd(vec, full_matrices=False)
431
+
432
+ if i == 0:
433
+ print(f"Computed SVD for {key}...")
434
+ sum_u = torch.zeros_like(u, device=device)
435
+ sum_s = torch.zeros_like(s, device=device)
436
+ sum_v = torch.zeros_like(v, device=device)
437
+ reduced_index_s = int(s.shape[0] * sv_reduction)
438
+
439
+ # select only the first reduced_index_s columns of u and place them
440
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
441
+ :, :reduced_index_s
442
+ ]
443
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
444
+ :reduced_index_s
445
+ ]
446
+ # select only the first reduced_index_s rows of v and place them
447
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
448
+ :reduced_index_s, :
449
+ ]
450
+
451
+ else:
452
+ if i == 0:
453
+ new_vector[key] = vec.clone()
454
+ else:
455
+ new_vector[key] += (vec - new_vector[key]) / (i + 1)
456
+
457
+ if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
458
+ sum_s, indices = torch.sort(sum_s, stable=True)
459
+
460
+ sum_u = torch.index_select(sum_u, 1, indices)
461
+ l_u, q_u = torch.linalg.eigh(sum_u.mT @ sum_u)
462
+ u_orth = (
463
+ q_u
464
+ @ torch.diag(1.0 / (torch.sqrt(torch.abs(l_u)) + 1e-12))
465
+ @ q_u.mT
466
+ )
467
+
468
+ sum_v = torch.index_select(sum_v, 0, indices)
469
+
470
+ l_v, q_v = torch.linalg.eigh(sum_v @ sum_v.mT)
471
+ v_orth = (
472
+ q_v
473
+ @ torch.diag(1.0 / (torch.sqrt(torch.abs(l_v)) + 1e-12))
474
+ @ q_v.mT
475
+ )
476
+
477
+ new_vector[key] = torch.linalg.multi_dot( # bool_mask *
478
+ (
479
+ sum_u,
480
+ u_orth,
481
+ torch.diag(sum_s),
482
+ v_orth,
483
+ sum_v,
484
+ )
485
+ )
486
+
487
+ return new_vector
488
+
489
+
490
+ ###############
491
+ #### Rank Reduction TV
492
+ def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config):
493
+ """
494
+ Compute and sum the Singular Value Decomposition (SVD) of task vectors with rank reduction.
495
+
496
+ This function performs SVD on the vectors in `task_vectors` and reduces their rank based on the
497
+ number of tasks specified in `config.DATASETS`. The reduced vectors are then summed together.
498
+
499
+ Args:
500
+ task_vectors (list): A list of task vector objects. Each object should have a `vector` attribute
501
+ which is a dictionary where keys are vector names and values are tensors.
502
+ config (object): Configuration object containing the following attributes:
503
+ - DATASETS (list): List of datasets.
504
+ - device (torch.device): The device to perform computations on.
505
+
506
+ Returns:
507
+ dict: A dictionary containing the new vectors after SVD computation and summation.
508
+ """
509
+ sv_reduction = 1 / len(config.DATASETS)
510
+ device = config.device
511
+ print("Computing SVD...")
512
+ with torch.no_grad():
513
+ new_vector = {}
514
+ for key in task_vectors[0].vector:
515
+ new_vector[key] = {}
516
+ for i, (task_vector, dataset) in enumerate(
517
+ zip(task_vectors, config.DATASETS)
518
+ ):
519
+ vec = task_vector.vector[key].to(device)
520
+
521
+ if (
522
+ len(task_vector.vector[key].shape) == 2
523
+ and "text_projection" not in key
524
+ ):
525
+ u, s, v = torch.linalg.svd(vec, full_matrices=False)
526
+
527
+ if i == 0:
528
+ print(f"Computed SVD for {key}...")
529
+ sum_u = torch.zeros_like(u, device=device)
530
+ sum_s = torch.zeros_like(s, device=device)
531
+ sum_v = torch.zeros_like(v, device=device)
532
+ reduced_index_s = int(s.shape[0] * sv_reduction)
533
+
534
+ # select only the first reduced_index_s columns of u and place them
535
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
536
+ :, :reduced_index_s
537
+ ]
538
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
539
+ :reduced_index_s
540
+ ]
541
+ # select only the first reduced_index_s rows of v and place them
542
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
543
+ :reduced_index_s, :
544
+ ]
545
+
546
+ else:
547
+ if i == 0:
548
+ new_vector[key] = vec.clone()
549
+ else:
550
+ new_vector[key] += (vec - new_vector[key]) / (i + 1)
551
+
552
+ if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
553
+ new_vector[key] = torch.linalg.multi_dot(
554
+ (
555
+ sum_u,
556
+ torch.diag(sum_s),
557
+ sum_v,
558
+ )
559
+ )
560
+ return new_vector
561
+
562
+
563
+ def compute_and_sum_svd_mem_reduction_dummy(task_vectors, config):
564
+ """To perform dummy operations."""
565
+ sv_reduction = 1 / 8
566
+ print("Computing SVD...")
567
+ with torch.no_grad():
568
+ new_vector = {}
569
+ for key in task_vectors[0].vector:
570
+ new_vector[key] = {}
571
+ for i in range(0, 8):
572
+ if (
573
+ len(task_vectors[0].vector[key].shape) == 2
574
+ and "text_projection" not in key
575
+ ):
576
+ if i == 0:
577
+ u, s, v = torch.linalg.svd(
578
+ task_vectors[0].vector[key], full_matrices=False
579
+ )
580
+ reduced_index_s = int(s.shape[0] * sv_reduction)
581
+
582
+ print(f"Computed SVD for {key}...")
583
+ sum_u = torch.zeros_like(u)
584
+ sum_s = torch.zeros_like(s)
585
+ sum_v = torch.zeros_like(v)
586
+
587
+ # select only the first reduced_index_s columns of u and place them
588
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
589
+ :, :reduced_index_s
590
+ ]
591
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
592
+ :reduced_index_s
593
+ ]
594
+ # select only the first reduced_index_s rows of v and place them
595
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
596
+ :reduced_index_s, :
597
+ ]
598
+ else:
599
+ # generate u vectors orthogonal to the previous ones
600
+ # generate v vectors orthogonal to the previous ones
601
+ print("dummy")
602
+ u = torch.nn.functional.normalize(
603
+ torch.randn_like(sum_u), p=2, dim=-2
604
+ )
605
+ v = torch.nn.functional.normalize(
606
+ torch.randn_like(sum_v), p=2, dim=-1
607
+ )
608
+
609
+ # select only the first reduced_index_s columns of u and place them
610
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
611
+ :, :reduced_index_s
612
+ ]
613
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
614
+ :reduced_index_s
615
+ ]
616
+ # select only the first reduced_index_s rows of v and place them
617
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
618
+ :reduced_index_s, :
619
+ ]
620
+
621
+ else:
622
+ if i == 0:
623
+ new_vector[key] = task_vectors[0].vector[key]
624
+ # else:
625
+ # new_vector[key] += (
626
+ # task_vector.vector[key] - new_vector[key]
627
+ # ) / (i + 1)
628
+
629
+ if (
630
+ len(task_vectors[0].vector[key].shape) == 2
631
+ and "text_projection" not in key
632
+ ):
633
+
634
+ new_vector[key] = torch.linalg.multi_dot(
635
+ (
636
+ sum_u,
637
+ torch.diag(sum_s),
638
+ sum_v,
639
+ )
640
+ )
641
+
642
+ return new_vector
@@ -0,0 +1,7 @@
1
+ from fusion_bench.method.ties_merging.ties_merging_utils import (
2
+ check_parameterNamesMatch,
3
+ check_state_dicts_equal,
4
+ )
5
+ from fusion_bench.utils import state_dict_to_vector, vector_to_state_dict
6
+
7
+ from . import TSVC_utils, TSVM_utils
@@ -4,10 +4,13 @@ This is modified based on https://github.com/EnnengYang/AdaMerging/blob/main/src
4
4
 
5
5
  import copy
6
6
  from collections import OrderedDict
7
+ from typing import List
7
8
 
8
9
  import torch
9
10
  from torch import Tensor, nn
10
11
 
12
+ from fusion_bench.utils.type import StateDictType
13
+
11
14
 
12
15
  # Model conversion utils
13
16
  def state_dict_to_vector(state_dict, remove_keys=[]):
@@ -82,7 +85,7 @@ def add_ptm_to_tv(tv_dict, ptm_dict):
82
85
  return final_dict
83
86
 
84
87
 
85
- def check_parameterNamesMatch(checkpoints):
88
+ def check_parameterNamesMatch(checkpoints: List[StateDictType]) -> None:
86
89
  """
87
90
  Check if the parameter names match across multiple checkpoints.
88
91
 
@@ -105,7 +108,9 @@ def check_parameterNamesMatch(checkpoints):
105
108
  )
106
109
 
107
110
 
108
- def check_state_dicts_equal(state_dict1, state_dict2):
111
+ def check_state_dicts_equal(
112
+ state_dict1: StateDictType, state_dict2: StateDictType
113
+ ) -> bool:
109
114
  """
110
115
  Check if two state dictionaries are equal.
111
116