fusion-bench 0.2.28__py3-none-any.whl → 0.2.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/constants/__init__.py +5 -1
- fusion_bench/constants/runtime.py +111 -7
- fusion_bench/dataset/gsm8k.py +6 -2
- fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
- fusion_bench/method/__init__.py +1 -1
- fusion_bench/method/classification/image_classification_finetune.py +1 -2
- fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
- fusion_bench/metrics/nyuv2/__init__.py +31 -0
- fusion_bench/metrics/nyuv2/depth.py +30 -0
- fusion_bench/metrics/nyuv2/loss.py +40 -0
- fusion_bench/metrics/nyuv2/noise.py +24 -0
- fusion_bench/metrics/nyuv2/normal.py +34 -1
- fusion_bench/metrics/nyuv2/segmentation.py +35 -1
- fusion_bench/mixins/clip_classification.py +30 -2
- fusion_bench/mixins/lightning_fabric.py +46 -5
- fusion_bench/mixins/rich_live.py +76 -0
- fusion_bench/modelpool/base_pool.py +86 -5
- fusion_bench/scripts/webui.py +250 -17
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/data.py +100 -9
- fusion_bench/utils/fabric.py +185 -4
- fusion_bench/utils/json.py +6 -0
- fusion_bench/utils/validation.py +197 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/METADATA +66 -7
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/RECORD +35 -35
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
- fusion_bench_config/llama_full_finetune.yaml +4 -16
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +4 -13
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench/utils/auto.py +0 -31
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/top_level.txt +0 -0
|
@@ -7,14 +7,36 @@ from torchmetrics import Metric
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class NormalMetric(Metric):
|
|
10
|
+
"""
|
|
11
|
+
Metric for evaluating surface normal prediction on NYUv2 dataset.
|
|
12
|
+
|
|
13
|
+
This metric computes angular error statistics between predicted and ground truth
|
|
14
|
+
surface normals, including mean, median, and percentage of predictions within
|
|
15
|
+
specific angular thresholds (11.25°, 22.5°, 30°).
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
metric_names: List of metric names ["mean", "median", "<11.25", "<22.5", "<30"].
|
|
19
|
+
record: List storing angular errors (in degrees) for all pixels across batches.
|
|
20
|
+
"""
|
|
21
|
+
|
|
10
22
|
metric_names = ["mean", "median", "<11.25", "<22.5", "<30"]
|
|
11
23
|
|
|
12
24
|
def __init__(self):
|
|
25
|
+
"""Initialize the NormalMetric with state for recording angular errors."""
|
|
13
26
|
super(NormalMetric, self).__init__()
|
|
14
27
|
|
|
15
28
|
self.add_state("record", default=[], dist_reduce_fx="cat")
|
|
16
29
|
|
|
17
30
|
def update(self, preds, target):
|
|
31
|
+
"""
|
|
32
|
+
Update metric state with predictions and targets from a batch.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
preds: Predicted surface normals of shape (batch_size, 3, height, width).
|
|
36
|
+
Will be L2-normalized before computing errors.
|
|
37
|
+
target: Ground truth surface normals of shape (batch_size, 3, height, width).
|
|
38
|
+
Already normalized on NYUv2 dataset. Pixels with sum of 0 are invalid.
|
|
39
|
+
"""
|
|
18
40
|
# gt has been normalized on the NYUv2 dataset
|
|
19
41
|
preds = preds / torch.norm(preds, p=2, dim=1, keepdim=True)
|
|
20
42
|
binary_mask = torch.sum(target, dim=1) != 0
|
|
@@ -33,7 +55,18 @@ class NormalMetric(Metric):
|
|
|
33
55
|
|
|
34
56
|
def compute(self):
|
|
35
57
|
"""
|
|
36
|
-
|
|
58
|
+
Compute final metric values from all recorded angular errors.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
List[Tensor]: A list containing five metrics:
|
|
62
|
+
- mean: Mean angular error in degrees.
|
|
63
|
+
- median: Median angular error in degrees.
|
|
64
|
+
- <11.25: Percentage of pixels with error < 11.25°.
|
|
65
|
+
- <22.5: Percentage of pixels with error < 22.5°.
|
|
66
|
+
- <30: Percentage of pixels with error < 30°.
|
|
67
|
+
|
|
68
|
+
Note:
|
|
69
|
+
Returns zeros if no data has been recorded.
|
|
37
70
|
"""
|
|
38
71
|
if self.record is None:
|
|
39
72
|
return torch.asarray([0.0, 0.0, 0.0, 0.0, 0.0])
|
|
@@ -6,9 +6,28 @@ from torchmetrics import Metric
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class SegmentationMetric(Metric):
|
|
9
|
+
"""
|
|
10
|
+
Metric for evaluating semantic segmentation on NYUv2 dataset.
|
|
11
|
+
|
|
12
|
+
This metric computes mean Intersection over Union (mIoU) and pixel accuracy
|
|
13
|
+
for multi-class segmentation tasks.
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
metric_names: List of metric names ["mIoU", "pixAcc"].
|
|
17
|
+
num_classes: Number of segmentation classes (default: 13 for NYUv2).
|
|
18
|
+
record: Confusion matrix of shape (num_classes, num_classes) tracking
|
|
19
|
+
predictions vs ground truth.
|
|
20
|
+
"""
|
|
21
|
+
|
|
9
22
|
metric_names = ["mIoU", "pixAcc"]
|
|
10
23
|
|
|
11
24
|
def __init__(self, num_classes=13):
|
|
25
|
+
"""
|
|
26
|
+
Initialize the SegmentationMetric.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
num_classes: Number of segmentation classes. Default is 13 for NYUv2 dataset.
|
|
30
|
+
"""
|
|
12
31
|
super().__init__()
|
|
13
32
|
|
|
14
33
|
self.num_classes = num_classes
|
|
@@ -21,9 +40,19 @@ class SegmentationMetric(Metric):
|
|
|
21
40
|
)
|
|
22
41
|
|
|
23
42
|
def reset(self):
|
|
43
|
+
"""Reset the confusion matrix to zeros."""
|
|
24
44
|
self.record.zero_()
|
|
25
45
|
|
|
26
46
|
def update(self, preds: Tensor, target: Tensor):
|
|
47
|
+
"""
|
|
48
|
+
Update the confusion matrix with predictions and targets from a batch.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
preds: Predicted segmentation logits of shape (batch_size, num_classes, height, width).
|
|
52
|
+
Will be converted to class predictions via softmax and argmax.
|
|
53
|
+
target: Ground truth segmentation labels of shape (batch_size, height, width).
|
|
54
|
+
Pixels with negative values or values >= num_classes are ignored.
|
|
55
|
+
"""
|
|
27
56
|
preds = preds.softmax(1).argmax(1).flatten()
|
|
28
57
|
target = target.long().flatten()
|
|
29
58
|
|
|
@@ -35,7 +64,12 @@ class SegmentationMetric(Metric):
|
|
|
35
64
|
|
|
36
65
|
def compute(self):
|
|
37
66
|
"""
|
|
38
|
-
|
|
67
|
+
Compute mIoU and pixel accuracy from the confusion matrix.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
List[Tensor]: A list containing [mIoU, pixel_accuracy]:
|
|
71
|
+
- mIoU: Mean Intersection over Union across all classes.
|
|
72
|
+
- pixel_accuracy: Overall pixel classification accuracy.
|
|
39
73
|
"""
|
|
40
74
|
h = cast(Tensor, self.record).float()
|
|
41
75
|
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
|
@@ -59,6 +59,15 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
59
59
|
|
|
60
60
|
@property
|
|
61
61
|
def clip_processor(self):
|
|
62
|
+
"""
|
|
63
|
+
Get the CLIP processor, loading it from the model pool if necessary.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
CLIPProcessor: The CLIP processor for image and text preprocessing.
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
AssertionError: If the model pool is not set.
|
|
70
|
+
"""
|
|
62
71
|
if self._clip_processor is None:
|
|
63
72
|
assert self.modelpool is not None, "Model pool is not set"
|
|
64
73
|
self._clip_processor = self.modelpool.load_processor()
|
|
@@ -125,6 +134,11 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
125
134
|
clip_model (Optional[CLIPModel]): The CLIP model to use. If not provided, a pretrained model is loaded from the model pool.
|
|
126
135
|
task_names (Optional[List[str]]): A list of task names to set up the classification head for. If not provided, all models in the model pool will be used.
|
|
127
136
|
"""
|
|
137
|
+
# make sure the task names are equal across all processes
|
|
138
|
+
_task_names = self.fabric.broadcast(task_names, src=0)
|
|
139
|
+
if not self.fabric.is_global_zero and task_names != _task_names:
|
|
140
|
+
raise ValueError("The `task_names` must be the same across all processes.")
|
|
141
|
+
|
|
128
142
|
self.whether_setup_zero_shot_classification_head = True
|
|
129
143
|
# load clip model if not provided
|
|
130
144
|
if clip_model is None:
|
|
@@ -147,7 +161,10 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
147
161
|
self.logit_scale_exp = self.fabric.to_device(self.logit_scale_exp)
|
|
148
162
|
|
|
149
163
|
@cache_with_joblib()
|
|
150
|
-
def construct_classification_head(task: str):
|
|
164
|
+
def construct_classification_head(task: str, model_name: str):
|
|
165
|
+
log.info(
|
|
166
|
+
f"Constructing zero-shot classification head for task: {task} using model: {model_name}"
|
|
167
|
+
)
|
|
151
168
|
nonlocal clip_classifier
|
|
152
169
|
|
|
153
170
|
classnames, templates = get_classnames_and_templates(task)
|
|
@@ -163,7 +180,18 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
163
180
|
):
|
|
164
181
|
zeroshot_weights = None
|
|
165
182
|
if self.fabric.is_global_zero:
|
|
166
|
-
|
|
183
|
+
if hasattr(clip_model, "config") and hasattr(
|
|
184
|
+
clip_model.config, "_name_or_path"
|
|
185
|
+
):
|
|
186
|
+
model_name = clip_model.config._name_or_path
|
|
187
|
+
else:
|
|
188
|
+
model_name = "unknown_model"
|
|
189
|
+
log.warning(
|
|
190
|
+
"CLIP model config does not have `_name_or_path` attribute. Using 'unknown_model' as model name."
|
|
191
|
+
)
|
|
192
|
+
zeroshot_weights = construct_classification_head(
|
|
193
|
+
task, model_name=model_name
|
|
194
|
+
)
|
|
167
195
|
|
|
168
196
|
self.fabric.barrier()
|
|
169
197
|
self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
-
from typing import TYPE_CHECKING, Any, List, Optional, TypeVar
|
|
4
|
+
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, TypeVar
|
|
5
5
|
|
|
6
6
|
import lightning as L
|
|
7
7
|
import torch
|
|
@@ -96,12 +96,24 @@ class LightningFabricMixin:
|
|
|
96
96
|
|
|
97
97
|
@property
|
|
98
98
|
def fabric(self):
|
|
99
|
+
"""
|
|
100
|
+
Get the Lightning Fabric instance, initializing it if necessary.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
L.Fabric: The Lightning Fabric instance for distributed computing.
|
|
104
|
+
"""
|
|
99
105
|
if self._fabric_instance is None:
|
|
100
106
|
self.setup_lightning_fabric(getattr(self, "config", DictConfig({})))
|
|
101
107
|
return self._fabric_instance
|
|
102
108
|
|
|
103
109
|
@fabric.setter
|
|
104
110
|
def fabric(self, instance: L.Fabric):
|
|
111
|
+
"""
|
|
112
|
+
Set the Lightning Fabric instance.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
instance: The Lightning Fabric instance to use.
|
|
116
|
+
"""
|
|
105
117
|
self._fabric_instance = instance
|
|
106
118
|
|
|
107
119
|
@property
|
|
@@ -172,6 +184,15 @@ class LightningFabricMixin:
|
|
|
172
184
|
def tensorboard_summarywriter(
|
|
173
185
|
self,
|
|
174
186
|
) -> "lightning.fabric.loggers.tensorboard.SummaryWriter":
|
|
187
|
+
"""
|
|
188
|
+
Get the TensorBoard SummaryWriter for detailed logging.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
SummaryWriter: The TensorBoard SummaryWriter instance.
|
|
192
|
+
|
|
193
|
+
Raises:
|
|
194
|
+
AttributeError: If the logger is not a TensorBoardLogger.
|
|
195
|
+
"""
|
|
175
196
|
if isinstance(self.fabric.logger, TensorBoardLogger):
|
|
176
197
|
return self.fabric.logger.experiment
|
|
177
198
|
else:
|
|
@@ -179,6 +200,12 @@ class LightningFabricMixin:
|
|
|
179
200
|
|
|
180
201
|
@property
|
|
181
202
|
def is_debug_mode(self):
|
|
203
|
+
"""
|
|
204
|
+
Check if the program is running in debug mode (fast_dev_run).
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
bool: True if fast_dev_run is enabled, False otherwise.
|
|
208
|
+
"""
|
|
182
209
|
if hasattr(self, "config") and self.config.get("fast_dev_run", False):
|
|
183
210
|
return True
|
|
184
211
|
elif hasattr(self, "_program") and self._program.config.get(
|
|
@@ -190,13 +217,22 @@ class LightningFabricMixin:
|
|
|
190
217
|
|
|
191
218
|
def log(self, name: str, value: Any, step: Optional[int] = None):
|
|
192
219
|
"""
|
|
193
|
-
Logs
|
|
220
|
+
Logs a single metric to the fabric's logger.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
name: The name of the metric to log.
|
|
224
|
+
value: The value of the metric.
|
|
225
|
+
step: Optional step number for the metric.
|
|
194
226
|
"""
|
|
195
227
|
self.fabric.log(name, value, step=step)
|
|
196
228
|
|
|
197
|
-
def log_dict(self, metrics:
|
|
229
|
+
def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None):
|
|
198
230
|
"""
|
|
199
|
-
Logs
|
|
231
|
+
Logs multiple metrics to the fabric's logger.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
metrics: Dictionary of metric names and values.
|
|
235
|
+
step: Optional step number for the metrics.
|
|
200
236
|
"""
|
|
201
237
|
self.fabric.log_dict(metrics, step=step)
|
|
202
238
|
|
|
@@ -207,7 +243,12 @@ class LightningFabricMixin:
|
|
|
207
243
|
name_template: str = "train/lr_group_{0}",
|
|
208
244
|
):
|
|
209
245
|
"""
|
|
210
|
-
Logs the learning rate of
|
|
246
|
+
Logs the learning rate of each parameter group in the optimizer.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
optimizer: The optimizer whose learning rates should be logged.
|
|
250
|
+
step: Optional step number for the log entry.
|
|
251
|
+
name_template: Template string for the log name. Use {0} as placeholder for group index.
|
|
211
252
|
"""
|
|
212
253
|
for i, param_group in enumerate(optimizer.param_groups):
|
|
213
254
|
self.fabric.log(name_template.format(i), param_group["lr"], step=step)
|
fusion_bench/mixins/rich_live.py
CHANGED
|
@@ -2,20 +2,96 @@ from rich.live import Live
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class RichLiveMixin:
|
|
5
|
+
"""
|
|
6
|
+
A mixin class that provides Rich Live display capabilities.
|
|
7
|
+
|
|
8
|
+
This mixin integrates Rich's Live display functionality, allowing for
|
|
9
|
+
dynamic, auto-refreshing console output. It's particularly useful for
|
|
10
|
+
displaying real-time updates, progress information, or continuously
|
|
11
|
+
changing data without cluttering the terminal.
|
|
12
|
+
|
|
13
|
+
Attributes:
|
|
14
|
+
_rich_live (Live): The internal Rich Live instance for live display updates.
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
```python
|
|
18
|
+
class MyTask(RichLiveMixin):
|
|
19
|
+
def run(self):
|
|
20
|
+
self.start_rich_live()
|
|
21
|
+
for i in range(100):
|
|
22
|
+
self.rich_live_print(f"Processing item {i}")
|
|
23
|
+
time.sleep(0.1)
|
|
24
|
+
self.stop_rich_live()
|
|
25
|
+
```
|
|
26
|
+
"""
|
|
27
|
+
|
|
5
28
|
_rich_live: Live = None
|
|
6
29
|
|
|
7
30
|
@property
|
|
8
31
|
def rich_live(self) -> Live:
|
|
32
|
+
"""
|
|
33
|
+
Get the Rich Live instance, creating it if necessary.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Live: The Rich Live instance for dynamic console output.
|
|
37
|
+
"""
|
|
9
38
|
if self._rich_live is None:
|
|
10
39
|
self._rich_live = Live()
|
|
11
40
|
return self._rich_live
|
|
12
41
|
|
|
13
42
|
def start_rich_live(self):
|
|
43
|
+
"""
|
|
44
|
+
Start the Rich Live display context.
|
|
45
|
+
|
|
46
|
+
This method enters the Rich Live context, enabling dynamic console output.
|
|
47
|
+
Must be paired with stop_rich_live() to properly clean up resources.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
The Rich Live instance in its started state.
|
|
51
|
+
|
|
52
|
+
Example:
|
|
53
|
+
```python
|
|
54
|
+
self.start_rich_live()
|
|
55
|
+
# Display dynamic content
|
|
56
|
+
self.rich_live_print("Dynamic output")
|
|
57
|
+
self.stop_rich_live()
|
|
58
|
+
```
|
|
59
|
+
"""
|
|
14
60
|
return self.rich_live.__enter__()
|
|
15
61
|
|
|
16
62
|
def stop_rich_live(self):
|
|
63
|
+
"""
|
|
64
|
+
Stop the Rich Live display context and clean up resources.
|
|
65
|
+
|
|
66
|
+
This method exits the Rich Live context and resets the internal Live instance.
|
|
67
|
+
Should be called after start_rich_live() when dynamic display is complete.
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
```python
|
|
71
|
+
self.start_rich_live()
|
|
72
|
+
# ... display content ...
|
|
73
|
+
self.stop_rich_live()
|
|
74
|
+
```
|
|
75
|
+
"""
|
|
17
76
|
self.rich_live.__exit__(None, None, None)
|
|
18
77
|
self._rich_live = None
|
|
19
78
|
|
|
20
79
|
def rich_live_print(self, msg):
|
|
80
|
+
"""
|
|
81
|
+
Print a message to the Rich Live console.
|
|
82
|
+
|
|
83
|
+
This method displays the given message through the Rich Live console,
|
|
84
|
+
allowing for formatted, dynamic output that updates in place.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
msg: The message to display. Can be a string or any Rich renderable object.
|
|
88
|
+
|
|
89
|
+
Example:
|
|
90
|
+
```python
|
|
91
|
+
self.start_rich_live()
|
|
92
|
+
self.rich_live_print("[bold green]Success![/bold green]")
|
|
93
|
+
self.rich_live_print(Panel("Status: Running"))
|
|
94
|
+
self.stop_rich_live()
|
|
95
|
+
```
|
|
96
|
+
"""
|
|
21
97
|
self.rich_live.console.print(msg)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from copy import deepcopy
|
|
3
|
-
from typing import Dict, Generator, List, Optional, Tuple, Union
|
|
3
|
+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
|
|
@@ -8,7 +8,12 @@ from torch import nn
|
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
9
|
|
|
10
10
|
from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
|
|
11
|
-
from fusion_bench.utils import
|
|
11
|
+
from fusion_bench.utils import (
|
|
12
|
+
ValidationError,
|
|
13
|
+
instantiate,
|
|
14
|
+
timeit_context,
|
|
15
|
+
validate_model_name,
|
|
16
|
+
)
|
|
12
17
|
|
|
13
18
|
__all__ = ["BaseModelPool"]
|
|
14
19
|
|
|
@@ -59,6 +64,16 @@ class BaseModelPool(
|
|
|
59
64
|
except UnsupportedValueType:
|
|
60
65
|
pass
|
|
61
66
|
|
|
67
|
+
if not models:
|
|
68
|
+
log.warning("Initialized BaseModelPool with empty models dictionary.")
|
|
69
|
+
else:
|
|
70
|
+
# Validate model names
|
|
71
|
+
for model_name in models.keys():
|
|
72
|
+
try:
|
|
73
|
+
validate_model_name(model_name, allow_special=True)
|
|
74
|
+
except ValidationError as e:
|
|
75
|
+
log.warning(f"Invalid model name '{model_name}': {e}")
|
|
76
|
+
|
|
62
77
|
self._models = models
|
|
63
78
|
self._train_datasets = train_datasets
|
|
64
79
|
self._val_datasets = val_datasets
|
|
@@ -147,7 +162,9 @@ class BaseModelPool(
|
|
|
147
162
|
"""
|
|
148
163
|
return model_name.startswith("_") and model_name.endswith("_")
|
|
149
164
|
|
|
150
|
-
def get_model_config(
|
|
165
|
+
def get_model_config(
|
|
166
|
+
self, model_name: str, return_copy: bool = True
|
|
167
|
+
) -> Union[DictConfig, str, Any]:
|
|
151
168
|
"""
|
|
152
169
|
Get the configuration for the specified model.
|
|
153
170
|
|
|
@@ -155,10 +172,36 @@ class BaseModelPool(
|
|
|
155
172
|
model_name (str): The name of the model.
|
|
156
173
|
|
|
157
174
|
Returns:
|
|
158
|
-
DictConfig: The configuration for the specified model.
|
|
175
|
+
Union[DictConfig, str, Any]: The configuration for the specified model, which may be a DictConfig, string path, or other type.
|
|
176
|
+
|
|
177
|
+
Raises:
|
|
178
|
+
ValidationError: If model_name is invalid.
|
|
179
|
+
KeyError: If model_name is not found in the pool.
|
|
159
180
|
"""
|
|
181
|
+
# Validate model name
|
|
182
|
+
validate_model_name(model_name, allow_special=True)
|
|
183
|
+
|
|
184
|
+
# raise friendly error if model not found in the pool
|
|
185
|
+
if model_name not in self._models:
|
|
186
|
+
available_models = list(self._models.keys())
|
|
187
|
+
raise KeyError(
|
|
188
|
+
f"Model '{model_name}' not found in model pool. "
|
|
189
|
+
f"Available models: {available_models}"
|
|
190
|
+
)
|
|
191
|
+
|
|
160
192
|
model_config = self._models[model_name]
|
|
193
|
+
if isinstance(model_config, nn.Module):
|
|
194
|
+
log.warning(
|
|
195
|
+
f"Model configuration for '{model_name}' is a pre-instantiated model. "
|
|
196
|
+
"Returning the model instance instead of configuration."
|
|
197
|
+
)
|
|
198
|
+
|
|
161
199
|
if return_copy:
|
|
200
|
+
if isinstance(model_config, nn.Module):
|
|
201
|
+
# raise performance warning
|
|
202
|
+
log.warning(
|
|
203
|
+
f"Furthermore, returning a copy of the pre-instantiated model '{model_name}' may be inefficient."
|
|
204
|
+
)
|
|
162
205
|
model_config = deepcopy(model_config)
|
|
163
206
|
return model_config
|
|
164
207
|
|
|
@@ -171,12 +214,28 @@ class BaseModelPool(
|
|
|
171
214
|
|
|
172
215
|
Returns:
|
|
173
216
|
str: The path for the specified model.
|
|
217
|
+
|
|
218
|
+
Raises:
|
|
219
|
+
ValidationError: If model_name is invalid.
|
|
220
|
+
KeyError: If model_name is not found in the pool.
|
|
221
|
+
ValueError: If model configuration is not a string path.
|
|
174
222
|
"""
|
|
223
|
+
# Validate model name
|
|
224
|
+
validate_model_name(model_name, allow_special=True)
|
|
225
|
+
|
|
226
|
+
if model_name not in self._models:
|
|
227
|
+
available_models = list(self._models.keys())
|
|
228
|
+
raise KeyError(
|
|
229
|
+
f"Model '{model_name}' not found in model pool. "
|
|
230
|
+
f"Available models: {available_models}"
|
|
231
|
+
)
|
|
232
|
+
|
|
175
233
|
if isinstance(self._models[model_name], str):
|
|
176
234
|
return self._models[model_name]
|
|
177
235
|
else:
|
|
178
236
|
raise ValueError(
|
|
179
|
-
"Model
|
|
237
|
+
f"Model configuration for '{model_name}' is not a string path. "
|
|
238
|
+
"Try to override this method in derived modelpool class."
|
|
180
239
|
)
|
|
181
240
|
|
|
182
241
|
def load_model(
|
|
@@ -357,3 +416,25 @@ class BaseModelPool(
|
|
|
357
416
|
"""
|
|
358
417
|
with timeit_context(f"Saving the state dict of model to {path}"):
|
|
359
418
|
torch.save(model.state_dict(), path)
|
|
419
|
+
|
|
420
|
+
def __contains__(self, model_name: str) -> bool:
|
|
421
|
+
"""
|
|
422
|
+
Check if a model with the given name exists in the model pool.
|
|
423
|
+
|
|
424
|
+
Examples:
|
|
425
|
+
>>> modelpool = BaseModelPool(models={"modelA": ..., "modelB": ...})
|
|
426
|
+
>>> "modelA" in modelpool
|
|
427
|
+
True
|
|
428
|
+
>>> "modelC" in modelpool
|
|
429
|
+
False
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
model_name (str): The name of the model to check.
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
bool: True if the model exists, False otherwise.
|
|
436
|
+
"""
|
|
437
|
+
if self._models is None:
|
|
438
|
+
raise RuntimeError("Model pool is not initialized")
|
|
439
|
+
validate_model_name(model_name, allow_special=True)
|
|
440
|
+
return model_name in self._models
|