fusion-bench 0.2.27__py3-none-any.whl → 0.2.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +4 -0
- fusion_bench/constants/__init__.py +5 -1
- fusion_bench/constants/runtime.py +111 -7
- fusion_bench/dataset/gsm8k.py +6 -2
- fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
- fusion_bench/method/__init__.py +1 -1
- fusion_bench/method/classification/image_classification_finetune.py +13 -2
- fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
- fusion_bench/method/task_arithmetic/task_arithmetic.py +4 -1
- fusion_bench/metrics/nyuv2/__init__.py +31 -0
- fusion_bench/metrics/nyuv2/depth.py +30 -0
- fusion_bench/metrics/nyuv2/loss.py +40 -0
- fusion_bench/metrics/nyuv2/noise.py +24 -0
- fusion_bench/metrics/nyuv2/normal.py +34 -1
- fusion_bench/metrics/nyuv2/segmentation.py +35 -1
- fusion_bench/mixins/clip_classification.py +30 -2
- fusion_bench/mixins/lightning_fabric.py +46 -5
- fusion_bench/mixins/rich_live.py +76 -0
- fusion_bench/modelpool/__init__.py +24 -2
- fusion_bench/modelpool/base_pool.py +94 -6
- fusion_bench/modelpool/convnext_for_image_classification.py +198 -0
- fusion_bench/modelpool/dinov2_for_image_classification.py +197 -0
- fusion_bench/modelpool/resnet_for_image_classification.py +4 -1
- fusion_bench/models/model_card_templates/default.md +1 -1
- fusion_bench/scripts/webui.py +250 -17
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/data.py +100 -9
- fusion_bench/utils/fabric.py +185 -4
- fusion_bench/utils/json.py +55 -8
- fusion_bench/utils/validation.py +197 -0
- {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/METADATA +66 -7
- {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/RECORD +44 -40
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
- fusion_bench_config/llama_full_finetune.yaml +4 -16
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224.yaml +10 -0
- fusion_bench_config/modelpool/Dinov2ForImageClassification/dinov2-base-imagenet1k-1-layer.yaml +10 -0
- fusion_bench_config/nyuv2_config.yaml +4 -13
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench/utils/auto.py +0 -31
- {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/top_level.txt +0 -0
|
@@ -3,10 +3,35 @@ from torch import Tensor, nn
|
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
def segmentation_loss(pred: Tensor, gt: Tensor):
|
|
6
|
+
"""
|
|
7
|
+
Compute cross-entropy loss for semantic segmentation.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
pred: Predicted segmentation logits of shape (batch_size, num_classes, height, width).
|
|
11
|
+
gt: Ground truth segmentation labels of shape (batch_size, height, width).
|
|
12
|
+
Pixels with value -1 are ignored in the loss computation.
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
Tensor: Scalar loss value.
|
|
16
|
+
"""
|
|
6
17
|
return nn.functional.cross_entropy(pred, gt.long(), ignore_index=-1)
|
|
7
18
|
|
|
8
19
|
|
|
9
20
|
def depth_loss(pred: Tensor, gt: Tensor):
|
|
21
|
+
"""
|
|
22
|
+
Compute L1 loss for depth estimation with binary masking.
|
|
23
|
+
|
|
24
|
+
This loss function calculates the absolute error between predicted and ground truth
|
|
25
|
+
depth values, but only for valid pixels (where ground truth depth is non-zero).
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
pred: Predicted depth values of shape (batch_size, 1, height, width).
|
|
29
|
+
gt: Ground truth depth values of shape (batch_size, 1, height, width).
|
|
30
|
+
Pixels with sum of 0 across channels are considered invalid and masked out.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Tensor: Scalar loss value averaged over valid pixels.
|
|
34
|
+
"""
|
|
10
35
|
binary_mask = (torch.sum(gt, dim=1) != 0).float().unsqueeze(1).to(pred.device)
|
|
11
36
|
loss = torch.sum(torch.abs(pred - gt) * binary_mask) / torch.nonzero(
|
|
12
37
|
binary_mask, as_tuple=False
|
|
@@ -15,6 +40,21 @@ def depth_loss(pred: Tensor, gt: Tensor):
|
|
|
15
40
|
|
|
16
41
|
|
|
17
42
|
def normal_loss(pred: Tensor, gt: Tensor):
|
|
43
|
+
"""
|
|
44
|
+
Compute cosine similarity loss for surface normal prediction.
|
|
45
|
+
|
|
46
|
+
This loss measures the angular difference between predicted and ground truth
|
|
47
|
+
surface normals using normalized cosine similarity (1 - dot product).
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
pred: Predicted surface normals of shape (batch_size, 3, height, width).
|
|
51
|
+
Will be L2-normalized before computing loss.
|
|
52
|
+
gt: Ground truth surface normals of shape (batch_size, 3, height, width).
|
|
53
|
+
Already normalized on NYUv2 dataset. Pixels with sum of 0 are invalid.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Tensor: Scalar loss value (1 - mean cosine similarity) over valid pixels.
|
|
57
|
+
"""
|
|
18
58
|
# gt has been normalized on the NYUv2 dataset
|
|
19
59
|
pred = pred / torch.norm(pred, p=2, dim=1, keepdim=True)
|
|
20
60
|
binary_mask = (torch.sum(gt, dim=1) != 0).float().unsqueeze(1).to(pred.device)
|
|
@@ -6,11 +6,35 @@ from torchmetrics import Metric
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class NoiseMetric(Metric):
|
|
9
|
+
"""
|
|
10
|
+
A placeholder metric for noise evaluation on NYUv2 dataset.
|
|
11
|
+
|
|
12
|
+
This metric currently serves as a placeholder and always returns a value of 1.
|
|
13
|
+
It can be extended in the future to include actual noise-related metrics.
|
|
14
|
+
|
|
15
|
+
Note:
|
|
16
|
+
This is a dummy implementation that doesn't perform actual noise measurements.
|
|
17
|
+
"""
|
|
18
|
+
|
|
9
19
|
def __init__(self):
|
|
20
|
+
"""Initialize the NoiseMetric."""
|
|
10
21
|
super().__init__()
|
|
11
22
|
|
|
12
23
|
def update(self, preds: Tensor, target: Tensor):
|
|
24
|
+
"""
|
|
25
|
+
Update metric state (currently a no-op).
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
preds: Predicted values (unused).
|
|
29
|
+
target: Ground truth values (unused).
|
|
30
|
+
"""
|
|
13
31
|
pass
|
|
14
32
|
|
|
15
33
|
def compute(self):
|
|
34
|
+
"""
|
|
35
|
+
Compute the metric value.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
List[int]: A list containing [1] as a placeholder value.
|
|
39
|
+
"""
|
|
16
40
|
return [1]
|
|
@@ -7,14 +7,36 @@ from torchmetrics import Metric
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class NormalMetric(Metric):
|
|
10
|
+
"""
|
|
11
|
+
Metric for evaluating surface normal prediction on NYUv2 dataset.
|
|
12
|
+
|
|
13
|
+
This metric computes angular error statistics between predicted and ground truth
|
|
14
|
+
surface normals, including mean, median, and percentage of predictions within
|
|
15
|
+
specific angular thresholds (11.25°, 22.5°, 30°).
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
metric_names: List of metric names ["mean", "median", "<11.25", "<22.5", "<30"].
|
|
19
|
+
record: List storing angular errors (in degrees) for all pixels across batches.
|
|
20
|
+
"""
|
|
21
|
+
|
|
10
22
|
metric_names = ["mean", "median", "<11.25", "<22.5", "<30"]
|
|
11
23
|
|
|
12
24
|
def __init__(self):
|
|
25
|
+
"""Initialize the NormalMetric with state for recording angular errors."""
|
|
13
26
|
super(NormalMetric, self).__init__()
|
|
14
27
|
|
|
15
28
|
self.add_state("record", default=[], dist_reduce_fx="cat")
|
|
16
29
|
|
|
17
30
|
def update(self, preds, target):
|
|
31
|
+
"""
|
|
32
|
+
Update metric state with predictions and targets from a batch.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
preds: Predicted surface normals of shape (batch_size, 3, height, width).
|
|
36
|
+
Will be L2-normalized before computing errors.
|
|
37
|
+
target: Ground truth surface normals of shape (batch_size, 3, height, width).
|
|
38
|
+
Already normalized on NYUv2 dataset. Pixels with sum of 0 are invalid.
|
|
39
|
+
"""
|
|
18
40
|
# gt has been normalized on the NYUv2 dataset
|
|
19
41
|
preds = preds / torch.norm(preds, p=2, dim=1, keepdim=True)
|
|
20
42
|
binary_mask = torch.sum(target, dim=1) != 0
|
|
@@ -33,7 +55,18 @@ class NormalMetric(Metric):
|
|
|
33
55
|
|
|
34
56
|
def compute(self):
|
|
35
57
|
"""
|
|
36
|
-
|
|
58
|
+
Compute final metric values from all recorded angular errors.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
List[Tensor]: A list containing five metrics:
|
|
62
|
+
- mean: Mean angular error in degrees.
|
|
63
|
+
- median: Median angular error in degrees.
|
|
64
|
+
- <11.25: Percentage of pixels with error < 11.25°.
|
|
65
|
+
- <22.5: Percentage of pixels with error < 22.5°.
|
|
66
|
+
- <30: Percentage of pixels with error < 30°.
|
|
67
|
+
|
|
68
|
+
Note:
|
|
69
|
+
Returns zeros if no data has been recorded.
|
|
37
70
|
"""
|
|
38
71
|
if self.record is None:
|
|
39
72
|
return torch.asarray([0.0, 0.0, 0.0, 0.0, 0.0])
|
|
@@ -6,9 +6,28 @@ from torchmetrics import Metric
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class SegmentationMetric(Metric):
|
|
9
|
+
"""
|
|
10
|
+
Metric for evaluating semantic segmentation on NYUv2 dataset.
|
|
11
|
+
|
|
12
|
+
This metric computes mean Intersection over Union (mIoU) and pixel accuracy
|
|
13
|
+
for multi-class segmentation tasks.
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
metric_names: List of metric names ["mIoU", "pixAcc"].
|
|
17
|
+
num_classes: Number of segmentation classes (default: 13 for NYUv2).
|
|
18
|
+
record: Confusion matrix of shape (num_classes, num_classes) tracking
|
|
19
|
+
predictions vs ground truth.
|
|
20
|
+
"""
|
|
21
|
+
|
|
9
22
|
metric_names = ["mIoU", "pixAcc"]
|
|
10
23
|
|
|
11
24
|
def __init__(self, num_classes=13):
|
|
25
|
+
"""
|
|
26
|
+
Initialize the SegmentationMetric.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
num_classes: Number of segmentation classes. Default is 13 for NYUv2 dataset.
|
|
30
|
+
"""
|
|
12
31
|
super().__init__()
|
|
13
32
|
|
|
14
33
|
self.num_classes = num_classes
|
|
@@ -21,9 +40,19 @@ class SegmentationMetric(Metric):
|
|
|
21
40
|
)
|
|
22
41
|
|
|
23
42
|
def reset(self):
|
|
43
|
+
"""Reset the confusion matrix to zeros."""
|
|
24
44
|
self.record.zero_()
|
|
25
45
|
|
|
26
46
|
def update(self, preds: Tensor, target: Tensor):
|
|
47
|
+
"""
|
|
48
|
+
Update the confusion matrix with predictions and targets from a batch.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
preds: Predicted segmentation logits of shape (batch_size, num_classes, height, width).
|
|
52
|
+
Will be converted to class predictions via softmax and argmax.
|
|
53
|
+
target: Ground truth segmentation labels of shape (batch_size, height, width).
|
|
54
|
+
Pixels with negative values or values >= num_classes are ignored.
|
|
55
|
+
"""
|
|
27
56
|
preds = preds.softmax(1).argmax(1).flatten()
|
|
28
57
|
target = target.long().flatten()
|
|
29
58
|
|
|
@@ -35,7 +64,12 @@ class SegmentationMetric(Metric):
|
|
|
35
64
|
|
|
36
65
|
def compute(self):
|
|
37
66
|
"""
|
|
38
|
-
|
|
67
|
+
Compute mIoU and pixel accuracy from the confusion matrix.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
List[Tensor]: A list containing [mIoU, pixel_accuracy]:
|
|
71
|
+
- mIoU: Mean Intersection over Union across all classes.
|
|
72
|
+
- pixel_accuracy: Overall pixel classification accuracy.
|
|
39
73
|
"""
|
|
40
74
|
h = cast(Tensor, self.record).float()
|
|
41
75
|
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
|
@@ -59,6 +59,15 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
59
59
|
|
|
60
60
|
@property
|
|
61
61
|
def clip_processor(self):
|
|
62
|
+
"""
|
|
63
|
+
Get the CLIP processor, loading it from the model pool if necessary.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
CLIPProcessor: The CLIP processor for image and text preprocessing.
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
AssertionError: If the model pool is not set.
|
|
70
|
+
"""
|
|
62
71
|
if self._clip_processor is None:
|
|
63
72
|
assert self.modelpool is not None, "Model pool is not set"
|
|
64
73
|
self._clip_processor = self.modelpool.load_processor()
|
|
@@ -125,6 +134,11 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
125
134
|
clip_model (Optional[CLIPModel]): The CLIP model to use. If not provided, a pretrained model is loaded from the model pool.
|
|
126
135
|
task_names (Optional[List[str]]): A list of task names to set up the classification head for. If not provided, all models in the model pool will be used.
|
|
127
136
|
"""
|
|
137
|
+
# make sure the task names are equal across all processes
|
|
138
|
+
_task_names = self.fabric.broadcast(task_names, src=0)
|
|
139
|
+
if not self.fabric.is_global_zero and task_names != _task_names:
|
|
140
|
+
raise ValueError("The `task_names` must be the same across all processes.")
|
|
141
|
+
|
|
128
142
|
self.whether_setup_zero_shot_classification_head = True
|
|
129
143
|
# load clip model if not provided
|
|
130
144
|
if clip_model is None:
|
|
@@ -147,7 +161,10 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
147
161
|
self.logit_scale_exp = self.fabric.to_device(self.logit_scale_exp)
|
|
148
162
|
|
|
149
163
|
@cache_with_joblib()
|
|
150
|
-
def construct_classification_head(task: str):
|
|
164
|
+
def construct_classification_head(task: str, model_name: str):
|
|
165
|
+
log.info(
|
|
166
|
+
f"Constructing zero-shot classification head for task: {task} using model: {model_name}"
|
|
167
|
+
)
|
|
151
168
|
nonlocal clip_classifier
|
|
152
169
|
|
|
153
170
|
classnames, templates = get_classnames_and_templates(task)
|
|
@@ -163,7 +180,18 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
163
180
|
):
|
|
164
181
|
zeroshot_weights = None
|
|
165
182
|
if self.fabric.is_global_zero:
|
|
166
|
-
|
|
183
|
+
if hasattr(clip_model, "config") and hasattr(
|
|
184
|
+
clip_model.config, "_name_or_path"
|
|
185
|
+
):
|
|
186
|
+
model_name = clip_model.config._name_or_path
|
|
187
|
+
else:
|
|
188
|
+
model_name = "unknown_model"
|
|
189
|
+
log.warning(
|
|
190
|
+
"CLIP model config does not have `_name_or_path` attribute. Using 'unknown_model' as model name."
|
|
191
|
+
)
|
|
192
|
+
zeroshot_weights = construct_classification_head(
|
|
193
|
+
task, model_name=model_name
|
|
194
|
+
)
|
|
167
195
|
|
|
168
196
|
self.fabric.barrier()
|
|
169
197
|
self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
-
from typing import TYPE_CHECKING, Any, List, Optional, TypeVar
|
|
4
|
+
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, TypeVar
|
|
5
5
|
|
|
6
6
|
import lightning as L
|
|
7
7
|
import torch
|
|
@@ -96,12 +96,24 @@ class LightningFabricMixin:
|
|
|
96
96
|
|
|
97
97
|
@property
|
|
98
98
|
def fabric(self):
|
|
99
|
+
"""
|
|
100
|
+
Get the Lightning Fabric instance, initializing it if necessary.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
L.Fabric: The Lightning Fabric instance for distributed computing.
|
|
104
|
+
"""
|
|
99
105
|
if self._fabric_instance is None:
|
|
100
106
|
self.setup_lightning_fabric(getattr(self, "config", DictConfig({})))
|
|
101
107
|
return self._fabric_instance
|
|
102
108
|
|
|
103
109
|
@fabric.setter
|
|
104
110
|
def fabric(self, instance: L.Fabric):
|
|
111
|
+
"""
|
|
112
|
+
Set the Lightning Fabric instance.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
instance: The Lightning Fabric instance to use.
|
|
116
|
+
"""
|
|
105
117
|
self._fabric_instance = instance
|
|
106
118
|
|
|
107
119
|
@property
|
|
@@ -172,6 +184,15 @@ class LightningFabricMixin:
|
|
|
172
184
|
def tensorboard_summarywriter(
|
|
173
185
|
self,
|
|
174
186
|
) -> "lightning.fabric.loggers.tensorboard.SummaryWriter":
|
|
187
|
+
"""
|
|
188
|
+
Get the TensorBoard SummaryWriter for detailed logging.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
SummaryWriter: The TensorBoard SummaryWriter instance.
|
|
192
|
+
|
|
193
|
+
Raises:
|
|
194
|
+
AttributeError: If the logger is not a TensorBoardLogger.
|
|
195
|
+
"""
|
|
175
196
|
if isinstance(self.fabric.logger, TensorBoardLogger):
|
|
176
197
|
return self.fabric.logger.experiment
|
|
177
198
|
else:
|
|
@@ -179,6 +200,12 @@ class LightningFabricMixin:
|
|
|
179
200
|
|
|
180
201
|
@property
|
|
181
202
|
def is_debug_mode(self):
|
|
203
|
+
"""
|
|
204
|
+
Check if the program is running in debug mode (fast_dev_run).
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
bool: True if fast_dev_run is enabled, False otherwise.
|
|
208
|
+
"""
|
|
182
209
|
if hasattr(self, "config") and self.config.get("fast_dev_run", False):
|
|
183
210
|
return True
|
|
184
211
|
elif hasattr(self, "_program") and self._program.config.get(
|
|
@@ -190,13 +217,22 @@ class LightningFabricMixin:
|
|
|
190
217
|
|
|
191
218
|
def log(self, name: str, value: Any, step: Optional[int] = None):
|
|
192
219
|
"""
|
|
193
|
-
Logs
|
|
220
|
+
Logs a single metric to the fabric's logger.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
name: The name of the metric to log.
|
|
224
|
+
value: The value of the metric.
|
|
225
|
+
step: Optional step number for the metric.
|
|
194
226
|
"""
|
|
195
227
|
self.fabric.log(name, value, step=step)
|
|
196
228
|
|
|
197
|
-
def log_dict(self, metrics:
|
|
229
|
+
def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None):
|
|
198
230
|
"""
|
|
199
|
-
Logs
|
|
231
|
+
Logs multiple metrics to the fabric's logger.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
metrics: Dictionary of metric names and values.
|
|
235
|
+
step: Optional step number for the metrics.
|
|
200
236
|
"""
|
|
201
237
|
self.fabric.log_dict(metrics, step=step)
|
|
202
238
|
|
|
@@ -207,7 +243,12 @@ class LightningFabricMixin:
|
|
|
207
243
|
name_template: str = "train/lr_group_{0}",
|
|
208
244
|
):
|
|
209
245
|
"""
|
|
210
|
-
Logs the learning rate of
|
|
246
|
+
Logs the learning rate of each parameter group in the optimizer.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
optimizer: The optimizer whose learning rates should be logged.
|
|
250
|
+
step: Optional step number for the log entry.
|
|
251
|
+
name_template: Template string for the log name. Use {0} as placeholder for group index.
|
|
211
252
|
"""
|
|
212
253
|
for i, param_group in enumerate(optimizer.param_groups):
|
|
213
254
|
self.fabric.log(name_template.format(i), param_group["lr"], step=step)
|
fusion_bench/mixins/rich_live.py
CHANGED
|
@@ -2,20 +2,96 @@ from rich.live import Live
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class RichLiveMixin:
|
|
5
|
+
"""
|
|
6
|
+
A mixin class that provides Rich Live display capabilities.
|
|
7
|
+
|
|
8
|
+
This mixin integrates Rich's Live display functionality, allowing for
|
|
9
|
+
dynamic, auto-refreshing console output. It's particularly useful for
|
|
10
|
+
displaying real-time updates, progress information, or continuously
|
|
11
|
+
changing data without cluttering the terminal.
|
|
12
|
+
|
|
13
|
+
Attributes:
|
|
14
|
+
_rich_live (Live): The internal Rich Live instance for live display updates.
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
```python
|
|
18
|
+
class MyTask(RichLiveMixin):
|
|
19
|
+
def run(self):
|
|
20
|
+
self.start_rich_live()
|
|
21
|
+
for i in range(100):
|
|
22
|
+
self.rich_live_print(f"Processing item {i}")
|
|
23
|
+
time.sleep(0.1)
|
|
24
|
+
self.stop_rich_live()
|
|
25
|
+
```
|
|
26
|
+
"""
|
|
27
|
+
|
|
5
28
|
_rich_live: Live = None
|
|
6
29
|
|
|
7
30
|
@property
|
|
8
31
|
def rich_live(self) -> Live:
|
|
32
|
+
"""
|
|
33
|
+
Get the Rich Live instance, creating it if necessary.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Live: The Rich Live instance for dynamic console output.
|
|
37
|
+
"""
|
|
9
38
|
if self._rich_live is None:
|
|
10
39
|
self._rich_live = Live()
|
|
11
40
|
return self._rich_live
|
|
12
41
|
|
|
13
42
|
def start_rich_live(self):
|
|
43
|
+
"""
|
|
44
|
+
Start the Rich Live display context.
|
|
45
|
+
|
|
46
|
+
This method enters the Rich Live context, enabling dynamic console output.
|
|
47
|
+
Must be paired with stop_rich_live() to properly clean up resources.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
The Rich Live instance in its started state.
|
|
51
|
+
|
|
52
|
+
Example:
|
|
53
|
+
```python
|
|
54
|
+
self.start_rich_live()
|
|
55
|
+
# Display dynamic content
|
|
56
|
+
self.rich_live_print("Dynamic output")
|
|
57
|
+
self.stop_rich_live()
|
|
58
|
+
```
|
|
59
|
+
"""
|
|
14
60
|
return self.rich_live.__enter__()
|
|
15
61
|
|
|
16
62
|
def stop_rich_live(self):
|
|
63
|
+
"""
|
|
64
|
+
Stop the Rich Live display context and clean up resources.
|
|
65
|
+
|
|
66
|
+
This method exits the Rich Live context and resets the internal Live instance.
|
|
67
|
+
Should be called after start_rich_live() when dynamic display is complete.
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
```python
|
|
71
|
+
self.start_rich_live()
|
|
72
|
+
# ... display content ...
|
|
73
|
+
self.stop_rich_live()
|
|
74
|
+
```
|
|
75
|
+
"""
|
|
17
76
|
self.rich_live.__exit__(None, None, None)
|
|
18
77
|
self._rich_live = None
|
|
19
78
|
|
|
20
79
|
def rich_live_print(self, msg):
|
|
80
|
+
"""
|
|
81
|
+
Print a message to the Rich Live console.
|
|
82
|
+
|
|
83
|
+
This method displays the given message through the Rich Live console,
|
|
84
|
+
allowing for formatted, dynamic output that updates in place.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
msg: The message to display. Can be a string or any Rich renderable object.
|
|
88
|
+
|
|
89
|
+
Example:
|
|
90
|
+
```python
|
|
91
|
+
self.start_rich_live()
|
|
92
|
+
self.rich_live_print("[bold green]Success![/bold green]")
|
|
93
|
+
self.rich_live_print(Panel("Status: Running"))
|
|
94
|
+
self.stop_rich_live()
|
|
95
|
+
```
|
|
96
|
+
"""
|
|
21
97
|
self.rich_live.console.print(msg)
|
|
@@ -8,6 +8,14 @@ _import_structure = {
|
|
|
8
8
|
"base_pool": ["BaseModelPool"],
|
|
9
9
|
"causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
|
|
10
10
|
"clip_vision": ["CLIPVisionModelPool"],
|
|
11
|
+
"convnext_for_image_classification": [
|
|
12
|
+
"ConvNextForImageClassificationPool",
|
|
13
|
+
"load_transformers_convnext",
|
|
14
|
+
],
|
|
15
|
+
"dinov2_for_image_classification": [
|
|
16
|
+
"Dinov2ForImageClassificationPool",
|
|
17
|
+
"load_transformers_dinov2",
|
|
18
|
+
],
|
|
11
19
|
"nyuv2_modelpool": ["NYUv2ModelPool"],
|
|
12
20
|
"huggingface_automodel": ["AutoModelPool"],
|
|
13
21
|
"seq2seq_lm": ["Seq2SeqLMPool"],
|
|
@@ -18,7 +26,10 @@ _import_structure = {
|
|
|
18
26
|
"GPT2ForSequenceClassificationPool",
|
|
19
27
|
],
|
|
20
28
|
"seq_classification_lm": ["SequenceClassificationModelPool"],
|
|
21
|
-
"resnet_for_image_classification": [
|
|
29
|
+
"resnet_for_image_classification": [
|
|
30
|
+
"ResNetForImageClassificationPool",
|
|
31
|
+
"load_transformers_resnet",
|
|
32
|
+
],
|
|
22
33
|
}
|
|
23
34
|
|
|
24
35
|
|
|
@@ -26,6 +37,14 @@ if TYPE_CHECKING:
|
|
|
26
37
|
from .base_pool import BaseModelPool
|
|
27
38
|
from .causal_lm import CausalLMBackbonePool, CausalLMPool
|
|
28
39
|
from .clip_vision import CLIPVisionModelPool
|
|
40
|
+
from .convnext_for_image_classification import (
|
|
41
|
+
ConvNextForImageClassificationPool,
|
|
42
|
+
load_transformers_convnext,
|
|
43
|
+
)
|
|
44
|
+
from .dinov2_for_image_classification import (
|
|
45
|
+
Dinov2ForImageClassificationPool,
|
|
46
|
+
load_transformers_dinov2,
|
|
47
|
+
)
|
|
29
48
|
from .huggingface_automodel import AutoModelPool
|
|
30
49
|
from .huggingface_gpt2_classification import (
|
|
31
50
|
GPT2ForSequenceClassificationPool,
|
|
@@ -34,7 +53,10 @@ if TYPE_CHECKING:
|
|
|
34
53
|
from .nyuv2_modelpool import NYUv2ModelPool
|
|
35
54
|
from .openclip_vision import OpenCLIPVisionModelPool
|
|
36
55
|
from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
|
|
37
|
-
from .resnet_for_image_classification import
|
|
56
|
+
from .resnet_for_image_classification import (
|
|
57
|
+
ResNetForImageClassificationPool,
|
|
58
|
+
load_transformers_resnet,
|
|
59
|
+
)
|
|
38
60
|
from .seq2seq_lm import Seq2SeqLMPool
|
|
39
61
|
from .seq_classification_lm import SequenceClassificationModelPool
|
|
40
62
|
|
|
@@ -1,14 +1,19 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from copy import deepcopy
|
|
3
|
-
from typing import Dict, Generator, List, Optional, Tuple, Union
|
|
3
|
+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
from omegaconf import DictConfig
|
|
6
|
+
from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
|
|
7
7
|
from torch import nn
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
9
|
|
|
10
10
|
from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
|
|
11
|
-
from fusion_bench.utils import
|
|
11
|
+
from fusion_bench.utils import (
|
|
12
|
+
ValidationError,
|
|
13
|
+
instantiate,
|
|
14
|
+
timeit_context,
|
|
15
|
+
validate_model_name,
|
|
16
|
+
)
|
|
12
17
|
|
|
13
18
|
__all__ = ["BaseModelPool"]
|
|
14
19
|
|
|
@@ -52,6 +57,23 @@ class BaseModelPool(
|
|
|
52
57
|
):
|
|
53
58
|
if isinstance(models, List):
|
|
54
59
|
models = {str(model_idx): model for model_idx, model in enumerate(models)}
|
|
60
|
+
|
|
61
|
+
if isinstance(models, dict):
|
|
62
|
+
try: # try to convert to DictConfig
|
|
63
|
+
models = OmegaConf.create(models)
|
|
64
|
+
except UnsupportedValueType:
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
if not models:
|
|
68
|
+
log.warning("Initialized BaseModelPool with empty models dictionary.")
|
|
69
|
+
else:
|
|
70
|
+
# Validate model names
|
|
71
|
+
for model_name in models.keys():
|
|
72
|
+
try:
|
|
73
|
+
validate_model_name(model_name, allow_special=True)
|
|
74
|
+
except ValidationError as e:
|
|
75
|
+
log.warning(f"Invalid model name '{model_name}': {e}")
|
|
76
|
+
|
|
55
77
|
self._models = models
|
|
56
78
|
self._train_datasets = train_datasets
|
|
57
79
|
self._val_datasets = val_datasets
|
|
@@ -140,7 +162,9 @@ class BaseModelPool(
|
|
|
140
162
|
"""
|
|
141
163
|
return model_name.startswith("_") and model_name.endswith("_")
|
|
142
164
|
|
|
143
|
-
def get_model_config(
|
|
165
|
+
def get_model_config(
|
|
166
|
+
self, model_name: str, return_copy: bool = True
|
|
167
|
+
) -> Union[DictConfig, str, Any]:
|
|
144
168
|
"""
|
|
145
169
|
Get the configuration for the specified model.
|
|
146
170
|
|
|
@@ -148,10 +172,36 @@ class BaseModelPool(
|
|
|
148
172
|
model_name (str): The name of the model.
|
|
149
173
|
|
|
150
174
|
Returns:
|
|
151
|
-
DictConfig: The configuration for the specified model.
|
|
175
|
+
Union[DictConfig, str, Any]: The configuration for the specified model, which may be a DictConfig, string path, or other type.
|
|
176
|
+
|
|
177
|
+
Raises:
|
|
178
|
+
ValidationError: If model_name is invalid.
|
|
179
|
+
KeyError: If model_name is not found in the pool.
|
|
152
180
|
"""
|
|
181
|
+
# Validate model name
|
|
182
|
+
validate_model_name(model_name, allow_special=True)
|
|
183
|
+
|
|
184
|
+
# raise friendly error if model not found in the pool
|
|
185
|
+
if model_name not in self._models:
|
|
186
|
+
available_models = list(self._models.keys())
|
|
187
|
+
raise KeyError(
|
|
188
|
+
f"Model '{model_name}' not found in model pool. "
|
|
189
|
+
f"Available models: {available_models}"
|
|
190
|
+
)
|
|
191
|
+
|
|
153
192
|
model_config = self._models[model_name]
|
|
193
|
+
if isinstance(model_config, nn.Module):
|
|
194
|
+
log.warning(
|
|
195
|
+
f"Model configuration for '{model_name}' is a pre-instantiated model. "
|
|
196
|
+
"Returning the model instance instead of configuration."
|
|
197
|
+
)
|
|
198
|
+
|
|
154
199
|
if return_copy:
|
|
200
|
+
if isinstance(model_config, nn.Module):
|
|
201
|
+
# raise performance warning
|
|
202
|
+
log.warning(
|
|
203
|
+
f"Furthermore, returning a copy of the pre-instantiated model '{model_name}' may be inefficient."
|
|
204
|
+
)
|
|
155
205
|
model_config = deepcopy(model_config)
|
|
156
206
|
return model_config
|
|
157
207
|
|
|
@@ -164,12 +214,28 @@ class BaseModelPool(
|
|
|
164
214
|
|
|
165
215
|
Returns:
|
|
166
216
|
str: The path for the specified model.
|
|
217
|
+
|
|
218
|
+
Raises:
|
|
219
|
+
ValidationError: If model_name is invalid.
|
|
220
|
+
KeyError: If model_name is not found in the pool.
|
|
221
|
+
ValueError: If model configuration is not a string path.
|
|
167
222
|
"""
|
|
223
|
+
# Validate model name
|
|
224
|
+
validate_model_name(model_name, allow_special=True)
|
|
225
|
+
|
|
226
|
+
if model_name not in self._models:
|
|
227
|
+
available_models = list(self._models.keys())
|
|
228
|
+
raise KeyError(
|
|
229
|
+
f"Model '{model_name}' not found in model pool. "
|
|
230
|
+
f"Available models: {available_models}"
|
|
231
|
+
)
|
|
232
|
+
|
|
168
233
|
if isinstance(self._models[model_name], str):
|
|
169
234
|
return self._models[model_name]
|
|
170
235
|
else:
|
|
171
236
|
raise ValueError(
|
|
172
|
-
"Model
|
|
237
|
+
f"Model configuration for '{model_name}' is not a string path. "
|
|
238
|
+
"Try to override this method in derived modelpool class."
|
|
173
239
|
)
|
|
174
240
|
|
|
175
241
|
def load_model(
|
|
@@ -350,3 +416,25 @@ class BaseModelPool(
|
|
|
350
416
|
"""
|
|
351
417
|
with timeit_context(f"Saving the state dict of model to {path}"):
|
|
352
418
|
torch.save(model.state_dict(), path)
|
|
419
|
+
|
|
420
|
+
def __contains__(self, model_name: str) -> bool:
|
|
421
|
+
"""
|
|
422
|
+
Check if a model with the given name exists in the model pool.
|
|
423
|
+
|
|
424
|
+
Examples:
|
|
425
|
+
>>> modelpool = BaseModelPool(models={"modelA": ..., "modelB": ...})
|
|
426
|
+
>>> "modelA" in modelpool
|
|
427
|
+
True
|
|
428
|
+
>>> "modelC" in modelpool
|
|
429
|
+
False
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
model_name (str): The name of the model to check.
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
bool: True if the model exists, False otherwise.
|
|
436
|
+
"""
|
|
437
|
+
if self._models is None:
|
|
438
|
+
raise RuntimeError("Model pool is not initialized")
|
|
439
|
+
validate_model_name(model_name, allow_special=True)
|
|
440
|
+
return model_name in self._models
|