dnnlpy 2026.6.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. dnnlpy-2026.6.11/PKG-INFO +164 -0
  2. dnnlpy-2026.6.11/README.md +137 -0
  3. dnnlpy-2026.6.11/pyproject.toml +44 -0
  4. dnnlpy-2026.6.11/src/dnnlpy/__init__.py +7 -0
  5. dnnlpy-2026.6.11/src/dnnlpy/configtools.py +64 -0
  6. dnnlpy-2026.6.11/src/dnnlpy/models/__init__.py +5 -0
  7. dnnlpy-2026.6.11/src/dnnlpy/models/ddpm/__init__.py +5 -0
  8. dnnlpy-2026.6.11/src/dnnlpy/models/ddpm/ddpm.py +154 -0
  9. dnnlpy-2026.6.11/src/dnnlpy/models/ddpm/embedding.py +44 -0
  10. dnnlpy-2026.6.11/src/dnnlpy/models/ddpm/unet.py +259 -0
  11. dnnlpy-2026.6.11/src/dnnlpy/models/ddpm/utils.py +39 -0
  12. dnnlpy-2026.6.11/src/dnnlpy/models/mlp/__init__.py +13 -0
  13. dnnlpy-2026.6.11/src/dnnlpy/models/mlp/activation.py +120 -0
  14. dnnlpy-2026.6.11/src/dnnlpy/models/mlp/base.py +128 -0
  15. dnnlpy-2026.6.11/src/dnnlpy/models/mlp/layer.py +89 -0
  16. dnnlpy-2026.6.11/src/dnnlpy/models/mlp/loss.py +59 -0
  17. dnnlpy-2026.6.11/src/dnnlpy/models/mlp/mlp.py +42 -0
  18. dnnlpy-2026.6.11/src/dnnlpy/models/mlp/optimizer.py +29 -0
  19. dnnlpy-2026.6.11/src/dnnlpy/models/seq2seq/__init__.py +1 -0
  20. dnnlpy-2026.6.11/src/dnnlpy/models/seq2seq/transformer.py +87 -0
  21. dnnlpy-2026.6.11/src/dnnlpy/models/vae/__init__.py +2 -0
  22. dnnlpy-2026.6.11/src/dnnlpy/models/vae/autoencoder.py +59 -0
  23. dnnlpy-2026.6.11/src/dnnlpy/models/vae/vae.py +112 -0
  24. dnnlpy-2026.6.11/src/dnnlpy/models/vit/__init__.py +11 -0
  25. dnnlpy-2026.6.11/src/dnnlpy/models/vit/embedding.py +247 -0
  26. dnnlpy-2026.6.11/src/dnnlpy/models/vit/utils.py +31 -0
  27. dnnlpy-2026.6.11/src/dnnlpy/models/vit/vit.py +283 -0
  28. dnnlpy-2026.6.11/src/dnnlpy/nn/__init__.py +17 -0
  29. dnnlpy-2026.6.11/src/dnnlpy/nn/activation.py +81 -0
  30. dnnlpy-2026.6.11/src/dnnlpy/nn/attention.py +113 -0
  31. dnnlpy-2026.6.11/src/dnnlpy/nn/functional/__init__.py +14 -0
  32. dnnlpy-2026.6.11/src/dnnlpy/nn/functional/activation.py +47 -0
  33. dnnlpy-2026.6.11/src/dnnlpy/nn/functional/attention.py +231 -0
  34. dnnlpy-2026.6.11/src/dnnlpy/nn/functional/flash_attention.py +385 -0
  35. dnnlpy-2026.6.11/src/dnnlpy/nn/functional/linear.py +23 -0
  36. dnnlpy-2026.6.11/src/dnnlpy/nn/functional/loss.py +43 -0
  37. dnnlpy-2026.6.11/src/dnnlpy/nn/linear.py +73 -0
  38. dnnlpy-2026.6.11/src/dnnlpy/nn/loss.py +25 -0
  39. dnnlpy-2026.6.11/src/dnnlpy/nn/transformer.py +557 -0
  40. dnnlpy-2026.6.11/src/dnnlpy/optim/__init__.py +15 -0
  41. dnnlpy-2026.6.11/src/dnnlpy/optim/adadelta.py +84 -0
  42. dnnlpy-2026.6.11/src/dnnlpy/optim/adagrad.py +76 -0
  43. dnnlpy-2026.6.11/src/dnnlpy/optim/adam.py +77 -0
  44. dnnlpy-2026.6.11/src/dnnlpy/optim/adamw.py +80 -0
  45. dnnlpy-2026.6.11/src/dnnlpy/optim/base.py +44 -0
  46. dnnlpy-2026.6.11/src/dnnlpy/optim/muon.py +120 -0
  47. dnnlpy-2026.6.11/src/dnnlpy/optim/rmsprop.py +83 -0
  48. dnnlpy-2026.6.11/src/dnnlpy/optim/sgd.py +200 -0
  49. dnnlpy-2026.6.11/src/dnnlpy/optim/utils.py +120 -0
  50. dnnlpy-2026.6.11/src/dnnlpy/py.typed +0 -0
  51. dnnlpy-2026.6.11/src/dnnlpy/pylabtools.py +158 -0
  52. dnnlpy-2026.6.11/src/dnnlpy/trainingtools.py +165 -0
@@ -0,0 +1,164 @@
1
+ Metadata-Version: 2.4
2
+ Name: dnnlpy
3
+ Version: 2026.6.11
4
+ Summary: A utility library for deep learning notes, providing common functions and tools to support the main content of the deep learning notes collection.
5
+ Keywords: deep learning notes,library,utilities
6
+ Author: Yunjie Lin
7
+ Author-email: Yunjie Lin <jshn9510@gmail.com>
8
+ License-Expression: MIT
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Intended Audience :: Education
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
14
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
15
+ Classifier: Programming Language :: Python :: 3 :: Only
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
18
+ Classifier: Programming Language :: Python :: 3.14
19
+ Requires-Dist: numpy>=2.4.0,<2.5.0
20
+ Requires-Dist: matplotlib>=3.11.0,<3.12.0
21
+ Requires-Dist: torch>=2.12.0,<2.13.0
22
+ Requires-Dist: pytest>=9.0.0,<9.1.0 ; extra == 'test'
23
+ Requires-Python: >=3.12, <3.15
24
+ Project-URL: Homepage, https://github.com/jshn9515/deep-learning-notes
25
+ Provides-Extra: test
26
+ Description-Content-Type: text/markdown
27
+
28
+ # dnnlpy
29
+
30
+ **dnnlpy** is the companion Python package for **Deep Learning Notes Library**.
31
+
32
+ It provides code examples, helper functions, and small utilities used throughout the tutorial, similar in spirit to the `d2l` package for _Dive into Deep Learning_.
33
+
34
+ The package structure is similar to PyTorch, but keeps a clear boundary between reusable neural network building blocks and complete model implementations:
35
+
36
+ - `dnnlpy.nn` contains general neural network modules, such as attention layers, positional encodings, and other reusable components.
37
+ - `dnnlpy.nn.functional` contains stateless helper functions, such as functional attention implementations.
38
+ - `dnnlpy.models` contains higher-level model architectures or model-specific components, such as ViT, DDPM, or other models introduced in the notes.
39
+
40
+ The APIs are designed to feel close to their PyTorch counterparts where practical, while still keeping the code lightweight and easy to read for tutorial purposes.
41
+
42
+ This package is intended as a lightweight code supplement rather than a general-purpose deep learning framework. Its goal is to make the examples in the notes easier to run, reuse, and extend.
43
+
44
+ ## What is this package for?
45
+
46
+ The `dnnlpy` package is designed to support the code in the **Deep Learning Notes Library** tutorial.
47
+
48
+ It can be used to:
49
+
50
+ - Organize example code from the notes
51
+ - Provide reusable utility functions
52
+ - Reduce repeated boilerplate in notebooks and scripts
53
+ - Make tutorial examples easier to reproduce
54
+
55
+ In short, this package serves as the code companion to the tutorial.
56
+
57
+ ## Requirements
58
+
59
+ - Python 3.14 or newer
60
+ - PyTorch 3.10 or newer
61
+
62
+ ## Installation
63
+
64
+ This project uses [uv](https://docs.astral.sh/uv/) for package management.
65
+
66
+ ```bash
67
+ git clone https://github.com/jshn9515/deep-learning-notes.git
68
+ cd dnnlpy
69
+ uv pip install .
70
+ ```
71
+
72
+ If you want to modify the package while working through the notes, editable installation is recommended:
73
+
74
+ ```bash
75
+ uv pip install -e .
76
+ ```
77
+
78
+ This way, changes to the source code take effect immediately without reinstalling the package each time.
79
+
80
+ ## Examples
81
+
82
+ After installation, you can import reusable neural network modules from `dnnlpy.nn`:
83
+
84
+ ```python
85
+ import torch
86
+ import dnnlpy.nn as dnn
87
+
88
+ attn = dnn.MultiheadAttention(embed_dim=16, num_heads=4)
89
+
90
+ query = torch.randn(2, 8, 16)
91
+ key = torch.randn(2, 8, 16)
92
+ value = torch.randn(2, 8, 16)
93
+
94
+ output = attn(query, key, value)
95
+ ```
96
+
97
+ You can also import stateless functions from `dnnlpy.nn.functional`:
98
+
99
+ ```python
100
+ import torch
101
+ import dnnlpy.nn.functional as dF
102
+
103
+ query = torch.randn(2, 4, 8, 16)
104
+ key = torch.randn(2, 4, 8, 16)
105
+ value = torch.randn(2, 4, 8, 16)
106
+
107
+ output, weights = dF.scaled_dot_product_attention(
108
+ query,
109
+ key,
110
+ value,
111
+ need_weights=True,
112
+ )
113
+ ```
114
+
115
+ Higher-level model architectures live under `dnnlpy.models`:
116
+
117
+ ```python
118
+ import torch
119
+ from dnnlpy.models.vit import ViTForImageClassification
120
+
121
+ model = ViTForImageClassification(
122
+ image_size=224,
123
+ patch_size=16,
124
+ in_channels=3,
125
+ num_classes=1000,
126
+ embed_dim=768,
127
+ num_heads=12,
128
+ num_layers=12,
129
+ )
130
+
131
+ images = torch.randn(2, 3, 224, 224)
132
+ logits = model(images)
133
+ ```
134
+
135
+ The `dnnlpy.models.mlp` package contains small NumPy modules for teaching manual
136
+ forward and backward passes:
137
+
138
+ ```python
139
+ import dnnlpy.models.mlp as mlp
140
+ import numpy as np
141
+
142
+ model = mlp.MLP(input_dim=4, hidden_dim=8, num_classes=3)
143
+ loss_fn = mlp.CrossEntropyLoss()
144
+ optimizer = mlp.SGD(model.parameters(), lr=0.1)
145
+
146
+ x = np.random.randn(2, 4)
147
+ targets = np.array([0, 2])
148
+
149
+ logits = model(x)
150
+ loss = loss_fn(logits, targets)
151
+ model.backward(loss_fn.backward())
152
+
153
+ optimizer.step()
154
+ optimizer.zero_grad()
155
+ ```
156
+
157
+ A simple rule of thumb is:
158
+
159
+ - Use `dnnlpy.nn` when a component is reusable across many models.
160
+ - Use `dnnlpy.models` when the code represents a complete architecture or is tightly coupled to one model family.
161
+
162
+ ## License
163
+
164
+ This project is licensed under the **MIT License**.
@@ -0,0 +1,137 @@
1
+ # dnnlpy
2
+
3
+ **dnnlpy** is the companion Python package for **Deep Learning Notes Library**.
4
+
5
+ It provides code examples, helper functions, and small utilities used throughout the tutorial, similar in spirit to the `d2l` package for _Dive into Deep Learning_.
6
+
7
+ The package structure is similar to PyTorch, but keeps a clear boundary between reusable neural network building blocks and complete model implementations:
8
+
9
+ - `dnnlpy.nn` contains general neural network modules, such as attention layers, positional encodings, and other reusable components.
10
+ - `dnnlpy.nn.functional` contains stateless helper functions, such as functional attention implementations.
11
+ - `dnnlpy.models` contains higher-level model architectures or model-specific components, such as ViT, DDPM, or other models introduced in the notes.
12
+
13
+ The APIs are designed to feel close to their PyTorch counterparts where practical, while still keeping the code lightweight and easy to read for tutorial purposes.
14
+
15
+ This package is intended as a lightweight code supplement rather than a general-purpose deep learning framework. Its goal is to make the examples in the notes easier to run, reuse, and extend.
16
+
17
+ ## What is this package for?
18
+
19
+ The `dnnlpy` package is designed to support the code in the **Deep Learning Notes Library** tutorial.
20
+
21
+ It can be used to:
22
+
23
+ - Organize example code from the notes
24
+ - Provide reusable utility functions
25
+ - Reduce repeated boilerplate in notebooks and scripts
26
+ - Make tutorial examples easier to reproduce
27
+
28
+ In short, this package serves as the code companion to the tutorial.
29
+
30
+ ## Requirements
31
+
32
+ - Python 3.14 or newer
33
+ - PyTorch 3.10 or newer
34
+
35
+ ## Installation
36
+
37
+ This project uses [uv](https://docs.astral.sh/uv/) for package management.
38
+
39
+ ```bash
40
+ git clone https://github.com/jshn9515/deep-learning-notes.git
41
+ cd dnnlpy
42
+ uv pip install .
43
+ ```
44
+
45
+ If you want to modify the package while working through the notes, editable installation is recommended:
46
+
47
+ ```bash
48
+ uv pip install -e .
49
+ ```
50
+
51
+ This way, changes to the source code take effect immediately without reinstalling the package each time.
52
+
53
+ ## Examples
54
+
55
+ After installation, you can import reusable neural network modules from `dnnlpy.nn`:
56
+
57
+ ```python
58
+ import torch
59
+ import dnnlpy.nn as dnn
60
+
61
+ attn = dnn.MultiheadAttention(embed_dim=16, num_heads=4)
62
+
63
+ query = torch.randn(2, 8, 16)
64
+ key = torch.randn(2, 8, 16)
65
+ value = torch.randn(2, 8, 16)
66
+
67
+ output = attn(query, key, value)
68
+ ```
69
+
70
+ You can also import stateless functions from `dnnlpy.nn.functional`:
71
+
72
+ ```python
73
+ import torch
74
+ import dnnlpy.nn.functional as dF
75
+
76
+ query = torch.randn(2, 4, 8, 16)
77
+ key = torch.randn(2, 4, 8, 16)
78
+ value = torch.randn(2, 4, 8, 16)
79
+
80
+ output, weights = dF.scaled_dot_product_attention(
81
+ query,
82
+ key,
83
+ value,
84
+ need_weights=True,
85
+ )
86
+ ```
87
+
88
+ Higher-level model architectures live under `dnnlpy.models`:
89
+
90
+ ```python
91
+ import torch
92
+ from dnnlpy.models.vit import ViTForImageClassification
93
+
94
+ model = ViTForImageClassification(
95
+ image_size=224,
96
+ patch_size=16,
97
+ in_channels=3,
98
+ num_classes=1000,
99
+ embed_dim=768,
100
+ num_heads=12,
101
+ num_layers=12,
102
+ )
103
+
104
+ images = torch.randn(2, 3, 224, 224)
105
+ logits = model(images)
106
+ ```
107
+
108
+ The `dnnlpy.models.mlp` package contains small NumPy modules for teaching manual
109
+ forward and backward passes:
110
+
111
+ ```python
112
+ import dnnlpy.models.mlp as mlp
113
+ import numpy as np
114
+
115
+ model = mlp.MLP(input_dim=4, hidden_dim=8, num_classes=3)
116
+ loss_fn = mlp.CrossEntropyLoss()
117
+ optimizer = mlp.SGD(model.parameters(), lr=0.1)
118
+
119
+ x = np.random.randn(2, 4)
120
+ targets = np.array([0, 2])
121
+
122
+ logits = model(x)
123
+ loss = loss_fn(logits, targets)
124
+ model.backward(loss_fn.backward())
125
+
126
+ optimizer.step()
127
+ optimizer.zero_grad()
128
+ ```
129
+
130
+ A simple rule of thumb is:
131
+
132
+ - Use `dnnlpy.nn` when a component is reusable across many models.
133
+ - Use `dnnlpy.models` when the code represents a complete architecture or is tightly coupled to one model family.
134
+
135
+ ## License
136
+
137
+ This project is licensed under the **MIT License**.
@@ -0,0 +1,44 @@
1
+ [project]
2
+ name = "dnnlpy"
3
+ version = "2026.06.11"
4
+ keywords = ["deep learning notes", "library", "utilities"]
5
+ description = "A utility library for deep learning notes, providing common functions and tools to support the main content of the deep learning notes collection."
6
+ authors = [{ name = "Yunjie Lin", email = "jshn9510@gmail.com" }]
7
+ readme = "README.md"
8
+ license = "MIT"
9
+ requires-python = ">=3.12,<3.15"
10
+ classifiers = [
11
+ "Development Status :: 4 - Beta",
12
+ "Operating System :: OS Independent",
13
+ "Intended Audience :: Education",
14
+ "Intended Audience :: Science/Research",
15
+ "Topic :: Scientific/Engineering :: Mathematics",
16
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
17
+ "Programming Language :: Python :: 3 :: Only",
18
+ "Programming Language :: Python :: 3.12",
19
+ "Programming Language :: Python :: 3.13",
20
+ "Programming Language :: Python :: 3.14",
21
+ ]
22
+ dependencies = [
23
+ 'numpy>=2.4.0,<2.5.0',
24
+ 'matplotlib>=3.11.0,<3.12.0',
25
+ 'torch>=2.12.0,<2.13.0',
26
+ ]
27
+
28
+ [project.optional-dependencies]
29
+ test = ['pytest>=9.0.0,<9.1.0']
30
+
31
+ [project.urls]
32
+ Homepage = "https://github.com/jshn9515/deep-learning-notes"
33
+
34
+ [tool.uv.sources]
35
+ torch = { index = "pytorch" }
36
+
37
+ [[tool.uv.index]]
38
+ name = "pytorch"
39
+ url = "https://download.pytorch.org/whl/cpu"
40
+ explicit = true
41
+
42
+ [build-system]
43
+ requires = ["uv_build>=0.11.0,<0.12.0"]
44
+ build-backend = "uv_build"
@@ -0,0 +1,7 @@
1
+ from . import models as models
2
+ from . import nn as nn
3
+ from . import optim as optim
4
+ from .configtools import get_data_root as get_data_root
5
+ from .configtools import get_default_device as get_default_device
6
+ from .configtools import set_seed as set_seed
7
+ from .pylabtools import set_matplotlib_format as set_matplotlib_format
@@ -0,0 +1,64 @@
1
+ import os
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.accelerator as accl
7
+
8
+ __all__ = [
9
+ 'set_seed',
10
+ 'get_default_device',
11
+ 'get_data_root',
12
+ ]
13
+
14
+
15
+ def set_seed(
16
+ seed: int = 42,
17
+ *,
18
+ deterministic: bool = False,
19
+ benchmark: bool = False,
20
+ warn_only: bool = True,
21
+ ) -> torch.Generator:
22
+ """Seed Python, NumPy, and PyTorch random number generators.
23
+
24
+ Args:
25
+ seed (int, default: 42): Seed value to apply to all supported random
26
+ number generators.
27
+ deterministic (bool, default: False): Whether to request deterministic
28
+ PyTorch algorithms.
29
+ benchmark (bool, default: False): Whether to enable cuDNN benchmark mode.
30
+ warn_only (bool, default: True): Whether nondeterministic PyTorch
31
+ operations should warn instead of raising an error.
32
+
33
+ Returns:
34
+ Generator: The PyTorch generator returned by ``torch.manual_seed``.
35
+ """
36
+ random.seed(seed)
37
+ np.random.seed(seed)
38
+ torch_rng = torch.manual_seed(seed)
39
+
40
+ torch.use_deterministic_algorithms(deterministic, warn_only=warn_only)
41
+ torch.backends.cudnn.deterministic = deterministic
42
+ torch.backends.cudnn.benchmark = benchmark
43
+
44
+ return torch_rng
45
+
46
+
47
+ def get_default_device() -> torch.device:
48
+ """Return the current accelerator device, or CPU when none is available."""
49
+ device = accl.current_accelerator(check_available=True)
50
+ if device is not None:
51
+ return device
52
+ return torch.device('cpu')
53
+
54
+
55
+ def get_data_root() -> str:
56
+ """Return the dataset root directory, creating it when necessary.
57
+
58
+ The ``DNNL_DATA_ROOT`` environment variable overrides the default
59
+ ``~/datasets`` location.
60
+ """
61
+ root = os.getenv('DNNL_DATA_ROOT', os.path.expanduser('~/datasets'))
62
+ if not os.path.exists(root):
63
+ os.mkdir(root)
64
+ return root
@@ -0,0 +1,5 @@
1
+ from . import ddpm as ddpm
2
+ from . import mlp as mlp
3
+ from . import seq2seq as seq2seq
4
+ from . import vae as vae
5
+ from . import vit as vit
@@ -0,0 +1,5 @@
1
+ from .ddpm import DDPMScheduler as DDPMScheduler
2
+ from .embedding import SinusoidalTimestepEmbedding as SinusoidalTimestepEmbedding
3
+ from .unet import UNet2DModel as UNet2DModel
4
+ from .utils import add_noise as add_noise
5
+ from .utils import denoise as denoise
@@ -0,0 +1,154 @@
1
+ import torch
2
+ from torch import Tensor
3
+ from torch.types import Device
4
+
5
+ __all__ = ['DDPMScheduler']
6
+
7
+
8
+ class DDPMScheduler:
9
+ """Noise schedule and reverse-step helper for DDPM sampling."""
10
+
11
+ def __init__(
12
+ self,
13
+ num_train_timesteps: int = 1000,
14
+ beta_start: float = 0.0001,
15
+ beta_end: float = 0.02,
16
+ ):
17
+ """Scheduler for the Denoising Diffusion Probabilistic Models (DDPM) that defines
18
+ the noise schedule and provides a method to add noise to the original samples based
19
+ on the time steps.
20
+
21
+ Args:
22
+ num_train_timesteps (int): The total number of time steps used during training,
23
+ which determines the length of the noise schedule.
24
+ beta_start (float): The starting value of the noise variance (beta) at time step 0.
25
+ beta_end (float): The ending value of the noise variance (beta) at the final time step.
26
+ """
27
+ self.num_train_timesteps = num_train_timesteps
28
+ self.beta_start = beta_start
29
+ self.beta_end = beta_end
30
+
31
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps)
32
+ self.alphas = 1.0 - self.betas
33
+ self.alphas_cumprod = self.alphas.cumprod(dim=0)
34
+
35
+ self.num_inference_steps = num_train_timesteps
36
+ self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1, dtype=torch.long)
37
+
38
+ def add_noise(
39
+ self,
40
+ original_samples: Tensor,
41
+ noise: Tensor,
42
+ timesteps: Tensor,
43
+ ) -> Tensor:
44
+ """Add noise to clean samples at the requested timesteps.
45
+
46
+ Args:
47
+ original_samples (Tensor): Clean samples ``x_0``.
48
+ noise (Tensor): Gaussian noise with the same shape as ``original_samples``.
49
+ timesteps (Tensor): 1D tensor of timestep indices, one per batch item.
50
+
51
+ Returns:
52
+ Noisy samples ``x_t``.
53
+ """
54
+ if original_samples.shape != noise.shape:
55
+ raise AssertionError(
56
+ '`original_samples` and `noise` must have the same shape.'
57
+ )
58
+
59
+ if timesteps.ndim != 1:
60
+ raise AssertionError(
61
+ '`timesteps` must be a 1D tensor of shape (batch_size,).'
62
+ )
63
+
64
+ self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
65
+ sqrt_alpha_bar = self.alphas_cumprod[timesteps].sqrt()
66
+ sqrt_alpha_bar = sqrt_alpha_bar.view(-1, 1, 1, 1)
67
+
68
+ sqrt_one_minus_alpha_bar = (1.0 - self.alphas_cumprod)[timesteps].sqrt()
69
+ sqrt_one_minus_alpha_bar = sqrt_one_minus_alpha_bar.view(-1, 1, 1, 1)
70
+
71
+ noisy_samples = (
72
+ sqrt_alpha_bar * original_samples + sqrt_one_minus_alpha_bar * noise
73
+ )
74
+ return noisy_samples
75
+
76
+ def set_timesteps(
77
+ self,
78
+ num_inference_steps: int,
79
+ device: Device = 'cpu',
80
+ ):
81
+ """Set the inference timestep schedule.
82
+
83
+ Args:
84
+ num_inference_steps (int): Number of reverse diffusion steps to run.
85
+ device (Device, default: 'cpu'): Device where the timestep tensor should live.
86
+ """
87
+ if num_inference_steps > self.num_train_timesteps:
88
+ raise AssertionError(
89
+ f'num_inference_steps must be in the range (0, {self.num_train_timesteps}].'
90
+ )
91
+
92
+ self.num_inference_steps = num_inference_steps
93
+ self.timesteps = torch.linspace(
94
+ self.num_train_timesteps - 1,
95
+ 0,
96
+ num_inference_steps,
97
+ dtype=torch.long,
98
+ device=device,
99
+ )
100
+
101
+ def previous_timestep(self, timestep: int) -> int:
102
+ """Return the previous inference timestep for the current schedule."""
103
+ if self.num_inference_steps != self.num_train_timesteps:
104
+ index = (self.timesteps == timestep).float().argmax()
105
+ if index == len(self.timesteps) - 1:
106
+ prev = -1
107
+ else:
108
+ prev = int(self.timesteps[index + 1])
109
+ else:
110
+ prev = timestep - 1
111
+ return prev
112
+
113
+ def step(self, model_output: Tensor, timestep: int, sample: Tensor) -> Tensor:
114
+ """Perform a single reverse diffusion step to compute the previous sample given the
115
+ model's output, the current time step, and the current sample.
116
+
117
+ Args:
118
+ model_output (Tensor): The output from the diffusion model, which is typically
119
+ the predicted noise component at the current time step.
120
+ timestep (int): The current time step in the reverse diffusion process.
121
+ sample (Tensor): The current noisy sample at the given time step.
122
+ """
123
+ t = timestep
124
+ prev_t = self.previous_timestep(t)
125
+
126
+ alpha_t = self.alphas[t]
127
+ alpha_bar_t = self.alphas_cumprod[t]
128
+ beta_t = self.betas[t]
129
+
130
+ if prev_t >= 0:
131
+ alpha_bar_prev = self.alphas_cumprod[prev_t]
132
+ else:
133
+ alpha_bar_prev = torch.tensor(1.0, device=sample.device)
134
+
135
+ pred_original_sample = (
136
+ sample - (1 - alpha_bar_t).sqrt() * model_output
137
+ ) / alpha_bar_t.sqrt()
138
+
139
+ param1 = alpha_bar_prev.sqrt() * beta_t / (1 - alpha_bar_t)
140
+ param2 = alpha_t.sqrt() * (1 - alpha_bar_prev) / (1 - alpha_bar_t)
141
+ mean = param1 * pred_original_sample + param2 * sample
142
+
143
+ if prev_t >= 0:
144
+ variance = (1 - alpha_bar_prev) / (1 - alpha_bar_t) * beta_t
145
+ else:
146
+ variance = torch.tensor(0.0, device=sample.device)
147
+
148
+ if timestep > 0:
149
+ noise = torch.randn_like(sample)
150
+ prev_sample = mean + variance.sqrt() * noise
151
+ else:
152
+ prev_sample = mean
153
+
154
+ return prev_sample
@@ -0,0 +1,44 @@
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+
8
+ __all__ = ['SinusoidalTimestepEmbedding']
9
+
10
+
11
+ class SinusoidalTimestepEmbedding(nn.Module):
12
+ """Create sinusoidal embeddings for diffusion timesteps."""
13
+
14
+ def __init__(self, embedding_dim: int, max_period: int = 10000):
15
+ """Initialize timestep embedding parameters.
16
+
17
+ Args:
18
+ embedding_dim (int): Size of each timestep embedding.
19
+ max_period (int, default: 10000): Controls the minimum sinusoidal frequency.
20
+ """
21
+ super().__init__()
22
+ self.embedding_dim = embedding_dim
23
+ self.max_period = max_period
24
+
25
+ def forward(self, timesteps: Tensor) -> Tensor:
26
+ """Embed a 1D tensor of timesteps."""
27
+ half_dim = self.embedding_dim // 2
28
+ if half_dim == 0:
29
+ return torch.zeros(
30
+ timesteps.size(0),
31
+ self.embedding_dim,
32
+ device=timesteps.device,
33
+ dtype=torch.float32,
34
+ )
35
+
36
+ scale = -math.log(self.max_period) / max(half_dim - 1, 1)
37
+ emb = torch.arange(half_dim, device=timesteps.device) * scale
38
+ emb = timesteps.unsqueeze(1) * emb.exp().unsqueeze(0)
39
+ emb = torch.concat([emb.sin(), emb.cos()], dim=-1)
40
+
41
+ if self.embedding_dim % 2 == 1:
42
+ emb = F.pad(emb, (0, 1))
43
+
44
+ return emb