fusion-bench 0.2.28__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. fusion_bench/constants/__init__.py +5 -1
  2. fusion_bench/constants/runtime.py +111 -7
  3. fusion_bench/dataset/gsm8k.py +6 -2
  4. fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
  5. fusion_bench/method/__init__.py +1 -1
  6. fusion_bench/method/classification/image_classification_finetune.py +1 -2
  7. fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
  8. fusion_bench/metrics/nyuv2/__init__.py +31 -0
  9. fusion_bench/metrics/nyuv2/depth.py +30 -0
  10. fusion_bench/metrics/nyuv2/loss.py +40 -0
  11. fusion_bench/metrics/nyuv2/noise.py +24 -0
  12. fusion_bench/metrics/nyuv2/normal.py +34 -1
  13. fusion_bench/metrics/nyuv2/segmentation.py +35 -1
  14. fusion_bench/mixins/clip_classification.py +30 -2
  15. fusion_bench/mixins/lightning_fabric.py +46 -5
  16. fusion_bench/mixins/rich_live.py +76 -0
  17. fusion_bench/modelpool/base_pool.py +86 -5
  18. fusion_bench/scripts/webui.py +250 -17
  19. fusion_bench/utils/__init__.py +14 -0
  20. fusion_bench/utils/data.py +100 -9
  21. fusion_bench/utils/fabric.py +185 -4
  22. fusion_bench/utils/json.py +6 -0
  23. fusion_bench/utils/validation.py +197 -0
  24. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/METADATA +66 -7
  25. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/RECORD +35 -35
  26. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
  27. fusion_bench_config/llama_full_finetune.yaml +4 -16
  28. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  29. fusion_bench_config/nyuv2_config.yaml +4 -13
  30. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  31. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  32. fusion_bench/utils/auto.py +0 -31
  33. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/WHEEL +0 -0
  34. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/entry_points.txt +0 -0
  35. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/licenses/LICENSE +0 -0
  36. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.29.dist-info}/top_level.txt +0 -0
@@ -5,4 +5,8 @@ from .paths import *
5
5
  from .runtime import RuntimeConstants
6
6
 
7
7
  # fusionbench version
8
- FUSION_BENCH_VERSION = importlib.metadata.version("fusion-bench")
8
+ try:
9
+ FUSION_BENCH_VERSION = importlib.metadata.version("fusion-bench")
10
+ except importlib.metadata.PackageNotFoundError:
11
+ # Fallback when package is not installed (e.g., during development)
12
+ FUSION_BENCH_VERSION = "0.0.0.dev"
@@ -1,3 +1,29 @@
1
+ """
2
+ Runtime Constants Module.
3
+
4
+ This module provides a thread-safe singleton class for managing runtime configuration
5
+ and constants across the Fusion Bench framework. It centralizes access to runtime
6
+ settings like cache directories, debug flags, and logging preferences.
7
+
8
+ Example:
9
+ ```python
10
+ from fusion_bench.constants.runtime import RuntimeConstants
11
+
12
+ # Get the singleton instance
13
+ runtime = RuntimeConstants()
14
+
15
+ # Configure cache directory
16
+ runtime.cache_dir = "/custom/cache/path"
17
+
18
+ # Enable debug mode
19
+ runtime.debug = True
20
+
21
+ # Control function call logging
22
+ runtime.print_function_call = True
23
+ ```
24
+ """
25
+
26
+ import os
1
27
  import threading
2
28
  from pathlib import Path
3
29
  from typing import Optional, Union
@@ -5,18 +31,46 @@ from typing import Optional, Union
5
31
 
6
32
  class RuntimeConstants:
7
33
  """
8
- This class holds constants related to the runtime environment of the Fusion Bench framework.
9
- It includes default values for cache directories and other runtime configurations.
34
+ Thread-safe singleton for managing runtime configuration and constants.
35
+
36
+ This class provides centralized access to runtime settings that affect the
37
+ behavior of the entire Fusion Bench framework. It ensures consistent
38
+ configuration across all modules and supports thread-safe access in
39
+ multi-threaded environments.
40
+
41
+ Attributes:
42
+ debug: Global debug flag for enabling verbose logging and debugging features.
43
+
44
+ Example:
45
+ ```python
46
+ runtime = RuntimeConstants()
47
+
48
+ # Configure caching
49
+ runtime.cache_dir = Path.home() / ".cache" / "fusion_bench"
50
+
51
+ # Enable debugging
52
+ runtime.debug = True
53
+ runtime.print_function_call = True
54
+ ```
10
55
 
11
- Implemented as a thread-safe singleton to ensure consistent runtime configuration
12
- across the entire application.
56
+ Note:
57
+ This class implements the singleton pattern with thread-safe initialization.
58
+ Multiple calls to the constructor will return the same instance.
13
59
  """
14
60
 
15
61
  _instance: Optional["RuntimeConstants"] = None
16
62
  _lock = threading.Lock()
17
63
 
18
64
  def __new__(cls) -> "RuntimeConstants":
19
- """Create a new instance using singleton pattern with thread safety."""
65
+ """
66
+ Create or return the singleton instance with thread safety.
67
+
68
+ Uses double-check locking pattern to ensure thread-safe singleton creation
69
+ while minimizing synchronization overhead.
70
+
71
+ Returns:
72
+ The singleton RuntimeConstants instance.
73
+ """
20
74
  with cls._lock:
21
75
  # Double-check locking pattern
22
76
  if cls._instance is None:
@@ -25,33 +79,83 @@ class RuntimeConstants:
25
79
  return cls._instance
26
80
 
27
81
  def __init__(self):
28
- """Initialize the singleton instance only once."""
82
+ """
83
+ Initialize the singleton instance only once.
84
+
85
+ Subsequent calls to __init__ are no-ops to maintain singleton behavior.
86
+ """
29
87
  if not self._initialized:
30
- # Add your runtime constants here
88
+ # Initialize default values
31
89
  self._initialized = True
32
90
 
33
91
  debug = False
92
+ """Global debug flag for enabling verbose logging and debugging features."""
34
93
 
35
94
  @property
36
95
  def cache_dir(self) -> Path:
96
+ """
97
+ Get the default cache directory for models and datasets.
98
+
99
+ Returns:
100
+ Path object pointing to the cache directory.
101
+
102
+ Example:
103
+ ```python
104
+ runtime = RuntimeConstants()
105
+ print(f"Cache directory: {runtime.cache_dir}")
106
+ ```
107
+ """
37
108
  from fusion_bench.utils.cache_utils import DEFAULT_CACHE_DIR
38
109
 
39
110
  return DEFAULT_CACHE_DIR
40
111
 
41
112
  @cache_dir.setter
42
113
  def cache_dir(self, path: Union[str, Path]) -> None:
114
+ """
115
+ Set the default cache directory for models and datasets.
116
+
117
+ Args:
118
+ path: New cache directory path as string or Path object.
119
+
120
+ Example:
121
+ ```python
122
+ runtime = RuntimeConstants()
123
+ runtime.cache_dir = "/data/fusion_bench_cache"
124
+ ```
125
+ """
43
126
  from fusion_bench.utils.cache_utils import set_default_cache_dir
44
127
 
45
128
  set_default_cache_dir(path)
46
129
 
47
130
  @property
48
131
  def print_function_call(self) -> bool:
132
+ """
133
+ Get whether function calls are printed during instantiation.
134
+
135
+ Returns:
136
+ True if function call printing is enabled, False otherwise.
137
+ """
49
138
  from fusion_bench.utils.instantiate_utils import PRINT_FUNCTION_CALL
50
139
 
51
140
  return PRINT_FUNCTION_CALL
52
141
 
53
142
  @print_function_call.setter
54
143
  def print_function_call(self, enable: bool) -> None:
144
+ """
145
+ Set whether to print function calls during instantiation.
146
+
147
+ Useful for debugging to see which functions are being called
148
+ when instantiating objects from configuration.
149
+
150
+ Args:
151
+ enable: True to enable printing, False to disable.
152
+
153
+ Example:
154
+ ```python
155
+ runtime = RuntimeConstants()
156
+ runtime.print_function_call = True # Enable verbose logging
157
+ ```
158
+ """
55
159
  from fusion_bench.utils.instantiate_utils import set_print_function_call
56
160
 
57
161
  set_print_function_call(enable)
@@ -13,8 +13,12 @@ def load_gsm8k_question_label_data(
13
13
 
14
14
  An example in the dataset:
15
15
 
16
- {'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
17
- 'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'}
16
+ ```python
17
+ {
18
+ 'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
19
+ 'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'
20
+ }
21
+ ```
18
22
 
19
23
  Args:
20
24
  dataset_name (Literal["train", "test", "train_socratic", "test_socratic"]): The name of the dataset to load.
@@ -1,4 +1,31 @@
1
1
  # -*- coding: utf-8 -*-
2
+ """
3
+ Image Corruption Module for Robustness Testing.
4
+
5
+ This module provides various image corruption functions to test model robustness.
6
+ It implements common corruptions such as noise, blur, compression artifacts, and
7
+ weather effects. These corruptions are commonly used in benchmark datasets like
8
+ ImageNet-C and CIFAR-10-C.
9
+
10
+ The corruptions can be applied at different severity levels (1-5), where higher
11
+ levels indicate stronger corruption effects.
12
+
13
+ Example:
14
+ ```python
15
+ from PIL import Image
16
+ from fusion_bench.dataset.image_corruption.make_corruption import gaussian_noise, motion_blur
17
+
18
+ # Load an image
19
+ img = Image.open("example.jpg")
20
+
21
+ # Apply gaussian noise at severity level 3
22
+ corrupted_img = gaussian_noise(img, severity=3)
23
+
24
+ # Apply motion blur at severity level 2
25
+ blurred_img = motion_blur(img, severity=2)
26
+ ```
27
+ """
28
+
2
29
  import logging
3
30
 
4
31
  logger = logging.getLogger(__name__)
@@ -37,11 +64,39 @@ warnings.simplefilter("ignore", UserWarning)
37
64
 
38
65
  # /////////////// Distortions ///////////////
39
66
  class MotionImage(WandImage):
67
+ """
68
+ Extended WandImage class with motion blur capability.
69
+
70
+ This class wraps ImageMagick's motion blur functionality through the Wand library.
71
+ """
72
+
40
73
  def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
74
+ """
75
+ Apply motion blur effect to the image.
76
+
77
+ Args:
78
+ radius: The radius of the Gaussian, in pixels, not counting the center pixel.
79
+ sigma: The standard deviation of the Gaussian, in pixels.
80
+ angle: Apply the effect along this angle in degrees.
81
+ """
41
82
  wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)
42
83
 
43
84
 
44
85
  def gaussian_noise(x, severity=1):
86
+ """
87
+ Apply Gaussian noise corruption to an image.
88
+
89
+ Adds random Gaussian noise to the image, simulating sensor noise or
90
+ environmental interference.
91
+
92
+ Args:
93
+ x: Input image as PIL Image or numpy array. If numpy array, should be in
94
+ range [0, 255].
95
+ severity: Corruption severity level from 1 (mild) to 5 (severe).
96
+
97
+ Returns:
98
+ numpy.ndarray: Corrupted image as numpy array in range [0, 255].
99
+ """
45
100
  c = [0.04, 0.06, 0.08, 0.09, 0.10][severity - 1]
46
101
 
47
102
  x = np.array(x) / 255.0
@@ -49,6 +104,20 @@ def gaussian_noise(x, severity=1):
49
104
 
50
105
 
51
106
  def impulse_noise(x, severity=1):
107
+ """
108
+ Apply impulse (salt-and-pepper) noise corruption to an image.
109
+
110
+ Randomly replaces pixels with either maximum or minimum intensity values,
111
+ simulating transmission errors or faulty pixels.
112
+
113
+ Args:
114
+ x: Input image as PIL Image or numpy array. If numpy array, should be in
115
+ range [0, 255].
116
+ severity: Corruption severity level from 1 (mild) to 5 (severe).
117
+
118
+ Returns:
119
+ numpy.ndarray: Corrupted image as numpy array in range [0, 255].
120
+ """
52
121
  c = [0.01, 0.02, 0.03, 0.05, 0.07][severity - 1]
53
122
 
54
123
  x = sk.util.random_noise(np.array(x) / 255.0, mode="s&p", amount=c)
@@ -56,6 +125,21 @@ def impulse_noise(x, severity=1):
56
125
 
57
126
 
58
127
  def motion_blur(x, severity=1):
128
+ """
129
+ Apply motion blur corruption to an image.
130
+
131
+ Simulates camera shake or object motion during image capture by applying
132
+ directional blur at a random angle.
133
+
134
+ Args:
135
+ x: Input PIL Image.
136
+ severity: Corruption severity level from 1 (mild) to 5 (severe).
137
+ Higher severity increases blur radius and sigma.
138
+
139
+ Returns:
140
+ numpy.ndarray: Corrupted image as numpy array in range [0, 255].
141
+ Returns RGB image regardless of input format.
142
+ """
59
143
  c = [(6, 1), (6, 1.5), (6, 2), (8, 2), (9, 2.5)][severity - 1]
60
144
 
61
145
  output = BytesIO()
@@ -73,6 +157,21 @@ def motion_blur(x, severity=1):
73
157
 
74
158
 
75
159
  def spatter(x, severity=1):
160
+ """
161
+ Apply spatter corruption to an image.
162
+
163
+ Simulates liquid splatter effects (water or mud) on the image, creating
164
+ realistic occlusions similar to raindrops or dirt on a camera lens.
165
+
166
+ Args:
167
+ x: Input image as PIL Image or numpy array. If numpy array, should be in
168
+ range [0, 255].
169
+ severity: Corruption severity level from 1 (mild) to 5 (severe).
170
+ Levels 1-3 simulate water splatter, levels 4-5 simulate mud splatter.
171
+
172
+ Returns:
173
+ numpy.ndarray: Corrupted image as numpy array in range [0, 255].
174
+ """
76
175
  c = [
77
176
  (0.62, 0.1, 0.7, 0.7, 0.5, 0),
78
177
  (0.65, 0.1, 0.8, 0.7, 0.5, 0),
@@ -140,6 +239,21 @@ def spatter(x, severity=1):
140
239
 
141
240
 
142
241
  def contrast(x, severity=1):
242
+ """
243
+ Apply contrast reduction corruption to an image.
244
+
245
+ Reduces image contrast by blending pixels toward their mean values,
246
+ simulating poor lighting conditions or low-quality image sensors.
247
+
248
+ Args:
249
+ x: Input image as PIL Image or numpy array. If numpy array, should be in
250
+ range [0, 255].
251
+ severity: Corruption severity level from 1 (mild) to 5 (severe).
252
+ Higher severity results in lower contrast.
253
+
254
+ Returns:
255
+ numpy.ndarray: Corrupted image as numpy array in range [0, 255].
256
+ """
143
257
  c = [0.75, 0.5, 0.4, 0.3, 0.15][severity - 1]
144
258
 
145
259
  x = np.array(x) / 255.0
@@ -148,6 +262,20 @@ def contrast(x, severity=1):
148
262
 
149
263
 
150
264
  def jpeg_compression(x, severity=1):
265
+ """
266
+ Apply JPEG compression artifacts to an image.
267
+
268
+ Simulates compression artifacts from lossy JPEG encoding at various
269
+ quality levels, commonly seen in heavily compressed images.
270
+
271
+ Args:
272
+ x: Input PIL Image.
273
+ severity: Corruption severity level from 1 (mild) to 5 (severe).
274
+ Lower severity uses higher JPEG quality (less compression).
275
+
276
+ Returns:
277
+ PIL.Image: Corrupted image as PIL Image.
278
+ """
151
279
  c = [80, 65, 58, 50, 40][severity - 1]
152
280
 
153
281
  output = BytesIO()
@@ -158,6 +286,23 @@ def jpeg_compression(x, severity=1):
158
286
 
159
287
 
160
288
  def pixelate(x, severity=1):
289
+ """
290
+ Apply pixelation corruption to an image.
291
+
292
+ Reduces image resolution by downsampling and then upsampling,
293
+ creating a blocky, pixelated appearance.
294
+
295
+ Args:
296
+ x: Input PIL Image with size (32, 32).
297
+ severity: Corruption severity level from 1 (mild) to 5 (severe).
298
+ Higher severity results in more pixelation.
299
+
300
+ Returns:
301
+ PIL.Image: Corrupted image as PIL Image with original size (32, 32).
302
+
303
+ Note:
304
+ This function is designed for 32x32 images (e.g., CIFAR-10).
305
+ """
161
306
  c = [0.95, 0.9, 0.85, 0.75, 0.65][severity - 1]
162
307
 
163
308
  x = x.resize((int(32 * c), int(32 * c)), PILImage.BOX)
@@ -170,6 +315,29 @@ def pixelate(x, severity=1):
170
315
 
171
316
 
172
317
  distortion_methods = collections.OrderedDict()
318
+ """
319
+ OrderedDict mapping corruption names to their corresponding functions.
320
+
321
+ Available corruptions:
322
+ - "Gaussian Noise": Additive Gaussian noise
323
+ - "Impulse Noise": Salt-and-pepper noise
324
+ - "Motion Blur": Directional motion blur
325
+ - "Contrast": Reduced contrast
326
+ - "Pixelate": Resolution reduction
327
+ - "JPEG": JPEG compression artifacts
328
+ - "Spatter": Water or mud splatter effects
329
+
330
+ Example:
331
+ ```python
332
+ from PIL import Image
333
+ from fusion_bench.dataset.image_corruption.make_corruption import distortion_methods
334
+
335
+ img = Image.open("example.jpg")
336
+ for name, corruption_fn in distortion_methods.items():
337
+ corrupted = corruption_fn(img, severity=3)
338
+ # Process corrupted image
339
+ ```
340
+ """
173
341
  distortion_methods["Gaussian Noise"] = gaussian_noise
174
342
  distortion_methods["Impulse Noise"] = impulse_noise
175
343
  distortion_methods["Motion Blur"] = motion_blur
@@ -167,6 +167,7 @@ if TYPE_CHECKING:
167
167
  from .dawe import DataAdaptiveWeightEnsemblingForCLIP
168
168
  from .depth_upscaling import DepthUpscalingAlgorithm, DepthUpscalingForLlama
169
169
  from .doge_ta import DOGE_TA_Algorithm
170
+ from .dop import ContinualDOPForCLIP
170
171
  from .dummy import DummyAlgorithm
171
172
  from .ensemble import (
172
173
  MaxModelPredictorAlgorithm,
@@ -215,7 +216,6 @@ if TYPE_CHECKING:
215
216
  from .model_recombination import ModelRecombinationAlgorithm
216
217
  from .model_stock import ModelStock
217
218
  from .opcm import OPCMForCLIP
218
- from .dop import ContinualDOPForCLIP
219
219
  from .pruning import (
220
220
  MagnitudeDiffPruningAlgorithm,
221
221
  MagnitudePruningForLlama,
@@ -15,7 +15,7 @@ from lightning_utilities.core.rank_zero import rank_zero_only
15
15
  from lit_learn.lit_modules import ERM_LitModule
16
16
  from omegaconf import DictConfig
17
17
  from torch import nn
18
- from torch.utils.data import DataLoader
18
+ from torch.utils.data import DataLoader, random_split
19
19
  from torchmetrics.classification import Accuracy
20
20
 
21
21
  from fusion_bench import (
@@ -29,7 +29,6 @@ from fusion_bench import (
29
29
  from fusion_bench.dataset import CLIPDataset
30
30
  from fusion_bench.modelpool import ResNetForImageClassificationPool
31
31
  from fusion_bench.tasks.clip_classification import get_num_classes
32
- from torch.utils.data import random_split
33
32
 
34
33
  log = get_rankzero_logger(__name__)
35
34
 
@@ -13,41 +13,13 @@ from fusion_bench.modelpool import CLIPVisionModelPool
13
13
  from fusion_bench.models.hf_clip import HFCLIPClassifier
14
14
  from fusion_bench.tasks.clip_classification import get_classnames_and_templates
15
15
  from fusion_bench.utils import timeit_context
16
+ from fusion_bench.utils.data import InfiniteDataLoader
16
17
 
17
18
  from .task_wise_gossip import TaskWiseGossipAlgorithm
18
19
 
19
20
  log = logging.getLogger(__name__)
20
21
 
21
22
 
22
- class InfiniteDataLoader:
23
- """
24
- A wrapper class for DataLoader to create an infinite data loader.
25
- This is useful in case we are only interested in the number of steps and not the number of epochs.
26
-
27
- This class wraps a DataLoader and provides an iterator that resets
28
- when the end of the dataset is reached, creating an infinite loop.
29
-
30
- Attributes:
31
- data_loader (DataLoader): The DataLoader to wrap.
32
- data_iter (iterator): An iterator over the DataLoader.
33
- """
34
-
35
- def __init__(self, data_loader):
36
- self.data_loader = data_loader
37
- self.data_iter = iter(data_loader)
38
-
39
- def __iter__(self):
40
- return self
41
-
42
- def __next__(self):
43
- try:
44
- data = next(self.data_iter)
45
- except StopIteration:
46
- self.data_iter = iter(self.data_loader) # Reset the data loader
47
- data = next(self.data_iter)
48
- return data
49
-
50
-
51
23
  class CLIPTaskWiseGossipAlgorithm(TaskWiseGossipAlgorithm):
52
24
  """
53
25
  A class for task-wise adaptive merging of CLIP models.
@@ -1,3 +1,34 @@
1
+ """
2
+ NYUv2 Dataset Metrics Module.
3
+
4
+ This module provides metric classes and loss functions for evaluating multi-task learning
5
+ models on the NYUv2 dataset. NYUv2 is a popular indoor scene understanding dataset that
6
+ includes multiple tasks: semantic segmentation, depth estimation, and surface normal prediction.
7
+
8
+ Available Metrics:
9
+ - SegmentationMetric: Computes mIoU and pixel accuracy for semantic segmentation.
10
+ - DepthMetric: Computes absolute and relative errors for depth estimation.
11
+ - NormalMetric: Computes angular errors for surface normal prediction.
12
+ - NoiseMetric: Placeholder metric for noise evaluation.
13
+
14
+ Usage:
15
+ ```python
16
+ from fusion_bench.metrics.nyuv2 import SegmentationMetric, DepthMetric
17
+
18
+ # Initialize metrics
19
+ seg_metric = SegmentationMetric(num_classes=13)
20
+ depth_metric = DepthMetric()
21
+
22
+ # Update with predictions and targets
23
+ seg_metric.update(seg_preds, seg_targets)
24
+ depth_metric.update(depth_preds, depth_targets)
25
+
26
+ # Compute final metrics
27
+ miou, pix_acc = seg_metric.compute()
28
+ abs_err, rel_err = depth_metric.compute()
29
+ ```
30
+ """
31
+
1
32
  from .depth import DepthMetric
2
33
  from .noise import NoiseMetric
3
34
  from .normal import NormalMetric
@@ -7,9 +7,23 @@ from torchmetrics import Metric
7
7
 
8
8
 
9
9
  class DepthMetric(Metric):
10
+ """
11
+ Metric for evaluating depth estimation performance on NYUv2 dataset.
12
+
13
+ This metric computes absolute error and relative error for depth predictions,
14
+ properly handling the binary mask to exclude invalid depth regions.
15
+
16
+ Attributes:
17
+ metric_names: List of metric names ["abs_err", "rel_err"].
18
+ abs_record: List storing absolute error values for each batch.
19
+ rel_record: List storing relative error values for each batch.
20
+ batch_size: List storing batch sizes for weighted averaging.
21
+ """
22
+
10
23
  metric_names = ["abs_err", "rel_err"]
11
24
 
12
25
  def __init__(self):
26
+ """Initialize the DepthMetric with state variables for tracking errors."""
13
27
  super().__init__()
14
28
 
15
29
  self.add_state("abs_record", default=[], dist_reduce_fx="cat")
@@ -17,11 +31,20 @@ class DepthMetric(Metric):
17
31
  self.add_state("batch_size", default=[], dist_reduce_fx="cat")
18
32
 
19
33
  def reset(self):
34
+ """Reset all metric states to empty lists."""
20
35
  self.abs_record = []
21
36
  self.rel_record = []
22
37
  self.batch_size = []
23
38
 
24
39
  def update(self, preds: Tensor, target: Tensor):
40
+ """
41
+ Update metric states with predictions and targets from a batch.
42
+
43
+ Args:
44
+ preds: Predicted depth values of shape (batch_size, 1, height, width).
45
+ target: Ground truth depth values of shape (batch_size, 1, height, width).
46
+ Pixels with sum of 0 are considered invalid and masked out.
47
+ """
25
48
  binary_mask = (torch.sum(target, dim=1) != 0).unsqueeze(1)
26
49
  preds = preds.masked_select(binary_mask)
27
50
  target = target.masked_select(binary_mask)
@@ -38,6 +61,13 @@ class DepthMetric(Metric):
38
61
  self.batch_size.append(torch.asarray(preds.size(0), device=preds.device))
39
62
 
40
63
  def compute(self):
64
+ """
65
+ Compute the final metric values across all batches.
66
+
67
+ Returns:
68
+ List[Tensor]: A list containing [absolute_error, relative_error],
69
+ where each value is the weighted average across all batches.
70
+ """
41
71
  records = torch.stack(
42
72
  [torch.stack(self.abs_record), torch.stack(self.rel_record)]
43
73
  )
@@ -3,10 +3,35 @@ from torch import Tensor, nn
3
3
 
4
4
 
5
5
  def segmentation_loss(pred: Tensor, gt: Tensor):
6
+ """
7
+ Compute cross-entropy loss for semantic segmentation.
8
+
9
+ Args:
10
+ pred: Predicted segmentation logits of shape (batch_size, num_classes, height, width).
11
+ gt: Ground truth segmentation labels of shape (batch_size, height, width).
12
+ Pixels with value -1 are ignored in the loss computation.
13
+
14
+ Returns:
15
+ Tensor: Scalar loss value.
16
+ """
6
17
  return nn.functional.cross_entropy(pred, gt.long(), ignore_index=-1)
7
18
 
8
19
 
9
20
  def depth_loss(pred: Tensor, gt: Tensor):
21
+ """
22
+ Compute L1 loss for depth estimation with binary masking.
23
+
24
+ This loss function calculates the absolute error between predicted and ground truth
25
+ depth values, but only for valid pixels (where ground truth depth is non-zero).
26
+
27
+ Args:
28
+ pred: Predicted depth values of shape (batch_size, 1, height, width).
29
+ gt: Ground truth depth values of shape (batch_size, 1, height, width).
30
+ Pixels with sum of 0 across channels are considered invalid and masked out.
31
+
32
+ Returns:
33
+ Tensor: Scalar loss value averaged over valid pixels.
34
+ """
10
35
  binary_mask = (torch.sum(gt, dim=1) != 0).float().unsqueeze(1).to(pred.device)
11
36
  loss = torch.sum(torch.abs(pred - gt) * binary_mask) / torch.nonzero(
12
37
  binary_mask, as_tuple=False
@@ -15,6 +40,21 @@ def depth_loss(pred: Tensor, gt: Tensor):
15
40
 
16
41
 
17
42
  def normal_loss(pred: Tensor, gt: Tensor):
43
+ """
44
+ Compute cosine similarity loss for surface normal prediction.
45
+
46
+ This loss measures the angular difference between predicted and ground truth
47
+ surface normals using normalized cosine similarity (1 - dot product).
48
+
49
+ Args:
50
+ pred: Predicted surface normals of shape (batch_size, 3, height, width).
51
+ Will be L2-normalized before computing loss.
52
+ gt: Ground truth surface normals of shape (batch_size, 3, height, width).
53
+ Already normalized on NYUv2 dataset. Pixels with sum of 0 are invalid.
54
+
55
+ Returns:
56
+ Tensor: Scalar loss value (1 - mean cosine similarity) over valid pixels.
57
+ """
18
58
  # gt has been normalized on the NYUv2 dataset
19
59
  pred = pred / torch.norm(pred, p=2, dim=1, keepdim=True)
20
60
  binary_mask = (torch.sum(gt, dim=1) != 0).float().unsqueeze(1).to(pred.device)
@@ -6,11 +6,35 @@ from torchmetrics import Metric
6
6
 
7
7
 
8
8
  class NoiseMetric(Metric):
9
+ """
10
+ A placeholder metric for noise evaluation on NYUv2 dataset.
11
+
12
+ This metric currently serves as a placeholder and always returns a value of 1.
13
+ It can be extended in the future to include actual noise-related metrics.
14
+
15
+ Note:
16
+ This is a dummy implementation that doesn't perform actual noise measurements.
17
+ """
18
+
9
19
  def __init__(self):
20
+ """Initialize the NoiseMetric."""
10
21
  super().__init__()
11
22
 
12
23
  def update(self, preds: Tensor, target: Tensor):
24
+ """
25
+ Update metric state (currently a no-op).
26
+
27
+ Args:
28
+ preds: Predicted values (unused).
29
+ target: Ground truth values (unused).
30
+ """
13
31
  pass
14
32
 
15
33
  def compute(self):
34
+ """
35
+ Compute the metric value.
36
+
37
+ Returns:
38
+ List[int]: A list containing [1] as a placeholder value.
39
+ """
16
40
  return [1]