fusion-bench 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/constants/__init__.py +5 -1
- fusion_bench/constants/runtime.py +111 -7
- fusion_bench/dataset/gsm8k.py +6 -2
- fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
- fusion_bench/method/__init__.py +10 -2
- fusion_bench/method/base_algorithm.py +29 -19
- fusion_bench/method/classification/image_classification_finetune.py +1 -2
- fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
- fusion_bench/metrics/model_kinship/__init__.py +2 -0
- fusion_bench/metrics/model_kinship/calculate.py +77 -0
- fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
- fusion_bench/metrics/model_kinship/utility.py +184 -0
- fusion_bench/metrics/nyuv2/__init__.py +31 -0
- fusion_bench/metrics/nyuv2/depth.py +30 -0
- fusion_bench/metrics/nyuv2/loss.py +40 -0
- fusion_bench/metrics/nyuv2/noise.py +24 -0
- fusion_bench/metrics/nyuv2/normal.py +34 -1
- fusion_bench/metrics/nyuv2/segmentation.py +35 -1
- fusion_bench/mixins/clip_classification.py +30 -2
- fusion_bench/mixins/lightning_fabric.py +46 -5
- fusion_bench/mixins/rich_live.py +76 -0
- fusion_bench/modelpool/base_pool.py +86 -5
- fusion_bench/models/masks/mask_model.py +8 -2
- fusion_bench/models/open_clip/modeling.py +7 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
- fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
- fusion_bench/scripts/cli.py +14 -0
- fusion_bench/scripts/webui.py +250 -17
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/data.py +100 -9
- fusion_bench/utils/devices.py +3 -1
- fusion_bench/utils/fabric.py +185 -4
- fusion_bench/utils/instantiate_utils.py +29 -18
- fusion_bench/utils/json.py +6 -0
- fusion_bench/utils/misc.py +16 -0
- fusion_bench/utils/rich_utils.py +123 -6
- fusion_bench/utils/validation.py +197 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/METADATA +72 -13
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/RECORD +49 -45
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
- fusion_bench_config/llama_full_finetune.yaml +4 -16
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +4 -13
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench/utils/auto.py +0 -31
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict, List
|
|
3
|
+
|
|
4
|
+
import numpy
|
|
5
|
+
import torch
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from .utility import Metric, load_model_state_dict, quantize_8bit
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def cosine_similarity(a, b):
|
|
12
|
+
similarity = numpy.sqrt(numpy.dot(a, b) ** 2 / (numpy.dot(a, a) * numpy.dot(b, b)))
|
|
13
|
+
return similarity
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def calculate_model_kinship_split(
|
|
17
|
+
model_1_name: str,
|
|
18
|
+
model_2_name: str,
|
|
19
|
+
model_base_name: str,
|
|
20
|
+
low_precision: bool,
|
|
21
|
+
metrics: List[str],
|
|
22
|
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
|
23
|
+
) -> dict:
|
|
24
|
+
|
|
25
|
+
# Extract state dictionaries from models
|
|
26
|
+
state_dict_1 = load_model_state_dict(model_1_name, device)
|
|
27
|
+
state_dict_2 = load_model_state_dict(model_2_name, device)
|
|
28
|
+
state_dict_base = load_model_state_dict(model_base_name, device)
|
|
29
|
+
results = {}
|
|
30
|
+
|
|
31
|
+
# Validate metrics before processing
|
|
32
|
+
valid_metrics = Metric.list()
|
|
33
|
+
for metric in metrics:
|
|
34
|
+
try:
|
|
35
|
+
if metric not in valid_metrics:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
f"Unsupported metric: {metric}. Valid metrics are: {', '.join(valid_metrics)}"
|
|
38
|
+
)
|
|
39
|
+
results[metric] = calculate_metrics_by_split(
|
|
40
|
+
state_dict_1, state_dict_2, state_dict_base, low_precision, metric
|
|
41
|
+
)
|
|
42
|
+
except Exception as e:
|
|
43
|
+
logging.error(f"Error calculating {metric}: {str(e)}")
|
|
44
|
+
results[metric] = f"Error calculating {metric}: {str(e)}"
|
|
45
|
+
|
|
46
|
+
return results
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def calculate_metrics_by_split(
|
|
50
|
+
state_dict_1: dict,
|
|
51
|
+
state_dict_2: dict,
|
|
52
|
+
state_dict_base: dict,
|
|
53
|
+
low_precision: bool,
|
|
54
|
+
metric: str,
|
|
55
|
+
) -> str:
|
|
56
|
+
"""
|
|
57
|
+
Calculate metrics for each key and integrate results.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
state_dict_1 (dict): State dictionary of first model
|
|
61
|
+
state_dict_2 (dict): State dictionary of second model
|
|
62
|
+
state_dict_base (dict): State dictionary of base model
|
|
63
|
+
low_precision (bool): Whether to use 8-bit quantization
|
|
64
|
+
metric (str): Metric to calculate ('pcc', 'ed', 'cs')
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
str: Integrated metric result as formatted string
|
|
68
|
+
"""
|
|
69
|
+
total_similarity = 0.0
|
|
70
|
+
total_weight = 0.0
|
|
71
|
+
split_results = {}
|
|
72
|
+
|
|
73
|
+
# Determine the number of layers
|
|
74
|
+
num_layers = state_dict_base["lm_head.weight"].shape[0]
|
|
75
|
+
|
|
76
|
+
# Check architectures
|
|
77
|
+
if (
|
|
78
|
+
state_dict_1["lm_head.weight"].shape[0]
|
|
79
|
+
!= state_dict_2["lm_head.weight"].shape[0]
|
|
80
|
+
):
|
|
81
|
+
shape_1 = state_dict_1["lm_head.weight"].shape
|
|
82
|
+
shape_2 = state_dict_2["lm_head.weight"].shape
|
|
83
|
+
logging.warning(
|
|
84
|
+
f"Warning: Model architectures do not match. "
|
|
85
|
+
f"Using sub weight space instead.\n"
|
|
86
|
+
f"Vocab sizes in model 1: {shape_1[0]}, "
|
|
87
|
+
f"Vocab sizes in model 2: {shape_2[0]}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Process each key
|
|
91
|
+
for key, base_params in tqdm(
|
|
92
|
+
state_dict_base.items(), desc=f"Processing {metric.upper()} by key"
|
|
93
|
+
):
|
|
94
|
+
try:
|
|
95
|
+
if key not in state_dict_1 or key not in state_dict_2:
|
|
96
|
+
logging.warning(f"Key {key} not found in one of the models")
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
# Get parameters and calculate deltas
|
|
100
|
+
params_1 = state_dict_1[key][:num_layers]
|
|
101
|
+
params_2 = state_dict_2[key][:num_layers]
|
|
102
|
+
|
|
103
|
+
delta_1 = (params_1 - base_params).view(-1)
|
|
104
|
+
delta_2 = (params_2 - base_params).view(-1)
|
|
105
|
+
|
|
106
|
+
if low_precision:
|
|
107
|
+
delta_1 = quantize_8bit(delta_1)
|
|
108
|
+
delta_2 = quantize_8bit(delta_2)
|
|
109
|
+
|
|
110
|
+
# Calculate weight based on parameter count
|
|
111
|
+
weight = delta_1.numel()
|
|
112
|
+
|
|
113
|
+
# Calculate metric for current key
|
|
114
|
+
if metric == "pcc":
|
|
115
|
+
stack = torch.stack((delta_1, delta_2), dim=0)
|
|
116
|
+
split_similarity = torch.corrcoef(stack)[0, 1].item()
|
|
117
|
+
elif metric == "ed":
|
|
118
|
+
split_similarity = torch.dist(delta_1, delta_2).item()
|
|
119
|
+
elif metric == "cs":
|
|
120
|
+
split_similarity = cosine_similarity(delta_1, delta_2)
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError(f"Unsupported metric: {metric}")
|
|
123
|
+
|
|
124
|
+
# Skip NaN values
|
|
125
|
+
if torch.isnan(torch.tensor(split_similarity)):
|
|
126
|
+
logging.warning(f"Skipping key {key} due to NaN result")
|
|
127
|
+
continue
|
|
128
|
+
|
|
129
|
+
# Store valid result
|
|
130
|
+
split_results[key] = split_similarity
|
|
131
|
+
|
|
132
|
+
# Update weighted average only for valid results
|
|
133
|
+
weight = delta_1.numel()
|
|
134
|
+
total_similarity += split_similarity * weight
|
|
135
|
+
total_weight += weight
|
|
136
|
+
|
|
137
|
+
# Log progress for large layers
|
|
138
|
+
if weight > 1000000:
|
|
139
|
+
logging.info(
|
|
140
|
+
f"Layer {key}: {metric.upper()} = {split_similarity:.4f}, parameters = {weight}"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Free memory
|
|
144
|
+
del delta_1, delta_2
|
|
145
|
+
|
|
146
|
+
except Exception as e:
|
|
147
|
+
logging.error(f"Error processing key {key}: {str(e)}")
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
# Calculate final weighted average
|
|
151
|
+
if total_weight > 0:
|
|
152
|
+
final_result = total_similarity / total_weight
|
|
153
|
+
|
|
154
|
+
# Log summary statistics
|
|
155
|
+
logging.info(f"\nSummary for {metric.upper()}:")
|
|
156
|
+
logging.info(f"Total parameters: {total_weight}")
|
|
157
|
+
|
|
158
|
+
# Log detailed results for valid splits
|
|
159
|
+
logging.info(f"\nDetailed {metric.upper()} results by key:")
|
|
160
|
+
for key, value in split_results.items():
|
|
161
|
+
logging.info(f"{key}: {value:.4f}")
|
|
162
|
+
|
|
163
|
+
metric_names = {
|
|
164
|
+
"pcc": "Pearson Correlation Coefficient",
|
|
165
|
+
"ed": "Euclidean Distance",
|
|
166
|
+
"cs": "Cosine Similarity",
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
return f"Model Kinship based on {metric_names[metric]} (weighted average): {final_result:.4f}"
|
|
170
|
+
else:
|
|
171
|
+
return f"Error: No valid parameters found for {metric.upper()} calculation"
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
import torch
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from transformers import (
|
|
9
|
+
AutoConfig,
|
|
10
|
+
AutoModelForCausalLM,
|
|
11
|
+
AutoTokenizer,
|
|
12
|
+
PretrainedConfig,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Metric(str, Enum):
|
|
17
|
+
"""Enumeration of supported metrics"""
|
|
18
|
+
|
|
19
|
+
PCC = "pcc"
|
|
20
|
+
ED = "ed"
|
|
21
|
+
CS = "cs"
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def list(cls) -> List[str]:
|
|
25
|
+
"""Return list of supported metric values"""
|
|
26
|
+
return [metric.value for metric in cls]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_config(model: str, trust_remote_code: bool = False) -> PretrainedConfig:
|
|
30
|
+
"""
|
|
31
|
+
Fetch the configuration of a pretrained model from HuggingFace.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
model (str): The name or path of the model to load configuration for.
|
|
35
|
+
trust_remote_code (bool, optional): Whether to trust remote code during loading.
|
|
36
|
+
Defaults to False.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
PretrainedConfig: The configuration object of the specified model.
|
|
40
|
+
"""
|
|
41
|
+
# Fetch the configuration from HuggingFace's model hub.
|
|
42
|
+
config = AutoConfig.from_pretrained(
|
|
43
|
+
model,
|
|
44
|
+
trust_remote_code=trust_remote_code, # Whether to allow remote code execution.
|
|
45
|
+
)
|
|
46
|
+
return config
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def validate_models(model_1: str, model_2: str, base_model: str) -> None:
|
|
50
|
+
"""
|
|
51
|
+
Validate model names to ensure they are different and exist.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
model_1: Name of the first model
|
|
55
|
+
model_2: Name of the second model
|
|
56
|
+
base_model: Name of the base model
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
click.BadParameter: If validation fails
|
|
60
|
+
"""
|
|
61
|
+
if model_1 == model_2 or model_1 == base_model or model_2 == base_model:
|
|
62
|
+
raise click.BadParameter("All model names must be different")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def quantize_8bit(x: torch.Tensor) -> torch.Tensor:
|
|
66
|
+
# Get absolute min and max values
|
|
67
|
+
abs_max = torch.max(torch.abs(x))
|
|
68
|
+
|
|
69
|
+
# Scale to [-127, 127] range for 8-bit signed integers
|
|
70
|
+
# Using 127 instead of 128 to keep zero exactly representable
|
|
71
|
+
scaled = 127 * (x / abs_max)
|
|
72
|
+
|
|
73
|
+
# Round to nearest integer
|
|
74
|
+
quantized = torch.round(scaled)
|
|
75
|
+
|
|
76
|
+
# Clamp values to ensure they stay in valid range
|
|
77
|
+
quantized = torch.clamp(quantized, -127, 127)
|
|
78
|
+
|
|
79
|
+
return quantized
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def load_model_state_dict(model_name: str, device: str) -> dict:
|
|
83
|
+
"""
|
|
84
|
+
Load a model and return its state dictionary.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
model_name (str): Name or path of the model to load
|
|
88
|
+
device (str): Device to load the model on ('cuda' or 'cpu')
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
dict: State dictionary of the loaded model
|
|
92
|
+
"""
|
|
93
|
+
logging.info(f"Loading model: {model_name}")
|
|
94
|
+
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
|
|
95
|
+
state_dict = model.state_dict()
|
|
96
|
+
del model # Free memory
|
|
97
|
+
return state_dict
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def extract_delta_parameters(
|
|
101
|
+
model_1_name: str,
|
|
102
|
+
model_2_name: str,
|
|
103
|
+
model_base_name: str,
|
|
104
|
+
low_precision: bool,
|
|
105
|
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
|
106
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
107
|
+
"""
|
|
108
|
+
Extract the delta parameters (weight differences) between two models
|
|
109
|
+
relative to a base model.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
model_1_name (str): Name or path of the first model.
|
|
113
|
+
model_2_name (str): Name or path of the second model.
|
|
114
|
+
model_base_name (str): Name or path of the base model for comparison.
|
|
115
|
+
low_precision (bool): Whether to use low precision weights
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
(torch.Tensor, torch.Tensor): Delta parameters of model_1 and model_2 relative to base model.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
# Extract state dictionaries from models
|
|
122
|
+
state_dict_1 = load_model_state_dict(model_1_name, device)
|
|
123
|
+
state_dict_2 = load_model_state_dict(model_2_name, device)
|
|
124
|
+
state_dict_base = load_model_state_dict(model_base_name, device)
|
|
125
|
+
|
|
126
|
+
# Determine the number of layers
|
|
127
|
+
num_layers = state_dict_base["lm_head.weight"].shape[0]
|
|
128
|
+
|
|
129
|
+
# Check if model architectures match, log a warning if not
|
|
130
|
+
if (
|
|
131
|
+
state_dict_1["lm_head.weight"].shape[0]
|
|
132
|
+
!= state_dict_2["lm_head.weight"].shape[0]
|
|
133
|
+
):
|
|
134
|
+
shape_1 = state_dict_1["lm_head.weight"].shape
|
|
135
|
+
shape_2 = state_dict_2["lm_head.weight"].shape
|
|
136
|
+
logging.warning(
|
|
137
|
+
f"Warning: Model architectures do not match. "
|
|
138
|
+
f"Using sub weight space instead.\n"
|
|
139
|
+
f"Vocab sizes in model 1: {shape_1[0]}, "
|
|
140
|
+
f"Vocab sizes in model 2: {shape_2[0]}"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Initialize lists to store delta parameters for both models
|
|
144
|
+
d_vector_1, d_vector_2 = [], []
|
|
145
|
+
|
|
146
|
+
# Iterate over keys in the base model's state dictionary with tqdm
|
|
147
|
+
for key, base_params in tqdm(
|
|
148
|
+
state_dict_base.items(), desc="Processing keys", unit="key"
|
|
149
|
+
):
|
|
150
|
+
# Only proceed if key exists in both models
|
|
151
|
+
try:
|
|
152
|
+
if key not in state_dict_1 or key not in state_dict_2:
|
|
153
|
+
logging.warning(f"Key {key} not found in one of the models")
|
|
154
|
+
continue
|
|
155
|
+
except Exception as e:
|
|
156
|
+
logging.error(f"Error processing key {key}: {str(e)}")
|
|
157
|
+
|
|
158
|
+
# Get the parameters for each model (truncate to num_layers for consistency)
|
|
159
|
+
params_1 = state_dict_1[key][:num_layers]
|
|
160
|
+
params_2 = state_dict_2[key][:num_layers]
|
|
161
|
+
|
|
162
|
+
# Compute the deltas relative to the base model
|
|
163
|
+
delta_1 = (params_1 - base_params).view(-1)
|
|
164
|
+
delta_2 = (params_2 - base_params).view(-1)
|
|
165
|
+
|
|
166
|
+
# Accumulate deltas
|
|
167
|
+
d_vector_1.append(delta_1)
|
|
168
|
+
d_vector_2.append(delta_2)
|
|
169
|
+
|
|
170
|
+
# Clear memory
|
|
171
|
+
del state_dict_1, state_dict_2, state_dict_base
|
|
172
|
+
|
|
173
|
+
logging.info("Concatenating delta vectors...")
|
|
174
|
+
|
|
175
|
+
d_vector_1 = torch.cat(d_vector_1)
|
|
176
|
+
d_vector_2 = torch.cat(d_vector_2)
|
|
177
|
+
|
|
178
|
+
if low_precision:
|
|
179
|
+
logging.info("Quantizing delta vectors to 8-bit precision...")
|
|
180
|
+
d_vector_1 = quantize_8bit(d_vector_1)
|
|
181
|
+
d_vector_2 = quantize_8bit(d_vector_2)
|
|
182
|
+
logging.info("Quantization complete")
|
|
183
|
+
|
|
184
|
+
return d_vector_1, d_vector_2
|
|
@@ -1,3 +1,34 @@
|
|
|
1
|
+
"""
|
|
2
|
+
NYUv2 Dataset Metrics Module.
|
|
3
|
+
|
|
4
|
+
This module provides metric classes and loss functions for evaluating multi-task learning
|
|
5
|
+
models on the NYUv2 dataset. NYUv2 is a popular indoor scene understanding dataset that
|
|
6
|
+
includes multiple tasks: semantic segmentation, depth estimation, and surface normal prediction.
|
|
7
|
+
|
|
8
|
+
Available Metrics:
|
|
9
|
+
- SegmentationMetric: Computes mIoU and pixel accuracy for semantic segmentation.
|
|
10
|
+
- DepthMetric: Computes absolute and relative errors for depth estimation.
|
|
11
|
+
- NormalMetric: Computes angular errors for surface normal prediction.
|
|
12
|
+
- NoiseMetric: Placeholder metric for noise evaluation.
|
|
13
|
+
|
|
14
|
+
Usage:
|
|
15
|
+
```python
|
|
16
|
+
from fusion_bench.metrics.nyuv2 import SegmentationMetric, DepthMetric
|
|
17
|
+
|
|
18
|
+
# Initialize metrics
|
|
19
|
+
seg_metric = SegmentationMetric(num_classes=13)
|
|
20
|
+
depth_metric = DepthMetric()
|
|
21
|
+
|
|
22
|
+
# Update with predictions and targets
|
|
23
|
+
seg_metric.update(seg_preds, seg_targets)
|
|
24
|
+
depth_metric.update(depth_preds, depth_targets)
|
|
25
|
+
|
|
26
|
+
# Compute final metrics
|
|
27
|
+
miou, pix_acc = seg_metric.compute()
|
|
28
|
+
abs_err, rel_err = depth_metric.compute()
|
|
29
|
+
```
|
|
30
|
+
"""
|
|
31
|
+
|
|
1
32
|
from .depth import DepthMetric
|
|
2
33
|
from .noise import NoiseMetric
|
|
3
34
|
from .normal import NormalMetric
|
|
@@ -7,9 +7,23 @@ from torchmetrics import Metric
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class DepthMetric(Metric):
|
|
10
|
+
"""
|
|
11
|
+
Metric for evaluating depth estimation performance on NYUv2 dataset.
|
|
12
|
+
|
|
13
|
+
This metric computes absolute error and relative error for depth predictions,
|
|
14
|
+
properly handling the binary mask to exclude invalid depth regions.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
metric_names: List of metric names ["abs_err", "rel_err"].
|
|
18
|
+
abs_record: List storing absolute error values for each batch.
|
|
19
|
+
rel_record: List storing relative error values for each batch.
|
|
20
|
+
batch_size: List storing batch sizes for weighted averaging.
|
|
21
|
+
"""
|
|
22
|
+
|
|
10
23
|
metric_names = ["abs_err", "rel_err"]
|
|
11
24
|
|
|
12
25
|
def __init__(self):
|
|
26
|
+
"""Initialize the DepthMetric with state variables for tracking errors."""
|
|
13
27
|
super().__init__()
|
|
14
28
|
|
|
15
29
|
self.add_state("abs_record", default=[], dist_reduce_fx="cat")
|
|
@@ -17,11 +31,20 @@ class DepthMetric(Metric):
|
|
|
17
31
|
self.add_state("batch_size", default=[], dist_reduce_fx="cat")
|
|
18
32
|
|
|
19
33
|
def reset(self):
|
|
34
|
+
"""Reset all metric states to empty lists."""
|
|
20
35
|
self.abs_record = []
|
|
21
36
|
self.rel_record = []
|
|
22
37
|
self.batch_size = []
|
|
23
38
|
|
|
24
39
|
def update(self, preds: Tensor, target: Tensor):
|
|
40
|
+
"""
|
|
41
|
+
Update metric states with predictions and targets from a batch.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
preds: Predicted depth values of shape (batch_size, 1, height, width).
|
|
45
|
+
target: Ground truth depth values of shape (batch_size, 1, height, width).
|
|
46
|
+
Pixels with sum of 0 are considered invalid and masked out.
|
|
47
|
+
"""
|
|
25
48
|
binary_mask = (torch.sum(target, dim=1) != 0).unsqueeze(1)
|
|
26
49
|
preds = preds.masked_select(binary_mask)
|
|
27
50
|
target = target.masked_select(binary_mask)
|
|
@@ -38,6 +61,13 @@ class DepthMetric(Metric):
|
|
|
38
61
|
self.batch_size.append(torch.asarray(preds.size(0), device=preds.device))
|
|
39
62
|
|
|
40
63
|
def compute(self):
|
|
64
|
+
"""
|
|
65
|
+
Compute the final metric values across all batches.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
List[Tensor]: A list containing [absolute_error, relative_error],
|
|
69
|
+
where each value is the weighted average across all batches.
|
|
70
|
+
"""
|
|
41
71
|
records = torch.stack(
|
|
42
72
|
[torch.stack(self.abs_record), torch.stack(self.rel_record)]
|
|
43
73
|
)
|
|
@@ -3,10 +3,35 @@ from torch import Tensor, nn
|
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
def segmentation_loss(pred: Tensor, gt: Tensor):
|
|
6
|
+
"""
|
|
7
|
+
Compute cross-entropy loss for semantic segmentation.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
pred: Predicted segmentation logits of shape (batch_size, num_classes, height, width).
|
|
11
|
+
gt: Ground truth segmentation labels of shape (batch_size, height, width).
|
|
12
|
+
Pixels with value -1 are ignored in the loss computation.
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
Tensor: Scalar loss value.
|
|
16
|
+
"""
|
|
6
17
|
return nn.functional.cross_entropy(pred, gt.long(), ignore_index=-1)
|
|
7
18
|
|
|
8
19
|
|
|
9
20
|
def depth_loss(pred: Tensor, gt: Tensor):
|
|
21
|
+
"""
|
|
22
|
+
Compute L1 loss for depth estimation with binary masking.
|
|
23
|
+
|
|
24
|
+
This loss function calculates the absolute error between predicted and ground truth
|
|
25
|
+
depth values, but only for valid pixels (where ground truth depth is non-zero).
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
pred: Predicted depth values of shape (batch_size, 1, height, width).
|
|
29
|
+
gt: Ground truth depth values of shape (batch_size, 1, height, width).
|
|
30
|
+
Pixels with sum of 0 across channels are considered invalid and masked out.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Tensor: Scalar loss value averaged over valid pixels.
|
|
34
|
+
"""
|
|
10
35
|
binary_mask = (torch.sum(gt, dim=1) != 0).float().unsqueeze(1).to(pred.device)
|
|
11
36
|
loss = torch.sum(torch.abs(pred - gt) * binary_mask) / torch.nonzero(
|
|
12
37
|
binary_mask, as_tuple=False
|
|
@@ -15,6 +40,21 @@ def depth_loss(pred: Tensor, gt: Tensor):
|
|
|
15
40
|
|
|
16
41
|
|
|
17
42
|
def normal_loss(pred: Tensor, gt: Tensor):
|
|
43
|
+
"""
|
|
44
|
+
Compute cosine similarity loss for surface normal prediction.
|
|
45
|
+
|
|
46
|
+
This loss measures the angular difference between predicted and ground truth
|
|
47
|
+
surface normals using normalized cosine similarity (1 - dot product).
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
pred: Predicted surface normals of shape (batch_size, 3, height, width).
|
|
51
|
+
Will be L2-normalized before computing loss.
|
|
52
|
+
gt: Ground truth surface normals of shape (batch_size, 3, height, width).
|
|
53
|
+
Already normalized on NYUv2 dataset. Pixels with sum of 0 are invalid.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Tensor: Scalar loss value (1 - mean cosine similarity) over valid pixels.
|
|
57
|
+
"""
|
|
18
58
|
# gt has been normalized on the NYUv2 dataset
|
|
19
59
|
pred = pred / torch.norm(pred, p=2, dim=1, keepdim=True)
|
|
20
60
|
binary_mask = (torch.sum(gt, dim=1) != 0).float().unsqueeze(1).to(pred.device)
|
|
@@ -6,11 +6,35 @@ from torchmetrics import Metric
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class NoiseMetric(Metric):
|
|
9
|
+
"""
|
|
10
|
+
A placeholder metric for noise evaluation on NYUv2 dataset.
|
|
11
|
+
|
|
12
|
+
This metric currently serves as a placeholder and always returns a value of 1.
|
|
13
|
+
It can be extended in the future to include actual noise-related metrics.
|
|
14
|
+
|
|
15
|
+
Note:
|
|
16
|
+
This is a dummy implementation that doesn't perform actual noise measurements.
|
|
17
|
+
"""
|
|
18
|
+
|
|
9
19
|
def __init__(self):
|
|
20
|
+
"""Initialize the NoiseMetric."""
|
|
10
21
|
super().__init__()
|
|
11
22
|
|
|
12
23
|
def update(self, preds: Tensor, target: Tensor):
|
|
24
|
+
"""
|
|
25
|
+
Update metric state (currently a no-op).
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
preds: Predicted values (unused).
|
|
29
|
+
target: Ground truth values (unused).
|
|
30
|
+
"""
|
|
13
31
|
pass
|
|
14
32
|
|
|
15
33
|
def compute(self):
|
|
34
|
+
"""
|
|
35
|
+
Compute the metric value.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
List[int]: A list containing [1] as a placeholder value.
|
|
39
|
+
"""
|
|
16
40
|
return [1]
|
|
@@ -7,14 +7,36 @@ from torchmetrics import Metric
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class NormalMetric(Metric):
|
|
10
|
+
"""
|
|
11
|
+
Metric for evaluating surface normal prediction on NYUv2 dataset.
|
|
12
|
+
|
|
13
|
+
This metric computes angular error statistics between predicted and ground truth
|
|
14
|
+
surface normals, including mean, median, and percentage of predictions within
|
|
15
|
+
specific angular thresholds (11.25°, 22.5°, 30°).
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
metric_names: List of metric names ["mean", "median", "<11.25", "<22.5", "<30"].
|
|
19
|
+
record: List storing angular errors (in degrees) for all pixels across batches.
|
|
20
|
+
"""
|
|
21
|
+
|
|
10
22
|
metric_names = ["mean", "median", "<11.25", "<22.5", "<30"]
|
|
11
23
|
|
|
12
24
|
def __init__(self):
|
|
25
|
+
"""Initialize the NormalMetric with state for recording angular errors."""
|
|
13
26
|
super(NormalMetric, self).__init__()
|
|
14
27
|
|
|
15
28
|
self.add_state("record", default=[], dist_reduce_fx="cat")
|
|
16
29
|
|
|
17
30
|
def update(self, preds, target):
|
|
31
|
+
"""
|
|
32
|
+
Update metric state with predictions and targets from a batch.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
preds: Predicted surface normals of shape (batch_size, 3, height, width).
|
|
36
|
+
Will be L2-normalized before computing errors.
|
|
37
|
+
target: Ground truth surface normals of shape (batch_size, 3, height, width).
|
|
38
|
+
Already normalized on NYUv2 dataset. Pixels with sum of 0 are invalid.
|
|
39
|
+
"""
|
|
18
40
|
# gt has been normalized on the NYUv2 dataset
|
|
19
41
|
preds = preds / torch.norm(preds, p=2, dim=1, keepdim=True)
|
|
20
42
|
binary_mask = torch.sum(target, dim=1) != 0
|
|
@@ -33,7 +55,18 @@ class NormalMetric(Metric):
|
|
|
33
55
|
|
|
34
56
|
def compute(self):
|
|
35
57
|
"""
|
|
36
|
-
|
|
58
|
+
Compute final metric values from all recorded angular errors.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
List[Tensor]: A list containing five metrics:
|
|
62
|
+
- mean: Mean angular error in degrees.
|
|
63
|
+
- median: Median angular error in degrees.
|
|
64
|
+
- <11.25: Percentage of pixels with error < 11.25°.
|
|
65
|
+
- <22.5: Percentage of pixels with error < 22.5°.
|
|
66
|
+
- <30: Percentage of pixels with error < 30°.
|
|
67
|
+
|
|
68
|
+
Note:
|
|
69
|
+
Returns zeros if no data has been recorded.
|
|
37
70
|
"""
|
|
38
71
|
if self.record is None:
|
|
39
72
|
return torch.asarray([0.0, 0.0, 0.0, 0.0, 0.0])
|
|
@@ -6,9 +6,28 @@ from torchmetrics import Metric
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class SegmentationMetric(Metric):
|
|
9
|
+
"""
|
|
10
|
+
Metric for evaluating semantic segmentation on NYUv2 dataset.
|
|
11
|
+
|
|
12
|
+
This metric computes mean Intersection over Union (mIoU) and pixel accuracy
|
|
13
|
+
for multi-class segmentation tasks.
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
metric_names: List of metric names ["mIoU", "pixAcc"].
|
|
17
|
+
num_classes: Number of segmentation classes (default: 13 for NYUv2).
|
|
18
|
+
record: Confusion matrix of shape (num_classes, num_classes) tracking
|
|
19
|
+
predictions vs ground truth.
|
|
20
|
+
"""
|
|
21
|
+
|
|
9
22
|
metric_names = ["mIoU", "pixAcc"]
|
|
10
23
|
|
|
11
24
|
def __init__(self, num_classes=13):
|
|
25
|
+
"""
|
|
26
|
+
Initialize the SegmentationMetric.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
num_classes: Number of segmentation classes. Default is 13 for NYUv2 dataset.
|
|
30
|
+
"""
|
|
12
31
|
super().__init__()
|
|
13
32
|
|
|
14
33
|
self.num_classes = num_classes
|
|
@@ -21,9 +40,19 @@ class SegmentationMetric(Metric):
|
|
|
21
40
|
)
|
|
22
41
|
|
|
23
42
|
def reset(self):
|
|
43
|
+
"""Reset the confusion matrix to zeros."""
|
|
24
44
|
self.record.zero_()
|
|
25
45
|
|
|
26
46
|
def update(self, preds: Tensor, target: Tensor):
|
|
47
|
+
"""
|
|
48
|
+
Update the confusion matrix with predictions and targets from a batch.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
preds: Predicted segmentation logits of shape (batch_size, num_classes, height, width).
|
|
52
|
+
Will be converted to class predictions via softmax and argmax.
|
|
53
|
+
target: Ground truth segmentation labels of shape (batch_size, height, width).
|
|
54
|
+
Pixels with negative values or values >= num_classes are ignored.
|
|
55
|
+
"""
|
|
27
56
|
preds = preds.softmax(1).argmax(1).flatten()
|
|
28
57
|
target = target.long().flatten()
|
|
29
58
|
|
|
@@ -35,7 +64,12 @@ class SegmentationMetric(Metric):
|
|
|
35
64
|
|
|
36
65
|
def compute(self):
|
|
37
66
|
"""
|
|
38
|
-
|
|
67
|
+
Compute mIoU and pixel accuracy from the confusion matrix.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
List[Tensor]: A list containing [mIoU, pixel_accuracy]:
|
|
71
|
+
- mIoU: Mean Intersection over Union across all classes.
|
|
72
|
+
- pixel_accuracy: Overall pixel classification accuracy.
|
|
39
73
|
"""
|
|
40
74
|
h = cast(Tensor, self.record).float()
|
|
41
75
|
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|