fusion-bench 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. fusion_bench/constants/__init__.py +5 -1
  2. fusion_bench/constants/runtime.py +111 -7
  3. fusion_bench/dataset/gsm8k.py +6 -2
  4. fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
  5. fusion_bench/method/__init__.py +10 -2
  6. fusion_bench/method/base_algorithm.py +29 -19
  7. fusion_bench/method/classification/image_classification_finetune.py +1 -2
  8. fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
  9. fusion_bench/metrics/model_kinship/__init__.py +2 -0
  10. fusion_bench/metrics/model_kinship/calculate.py +77 -0
  11. fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
  12. fusion_bench/metrics/model_kinship/utility.py +184 -0
  13. fusion_bench/metrics/nyuv2/__init__.py +31 -0
  14. fusion_bench/metrics/nyuv2/depth.py +30 -0
  15. fusion_bench/metrics/nyuv2/loss.py +40 -0
  16. fusion_bench/metrics/nyuv2/noise.py +24 -0
  17. fusion_bench/metrics/nyuv2/normal.py +34 -1
  18. fusion_bench/metrics/nyuv2/segmentation.py +35 -1
  19. fusion_bench/mixins/clip_classification.py +30 -2
  20. fusion_bench/mixins/lightning_fabric.py +46 -5
  21. fusion_bench/mixins/rich_live.py +76 -0
  22. fusion_bench/modelpool/base_pool.py +86 -5
  23. fusion_bench/models/masks/mask_model.py +8 -2
  24. fusion_bench/models/open_clip/modeling.py +7 -0
  25. fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
  26. fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
  27. fusion_bench/scripts/cli.py +14 -0
  28. fusion_bench/scripts/webui.py +250 -17
  29. fusion_bench/utils/__init__.py +14 -0
  30. fusion_bench/utils/data.py +100 -9
  31. fusion_bench/utils/devices.py +3 -1
  32. fusion_bench/utils/fabric.py +185 -4
  33. fusion_bench/utils/instantiate_utils.py +29 -18
  34. fusion_bench/utils/json.py +6 -0
  35. fusion_bench/utils/misc.py +16 -0
  36. fusion_bench/utils/rich_utils.py +123 -6
  37. fusion_bench/utils/validation.py +197 -0
  38. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/METADATA +72 -13
  39. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/RECORD +49 -45
  40. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
  41. fusion_bench_config/llama_full_finetune.yaml +4 -16
  42. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  43. fusion_bench_config/nyuv2_config.yaml +4 -13
  44. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  45. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  46. fusion_bench/utils/auto.py +0 -31
  47. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/WHEEL +0 -0
  48. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/entry_points.txt +0 -0
  49. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/licenses/LICENSE +0 -0
  50. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,171 @@
1
+ import logging
2
+ from typing import Dict, List
3
+
4
+ import numpy
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from .utility import Metric, load_model_state_dict, quantize_8bit
9
+
10
+
11
+ def cosine_similarity(a, b):
12
+ similarity = numpy.sqrt(numpy.dot(a, b) ** 2 / (numpy.dot(a, a) * numpy.dot(b, b)))
13
+ return similarity
14
+
15
+
16
+ def calculate_model_kinship_split(
17
+ model_1_name: str,
18
+ model_2_name: str,
19
+ model_base_name: str,
20
+ low_precision: bool,
21
+ metrics: List[str],
22
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
23
+ ) -> dict:
24
+
25
+ # Extract state dictionaries from models
26
+ state_dict_1 = load_model_state_dict(model_1_name, device)
27
+ state_dict_2 = load_model_state_dict(model_2_name, device)
28
+ state_dict_base = load_model_state_dict(model_base_name, device)
29
+ results = {}
30
+
31
+ # Validate metrics before processing
32
+ valid_metrics = Metric.list()
33
+ for metric in metrics:
34
+ try:
35
+ if metric not in valid_metrics:
36
+ raise ValueError(
37
+ f"Unsupported metric: {metric}. Valid metrics are: {', '.join(valid_metrics)}"
38
+ )
39
+ results[metric] = calculate_metrics_by_split(
40
+ state_dict_1, state_dict_2, state_dict_base, low_precision, metric
41
+ )
42
+ except Exception as e:
43
+ logging.error(f"Error calculating {metric}: {str(e)}")
44
+ results[metric] = f"Error calculating {metric}: {str(e)}"
45
+
46
+ return results
47
+
48
+
49
+ def calculate_metrics_by_split(
50
+ state_dict_1: dict,
51
+ state_dict_2: dict,
52
+ state_dict_base: dict,
53
+ low_precision: bool,
54
+ metric: str,
55
+ ) -> str:
56
+ """
57
+ Calculate metrics for each key and integrate results.
58
+
59
+ Args:
60
+ state_dict_1 (dict): State dictionary of first model
61
+ state_dict_2 (dict): State dictionary of second model
62
+ state_dict_base (dict): State dictionary of base model
63
+ low_precision (bool): Whether to use 8-bit quantization
64
+ metric (str): Metric to calculate ('pcc', 'ed', 'cs')
65
+
66
+ Returns:
67
+ str: Integrated metric result as formatted string
68
+ """
69
+ total_similarity = 0.0
70
+ total_weight = 0.0
71
+ split_results = {}
72
+
73
+ # Determine the number of layers
74
+ num_layers = state_dict_base["lm_head.weight"].shape[0]
75
+
76
+ # Check architectures
77
+ if (
78
+ state_dict_1["lm_head.weight"].shape[0]
79
+ != state_dict_2["lm_head.weight"].shape[0]
80
+ ):
81
+ shape_1 = state_dict_1["lm_head.weight"].shape
82
+ shape_2 = state_dict_2["lm_head.weight"].shape
83
+ logging.warning(
84
+ f"Warning: Model architectures do not match. "
85
+ f"Using sub weight space instead.\n"
86
+ f"Vocab sizes in model 1: {shape_1[0]}, "
87
+ f"Vocab sizes in model 2: {shape_2[0]}"
88
+ )
89
+
90
+ # Process each key
91
+ for key, base_params in tqdm(
92
+ state_dict_base.items(), desc=f"Processing {metric.upper()} by key"
93
+ ):
94
+ try:
95
+ if key not in state_dict_1 or key not in state_dict_2:
96
+ logging.warning(f"Key {key} not found in one of the models")
97
+ continue
98
+
99
+ # Get parameters and calculate deltas
100
+ params_1 = state_dict_1[key][:num_layers]
101
+ params_2 = state_dict_2[key][:num_layers]
102
+
103
+ delta_1 = (params_1 - base_params).view(-1)
104
+ delta_2 = (params_2 - base_params).view(-1)
105
+
106
+ if low_precision:
107
+ delta_1 = quantize_8bit(delta_1)
108
+ delta_2 = quantize_8bit(delta_2)
109
+
110
+ # Calculate weight based on parameter count
111
+ weight = delta_1.numel()
112
+
113
+ # Calculate metric for current key
114
+ if metric == "pcc":
115
+ stack = torch.stack((delta_1, delta_2), dim=0)
116
+ split_similarity = torch.corrcoef(stack)[0, 1].item()
117
+ elif metric == "ed":
118
+ split_similarity = torch.dist(delta_1, delta_2).item()
119
+ elif metric == "cs":
120
+ split_similarity = cosine_similarity(delta_1, delta_2)
121
+ else:
122
+ raise ValueError(f"Unsupported metric: {metric}")
123
+
124
+ # Skip NaN values
125
+ if torch.isnan(torch.tensor(split_similarity)):
126
+ logging.warning(f"Skipping key {key} due to NaN result")
127
+ continue
128
+
129
+ # Store valid result
130
+ split_results[key] = split_similarity
131
+
132
+ # Update weighted average only for valid results
133
+ weight = delta_1.numel()
134
+ total_similarity += split_similarity * weight
135
+ total_weight += weight
136
+
137
+ # Log progress for large layers
138
+ if weight > 1000000:
139
+ logging.info(
140
+ f"Layer {key}: {metric.upper()} = {split_similarity:.4f}, parameters = {weight}"
141
+ )
142
+
143
+ # Free memory
144
+ del delta_1, delta_2
145
+
146
+ except Exception as e:
147
+ logging.error(f"Error processing key {key}: {str(e)}")
148
+ continue
149
+
150
+ # Calculate final weighted average
151
+ if total_weight > 0:
152
+ final_result = total_similarity / total_weight
153
+
154
+ # Log summary statistics
155
+ logging.info(f"\nSummary for {metric.upper()}:")
156
+ logging.info(f"Total parameters: {total_weight}")
157
+
158
+ # Log detailed results for valid splits
159
+ logging.info(f"\nDetailed {metric.upper()} results by key:")
160
+ for key, value in split_results.items():
161
+ logging.info(f"{key}: {value:.4f}")
162
+
163
+ metric_names = {
164
+ "pcc": "Pearson Correlation Coefficient",
165
+ "ed": "Euclidean Distance",
166
+ "cs": "Cosine Similarity",
167
+ }
168
+
169
+ return f"Model Kinship based on {metric_names[metric]} (weighted average): {final_result:.4f}"
170
+ else:
171
+ return f"Error: No valid parameters found for {metric.upper()} calculation"
@@ -0,0 +1,184 @@
1
+ import logging
2
+ from enum import Enum
3
+ from typing import List
4
+
5
+ import click
6
+ import torch
7
+ from tqdm import tqdm
8
+ from transformers import (
9
+ AutoConfig,
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ PretrainedConfig,
13
+ )
14
+
15
+
16
+ class Metric(str, Enum):
17
+ """Enumeration of supported metrics"""
18
+
19
+ PCC = "pcc"
20
+ ED = "ed"
21
+ CS = "cs"
22
+
23
+ @classmethod
24
+ def list(cls) -> List[str]:
25
+ """Return list of supported metric values"""
26
+ return [metric.value for metric in cls]
27
+
28
+
29
+ def get_config(model: str, trust_remote_code: bool = False) -> PretrainedConfig:
30
+ """
31
+ Fetch the configuration of a pretrained model from HuggingFace.
32
+
33
+ Args:
34
+ model (str): The name or path of the model to load configuration for.
35
+ trust_remote_code (bool, optional): Whether to trust remote code during loading.
36
+ Defaults to False.
37
+
38
+ Returns:
39
+ PretrainedConfig: The configuration object of the specified model.
40
+ """
41
+ # Fetch the configuration from HuggingFace's model hub.
42
+ config = AutoConfig.from_pretrained(
43
+ model,
44
+ trust_remote_code=trust_remote_code, # Whether to allow remote code execution.
45
+ )
46
+ return config
47
+
48
+
49
+ def validate_models(model_1: str, model_2: str, base_model: str) -> None:
50
+ """
51
+ Validate model names to ensure they are different and exist.
52
+
53
+ Args:
54
+ model_1: Name of the first model
55
+ model_2: Name of the second model
56
+ base_model: Name of the base model
57
+
58
+ Raises:
59
+ click.BadParameter: If validation fails
60
+ """
61
+ if model_1 == model_2 or model_1 == base_model or model_2 == base_model:
62
+ raise click.BadParameter("All model names must be different")
63
+
64
+
65
+ def quantize_8bit(x: torch.Tensor) -> torch.Tensor:
66
+ # Get absolute min and max values
67
+ abs_max = torch.max(torch.abs(x))
68
+
69
+ # Scale to [-127, 127] range for 8-bit signed integers
70
+ # Using 127 instead of 128 to keep zero exactly representable
71
+ scaled = 127 * (x / abs_max)
72
+
73
+ # Round to nearest integer
74
+ quantized = torch.round(scaled)
75
+
76
+ # Clamp values to ensure they stay in valid range
77
+ quantized = torch.clamp(quantized, -127, 127)
78
+
79
+ return quantized
80
+
81
+
82
+ def load_model_state_dict(model_name: str, device: str) -> dict:
83
+ """
84
+ Load a model and return its state dictionary.
85
+
86
+ Args:
87
+ model_name (str): Name or path of the model to load
88
+ device (str): Device to load the model on ('cuda' or 'cpu')
89
+
90
+ Returns:
91
+ dict: State dictionary of the loaded model
92
+ """
93
+ logging.info(f"Loading model: {model_name}")
94
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
95
+ state_dict = model.state_dict()
96
+ del model # Free memory
97
+ return state_dict
98
+
99
+
100
+ def extract_delta_parameters(
101
+ model_1_name: str,
102
+ model_2_name: str,
103
+ model_base_name: str,
104
+ low_precision: bool,
105
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
106
+ ) -> tuple[torch.Tensor, torch.Tensor]:
107
+ """
108
+ Extract the delta parameters (weight differences) between two models
109
+ relative to a base model.
110
+
111
+ Args:
112
+ model_1_name (str): Name or path of the first model.
113
+ model_2_name (str): Name or path of the second model.
114
+ model_base_name (str): Name or path of the base model for comparison.
115
+ low_precision (bool): Whether to use low precision weights
116
+
117
+ Returns:
118
+ (torch.Tensor, torch.Tensor): Delta parameters of model_1 and model_2 relative to base model.
119
+ """
120
+
121
+ # Extract state dictionaries from models
122
+ state_dict_1 = load_model_state_dict(model_1_name, device)
123
+ state_dict_2 = load_model_state_dict(model_2_name, device)
124
+ state_dict_base = load_model_state_dict(model_base_name, device)
125
+
126
+ # Determine the number of layers
127
+ num_layers = state_dict_base["lm_head.weight"].shape[0]
128
+
129
+ # Check if model architectures match, log a warning if not
130
+ if (
131
+ state_dict_1["lm_head.weight"].shape[0]
132
+ != state_dict_2["lm_head.weight"].shape[0]
133
+ ):
134
+ shape_1 = state_dict_1["lm_head.weight"].shape
135
+ shape_2 = state_dict_2["lm_head.weight"].shape
136
+ logging.warning(
137
+ f"Warning: Model architectures do not match. "
138
+ f"Using sub weight space instead.\n"
139
+ f"Vocab sizes in model 1: {shape_1[0]}, "
140
+ f"Vocab sizes in model 2: {shape_2[0]}"
141
+ )
142
+
143
+ # Initialize lists to store delta parameters for both models
144
+ d_vector_1, d_vector_2 = [], []
145
+
146
+ # Iterate over keys in the base model's state dictionary with tqdm
147
+ for key, base_params in tqdm(
148
+ state_dict_base.items(), desc="Processing keys", unit="key"
149
+ ):
150
+ # Only proceed if key exists in both models
151
+ try:
152
+ if key not in state_dict_1 or key not in state_dict_2:
153
+ logging.warning(f"Key {key} not found in one of the models")
154
+ continue
155
+ except Exception as e:
156
+ logging.error(f"Error processing key {key}: {str(e)}")
157
+
158
+ # Get the parameters for each model (truncate to num_layers for consistency)
159
+ params_1 = state_dict_1[key][:num_layers]
160
+ params_2 = state_dict_2[key][:num_layers]
161
+
162
+ # Compute the deltas relative to the base model
163
+ delta_1 = (params_1 - base_params).view(-1)
164
+ delta_2 = (params_2 - base_params).view(-1)
165
+
166
+ # Accumulate deltas
167
+ d_vector_1.append(delta_1)
168
+ d_vector_2.append(delta_2)
169
+
170
+ # Clear memory
171
+ del state_dict_1, state_dict_2, state_dict_base
172
+
173
+ logging.info("Concatenating delta vectors...")
174
+
175
+ d_vector_1 = torch.cat(d_vector_1)
176
+ d_vector_2 = torch.cat(d_vector_2)
177
+
178
+ if low_precision:
179
+ logging.info("Quantizing delta vectors to 8-bit precision...")
180
+ d_vector_1 = quantize_8bit(d_vector_1)
181
+ d_vector_2 = quantize_8bit(d_vector_2)
182
+ logging.info("Quantization complete")
183
+
184
+ return d_vector_1, d_vector_2
@@ -1,3 +1,34 @@
1
+ """
2
+ NYUv2 Dataset Metrics Module.
3
+
4
+ This module provides metric classes and loss functions for evaluating multi-task learning
5
+ models on the NYUv2 dataset. NYUv2 is a popular indoor scene understanding dataset that
6
+ includes multiple tasks: semantic segmentation, depth estimation, and surface normal prediction.
7
+
8
+ Available Metrics:
9
+ - SegmentationMetric: Computes mIoU and pixel accuracy for semantic segmentation.
10
+ - DepthMetric: Computes absolute and relative errors for depth estimation.
11
+ - NormalMetric: Computes angular errors for surface normal prediction.
12
+ - NoiseMetric: Placeholder metric for noise evaluation.
13
+
14
+ Usage:
15
+ ```python
16
+ from fusion_bench.metrics.nyuv2 import SegmentationMetric, DepthMetric
17
+
18
+ # Initialize metrics
19
+ seg_metric = SegmentationMetric(num_classes=13)
20
+ depth_metric = DepthMetric()
21
+
22
+ # Update with predictions and targets
23
+ seg_metric.update(seg_preds, seg_targets)
24
+ depth_metric.update(depth_preds, depth_targets)
25
+
26
+ # Compute final metrics
27
+ miou, pix_acc = seg_metric.compute()
28
+ abs_err, rel_err = depth_metric.compute()
29
+ ```
30
+ """
31
+
1
32
  from .depth import DepthMetric
2
33
  from .noise import NoiseMetric
3
34
  from .normal import NormalMetric
@@ -7,9 +7,23 @@ from torchmetrics import Metric
7
7
 
8
8
 
9
9
  class DepthMetric(Metric):
10
+ """
11
+ Metric for evaluating depth estimation performance on NYUv2 dataset.
12
+
13
+ This metric computes absolute error and relative error for depth predictions,
14
+ properly handling the binary mask to exclude invalid depth regions.
15
+
16
+ Attributes:
17
+ metric_names: List of metric names ["abs_err", "rel_err"].
18
+ abs_record: List storing absolute error values for each batch.
19
+ rel_record: List storing relative error values for each batch.
20
+ batch_size: List storing batch sizes for weighted averaging.
21
+ """
22
+
10
23
  metric_names = ["abs_err", "rel_err"]
11
24
 
12
25
  def __init__(self):
26
+ """Initialize the DepthMetric with state variables for tracking errors."""
13
27
  super().__init__()
14
28
 
15
29
  self.add_state("abs_record", default=[], dist_reduce_fx="cat")
@@ -17,11 +31,20 @@ class DepthMetric(Metric):
17
31
  self.add_state("batch_size", default=[], dist_reduce_fx="cat")
18
32
 
19
33
  def reset(self):
34
+ """Reset all metric states to empty lists."""
20
35
  self.abs_record = []
21
36
  self.rel_record = []
22
37
  self.batch_size = []
23
38
 
24
39
  def update(self, preds: Tensor, target: Tensor):
40
+ """
41
+ Update metric states with predictions and targets from a batch.
42
+
43
+ Args:
44
+ preds: Predicted depth values of shape (batch_size, 1, height, width).
45
+ target: Ground truth depth values of shape (batch_size, 1, height, width).
46
+ Pixels with sum of 0 are considered invalid and masked out.
47
+ """
25
48
  binary_mask = (torch.sum(target, dim=1) != 0).unsqueeze(1)
26
49
  preds = preds.masked_select(binary_mask)
27
50
  target = target.masked_select(binary_mask)
@@ -38,6 +61,13 @@ class DepthMetric(Metric):
38
61
  self.batch_size.append(torch.asarray(preds.size(0), device=preds.device))
39
62
 
40
63
  def compute(self):
64
+ """
65
+ Compute the final metric values across all batches.
66
+
67
+ Returns:
68
+ List[Tensor]: A list containing [absolute_error, relative_error],
69
+ where each value is the weighted average across all batches.
70
+ """
41
71
  records = torch.stack(
42
72
  [torch.stack(self.abs_record), torch.stack(self.rel_record)]
43
73
  )
@@ -3,10 +3,35 @@ from torch import Tensor, nn
3
3
 
4
4
 
5
5
  def segmentation_loss(pred: Tensor, gt: Tensor):
6
+ """
7
+ Compute cross-entropy loss for semantic segmentation.
8
+
9
+ Args:
10
+ pred: Predicted segmentation logits of shape (batch_size, num_classes, height, width).
11
+ gt: Ground truth segmentation labels of shape (batch_size, height, width).
12
+ Pixels with value -1 are ignored in the loss computation.
13
+
14
+ Returns:
15
+ Tensor: Scalar loss value.
16
+ """
6
17
  return nn.functional.cross_entropy(pred, gt.long(), ignore_index=-1)
7
18
 
8
19
 
9
20
  def depth_loss(pred: Tensor, gt: Tensor):
21
+ """
22
+ Compute L1 loss for depth estimation with binary masking.
23
+
24
+ This loss function calculates the absolute error between predicted and ground truth
25
+ depth values, but only for valid pixels (where ground truth depth is non-zero).
26
+
27
+ Args:
28
+ pred: Predicted depth values of shape (batch_size, 1, height, width).
29
+ gt: Ground truth depth values of shape (batch_size, 1, height, width).
30
+ Pixels with sum of 0 across channels are considered invalid and masked out.
31
+
32
+ Returns:
33
+ Tensor: Scalar loss value averaged over valid pixels.
34
+ """
10
35
  binary_mask = (torch.sum(gt, dim=1) != 0).float().unsqueeze(1).to(pred.device)
11
36
  loss = torch.sum(torch.abs(pred - gt) * binary_mask) / torch.nonzero(
12
37
  binary_mask, as_tuple=False
@@ -15,6 +40,21 @@ def depth_loss(pred: Tensor, gt: Tensor):
15
40
 
16
41
 
17
42
  def normal_loss(pred: Tensor, gt: Tensor):
43
+ """
44
+ Compute cosine similarity loss for surface normal prediction.
45
+
46
+ This loss measures the angular difference between predicted and ground truth
47
+ surface normals using normalized cosine similarity (1 - dot product).
48
+
49
+ Args:
50
+ pred: Predicted surface normals of shape (batch_size, 3, height, width).
51
+ Will be L2-normalized before computing loss.
52
+ gt: Ground truth surface normals of shape (batch_size, 3, height, width).
53
+ Already normalized on NYUv2 dataset. Pixels with sum of 0 are invalid.
54
+
55
+ Returns:
56
+ Tensor: Scalar loss value (1 - mean cosine similarity) over valid pixels.
57
+ """
18
58
  # gt has been normalized on the NYUv2 dataset
19
59
  pred = pred / torch.norm(pred, p=2, dim=1, keepdim=True)
20
60
  binary_mask = (torch.sum(gt, dim=1) != 0).float().unsqueeze(1).to(pred.device)
@@ -6,11 +6,35 @@ from torchmetrics import Metric
6
6
 
7
7
 
8
8
  class NoiseMetric(Metric):
9
+ """
10
+ A placeholder metric for noise evaluation on NYUv2 dataset.
11
+
12
+ This metric currently serves as a placeholder and always returns a value of 1.
13
+ It can be extended in the future to include actual noise-related metrics.
14
+
15
+ Note:
16
+ This is a dummy implementation that doesn't perform actual noise measurements.
17
+ """
18
+
9
19
  def __init__(self):
20
+ """Initialize the NoiseMetric."""
10
21
  super().__init__()
11
22
 
12
23
  def update(self, preds: Tensor, target: Tensor):
24
+ """
25
+ Update metric state (currently a no-op).
26
+
27
+ Args:
28
+ preds: Predicted values (unused).
29
+ target: Ground truth values (unused).
30
+ """
13
31
  pass
14
32
 
15
33
  def compute(self):
34
+ """
35
+ Compute the metric value.
36
+
37
+ Returns:
38
+ List[int]: A list containing [1] as a placeholder value.
39
+ """
16
40
  return [1]
@@ -7,14 +7,36 @@ from torchmetrics import Metric
7
7
 
8
8
 
9
9
  class NormalMetric(Metric):
10
+ """
11
+ Metric for evaluating surface normal prediction on NYUv2 dataset.
12
+
13
+ This metric computes angular error statistics between predicted and ground truth
14
+ surface normals, including mean, median, and percentage of predictions within
15
+ specific angular thresholds (11.25°, 22.5°, 30°).
16
+
17
+ Attributes:
18
+ metric_names: List of metric names ["mean", "median", "<11.25", "<22.5", "<30"].
19
+ record: List storing angular errors (in degrees) for all pixels across batches.
20
+ """
21
+
10
22
  metric_names = ["mean", "median", "<11.25", "<22.5", "<30"]
11
23
 
12
24
  def __init__(self):
25
+ """Initialize the NormalMetric with state for recording angular errors."""
13
26
  super(NormalMetric, self).__init__()
14
27
 
15
28
  self.add_state("record", default=[], dist_reduce_fx="cat")
16
29
 
17
30
  def update(self, preds, target):
31
+ """
32
+ Update metric state with predictions and targets from a batch.
33
+
34
+ Args:
35
+ preds: Predicted surface normals of shape (batch_size, 3, height, width).
36
+ Will be L2-normalized before computing errors.
37
+ target: Ground truth surface normals of shape (batch_size, 3, height, width).
38
+ Already normalized on NYUv2 dataset. Pixels with sum of 0 are invalid.
39
+ """
18
40
  # gt has been normalized on the NYUv2 dataset
19
41
  preds = preds / torch.norm(preds, p=2, dim=1, keepdim=True)
20
42
  binary_mask = torch.sum(target, dim=1) != 0
@@ -33,7 +55,18 @@ class NormalMetric(Metric):
33
55
 
34
56
  def compute(self):
35
57
  """
36
- returns mean, median, and percentage of pixels with error less than 11.25, 22.5, and 30 degrees ("mean", "median", "<11.25", "<22.5", "<30")
58
+ Compute final metric values from all recorded angular errors.
59
+
60
+ Returns:
61
+ List[Tensor]: A list containing five metrics:
62
+ - mean: Mean angular error in degrees.
63
+ - median: Median angular error in degrees.
64
+ - <11.25: Percentage of pixels with error < 11.25°.
65
+ - <22.5: Percentage of pixels with error < 22.5°.
66
+ - <30: Percentage of pixels with error < 30°.
67
+
68
+ Note:
69
+ Returns zeros if no data has been recorded.
37
70
  """
38
71
  if self.record is None:
39
72
  return torch.asarray([0.0, 0.0, 0.0, 0.0, 0.0])
@@ -6,9 +6,28 @@ from torchmetrics import Metric
6
6
 
7
7
 
8
8
  class SegmentationMetric(Metric):
9
+ """
10
+ Metric for evaluating semantic segmentation on NYUv2 dataset.
11
+
12
+ This metric computes mean Intersection over Union (mIoU) and pixel accuracy
13
+ for multi-class segmentation tasks.
14
+
15
+ Attributes:
16
+ metric_names: List of metric names ["mIoU", "pixAcc"].
17
+ num_classes: Number of segmentation classes (default: 13 for NYUv2).
18
+ record: Confusion matrix of shape (num_classes, num_classes) tracking
19
+ predictions vs ground truth.
20
+ """
21
+
9
22
  metric_names = ["mIoU", "pixAcc"]
10
23
 
11
24
  def __init__(self, num_classes=13):
25
+ """
26
+ Initialize the SegmentationMetric.
27
+
28
+ Args:
29
+ num_classes: Number of segmentation classes. Default is 13 for NYUv2 dataset.
30
+ """
12
31
  super().__init__()
13
32
 
14
33
  self.num_classes = num_classes
@@ -21,9 +40,19 @@ class SegmentationMetric(Metric):
21
40
  )
22
41
 
23
42
  def reset(self):
43
+ """Reset the confusion matrix to zeros."""
24
44
  self.record.zero_()
25
45
 
26
46
  def update(self, preds: Tensor, target: Tensor):
47
+ """
48
+ Update the confusion matrix with predictions and targets from a batch.
49
+
50
+ Args:
51
+ preds: Predicted segmentation logits of shape (batch_size, num_classes, height, width).
52
+ Will be converted to class predictions via softmax and argmax.
53
+ target: Ground truth segmentation labels of shape (batch_size, height, width).
54
+ Pixels with negative values or values >= num_classes are ignored.
55
+ """
27
56
  preds = preds.softmax(1).argmax(1).flatten()
28
57
  target = target.long().flatten()
29
58
 
@@ -35,7 +64,12 @@ class SegmentationMetric(Metric):
35
64
 
36
65
  def compute(self):
37
66
  """
38
- return mIoU and pixel accuracy
67
+ Compute mIoU and pixel accuracy from the confusion matrix.
68
+
69
+ Returns:
70
+ List[Tensor]: A list containing [mIoU, pixel_accuracy]:
71
+ - mIoU: Mean Intersection over Union across all classes.
72
+ - pixel_accuracy: Overall pixel classification accuracy.
39
73
  """
40
74
  h = cast(Tensor, self.record).float()
41
75
  iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))