wavelet-loss 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavelet_loss/__init__.py +3 -0
- wavelet_loss/loss.py +662 -0
- wavelet_loss-2.0.0.dist-info/METADATA +87 -0
- wavelet_loss-2.0.0.dist-info/RECORD +9 -0
- wavelet_loss-2.0.0.dist-info/WHEEL +4 -0
- wavelet_loss-2.0.0.dist-info/licenses/LICENSE +177 -0
- wavelet_transform/__init__.py +23 -0
- wavelet_transform/backends.py +130 -0
- wavelet_transform/transform.py +429 -0
wavelet_loss/__init__.py
ADDED
wavelet_loss/loss.py
ADDED
|
@@ -0,0 +1,662 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
from typing import Protocol
|
|
6
|
+
from collections.abc import Mapping
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
from torch.nn import functional as F
|
|
11
|
+
|
|
12
|
+
from wavelet_transform import QuaternionWaveletTransform
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LossCallableMSE(Protocol):
|
|
16
|
+
def __call__(
|
|
17
|
+
self,
|
|
18
|
+
input: Tensor,
|
|
19
|
+
target: Tensor,
|
|
20
|
+
size_average: bool | None = None,
|
|
21
|
+
reduce: bool | None = None,
|
|
22
|
+
reduction: str = "mean",
|
|
23
|
+
) -> Tensor: ...
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LossCallableReduction(Protocol):
|
|
27
|
+
def __call__(self, input: Tensor, target: Tensor, reduction: str = "mean") -> Tensor: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
LossCallable = LossCallableReduction | LossCallableMSE
|
|
31
|
+
Metrics = dict[str, int | float | None]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class WaveletLoss(nn.Module):
|
|
35
|
+
"""Wavelet-based loss calculation module."""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
wavelet="db4",
|
|
40
|
+
level=3,
|
|
41
|
+
transform_type="dwt",
|
|
42
|
+
backend: str = "pytorch_wavelets",
|
|
43
|
+
mode: str = "zero",
|
|
44
|
+
loss_fn: LossCallable = F.mse_loss,
|
|
45
|
+
device=torch.device("cpu"),
|
|
46
|
+
band_level_weights: dict[str, float] | None = None,
|
|
47
|
+
band_weights: dict[str, float] | None = None,
|
|
48
|
+
quaternion_component_weights: dict[str, float] | None = None,
|
|
49
|
+
ll_level_threshold: int | None = -1,
|
|
50
|
+
metrics: bool = False,
|
|
51
|
+
normalize_bands: bool = False,
|
|
52
|
+
max_timestep: float = 1.0,
|
|
53
|
+
timestep_cutoff: float = 0.7,
|
|
54
|
+
timestep_transition_width: float = 0.4,
|
|
55
|
+
):
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
wavelet: Wavelet family (e.g., 'db4', 'sym7')
|
|
60
|
+
level: Decomposition level
|
|
61
|
+
transform_type: Type of wavelet transform ('dwt' or 'swt')
|
|
62
|
+
loss_fn: Loss function to apply to wavelet coefficients
|
|
63
|
+
device: Computation device
|
|
64
|
+
band_level_weights: Optional custom weights for different bands on different levels
|
|
65
|
+
band_weights: Optional custom weights for different bands
|
|
66
|
+
component_weights: Weights for quaternion components
|
|
67
|
+
ll_level_threshold: Level when applying loss for ll. Default -1 or last level.
|
|
68
|
+
max_timestep: Maximum value of the trainer's timestep convention.
|
|
69
|
+
Default 1.0 (flow-matching sigmas, e.g. Flux2). Pass 1000 for
|
|
70
|
+
DDPM-style integer timesteps. Timesteps outside
|
|
71
|
+
[0, max_timestep] raise ValueError.
|
|
72
|
+
timestep_cutoff: Fraction of max_timestep at which the timestep
|
|
73
|
+
weight crosses 0.5. Below the cutoff (less noise) the weight
|
|
74
|
+
is ~1; above it the weight fades toward 0.
|
|
75
|
+
timestep_transition_width: Fraction of the timestep range the
|
|
76
|
+
fade is spread over. Smaller = harder cutoff.
|
|
77
|
+
"""
|
|
78
|
+
super().__init__()
|
|
79
|
+
self.level = level
|
|
80
|
+
self.wavelet = wavelet
|
|
81
|
+
self.transform_type = transform_type
|
|
82
|
+
self.loss_fn = loss_fn
|
|
83
|
+
self.device = device
|
|
84
|
+
self.ll_level_threshold: int | None = ll_level_threshold
|
|
85
|
+
self.metrics = metrics
|
|
86
|
+
self.max_timestep = max_timestep
|
|
87
|
+
self.timestep_cutoff = timestep_cutoff
|
|
88
|
+
self.timestep_transition_width = timestep_transition_width
|
|
89
|
+
self.normalize_bands = normalize_bands
|
|
90
|
+
|
|
91
|
+
# Initialize transform via backend factory
|
|
92
|
+
from wavelet_transform import make_backend
|
|
93
|
+
|
|
94
|
+
self.backend = backend
|
|
95
|
+
self.mode = mode
|
|
96
|
+
self.transform = make_backend(backend, transform_type, wavelet, mode, device)
|
|
97
|
+
|
|
98
|
+
if transform_type == "qwt":
|
|
99
|
+
self.register_buffer("hilbert_x", self.transform.hilbert_x)
|
|
100
|
+
self.register_buffer("hilbert_y", self.transform.hilbert_y)
|
|
101
|
+
self.register_buffer("hilbert_xy", self.transform.hilbert_xy)
|
|
102
|
+
self.component_weights = quaternion_component_weights or {
|
|
103
|
+
"r": 1.0,
|
|
104
|
+
"i": 0.7,
|
|
105
|
+
"j": 0.7,
|
|
106
|
+
"k": 0.5,
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
# Register wavelet filters as module buffers
|
|
110
|
+
self.register_buffer("dec_lo", self.transform.dec_lo.to(device))
|
|
111
|
+
self.register_buffer("dec_hi", self.transform.dec_hi.to(device))
|
|
112
|
+
|
|
113
|
+
# Default weights from paper:
|
|
114
|
+
# "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses"
|
|
115
|
+
self.band_level_weights = band_level_weights or {}
|
|
116
|
+
self.band_weights = band_weights or {
|
|
117
|
+
"ll": 0.1,
|
|
118
|
+
"lh": 0.01,
|
|
119
|
+
"hl": 0.01,
|
|
120
|
+
"hh": 0.05,
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
def forward(
|
|
124
|
+
self,
|
|
125
|
+
pred_latent: Tensor,
|
|
126
|
+
target_latent: Tensor,
|
|
127
|
+
timestep: torch.Tensor | None = None,
|
|
128
|
+
reduce: bool = True,
|
|
129
|
+
) -> tuple[Tensor | list[Tensor], Mapping[str, int | float | None]]:
|
|
130
|
+
"""
|
|
131
|
+
Calculate wavelet loss between prediction and target.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
loss: Total wavelet loss (scalar if reduce=True, list of tensors if reduce=False)
|
|
135
|
+
metrics: Wavelet metrics if requested in WaveletLoss(metrics=True)
|
|
136
|
+
"""
|
|
137
|
+
if pred_latent.ndim != 4 or target_latent.ndim != 4:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f"WaveletLoss expects 4D [B, C, H, W] tensors, got "
|
|
140
|
+
f"pred.ndim={pred_latent.ndim}, target.ndim={target_latent.ndim}."
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if timestep is not None:
|
|
144
|
+
self._validate_timestep(timestep)
|
|
145
|
+
|
|
146
|
+
if isinstance(self.transform, QuaternionWaveletTransform):
|
|
147
|
+
return self.quaternion_forward(pred_latent, target_latent, timestep, reduce)
|
|
148
|
+
|
|
149
|
+
batch_size = pred_latent.shape[0]
|
|
150
|
+
device = pred_latent.device
|
|
151
|
+
|
|
152
|
+
# Decompose inputs
|
|
153
|
+
pred_coeffs = self.transform.decompose(pred_latent, self.level)
|
|
154
|
+
target_coeffs = self.transform.decompose(target_latent, self.level)
|
|
155
|
+
|
|
156
|
+
# Calculate weighted loss
|
|
157
|
+
pattern_losses = []
|
|
158
|
+
metrics: Metrics = {}
|
|
159
|
+
|
|
160
|
+
base_weight = torch.ones((batch_size), device=device)
|
|
161
|
+
if timestep is not None:
|
|
162
|
+
base_weight *= self.smooth_timestep_weight(timestep)
|
|
163
|
+
if self.metrics:
|
|
164
|
+
metrics["wavelet_loss/avg_timestep_adjusted_weight"] = base_weight.detach().mean().item()
|
|
165
|
+
|
|
166
|
+
for i in range(self.level):
|
|
167
|
+
# High frequency bands
|
|
168
|
+
for band in ["ll", "lh", "hl", "hh"]:
|
|
169
|
+
band_loss, pred, target, band_metrics = self.process_band(
|
|
170
|
+
pred_coeffs, target_coeffs, band, i, base_weight=base_weight
|
|
171
|
+
)
|
|
172
|
+
metrics.update(band_metrics)
|
|
173
|
+
|
|
174
|
+
pattern_losses.append(band_loss)
|
|
175
|
+
|
|
176
|
+
losses = pattern_losses
|
|
177
|
+
|
|
178
|
+
# METRICS: Calculate all additional metrics (no gradients needed)
|
|
179
|
+
if self.metrics:
|
|
180
|
+
with torch.no_grad():
|
|
181
|
+
metrics.update(self.process_coeff_metrics(pred_coeffs, target_coeffs))
|
|
182
|
+
metrics.update(self.process_loss_metrics(pattern_losses))
|
|
183
|
+
metrics.update(self.process_latent_metrics(pred_latent))
|
|
184
|
+
|
|
185
|
+
if reduce:
|
|
186
|
+
total = sum(loss_item.mean() for loss_item in losses)
|
|
187
|
+
return total, metrics
|
|
188
|
+
return losses, metrics
|
|
189
|
+
|
|
190
|
+
def process_coeff_metrics(
|
|
191
|
+
self,
|
|
192
|
+
pred_coeffs: dict[str, list[Tensor]],
|
|
193
|
+
target_coeffs: dict[str, list[Tensor]],
|
|
194
|
+
) -> Metrics:
|
|
195
|
+
metrics: Metrics = {}
|
|
196
|
+
metrics.update(self.calculate_correlation_metrics(pred_coeffs, target_coeffs))
|
|
197
|
+
metrics.update(self.calculate_energy_metrics(pred_coeffs, target_coeffs))
|
|
198
|
+
metrics.update(self.calculate_cross_scale_consistency_metrics(pred_coeffs, target_coeffs))
|
|
199
|
+
metrics.update(self.calculate_directional_consistency_metrics(pred_coeffs, target_coeffs))
|
|
200
|
+
|
|
201
|
+
return metrics
|
|
202
|
+
|
|
203
|
+
@torch.no_grad()
|
|
204
|
+
def calculate_energy_metrics(
|
|
205
|
+
self,
|
|
206
|
+
pred_coeffs: dict[str, list[Tensor]],
|
|
207
|
+
target_coeffs: dict[str, list[Tensor]],
|
|
208
|
+
) -> Metrics:
|
|
209
|
+
"""Per-band coefficient energy (mean of squares) for pred and target.
|
|
210
|
+
|
|
211
|
+
Energy is non-negative and amplitude-sensitive, so it surfaces the
|
|
212
|
+
scale/amplitude errors that correlation and ratio metrics are blind to
|
|
213
|
+
(e.g. a uniformly scaled prediction). Replaces the old signed-mean
|
|
214
|
+
``avg_hf_*`` metrics, which hovered near zero by construction.
|
|
215
|
+
"""
|
|
216
|
+
metrics: Metrics = {}
|
|
217
|
+
hf_pred: list[float] = []
|
|
218
|
+
hf_target: list[float] = []
|
|
219
|
+
|
|
220
|
+
for band in ["ll", "lh", "hl", "hh"]:
|
|
221
|
+
for i in range(self.level):
|
|
222
|
+
pred_e = torch.mean(pred_coeffs[band][i] ** 2).item()
|
|
223
|
+
target_e = torch.mean(target_coeffs[band][i] ** 2).item()
|
|
224
|
+
metrics[f"wavelet_loss/energy/{band}{i + 1}_pred"] = pred_e
|
|
225
|
+
metrics[f"wavelet_loss/energy/{band}{i + 1}_target"] = target_e
|
|
226
|
+
if band != "ll":
|
|
227
|
+
hf_pred.append(pred_e)
|
|
228
|
+
hf_target.append(target_e)
|
|
229
|
+
|
|
230
|
+
if hf_pred:
|
|
231
|
+
metrics["wavelet_loss/avg_hf_energy_pred"] = sum(hf_pred) / len(hf_pred)
|
|
232
|
+
metrics["wavelet_loss/avg_hf_energy_target"] = sum(hf_target) / len(hf_target)
|
|
233
|
+
|
|
234
|
+
return metrics
|
|
235
|
+
|
|
236
|
+
def process_latent_metrics(self, pred_latent: Tensor) -> dict[str, int | float | None]:
|
|
237
|
+
"""
|
|
238
|
+
Calculate metrics for the latent space.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
pred_latent: The predicted latent tensor
|
|
242
|
+
target_latent: The target latent tensor
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
metrics: The metrics dictionary
|
|
246
|
+
"""
|
|
247
|
+
metrics: dict[str, int | float | None] = {}
|
|
248
|
+
metrics.update(self.calculate_latent_regularity_metrics(pred_latent))
|
|
249
|
+
|
|
250
|
+
return metrics
|
|
251
|
+
|
|
252
|
+
def process_loss_metrics(self, losses: list[Tensor]) -> Metrics:
|
|
253
|
+
"""Aggregate the weighted per-band losses into a scalar total metric.
|
|
254
|
+
|
|
255
|
+
``losses`` are the weighted per-band loss tensors — the same ones summed
|
|
256
|
+
for the ``reduce=True`` return — so this mirrors the optimized objective
|
|
257
|
+
exactly. Per-band breakdowns are emitted by ``process_band``.
|
|
258
|
+
"""
|
|
259
|
+
metrics: Metrics = {}
|
|
260
|
+
total = sum(loss_item.detach().mean() for loss_item in losses)
|
|
261
|
+
metrics["wavelet_loss/total"] = float(total)
|
|
262
|
+
return metrics
|
|
263
|
+
|
|
264
|
+
def process_band(
|
|
265
|
+
self,
|
|
266
|
+
pred_coeffs: dict[str, list[Tensor]],
|
|
267
|
+
target_coeffs: dict[str, list[Tensor]],
|
|
268
|
+
band: str,
|
|
269
|
+
i: int,
|
|
270
|
+
base_weight: Tensor,
|
|
271
|
+
) -> tuple[Tensor, Tensor, Tensor, Metrics]:
|
|
272
|
+
"""
|
|
273
|
+
Process a single band and calculate the loss.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
pred_coeffs: The predicted coefficients
|
|
277
|
+
target_coeffs: The target coefficients
|
|
278
|
+
band: The band to process (e.g. "lh", "hl", etc.)
|
|
279
|
+
i: The level index
|
|
280
|
+
base_weight: The base weight for the band
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
loss: The band loss
|
|
284
|
+
pred: The predicted wavelet component
|
|
285
|
+
target: The target wavelet component
|
|
286
|
+
metrics: The metrics for this band
|
|
287
|
+
|
|
288
|
+
"""
|
|
289
|
+
# # If negative it's from the end of the levels else it's the level.
|
|
290
|
+
# ll_threshold = None
|
|
291
|
+
ll_threshold = self._calculate_effective_ll_threshold()
|
|
292
|
+
if ll_threshold is not None and band == "ll" and i + 1 <= ll_threshold:
|
|
293
|
+
return (
|
|
294
|
+
torch.zeros_like(pred_coeffs[band][i]),
|
|
295
|
+
torch.zeros_like(pred_coeffs[band][i]),
|
|
296
|
+
torch.zeros_like(target_coeffs[band][i]),
|
|
297
|
+
{},
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
weight_key = f"{band}{i + 1}"
|
|
301
|
+
pred = pred_coeffs[band][i]
|
|
302
|
+
target = target_coeffs[band][i]
|
|
303
|
+
|
|
304
|
+
if self.normalize_bands:
|
|
305
|
+
# Shared normalization: use the TARGET band statistics for BOTH tensors
|
|
306
|
+
# so relative amplitude/offset errors are preserved (not zeroed out).
|
|
307
|
+
mean = target.mean()
|
|
308
|
+
std = target.std() + 1e-8
|
|
309
|
+
pred = (pred - mean) / std
|
|
310
|
+
target = (target - mean) / std
|
|
311
|
+
|
|
312
|
+
band_loss = self.loss_fn(pred, target, reduction="none")
|
|
313
|
+
|
|
314
|
+
weight = base_weight * self.band_level_weights.get(weight_key, self.band_weights[band])
|
|
315
|
+
loss = weight.view(-1, 1, 1, 1) * band_loss
|
|
316
|
+
|
|
317
|
+
metrics: Metrics = {}
|
|
318
|
+
if self.metrics:
|
|
319
|
+
metrics = {
|
|
320
|
+
f"wavelet_loss/band_loss/{band}{i + 1}": band_loss.detach().mean().item(),
|
|
321
|
+
f"wavelet_loss/weighted_band_loss/{band}{i + 1}": loss.detach().mean().item(),
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
return loss, pred, target, metrics
|
|
325
|
+
|
|
326
|
+
def quaternion_forward(
|
|
327
|
+
self, pred: Tensor, target: Tensor, timestep: Tensor | None, reduce: bool = True
|
|
328
|
+
) -> tuple[Tensor | list[Tensor], Mapping[str, int | float | None]]:
|
|
329
|
+
"""
|
|
330
|
+
Calculate QWT loss between prediction and target.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
pred: Predicted tensor [B, C, H, W]
|
|
334
|
+
target: Target tensor [B, C, H, W]
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
Tuple of (total loss, detailed component losses)
|
|
338
|
+
"""
|
|
339
|
+
batch_size = pred.shape[0]
|
|
340
|
+
device = pred.device
|
|
341
|
+
|
|
342
|
+
assert isinstance(self.transform, QuaternionWaveletTransform), "Not a quaternion wavelet transform"
|
|
343
|
+
# Apply QWT to both inputs
|
|
344
|
+
pred_qwt = self.transform.decompose_quaternion(pred, self.level)
|
|
345
|
+
target_qwt = self.transform.decompose_quaternion(target, self.level)
|
|
346
|
+
|
|
347
|
+
# Initialize total loss and component losses
|
|
348
|
+
pattern_losses = []
|
|
349
|
+
component_losses = {
|
|
350
|
+
f"{component}_{band}_{level + 1}": torch.zeros_like(pred_qwt[component][band][level], device=pred.device)
|
|
351
|
+
for level in range(self.level)
|
|
352
|
+
for component in ["r", "i", "j", "k"]
|
|
353
|
+
for band in ["ll", "lh", "hl", "hh"]
|
|
354
|
+
}
|
|
355
|
+
metrics: dict[str, float | int | None] = {}
|
|
356
|
+
|
|
357
|
+
# Calculate the weighted loss based on the timestep
|
|
358
|
+
base_weight = torch.ones((batch_size), device=device)
|
|
359
|
+
if timestep is not None:
|
|
360
|
+
base_weight *= self.smooth_timestep_weight(timestep)
|
|
361
|
+
if self.metrics:
|
|
362
|
+
metrics["wavelet_loss/avg_timestep_adjusted_weight"] = base_weight.detach().mean().item()
|
|
363
|
+
|
|
364
|
+
# Calculate loss for each quaternion component, band and level
|
|
365
|
+
|
|
366
|
+
for component in ["r", "i", "j", "k"]:
|
|
367
|
+
component_weight = self.component_weights[component]
|
|
368
|
+
for band in ["ll", "lh", "hl", "hh"]:
|
|
369
|
+
for level_idx in range(self.level):
|
|
370
|
+
band_loss, pred_coeffs, target_coeffs, band_metrics = self.process_band(
|
|
371
|
+
pred_qwt[component], target_qwt[component], band, level_idx, base_weight=base_weight
|
|
372
|
+
)
|
|
373
|
+
component_losses[f"{component}_{band}_{level_idx + 1}"] = band_loss
|
|
374
|
+
if self.metrics:
|
|
375
|
+
metrics[f"{component}_{band}_{level_idx + 1}"] = band_loss.detach().mean().item()
|
|
376
|
+
metrics.update(band_metrics)
|
|
377
|
+
|
|
378
|
+
pattern_losses.append(component_weight * band_loss)
|
|
379
|
+
|
|
380
|
+
if self.metrics:
|
|
381
|
+
component_metrics = self.process_coeff_metrics(pred_qwt[component], target_qwt[component])
|
|
382
|
+
for k, v in component_metrics.items():
|
|
383
|
+
metrics[f"{component}_{k}"] = v
|
|
384
|
+
|
|
385
|
+
# METRICS: Calculate all additional metrics
|
|
386
|
+
if self.metrics:
|
|
387
|
+
metrics.update(self.process_loss_metrics(pattern_losses))
|
|
388
|
+
metrics.update(self.process_latent_metrics(pred))
|
|
389
|
+
|
|
390
|
+
if reduce:
|
|
391
|
+
total = sum(loss_item.mean() for loss_item in pattern_losses)
|
|
392
|
+
return total, metrics
|
|
393
|
+
return pattern_losses, metrics
|
|
394
|
+
|
|
395
|
+
@torch.no_grad()
|
|
396
|
+
def calculate_cross_scale_consistency_metrics(
|
|
397
|
+
self,
|
|
398
|
+
pred_coeffs: dict[str, list[Tensor]],
|
|
399
|
+
target_coeffs: dict[str, list[Tensor]],
|
|
400
|
+
) -> dict:
|
|
401
|
+
"""
|
|
402
|
+
Calculate metrics for cross-scale consistency between adjacent wavelet levels.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
pred_coeffs: Dictionary of predicted wavelet coefficients
|
|
406
|
+
target_coeffs: Dictionary of target wavelet coefficients
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
Dictionary containing cross-scale consistency metrics
|
|
410
|
+
|
|
411
|
+
Notes:
|
|
412
|
+
- Compares energy ratios between adjacent scales
|
|
413
|
+
- Uses log-scale differences for stability
|
|
414
|
+
- Provides per-level and averaged metrics
|
|
415
|
+
"""
|
|
416
|
+
metrics = {}
|
|
417
|
+
|
|
418
|
+
for band in ["lh", "hl", "hh"]:
|
|
419
|
+
for i in range(1, self.level):
|
|
420
|
+
# Compare ratio of energies between adjacent scales
|
|
421
|
+
pred_energy_fine = torch.mean(pred_coeffs[band][i - 1] ** 2).item()
|
|
422
|
+
pred_energy_coarse = torch.mean(pred_coeffs[band][i] ** 2).item()
|
|
423
|
+
target_energy_fine = torch.mean(target_coeffs[band][i - 1] ** 2).item()
|
|
424
|
+
target_energy_coarse = torch.mean(target_coeffs[band][i] ** 2).item()
|
|
425
|
+
|
|
426
|
+
# Calculate ratios and log differences
|
|
427
|
+
pred_ratio = pred_energy_coarse / (pred_energy_fine + 1e-8)
|
|
428
|
+
target_ratio = target_energy_coarse / (target_energy_fine + 1e-8)
|
|
429
|
+
log_ratio_diff = abs(math.log(pred_ratio + 1e-8) - math.log(target_ratio + 1e-8))
|
|
430
|
+
|
|
431
|
+
# Store individual metrics
|
|
432
|
+
metrics[f"wavelet_loss/cross_scale/{band}{i}_to_{i + 1}_pred_ratio"] = pred_ratio
|
|
433
|
+
metrics[f"wavelet_loss/cross_scale/{band}{i}_to_{i + 1}_target_ratio"] = target_ratio
|
|
434
|
+
metrics[f"wavelet_loss/cross_scale/{band}{i}_to_{i + 1}_log_diff"] = log_ratio_diff
|
|
435
|
+
|
|
436
|
+
# Calculate average difference across all bands and scales
|
|
437
|
+
log_diffs = [v for k, v in metrics.items() if k.endswith("_log_diff")]
|
|
438
|
+
if log_diffs:
|
|
439
|
+
metrics["wavelet_loss/cross_scale/avg_difference"] = sum(log_diffs) / len(log_diffs)
|
|
440
|
+
|
|
441
|
+
return metrics
|
|
442
|
+
|
|
443
|
+
@torch.no_grad()
|
|
444
|
+
def calculate_correlation_metrics(
|
|
445
|
+
self,
|
|
446
|
+
pred_coeffs: dict[str, list[Tensor]],
|
|
447
|
+
target_coeffs: dict[str, list[Tensor]],
|
|
448
|
+
) -> dict:
|
|
449
|
+
"""
|
|
450
|
+
Calculate spatial correlation metrics between predicted and target wavelet coefficients.
|
|
451
|
+
|
|
452
|
+
Args:
|
|
453
|
+
pred_coeffs: Dictionary of predicted wavelet coefficients
|
|
454
|
+
target_coeffs: Dictionary of target wavelet coefficients
|
|
455
|
+
|
|
456
|
+
Returns:
|
|
457
|
+
Dictionary containing correlation metrics for each band and level
|
|
458
|
+
|
|
459
|
+
Notes:
|
|
460
|
+
- Calculates correlation across spatial dimensions
|
|
461
|
+
- Provides per-level and per-band averaged correlations
|
|
462
|
+
- Uses centered coefficients for accurate correlation measurement
|
|
463
|
+
"""
|
|
464
|
+
metrics = {}
|
|
465
|
+
|
|
466
|
+
for band in ["lh", "hl", "hh"]:
|
|
467
|
+
band_correlations = []
|
|
468
|
+
for i in range(self.level):
|
|
469
|
+
pred = pred_coeffs[band][i] # [B, C, H, W]
|
|
470
|
+
target = target_coeffs[band][i]
|
|
471
|
+
|
|
472
|
+
# Flatten spatial dims but keep batch/channel separate
|
|
473
|
+
pred_flat = pred.flatten(start_dim=2) # [B, C, H*W]
|
|
474
|
+
target_flat = target.flatten(start_dim=2)
|
|
475
|
+
|
|
476
|
+
# Calculate correlation across spatial dimension
|
|
477
|
+
pred_centered = pred_flat - pred_flat.mean(dim=2, keepdim=True)
|
|
478
|
+
target_centered = target_flat - target_flat.mean(dim=2, keepdim=True)
|
|
479
|
+
|
|
480
|
+
numerator = torch.sum(pred_centered * target_centered, dim=2)
|
|
481
|
+
denom = torch.sqrt(torch.sum(pred_centered**2, dim=2) * torch.sum(target_centered**2, dim=2) + 1e-8)
|
|
482
|
+
|
|
483
|
+
correlation = numerator / denom # [B, C]
|
|
484
|
+
avg_corr = correlation.mean().item()
|
|
485
|
+
|
|
486
|
+
metrics[f"wavelet_loss/correlation/{band}{i + 1}"] = avg_corr
|
|
487
|
+
band_correlations.append(avg_corr)
|
|
488
|
+
|
|
489
|
+
metrics[f"wavelet_loss/correlation/{band}_avg"] = np.mean(band_correlations)
|
|
490
|
+
|
|
491
|
+
return metrics
|
|
492
|
+
|
|
493
|
+
@torch.no_grad()
|
|
494
|
+
def calculate_directional_consistency_metrics(
|
|
495
|
+
self,
|
|
496
|
+
pred_coeffs: dict[str, list[Tensor]],
|
|
497
|
+
target_coeffs: dict[str, list[Tensor]],
|
|
498
|
+
) -> dict:
|
|
499
|
+
"""
|
|
500
|
+
Calculate metrics for directional consistency between wavelet bands.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
pred_coeffs: Dictionary of predicted wavelet coefficients
|
|
504
|
+
target_coeffs: Dictionary of target wavelet coefficients
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
Dictionary containing directional consistency metrics
|
|
508
|
+
|
|
509
|
+
Notes:
|
|
510
|
+
- Analyzes horizontal vs vertical energy ratios (hl/lh)
|
|
511
|
+
- Analyzes diagonal vs horizontal+vertical energy ratios (hh/(hl+lh))
|
|
512
|
+
- Uses log-scale differences for stability
|
|
513
|
+
- Provides per-level and averaged metrics
|
|
514
|
+
"""
|
|
515
|
+
metrics = {}
|
|
516
|
+
hv_diffs = []
|
|
517
|
+
diag_diffs = []
|
|
518
|
+
|
|
519
|
+
for i in range(1, self.level + 1):
|
|
520
|
+
# Horizontal to vertical energy ratio
|
|
521
|
+
pred_hl_energy = torch.mean(pred_coeffs["hl"][i - 1] ** 2).item()
|
|
522
|
+
pred_lh_energy = torch.mean(pred_coeffs["lh"][i - 1] ** 2).item()
|
|
523
|
+
target_hl_energy = torch.mean(target_coeffs["hl"][i - 1] ** 2).item()
|
|
524
|
+
target_lh_energy = torch.mean(target_coeffs["lh"][i - 1] ** 2).item()
|
|
525
|
+
|
|
526
|
+
pred_hv_ratio = pred_hl_energy / (pred_lh_energy + 1e-8)
|
|
527
|
+
target_hv_ratio = target_hl_energy / (target_lh_energy + 1e-8)
|
|
528
|
+
hv_log_diff = abs(math.log(pred_hv_ratio + 1e-8) - math.log(target_hv_ratio + 1e-8))
|
|
529
|
+
|
|
530
|
+
# Diagonal to (horizontal+vertical) energy ratio
|
|
531
|
+
pred_hh_energy = torch.mean(pred_coeffs["hh"][i - 1] ** 2).item()
|
|
532
|
+
target_hh_energy = torch.mean(target_coeffs["hh"][i - 1] ** 2).item()
|
|
533
|
+
|
|
534
|
+
pred_d_ratio = pred_hh_energy / (pred_hl_energy + pred_lh_energy + 1e-8)
|
|
535
|
+
target_d_ratio = target_hh_energy / (target_hl_energy + target_lh_energy + 1e-8)
|
|
536
|
+
diag_log_diff = abs(math.log(pred_d_ratio + 1e-8) - math.log(target_d_ratio + 1e-8))
|
|
537
|
+
|
|
538
|
+
# Store metrics
|
|
539
|
+
metrics[f"wavelet_loss/directional/level{i}_hv_pred_ratio"] = pred_hv_ratio
|
|
540
|
+
metrics[f"wavelet_loss/directional/level{i}_hv_target_ratio"] = target_hv_ratio
|
|
541
|
+
metrics[f"wavelet_loss/directional/level{i}_hv_log_diff"] = hv_log_diff
|
|
542
|
+
|
|
543
|
+
metrics[f"wavelet_loss/directional/level{i}_diag_pred_ratio"] = pred_d_ratio
|
|
544
|
+
metrics[f"wavelet_loss/directional/level{i}_diag_target_ratio"] = target_d_ratio
|
|
545
|
+
metrics[f"wavelet_loss/directional/level{i}_diag_log_diff"] = diag_log_diff
|
|
546
|
+
|
|
547
|
+
hv_diffs.append(hv_log_diff)
|
|
548
|
+
diag_diffs.append(diag_log_diff)
|
|
549
|
+
|
|
550
|
+
# Average metrics
|
|
551
|
+
if hv_diffs:
|
|
552
|
+
metrics["wavelet_loss/directional/avg_hv_diff"] = sum(hv_diffs) / len(hv_diffs)
|
|
553
|
+
if diag_diffs:
|
|
554
|
+
metrics["wavelet_loss/directional/avg_diag_diff"] = sum(diag_diffs) / len(diag_diffs)
|
|
555
|
+
|
|
556
|
+
return metrics
|
|
557
|
+
|
|
558
|
+
@torch.no_grad()
|
|
559
|
+
def calculate_latent_regularity_metrics(self, pred_latents: Tensor) -> dict:
|
|
560
|
+
"""
|
|
561
|
+
Calculate metrics for latent space regularity and smoothness.
|
|
562
|
+
|
|
563
|
+
Args:
|
|
564
|
+
pred_latents: Predicted latent tensor
|
|
565
|
+
|
|
566
|
+
Returns:
|
|
567
|
+
Dictionary containing latent regularity metrics
|
|
568
|
+
|
|
569
|
+
Notes:
|
|
570
|
+
- Calculates total variation (TV) for smoothness measurement
|
|
571
|
+
- Provides statistical metrics (mean, std)
|
|
572
|
+
- Measures deviation from normal distribution (std from 1.0)
|
|
573
|
+
"""
|
|
574
|
+
metrics = {}
|
|
575
|
+
|
|
576
|
+
# Calculate gradient magnitude of latent representation
|
|
577
|
+
grad_x = pred_latents[:, :, 1:, :] - pred_latents[:, :, :-1, :]
|
|
578
|
+
grad_y = pred_latents[:, :, :, 1:] - pred_latents[:, :, :, :-1]
|
|
579
|
+
|
|
580
|
+
# Total variation
|
|
581
|
+
tv_x = torch.mean(torch.abs(grad_x)).item()
|
|
582
|
+
tv_y = torch.mean(torch.abs(grad_y)).item()
|
|
583
|
+
tv_total = tv_x + tv_y
|
|
584
|
+
|
|
585
|
+
# Statistical metrics
|
|
586
|
+
std_value = torch.std(pred_latents).item()
|
|
587
|
+
mean_value = torch.mean(pred_latents).item()
|
|
588
|
+
std_diff = abs(std_value - 1.0)
|
|
589
|
+
|
|
590
|
+
# Store metrics
|
|
591
|
+
metrics["wavelet_loss/latent/tv_x"] = tv_x
|
|
592
|
+
metrics["wavelet_loss/latent/tv_y"] = tv_y
|
|
593
|
+
metrics["wavelet_loss/latent/tv_total"] = tv_total
|
|
594
|
+
metrics["wavelet_loss/latent/std"] = std_value
|
|
595
|
+
metrics["wavelet_loss/latent/mean"] = mean_value
|
|
596
|
+
metrics["wavelet_loss/latent/std_from_normal"] = std_diff
|
|
597
|
+
|
|
598
|
+
return metrics
|
|
599
|
+
|
|
600
|
+
def smooth_timestep_weight(self, timestep):
|
|
601
|
+
"""
|
|
602
|
+
Timestep-dependent loss weight with a smooth sigmoid fade.
|
|
603
|
+
|
|
604
|
+
The weight is ~1 for timesteps below ``timestep_cutoff * max_timestep``
|
|
605
|
+
(low noise, where high-frequency wavelet detail is meaningful) and
|
|
606
|
+
fades to ~0 as the timestep approaches ``max_timestep`` (pure noise).
|
|
607
|
+
|
|
608
|
+
Args:
|
|
609
|
+
timestep: Current diffusion timestep tensor, in [0, max_timestep].
|
|
610
|
+
|
|
611
|
+
Returns:
|
|
612
|
+
Weight tensor in (0, 1), same shape as ``timestep``.
|
|
613
|
+
"""
|
|
614
|
+
t_frac = timestep / self.max_timestep
|
|
615
|
+
return torch.sigmoid((self.timestep_cutoff - t_frac) * 4.0 / self.timestep_transition_width)
|
|
616
|
+
|
|
617
|
+
def _validate_timestep(self, timestep: Tensor) -> None:
|
|
618
|
+
"""Raise ValueError for timesteps outside [0, max_timestep].
|
|
619
|
+
|
|
620
|
+
Strict by design: a mismatched timestep scale (e.g. DDPM-style 0-1000
|
|
621
|
+
timesteps against the flow-matching default max_timestep=1.0) must
|
|
622
|
+
fail loudly rather than silently saturate the weight. Costs one
|
|
623
|
+
host-device sync when a timestep is provided.
|
|
624
|
+
"""
|
|
625
|
+
invalid = (timestep < 0) | (timestep > self.max_timestep)
|
|
626
|
+
if bool(invalid.any()):
|
|
627
|
+
raise ValueError(
|
|
628
|
+
f"timestep values must lie in [0, max_timestep={self.max_timestep}], "
|
|
629
|
+
f"got [{timestep.min().item()}, {timestep.max().item()}]. "
|
|
630
|
+
"For DDPM-style integer timesteps pass WaveletLoss(..., max_timestep=1000); "
|
|
631
|
+
"flow-matching sigmas in [0, 1] work with the default max_timestep=1.0."
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
def _calculate_effective_ll_threshold(self) -> int | None:
|
|
635
|
+
"""
|
|
636
|
+
Calculate the effective LL level threshold.
|
|
637
|
+
|
|
638
|
+
For positive values, returns the value as-is.
|
|
639
|
+
For negative values, calculates from the end: level + threshold
|
|
640
|
+
|
|
641
|
+
Returns:
|
|
642
|
+
Effective threshold level, or None if no threshold is set
|
|
643
|
+
|
|
644
|
+
Examples:
|
|
645
|
+
level=3, threshold=1 -> 1
|
|
646
|
+
level=3, threshold=2 -> 2
|
|
647
|
+
level=3, threshold=-1 -> 2 (3 + (-1) = 2)
|
|
648
|
+
level=3, threshold=-2 -> 1 (3 + (-2) = 1)
|
|
649
|
+
"""
|
|
650
|
+
if self.ll_level_threshold is None:
|
|
651
|
+
return None
|
|
652
|
+
|
|
653
|
+
if self.ll_level_threshold > 0:
|
|
654
|
+
return self.ll_level_threshold
|
|
655
|
+
else:
|
|
656
|
+
return self.level + self.ll_level_threshold
|
|
657
|
+
|
|
658
|
+
def set_loss_fn(self, loss_fn: LossCallable):
|
|
659
|
+
"""
|
|
660
|
+
Set loss function to use. Wavelet loss wants l1 or huber loss.
|
|
661
|
+
"""
|
|
662
|
+
self.loss_fn = loss_fn
|