wavelet-loss 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ from .loss import WaveletLoss
2
+
3
+ __all__ = ["WaveletLoss"]
wavelet_loss/loss.py ADDED
@@ -0,0 +1,662 @@
1
+ import math
2
+ import numpy as np
3
+
4
+ from torch import Tensor
5
+ from typing import Protocol
6
+ from collections.abc import Mapping
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+
12
+ from wavelet_transform import QuaternionWaveletTransform
13
+
14
+
15
+ class LossCallableMSE(Protocol):
16
+ def __call__(
17
+ self,
18
+ input: Tensor,
19
+ target: Tensor,
20
+ size_average: bool | None = None,
21
+ reduce: bool | None = None,
22
+ reduction: str = "mean",
23
+ ) -> Tensor: ...
24
+
25
+
26
+ class LossCallableReduction(Protocol):
27
+ def __call__(self, input: Tensor, target: Tensor, reduction: str = "mean") -> Tensor: ...
28
+
29
+
30
+ LossCallable = LossCallableReduction | LossCallableMSE
31
+ Metrics = dict[str, int | float | None]
32
+
33
+
34
+ class WaveletLoss(nn.Module):
35
+ """Wavelet-based loss calculation module."""
36
+
37
+ def __init__(
38
+ self,
39
+ wavelet="db4",
40
+ level=3,
41
+ transform_type="dwt",
42
+ backend: str = "pytorch_wavelets",
43
+ mode: str = "zero",
44
+ loss_fn: LossCallable = F.mse_loss,
45
+ device=torch.device("cpu"),
46
+ band_level_weights: dict[str, float] | None = None,
47
+ band_weights: dict[str, float] | None = None,
48
+ quaternion_component_weights: dict[str, float] | None = None,
49
+ ll_level_threshold: int | None = -1,
50
+ metrics: bool = False,
51
+ normalize_bands: bool = False,
52
+ max_timestep: float = 1.0,
53
+ timestep_cutoff: float = 0.7,
54
+ timestep_transition_width: float = 0.4,
55
+ ):
56
+ """
57
+
58
+ Args:
59
+ wavelet: Wavelet family (e.g., 'db4', 'sym7')
60
+ level: Decomposition level
61
+ transform_type: Type of wavelet transform ('dwt' or 'swt')
62
+ loss_fn: Loss function to apply to wavelet coefficients
63
+ device: Computation device
64
+ band_level_weights: Optional custom weights for different bands on different levels
65
+ band_weights: Optional custom weights for different bands
66
+ component_weights: Weights for quaternion components
67
+ ll_level_threshold: Level when applying loss for ll. Default -1 or last level.
68
+ max_timestep: Maximum value of the trainer's timestep convention.
69
+ Default 1.0 (flow-matching sigmas, e.g. Flux2). Pass 1000 for
70
+ DDPM-style integer timesteps. Timesteps outside
71
+ [0, max_timestep] raise ValueError.
72
+ timestep_cutoff: Fraction of max_timestep at which the timestep
73
+ weight crosses 0.5. Below the cutoff (less noise) the weight
74
+ is ~1; above it the weight fades toward 0.
75
+ timestep_transition_width: Fraction of the timestep range the
76
+ fade is spread over. Smaller = harder cutoff.
77
+ """
78
+ super().__init__()
79
+ self.level = level
80
+ self.wavelet = wavelet
81
+ self.transform_type = transform_type
82
+ self.loss_fn = loss_fn
83
+ self.device = device
84
+ self.ll_level_threshold: int | None = ll_level_threshold
85
+ self.metrics = metrics
86
+ self.max_timestep = max_timestep
87
+ self.timestep_cutoff = timestep_cutoff
88
+ self.timestep_transition_width = timestep_transition_width
89
+ self.normalize_bands = normalize_bands
90
+
91
+ # Initialize transform via backend factory
92
+ from wavelet_transform import make_backend
93
+
94
+ self.backend = backend
95
+ self.mode = mode
96
+ self.transform = make_backend(backend, transform_type, wavelet, mode, device)
97
+
98
+ if transform_type == "qwt":
99
+ self.register_buffer("hilbert_x", self.transform.hilbert_x)
100
+ self.register_buffer("hilbert_y", self.transform.hilbert_y)
101
+ self.register_buffer("hilbert_xy", self.transform.hilbert_xy)
102
+ self.component_weights = quaternion_component_weights or {
103
+ "r": 1.0,
104
+ "i": 0.7,
105
+ "j": 0.7,
106
+ "k": 0.5,
107
+ }
108
+
109
+ # Register wavelet filters as module buffers
110
+ self.register_buffer("dec_lo", self.transform.dec_lo.to(device))
111
+ self.register_buffer("dec_hi", self.transform.dec_hi.to(device))
112
+
113
+ # Default weights from paper:
114
+ # "Training Generative Image Super-Resolution Models by Wavelet-Domain Losses"
115
+ self.band_level_weights = band_level_weights or {}
116
+ self.band_weights = band_weights or {
117
+ "ll": 0.1,
118
+ "lh": 0.01,
119
+ "hl": 0.01,
120
+ "hh": 0.05,
121
+ }
122
+
123
+ def forward(
124
+ self,
125
+ pred_latent: Tensor,
126
+ target_latent: Tensor,
127
+ timestep: torch.Tensor | None = None,
128
+ reduce: bool = True,
129
+ ) -> tuple[Tensor | list[Tensor], Mapping[str, int | float | None]]:
130
+ """
131
+ Calculate wavelet loss between prediction and target.
132
+
133
+ Returns:
134
+ loss: Total wavelet loss (scalar if reduce=True, list of tensors if reduce=False)
135
+ metrics: Wavelet metrics if requested in WaveletLoss(metrics=True)
136
+ """
137
+ if pred_latent.ndim != 4 or target_latent.ndim != 4:
138
+ raise ValueError(
139
+ f"WaveletLoss expects 4D [B, C, H, W] tensors, got "
140
+ f"pred.ndim={pred_latent.ndim}, target.ndim={target_latent.ndim}."
141
+ )
142
+
143
+ if timestep is not None:
144
+ self._validate_timestep(timestep)
145
+
146
+ if isinstance(self.transform, QuaternionWaveletTransform):
147
+ return self.quaternion_forward(pred_latent, target_latent, timestep, reduce)
148
+
149
+ batch_size = pred_latent.shape[0]
150
+ device = pred_latent.device
151
+
152
+ # Decompose inputs
153
+ pred_coeffs = self.transform.decompose(pred_latent, self.level)
154
+ target_coeffs = self.transform.decompose(target_latent, self.level)
155
+
156
+ # Calculate weighted loss
157
+ pattern_losses = []
158
+ metrics: Metrics = {}
159
+
160
+ base_weight = torch.ones((batch_size), device=device)
161
+ if timestep is not None:
162
+ base_weight *= self.smooth_timestep_weight(timestep)
163
+ if self.metrics:
164
+ metrics["wavelet_loss/avg_timestep_adjusted_weight"] = base_weight.detach().mean().item()
165
+
166
+ for i in range(self.level):
167
+ # High frequency bands
168
+ for band in ["ll", "lh", "hl", "hh"]:
169
+ band_loss, pred, target, band_metrics = self.process_band(
170
+ pred_coeffs, target_coeffs, band, i, base_weight=base_weight
171
+ )
172
+ metrics.update(band_metrics)
173
+
174
+ pattern_losses.append(band_loss)
175
+
176
+ losses = pattern_losses
177
+
178
+ # METRICS: Calculate all additional metrics (no gradients needed)
179
+ if self.metrics:
180
+ with torch.no_grad():
181
+ metrics.update(self.process_coeff_metrics(pred_coeffs, target_coeffs))
182
+ metrics.update(self.process_loss_metrics(pattern_losses))
183
+ metrics.update(self.process_latent_metrics(pred_latent))
184
+
185
+ if reduce:
186
+ total = sum(loss_item.mean() for loss_item in losses)
187
+ return total, metrics
188
+ return losses, metrics
189
+
190
+ def process_coeff_metrics(
191
+ self,
192
+ pred_coeffs: dict[str, list[Tensor]],
193
+ target_coeffs: dict[str, list[Tensor]],
194
+ ) -> Metrics:
195
+ metrics: Metrics = {}
196
+ metrics.update(self.calculate_correlation_metrics(pred_coeffs, target_coeffs))
197
+ metrics.update(self.calculate_energy_metrics(pred_coeffs, target_coeffs))
198
+ metrics.update(self.calculate_cross_scale_consistency_metrics(pred_coeffs, target_coeffs))
199
+ metrics.update(self.calculate_directional_consistency_metrics(pred_coeffs, target_coeffs))
200
+
201
+ return metrics
202
+
203
+ @torch.no_grad()
204
+ def calculate_energy_metrics(
205
+ self,
206
+ pred_coeffs: dict[str, list[Tensor]],
207
+ target_coeffs: dict[str, list[Tensor]],
208
+ ) -> Metrics:
209
+ """Per-band coefficient energy (mean of squares) for pred and target.
210
+
211
+ Energy is non-negative and amplitude-sensitive, so it surfaces the
212
+ scale/amplitude errors that correlation and ratio metrics are blind to
213
+ (e.g. a uniformly scaled prediction). Replaces the old signed-mean
214
+ ``avg_hf_*`` metrics, which hovered near zero by construction.
215
+ """
216
+ metrics: Metrics = {}
217
+ hf_pred: list[float] = []
218
+ hf_target: list[float] = []
219
+
220
+ for band in ["ll", "lh", "hl", "hh"]:
221
+ for i in range(self.level):
222
+ pred_e = torch.mean(pred_coeffs[band][i] ** 2).item()
223
+ target_e = torch.mean(target_coeffs[band][i] ** 2).item()
224
+ metrics[f"wavelet_loss/energy/{band}{i + 1}_pred"] = pred_e
225
+ metrics[f"wavelet_loss/energy/{band}{i + 1}_target"] = target_e
226
+ if band != "ll":
227
+ hf_pred.append(pred_e)
228
+ hf_target.append(target_e)
229
+
230
+ if hf_pred:
231
+ metrics["wavelet_loss/avg_hf_energy_pred"] = sum(hf_pred) / len(hf_pred)
232
+ metrics["wavelet_loss/avg_hf_energy_target"] = sum(hf_target) / len(hf_target)
233
+
234
+ return metrics
235
+
236
+ def process_latent_metrics(self, pred_latent: Tensor) -> dict[str, int | float | None]:
237
+ """
238
+ Calculate metrics for the latent space.
239
+
240
+ Args:
241
+ pred_latent: The predicted latent tensor
242
+ target_latent: The target latent tensor
243
+
244
+ Returns:
245
+ metrics: The metrics dictionary
246
+ """
247
+ metrics: dict[str, int | float | None] = {}
248
+ metrics.update(self.calculate_latent_regularity_metrics(pred_latent))
249
+
250
+ return metrics
251
+
252
+ def process_loss_metrics(self, losses: list[Tensor]) -> Metrics:
253
+ """Aggregate the weighted per-band losses into a scalar total metric.
254
+
255
+ ``losses`` are the weighted per-band loss tensors — the same ones summed
256
+ for the ``reduce=True`` return — so this mirrors the optimized objective
257
+ exactly. Per-band breakdowns are emitted by ``process_band``.
258
+ """
259
+ metrics: Metrics = {}
260
+ total = sum(loss_item.detach().mean() for loss_item in losses)
261
+ metrics["wavelet_loss/total"] = float(total)
262
+ return metrics
263
+
264
+ def process_band(
265
+ self,
266
+ pred_coeffs: dict[str, list[Tensor]],
267
+ target_coeffs: dict[str, list[Tensor]],
268
+ band: str,
269
+ i: int,
270
+ base_weight: Tensor,
271
+ ) -> tuple[Tensor, Tensor, Tensor, Metrics]:
272
+ """
273
+ Process a single band and calculate the loss.
274
+
275
+ Args:
276
+ pred_coeffs: The predicted coefficients
277
+ target_coeffs: The target coefficients
278
+ band: The band to process (e.g. "lh", "hl", etc.)
279
+ i: The level index
280
+ base_weight: The base weight for the band
281
+
282
+ Returns:
283
+ loss: The band loss
284
+ pred: The predicted wavelet component
285
+ target: The target wavelet component
286
+ metrics: The metrics for this band
287
+
288
+ """
289
+ # # If negative it's from the end of the levels else it's the level.
290
+ # ll_threshold = None
291
+ ll_threshold = self._calculate_effective_ll_threshold()
292
+ if ll_threshold is not None and band == "ll" and i + 1 <= ll_threshold:
293
+ return (
294
+ torch.zeros_like(pred_coeffs[band][i]),
295
+ torch.zeros_like(pred_coeffs[band][i]),
296
+ torch.zeros_like(target_coeffs[band][i]),
297
+ {},
298
+ )
299
+
300
+ weight_key = f"{band}{i + 1}"
301
+ pred = pred_coeffs[band][i]
302
+ target = target_coeffs[band][i]
303
+
304
+ if self.normalize_bands:
305
+ # Shared normalization: use the TARGET band statistics for BOTH tensors
306
+ # so relative amplitude/offset errors are preserved (not zeroed out).
307
+ mean = target.mean()
308
+ std = target.std() + 1e-8
309
+ pred = (pred - mean) / std
310
+ target = (target - mean) / std
311
+
312
+ band_loss = self.loss_fn(pred, target, reduction="none")
313
+
314
+ weight = base_weight * self.band_level_weights.get(weight_key, self.band_weights[band])
315
+ loss = weight.view(-1, 1, 1, 1) * band_loss
316
+
317
+ metrics: Metrics = {}
318
+ if self.metrics:
319
+ metrics = {
320
+ f"wavelet_loss/band_loss/{band}{i + 1}": band_loss.detach().mean().item(),
321
+ f"wavelet_loss/weighted_band_loss/{band}{i + 1}": loss.detach().mean().item(),
322
+ }
323
+
324
+ return loss, pred, target, metrics
325
+
326
+ def quaternion_forward(
327
+ self, pred: Tensor, target: Tensor, timestep: Tensor | None, reduce: bool = True
328
+ ) -> tuple[Tensor | list[Tensor], Mapping[str, int | float | None]]:
329
+ """
330
+ Calculate QWT loss between prediction and target.
331
+
332
+ Args:
333
+ pred: Predicted tensor [B, C, H, W]
334
+ target: Target tensor [B, C, H, W]
335
+
336
+ Returns:
337
+ Tuple of (total loss, detailed component losses)
338
+ """
339
+ batch_size = pred.shape[0]
340
+ device = pred.device
341
+
342
+ assert isinstance(self.transform, QuaternionWaveletTransform), "Not a quaternion wavelet transform"
343
+ # Apply QWT to both inputs
344
+ pred_qwt = self.transform.decompose_quaternion(pred, self.level)
345
+ target_qwt = self.transform.decompose_quaternion(target, self.level)
346
+
347
+ # Initialize total loss and component losses
348
+ pattern_losses = []
349
+ component_losses = {
350
+ f"{component}_{band}_{level + 1}": torch.zeros_like(pred_qwt[component][band][level], device=pred.device)
351
+ for level in range(self.level)
352
+ for component in ["r", "i", "j", "k"]
353
+ for band in ["ll", "lh", "hl", "hh"]
354
+ }
355
+ metrics: dict[str, float | int | None] = {}
356
+
357
+ # Calculate the weighted loss based on the timestep
358
+ base_weight = torch.ones((batch_size), device=device)
359
+ if timestep is not None:
360
+ base_weight *= self.smooth_timestep_weight(timestep)
361
+ if self.metrics:
362
+ metrics["wavelet_loss/avg_timestep_adjusted_weight"] = base_weight.detach().mean().item()
363
+
364
+ # Calculate loss for each quaternion component, band and level
365
+
366
+ for component in ["r", "i", "j", "k"]:
367
+ component_weight = self.component_weights[component]
368
+ for band in ["ll", "lh", "hl", "hh"]:
369
+ for level_idx in range(self.level):
370
+ band_loss, pred_coeffs, target_coeffs, band_metrics = self.process_band(
371
+ pred_qwt[component], target_qwt[component], band, level_idx, base_weight=base_weight
372
+ )
373
+ component_losses[f"{component}_{band}_{level_idx + 1}"] = band_loss
374
+ if self.metrics:
375
+ metrics[f"{component}_{band}_{level_idx + 1}"] = band_loss.detach().mean().item()
376
+ metrics.update(band_metrics)
377
+
378
+ pattern_losses.append(component_weight * band_loss)
379
+
380
+ if self.metrics:
381
+ component_metrics = self.process_coeff_metrics(pred_qwt[component], target_qwt[component])
382
+ for k, v in component_metrics.items():
383
+ metrics[f"{component}_{k}"] = v
384
+
385
+ # METRICS: Calculate all additional metrics
386
+ if self.metrics:
387
+ metrics.update(self.process_loss_metrics(pattern_losses))
388
+ metrics.update(self.process_latent_metrics(pred))
389
+
390
+ if reduce:
391
+ total = sum(loss_item.mean() for loss_item in pattern_losses)
392
+ return total, metrics
393
+ return pattern_losses, metrics
394
+
395
+ @torch.no_grad()
396
+ def calculate_cross_scale_consistency_metrics(
397
+ self,
398
+ pred_coeffs: dict[str, list[Tensor]],
399
+ target_coeffs: dict[str, list[Tensor]],
400
+ ) -> dict:
401
+ """
402
+ Calculate metrics for cross-scale consistency between adjacent wavelet levels.
403
+
404
+ Args:
405
+ pred_coeffs: Dictionary of predicted wavelet coefficients
406
+ target_coeffs: Dictionary of target wavelet coefficients
407
+
408
+ Returns:
409
+ Dictionary containing cross-scale consistency metrics
410
+
411
+ Notes:
412
+ - Compares energy ratios between adjacent scales
413
+ - Uses log-scale differences for stability
414
+ - Provides per-level and averaged metrics
415
+ """
416
+ metrics = {}
417
+
418
+ for band in ["lh", "hl", "hh"]:
419
+ for i in range(1, self.level):
420
+ # Compare ratio of energies between adjacent scales
421
+ pred_energy_fine = torch.mean(pred_coeffs[band][i - 1] ** 2).item()
422
+ pred_energy_coarse = torch.mean(pred_coeffs[band][i] ** 2).item()
423
+ target_energy_fine = torch.mean(target_coeffs[band][i - 1] ** 2).item()
424
+ target_energy_coarse = torch.mean(target_coeffs[band][i] ** 2).item()
425
+
426
+ # Calculate ratios and log differences
427
+ pred_ratio = pred_energy_coarse / (pred_energy_fine + 1e-8)
428
+ target_ratio = target_energy_coarse / (target_energy_fine + 1e-8)
429
+ log_ratio_diff = abs(math.log(pred_ratio + 1e-8) - math.log(target_ratio + 1e-8))
430
+
431
+ # Store individual metrics
432
+ metrics[f"wavelet_loss/cross_scale/{band}{i}_to_{i + 1}_pred_ratio"] = pred_ratio
433
+ metrics[f"wavelet_loss/cross_scale/{band}{i}_to_{i + 1}_target_ratio"] = target_ratio
434
+ metrics[f"wavelet_loss/cross_scale/{band}{i}_to_{i + 1}_log_diff"] = log_ratio_diff
435
+
436
+ # Calculate average difference across all bands and scales
437
+ log_diffs = [v for k, v in metrics.items() if k.endswith("_log_diff")]
438
+ if log_diffs:
439
+ metrics["wavelet_loss/cross_scale/avg_difference"] = sum(log_diffs) / len(log_diffs)
440
+
441
+ return metrics
442
+
443
+ @torch.no_grad()
444
+ def calculate_correlation_metrics(
445
+ self,
446
+ pred_coeffs: dict[str, list[Tensor]],
447
+ target_coeffs: dict[str, list[Tensor]],
448
+ ) -> dict:
449
+ """
450
+ Calculate spatial correlation metrics between predicted and target wavelet coefficients.
451
+
452
+ Args:
453
+ pred_coeffs: Dictionary of predicted wavelet coefficients
454
+ target_coeffs: Dictionary of target wavelet coefficients
455
+
456
+ Returns:
457
+ Dictionary containing correlation metrics for each band and level
458
+
459
+ Notes:
460
+ - Calculates correlation across spatial dimensions
461
+ - Provides per-level and per-band averaged correlations
462
+ - Uses centered coefficients for accurate correlation measurement
463
+ """
464
+ metrics = {}
465
+
466
+ for band in ["lh", "hl", "hh"]:
467
+ band_correlations = []
468
+ for i in range(self.level):
469
+ pred = pred_coeffs[band][i] # [B, C, H, W]
470
+ target = target_coeffs[band][i]
471
+
472
+ # Flatten spatial dims but keep batch/channel separate
473
+ pred_flat = pred.flatten(start_dim=2) # [B, C, H*W]
474
+ target_flat = target.flatten(start_dim=2)
475
+
476
+ # Calculate correlation across spatial dimension
477
+ pred_centered = pred_flat - pred_flat.mean(dim=2, keepdim=True)
478
+ target_centered = target_flat - target_flat.mean(dim=2, keepdim=True)
479
+
480
+ numerator = torch.sum(pred_centered * target_centered, dim=2)
481
+ denom = torch.sqrt(torch.sum(pred_centered**2, dim=2) * torch.sum(target_centered**2, dim=2) + 1e-8)
482
+
483
+ correlation = numerator / denom # [B, C]
484
+ avg_corr = correlation.mean().item()
485
+
486
+ metrics[f"wavelet_loss/correlation/{band}{i + 1}"] = avg_corr
487
+ band_correlations.append(avg_corr)
488
+
489
+ metrics[f"wavelet_loss/correlation/{band}_avg"] = np.mean(band_correlations)
490
+
491
+ return metrics
492
+
493
+ @torch.no_grad()
494
+ def calculate_directional_consistency_metrics(
495
+ self,
496
+ pred_coeffs: dict[str, list[Tensor]],
497
+ target_coeffs: dict[str, list[Tensor]],
498
+ ) -> dict:
499
+ """
500
+ Calculate metrics for directional consistency between wavelet bands.
501
+
502
+ Args:
503
+ pred_coeffs: Dictionary of predicted wavelet coefficients
504
+ target_coeffs: Dictionary of target wavelet coefficients
505
+
506
+ Returns:
507
+ Dictionary containing directional consistency metrics
508
+
509
+ Notes:
510
+ - Analyzes horizontal vs vertical energy ratios (hl/lh)
511
+ - Analyzes diagonal vs horizontal+vertical energy ratios (hh/(hl+lh))
512
+ - Uses log-scale differences for stability
513
+ - Provides per-level and averaged metrics
514
+ """
515
+ metrics = {}
516
+ hv_diffs = []
517
+ diag_diffs = []
518
+
519
+ for i in range(1, self.level + 1):
520
+ # Horizontal to vertical energy ratio
521
+ pred_hl_energy = torch.mean(pred_coeffs["hl"][i - 1] ** 2).item()
522
+ pred_lh_energy = torch.mean(pred_coeffs["lh"][i - 1] ** 2).item()
523
+ target_hl_energy = torch.mean(target_coeffs["hl"][i - 1] ** 2).item()
524
+ target_lh_energy = torch.mean(target_coeffs["lh"][i - 1] ** 2).item()
525
+
526
+ pred_hv_ratio = pred_hl_energy / (pred_lh_energy + 1e-8)
527
+ target_hv_ratio = target_hl_energy / (target_lh_energy + 1e-8)
528
+ hv_log_diff = abs(math.log(pred_hv_ratio + 1e-8) - math.log(target_hv_ratio + 1e-8))
529
+
530
+ # Diagonal to (horizontal+vertical) energy ratio
531
+ pred_hh_energy = torch.mean(pred_coeffs["hh"][i - 1] ** 2).item()
532
+ target_hh_energy = torch.mean(target_coeffs["hh"][i - 1] ** 2).item()
533
+
534
+ pred_d_ratio = pred_hh_energy / (pred_hl_energy + pred_lh_energy + 1e-8)
535
+ target_d_ratio = target_hh_energy / (target_hl_energy + target_lh_energy + 1e-8)
536
+ diag_log_diff = abs(math.log(pred_d_ratio + 1e-8) - math.log(target_d_ratio + 1e-8))
537
+
538
+ # Store metrics
539
+ metrics[f"wavelet_loss/directional/level{i}_hv_pred_ratio"] = pred_hv_ratio
540
+ metrics[f"wavelet_loss/directional/level{i}_hv_target_ratio"] = target_hv_ratio
541
+ metrics[f"wavelet_loss/directional/level{i}_hv_log_diff"] = hv_log_diff
542
+
543
+ metrics[f"wavelet_loss/directional/level{i}_diag_pred_ratio"] = pred_d_ratio
544
+ metrics[f"wavelet_loss/directional/level{i}_diag_target_ratio"] = target_d_ratio
545
+ metrics[f"wavelet_loss/directional/level{i}_diag_log_diff"] = diag_log_diff
546
+
547
+ hv_diffs.append(hv_log_diff)
548
+ diag_diffs.append(diag_log_diff)
549
+
550
+ # Average metrics
551
+ if hv_diffs:
552
+ metrics["wavelet_loss/directional/avg_hv_diff"] = sum(hv_diffs) / len(hv_diffs)
553
+ if diag_diffs:
554
+ metrics["wavelet_loss/directional/avg_diag_diff"] = sum(diag_diffs) / len(diag_diffs)
555
+
556
+ return metrics
557
+
558
+ @torch.no_grad()
559
+ def calculate_latent_regularity_metrics(self, pred_latents: Tensor) -> dict:
560
+ """
561
+ Calculate metrics for latent space regularity and smoothness.
562
+
563
+ Args:
564
+ pred_latents: Predicted latent tensor
565
+
566
+ Returns:
567
+ Dictionary containing latent regularity metrics
568
+
569
+ Notes:
570
+ - Calculates total variation (TV) for smoothness measurement
571
+ - Provides statistical metrics (mean, std)
572
+ - Measures deviation from normal distribution (std from 1.0)
573
+ """
574
+ metrics = {}
575
+
576
+ # Calculate gradient magnitude of latent representation
577
+ grad_x = pred_latents[:, :, 1:, :] - pred_latents[:, :, :-1, :]
578
+ grad_y = pred_latents[:, :, :, 1:] - pred_latents[:, :, :, :-1]
579
+
580
+ # Total variation
581
+ tv_x = torch.mean(torch.abs(grad_x)).item()
582
+ tv_y = torch.mean(torch.abs(grad_y)).item()
583
+ tv_total = tv_x + tv_y
584
+
585
+ # Statistical metrics
586
+ std_value = torch.std(pred_latents).item()
587
+ mean_value = torch.mean(pred_latents).item()
588
+ std_diff = abs(std_value - 1.0)
589
+
590
+ # Store metrics
591
+ metrics["wavelet_loss/latent/tv_x"] = tv_x
592
+ metrics["wavelet_loss/latent/tv_y"] = tv_y
593
+ metrics["wavelet_loss/latent/tv_total"] = tv_total
594
+ metrics["wavelet_loss/latent/std"] = std_value
595
+ metrics["wavelet_loss/latent/mean"] = mean_value
596
+ metrics["wavelet_loss/latent/std_from_normal"] = std_diff
597
+
598
+ return metrics
599
+
600
+ def smooth_timestep_weight(self, timestep):
601
+ """
602
+ Timestep-dependent loss weight with a smooth sigmoid fade.
603
+
604
+ The weight is ~1 for timesteps below ``timestep_cutoff * max_timestep``
605
+ (low noise, where high-frequency wavelet detail is meaningful) and
606
+ fades to ~0 as the timestep approaches ``max_timestep`` (pure noise).
607
+
608
+ Args:
609
+ timestep: Current diffusion timestep tensor, in [0, max_timestep].
610
+
611
+ Returns:
612
+ Weight tensor in (0, 1), same shape as ``timestep``.
613
+ """
614
+ t_frac = timestep / self.max_timestep
615
+ return torch.sigmoid((self.timestep_cutoff - t_frac) * 4.0 / self.timestep_transition_width)
616
+
617
+ def _validate_timestep(self, timestep: Tensor) -> None:
618
+ """Raise ValueError for timesteps outside [0, max_timestep].
619
+
620
+ Strict by design: a mismatched timestep scale (e.g. DDPM-style 0-1000
621
+ timesteps against the flow-matching default max_timestep=1.0) must
622
+ fail loudly rather than silently saturate the weight. Costs one
623
+ host-device sync when a timestep is provided.
624
+ """
625
+ invalid = (timestep < 0) | (timestep > self.max_timestep)
626
+ if bool(invalid.any()):
627
+ raise ValueError(
628
+ f"timestep values must lie in [0, max_timestep={self.max_timestep}], "
629
+ f"got [{timestep.min().item()}, {timestep.max().item()}]. "
630
+ "For DDPM-style integer timesteps pass WaveletLoss(..., max_timestep=1000); "
631
+ "flow-matching sigmas in [0, 1] work with the default max_timestep=1.0."
632
+ )
633
+
634
+ def _calculate_effective_ll_threshold(self) -> int | None:
635
+ """
636
+ Calculate the effective LL level threshold.
637
+
638
+ For positive values, returns the value as-is.
639
+ For negative values, calculates from the end: level + threshold
640
+
641
+ Returns:
642
+ Effective threshold level, or None if no threshold is set
643
+
644
+ Examples:
645
+ level=3, threshold=1 -> 1
646
+ level=3, threshold=2 -> 2
647
+ level=3, threshold=-1 -> 2 (3 + (-1) = 2)
648
+ level=3, threshold=-2 -> 1 (3 + (-2) = 1)
649
+ """
650
+ if self.ll_level_threshold is None:
651
+ return None
652
+
653
+ if self.ll_level_threshold > 0:
654
+ return self.ll_level_threshold
655
+ else:
656
+ return self.level + self.ll_level_threshold
657
+
658
+ def set_loss_fn(self, loss_fn: LossCallable):
659
+ """
660
+ Set loss function to use. Wavelet loss wants l1 or huber loss.
661
+ """
662
+ self.loss_fn = loss_fn