TorchDiff 2.2.0__tar.gz → 2.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. {torchdiff-2.2.0 → torchdiff-2.3.0}/PKG-INFO +1 -1
  2. {torchdiff-2.2.0 → torchdiff-2.3.0}/README.md +3 -3
  3. {torchdiff-2.2.0 → torchdiff-2.3.0}/TorchDiff.egg-info/PKG-INFO +1 -1
  4. {torchdiff-2.2.0 → torchdiff-2.3.0}/TorchDiff.egg-info/SOURCES.txt +10 -14
  5. {torchdiff-2.2.0 → torchdiff-2.3.0}/setup.py +1 -1
  6. torchdiff-2.3.0/torchdiff/__init__.py +13 -0
  7. {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/ddim.py +78 -82
  8. {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/ddpm.py +45 -49
  9. {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/ldm.py +115 -102
  10. {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/sde.py +59 -61
  11. torchdiff-2.3.0/torchdiff/unclip.py +3776 -0
  12. {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/utils.py +1 -2
  13. torchdiff-2.2.0/unclip/clip_model.py → torchdiff-2.3.0/unclip/clip_encoder.py +8 -138
  14. torchdiff-2.3.0/unclip/forward_unclip.py +55 -0
  15. torchdiff-2.2.0/unclip/project_prior.py → torchdiff-2.3.0/unclip/projections.py +76 -49
  16. torchdiff-2.3.0/unclip/reverse_unclip.py +114 -0
  17. torchdiff-2.3.0/unclip/scheduler.py +125 -0
  18. torchdiff-2.3.0/unclip/train_unclip_decoder.py +779 -0
  19. torchdiff-2.3.0/unclip/train_unclip_prior.py +497 -0
  20. torchdiff-2.2.0/unclip/decoder_model.py → torchdiff-2.3.0/unclip/unclip_decoder.py +94 -129
  21. torchdiff-2.3.0/unclip/unclip_sampler.py +307 -0
  22. torchdiff-2.3.0/unclip/unclip_trainstormer_prior.py +357 -0
  23. torchdiff-2.3.0/unclip/upsampler_trainer.py +559 -0
  24. torchdiff-2.2.0/unclip/upsampler.py → torchdiff-2.3.0/unclip/upsampler_unclip.py +65 -141
  25. torchdiff-2.2.0/torchdiff/__init__.py +0 -8
  26. torchdiff-2.2.0/torchdiff/tests/test_ldm.py +0 -660
  27. torchdiff-2.2.0/torchdiff/tests/test_unclip.py +0 -315
  28. torchdiff-2.2.0/torchdiff/unclip.py +0 -4171
  29. torchdiff-2.2.0/unclip/ddim_model.py +0 -1296
  30. torchdiff-2.2.0/unclip/prior_diff.py +0 -402
  31. torchdiff-2.2.0/unclip/prior_model.py +0 -264
  32. torchdiff-2.2.0/unclip/project_decoder.py +0 -57
  33. torchdiff-2.2.0/unclip/train_decoder.py +0 -1059
  34. torchdiff-2.2.0/unclip/train_prior.py +0 -757
  35. torchdiff-2.2.0/unclip/unclip_sampler.py +0 -626
  36. torchdiff-2.2.0/unclip/upsampler_trainer.py +0 -784
  37. torchdiff-2.2.0/unclip/utils.py +0 -1793
  38. torchdiff-2.2.0/unclip/val_metrics.py +0 -221
  39. {torchdiff-2.2.0 → torchdiff-2.3.0}/LICENSE +0 -0
  40. {torchdiff-2.2.0 → torchdiff-2.3.0}/TorchDiff.egg-info/dependency_links.txt +0 -0
  41. {torchdiff-2.2.0 → torchdiff-2.3.0}/TorchDiff.egg-info/requires.txt +0 -0
  42. {torchdiff-2.2.0 → torchdiff-2.3.0}/TorchDiff.egg-info/top_level.txt +0 -0
  43. {torchdiff-2.2.0 → torchdiff-2.3.0}/ddim/__init__.py +0 -0
  44. {torchdiff-2.2.0 → torchdiff-2.3.0}/ddim/forward_ddim.py +0 -0
  45. {torchdiff-2.2.0 → torchdiff-2.3.0}/ddim/reverse_ddim.py +0 -0
  46. {torchdiff-2.2.0 → torchdiff-2.3.0}/ddim/sample_ddim.py +0 -0
  47. {torchdiff-2.2.0 → torchdiff-2.3.0}/ddim/scheduler.py +0 -0
  48. {torchdiff-2.2.0 → torchdiff-2.3.0}/ddim/test_ddim.py +0 -0
  49. {torchdiff-2.2.0 → torchdiff-2.3.0}/ddim/train_ddim.py +0 -0
  50. {torchdiff-2.2.0 → torchdiff-2.3.0}/ddpm/__init__.py +0 -0
  51. {torchdiff-2.2.0 → torchdiff-2.3.0}/ddpm/forward_ddpm.py +0 -0
  52. {torchdiff-2.2.0 → torchdiff-2.3.0}/ddpm/reverse_ddpm.py +0 -0
  53. {torchdiff-2.2.0 → torchdiff-2.3.0}/ddpm/sample_ddpm.py +0 -0
  54. {torchdiff-2.2.0 → torchdiff-2.3.0}/ddpm/scheduler.py +0 -0
  55. {torchdiff-2.2.0 → torchdiff-2.3.0}/ddpm/test_ddpm.py +0 -0
  56. {torchdiff-2.2.0 → torchdiff-2.3.0}/ddpm/train_ddpm.py +0 -0
  57. {torchdiff-2.2.0 → torchdiff-2.3.0}/ldm/__init__.py +0 -0
  58. {torchdiff-2.2.0 → torchdiff-2.3.0}/ldm/autoencoder.py +0 -0
  59. {torchdiff-2.2.0 → torchdiff-2.3.0}/ldm/sample_ldm.py +0 -0
  60. {torchdiff-2.2.0 → torchdiff-2.3.0}/ldm/train_autoencoder.py +0 -0
  61. {torchdiff-2.2.0 → torchdiff-2.3.0}/ldm/train_ldm.py +0 -0
  62. {torchdiff-2.2.0 → torchdiff-2.3.0}/sde/__init__.py +0 -0
  63. {torchdiff-2.2.0 → torchdiff-2.3.0}/sde/forward_sde.py +0 -0
  64. {torchdiff-2.2.0 → torchdiff-2.3.0}/sde/reverse_sde.py +0 -0
  65. {torchdiff-2.2.0 → torchdiff-2.3.0}/sde/sample_sde.py +0 -0
  66. {torchdiff-2.2.0 → torchdiff-2.3.0}/sde/scheduler.py +0 -0
  67. {torchdiff-2.2.0 → torchdiff-2.3.0}/sde/test_sde.py +0 -0
  68. {torchdiff-2.2.0 → torchdiff-2.3.0}/sde/train_sde.py +0 -0
  69. {torchdiff-2.2.0 → torchdiff-2.3.0}/setup.cfg +0 -0
  70. {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/tests/__init__.py +0 -0
  71. {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/tests/test_ddim.py +0 -0
  72. {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/tests/test_ddpm.py +0 -0
  73. {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/tests/test_sde.py +0 -0
  74. {torchdiff-2.2.0 → torchdiff-2.3.0}/unclip/__init__.py +0 -0
  75. {torchdiff-2.2.0 → torchdiff-2.3.0}/utils/__init__.py +0 -0
  76. {torchdiff-2.2.0 → torchdiff-2.3.0}/utils/diff_net.py +0 -0
  77. {torchdiff-2.2.0 → torchdiff-2.3.0}/utils/losses.py +0 -0
  78. {torchdiff-2.2.0 → torchdiff-2.3.0}/utils/metrics.py +0 -0
  79. {torchdiff-2.2.0 → torchdiff-2.3.0}/utils/text_encoder.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: TorchDiff
3
- Version: 2.2.0
3
+ Version: 2.3.0
4
4
  Summary: A PyTorch-based library for diffusion models
5
5
  Home-page: https://github.com/LoqmanSamani/TorchDiff
6
6
  Author: Loghman Samani
@@ -9,8 +9,8 @@
9
9
 
10
10
  [![License: MIT](https://img.shields.io/badge/license-MIT-red?style=plastic)](https://opensource.org/licenses/MIT)
11
11
  [![PyTorch](https://img.shields.io/badge/PyTorch-white?style=plastic&logo=pytorch&logoColor=red)](https://pytorch.org/)
12
- [![Version](https://img.shields.io/badge/version-2.1.0-blue?style=plastic)](https://pypi.org/project/torchdiff/)
13
- [![Python](https://img.shields.io/badge/python-3.8%2B-blue?style=plastic&logo=python&logoColor=white)](https://www.python.org/)
12
+ [![Version](https://img.shields.io/badge/version-2.2.0-blue?style=plastic)](https://pypi.org/project/torchdiff/)
13
+ [![Python](https://img.shields.io/badge/python-3.10%2B-blue?style=plastic&logo=python&logoColor=white)](https://www.python.org/)
14
14
  [![Downloads](https://pepy.tech/badge/torchdiff)](https://pepy.tech/project/torchdiff)
15
15
  [![Stars](https://img.shields.io/github/stars/LoqmanSamani/TorchDiff?style=plastic&color=yellow)](https://github.com/LoqmanSamani/TorchDiff)
16
16
  [![Forks](https://img.shields.io/github/forks/LoqmanSamani/TorchDiff?style=plastic&color=orange)](https://github.com/LoqmanSamani/TorchDiff)
@@ -118,7 +118,7 @@ trainer = TrainDDPM(
118
118
  device = device,
119
119
  grad_acc = 2
120
120
  )
121
- #trainer()
121
+ trainer()
122
122
 
123
123
  # Sampling
124
124
  sampler = SampleDDPM(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: TorchDiff
3
- Version: 2.2.0
3
+ Version: 2.3.0
4
4
  Summary: A PyTorch-based library for diffusion models
5
5
  Home-page: https://github.com/LoqmanSamani/TorchDiff
6
6
  Author: Loghman Samani
@@ -42,24 +42,20 @@ torchdiff/utils.py
42
42
  torchdiff/tests/__init__.py
43
43
  torchdiff/tests/test_ddim.py
44
44
  torchdiff/tests/test_ddpm.py
45
- torchdiff/tests/test_ldm.py
46
45
  torchdiff/tests/test_sde.py
47
- torchdiff/tests/test_unclip.py
48
46
  unclip/__init__.py
49
- unclip/clip_model.py
50
- unclip/ddim_model.py
51
- unclip/decoder_model.py
52
- unclip/prior_diff.py
53
- unclip/prior_model.py
54
- unclip/project_decoder.py
55
- unclip/project_prior.py
56
- unclip/train_decoder.py
57
- unclip/train_prior.py
47
+ unclip/clip_encoder.py
48
+ unclip/forward_unclip.py
49
+ unclip/projections.py
50
+ unclip/reverse_unclip.py
51
+ unclip/scheduler.py
52
+ unclip/train_unclip_decoder.py
53
+ unclip/train_unclip_prior.py
54
+ unclip/unclip_decoder.py
58
55
  unclip/unclip_sampler.py
59
- unclip/upsampler.py
56
+ unclip/unclip_trainstormer_prior.py
60
57
  unclip/upsampler_trainer.py
61
- unclip/utils.py
62
- unclip/val_metrics.py
58
+ unclip/upsampler_unclip.py
63
59
  utils/__init__.py
64
60
  utils/diff_net.py
65
61
  utils/losses.py
@@ -17,7 +17,7 @@ if not long_description:
17
17
 
18
18
  setup(
19
19
  name="TorchDiff",
20
- version="2.2.0",
20
+ version="2.3.0",
21
21
  description="A PyTorch-based library for diffusion models",
22
22
  long_description=long_description,
23
23
  long_description_content_type="text/markdown",
@@ -0,0 +1,13 @@
1
+ __version__ = "2.3.0"
2
+
3
+ from .ddim import ForwardDDIM, ReverseDDIM, SchedulerDDIM, TrainDDIM, SampleDDIM
4
+ from .ddpm import ForwardDDPM, ReverseDDPM, SchedulerDDPM, TrainDDPM, SampleDDPM
5
+ from .ldm import TrainLDM, TrainAE, AutoencoderLDM, SampleLDM
6
+ from .sde import ForwardSDE, ReverseSDE, SchedulerSDE, TrainSDE, SampleSDE
7
+ from .unclip import (
8
+ ForwardUnCLIP, ReverseUnCLIP, SchedulerUnCLIP, CLIPEncoder,
9
+ SampleUnCLIP, UnClipDecoder, UnCLIPTransformerPrior,
10
+ CLIPContextProjection, CLIPEmbeddingProjection, TrainUnClipDecoder,
11
+ SampleUnCLIP, UpsamplerUnCLIP, TrainUpsamplerUnCLIP
12
+ )
13
+ from .utils import DiffusionNetwork, TextEncoder, Metrics, mse_loss, snr_capped_loss, ve_sigma_weighted_score_loss
@@ -345,7 +345,6 @@ class SchedulerDDIM(nn.Module):
345
345
  """
346
346
  step_ratio = self.train_steps // self.sample_steps
347
347
  inference_timesteps = torch.arange(0, self.train_steps, step_ratio)
348
-
349
348
  self.register_buffer('inference_timesteps', inference_timesteps)
350
349
 
351
350
  def set_inference_timesteps(self, num_inference_timesteps: int):
@@ -393,49 +392,49 @@ class TrainDDIM(nn.Module):
393
392
 
394
393
  Parameters
395
394
  ----------
396
- `noise_predictor` : nn.Module
397
- Model to predict noise added during the forward diffusion process.
395
+ `diff_net` : nn.Module
396
+ Main model to predict noise/v/x0
398
397
  fwd_ddim : nn.Module
399
398
  Forward DDIM diffusion module for adding noise.
400
399
  rwd_ddim: nn.Module
401
400
  Reverse DDIM diffusion module for denoising.
402
401
  `data_loader` : torch.utils.data.DataLoader
403
402
  DataLoader for training data.
404
- `optimizer` : torch.optim.Optimizer
403
+ `optim` : torch.optim.Optimizer
405
404
  Optimizer for training the noise predictor and conditional model (if applicable).
406
- `objective` : callable
405
+ `loss_fn` : callable
407
406
  Loss function to compute the difference between predicted and actual noise.
408
407
  `val_loader` : torch.utils.data.DataLoader, optional
409
408
  DataLoader for validation data, default None.
410
409
  `max_epochs` : int, optional
411
- Maximum number of training epochs (default: 1000).
412
- `device` : torch.device, optional
413
- Device for computation (default: CUDA if available, else CPU).
414
- `conditional_model` : nn.Module, optional
410
+ Maximum number of training epochs (default: 100).
411
+ `device` : str
412
+ Device for computation (default: CUDA).
413
+ `cond_net` : nn.Module, optional
415
414
  Model for conditional generation (e.g., text embeddings), default None.
416
415
  `metrics_` : object, optional
417
416
  Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
418
- `bert_tokenizer` : BertTokenizer, optional
417
+ `tokenizer` : BertTokenizer, optional
419
418
  Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
420
419
  `max_token_length` : int, optional
421
420
  Maximum length for tokenized prompts (default: 77).
422
421
  `store_path` : str, optional
423
- Path to save model checkpoints (default: "ddim_model.pth").
422
+ Path to save model checkpoints (default: "ddim_train").
424
423
  `patience` : int, optional
425
- Number of epochs to wait for improvement before early stopping (default: 100).
426
- `warmup_epochs` : int, optional
427
- Number of epochs for learning rate warmup (default: 100).
428
- `val_frequency` : int, optional
424
+ Number of epochs to wait for improvement before early stopping (default: 20).
425
+ `warmup_steps` : int, optional
426
+ Number of epochs for learning rate warmup (default: 1000).
427
+ `val_freq` : int, optional
429
428
  Frequency (in epochs) for validation (default: 10).
430
- `output_range` : tuple, optional
429
+ `norm_range` : tuple, optional
431
430
  Range for clamping generated images (default: (-1, 1)).
432
- `normalize_output` : bool, optional
431
+ `norm_output` : bool, optional
433
432
  Whether to normalize generated images to [0, 1] for metrics (default: True).
434
433
  `use_ddp` : bool, optional
435
434
  Whether to use Distributed Data Parallel training (default: False).
436
- `grad_accumulation_steps` : int, optional
435
+ `grad_acc` : int, optional
437
436
  Number of gradient accumulation steps before optimizer update (default: 1).
438
- `log_frequency` : int, optional
437
+ `log_freq` : int, optional
439
438
  Number of epochs before printing loss.
440
439
  use_comp : bool, optional
441
440
  whether the model is internally compiled using torch.compile (default: false)
@@ -449,15 +448,15 @@ class TrainDDIM(nn.Module):
449
448
  optim: torch.optim.Optimizer,
450
449
  loss_fn: Callable,
451
450
  val_loader: Optional[torch.utils.data.DataLoader] = None,
452
- max_epochs: int = 1000,
451
+ max_epochs: int = 100,
453
452
  device: str = 'cuda',
454
- cond_model: torch.nn.Module = None,
453
+ cond_net: torch.nn.Module = None,
455
454
  metrics_: Optional[Any] = None,
456
- bert_tokenizer: Optional[BertTokenizer] = None,
455
+ tokenizer: Optional[BertTokenizer] = None,
457
456
  max_token_length: int = 77,
458
457
  store_path: Optional[str] = None,
459
- patience: int = 100,
460
- warmup_steps: int = 10000,
458
+ patience: int = 20,
459
+ warmup_steps: int = 1000,
461
460
  val_freq: int = 10,
462
461
  norm_range: Tuple[float, float] = (-1, 1),
463
462
  norm_output: bool = True,
@@ -481,11 +480,11 @@ class TrainDDIM(nn.Module):
481
480
  self.diff_net = diff_net.to(self.device)
482
481
  self.fwd_ddim = fwd_ddim.to(self.device)
483
482
  self.rwd_ddim = rwd_ddim.to(self.device)
484
- self.cond_model = cond_model.to(self.device) if cond_model else None
483
+ self.cond_net = cond_net.to(self.device) if cond_net else None
485
484
  self.metrics_ = metrics_
486
485
  self.optim = optim
487
486
  self.loss_fn = loss_fn
488
- self.store_path = store_path or "ddim_model"
487
+ self.store_path = store_path or "ddim_train"
489
488
  self.train_loader = train_loader
490
489
  self.val_loader = val_loader
491
490
  self.max_epochs = max_epochs
@@ -506,7 +505,7 @@ class TrainDDIM(nn.Module):
506
505
  factor=0.5
507
506
  )
508
507
  self.warmup_lr_scheduler = self.warmup_scheduler(self.optim, warmup_steps)
509
- if bert_tokenizer is None:
508
+ if tokenizer is None:
510
509
  try:
511
510
  self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
512
511
  except Exception as e:
@@ -574,14 +573,14 @@ class TrainDDIM(nn.Module):
574
573
  elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
575
574
  state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
576
575
  self.diff_net.load_state_dict(state_dict)
577
- if self.cond_model is not None:
576
+ if self.cond_net is not None:
578
577
  if 'model_state_dict_cond' in checkpoint and checkpoint['model_state_dict_cond'] is not None:
579
578
  cond_state_dict = checkpoint['model_state_dict_cond']
580
579
  if self.use_ddp and not any(key.startswith('module.') for key in cond_state_dict.keys()):
581
580
  cond_state_dict = {f'module.{k}': v for k, v in cond_state_dict.items()}
582
581
  elif not self.use_ddp and any(key.startswith('module.') for key in cond_state_dict.keys()):
583
582
  cond_state_dict = {k.replace('module.', ''): v for k, v in cond_state_dict.items()}
584
- self.cond_model.load_state_dict(cond_state_dict)
583
+ self.cond_net.load_state_dict(cond_state_dict)
585
584
  else:
586
585
  warnings.warn(
587
586
  "Checkpoint contains no 'model_state_dict_cond' or it is None, "
@@ -625,8 +624,8 @@ class TrainDDIM(nn.Module):
625
624
  ----------
626
625
  `optimizer` : torch.optim.Optimizer
627
626
  Optimizer to apply the scheduler to.
628
- `warmup_epochs` : int
629
- Number of epochs for the warmup phase.
627
+ `warmup_steps` : int
628
+ Number of steps for the warmup phase.
630
629
 
631
630
  Returns
632
631
  -------
@@ -646,9 +645,9 @@ class TrainDDIM(nn.Module):
646
645
  device_ids=[self.ddp_local_rank],
647
646
  find_unused_parameters=True
648
647
  )
649
- if self.cond_model is not None:
650
- self.cond_model = DDP(
651
- self.cond_model,
648
+ if self.cond_net is not None:
649
+ self.cond_net = DDP(
650
+ self.cond_net,
652
651
  device_ids=[self.ddp_local_rank],
653
652
  find_unused_parameters=True
654
653
  )
@@ -662,20 +661,17 @@ class TrainDDIM(nn.Module):
662
661
 
663
662
  Returns
664
663
  -------
665
- train_losses : list of float
666
- List of mean training losses per epoch.
667
- best_val_loss : float
668
- Best validation or training loss achieved.
664
+ losses: dictionlary contains train and validation losses
669
665
  """
670
666
  self.diff_net.train()
671
- if self.cond_model is not None:
672
- self.cond_model.train()
667
+ if self.cond_net is not None:
668
+ self.cond_net.train()
673
669
 
674
670
  if self.use_comp:
675
671
  try:
676
672
  self.diff_net = torch.compile(self.diff_net)
677
- if self.cond_model is not None:
678
- self.cond_model = torch.compile(self.cond_model)
673
+ if self.cond_net is not None:
674
+ self.cond_net = torch.compile(self.cond_net)
679
675
  except Exception as e:
680
676
  if self.master_process:
681
677
  print(f"Model compilation failed: {e}. Continuing without compilation.")
@@ -690,7 +686,7 @@ class TrainDDIM(nn.Module):
690
686
  train_losses_epoch = []
691
687
  for step, (x, y) in enumerate(pbar):
692
688
  x = x.to(self.device)
693
- if self.cond_model is not None:
689
+ if self.cond_net is not None:
694
690
  y_encoded = self._process_conditional_input(y)
695
691
  else:
696
692
  y_encoded = None
@@ -705,8 +701,8 @@ class TrainDDIM(nn.Module):
705
701
  if (step + 1) % self.grad_acc == 0:
706
702
  scaler.unscale_(self.optim)
707
703
  torch.nn.utils.clip_grad_norm_(self.diff_net.parameters(), max_norm=1.0)
708
- if self.cond_model is not None:
709
- torch.nn.utils.clip_grad_norm_(self.cond_model.parameters(), max_norm=1.0)
704
+ if self.cond_net is not None:
705
+ torch.nn.utils.clip_grad_norm_(self.cond_net.parameters(), max_norm=1.0)
710
706
  scaler.step(self.optim)
711
707
  scaler.update()
712
708
  self.optim.zero_grad()
@@ -786,7 +782,7 @@ class TrainDDIM(nn.Module):
786
782
  ).to(self.device)
787
783
  input_ids = y_encoded["input_ids"]
788
784
  attention_mask = y_encoded["attention_mask"]
789
- y_encoded = self.cond_model(input_ids, attention_mask)
785
+ y_encoded = self.cond_net(input_ids, attention_mask)
790
786
  return y_encoded
791
787
 
792
788
  def _save_checkpoint(self, epoch: int, loss: float, pref: str = "") -> None:
@@ -807,10 +803,10 @@ class TrainDDIM(nn.Module):
807
803
  else self.diff_net.state_dict()
808
804
  )
809
805
  cond_state = None
810
- if self.cond_model is not None:
806
+ if self.cond_net is not None:
811
807
  cond_state = (
812
- self.cond_model.module.state_dict() if self.use_ddp
813
- else self.cond_model.state_dict()
808
+ self.cond_net.module.state_dict() if self.use_ddp
809
+ else self.cond_net.state_dict()
814
810
  )
815
811
  checkpoint = {
816
812
  'epoch': epoch,
@@ -826,8 +822,7 @@ class TrainDDIM(nn.Module):
826
822
  filepath = os.path.join(self.store_path, filename)
827
823
  os.makedirs(self.store_path, exist_ok=True)
828
824
  torch.save(checkpoint, filepath)
829
-
830
- print(f"Model saved at epoch {epoch}")
825
+ print(f"Model saved at epoch {epoch} with loss: {loss:.4f}")
831
826
  except Exception as e:
832
827
  print(f"Failed to save model: {e}")
833
828
 
@@ -856,8 +851,8 @@ class TrainDDIM(nn.Module):
856
851
  """
857
852
 
858
853
  self.diff_net.eval()
859
- if self.cond_model is not None:
860
- self.cond_model.eval()
854
+ if self.cond_net is not None:
855
+ self.cond_net.eval()
861
856
 
862
857
  val_losses = []
863
858
  fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
@@ -865,7 +860,7 @@ class TrainDDIM(nn.Module):
865
860
  for x, y in self.val_loader:
866
861
  x = x.to(self.device)
867
862
  x_orig = x.clone()
868
- if self.cond_model is not None:
863
+ if self.cond_net is not None:
869
864
  y_encoded = self._process_conditional_input(y)
870
865
  else:
871
866
  y_encoded = None
@@ -880,10 +875,10 @@ class TrainDDIM(nn.Module):
880
875
  xt = torch.randn_like(x)
881
876
  timesteps = self.fwd_ddim.vs.inference_timesteps.flip(0)
882
877
  for i in range(len(timesteps) - 1):
883
- t_current = timesteps[i].item()
884
- t_next = timesteps[i + 1].item()
885
- time = torch.full((xt.shape[0],), t_current, device=self.device, dtype=torch.long)
886
- prev_time = torch.full((xt.shape[0],), t_next, device=self.device, dtype=torch.long)
878
+ t_ = timesteps[i].item()
879
+ t_pre = timesteps[i + 1].item()
880
+ time = torch.full((xt.shape[0],), t_, device=self.device, dtype=torch.long)
881
+ prev_time = torch.full((xt.shape[0],), t_pre, device=self.device, dtype=torch.long)
887
882
  pred = self.diff_net(xt, time, y_encoded, clip_embeddings=None)
888
883
  xt, _ = self.rwd_ddim(xt, time, prev_time, pred)
889
884
  x_hat = torch.clamp(xt, min=self.norm_range[0], max=self.norm_range[1])
@@ -915,8 +910,8 @@ class TrainDDIM(nn.Module):
915
910
  lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
916
911
 
917
912
  self.diff_net.train()
918
- if self.cond_model is not None:
919
- self.cond_model.train()
913
+ if self.cond_net is not None:
914
+ self.cond_net.train()
920
915
  return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
921
916
 
922
917
  ###==================================================================================================================###
@@ -931,13 +926,13 @@ class SampleDDIM(nn.Module):
931
926
 
932
927
  Parameters
933
928
  ----------
934
- `reverse_diffusion` : nn.Module
929
+ `rwd_ddim` : nn.Module
935
930
  Reverse diffusion module (e.g., ReverseDDIM) for the reverse process.
936
- `noise_predictor` : nn.Module
937
- Trained model to predict noise at each time step.
938
- `image_shape` : tuple
931
+ `diff_net` : nn.Module
932
+ Trained model to predict noise/v/x0 at each time step.
933
+ `img_size` : tuple
939
934
  Tuple of (height, width) specifying the generated image dimensions.
940
- `conditional_model` : nn.Module, optional
935
+ `cond_net` : nn.Module, optional
941
936
  Model for conditional generation (e.g., text embeddings), default None.
942
937
  `tokenizer` : str, optional
943
938
  Pretrained tokenizer name from Hugging Face (default: "bert-base-uncased").
@@ -947,9 +942,9 @@ class SampleDDIM(nn.Module):
947
942
  Number of images to generate per batch (default: 1).
948
943
  `in_channels` : int, optional
949
944
  Number of input channels for generated images (default: 3).
950
- `device` : torch.device, optional
951
- Device for computation (default: CUDA if available, else CPU).
952
- `output_range` : tuple, optional
945
+ `device` : str
946
+ Device for computation (default: CUDA).
947
+ `norm_range` : tuple, optional
953
948
  Tuple of (min, max) for clamping generated images (default: (-1, 1)).
954
949
  """
955
950
  def __init__(
@@ -957,7 +952,7 @@ class SampleDDIM(nn.Module):
957
952
  rwd_ddim: torch.nn.Module,
958
953
  diff_net: torch.nn.Module,
959
954
  img_size: Tuple[int, int],
960
- cond_model: Optional[torch.nn.Module] = None,
955
+ cond_net: Optional[torch.nn.Module] = None,
961
956
  tokenizer: str = "bert-base-uncased",
962
957
  max_token_length: int = 77,
963
958
  batch_size: int = 1,
@@ -972,7 +967,7 @@ class SampleDDIM(nn.Module):
972
967
  self.device = device
973
968
  self.rwd_ddim = rwd_ddim.to(self.device)
974
969
  self.diff_net = diff_net.to(self.device)
975
- self.cond_model = cond_model.to(self.device) if cond_model else None
970
+ self.cond_net = cond_net.to(self.device) if cond_net else None
976
971
  self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
977
972
  self.max_token_length = max_token_length
978
973
  self.in_channels = in_channels
@@ -1035,21 +1030,21 @@ class SampleDDIM(nn.Module):
1035
1030
  `save_imgs` : bool, optional
1036
1031
  If True, saves generated images to `save_path` (default: True).
1037
1032
  `save_path` : str, optional
1038
- Directory to save generated images (default: "ddim_generated").
1033
+ Directory to save generated images (default: "ddim_samples").
1039
1034
 
1040
1035
  Returns
1041
1036
  -------
1042
1037
  samps (torch.Tensor) - Generated images, shape (batch_size, in_channels, height, width).
1043
1038
  """
1044
- if conds is not None and self.cond_model is None:
1039
+ if conds is not None and self.cond_net is None:
1045
1040
  raise ValueError("Conditions provided but no conditional model specified")
1046
- if conds is None and self.cond_model is not None:
1041
+ if conds is None and self.cond_net is not None:
1047
1042
  raise ValueError("Conditions must be provided for conditional model")
1048
1043
 
1049
1044
  init_samps = torch.randn(self.batch_size, self.in_channels, self.img_size[0], self.img_size[1]).to(self.device)
1050
1045
  self.diff_net.eval()
1051
- if self.cond_model:
1052
- self.cond_model.eval()
1046
+ if self.cond_net:
1047
+ self.cond_net.eval()
1053
1048
  timesteps = self.rwd_ddim.vs.inference_timesteps
1054
1049
  timesteps = timesteps.flip(0)
1055
1050
  iterator = tqdm(
@@ -1059,10 +1054,10 @@ class SampleDDIM(nn.Module):
1059
1054
  dynamic_ncols=True,
1060
1055
  leave=True,
1061
1056
  )
1062
- if self.cond_model is not None and conds is not None:
1057
+ if self.cond_net is not None and conds is not None:
1063
1058
  input_ids, attention_masks = self.tokenize(conds)
1064
1059
  key_padding_mask = (attention_masks == 0)
1065
- y = self.cond_model(input_ids, key_padding_mask)
1060
+ y = self.cond_net(input_ids, key_padding_mask)
1066
1061
  else:
1067
1062
  y = None
1068
1063
 
@@ -1070,9 +1065,10 @@ class SampleDDIM(nn.Module):
1070
1065
  xt = init_samps
1071
1066
  for i in iterator:
1072
1067
  t_current = timesteps[i].item()
1073
- t_next = timesteps[i + 1].item()
1068
+ t_prev = timesteps[i + 1].item()
1069
+ #assert t_current > t_prev or t_prev == 0
1074
1070
  time = torch.full((self.batch_size,), t_current, device=self.device, dtype=torch.long)
1075
- prev_time = torch.full((self.batch_size,), t_next, device=self.device, dtype=torch.long)
1071
+ prev_time = torch.full((self.batch_size,), t_prev, device=self.device, dtype=torch.long)
1076
1072
  pred = self.diff_net(xt, time, y, clip_embeddings=None)
1077
1073
  xt, _ = self.rwd_ddim(xt, time, prev_time, pred)
1078
1074
  samps = torch.clamp(xt, min=self.norm_range[0], max=self.norm_range[1])
@@ -1081,7 +1077,7 @@ class SampleDDIM(nn.Module):
1081
1077
  if save_imgs:
1082
1078
  os.makedirs(save_path, exist_ok=True)
1083
1079
  for i in range(samps.size(0)):
1084
- img_path = os.path.join(save_path, f"img_{i + 1}.png")
1080
+ img_path = os.path.join(save_path, f"img_{i+1}.png")
1085
1081
  save_image(samps[i], img_path)
1086
1082
  return samps
1087
1083
 
@@ -1102,6 +1098,6 @@ class SampleDDIM(nn.Module):
1102
1098
  """
1103
1099
  self.device = device
1104
1100
  self.diff_net.to(device)
1105
- if self.cond_model:
1106
- self.cond_model.to(device)
1101
+ if self.cond_net:
1102
+ self.cond_net.to(device)
1107
1103
  return super().to(device)