fusion-bench 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. fusion_bench/constants/__init__.py +5 -1
  2. fusion_bench/constants/runtime.py +111 -7
  3. fusion_bench/dataset/gsm8k.py +6 -2
  4. fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
  5. fusion_bench/method/__init__.py +10 -2
  6. fusion_bench/method/base_algorithm.py +29 -19
  7. fusion_bench/method/classification/image_classification_finetune.py +1 -2
  8. fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
  9. fusion_bench/metrics/model_kinship/__init__.py +2 -0
  10. fusion_bench/metrics/model_kinship/calculate.py +77 -0
  11. fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
  12. fusion_bench/metrics/model_kinship/utility.py +184 -0
  13. fusion_bench/metrics/nyuv2/__init__.py +31 -0
  14. fusion_bench/metrics/nyuv2/depth.py +30 -0
  15. fusion_bench/metrics/nyuv2/loss.py +40 -0
  16. fusion_bench/metrics/nyuv2/noise.py +24 -0
  17. fusion_bench/metrics/nyuv2/normal.py +34 -1
  18. fusion_bench/metrics/nyuv2/segmentation.py +35 -1
  19. fusion_bench/mixins/clip_classification.py +30 -2
  20. fusion_bench/mixins/lightning_fabric.py +46 -5
  21. fusion_bench/mixins/rich_live.py +76 -0
  22. fusion_bench/modelpool/base_pool.py +86 -5
  23. fusion_bench/models/masks/mask_model.py +8 -2
  24. fusion_bench/models/open_clip/modeling.py +7 -0
  25. fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
  26. fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
  27. fusion_bench/scripts/cli.py +14 -0
  28. fusion_bench/scripts/webui.py +250 -17
  29. fusion_bench/utils/__init__.py +14 -0
  30. fusion_bench/utils/data.py +100 -9
  31. fusion_bench/utils/devices.py +3 -1
  32. fusion_bench/utils/fabric.py +185 -4
  33. fusion_bench/utils/instantiate_utils.py +29 -18
  34. fusion_bench/utils/json.py +6 -0
  35. fusion_bench/utils/misc.py +16 -0
  36. fusion_bench/utils/rich_utils.py +123 -6
  37. fusion_bench/utils/validation.py +197 -0
  38. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/METADATA +72 -13
  39. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/RECORD +49 -45
  40. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
  41. fusion_bench_config/llama_full_finetune.yaml +4 -16
  42. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  43. fusion_bench_config/nyuv2_config.yaml +4 -13
  44. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  45. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  46. fusion_bench/utils/auto.py +0 -31
  47. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/WHEEL +0 -0
  48. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/entry_points.txt +0 -0
  49. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/licenses/LICENSE +0 -0
  50. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/top_level.txt +0 -0
@@ -5,4 +5,8 @@ from .paths import *
5
5
  from .runtime import RuntimeConstants
6
6
 
7
7
  # fusionbench version
8
- FUSION_BENCH_VERSION = importlib.metadata.version("fusion-bench")
8
+ try:
9
+ FUSION_BENCH_VERSION = importlib.metadata.version("fusion-bench")
10
+ except importlib.metadata.PackageNotFoundError:
11
+ # Fallback when package is not installed (e.g., during development)
12
+ FUSION_BENCH_VERSION = "0.0.0.dev"
@@ -1,3 +1,29 @@
1
+ """
2
+ Runtime Constants Module.
3
+
4
+ This module provides a thread-safe singleton class for managing runtime configuration
5
+ and constants across the Fusion Bench framework. It centralizes access to runtime
6
+ settings like cache directories, debug flags, and logging preferences.
7
+
8
+ Example:
9
+ ```python
10
+ from fusion_bench.constants.runtime import RuntimeConstants
11
+
12
+ # Get the singleton instance
13
+ runtime = RuntimeConstants()
14
+
15
+ # Configure cache directory
16
+ runtime.cache_dir = "/custom/cache/path"
17
+
18
+ # Enable debug mode
19
+ runtime.debug = True
20
+
21
+ # Control function call logging
22
+ runtime.print_function_call = True
23
+ ```
24
+ """
25
+
26
+ import os
1
27
  import threading
2
28
  from pathlib import Path
3
29
  from typing import Optional, Union
@@ -5,18 +31,46 @@ from typing import Optional, Union
5
31
 
6
32
  class RuntimeConstants:
7
33
  """
8
- This class holds constants related to the runtime environment of the Fusion Bench framework.
9
- It includes default values for cache directories and other runtime configurations.
34
+ Thread-safe singleton for managing runtime configuration and constants.
35
+
36
+ This class provides centralized access to runtime settings that affect the
37
+ behavior of the entire Fusion Bench framework. It ensures consistent
38
+ configuration across all modules and supports thread-safe access in
39
+ multi-threaded environments.
40
+
41
+ Attributes:
42
+ debug: Global debug flag for enabling verbose logging and debugging features.
43
+
44
+ Example:
45
+ ```python
46
+ runtime = RuntimeConstants()
47
+
48
+ # Configure caching
49
+ runtime.cache_dir = Path.home() / ".cache" / "fusion_bench"
50
+
51
+ # Enable debugging
52
+ runtime.debug = True
53
+ runtime.print_function_call = True
54
+ ```
10
55
 
11
- Implemented as a thread-safe singleton to ensure consistent runtime configuration
12
- across the entire application.
56
+ Note:
57
+ This class implements the singleton pattern with thread-safe initialization.
58
+ Multiple calls to the constructor will return the same instance.
13
59
  """
14
60
 
15
61
  _instance: Optional["RuntimeConstants"] = None
16
62
  _lock = threading.Lock()
17
63
 
18
64
  def __new__(cls) -> "RuntimeConstants":
19
- """Create a new instance using singleton pattern with thread safety."""
65
+ """
66
+ Create or return the singleton instance with thread safety.
67
+
68
+ Uses double-check locking pattern to ensure thread-safe singleton creation
69
+ while minimizing synchronization overhead.
70
+
71
+ Returns:
72
+ The singleton RuntimeConstants instance.
73
+ """
20
74
  with cls._lock:
21
75
  # Double-check locking pattern
22
76
  if cls._instance is None:
@@ -25,33 +79,83 @@ class RuntimeConstants:
25
79
  return cls._instance
26
80
 
27
81
  def __init__(self):
28
- """Initialize the singleton instance only once."""
82
+ """
83
+ Initialize the singleton instance only once.
84
+
85
+ Subsequent calls to __init__ are no-ops to maintain singleton behavior.
86
+ """
29
87
  if not self._initialized:
30
- # Add your runtime constants here
88
+ # Initialize default values
31
89
  self._initialized = True
32
90
 
33
91
  debug = False
92
+ """Global debug flag for enabling verbose logging and debugging features."""
34
93
 
35
94
  @property
36
95
  def cache_dir(self) -> Path:
96
+ """
97
+ Get the default cache directory for models and datasets.
98
+
99
+ Returns:
100
+ Path object pointing to the cache directory.
101
+
102
+ Example:
103
+ ```python
104
+ runtime = RuntimeConstants()
105
+ print(f"Cache directory: {runtime.cache_dir}")
106
+ ```
107
+ """
37
108
  from fusion_bench.utils.cache_utils import DEFAULT_CACHE_DIR
38
109
 
39
110
  return DEFAULT_CACHE_DIR
40
111
 
41
112
  @cache_dir.setter
42
113
  def cache_dir(self, path: Union[str, Path]) -> None:
114
+ """
115
+ Set the default cache directory for models and datasets.
116
+
117
+ Args:
118
+ path: New cache directory path as string or Path object.
119
+
120
+ Example:
121
+ ```python
122
+ runtime = RuntimeConstants()
123
+ runtime.cache_dir = "/data/fusion_bench_cache"
124
+ ```
125
+ """
43
126
  from fusion_bench.utils.cache_utils import set_default_cache_dir
44
127
 
45
128
  set_default_cache_dir(path)
46
129
 
47
130
  @property
48
131
  def print_function_call(self) -> bool:
132
+ """
133
+ Get whether function calls are printed during instantiation.
134
+
135
+ Returns:
136
+ True if function call printing is enabled, False otherwise.
137
+ """
49
138
  from fusion_bench.utils.instantiate_utils import PRINT_FUNCTION_CALL
50
139
 
51
140
  return PRINT_FUNCTION_CALL
52
141
 
53
142
  @print_function_call.setter
54
143
  def print_function_call(self, enable: bool) -> None:
144
+ """
145
+ Set whether to print function calls during instantiation.
146
+
147
+ Useful for debugging to see which functions are being called
148
+ when instantiating objects from configuration.
149
+
150
+ Args:
151
+ enable: True to enable printing, False to disable.
152
+
153
+ Example:
154
+ ```python
155
+ runtime = RuntimeConstants()
156
+ runtime.print_function_call = True # Enable verbose logging
157
+ ```
158
+ """
55
159
  from fusion_bench.utils.instantiate_utils import set_print_function_call
56
160
 
57
161
  set_print_function_call(enable)
@@ -13,8 +13,12 @@ def load_gsm8k_question_label_data(
13
13
 
14
14
  An example in the dataset:
15
15
 
16
- {'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
17
- 'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'}
16
+ ```python
17
+ {
18
+ 'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
19
+ 'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'
20
+ }
21
+ ```
18
22
 
19
23
  Args:
20
24
  dataset_name (Literal["train", "test", "train_socratic", "test_socratic"]): The name of the dataset to load.
@@ -1,4 +1,31 @@
1
1
  # -*- coding: utf-8 -*-
2
+ """
3
+ Image Corruption Module for Robustness Testing.
4
+
5
+ This module provides various image corruption functions to test model robustness.
6
+ It implements common corruptions such as noise, blur, compression artifacts, and
7
+ weather effects. These corruptions are commonly used in benchmark datasets like
8
+ ImageNet-C and CIFAR-10-C.
9
+
10
+ The corruptions can be applied at different severity levels (1-5), where higher
11
+ levels indicate stronger corruption effects.
12
+
13
+ Example:
14
+ ```python
15
+ from PIL import Image
16
+ from fusion_bench.dataset.image_corruption.make_corruption import gaussian_noise, motion_blur
17
+
18
+ # Load an image
19
+ img = Image.open("example.jpg")
20
+
21
+ # Apply gaussian noise at severity level 3
22
+ corrupted_img = gaussian_noise(img, severity=3)
23
+
24
+ # Apply motion blur at severity level 2
25
+ blurred_img = motion_blur(img, severity=2)
26
+ ```
27
+ """
28
+
2
29
  import logging
3
30
 
4
31
  logger = logging.getLogger(__name__)
@@ -37,11 +64,39 @@ warnings.simplefilter("ignore", UserWarning)
37
64
 
38
65
  # /////////////// Distortions ///////////////
39
66
  class MotionImage(WandImage):
67
+ """
68
+ Extended WandImage class with motion blur capability.
69
+
70
+ This class wraps ImageMagick's motion blur functionality through the Wand library.
71
+ """
72
+
40
73
  def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
74
+ """
75
+ Apply motion blur effect to the image.
76
+
77
+ Args:
78
+ radius: The radius of the Gaussian, in pixels, not counting the center pixel.
79
+ sigma: The standard deviation of the Gaussian, in pixels.
80
+ angle: Apply the effect along this angle in degrees.
81
+ """
41
82
  wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)
42
83
 
43
84
 
44
85
  def gaussian_noise(x, severity=1):
86
+ """
87
+ Apply Gaussian noise corruption to an image.
88
+
89
+ Adds random Gaussian noise to the image, simulating sensor noise or
90
+ environmental interference.
91
+
92
+ Args:
93
+ x: Input image as PIL Image or numpy array. If numpy array, should be in
94
+ range [0, 255].
95
+ severity: Corruption severity level from 1 (mild) to 5 (severe).
96
+
97
+ Returns:
98
+ numpy.ndarray: Corrupted image as numpy array in range [0, 255].
99
+ """
45
100
  c = [0.04, 0.06, 0.08, 0.09, 0.10][severity - 1]
46
101
 
47
102
  x = np.array(x) / 255.0
@@ -49,6 +104,20 @@ def gaussian_noise(x, severity=1):
49
104
 
50
105
 
51
106
  def impulse_noise(x, severity=1):
107
+ """
108
+ Apply impulse (salt-and-pepper) noise corruption to an image.
109
+
110
+ Randomly replaces pixels with either maximum or minimum intensity values,
111
+ simulating transmission errors or faulty pixels.
112
+
113
+ Args:
114
+ x: Input image as PIL Image or numpy array. If numpy array, should be in
115
+ range [0, 255].
116
+ severity: Corruption severity level from 1 (mild) to 5 (severe).
117
+
118
+ Returns:
119
+ numpy.ndarray: Corrupted image as numpy array in range [0, 255].
120
+ """
52
121
  c = [0.01, 0.02, 0.03, 0.05, 0.07][severity - 1]
53
122
 
54
123
  x = sk.util.random_noise(np.array(x) / 255.0, mode="s&p", amount=c)
@@ -56,6 +125,21 @@ def impulse_noise(x, severity=1):
56
125
 
57
126
 
58
127
  def motion_blur(x, severity=1):
128
+ """
129
+ Apply motion blur corruption to an image.
130
+
131
+ Simulates camera shake or object motion during image capture by applying
132
+ directional blur at a random angle.
133
+
134
+ Args:
135
+ x: Input PIL Image.
136
+ severity: Corruption severity level from 1 (mild) to 5 (severe).
137
+ Higher severity increases blur radius and sigma.
138
+
139
+ Returns:
140
+ numpy.ndarray: Corrupted image as numpy array in range [0, 255].
141
+ Returns RGB image regardless of input format.
142
+ """
59
143
  c = [(6, 1), (6, 1.5), (6, 2), (8, 2), (9, 2.5)][severity - 1]
60
144
 
61
145
  output = BytesIO()
@@ -73,6 +157,21 @@ def motion_blur(x, severity=1):
73
157
 
74
158
 
75
159
  def spatter(x, severity=1):
160
+ """
161
+ Apply spatter corruption to an image.
162
+
163
+ Simulates liquid splatter effects (water or mud) on the image, creating
164
+ realistic occlusions similar to raindrops or dirt on a camera lens.
165
+
166
+ Args:
167
+ x: Input image as PIL Image or numpy array. If numpy array, should be in
168
+ range [0, 255].
169
+ severity: Corruption severity level from 1 (mild) to 5 (severe).
170
+ Levels 1-3 simulate water splatter, levels 4-5 simulate mud splatter.
171
+
172
+ Returns:
173
+ numpy.ndarray: Corrupted image as numpy array in range [0, 255].
174
+ """
76
175
  c = [
77
176
  (0.62, 0.1, 0.7, 0.7, 0.5, 0),
78
177
  (0.65, 0.1, 0.8, 0.7, 0.5, 0),
@@ -140,6 +239,21 @@ def spatter(x, severity=1):
140
239
 
141
240
 
142
241
  def contrast(x, severity=1):
242
+ """
243
+ Apply contrast reduction corruption to an image.
244
+
245
+ Reduces image contrast by blending pixels toward their mean values,
246
+ simulating poor lighting conditions or low-quality image sensors.
247
+
248
+ Args:
249
+ x: Input image as PIL Image or numpy array. If numpy array, should be in
250
+ range [0, 255].
251
+ severity: Corruption severity level from 1 (mild) to 5 (severe).
252
+ Higher severity results in lower contrast.
253
+
254
+ Returns:
255
+ numpy.ndarray: Corrupted image as numpy array in range [0, 255].
256
+ """
143
257
  c = [0.75, 0.5, 0.4, 0.3, 0.15][severity - 1]
144
258
 
145
259
  x = np.array(x) / 255.0
@@ -148,6 +262,20 @@ def contrast(x, severity=1):
148
262
 
149
263
 
150
264
  def jpeg_compression(x, severity=1):
265
+ """
266
+ Apply JPEG compression artifacts to an image.
267
+
268
+ Simulates compression artifacts from lossy JPEG encoding at various
269
+ quality levels, commonly seen in heavily compressed images.
270
+
271
+ Args:
272
+ x: Input PIL Image.
273
+ severity: Corruption severity level from 1 (mild) to 5 (severe).
274
+ Lower severity uses higher JPEG quality (less compression).
275
+
276
+ Returns:
277
+ PIL.Image: Corrupted image as PIL Image.
278
+ """
151
279
  c = [80, 65, 58, 50, 40][severity - 1]
152
280
 
153
281
  output = BytesIO()
@@ -158,6 +286,23 @@ def jpeg_compression(x, severity=1):
158
286
 
159
287
 
160
288
  def pixelate(x, severity=1):
289
+ """
290
+ Apply pixelation corruption to an image.
291
+
292
+ Reduces image resolution by downsampling and then upsampling,
293
+ creating a blocky, pixelated appearance.
294
+
295
+ Args:
296
+ x: Input PIL Image with size (32, 32).
297
+ severity: Corruption severity level from 1 (mild) to 5 (severe).
298
+ Higher severity results in more pixelation.
299
+
300
+ Returns:
301
+ PIL.Image: Corrupted image as PIL Image with original size (32, 32).
302
+
303
+ Note:
304
+ This function is designed for 32x32 images (e.g., CIFAR-10).
305
+ """
161
306
  c = [0.95, 0.9, 0.85, 0.75, 0.65][severity - 1]
162
307
 
163
308
  x = x.resize((int(32 * c), int(32 * c)), PILImage.BOX)
@@ -170,6 +315,29 @@ def pixelate(x, severity=1):
170
315
 
171
316
 
172
317
  distortion_methods = collections.OrderedDict()
318
+ """
319
+ OrderedDict mapping corruption names to their corresponding functions.
320
+
321
+ Available corruptions:
322
+ - "Gaussian Noise": Additive Gaussian noise
323
+ - "Impulse Noise": Salt-and-pepper noise
324
+ - "Motion Blur": Directional motion blur
325
+ - "Contrast": Reduced contrast
326
+ - "Pixelate": Resolution reduction
327
+ - "JPEG": JPEG compression artifacts
328
+ - "Spatter": Water or mud splatter effects
329
+
330
+ Example:
331
+ ```python
332
+ from PIL import Image
333
+ from fusion_bench.dataset.image_corruption.make_corruption import distortion_methods
334
+
335
+ img = Image.open("example.jpg")
336
+ for name, corruption_fn in distortion_methods.items():
337
+ corrupted = corruption_fn(img, severity=3)
338
+ # Process corrupted image
339
+ ```
340
+ """
173
341
  distortion_methods["Gaussian Noise"] = gaussian_noise
174
342
  distortion_methods["Impulse Noise"] = impulse_noise
175
343
  distortion_methods["Motion Blur"] = motion_blur
@@ -144,7 +144,15 @@ _extra_objects = {
144
144
 
145
145
  if TYPE_CHECKING:
146
146
  from .ada_svd import AdaSVDMergingForCLIPVisionModel
147
- from .adamerging import *
147
+ from .adamerging import (
148
+ CLIPLayerWiseAdaMergingAlgorithm,
149
+ CLIPTaskWiseAdaMergingAlgorithm,
150
+ FlanT5LayerWiseAdaMergingAlgorithm,
151
+ GPT2LayerWiseAdaMergingAlgorithm,
152
+ LayerWiseAdaMergingForLlamaSFT,
153
+ ResNetLayerWiseAdamerging,
154
+ ResNetTaskWiseAdamerging,
155
+ )
148
156
  from .analysis import TaskVectorCosSimilarity, TaskVectorViolinPlot
149
157
  from .base_algorithm import BaseAlgorithm, BaseModelFusionAlgorithm
150
158
  from .bitdelta import BitDeltaAlgorithm
@@ -167,6 +175,7 @@ if TYPE_CHECKING:
167
175
  from .dawe import DataAdaptiveWeightEnsemblingForCLIP
168
176
  from .depth_upscaling import DepthUpscalingAlgorithm, DepthUpscalingForLlama
169
177
  from .doge_ta import DOGE_TA_Algorithm
178
+ from .dop import ContinualDOPForCLIP
170
179
  from .dummy import DummyAlgorithm
171
180
  from .ensemble import (
172
181
  MaxModelPredictorAlgorithm,
@@ -215,7 +224,6 @@ if TYPE_CHECKING:
215
224
  from .model_recombination import ModelRecombinationAlgorithm
216
225
  from .model_stock import ModelStock
217
226
  from .opcm import OPCMForCLIP
218
- from .dop import ContinualDOPForCLIP
219
227
  from .pruning import (
220
228
  MagnitudeDiffPruningAlgorithm,
221
229
  MagnitudePruningForLlama,
@@ -40,6 +40,7 @@ from typing import Optional # noqa: F401
40
40
 
41
41
  from fusion_bench.mixins import BaseYAMLSerializable
42
42
  from fusion_bench.modelpool import BaseModelPool
43
+ from fusion_bench.utils.misc import DeprecationWarningMeta
43
44
 
44
45
  __all__ = ["BaseAlgorithm", "BaseModelFusionAlgorithm"]
45
46
 
@@ -202,27 +203,36 @@ class BaseAlgorithm(BaseYAMLSerializable):
202
203
  pass
203
204
 
204
205
 
205
- BaseModelFusionAlgorithm = BaseAlgorithm
206
- """
207
- Alias for BaseAlgorithm class.
206
+ # Create a deprecated wrapper class that inherits from BaseAlgorithm
207
+ class BaseModelFusionAlgorithm(BaseAlgorithm, metaclass=DeprecationWarningMeta):
208
+ """
209
+ Alias for BaseAlgorithm class.
208
210
 
209
- This alias is provided for backward compatibility and semantic clarity.
210
- Some users may prefer the more explicit name 'BaseModelFusionAlgorithm'
211
- to emphasize that this class is specifically designed for model fusion
212
- tasks, while others may prefer the shorter 'BaseAlgorithm' name.
211
+ .. deprecated::
212
+ BaseModelFusionAlgorithm is deprecated and will be removed in a future version.
213
+ Use :class:`BaseAlgorithm` instead.
213
214
 
214
- Both names refer to the exact same class and can be used interchangeably.
215
+ This alias was provided for backward compatibility and semantic clarity.
216
+ Both names refer to the same base class and can be used interchangeably,
217
+ but BaseAlgorithm is now the preferred name for all implementations.
215
218
 
216
- Examples:
217
- Using the original name:
218
- >>> class MyAlgorithm(BaseAlgorithm):
219
- ... def run(self, modelpool): pass
219
+ Examples:
220
+ Preferred (using BaseAlgorithm):
220
221
 
221
- Using the alias:
222
- >>> class MyAlgorithm(BaseModelFusionAlgorithm):
223
- ... def run(self, modelpool): pass
222
+ >>> class MyAlgorithm(BaseAlgorithm):
223
+ ... def run(self, modelpool): pass
224
224
 
225
- Note:
226
- The alias is maintained for compatibility but BaseAlgorithm is the
227
- preferred name for new implementations.
228
- """
225
+ Deprecated (using BaseModelFusionAlgorithm):
226
+
227
+ >>> class MyAlgorithm(BaseModelFusionAlgorithm): # Will trigger deprecation warning
228
+ ... def run(self, modelpool): pass
229
+
230
+ Note:
231
+ New implementations should use :class:`BaseAlgorithm` exclusively.
232
+ The BaseModelFusionAlgorithm alias will be removed in a future release.
233
+
234
+ Warning:
235
+ Using BaseModelFusionAlgorithm will trigger a DeprecationWarning.
236
+ """
237
+
238
+ pass
@@ -15,7 +15,7 @@ from lightning_utilities.core.rank_zero import rank_zero_only
15
15
  from lit_learn.lit_modules import ERM_LitModule
16
16
  from omegaconf import DictConfig
17
17
  from torch import nn
18
- from torch.utils.data import DataLoader
18
+ from torch.utils.data import DataLoader, random_split
19
19
  from torchmetrics.classification import Accuracy
20
20
 
21
21
  from fusion_bench import (
@@ -29,7 +29,6 @@ from fusion_bench import (
29
29
  from fusion_bench.dataset import CLIPDataset
30
30
  from fusion_bench.modelpool import ResNetForImageClassificationPool
31
31
  from fusion_bench.tasks.clip_classification import get_num_classes
32
- from torch.utils.data import random_split
33
32
 
34
33
  log = get_rankzero_logger(__name__)
35
34
 
@@ -13,41 +13,13 @@ from fusion_bench.modelpool import CLIPVisionModelPool
13
13
  from fusion_bench.models.hf_clip import HFCLIPClassifier
14
14
  from fusion_bench.tasks.clip_classification import get_classnames_and_templates
15
15
  from fusion_bench.utils import timeit_context
16
+ from fusion_bench.utils.data import InfiniteDataLoader
16
17
 
17
18
  from .task_wise_gossip import TaskWiseGossipAlgorithm
18
19
 
19
20
  log = logging.getLogger(__name__)
20
21
 
21
22
 
22
- class InfiniteDataLoader:
23
- """
24
- A wrapper class for DataLoader to create an infinite data loader.
25
- This is useful in case we are only interested in the number of steps and not the number of epochs.
26
-
27
- This class wraps a DataLoader and provides an iterator that resets
28
- when the end of the dataset is reached, creating an infinite loop.
29
-
30
- Attributes:
31
- data_loader (DataLoader): The DataLoader to wrap.
32
- data_iter (iterator): An iterator over the DataLoader.
33
- """
34
-
35
- def __init__(self, data_loader):
36
- self.data_loader = data_loader
37
- self.data_iter = iter(data_loader)
38
-
39
- def __iter__(self):
40
- return self
41
-
42
- def __next__(self):
43
- try:
44
- data = next(self.data_iter)
45
- except StopIteration:
46
- self.data_iter = iter(self.data_loader) # Reset the data loader
47
- data = next(self.data_iter)
48
- return data
49
-
50
-
51
23
  class CLIPTaskWiseGossipAlgorithm(TaskWiseGossipAlgorithm):
52
24
  """
53
25
  A class for task-wise adaptive merging of CLIP models.
@@ -0,0 +1,2 @@
1
+ # Exploring Model Kinship for Merging LLMs
2
+ # The implementation of this module is borrowed from: https://github.com/zjunlp/ModelKinship/
@@ -0,0 +1,77 @@
1
+ import logging
2
+ from typing import List
3
+
4
+ import numpy
5
+ import torch
6
+
7
+ from .utility import Metric
8
+
9
+
10
+ def cosine_similarity(a, b):
11
+ similarity = numpy.sqrt(numpy.dot(a, b) ** 2 / (numpy.dot(a, a) * numpy.dot(b, b)))
12
+ return similarity
13
+
14
+
15
+ def calculate_model_kinship(
16
+ delta1: numpy.ndarray, delta2: numpy.ndarray, metrics: List[str]
17
+ ) -> dict:
18
+ """
19
+ Calculate model kinship using specified metrics.
20
+
21
+ Args:
22
+ delta1: Delta parameters for first model
23
+ delta2: Delta parameters for second model
24
+ metrics: List of metrics to calculate
25
+
26
+ Returns:
27
+ dict: Dictionary of metric names and their calculated values
28
+ """
29
+ results = {}
30
+ for metric in metrics:
31
+ try:
32
+ if metric not in Metric.list():
33
+ raise ValueError(f"Unsupported metric: {metric}")
34
+ results[metric] = calculate_metric(delta1, delta2, metric)
35
+ except Exception as e:
36
+ results[metric] = f"Error calculating {metric}: {str(e)}"
37
+ return results
38
+
39
+
40
+ def calculate_metric(
41
+ d_vector_1: torch.Tensor, d_vector_2: torch.Tensor, metric: str
42
+ ) -> str:
43
+ """
44
+ Calculate the specified metric between two delta vectors.
45
+
46
+ Args:
47
+ d_vector_1 (torch.Tensor): Delta parameters for model 1.
48
+ d_vector_2 (torch.Tensor): Delta parameters for model 2.
49
+ metric (str): The metric to calculate ('pcc', 'ed', 'cs').
50
+
51
+ Returns:
52
+ str: A formatted string with the result of the chosen metric.
53
+ """
54
+ logging.info(f"Starting calculation of {metric.upper()} metric...")
55
+
56
+ # Pearson Correlation Coefficient (PCC)
57
+ if metric == "pcc":
58
+ # Stack the two vectors and calculate the Pearson correlation coefficient
59
+ stack = torch.stack((d_vector_1, d_vector_2), dim=0)
60
+ pcc = torch.corrcoef(stack)[0, 1].item()
61
+ return f"Model Kinship based on Pearson Correlation Coefficient: {pcc}"
62
+
63
+ # Euclidean Distance (ED)
64
+ elif metric == "ed":
65
+ # Compute the Euclidean distance between the vectors
66
+ distance = torch.dist(d_vector_1, d_vector_2).item()
67
+ return f"Model Kinship based on Euclidean Distance: {distance}"
68
+
69
+ # Cosine Similarity (CS)
70
+ elif metric == "cs":
71
+ # Compute cosine similarity
72
+ cs = cosine_similarity(d_vector_1, d_vector_2)
73
+ return f"Model Kinship based on Cosine Similarity: {cs}"
74
+
75
+ # If metric is not recognized
76
+ else:
77
+ return "Invalid metric specified."