fusion-bench 0.2.27__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. fusion_bench/__init__.py +4 -0
  2. fusion_bench/constants/__init__.py +5 -1
  3. fusion_bench/constants/runtime.py +111 -7
  4. fusion_bench/dataset/gsm8k.py +6 -2
  5. fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
  6. fusion_bench/method/__init__.py +1 -1
  7. fusion_bench/method/classification/image_classification_finetune.py +13 -2
  8. fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
  9. fusion_bench/method/task_arithmetic/task_arithmetic.py +4 -1
  10. fusion_bench/metrics/nyuv2/__init__.py +31 -0
  11. fusion_bench/metrics/nyuv2/depth.py +30 -0
  12. fusion_bench/metrics/nyuv2/loss.py +40 -0
  13. fusion_bench/metrics/nyuv2/noise.py +24 -0
  14. fusion_bench/metrics/nyuv2/normal.py +34 -1
  15. fusion_bench/metrics/nyuv2/segmentation.py +35 -1
  16. fusion_bench/mixins/clip_classification.py +30 -2
  17. fusion_bench/mixins/lightning_fabric.py +46 -5
  18. fusion_bench/mixins/rich_live.py +76 -0
  19. fusion_bench/modelpool/__init__.py +24 -2
  20. fusion_bench/modelpool/base_pool.py +94 -6
  21. fusion_bench/modelpool/convnext_for_image_classification.py +198 -0
  22. fusion_bench/modelpool/dinov2_for_image_classification.py +197 -0
  23. fusion_bench/modelpool/resnet_for_image_classification.py +4 -1
  24. fusion_bench/models/model_card_templates/default.md +1 -1
  25. fusion_bench/scripts/webui.py +250 -17
  26. fusion_bench/utils/__init__.py +14 -0
  27. fusion_bench/utils/data.py +100 -9
  28. fusion_bench/utils/fabric.py +185 -4
  29. fusion_bench/utils/json.py +55 -8
  30. fusion_bench/utils/validation.py +197 -0
  31. {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/METADATA +66 -7
  32. {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/RECORD +44 -40
  33. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
  34. fusion_bench_config/llama_full_finetune.yaml +4 -16
  35. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  36. fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224.yaml +10 -0
  37. fusion_bench_config/modelpool/Dinov2ForImageClassification/dinov2-base-imagenet1k-1-layer.yaml +10 -0
  38. fusion_bench_config/nyuv2_config.yaml +4 -13
  39. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  40. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  41. fusion_bench/utils/auto.py +0 -31
  42. {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/WHEEL +0 -0
  43. {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/entry_points.txt +0 -0
  44. {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/licenses/LICENSE +0 -0
  45. {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/top_level.txt +0 -0
@@ -3,10 +3,35 @@ from torch import Tensor, nn
3
3
 
4
4
 
5
5
  def segmentation_loss(pred: Tensor, gt: Tensor):
6
+ """
7
+ Compute cross-entropy loss for semantic segmentation.
8
+
9
+ Args:
10
+ pred: Predicted segmentation logits of shape (batch_size, num_classes, height, width).
11
+ gt: Ground truth segmentation labels of shape (batch_size, height, width).
12
+ Pixels with value -1 are ignored in the loss computation.
13
+
14
+ Returns:
15
+ Tensor: Scalar loss value.
16
+ """
6
17
  return nn.functional.cross_entropy(pred, gt.long(), ignore_index=-1)
7
18
 
8
19
 
9
20
  def depth_loss(pred: Tensor, gt: Tensor):
21
+ """
22
+ Compute L1 loss for depth estimation with binary masking.
23
+
24
+ This loss function calculates the absolute error between predicted and ground truth
25
+ depth values, but only for valid pixels (where ground truth depth is non-zero).
26
+
27
+ Args:
28
+ pred: Predicted depth values of shape (batch_size, 1, height, width).
29
+ gt: Ground truth depth values of shape (batch_size, 1, height, width).
30
+ Pixels with sum of 0 across channels are considered invalid and masked out.
31
+
32
+ Returns:
33
+ Tensor: Scalar loss value averaged over valid pixels.
34
+ """
10
35
  binary_mask = (torch.sum(gt, dim=1) != 0).float().unsqueeze(1).to(pred.device)
11
36
  loss = torch.sum(torch.abs(pred - gt) * binary_mask) / torch.nonzero(
12
37
  binary_mask, as_tuple=False
@@ -15,6 +40,21 @@ def depth_loss(pred: Tensor, gt: Tensor):
15
40
 
16
41
 
17
42
  def normal_loss(pred: Tensor, gt: Tensor):
43
+ """
44
+ Compute cosine similarity loss for surface normal prediction.
45
+
46
+ This loss measures the angular difference between predicted and ground truth
47
+ surface normals using normalized cosine similarity (1 - dot product).
48
+
49
+ Args:
50
+ pred: Predicted surface normals of shape (batch_size, 3, height, width).
51
+ Will be L2-normalized before computing loss.
52
+ gt: Ground truth surface normals of shape (batch_size, 3, height, width).
53
+ Already normalized on NYUv2 dataset. Pixels with sum of 0 are invalid.
54
+
55
+ Returns:
56
+ Tensor: Scalar loss value (1 - mean cosine similarity) over valid pixels.
57
+ """
18
58
  # gt has been normalized on the NYUv2 dataset
19
59
  pred = pred / torch.norm(pred, p=2, dim=1, keepdim=True)
20
60
  binary_mask = (torch.sum(gt, dim=1) != 0).float().unsqueeze(1).to(pred.device)
@@ -6,11 +6,35 @@ from torchmetrics import Metric
6
6
 
7
7
 
8
8
  class NoiseMetric(Metric):
9
+ """
10
+ A placeholder metric for noise evaluation on NYUv2 dataset.
11
+
12
+ This metric currently serves as a placeholder and always returns a value of 1.
13
+ It can be extended in the future to include actual noise-related metrics.
14
+
15
+ Note:
16
+ This is a dummy implementation that doesn't perform actual noise measurements.
17
+ """
18
+
9
19
  def __init__(self):
20
+ """Initialize the NoiseMetric."""
10
21
  super().__init__()
11
22
 
12
23
  def update(self, preds: Tensor, target: Tensor):
24
+ """
25
+ Update metric state (currently a no-op).
26
+
27
+ Args:
28
+ preds: Predicted values (unused).
29
+ target: Ground truth values (unused).
30
+ """
13
31
  pass
14
32
 
15
33
  def compute(self):
34
+ """
35
+ Compute the metric value.
36
+
37
+ Returns:
38
+ List[int]: A list containing [1] as a placeholder value.
39
+ """
16
40
  return [1]
@@ -7,14 +7,36 @@ from torchmetrics import Metric
7
7
 
8
8
 
9
9
  class NormalMetric(Metric):
10
+ """
11
+ Metric for evaluating surface normal prediction on NYUv2 dataset.
12
+
13
+ This metric computes angular error statistics between predicted and ground truth
14
+ surface normals, including mean, median, and percentage of predictions within
15
+ specific angular thresholds (11.25°, 22.5°, 30°).
16
+
17
+ Attributes:
18
+ metric_names: List of metric names ["mean", "median", "<11.25", "<22.5", "<30"].
19
+ record: List storing angular errors (in degrees) for all pixels across batches.
20
+ """
21
+
10
22
  metric_names = ["mean", "median", "<11.25", "<22.5", "<30"]
11
23
 
12
24
  def __init__(self):
25
+ """Initialize the NormalMetric with state for recording angular errors."""
13
26
  super(NormalMetric, self).__init__()
14
27
 
15
28
  self.add_state("record", default=[], dist_reduce_fx="cat")
16
29
 
17
30
  def update(self, preds, target):
31
+ """
32
+ Update metric state with predictions and targets from a batch.
33
+
34
+ Args:
35
+ preds: Predicted surface normals of shape (batch_size, 3, height, width).
36
+ Will be L2-normalized before computing errors.
37
+ target: Ground truth surface normals of shape (batch_size, 3, height, width).
38
+ Already normalized on NYUv2 dataset. Pixels with sum of 0 are invalid.
39
+ """
18
40
  # gt has been normalized on the NYUv2 dataset
19
41
  preds = preds / torch.norm(preds, p=2, dim=1, keepdim=True)
20
42
  binary_mask = torch.sum(target, dim=1) != 0
@@ -33,7 +55,18 @@ class NormalMetric(Metric):
33
55
 
34
56
  def compute(self):
35
57
  """
36
- returns mean, median, and percentage of pixels with error less than 11.25, 22.5, and 30 degrees ("mean", "median", "<11.25", "<22.5", "<30")
58
+ Compute final metric values from all recorded angular errors.
59
+
60
+ Returns:
61
+ List[Tensor]: A list containing five metrics:
62
+ - mean: Mean angular error in degrees.
63
+ - median: Median angular error in degrees.
64
+ - <11.25: Percentage of pixels with error < 11.25°.
65
+ - <22.5: Percentage of pixels with error < 22.5°.
66
+ - <30: Percentage of pixels with error < 30°.
67
+
68
+ Note:
69
+ Returns zeros if no data has been recorded.
37
70
  """
38
71
  if self.record is None:
39
72
  return torch.asarray([0.0, 0.0, 0.0, 0.0, 0.0])
@@ -6,9 +6,28 @@ from torchmetrics import Metric
6
6
 
7
7
 
8
8
  class SegmentationMetric(Metric):
9
+ """
10
+ Metric for evaluating semantic segmentation on NYUv2 dataset.
11
+
12
+ This metric computes mean Intersection over Union (mIoU) and pixel accuracy
13
+ for multi-class segmentation tasks.
14
+
15
+ Attributes:
16
+ metric_names: List of metric names ["mIoU", "pixAcc"].
17
+ num_classes: Number of segmentation classes (default: 13 for NYUv2).
18
+ record: Confusion matrix of shape (num_classes, num_classes) tracking
19
+ predictions vs ground truth.
20
+ """
21
+
9
22
  metric_names = ["mIoU", "pixAcc"]
10
23
 
11
24
  def __init__(self, num_classes=13):
25
+ """
26
+ Initialize the SegmentationMetric.
27
+
28
+ Args:
29
+ num_classes: Number of segmentation classes. Default is 13 for NYUv2 dataset.
30
+ """
12
31
  super().__init__()
13
32
 
14
33
  self.num_classes = num_classes
@@ -21,9 +40,19 @@ class SegmentationMetric(Metric):
21
40
  )
22
41
 
23
42
  def reset(self):
43
+ """Reset the confusion matrix to zeros."""
24
44
  self.record.zero_()
25
45
 
26
46
  def update(self, preds: Tensor, target: Tensor):
47
+ """
48
+ Update the confusion matrix with predictions and targets from a batch.
49
+
50
+ Args:
51
+ preds: Predicted segmentation logits of shape (batch_size, num_classes, height, width).
52
+ Will be converted to class predictions via softmax and argmax.
53
+ target: Ground truth segmentation labels of shape (batch_size, height, width).
54
+ Pixels with negative values or values >= num_classes are ignored.
55
+ """
27
56
  preds = preds.softmax(1).argmax(1).flatten()
28
57
  target = target.long().flatten()
29
58
 
@@ -35,7 +64,12 @@ class SegmentationMetric(Metric):
35
64
 
36
65
  def compute(self):
37
66
  """
38
- return mIoU and pixel accuracy
67
+ Compute mIoU and pixel accuracy from the confusion matrix.
68
+
69
+ Returns:
70
+ List[Tensor]: A list containing [mIoU, pixel_accuracy]:
71
+ - mIoU: Mean Intersection over Union across all classes.
72
+ - pixel_accuracy: Overall pixel classification accuracy.
39
73
  """
40
74
  h = cast(Tensor, self.record).float()
41
75
  iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
@@ -59,6 +59,15 @@ class CLIPClassificationMixin(LightningFabricMixin):
59
59
 
60
60
  @property
61
61
  def clip_processor(self):
62
+ """
63
+ Get the CLIP processor, loading it from the model pool if necessary.
64
+
65
+ Returns:
66
+ CLIPProcessor: The CLIP processor for image and text preprocessing.
67
+
68
+ Raises:
69
+ AssertionError: If the model pool is not set.
70
+ """
62
71
  if self._clip_processor is None:
63
72
  assert self.modelpool is not None, "Model pool is not set"
64
73
  self._clip_processor = self.modelpool.load_processor()
@@ -125,6 +134,11 @@ class CLIPClassificationMixin(LightningFabricMixin):
125
134
  clip_model (Optional[CLIPModel]): The CLIP model to use. If not provided, a pretrained model is loaded from the model pool.
126
135
  task_names (Optional[List[str]]): A list of task names to set up the classification head for. If not provided, all models in the model pool will be used.
127
136
  """
137
+ # make sure the task names are equal across all processes
138
+ _task_names = self.fabric.broadcast(task_names, src=0)
139
+ if not self.fabric.is_global_zero and task_names != _task_names:
140
+ raise ValueError("The `task_names` must be the same across all processes.")
141
+
128
142
  self.whether_setup_zero_shot_classification_head = True
129
143
  # load clip model if not provided
130
144
  if clip_model is None:
@@ -147,7 +161,10 @@ class CLIPClassificationMixin(LightningFabricMixin):
147
161
  self.logit_scale_exp = self.fabric.to_device(self.logit_scale_exp)
148
162
 
149
163
  @cache_with_joblib()
150
- def construct_classification_head(task: str):
164
+ def construct_classification_head(task: str, model_name: str):
165
+ log.info(
166
+ f"Constructing zero-shot classification head for task: {task} using model: {model_name}"
167
+ )
151
168
  nonlocal clip_classifier
152
169
 
153
170
  classnames, templates = get_classnames_and_templates(task)
@@ -163,7 +180,18 @@ class CLIPClassificationMixin(LightningFabricMixin):
163
180
  ):
164
181
  zeroshot_weights = None
165
182
  if self.fabric.is_global_zero:
166
- zeroshot_weights = construct_classification_head(task)
183
+ if hasattr(clip_model, "config") and hasattr(
184
+ clip_model.config, "_name_or_path"
185
+ ):
186
+ model_name = clip_model.config._name_or_path
187
+ else:
188
+ model_name = "unknown_model"
189
+ log.warning(
190
+ "CLIP model config does not have `_name_or_path` attribute. Using 'unknown_model' as model name."
191
+ )
192
+ zeroshot_weights = construct_classification_head(
193
+ task, model_name=model_name
194
+ )
167
195
 
168
196
  self.fabric.barrier()
169
197
  self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)
@@ -1,7 +1,7 @@
1
1
  import functools
2
2
  import logging
3
3
  import os
4
- from typing import TYPE_CHECKING, Any, List, Optional, TypeVar
4
+ from typing import TYPE_CHECKING, Any, List, Mapping, Optional, TypeVar
5
5
 
6
6
  import lightning as L
7
7
  import torch
@@ -96,12 +96,24 @@ class LightningFabricMixin:
96
96
 
97
97
  @property
98
98
  def fabric(self):
99
+ """
100
+ Get the Lightning Fabric instance, initializing it if necessary.
101
+
102
+ Returns:
103
+ L.Fabric: The Lightning Fabric instance for distributed computing.
104
+ """
99
105
  if self._fabric_instance is None:
100
106
  self.setup_lightning_fabric(getattr(self, "config", DictConfig({})))
101
107
  return self._fabric_instance
102
108
 
103
109
  @fabric.setter
104
110
  def fabric(self, instance: L.Fabric):
111
+ """
112
+ Set the Lightning Fabric instance.
113
+
114
+ Args:
115
+ instance: The Lightning Fabric instance to use.
116
+ """
105
117
  self._fabric_instance = instance
106
118
 
107
119
  @property
@@ -172,6 +184,15 @@ class LightningFabricMixin:
172
184
  def tensorboard_summarywriter(
173
185
  self,
174
186
  ) -> "lightning.fabric.loggers.tensorboard.SummaryWriter":
187
+ """
188
+ Get the TensorBoard SummaryWriter for detailed logging.
189
+
190
+ Returns:
191
+ SummaryWriter: The TensorBoard SummaryWriter instance.
192
+
193
+ Raises:
194
+ AttributeError: If the logger is not a TensorBoardLogger.
195
+ """
175
196
  if isinstance(self.fabric.logger, TensorBoardLogger):
176
197
  return self.fabric.logger.experiment
177
198
  else:
@@ -179,6 +200,12 @@ class LightningFabricMixin:
179
200
 
180
201
  @property
181
202
  def is_debug_mode(self):
203
+ """
204
+ Check if the program is running in debug mode (fast_dev_run).
205
+
206
+ Returns:
207
+ bool: True if fast_dev_run is enabled, False otherwise.
208
+ """
182
209
  if hasattr(self, "config") and self.config.get("fast_dev_run", False):
183
210
  return True
184
211
  elif hasattr(self, "_program") and self._program.config.get(
@@ -190,13 +217,22 @@ class LightningFabricMixin:
190
217
 
191
218
  def log(self, name: str, value: Any, step: Optional[int] = None):
192
219
  """
193
- Logs the metric to the fabric's logger.
220
+ Logs a single metric to the fabric's logger.
221
+
222
+ Args:
223
+ name: The name of the metric to log.
224
+ value: The value of the metric.
225
+ step: Optional step number for the metric.
194
226
  """
195
227
  self.fabric.log(name, value, step=step)
196
228
 
197
- def log_dict(self, metrics: dict, step: Optional[int] = None):
229
+ def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None):
198
230
  """
199
- Logs the metrics to the fabric's logger.
231
+ Logs multiple metrics to the fabric's logger.
232
+
233
+ Args:
234
+ metrics: Dictionary of metric names and values.
235
+ step: Optional step number for the metrics.
200
236
  """
201
237
  self.fabric.log_dict(metrics, step=step)
202
238
 
@@ -207,7 +243,12 @@ class LightningFabricMixin:
207
243
  name_template: str = "train/lr_group_{0}",
208
244
  ):
209
245
  """
210
- Logs the learning rate of the optimizer to the fabric's logger.
246
+ Logs the learning rate of each parameter group in the optimizer.
247
+
248
+ Args:
249
+ optimizer: The optimizer whose learning rates should be logged.
250
+ step: Optional step number for the log entry.
251
+ name_template: Template string for the log name. Use {0} as placeholder for group index.
211
252
  """
212
253
  for i, param_group in enumerate(optimizer.param_groups):
213
254
  self.fabric.log(name_template.format(i), param_group["lr"], step=step)
@@ -2,20 +2,96 @@ from rich.live import Live
2
2
 
3
3
 
4
4
  class RichLiveMixin:
5
+ """
6
+ A mixin class that provides Rich Live display capabilities.
7
+
8
+ This mixin integrates Rich's Live display functionality, allowing for
9
+ dynamic, auto-refreshing console output. It's particularly useful for
10
+ displaying real-time updates, progress information, or continuously
11
+ changing data without cluttering the terminal.
12
+
13
+ Attributes:
14
+ _rich_live (Live): The internal Rich Live instance for live display updates.
15
+
16
+ Example:
17
+ ```python
18
+ class MyTask(RichLiveMixin):
19
+ def run(self):
20
+ self.start_rich_live()
21
+ for i in range(100):
22
+ self.rich_live_print(f"Processing item {i}")
23
+ time.sleep(0.1)
24
+ self.stop_rich_live()
25
+ ```
26
+ """
27
+
5
28
  _rich_live: Live = None
6
29
 
7
30
  @property
8
31
  def rich_live(self) -> Live:
32
+ """
33
+ Get the Rich Live instance, creating it if necessary.
34
+
35
+ Returns:
36
+ Live: The Rich Live instance for dynamic console output.
37
+ """
9
38
  if self._rich_live is None:
10
39
  self._rich_live = Live()
11
40
  return self._rich_live
12
41
 
13
42
  def start_rich_live(self):
43
+ """
44
+ Start the Rich Live display context.
45
+
46
+ This method enters the Rich Live context, enabling dynamic console output.
47
+ Must be paired with stop_rich_live() to properly clean up resources.
48
+
49
+ Returns:
50
+ The Rich Live instance in its started state.
51
+
52
+ Example:
53
+ ```python
54
+ self.start_rich_live()
55
+ # Display dynamic content
56
+ self.rich_live_print("Dynamic output")
57
+ self.stop_rich_live()
58
+ ```
59
+ """
14
60
  return self.rich_live.__enter__()
15
61
 
16
62
  def stop_rich_live(self):
63
+ """
64
+ Stop the Rich Live display context and clean up resources.
65
+
66
+ This method exits the Rich Live context and resets the internal Live instance.
67
+ Should be called after start_rich_live() when dynamic display is complete.
68
+
69
+ Example:
70
+ ```python
71
+ self.start_rich_live()
72
+ # ... display content ...
73
+ self.stop_rich_live()
74
+ ```
75
+ """
17
76
  self.rich_live.__exit__(None, None, None)
18
77
  self._rich_live = None
19
78
 
20
79
  def rich_live_print(self, msg):
80
+ """
81
+ Print a message to the Rich Live console.
82
+
83
+ This method displays the given message through the Rich Live console,
84
+ allowing for formatted, dynamic output that updates in place.
85
+
86
+ Args:
87
+ msg: The message to display. Can be a string or any Rich renderable object.
88
+
89
+ Example:
90
+ ```python
91
+ self.start_rich_live()
92
+ self.rich_live_print("[bold green]Success![/bold green]")
93
+ self.rich_live_print(Panel("Status: Running"))
94
+ self.stop_rich_live()
95
+ ```
96
+ """
21
97
  self.rich_live.console.print(msg)
@@ -8,6 +8,14 @@ _import_structure = {
8
8
  "base_pool": ["BaseModelPool"],
9
9
  "causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
10
10
  "clip_vision": ["CLIPVisionModelPool"],
11
+ "convnext_for_image_classification": [
12
+ "ConvNextForImageClassificationPool",
13
+ "load_transformers_convnext",
14
+ ],
15
+ "dinov2_for_image_classification": [
16
+ "Dinov2ForImageClassificationPool",
17
+ "load_transformers_dinov2",
18
+ ],
11
19
  "nyuv2_modelpool": ["NYUv2ModelPool"],
12
20
  "huggingface_automodel": ["AutoModelPool"],
13
21
  "seq2seq_lm": ["Seq2SeqLMPool"],
@@ -18,7 +26,10 @@ _import_structure = {
18
26
  "GPT2ForSequenceClassificationPool",
19
27
  ],
20
28
  "seq_classification_lm": ["SequenceClassificationModelPool"],
21
- "resnet_for_image_classification": ["ResNetForImageClassificationPool"],
29
+ "resnet_for_image_classification": [
30
+ "ResNetForImageClassificationPool",
31
+ "load_transformers_resnet",
32
+ ],
22
33
  }
23
34
 
24
35
 
@@ -26,6 +37,14 @@ if TYPE_CHECKING:
26
37
  from .base_pool import BaseModelPool
27
38
  from .causal_lm import CausalLMBackbonePool, CausalLMPool
28
39
  from .clip_vision import CLIPVisionModelPool
40
+ from .convnext_for_image_classification import (
41
+ ConvNextForImageClassificationPool,
42
+ load_transformers_convnext,
43
+ )
44
+ from .dinov2_for_image_classification import (
45
+ Dinov2ForImageClassificationPool,
46
+ load_transformers_dinov2,
47
+ )
29
48
  from .huggingface_automodel import AutoModelPool
30
49
  from .huggingface_gpt2_classification import (
31
50
  GPT2ForSequenceClassificationPool,
@@ -34,7 +53,10 @@ if TYPE_CHECKING:
34
53
  from .nyuv2_modelpool import NYUv2ModelPool
35
54
  from .openclip_vision import OpenCLIPVisionModelPool
36
55
  from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
37
- from .resnet_for_image_classification import ResNetForImageClassificationPool
56
+ from .resnet_for_image_classification import (
57
+ ResNetForImageClassificationPool,
58
+ load_transformers_resnet,
59
+ )
38
60
  from .seq2seq_lm import Seq2SeqLMPool
39
61
  from .seq_classification_lm import SequenceClassificationModelPool
40
62
 
@@ -1,14 +1,19 @@
1
1
  import logging
2
2
  from copy import deepcopy
3
- from typing import Dict, Generator, List, Optional, Tuple, Union
3
+ from typing import Any, Dict, Generator, List, Optional, Tuple, Union
4
4
 
5
5
  import torch
6
- from omegaconf import DictConfig
6
+ from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
7
7
  from torch import nn
8
8
  from torch.utils.data import Dataset
9
9
 
10
10
  from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
11
- from fusion_bench.utils import instantiate, timeit_context
11
+ from fusion_bench.utils import (
12
+ ValidationError,
13
+ instantiate,
14
+ timeit_context,
15
+ validate_model_name,
16
+ )
12
17
 
13
18
  __all__ = ["BaseModelPool"]
14
19
 
@@ -52,6 +57,23 @@ class BaseModelPool(
52
57
  ):
53
58
  if isinstance(models, List):
54
59
  models = {str(model_idx): model for model_idx, model in enumerate(models)}
60
+
61
+ if isinstance(models, dict):
62
+ try: # try to convert to DictConfig
63
+ models = OmegaConf.create(models)
64
+ except UnsupportedValueType:
65
+ pass
66
+
67
+ if not models:
68
+ log.warning("Initialized BaseModelPool with empty models dictionary.")
69
+ else:
70
+ # Validate model names
71
+ for model_name in models.keys():
72
+ try:
73
+ validate_model_name(model_name, allow_special=True)
74
+ except ValidationError as e:
75
+ log.warning(f"Invalid model name '{model_name}': {e}")
76
+
55
77
  self._models = models
56
78
  self._train_datasets = train_datasets
57
79
  self._val_datasets = val_datasets
@@ -140,7 +162,9 @@ class BaseModelPool(
140
162
  """
141
163
  return model_name.startswith("_") and model_name.endswith("_")
142
164
 
143
- def get_model_config(self, model_name: str, return_copy: bool = True) -> DictConfig:
165
+ def get_model_config(
166
+ self, model_name: str, return_copy: bool = True
167
+ ) -> Union[DictConfig, str, Any]:
144
168
  """
145
169
  Get the configuration for the specified model.
146
170
 
@@ -148,10 +172,36 @@ class BaseModelPool(
148
172
  model_name (str): The name of the model.
149
173
 
150
174
  Returns:
151
- DictConfig: The configuration for the specified model.
175
+ Union[DictConfig, str, Any]: The configuration for the specified model, which may be a DictConfig, string path, or other type.
176
+
177
+ Raises:
178
+ ValidationError: If model_name is invalid.
179
+ KeyError: If model_name is not found in the pool.
152
180
  """
181
+ # Validate model name
182
+ validate_model_name(model_name, allow_special=True)
183
+
184
+ # raise friendly error if model not found in the pool
185
+ if model_name not in self._models:
186
+ available_models = list(self._models.keys())
187
+ raise KeyError(
188
+ f"Model '{model_name}' not found in model pool. "
189
+ f"Available models: {available_models}"
190
+ )
191
+
153
192
  model_config = self._models[model_name]
193
+ if isinstance(model_config, nn.Module):
194
+ log.warning(
195
+ f"Model configuration for '{model_name}' is a pre-instantiated model. "
196
+ "Returning the model instance instead of configuration."
197
+ )
198
+
154
199
  if return_copy:
200
+ if isinstance(model_config, nn.Module):
201
+ # raise performance warning
202
+ log.warning(
203
+ f"Furthermore, returning a copy of the pre-instantiated model '{model_name}' may be inefficient."
204
+ )
155
205
  model_config = deepcopy(model_config)
156
206
  return model_config
157
207
 
@@ -164,12 +214,28 @@ class BaseModelPool(
164
214
 
165
215
  Returns:
166
216
  str: The path for the specified model.
217
+
218
+ Raises:
219
+ ValidationError: If model_name is invalid.
220
+ KeyError: If model_name is not found in the pool.
221
+ ValueError: If model configuration is not a string path.
167
222
  """
223
+ # Validate model name
224
+ validate_model_name(model_name, allow_special=True)
225
+
226
+ if model_name not in self._models:
227
+ available_models = list(self._models.keys())
228
+ raise KeyError(
229
+ f"Model '{model_name}' not found in model pool. "
230
+ f"Available models: {available_models}"
231
+ )
232
+
168
233
  if isinstance(self._models[model_name], str):
169
234
  return self._models[model_name]
170
235
  else:
171
236
  raise ValueError(
172
- "Model path is not a string. Try to override this method in derived modelpool class."
237
+ f"Model configuration for '{model_name}' is not a string path. "
238
+ "Try to override this method in derived modelpool class."
173
239
  )
174
240
 
175
241
  def load_model(
@@ -350,3 +416,25 @@ class BaseModelPool(
350
416
  """
351
417
  with timeit_context(f"Saving the state dict of model to {path}"):
352
418
  torch.save(model.state_dict(), path)
419
+
420
+ def __contains__(self, model_name: str) -> bool:
421
+ """
422
+ Check if a model with the given name exists in the model pool.
423
+
424
+ Examples:
425
+ >>> modelpool = BaseModelPool(models={"modelA": ..., "modelB": ...})
426
+ >>> "modelA" in modelpool
427
+ True
428
+ >>> "modelC" in modelpool
429
+ False
430
+
431
+ Args:
432
+ model_name (str): The name of the model to check.
433
+
434
+ Returns:
435
+ bool: True if the model exists, False otherwise.
436
+ """
437
+ if self._models is None:
438
+ raise RuntimeError("Model pool is not initialized")
439
+ validate_model_name(model_name, allow_special=True)
440
+ return model_name in self._models