fusion-bench 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/constants/__init__.py +5 -1
- fusion_bench/constants/runtime.py +111 -7
- fusion_bench/dataset/gsm8k.py +6 -2
- fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
- fusion_bench/method/__init__.py +10 -2
- fusion_bench/method/base_algorithm.py +29 -19
- fusion_bench/method/classification/image_classification_finetune.py +1 -2
- fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
- fusion_bench/metrics/model_kinship/__init__.py +2 -0
- fusion_bench/metrics/model_kinship/calculate.py +77 -0
- fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
- fusion_bench/metrics/model_kinship/utility.py +184 -0
- fusion_bench/metrics/nyuv2/__init__.py +31 -0
- fusion_bench/metrics/nyuv2/depth.py +30 -0
- fusion_bench/metrics/nyuv2/loss.py +40 -0
- fusion_bench/metrics/nyuv2/noise.py +24 -0
- fusion_bench/metrics/nyuv2/normal.py +34 -1
- fusion_bench/metrics/nyuv2/segmentation.py +35 -1
- fusion_bench/mixins/clip_classification.py +30 -2
- fusion_bench/mixins/lightning_fabric.py +46 -5
- fusion_bench/mixins/rich_live.py +76 -0
- fusion_bench/modelpool/base_pool.py +86 -5
- fusion_bench/models/masks/mask_model.py +8 -2
- fusion_bench/models/open_clip/modeling.py +7 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
- fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
- fusion_bench/scripts/cli.py +14 -0
- fusion_bench/scripts/webui.py +250 -17
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/data.py +100 -9
- fusion_bench/utils/devices.py +3 -1
- fusion_bench/utils/fabric.py +185 -4
- fusion_bench/utils/instantiate_utils.py +29 -18
- fusion_bench/utils/json.py +6 -0
- fusion_bench/utils/misc.py +16 -0
- fusion_bench/utils/rich_utils.py +123 -6
- fusion_bench/utils/validation.py +197 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/METADATA +72 -13
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/RECORD +49 -45
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
- fusion_bench_config/llama_full_finetune.yaml +4 -16
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +4 -13
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench/utils/auto.py +0 -31
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/top_level.txt +0 -0
|
@@ -5,4 +5,8 @@ from .paths import *
|
|
|
5
5
|
from .runtime import RuntimeConstants
|
|
6
6
|
|
|
7
7
|
# fusionbench version
|
|
8
|
-
|
|
8
|
+
try:
|
|
9
|
+
FUSION_BENCH_VERSION = importlib.metadata.version("fusion-bench")
|
|
10
|
+
except importlib.metadata.PackageNotFoundError:
|
|
11
|
+
# Fallback when package is not installed (e.g., during development)
|
|
12
|
+
FUSION_BENCH_VERSION = "0.0.0.dev"
|
|
@@ -1,3 +1,29 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Runtime Constants Module.
|
|
3
|
+
|
|
4
|
+
This module provides a thread-safe singleton class for managing runtime configuration
|
|
5
|
+
and constants across the Fusion Bench framework. It centralizes access to runtime
|
|
6
|
+
settings like cache directories, debug flags, and logging preferences.
|
|
7
|
+
|
|
8
|
+
Example:
|
|
9
|
+
```python
|
|
10
|
+
from fusion_bench.constants.runtime import RuntimeConstants
|
|
11
|
+
|
|
12
|
+
# Get the singleton instance
|
|
13
|
+
runtime = RuntimeConstants()
|
|
14
|
+
|
|
15
|
+
# Configure cache directory
|
|
16
|
+
runtime.cache_dir = "/custom/cache/path"
|
|
17
|
+
|
|
18
|
+
# Enable debug mode
|
|
19
|
+
runtime.debug = True
|
|
20
|
+
|
|
21
|
+
# Control function call logging
|
|
22
|
+
runtime.print_function_call = True
|
|
23
|
+
```
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
import os
|
|
1
27
|
import threading
|
|
2
28
|
from pathlib import Path
|
|
3
29
|
from typing import Optional, Union
|
|
@@ -5,18 +31,46 @@ from typing import Optional, Union
|
|
|
5
31
|
|
|
6
32
|
class RuntimeConstants:
|
|
7
33
|
"""
|
|
8
|
-
|
|
9
|
-
|
|
34
|
+
Thread-safe singleton for managing runtime configuration and constants.
|
|
35
|
+
|
|
36
|
+
This class provides centralized access to runtime settings that affect the
|
|
37
|
+
behavior of the entire Fusion Bench framework. It ensures consistent
|
|
38
|
+
configuration across all modules and supports thread-safe access in
|
|
39
|
+
multi-threaded environments.
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
debug: Global debug flag for enabling verbose logging and debugging features.
|
|
43
|
+
|
|
44
|
+
Example:
|
|
45
|
+
```python
|
|
46
|
+
runtime = RuntimeConstants()
|
|
47
|
+
|
|
48
|
+
# Configure caching
|
|
49
|
+
runtime.cache_dir = Path.home() / ".cache" / "fusion_bench"
|
|
50
|
+
|
|
51
|
+
# Enable debugging
|
|
52
|
+
runtime.debug = True
|
|
53
|
+
runtime.print_function_call = True
|
|
54
|
+
```
|
|
10
55
|
|
|
11
|
-
|
|
12
|
-
|
|
56
|
+
Note:
|
|
57
|
+
This class implements the singleton pattern with thread-safe initialization.
|
|
58
|
+
Multiple calls to the constructor will return the same instance.
|
|
13
59
|
"""
|
|
14
60
|
|
|
15
61
|
_instance: Optional["RuntimeConstants"] = None
|
|
16
62
|
_lock = threading.Lock()
|
|
17
63
|
|
|
18
64
|
def __new__(cls) -> "RuntimeConstants":
|
|
19
|
-
"""
|
|
65
|
+
"""
|
|
66
|
+
Create or return the singleton instance with thread safety.
|
|
67
|
+
|
|
68
|
+
Uses double-check locking pattern to ensure thread-safe singleton creation
|
|
69
|
+
while minimizing synchronization overhead.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
The singleton RuntimeConstants instance.
|
|
73
|
+
"""
|
|
20
74
|
with cls._lock:
|
|
21
75
|
# Double-check locking pattern
|
|
22
76
|
if cls._instance is None:
|
|
@@ -25,33 +79,83 @@ class RuntimeConstants:
|
|
|
25
79
|
return cls._instance
|
|
26
80
|
|
|
27
81
|
def __init__(self):
|
|
28
|
-
"""
|
|
82
|
+
"""
|
|
83
|
+
Initialize the singleton instance only once.
|
|
84
|
+
|
|
85
|
+
Subsequent calls to __init__ are no-ops to maintain singleton behavior.
|
|
86
|
+
"""
|
|
29
87
|
if not self._initialized:
|
|
30
|
-
#
|
|
88
|
+
# Initialize default values
|
|
31
89
|
self._initialized = True
|
|
32
90
|
|
|
33
91
|
debug = False
|
|
92
|
+
"""Global debug flag for enabling verbose logging and debugging features."""
|
|
34
93
|
|
|
35
94
|
@property
|
|
36
95
|
def cache_dir(self) -> Path:
|
|
96
|
+
"""
|
|
97
|
+
Get the default cache directory for models and datasets.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Path object pointing to the cache directory.
|
|
101
|
+
|
|
102
|
+
Example:
|
|
103
|
+
```python
|
|
104
|
+
runtime = RuntimeConstants()
|
|
105
|
+
print(f"Cache directory: {runtime.cache_dir}")
|
|
106
|
+
```
|
|
107
|
+
"""
|
|
37
108
|
from fusion_bench.utils.cache_utils import DEFAULT_CACHE_DIR
|
|
38
109
|
|
|
39
110
|
return DEFAULT_CACHE_DIR
|
|
40
111
|
|
|
41
112
|
@cache_dir.setter
|
|
42
113
|
def cache_dir(self, path: Union[str, Path]) -> None:
|
|
114
|
+
"""
|
|
115
|
+
Set the default cache directory for models and datasets.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
path: New cache directory path as string or Path object.
|
|
119
|
+
|
|
120
|
+
Example:
|
|
121
|
+
```python
|
|
122
|
+
runtime = RuntimeConstants()
|
|
123
|
+
runtime.cache_dir = "/data/fusion_bench_cache"
|
|
124
|
+
```
|
|
125
|
+
"""
|
|
43
126
|
from fusion_bench.utils.cache_utils import set_default_cache_dir
|
|
44
127
|
|
|
45
128
|
set_default_cache_dir(path)
|
|
46
129
|
|
|
47
130
|
@property
|
|
48
131
|
def print_function_call(self) -> bool:
|
|
132
|
+
"""
|
|
133
|
+
Get whether function calls are printed during instantiation.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
True if function call printing is enabled, False otherwise.
|
|
137
|
+
"""
|
|
49
138
|
from fusion_bench.utils.instantiate_utils import PRINT_FUNCTION_CALL
|
|
50
139
|
|
|
51
140
|
return PRINT_FUNCTION_CALL
|
|
52
141
|
|
|
53
142
|
@print_function_call.setter
|
|
54
143
|
def print_function_call(self, enable: bool) -> None:
|
|
144
|
+
"""
|
|
145
|
+
Set whether to print function calls during instantiation.
|
|
146
|
+
|
|
147
|
+
Useful for debugging to see which functions are being called
|
|
148
|
+
when instantiating objects from configuration.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
enable: True to enable printing, False to disable.
|
|
152
|
+
|
|
153
|
+
Example:
|
|
154
|
+
```python
|
|
155
|
+
runtime = RuntimeConstants()
|
|
156
|
+
runtime.print_function_call = True # Enable verbose logging
|
|
157
|
+
```
|
|
158
|
+
"""
|
|
55
159
|
from fusion_bench.utils.instantiate_utils import set_print_function_call
|
|
56
160
|
|
|
57
161
|
set_print_function_call(enable)
|
fusion_bench/dataset/gsm8k.py
CHANGED
|
@@ -13,8 +13,12 @@ def load_gsm8k_question_label_data(
|
|
|
13
13
|
|
|
14
14
|
An example in the dataset:
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
|
|
16
|
+
```python
|
|
17
|
+
{
|
|
18
|
+
'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
|
|
19
|
+
'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'
|
|
20
|
+
}
|
|
21
|
+
```
|
|
18
22
|
|
|
19
23
|
Args:
|
|
20
24
|
dataset_name (Literal["train", "test", "train_socratic", "test_socratic"]): The name of the dataset to load.
|
|
@@ -1,4 +1,31 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Image Corruption Module for Robustness Testing.
|
|
4
|
+
|
|
5
|
+
This module provides various image corruption functions to test model robustness.
|
|
6
|
+
It implements common corruptions such as noise, blur, compression artifacts, and
|
|
7
|
+
weather effects. These corruptions are commonly used in benchmark datasets like
|
|
8
|
+
ImageNet-C and CIFAR-10-C.
|
|
9
|
+
|
|
10
|
+
The corruptions can be applied at different severity levels (1-5), where higher
|
|
11
|
+
levels indicate stronger corruption effects.
|
|
12
|
+
|
|
13
|
+
Example:
|
|
14
|
+
```python
|
|
15
|
+
from PIL import Image
|
|
16
|
+
from fusion_bench.dataset.image_corruption.make_corruption import gaussian_noise, motion_blur
|
|
17
|
+
|
|
18
|
+
# Load an image
|
|
19
|
+
img = Image.open("example.jpg")
|
|
20
|
+
|
|
21
|
+
# Apply gaussian noise at severity level 3
|
|
22
|
+
corrupted_img = gaussian_noise(img, severity=3)
|
|
23
|
+
|
|
24
|
+
# Apply motion blur at severity level 2
|
|
25
|
+
blurred_img = motion_blur(img, severity=2)
|
|
26
|
+
```
|
|
27
|
+
"""
|
|
28
|
+
|
|
2
29
|
import logging
|
|
3
30
|
|
|
4
31
|
logger = logging.getLogger(__name__)
|
|
@@ -37,11 +64,39 @@ warnings.simplefilter("ignore", UserWarning)
|
|
|
37
64
|
|
|
38
65
|
# /////////////// Distortions ///////////////
|
|
39
66
|
class MotionImage(WandImage):
|
|
67
|
+
"""
|
|
68
|
+
Extended WandImage class with motion blur capability.
|
|
69
|
+
|
|
70
|
+
This class wraps ImageMagick's motion blur functionality through the Wand library.
|
|
71
|
+
"""
|
|
72
|
+
|
|
40
73
|
def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
|
|
74
|
+
"""
|
|
75
|
+
Apply motion blur effect to the image.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
radius: The radius of the Gaussian, in pixels, not counting the center pixel.
|
|
79
|
+
sigma: The standard deviation of the Gaussian, in pixels.
|
|
80
|
+
angle: Apply the effect along this angle in degrees.
|
|
81
|
+
"""
|
|
41
82
|
wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)
|
|
42
83
|
|
|
43
84
|
|
|
44
85
|
def gaussian_noise(x, severity=1):
|
|
86
|
+
"""
|
|
87
|
+
Apply Gaussian noise corruption to an image.
|
|
88
|
+
|
|
89
|
+
Adds random Gaussian noise to the image, simulating sensor noise or
|
|
90
|
+
environmental interference.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
x: Input image as PIL Image or numpy array. If numpy array, should be in
|
|
94
|
+
range [0, 255].
|
|
95
|
+
severity: Corruption severity level from 1 (mild) to 5 (severe).
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
numpy.ndarray: Corrupted image as numpy array in range [0, 255].
|
|
99
|
+
"""
|
|
45
100
|
c = [0.04, 0.06, 0.08, 0.09, 0.10][severity - 1]
|
|
46
101
|
|
|
47
102
|
x = np.array(x) / 255.0
|
|
@@ -49,6 +104,20 @@ def gaussian_noise(x, severity=1):
|
|
|
49
104
|
|
|
50
105
|
|
|
51
106
|
def impulse_noise(x, severity=1):
|
|
107
|
+
"""
|
|
108
|
+
Apply impulse (salt-and-pepper) noise corruption to an image.
|
|
109
|
+
|
|
110
|
+
Randomly replaces pixels with either maximum or minimum intensity values,
|
|
111
|
+
simulating transmission errors or faulty pixels.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
x: Input image as PIL Image or numpy array. If numpy array, should be in
|
|
115
|
+
range [0, 255].
|
|
116
|
+
severity: Corruption severity level from 1 (mild) to 5 (severe).
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
numpy.ndarray: Corrupted image as numpy array in range [0, 255].
|
|
120
|
+
"""
|
|
52
121
|
c = [0.01, 0.02, 0.03, 0.05, 0.07][severity - 1]
|
|
53
122
|
|
|
54
123
|
x = sk.util.random_noise(np.array(x) / 255.0, mode="s&p", amount=c)
|
|
@@ -56,6 +125,21 @@ def impulse_noise(x, severity=1):
|
|
|
56
125
|
|
|
57
126
|
|
|
58
127
|
def motion_blur(x, severity=1):
|
|
128
|
+
"""
|
|
129
|
+
Apply motion blur corruption to an image.
|
|
130
|
+
|
|
131
|
+
Simulates camera shake or object motion during image capture by applying
|
|
132
|
+
directional blur at a random angle.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
x: Input PIL Image.
|
|
136
|
+
severity: Corruption severity level from 1 (mild) to 5 (severe).
|
|
137
|
+
Higher severity increases blur radius and sigma.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
numpy.ndarray: Corrupted image as numpy array in range [0, 255].
|
|
141
|
+
Returns RGB image regardless of input format.
|
|
142
|
+
"""
|
|
59
143
|
c = [(6, 1), (6, 1.5), (6, 2), (8, 2), (9, 2.5)][severity - 1]
|
|
60
144
|
|
|
61
145
|
output = BytesIO()
|
|
@@ -73,6 +157,21 @@ def motion_blur(x, severity=1):
|
|
|
73
157
|
|
|
74
158
|
|
|
75
159
|
def spatter(x, severity=1):
|
|
160
|
+
"""
|
|
161
|
+
Apply spatter corruption to an image.
|
|
162
|
+
|
|
163
|
+
Simulates liquid splatter effects (water or mud) on the image, creating
|
|
164
|
+
realistic occlusions similar to raindrops or dirt on a camera lens.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
x: Input image as PIL Image or numpy array. If numpy array, should be in
|
|
168
|
+
range [0, 255].
|
|
169
|
+
severity: Corruption severity level from 1 (mild) to 5 (severe).
|
|
170
|
+
Levels 1-3 simulate water splatter, levels 4-5 simulate mud splatter.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
numpy.ndarray: Corrupted image as numpy array in range [0, 255].
|
|
174
|
+
"""
|
|
76
175
|
c = [
|
|
77
176
|
(0.62, 0.1, 0.7, 0.7, 0.5, 0),
|
|
78
177
|
(0.65, 0.1, 0.8, 0.7, 0.5, 0),
|
|
@@ -140,6 +239,21 @@ def spatter(x, severity=1):
|
|
|
140
239
|
|
|
141
240
|
|
|
142
241
|
def contrast(x, severity=1):
|
|
242
|
+
"""
|
|
243
|
+
Apply contrast reduction corruption to an image.
|
|
244
|
+
|
|
245
|
+
Reduces image contrast by blending pixels toward their mean values,
|
|
246
|
+
simulating poor lighting conditions or low-quality image sensors.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
x: Input image as PIL Image or numpy array. If numpy array, should be in
|
|
250
|
+
range [0, 255].
|
|
251
|
+
severity: Corruption severity level from 1 (mild) to 5 (severe).
|
|
252
|
+
Higher severity results in lower contrast.
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
numpy.ndarray: Corrupted image as numpy array in range [0, 255].
|
|
256
|
+
"""
|
|
143
257
|
c = [0.75, 0.5, 0.4, 0.3, 0.15][severity - 1]
|
|
144
258
|
|
|
145
259
|
x = np.array(x) / 255.0
|
|
@@ -148,6 +262,20 @@ def contrast(x, severity=1):
|
|
|
148
262
|
|
|
149
263
|
|
|
150
264
|
def jpeg_compression(x, severity=1):
|
|
265
|
+
"""
|
|
266
|
+
Apply JPEG compression artifacts to an image.
|
|
267
|
+
|
|
268
|
+
Simulates compression artifacts from lossy JPEG encoding at various
|
|
269
|
+
quality levels, commonly seen in heavily compressed images.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
x: Input PIL Image.
|
|
273
|
+
severity: Corruption severity level from 1 (mild) to 5 (severe).
|
|
274
|
+
Lower severity uses higher JPEG quality (less compression).
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
PIL.Image: Corrupted image as PIL Image.
|
|
278
|
+
"""
|
|
151
279
|
c = [80, 65, 58, 50, 40][severity - 1]
|
|
152
280
|
|
|
153
281
|
output = BytesIO()
|
|
@@ -158,6 +286,23 @@ def jpeg_compression(x, severity=1):
|
|
|
158
286
|
|
|
159
287
|
|
|
160
288
|
def pixelate(x, severity=1):
|
|
289
|
+
"""
|
|
290
|
+
Apply pixelation corruption to an image.
|
|
291
|
+
|
|
292
|
+
Reduces image resolution by downsampling and then upsampling,
|
|
293
|
+
creating a blocky, pixelated appearance.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
x: Input PIL Image with size (32, 32).
|
|
297
|
+
severity: Corruption severity level from 1 (mild) to 5 (severe).
|
|
298
|
+
Higher severity results in more pixelation.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
PIL.Image: Corrupted image as PIL Image with original size (32, 32).
|
|
302
|
+
|
|
303
|
+
Note:
|
|
304
|
+
This function is designed for 32x32 images (e.g., CIFAR-10).
|
|
305
|
+
"""
|
|
161
306
|
c = [0.95, 0.9, 0.85, 0.75, 0.65][severity - 1]
|
|
162
307
|
|
|
163
308
|
x = x.resize((int(32 * c), int(32 * c)), PILImage.BOX)
|
|
@@ -170,6 +315,29 @@ def pixelate(x, severity=1):
|
|
|
170
315
|
|
|
171
316
|
|
|
172
317
|
distortion_methods = collections.OrderedDict()
|
|
318
|
+
"""
|
|
319
|
+
OrderedDict mapping corruption names to their corresponding functions.
|
|
320
|
+
|
|
321
|
+
Available corruptions:
|
|
322
|
+
- "Gaussian Noise": Additive Gaussian noise
|
|
323
|
+
- "Impulse Noise": Salt-and-pepper noise
|
|
324
|
+
- "Motion Blur": Directional motion blur
|
|
325
|
+
- "Contrast": Reduced contrast
|
|
326
|
+
- "Pixelate": Resolution reduction
|
|
327
|
+
- "JPEG": JPEG compression artifacts
|
|
328
|
+
- "Spatter": Water or mud splatter effects
|
|
329
|
+
|
|
330
|
+
Example:
|
|
331
|
+
```python
|
|
332
|
+
from PIL import Image
|
|
333
|
+
from fusion_bench.dataset.image_corruption.make_corruption import distortion_methods
|
|
334
|
+
|
|
335
|
+
img = Image.open("example.jpg")
|
|
336
|
+
for name, corruption_fn in distortion_methods.items():
|
|
337
|
+
corrupted = corruption_fn(img, severity=3)
|
|
338
|
+
# Process corrupted image
|
|
339
|
+
```
|
|
340
|
+
"""
|
|
173
341
|
distortion_methods["Gaussian Noise"] = gaussian_noise
|
|
174
342
|
distortion_methods["Impulse Noise"] = impulse_noise
|
|
175
343
|
distortion_methods["Motion Blur"] = motion_blur
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -144,7 +144,15 @@ _extra_objects = {
|
|
|
144
144
|
|
|
145
145
|
if TYPE_CHECKING:
|
|
146
146
|
from .ada_svd import AdaSVDMergingForCLIPVisionModel
|
|
147
|
-
from .adamerging import
|
|
147
|
+
from .adamerging import (
|
|
148
|
+
CLIPLayerWiseAdaMergingAlgorithm,
|
|
149
|
+
CLIPTaskWiseAdaMergingAlgorithm,
|
|
150
|
+
FlanT5LayerWiseAdaMergingAlgorithm,
|
|
151
|
+
GPT2LayerWiseAdaMergingAlgorithm,
|
|
152
|
+
LayerWiseAdaMergingForLlamaSFT,
|
|
153
|
+
ResNetLayerWiseAdamerging,
|
|
154
|
+
ResNetTaskWiseAdamerging,
|
|
155
|
+
)
|
|
148
156
|
from .analysis import TaskVectorCosSimilarity, TaskVectorViolinPlot
|
|
149
157
|
from .base_algorithm import BaseAlgorithm, BaseModelFusionAlgorithm
|
|
150
158
|
from .bitdelta import BitDeltaAlgorithm
|
|
@@ -167,6 +175,7 @@ if TYPE_CHECKING:
|
|
|
167
175
|
from .dawe import DataAdaptiveWeightEnsemblingForCLIP
|
|
168
176
|
from .depth_upscaling import DepthUpscalingAlgorithm, DepthUpscalingForLlama
|
|
169
177
|
from .doge_ta import DOGE_TA_Algorithm
|
|
178
|
+
from .dop import ContinualDOPForCLIP
|
|
170
179
|
from .dummy import DummyAlgorithm
|
|
171
180
|
from .ensemble import (
|
|
172
181
|
MaxModelPredictorAlgorithm,
|
|
@@ -215,7 +224,6 @@ if TYPE_CHECKING:
|
|
|
215
224
|
from .model_recombination import ModelRecombinationAlgorithm
|
|
216
225
|
from .model_stock import ModelStock
|
|
217
226
|
from .opcm import OPCMForCLIP
|
|
218
|
-
from .dop import ContinualDOPForCLIP
|
|
219
227
|
from .pruning import (
|
|
220
228
|
MagnitudeDiffPruningAlgorithm,
|
|
221
229
|
MagnitudePruningForLlama,
|
|
@@ -40,6 +40,7 @@ from typing import Optional # noqa: F401
|
|
|
40
40
|
|
|
41
41
|
from fusion_bench.mixins import BaseYAMLSerializable
|
|
42
42
|
from fusion_bench.modelpool import BaseModelPool
|
|
43
|
+
from fusion_bench.utils.misc import DeprecationWarningMeta
|
|
43
44
|
|
|
44
45
|
__all__ = ["BaseAlgorithm", "BaseModelFusionAlgorithm"]
|
|
45
46
|
|
|
@@ -202,27 +203,36 @@ class BaseAlgorithm(BaseYAMLSerializable):
|
|
|
202
203
|
pass
|
|
203
204
|
|
|
204
205
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
206
|
+
# Create a deprecated wrapper class that inherits from BaseAlgorithm
|
|
207
|
+
class BaseModelFusionAlgorithm(BaseAlgorithm, metaclass=DeprecationWarningMeta):
|
|
208
|
+
"""
|
|
209
|
+
Alias for BaseAlgorithm class.
|
|
208
210
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
tasks, while others may prefer the shorter 'BaseAlgorithm' name.
|
|
211
|
+
.. deprecated::
|
|
212
|
+
BaseModelFusionAlgorithm is deprecated and will be removed in a future version.
|
|
213
|
+
Use :class:`BaseAlgorithm` instead.
|
|
213
214
|
|
|
214
|
-
|
|
215
|
+
This alias was provided for backward compatibility and semantic clarity.
|
|
216
|
+
Both names refer to the same base class and can be used interchangeably,
|
|
217
|
+
but BaseAlgorithm is now the preferred name for all implementations.
|
|
215
218
|
|
|
216
|
-
Examples:
|
|
217
|
-
|
|
218
|
-
>>> class MyAlgorithm(BaseAlgorithm):
|
|
219
|
-
... def run(self, modelpool): pass
|
|
219
|
+
Examples:
|
|
220
|
+
Preferred (using BaseAlgorithm):
|
|
220
221
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
... def run(self, modelpool): pass
|
|
222
|
+
>>> class MyAlgorithm(BaseAlgorithm):
|
|
223
|
+
... def run(self, modelpool): pass
|
|
224
224
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
225
|
+
Deprecated (using BaseModelFusionAlgorithm):
|
|
226
|
+
|
|
227
|
+
>>> class MyAlgorithm(BaseModelFusionAlgorithm): # Will trigger deprecation warning
|
|
228
|
+
... def run(self, modelpool): pass
|
|
229
|
+
|
|
230
|
+
Note:
|
|
231
|
+
New implementations should use :class:`BaseAlgorithm` exclusively.
|
|
232
|
+
The BaseModelFusionAlgorithm alias will be removed in a future release.
|
|
233
|
+
|
|
234
|
+
Warning:
|
|
235
|
+
Using BaseModelFusionAlgorithm will trigger a DeprecationWarning.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
pass
|
|
@@ -15,7 +15,7 @@ from lightning_utilities.core.rank_zero import rank_zero_only
|
|
|
15
15
|
from lit_learn.lit_modules import ERM_LitModule
|
|
16
16
|
from omegaconf import DictConfig
|
|
17
17
|
from torch import nn
|
|
18
|
-
from torch.utils.data import DataLoader
|
|
18
|
+
from torch.utils.data import DataLoader, random_split
|
|
19
19
|
from torchmetrics.classification import Accuracy
|
|
20
20
|
|
|
21
21
|
from fusion_bench import (
|
|
@@ -29,7 +29,6 @@ from fusion_bench import (
|
|
|
29
29
|
from fusion_bench.dataset import CLIPDataset
|
|
30
30
|
from fusion_bench.modelpool import ResNetForImageClassificationPool
|
|
31
31
|
from fusion_bench.tasks.clip_classification import get_num_classes
|
|
32
|
-
from torch.utils.data import random_split
|
|
33
32
|
|
|
34
33
|
log = get_rankzero_logger(__name__)
|
|
35
34
|
|
|
@@ -13,41 +13,13 @@ from fusion_bench.modelpool import CLIPVisionModelPool
|
|
|
13
13
|
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
14
14
|
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
|
|
15
15
|
from fusion_bench.utils import timeit_context
|
|
16
|
+
from fusion_bench.utils.data import InfiniteDataLoader
|
|
16
17
|
|
|
17
18
|
from .task_wise_gossip import TaskWiseGossipAlgorithm
|
|
18
19
|
|
|
19
20
|
log = logging.getLogger(__name__)
|
|
20
21
|
|
|
21
22
|
|
|
22
|
-
class InfiniteDataLoader:
|
|
23
|
-
"""
|
|
24
|
-
A wrapper class for DataLoader to create an infinite data loader.
|
|
25
|
-
This is useful in case we are only interested in the number of steps and not the number of epochs.
|
|
26
|
-
|
|
27
|
-
This class wraps a DataLoader and provides an iterator that resets
|
|
28
|
-
when the end of the dataset is reached, creating an infinite loop.
|
|
29
|
-
|
|
30
|
-
Attributes:
|
|
31
|
-
data_loader (DataLoader): The DataLoader to wrap.
|
|
32
|
-
data_iter (iterator): An iterator over the DataLoader.
|
|
33
|
-
"""
|
|
34
|
-
|
|
35
|
-
def __init__(self, data_loader):
|
|
36
|
-
self.data_loader = data_loader
|
|
37
|
-
self.data_iter = iter(data_loader)
|
|
38
|
-
|
|
39
|
-
def __iter__(self):
|
|
40
|
-
return self
|
|
41
|
-
|
|
42
|
-
def __next__(self):
|
|
43
|
-
try:
|
|
44
|
-
data = next(self.data_iter)
|
|
45
|
-
except StopIteration:
|
|
46
|
-
self.data_iter = iter(self.data_loader) # Reset the data loader
|
|
47
|
-
data = next(self.data_iter)
|
|
48
|
-
return data
|
|
49
|
-
|
|
50
|
-
|
|
51
23
|
class CLIPTaskWiseGossipAlgorithm(TaskWiseGossipAlgorithm):
|
|
52
24
|
"""
|
|
53
25
|
A class for task-wise adaptive merging of CLIP models.
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import numpy
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from .utility import Metric
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def cosine_similarity(a, b):
|
|
11
|
+
similarity = numpy.sqrt(numpy.dot(a, b) ** 2 / (numpy.dot(a, a) * numpy.dot(b, b)))
|
|
12
|
+
return similarity
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def calculate_model_kinship(
|
|
16
|
+
delta1: numpy.ndarray, delta2: numpy.ndarray, metrics: List[str]
|
|
17
|
+
) -> dict:
|
|
18
|
+
"""
|
|
19
|
+
Calculate model kinship using specified metrics.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
delta1: Delta parameters for first model
|
|
23
|
+
delta2: Delta parameters for second model
|
|
24
|
+
metrics: List of metrics to calculate
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
dict: Dictionary of metric names and their calculated values
|
|
28
|
+
"""
|
|
29
|
+
results = {}
|
|
30
|
+
for metric in metrics:
|
|
31
|
+
try:
|
|
32
|
+
if metric not in Metric.list():
|
|
33
|
+
raise ValueError(f"Unsupported metric: {metric}")
|
|
34
|
+
results[metric] = calculate_metric(delta1, delta2, metric)
|
|
35
|
+
except Exception as e:
|
|
36
|
+
results[metric] = f"Error calculating {metric}: {str(e)}"
|
|
37
|
+
return results
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def calculate_metric(
|
|
41
|
+
d_vector_1: torch.Tensor, d_vector_2: torch.Tensor, metric: str
|
|
42
|
+
) -> str:
|
|
43
|
+
"""
|
|
44
|
+
Calculate the specified metric between two delta vectors.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
d_vector_1 (torch.Tensor): Delta parameters for model 1.
|
|
48
|
+
d_vector_2 (torch.Tensor): Delta parameters for model 2.
|
|
49
|
+
metric (str): The metric to calculate ('pcc', 'ed', 'cs').
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
str: A formatted string with the result of the chosen metric.
|
|
53
|
+
"""
|
|
54
|
+
logging.info(f"Starting calculation of {metric.upper()} metric...")
|
|
55
|
+
|
|
56
|
+
# Pearson Correlation Coefficient (PCC)
|
|
57
|
+
if metric == "pcc":
|
|
58
|
+
# Stack the two vectors and calculate the Pearson correlation coefficient
|
|
59
|
+
stack = torch.stack((d_vector_1, d_vector_2), dim=0)
|
|
60
|
+
pcc = torch.corrcoef(stack)[0, 1].item()
|
|
61
|
+
return f"Model Kinship based on Pearson Correlation Coefficient: {pcc}"
|
|
62
|
+
|
|
63
|
+
# Euclidean Distance (ED)
|
|
64
|
+
elif metric == "ed":
|
|
65
|
+
# Compute the Euclidean distance between the vectors
|
|
66
|
+
distance = torch.dist(d_vector_1, d_vector_2).item()
|
|
67
|
+
return f"Model Kinship based on Euclidean Distance: {distance}"
|
|
68
|
+
|
|
69
|
+
# Cosine Similarity (CS)
|
|
70
|
+
elif metric == "cs":
|
|
71
|
+
# Compute cosine similarity
|
|
72
|
+
cs = cosine_similarity(d_vector_1, d_vector_2)
|
|
73
|
+
return f"Model Kinship based on Cosine Similarity: {cs}"
|
|
74
|
+
|
|
75
|
+
# If metric is not recognized
|
|
76
|
+
else:
|
|
77
|
+
return "Invalid metric specified."
|