@xdev-asia/xdev-knowledge-mcp 1.0.57 → 1.0.59
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/content/blog/ai/minimax-danh-gia-chi-tiet-nen-tang-ai-full-stack-trung-quoc.md +450 -0
- package/content/blog/ai/nvidia-dli-generative-ai-chung-chi-va-lo-trinh-hoc.md +894 -0
- package/content/metadata/authors/duy-tran.md +2 -0
- package/content/series/luyen-thi/luyen-thi-nvidia-dli-generative-ai/chapters/01-deep-learning-foundations/lessons/01-bai-1-pytorch-neural-network-fundamentals.md +790 -0
- package/content/series/luyen-thi/luyen-thi-nvidia-dli-generative-ai/chapters/01-deep-learning-foundations/lessons/02-bai-2-transformer-architecture-attention.md +984 -0
- package/content/series/luyen-thi/luyen-thi-nvidia-dli-generative-ai/chapters/02-diffusion-models/lessons/01-bai-3-unet-architecture-denoising.md +1111 -0
- package/content/series/luyen-thi/luyen-thi-nvidia-dli-generative-ai/chapters/02-diffusion-models/lessons/02-bai-4-ddpm-forward-reverse-diffusion.md +1007 -0
- package/content/series/luyen-thi/luyen-thi-nvidia-dli-generative-ai/chapters/02-diffusion-models/lessons/03-bai-5-clip-text-to-image-pipeline.md +1037 -0
- package/content/series/luyen-thi/luyen-thi-nvidia-dli-generative-ai/chapters/03-llm-applications-rag/lessons/01-bai-6-llm-inference-pipeline-design.md +929 -0
- package/content/series/luyen-thi/luyen-thi-nvidia-dli-generative-ai/chapters/03-llm-applications-rag/lessons/02-bai-7-rag-retrieval-augmented-generation.md +1099 -0
- package/content/series/luyen-thi/luyen-thi-nvidia-dli-generative-ai/chapters/03-llm-applications-rag/lessons/03-bai-8-rag-agent-build-evaluate.md +1249 -0
- package/content/series/luyen-thi/luyen-thi-nvidia-dli-generative-ai/chapters/04-agentic-ai-customization/lessons/01-bai-9-agentic-ai-multi-agent-systems.md +1357 -0
- package/content/series/luyen-thi/luyen-thi-nvidia-dli-generative-ai/chapters/04-agentic-ai-customization/lessons/02-bai-10-llm-evaluation-lora-fine-tuning.md +1867 -0
- package/content/series/luyen-thi/luyen-thi-nvidia-dli-generative-ai/index.md +237 -0
- package/data/quizzes/nvidia-dli-generative-ai.json +350 -0
- package/data/quizzes.json +14 -0
- package/package.json +1 -1
|
@@ -0,0 +1,1111 @@
|
|
|
1
|
+
---
|
|
2
|
+
id: 019c9619-nv01-p2-l03
|
|
3
|
+
title: 'Bài 3: U-Net Architecture & Denoising Basics'
|
|
4
|
+
slug: bai-3-unet-architecture-denoising
|
|
5
|
+
description: >-
|
|
6
|
+
U-Net encoder-decoder with skip connections.
|
|
7
|
+
Build U-Net from scratch in PyTorch. Train denoiser model.
|
|
8
|
+
Group Normalization, GELU activation, Rearrange Pooling.
|
|
9
|
+
Sinusoidal Position Embeddings for timestep encoding.
|
|
10
|
+
duration_minutes: 90
|
|
11
|
+
is_free: true
|
|
12
|
+
video_url: null
|
|
13
|
+
sort_order: 3
|
|
14
|
+
section_title: "Part 2: Generative AI with Diffusion Models"
|
|
15
|
+
course:
|
|
16
|
+
id: 019c9619-nv01-7001-c001-nv0100000001
|
|
17
|
+
title: 'Luyện thi NVIDIA DLI — Generative AI with Diffusion Models & LLMs'
|
|
18
|
+
slug: luyen-thi-nvidia-dli-generative-ai
|
|
19
|
+
---
|
|
20
|
+
|
|
21
|
+
<h2 id="gioi-thieu">1. Giới thiệu: Tại sao U-Net là trái tim của Diffusion Models?</h2>
|
|
22
|
+
|
|
23
|
+
<p>Trong bài trước, bạn đã hiểu <strong>forward process</strong> thêm noise vào ảnh theo từng timestep. Bây giờ câu hỏi là: mô hình nào sẽ học cách <strong>khử noise</strong> (denoise) — tức là đảo ngược quá trình đó? Câu trả lời là <strong>U-Net</strong>.</p>
|
|
24
|
+
|
|
25
|
+
<p><strong>U-Net</strong> ban đầu được thiết kế cho bài toán <strong>image segmentation</strong> trong y khoa (2015, Ronneberger et al.). Kiến trúc đặc biệt của nó — encoder-decoder với <strong>skip connections</strong> — giúp bảo toàn chi tiết spatial trong khi học được features ở nhiều mức độ trừu tượng. Đây chính xác là điều diffusion models cần.</p>
|
|
26
|
+
|
|
27
|
+
<blockquote><p><strong>Exam tip:</strong> Trong assessment, bạn sẽ phải implement U-Net từ đầu. Hiểu rõ chiều tensor qua mỗi layer là chìa khóa. NVIDIA DLI yêu cầu bạn viết code chạy đúng, không chỉ hiểu lý thuyết.</p></blockquote>
|
|
28
|
+
|
|
29
|
+
<figure><img src="/storage/uploads/2026/04/nvidia-dli-bai3-unet-architecture.png" alt="Kiến trúc U-Net — Encoder-Decoder với Skip Connections cho Image Denoising" loading="lazy" /><figcaption>Kiến trúc U-Net — Encoder-Decoder với Skip Connections cho Image Denoising</figcaption></figure>
|
|
30
|
+
|
|
31
|
+
<h2 id="unet-architecture">2. U-Net Architecture: Encoder-Decoder với Skip Connections</h2>
|
|
32
|
+
|
|
33
|
+
<h3 id="tong-quan-kien-truc">2.1 Tổng quan kiến trúc</h3>
|
|
34
|
+
|
|
35
|
+
<p>U-Net có hình chữ "U" với 3 phần chính:</p>
|
|
36
|
+
|
|
37
|
+
<ul>
|
|
38
|
+
<li><strong>Encoder (Contracting Path)</strong>: giảm spatial resolution, tăng số channels — học features trừu tượng cao</li>
|
|
39
|
+
<li><strong>Bottleneck</strong>: spatial nhỏ nhất, channels lớn nhất — nắm bắt global context</li>
|
|
40
|
+
<li><strong>Decoder (Expanding Path)</strong>: tăng spatial resolution, giảm số channels — khôi phục chi tiết</li>
|
|
41
|
+
<li><strong>Skip Connections</strong>: nối trực tiếp encoder features sang decoder tương ứng — bảo toàn fine-grained details</li>
|
|
42
|
+
</ul>
|
|
43
|
+
|
|
44
|
+
<pre><code class="language-text">
|
|
45
|
+
U-Net Architecture cho Diffusion (input 64×64×1)
|
|
46
|
+
═══════════════════════════════════════════════
|
|
47
|
+
|
|
48
|
+
ENCODER DECODER
|
|
49
|
+
(Contracting Path) (Expanding Path)
|
|
50
|
+
|
|
51
|
+
┌─────────────────┐ ┌─────────────────┐
|
|
52
|
+
│ 64 × 64 × 1 │ Input Image │ 64 × 64 × 1 │ Output (denoised)
|
|
53
|
+
└────────┬────────┘ └────────▲────────┘
|
|
54
|
+
│ │
|
|
55
|
+
▼ │
|
|
56
|
+
┌─────────────────┐ skip connection ┌─────────────────┐
|
|
57
|
+
│ 64 × 64 × 64 │ ───────────────► │ 64 × 64 × 64 │ UpBlock + Concat
|
|
58
|
+
│ Conv→GN→GELU │ (concatenate) │ Conv→GN→GELU │
|
|
59
|
+
└────────┬────────┘ └────────▲────────┘
|
|
60
|
+
│ Downsample │ Upsample
|
|
61
|
+
▼ │
|
|
62
|
+
┌─────────────────┐ skip connection ┌─────────────────┐
|
|
63
|
+
│ 32 × 32 × 128 │ ───────────────► │ 32 × 32 × 128 │ UpBlock + Concat
|
|
64
|
+
│ Conv→GN→GELU │ (concatenate) │ Conv→GN→GELU │
|
|
65
|
+
└────────┬────────┘ └────────▲────────┘
|
|
66
|
+
│ Downsample │ Upsample
|
|
67
|
+
▼ │
|
|
68
|
+
┌─────────────────┐ skip connection ┌─────────────────┐
|
|
69
|
+
│ 16 × 16 × 256 │ ───────────────► │ 16 × 16 × 256 │ UpBlock + Concat
|
|
70
|
+
│ Conv→GN→GELU │ (concatenate) │ Conv→GN→GELU │
|
|
71
|
+
└────────┬────────┘ └────────▲────────┘
|
|
72
|
+
│ Downsample │ Upsample
|
|
73
|
+
▼ │
|
|
74
|
+
┌──────────────────────────────────────────────┐
|
|
75
|
+
│ 8 × 8 × 512 │
|
|
76
|
+
│ BOTTLENECK │
|
|
77
|
+
│ Conv → GN → GELU → Conv → GN │
|
|
78
|
+
│ (smallest spatial, largest channels) │
|
|
79
|
+
└───────────────────────────────────────────────┘
|
|
80
|
+
|
|
81
|
+
+ Timestep Embedding ──► inject vào MỌI ResidualBlock qua linear projection
|
|
82
|
+
</code></pre>
|
|
83
|
+
|
|
84
|
+
<h3 id="encoder-path">2.2 Encoder Path (Contracting)</h3>
|
|
85
|
+
|
|
86
|
+
<p>Mỗi level của encoder thực hiện:</p>
|
|
87
|
+
|
|
88
|
+
<ol>
|
|
89
|
+
<li><strong>Convolution</strong>: 3×3 conv với padding=1 (giữ nguyên spatial size)</li>
|
|
90
|
+
<li><strong>Group Normalization</strong>: normalize theo groups thay vì batch</li>
|
|
91
|
+
<li><strong>GELU Activation</strong>: non-linearity mượt hơn ReLU</li>
|
|
92
|
+
<li><strong>Downsample</strong>: giảm spatial resolution đi 2× (có thể dùng stride=2 conv hoặc Rearrange Pooling)</li>
|
|
93
|
+
</ol>
|
|
94
|
+
|
|
95
|
+
<p>Qua mỗi level, số <strong>channels tăng gấp đôi</strong> và <strong>spatial giảm đi một nửa</strong>. Ví dụ:</p>
|
|
96
|
+
|
|
97
|
+
<table>
|
|
98
|
+
<thead>
|
|
99
|
+
<tr><th>Level</th><th>Input Shape</th><th>Output Shape</th><th>Operation</th></tr>
|
|
100
|
+
</thead>
|
|
101
|
+
<tbody>
|
|
102
|
+
<tr><td>0</td><td>B × 1 × 64 × 64</td><td>B × 64 × 64 × 64</td><td>Initial Conv</td></tr>
|
|
103
|
+
<tr><td>1</td><td>B × 64 × 64 × 64</td><td>B × 128 × 32 × 32</td><td>ResBlock → Down</td></tr>
|
|
104
|
+
<tr><td>2</td><td>B × 128 × 32 × 32</td><td>B × 256 × 16 × 16</td><td>ResBlock → Down</td></tr>
|
|
105
|
+
<tr><td>3</td><td>B × 256 × 16 × 16</td><td>B × 512 × 8 × 8</td><td>ResBlock → Down</td></tr>
|
|
106
|
+
</tbody>
|
|
107
|
+
</table>
|
|
108
|
+
|
|
109
|
+
<h3 id="decoder-path">2.3 Decoder Path (Expanding)</h3>
|
|
110
|
+
|
|
111
|
+
<p>Ngược lại với encoder, decoder <strong>tăng spatial</strong> và <strong>giảm channels</strong>:</p>
|
|
112
|
+
|
|
113
|
+
<ol>
|
|
114
|
+
<li><strong>Upsample</strong>: tăng spatial resolution lên 2× (thường dùng <code>nn.Upsample</code> hoặc <code>nn.ConvTranspose2d</code>)</li>
|
|
115
|
+
<li><strong>Concatenate</strong> với skip connection từ encoder cùng level</li>
|
|
116
|
+
<li><strong>Convolution → GroupNorm → GELU</strong>: xử lý features concat</li>
|
|
117
|
+
</ol>
|
|
118
|
+
|
|
119
|
+
<blockquote><p><strong>Exam tip:</strong> Khi concatenate skip connection, số channels sẽ <strong>gấp đôi</strong> tạm thời. Ví dụ: upsample output có 256 channels + skip có 256 channels = 512 channels đầu vào conv. Đây là lỗi phổ biến khi implement — chú ý <code>in_channels</code> của conv sau concat!</p></blockquote>
|
|
120
|
+
|
|
121
|
+
<h3 id="skip-connections">2.4 Skip Connections — Tại sao quan trọng?</h3>
|
|
122
|
+
|
|
123
|
+
<p>Không có skip connections, decoder phải "đoán" lại tất cả chi tiết spatial chỉ từ bottleneck 8×8 — gần như không thể. Skip connections cho phép:</p>
|
|
124
|
+
|
|
125
|
+
<ul>
|
|
126
|
+
<li><strong>Gradient flow</strong>: gradient chảy trực tiếp từ loss về encoder layers sâu — training dễ hơn</li>
|
|
127
|
+
<li><strong>Detail preservation</strong>: encoder ở level cao giữ edges, textures — decoder dùng lại thay vì phải học lại</li>
|
|
128
|
+
<li><strong>Multi-scale features</strong>: decoder nhận cả high-level (từ bottleneck) và low-level (từ skip) features</li>
|
|
129
|
+
</ul>
|
|
130
|
+
|
|
131
|
+
<h2 id="key-components">3. Key Components: GroupNorm, GELU, Rearrange Pooling</h2>
|
|
132
|
+
|
|
133
|
+
<h3 id="group-normalization">3.1 Group Normalization</h3>
|
|
134
|
+
|
|
135
|
+
<p>Trong diffusion models, <strong>batch size thường rất nhỏ</strong> (4-8) vì mỗi image chiếm nhiều GPU memory. <strong>Batch Normalization</strong> hoạt động kém với small batch vì statistics (mean, variance) tính trên batch không ổn định.</p>
|
|
136
|
+
|
|
137
|
+
<p><strong>Group Normalization</strong> giải quyết bằng cách chia channels thành <strong>groups</strong> và normalize <strong>trong mỗi group, cho từng sample riêng biệt</strong> — không phụ thuộc batch size.</p>
|
|
138
|
+
|
|
139
|
+
<pre><code class="language-text">
|
|
140
|
+
Group Normalization vs Batch Normalization
|
|
141
|
+
══════════════════════════════════════════
|
|
142
|
+
|
|
143
|
+
Batch Normalization: Group Normalization:
|
|
144
|
+
normalize theo N (batch) normalize theo group trong C
|
|
145
|
+
|
|
146
|
+
┌───┬───┬───┬───┐ ┌───┬───┬───┬───┐
|
|
147
|
+
│ N │ │ │ │ │ │ │ │ │ N (batch)
|
|
148
|
+
├───┼───┼───┼───┤ ├───┼───┼───┼───┤
|
|
149
|
+
│ │ │ │ │ C │ G1│ G1│ G2│ G2│ C (channels)
|
|
150
|
+
├───┼───┼───┼───┤ (channels) ├───┼───┼───┼───┤ chia thành groups
|
|
151
|
+
│ │ │ │ │ │ G1│ G1│ G2│ G2│
|
|
152
|
+
├───┼───┼───┼───┤ ├───┼───┼───┼───┤
|
|
153
|
+
│ │ │ │ │ H×W │ │ │ │ │ H×W
|
|
154
|
+
└───┴───┴───┴───┘ └───┴───┴───┴───┘
|
|
155
|
+
▲ ▲
|
|
156
|
+
normalize cột (across N) normalize block (within group)
|
|
157
|
+
⚠ batch nhỏ → unstable ✓ independent of batch size
|
|
158
|
+
</code></pre>
|
|
159
|
+
|
|
160
|
+
<pre><code class="language-python">
|
|
161
|
+
import torch.nn as nn
|
|
162
|
+
|
|
163
|
+
# GroupNorm: chia 64 channels thành 8 groups (mỗi group 8 channels)
|
|
164
|
+
norm = nn.GroupNorm(num_groups=8, num_channels=64)
|
|
165
|
+
|
|
166
|
+
# Với input shape (B, 64, 32, 32):
|
|
167
|
+
# - Chia 64 channels thành 8 groups, mỗi group 8 channels
|
|
168
|
+
# - Tính mean, var trên (8, 32, 32) = 8192 elements per group per sample
|
|
169
|
+
# - Normalize riêng cho mỗi sample, mỗi group
|
|
170
|
+
|
|
171
|
+
x = torch.randn(4, 64, 32, 32)
|
|
172
|
+
out = norm(x) # shape: (4, 64, 32, 32) — không đổi shape
|
|
173
|
+
</code></pre>
|
|
174
|
+
|
|
175
|
+
<table>
|
|
176
|
+
<thead>
|
|
177
|
+
<tr><th>Feature</th><th>BatchNorm</th><th>GroupNorm</th><th>LayerNorm</th><th>InstanceNorm</th></tr>
|
|
178
|
+
</thead>
|
|
179
|
+
<tbody>
|
|
180
|
+
<tr><td>Normalize across</td><td>Batch (N)</td><td>Channel groups</td><td>All channels</td><td>Each channel</td></tr>
|
|
181
|
+
<tr><td>Batch size dependency</td><td>Yes ⚠</td><td>No ✓</td><td>No ✓</td><td>No ✓</td></tr>
|
|
182
|
+
<tr><td>Small batch performance</td><td>Poor</td><td>Good</td><td>OK</td><td>OK</td></tr>
|
|
183
|
+
<tr><td>Use case</td><td>Classification</td><td>Diffusion, Detection</td><td>Transformers (NLP)</td><td>Style Transfer</td></tr>
|
|
184
|
+
<tr><td>PyTorch API</td><td><code>nn.BatchNorm2d(C)</code></td><td><code>nn.GroupNorm(G, C)</code></td><td><code>nn.LayerNorm(shape)</code></td><td><code>nn.InstanceNorm2d(C)</code></td></tr>
|
|
185
|
+
</tbody>
|
|
186
|
+
</table>
|
|
187
|
+
|
|
188
|
+
<h3 id="gelu-activation">3.2 GELU Activation</h3>
|
|
189
|
+
|
|
190
|
+
<p><strong>GELU</strong> (Gaussian Error Linear Unit) là activation function tiêu chuẩn trong các mô hình hiện đại (Transformers, Diffusion Models). Khác với ReLU "cứng" (cắt âm về 0), GELU mượt và cho phép một phần giá trị âm "rò rỉ" qua.</p>
|
|
191
|
+
|
|
192
|
+
<p>Công thức: <strong>GELU(x) = x · Φ(x)</strong>, trong đó Φ(x) là CDF của standard normal distribution.</p>
|
|
193
|
+
|
|
194
|
+
<pre><code class="language-text">
|
|
195
|
+
Activation Functions Comparison
|
|
196
|
+
═══════════════════════════════
|
|
197
|
+
|
|
198
|
+
Output Output
|
|
199
|
+
│ ReLU │ GELU
|
|
200
|
+
│ ╱ │ ╱
|
|
201
|
+
│ ╱ │ ╱
|
|
202
|
+
│ ╱ │ ╱
|
|
203
|
+
───┼───╱────── Input ───┼──╱─────── Input
|
|
204
|
+
│ ╱ ╱│
|
|
205
|
+
│ ╱ ╱ │
|
|
206
|
+
│╱ (hard cutoff at 0) ╱ │ (smooth curve, allows
|
|
207
|
+
│ ╱ │ small negative values)
|
|
208
|
+
|
|
209
|
+
ReLU(x) = max(0, x) GELU(x) = x · Φ(x)
|
|
210
|
+
⚠ Dead neurons problem ✓ Smoother gradient flow
|
|
211
|
+
⚠ Not differentiable at 0 ✓ Better for deep networks
|
|
212
|
+
</code></pre>
|
|
213
|
+
|
|
214
|
+
<pre><code class="language-python">
|
|
215
|
+
import torch.nn as nn
|
|
216
|
+
|
|
217
|
+
# Cách 1: dùng module
|
|
218
|
+
activation = nn.GELU()
|
|
219
|
+
out = activation(x)
|
|
220
|
+
|
|
221
|
+
# Cách 2: dùng functional
|
|
222
|
+
import torch.nn.functional as F
|
|
223
|
+
out = F.gelu(x)
|
|
224
|
+
|
|
225
|
+
# Cách 3: approximate (nhanh hơn, DLI course dùng cách này)
|
|
226
|
+
activation = nn.GELU(approximate='tanh')
|
|
227
|
+
</code></pre>
|
|
228
|
+
|
|
229
|
+
<h3 id="rearrange-pooling">3.3 Rearrange Pooling (Space-to-Channel)</h3>
|
|
230
|
+
|
|
231
|
+
<p><strong>Rearrange Pooling</strong> là kỹ thuật downsample thay thế cho MaxPool/AvgPool. Thay vì loại bỏ thông tin (MaxPool chọn max, AvgPool lấy trung bình), Rearrange "gấp" spatial dimensions vào channel dimension — giữ lại <strong>toàn bộ</strong> thông tin.</p>
|
|
232
|
+
|
|
233
|
+
<pre><code class="language-text">
|
|
234
|
+
Rearrange Pooling: (B, C, 2H, 2W) → (B, 4C, H, W)
|
|
235
|
+
════════════════════════════════════════════════════
|
|
236
|
+
|
|
237
|
+
Input: (B, C, 4, 4) Output: (B, 4C, 2, 2)
|
|
238
|
+
|
|
239
|
+
Channel c: 4 channels (mỗi cái là 1 "vị trí"):
|
|
240
|
+
┌───┬───┬───┬───┐ Channel c_0: Channel c_1:
|
|
241
|
+
│ a │ b │ e │ f │ ┌───┬───┐ ┌───┬───┐
|
|
242
|
+
├───┼───┼───┼───┤ │ a │ e │ │ b │ f │
|
|
243
|
+
│ c │ d │ g │ h │ ────► ├───┼───┤ ├───┼───┤
|
|
244
|
+
├───┼───┼───┼───┤ Rearrange │ i │ m │ │ j │ n │
|
|
245
|
+
│ i │ j │ m │ n │ └───┴───┘ └───┴───┘
|
|
246
|
+
├───┼───┼───┼───┤
|
|
247
|
+
│ k │ l │ o │ p │ Channel c_2: Channel c_3:
|
|
248
|
+
└───┴───┴───┴───┘ ┌───┬───┐ ┌───┬───┐
|
|
249
|
+
│ c │ g │ │ d │ h │
|
|
250
|
+
Spatial: 4×4, Channels: C ├───┼───┤ ├───┼───┤
|
|
251
|
+
│ k │ o │ │ l │ p │
|
|
252
|
+
└───┴───┘ └───┴───┘
|
|
253
|
+
|
|
254
|
+
Spatial: 2×2, Channels: 4C
|
|
255
|
+
✓ KHÔNG mất thông tin!
|
|
256
|
+
</code></pre>
|
|
257
|
+
|
|
258
|
+
<pre><code class="language-python">
|
|
259
|
+
from einops import rearrange
|
|
260
|
+
|
|
261
|
+
def rearrange_downsample(x):
|
|
262
|
+
"""Downsample by rearranging spatial dims into channels.
|
|
263
|
+
(B, C, H, W) -> (B, 4C, H/2, W/2)
|
|
264
|
+
"""
|
|
265
|
+
return rearrange(x, 'b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2)
|
|
266
|
+
|
|
267
|
+
# Ví dụ:
|
|
268
|
+
x = torch.randn(2, 64, 32, 32)
|
|
269
|
+
out = rearrange_downsample(x)
|
|
270
|
+
print(out.shape) # torch.Size([2, 256, 16, 16])
|
|
271
|
+
|
|
272
|
+
# Nếu không dùng einops, dùng PyTorch thuần:
|
|
273
|
+
def rearrange_downsample_pure(x):
|
|
274
|
+
B, C, H, W = x.shape
|
|
275
|
+
x = x.reshape(B, C, H // 2, 2, W // 2, 2)
|
|
276
|
+
x = x.permute(0, 1, 3, 5, 2, 4) # (B, C, 2, 2, H/2, W/2)
|
|
277
|
+
x = x.reshape(B, C * 4, H // 2, W // 2)
|
|
278
|
+
return x
|
|
279
|
+
</code></pre>
|
|
280
|
+
|
|
281
|
+
<table>
|
|
282
|
+
<thead>
|
|
283
|
+
<tr><th>Downsampling Method</th><th>Information Loss</th><th>Channel Change</th><th>Use in Diffusion</th></tr>
|
|
284
|
+
</thead>
|
|
285
|
+
<tbody>
|
|
286
|
+
<tr><td>MaxPool2d</td><td>High (chỉ giữ max)</td><td>Không đổi</td><td>Ít dùng</td></tr>
|
|
287
|
+
<tr><td>AvgPool2d</td><td>Medium (lấy trung bình)</td><td>Không đổi</td><td>Ít dùng</td></tr>
|
|
288
|
+
<tr><td>Stride-2 Conv</td><td>Learned (trainable)</td><td>Tuỳ config</td><td>Phổ biến</td></tr>
|
|
289
|
+
<tr><td>Rearrange Pooling</td><td>None ✓</td><td>×4</td><td>NVIDIA DLI course ✓</td></tr>
|
|
290
|
+
</tbody>
|
|
291
|
+
</table>
|
|
292
|
+
|
|
293
|
+
<blockquote><p><strong>Exam tip:</strong> NVIDIA DLI sử dụng <strong>Rearrange Pooling</strong> thay vì MaxPool. Trong assessment, bạn có thể cần implement hàm này bằng <code>einops.rearrange</code> hoặc PyTorch thuần (<code>reshape</code> + <code>permute</code>). Nhớ rằng channels tăng <strong>4 lần</strong> khi spatial giảm 2× mỗi chiều.</p></blockquote>
|
|
294
|
+
|
|
295
|
+
<h2 id="sinusoidal-embeddings">4. Sinusoidal Position Embeddings cho Timestep</h2>
|
|
296
|
+
|
|
297
|
+
<h3 id="tai-sao-can-timestep">4.1 Tại sao cần Timestep Embedding?</h3>
|
|
298
|
+
|
|
299
|
+
<p>U-Net cần biết đang ở <strong>timestep nào</strong> trong quá trình diffusion để denoise phù hợp:</p>
|
|
300
|
+
|
|
301
|
+
<ul>
|
|
302
|
+
<li>Timestep lớn (t gần T): ảnh gần như pure noise → model cần khôi phục cấu trúc tổng thể</li>
|
|
303
|
+
<li>Timestep nhỏ (t gần 0): ảnh gần sạch → model chỉ cần tinh chỉnh chi tiết nhỏ</li>
|
|
304
|
+
</ul>
|
|
305
|
+
|
|
306
|
+
<p>Ta chuyển integer timestep <strong>t</strong> thành một <strong>continuous embedding vector</strong> có chiều dài <code>embed_dim</code>, inject vào mọi layer của U-Net.</p>
|
|
307
|
+
|
|
308
|
+
<h3 id="sinusoidal-formula">4.2 Công thức Sinusoidal Embedding</h3>
|
|
309
|
+
|
|
310
|
+
<p>Giống hệt <strong>Positional Encoding</strong> trong Transformer ("Attention Is All You Need"):</p>
|
|
311
|
+
|
|
312
|
+
<pre><code class="language-text">
|
|
313
|
+
PE(t, 2i) = sin(t / 10000^(2i/d))
|
|
314
|
+
PE(t, 2i+1) = cos(t / 10000^(2i/d))
|
|
315
|
+
|
|
316
|
+
Trong đó:
|
|
317
|
+
t = timestep (integer: 0, 1, 2, ..., T)
|
|
318
|
+
d = embedding dimension (e.g., 128)
|
|
319
|
+
i = index trong embedding vector (0, 1, 2, ..., d/2 - 1)
|
|
320
|
+
|
|
321
|
+
Ví dụ với d=8:
|
|
322
|
+
PE(t) = [sin(t/1), cos(t/1), sin(t/100), cos(t/100),
|
|
323
|
+
sin(t/10000), cos(t/10000), sin(t/1000000), cos(t/1000000)]
|
|
324
|
+
|
|
325
|
+
→ Low frequency terms (cuối): thay đổi chậm → encode "big picture" timestep
|
|
326
|
+
→ High frequency terms (đầu): thay đổi nhanh → encode fine timestep differences
|
|
327
|
+
</code></pre>
|
|
328
|
+
|
|
329
|
+
<h3 id="implement-timestep-embedding">4.3 Implement TimestepEmbedding</h3>
|
|
330
|
+
|
|
331
|
+
<pre><code class="language-python">
|
|
332
|
+
import torch
|
|
333
|
+
import torch.nn as nn
|
|
334
|
+
import math
|
|
335
|
+
|
|
336
|
+
class SinusoidalPositionEmbedding(nn.Module):
|
|
337
|
+
"""Chuyển integer timestep thành sinusoidal embedding vector."""
|
|
338
|
+
|
|
339
|
+
def __init__(self, embed_dim):
|
|
340
|
+
super().__init__()
|
|
341
|
+
self.embed_dim = embed_dim
|
|
342
|
+
|
|
343
|
+
def forward(self, timesteps):
|
|
344
|
+
"""
|
|
345
|
+
Args:
|
|
346
|
+
timesteps: (B,) — integer timesteps
|
|
347
|
+
Returns:
|
|
348
|
+
embeddings: (B, embed_dim) — sinusoidal embeddings
|
|
349
|
+
"""
|
|
350
|
+
device = timesteps.device
|
|
351
|
+
half_dim = self.embed_dim // 2
|
|
352
|
+
|
|
353
|
+
# Tính frequencies: 1/10000^(2i/d) cho i = 0, 1, ..., d/2-1
|
|
354
|
+
exponent = torch.arange(half_dim, device=device).float() / half_dim
|
|
355
|
+
freqs = torch.exp(-math.log(10000.0) * exponent) # shape: (d/2,)
|
|
356
|
+
|
|
357
|
+
# Nhân timestep với frequencies: (B, 1) * (1, d/2) = (B, d/2)
|
|
358
|
+
args = timesteps[:, None].float() * freqs[None, :]
|
|
359
|
+
|
|
360
|
+
# Concat sin và cos: (B, d/2) cat (B, d/2) = (B, d)
|
|
361
|
+
embeddings = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
|
362
|
+
|
|
363
|
+
return embeddings # shape: (B, embed_dim)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
class TimestepMLPEmbedding(nn.Module):
|
|
367
|
+
"""Sinusoidal embedding + MLP projection (dùng trong DLI course)."""
|
|
368
|
+
|
|
369
|
+
def __init__(self, embed_dim, hidden_dim=None):
|
|
370
|
+
super().__init__()
|
|
371
|
+
if hidden_dim is None:
|
|
372
|
+
hidden_dim = embed_dim * 4
|
|
373
|
+
|
|
374
|
+
self.sinusoidal = SinusoidalPositionEmbedding(embed_dim)
|
|
375
|
+
self.mlp = nn.Sequential(
|
|
376
|
+
nn.Linear(embed_dim, hidden_dim),
|
|
377
|
+
nn.GELU(),
|
|
378
|
+
nn.Linear(hidden_dim, hidden_dim),
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
def forward(self, timesteps):
|
|
382
|
+
"""
|
|
383
|
+
Args:
|
|
384
|
+
timesteps: (B,) — integer timesteps
|
|
385
|
+
Returns:
|
|
386
|
+
(B, hidden_dim) — projected timestep embeddings
|
|
387
|
+
"""
|
|
388
|
+
x = self.sinusoidal(timesteps) # (B, embed_dim)
|
|
389
|
+
x = self.mlp(x) # (B, hidden_dim)
|
|
390
|
+
return x
|
|
391
|
+
</code></pre>
|
|
392
|
+
|
|
393
|
+
<h3 id="inject-timestep-unet">4.4 Inject Timestep vào U-Net</h3>
|
|
394
|
+
|
|
395
|
+
<p>Timestep embedding được inject vào <strong>mỗi ResidualBlock</strong> bằng cách:</p>
|
|
396
|
+
|
|
397
|
+
<ol>
|
|
398
|
+
<li>Project timestep embedding về cùng số channels với feature map (dùng <code>nn.Linear</code>)</li>
|
|
399
|
+
<li>Reshape thành <code>(B, C, 1, 1)</code> để broadcast</li>
|
|
400
|
+
<li><strong>Cộng</strong> vào feature map sau GroupNorm đầu tiên</li>
|
|
401
|
+
</ol>
|
|
402
|
+
|
|
403
|
+
<pre><code class="language-text">
|
|
404
|
+
Timestep Injection Flow
|
|
405
|
+
═══════════════════════
|
|
406
|
+
|
|
407
|
+
timestep t ──► SinusoidalEmbed ──► MLP ──► t_emb (B, hidden_dim)
|
|
408
|
+
│
|
|
409
|
+
Linear(hidden_dim, C)
|
|
410
|
+
│
|
|
411
|
+
(B, C, 1, 1) ← reshape để broadcast
|
|
412
|
+
│
|
|
413
|
+
Feature Map: ─── Conv ─── GroupNorm ────── (+) ────── GELU ─── Conv ─── ...
|
|
414
|
+
add here
|
|
415
|
+
</code></pre>
|
|
416
|
+
|
|
417
|
+
<blockquote><p><strong>Exam tip:</strong> Timestep embedding được <strong>cộng (add)</strong> chứ không phải concatenate vào feature map. Inject xảy ra <strong>sau GroupNorm, trước GELU</strong> trong mỗi ResidualBlock. Đây là pattern cố định trong DLI course.</p></blockquote>
|
|
418
|
+
|
|
419
|
+
<h2 id="build-unet-from-scratch">5. Build U-Net from Scratch — Step by Step</h2>
|
|
420
|
+
|
|
421
|
+
<h3 id="residual-block">5.1 ResidualBlock</h3>
|
|
422
|
+
|
|
423
|
+
<p>Đây là building block cơ bản nhất. Mỗi ResidualBlock gồm 2 lớp conv + GroupNorm + GELU, cộng với <strong>residual connection</strong> và <strong>timestep injection</strong>.</p>
|
|
424
|
+
|
|
425
|
+
<pre><code class="language-python">
|
|
426
|
+
class ResidualBlock(nn.Module):
|
|
427
|
+
"""Residual block with timestep embedding injection.
|
|
428
|
+
|
|
429
|
+
Flow: x → Conv1 → GN1 → (+t_emb) → GELU → Conv2 → GN2 → GELU → (+residual) → out
|
|
430
|
+
"""
|
|
431
|
+
|
|
432
|
+
def __init__(self, in_channels, out_channels, time_emb_dim):
|
|
433
|
+
super().__init__()
|
|
434
|
+
|
|
435
|
+
# First conv layer
|
|
436
|
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
437
|
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=out_channels)
|
|
438
|
+
|
|
439
|
+
# Second conv layer
|
|
440
|
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
|
441
|
+
self.norm2 = nn.GroupNorm(num_groups=8, num_channels=out_channels)
|
|
442
|
+
|
|
443
|
+
# Activation
|
|
444
|
+
self.act = nn.GELU()
|
|
445
|
+
|
|
446
|
+
# Timestep embedding projection: project to out_channels
|
|
447
|
+
self.time_mlp = nn.Sequential(
|
|
448
|
+
nn.GELU(),
|
|
449
|
+
nn.Linear(time_emb_dim, out_channels),
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
# Residual connection: nếu in_channels != out_channels, cần 1x1 conv
|
|
453
|
+
if in_channels != out_channels:
|
|
454
|
+
self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
|
455
|
+
else:
|
|
456
|
+
self.residual_conv = nn.Identity()
|
|
457
|
+
|
|
458
|
+
def forward(self, x, t_emb):
|
|
459
|
+
"""
|
|
460
|
+
Args:
|
|
461
|
+
x: (B, in_channels, H, W) — input feature map
|
|
462
|
+
t_emb: (B, time_emb_dim) — timestep embedding
|
|
463
|
+
Returns:
|
|
464
|
+
(B, out_channels, H, W)
|
|
465
|
+
"""
|
|
466
|
+
residual = self.residual_conv(x) # (B, out_channels, H, W)
|
|
467
|
+
|
|
468
|
+
# First layer
|
|
469
|
+
h = self.conv1(x) # (B, out_channels, H, W)
|
|
470
|
+
h = self.norm1(h) # normalize
|
|
471
|
+
|
|
472
|
+
# Inject timestep embedding
|
|
473
|
+
t = self.time_mlp(t_emb) # (B, out_channels)
|
|
474
|
+
t = t[:, :, None, None] # (B, out_channels, 1, 1) broadcast
|
|
475
|
+
h = h + t # add timestep info
|
|
476
|
+
|
|
477
|
+
h = self.act(h) # GELU activation
|
|
478
|
+
|
|
479
|
+
# Second layer
|
|
480
|
+
h = self.conv2(h) # (B, out_channels, H, W)
|
|
481
|
+
h = self.norm2(h) # normalize
|
|
482
|
+
h = self.act(h) # GELU activation
|
|
483
|
+
|
|
484
|
+
return h + residual # residual connection
|
|
485
|
+
</code></pre>
|
|
486
|
+
|
|
487
|
+
<h3 id="down-block">5.2 DownBlock (Encoder Level)</h3>
|
|
488
|
+
|
|
489
|
+
<pre><code class="language-python">
|
|
490
|
+
class DownBlock(nn.Module):
|
|
491
|
+
"""Encoder block: ResidualBlock + Rearrange Downsample."""
|
|
492
|
+
|
|
493
|
+
def __init__(self, in_channels, out_channels, time_emb_dim):
|
|
494
|
+
super().__init__()
|
|
495
|
+
self.res_block = ResidualBlock(in_channels, out_channels, time_emb_dim)
|
|
496
|
+
|
|
497
|
+
def downsample(self, x):
|
|
498
|
+
"""Rearrange pooling: (B, C, H, W) -> (B, 4C, H/2, W/2)"""
|
|
499
|
+
B, C, H, W = x.shape
|
|
500
|
+
x = x.reshape(B, C, H // 2, 2, W // 2, 2)
|
|
501
|
+
x = x.permute(0, 1, 3, 5, 2, 4).reshape(B, C * 4, H // 2, W // 2)
|
|
502
|
+
return x
|
|
503
|
+
|
|
504
|
+
def forward(self, x, t_emb):
|
|
505
|
+
"""
|
|
506
|
+
Args:
|
|
507
|
+
x: (B, in_channels, H, W)
|
|
508
|
+
t_emb: (B, time_emb_dim)
|
|
509
|
+
Returns:
|
|
510
|
+
skip: (B, out_channels, H, W) — for skip connection
|
|
511
|
+
down: (B, out_channels*4, H/2, W/2) — downsampled for next level
|
|
512
|
+
"""
|
|
513
|
+
skip = self.res_block(x, t_emb) # (B, out_channels, H, W)
|
|
514
|
+
down = self.downsample(skip) # (B, out_channels*4, H/2, W/2)
|
|
515
|
+
return skip, down
|
|
516
|
+
</code></pre>
|
|
517
|
+
|
|
518
|
+
<h3 id="up-block">5.3 UpBlock (Decoder Level)</h3>
|
|
519
|
+
|
|
520
|
+
<pre><code class="language-python">
|
|
521
|
+
class UpBlock(nn.Module):
|
|
522
|
+
"""Decoder block: Upsample + Concat skip + ResidualBlock."""
|
|
523
|
+
|
|
524
|
+
def __init__(self, in_channels, skip_channels, out_channels, time_emb_dim):
|
|
525
|
+
super().__init__()
|
|
526
|
+
# in_channels = channels from below level after upsample
|
|
527
|
+
# After concat with skip: in_channels + skip_channels
|
|
528
|
+
self.res_block = ResidualBlock(
|
|
529
|
+
in_channels + skip_channels, out_channels, time_emb_dim
|
|
530
|
+
)
|
|
531
|
+
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
|
532
|
+
|
|
533
|
+
def forward(self, x, skip, t_emb):
|
|
534
|
+
"""
|
|
535
|
+
Args:
|
|
536
|
+
x: (B, in_channels, H, W) — from below level
|
|
537
|
+
skip: (B, skip_channels, 2H, 2W) — skip connection from encoder
|
|
538
|
+
t_emb: (B, time_emb_dim)
|
|
539
|
+
Returns:
|
|
540
|
+
(B, out_channels, 2H, 2W)
|
|
541
|
+
"""
|
|
542
|
+
x = self.upsample(x) # (B, in_channels, 2H, 2W)
|
|
543
|
+
x = torch.cat([x, skip], dim=1) # (B, in_channels+skip_channels, 2H, 2W)
|
|
544
|
+
x = self.res_block(x, t_emb) # (B, out_channels, 2H, 2W)
|
|
545
|
+
return x
|
|
546
|
+
</code></pre>
|
|
547
|
+
|
|
548
|
+
<h3 id="full-unet">5.4 Full U-Net Assembly</h3>
|
|
549
|
+
|
|
550
|
+
<pre><code class="language-python">
|
|
551
|
+
class UNet(nn.Module):
|
|
552
|
+
"""Complete U-Net for diffusion denoising.
|
|
553
|
+
|
|
554
|
+
Architecture: 64×64×1 → encoder (3 levels) → bottleneck → decoder (3 levels) → 64×64×1
|
|
555
|
+
Channel progression: 1 → 64 → 128 → 256 → 512 (bottleneck) → 256 → 128 → 64 → 1
|
|
556
|
+
"""
|
|
557
|
+
|
|
558
|
+
def __init__(self, in_channels=1, base_channels=64, time_emb_dim=128):
|
|
559
|
+
super().__init__()
|
|
560
|
+
|
|
561
|
+
# Timestep embedding
|
|
562
|
+
self.time_embed = TimestepMLPEmbedding(
|
|
563
|
+
embed_dim=time_emb_dim,
|
|
564
|
+
hidden_dim=time_emb_dim * 4
|
|
565
|
+
)
|
|
566
|
+
t_dim = time_emb_dim * 4 # output dim of MLP
|
|
567
|
+
|
|
568
|
+
# Initial convolution: 1 → 64
|
|
569
|
+
self.init_conv = nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1)
|
|
570
|
+
|
|
571
|
+
# Encoder path
|
|
572
|
+
# Level 1: 64ch, 64×64 → Rearrange → 256ch, 32×32
|
|
573
|
+
self.down1 = DownBlock(base_channels, base_channels, t_dim) # 64 → 64 (skip), 256 (down)
|
|
574
|
+
|
|
575
|
+
# Level 2: 256ch, 32×32 → Rearrange → 512ch, 16×16
|
|
576
|
+
# Cần 1x1 conv trước vì Rearrange tạo 4×channels
|
|
577
|
+
self.down1_proj = nn.Conv2d(base_channels * 4, base_channels * 2, kernel_size=1)
|
|
578
|
+
self.down2 = DownBlock(base_channels * 2, base_channels * 2, t_dim) # 128 → 128 (skip), 512 (down)
|
|
579
|
+
|
|
580
|
+
# Level 3: 512ch, 16×16 → Rearrange → 1024ch, 8×8
|
|
581
|
+
self.down2_proj = nn.Conv2d(base_channels * 8, base_channels * 4, kernel_size=1)
|
|
582
|
+
self.down3 = DownBlock(base_channels * 4, base_channels * 4, t_dim) # 256 → 256 (skip), 1024 (down)
|
|
583
|
+
|
|
584
|
+
# Bottleneck: 1024ch, 8×8 → 512ch, 8×8
|
|
585
|
+
self.down3_proj = nn.Conv2d(base_channels * 16, base_channels * 8, kernel_size=1)
|
|
586
|
+
self.bottleneck = ResidualBlock(base_channels * 8, base_channels * 8, t_dim) # 512 → 512
|
|
587
|
+
|
|
588
|
+
# Decoder path
|
|
589
|
+
# Level 3: upsample 512 to 16×16, concat skip(256) → 768 → 256
|
|
590
|
+
self.up3 = UpBlock(base_channels * 8, base_channels * 4, base_channels * 4, t_dim)
|
|
591
|
+
|
|
592
|
+
# Level 2: upsample 256 to 32×32, concat skip(128) → 384 → 128
|
|
593
|
+
self.up2 = UpBlock(base_channels * 4, base_channels * 2, base_channels * 2, t_dim)
|
|
594
|
+
|
|
595
|
+
# Level 1: upsample 128 to 64×64, concat skip(64) → 192 → 64
|
|
596
|
+
self.up1 = UpBlock(base_channels * 2, base_channels, base_channels, t_dim)
|
|
597
|
+
|
|
598
|
+
# Final output: 64 → 1
|
|
599
|
+
self.final_conv = nn.Sequential(
|
|
600
|
+
nn.GroupNorm(8, base_channels),
|
|
601
|
+
nn.GELU(),
|
|
602
|
+
nn.Conv2d(base_channels, in_channels, kernel_size=1),
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
def forward(self, x, timesteps):
|
|
606
|
+
"""
|
|
607
|
+
Args:
|
|
608
|
+
x: (B, 1, 64, 64) — noisy image
|
|
609
|
+
timesteps: (B,) — integer timesteps
|
|
610
|
+
Returns:
|
|
611
|
+
(B, 1, 64, 64) — predicted clean image (or noise)
|
|
612
|
+
"""
|
|
613
|
+
# Timestep embedding
|
|
614
|
+
t_emb = self.time_embed(timesteps) # (B, t_dim)
|
|
615
|
+
|
|
616
|
+
# Initial conv
|
|
617
|
+
x = self.init_conv(x) # (B, 64, 64, 64)
|
|
618
|
+
|
|
619
|
+
# Encoder
|
|
620
|
+
skip1, x = self.down1(x, t_emb) # skip1: (B,64,64,64), x: (B,256,32,32)
|
|
621
|
+
x = self.down1_proj(x) # (B, 128, 32, 32)
|
|
622
|
+
|
|
623
|
+
skip2, x = self.down2(x, t_emb) # skip2: (B,128,32,32), x: (B,512,16,16)
|
|
624
|
+
x = self.down2_proj(x) # (B, 256, 16, 16)
|
|
625
|
+
|
|
626
|
+
skip3, x = self.down3(x, t_emb) # skip3: (B,256,16,16), x: (B,1024,8,8)
|
|
627
|
+
x = self.down3_proj(x) # (B, 512, 8, 8)
|
|
628
|
+
|
|
629
|
+
# Bottleneck
|
|
630
|
+
x = self.bottleneck(x, t_emb) # (B, 512, 8, 8)
|
|
631
|
+
|
|
632
|
+
# Decoder
|
|
633
|
+
x = self.up3(x, skip3, t_emb) # (B, 256, 16, 16)
|
|
634
|
+
x = self.up2(x, skip2, t_emb) # (B, 128, 32, 32)
|
|
635
|
+
x = self.up1(x, skip1, t_emb) # (B, 64, 64, 64)
|
|
636
|
+
|
|
637
|
+
# Final output
|
|
638
|
+
x = self.final_conv(x) # (B, 1, 64, 64)
|
|
639
|
+
return x
|
|
640
|
+
</code></pre>
|
|
641
|
+
|
|
642
|
+
<p>Kiểm tra tensor shapes:</p>
|
|
643
|
+
|
|
644
|
+
<pre><code class="language-python">
|
|
645
|
+
# Verify shapes
|
|
646
|
+
model = UNet(in_channels=1, base_channels=64, time_emb_dim=128)
|
|
647
|
+
x = torch.randn(2, 1, 64, 64)
|
|
648
|
+
t = torch.randint(0, 1000, (2,))
|
|
649
|
+
out = model(x, t)
|
|
650
|
+
print(f"Input: {x.shape}") # torch.Size([2, 1, 64, 64])
|
|
651
|
+
print(f"Output: {out.shape}") # torch.Size([2, 1, 64, 64])
|
|
652
|
+
print(f"Params: {sum(p.numel() for p in model.parameters()):,}")
|
|
653
|
+
</code></pre>
|
|
654
|
+
|
|
655
|
+
<pre><code class="language-text">
|
|
656
|
+
Tensor Shape Flow qua U-Net (base_channels=64)
|
|
657
|
+
═══════════════════════════════════════════════
|
|
658
|
+
|
|
659
|
+
Layer Shape Notes
|
|
660
|
+
──────────────────────────────────────────────────────────
|
|
661
|
+
Input (B, 1, 64, 64)
|
|
662
|
+
init_conv (B, 64, 64, 64) Conv2d(1, 64)
|
|
663
|
+
|
|
664
|
+
down1 ResBlock (B, 64, 64, 64) skip1 ─────────────────┐
|
|
665
|
+
down1 Rearrange (B, 256, 32, 32) 4× channels │
|
|
666
|
+
down1_proj (B, 128, 32, 32) 1×1 conv reduce │
|
|
667
|
+
│
|
|
668
|
+
down2 ResBlock (B, 128, 32, 32) skip2 ──────────┐ │
|
|
669
|
+
down2 Rearrange (B, 512, 16, 16) 4× channels │ │
|
|
670
|
+
down2_proj (B, 256, 16, 16) 1×1 conv reduce │ │
|
|
671
|
+
│ │
|
|
672
|
+
down3 ResBlock (B, 256, 16, 16) skip3 ───┐ │ │
|
|
673
|
+
down3 Rearrange (B, 1024, 8, 8) 4× ch │ │ │
|
|
674
|
+
down3_proj (B, 512, 8, 8) reduce │ │ │
|
|
675
|
+
│ │ │
|
|
676
|
+
bottleneck (B, 512, 8, 8) │ │ │
|
|
677
|
+
│ │ │
|
|
678
|
+
up3 Upsample (B, 512, 16, 16) │ │ │
|
|
679
|
+
up3 Concat skip3 (B, 768, 16, 16) ◄──────────────┘ │ │
|
|
680
|
+
up3 ResBlock (B, 256, 16, 16) │ │
|
|
681
|
+
│ │
|
|
682
|
+
up2 Upsample (B, 256, 32, 32) │ │
|
|
683
|
+
up2 Concat skip2 (B, 384, 32, 32) ◄─────────────────────┘ │
|
|
684
|
+
up2 ResBlock (B, 128, 32, 32) │
|
|
685
|
+
│
|
|
686
|
+
up1 Upsample (B, 128, 64, 64) │
|
|
687
|
+
up1 Concat skip1 (B, 192, 64, 64) ◄────────────────────────────┘
|
|
688
|
+
up1 ResBlock (B, 64, 64, 64)
|
|
689
|
+
|
|
690
|
+
final_conv (B, 1, 64, 64) Output = denoised image
|
|
691
|
+
</code></pre>
|
|
692
|
+
|
|
693
|
+
<blockquote><p><strong>Exam tip:</strong> Trong assessment, bạn sẽ cần tính toán chính xác tensor shapes. Quy tắc nhớ: <strong>Rearrange → channels ×4, spatial ÷2</strong>. Sau concat skip connection, số channels = channels từ upsample + channels từ skip. Ghi ra giấy nháp trước khi code!</p></blockquote>
|
|
694
|
+
|
|
695
|
+
<h2 id="train-denoiser">6. Train Denoiser Model</h2>
|
|
696
|
+
|
|
697
|
+
<h3 id="denoising-task">6.1 Bài toán Denoising đơn giản</h3>
|
|
698
|
+
|
|
699
|
+
<p>Trước khi học full diffusion process (nhiều timesteps), ta bắt đầu với bài toán đơn giản:<br/>
|
|
700
|
+
<strong>Thêm Gaussian noise vào ảnh → train U-Net khôi phục ảnh gốc</strong></p>
|
|
701
|
+
|
|
702
|
+
<pre><code class="language-text">
|
|
703
|
+
Simple Denoising Task
|
|
704
|
+
═════════════════════
|
|
705
|
+
|
|
706
|
+
Original Image Add Noise Noisy Image U-Net Denoised Output
|
|
707
|
+
┌──────┐ ┌──────┐ ┌──────┐ ┌─────┐ ┌──────┐
|
|
708
|
+
│ 🖼️ │ + │ ░░░░ │ noise_level │ ░🖼️░ │ ────► │U-Net│ ────► │ 🖼️ │
|
|
709
|
+
│ │ │ ░░░░ │ * N(0,1) │ ░░░░ │ │ │ │ │
|
|
710
|
+
└──────┘ └──────┘ └──────┘ └─────┘ └──────┘
|
|
711
|
+
x₀ ε x_noisy predict x̂₀
|
|
712
|
+
= x₀ + σ·ε x₀
|
|
713
|
+
|
|
714
|
+
Loss = MSE(x̂₀, x₀) = ‖U-Net(x_noisy, t) - x₀‖²
|
|
715
|
+
</code></pre>
|
|
716
|
+
|
|
717
|
+
<h3 id="training-code">6.2 Training Loop</h3>
|
|
718
|
+
|
|
719
|
+
<pre><code class="language-python">
|
|
720
|
+
import torch
|
|
721
|
+
import torch.nn as nn
|
|
722
|
+
import torch.optim as optim
|
|
723
|
+
from torch.utils.data import DataLoader
|
|
724
|
+
from torchvision import datasets, transforms
|
|
725
|
+
|
|
726
|
+
# Hyperparameters
|
|
727
|
+
BATCH_SIZE = 32
|
|
728
|
+
LEARNING_RATE = 1e-4
|
|
729
|
+
EPOCHS = 50
|
|
730
|
+
NOISE_LEVEL = 0.5 # σ: controls how much noise to add
|
|
731
|
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
732
|
+
|
|
733
|
+
# Dataset: MNIST (grayscale 28×28 → resize to 64×64)
|
|
734
|
+
transform = transforms.Compose([
|
|
735
|
+
transforms.Resize((64, 64)),
|
|
736
|
+
transforms.ToTensor(), # [0, 1]
|
|
737
|
+
transforms.Normalize([0.5], [0.5]) # [-1, 1]
|
|
738
|
+
])
|
|
739
|
+
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
|
|
740
|
+
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
741
|
+
|
|
742
|
+
# Model, optimizer, loss
|
|
743
|
+
model = UNet(in_channels=1, base_channels=64, time_emb_dim=128).to(DEVICE)
|
|
744
|
+
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
|
745
|
+
loss_fn = nn.MSELoss()
|
|
746
|
+
|
|
747
|
+
# Training loop
|
|
748
|
+
for epoch in range(EPOCHS):
|
|
749
|
+
total_loss = 0
|
|
750
|
+
for batch_idx, (images, _) in enumerate(dataloader):
|
|
751
|
+
images = images.to(DEVICE) # (B, 1, 64, 64)
|
|
752
|
+
|
|
753
|
+
# Random timesteps (mỗi sample một timestep khác nhau)
|
|
754
|
+
timesteps = torch.randint(0, 1000, (images.shape[0],), device=DEVICE)
|
|
755
|
+
|
|
756
|
+
# Scale noise level theo timestep (đơn giản: linear scaling)
|
|
757
|
+
noise_scales = (timesteps.float() / 1000.0 * NOISE_LEVEL) # (B,)
|
|
758
|
+
noise_scales = noise_scales[:, None, None, None] # (B,1,1,1)
|
|
759
|
+
|
|
760
|
+
# Add noise
|
|
761
|
+
noise = torch.randn_like(images) # (B, 1, 64, 64)
|
|
762
|
+
noisy_images = images + noise_scales * noise # (B, 1, 64, 64)
|
|
763
|
+
|
|
764
|
+
# Forward pass: predict clean image
|
|
765
|
+
predicted_clean = model(noisy_images, timesteps) # (B, 1, 64, 64)
|
|
766
|
+
|
|
767
|
+
# Loss: MSE between predicted clean and actual clean
|
|
768
|
+
loss = loss_fn(predicted_clean, images)
|
|
769
|
+
|
|
770
|
+
# Backward pass
|
|
771
|
+
optimizer.zero_grad()
|
|
772
|
+
loss.backward()
|
|
773
|
+
optimizer.step()
|
|
774
|
+
|
|
775
|
+
total_loss += loss.item()
|
|
776
|
+
|
|
777
|
+
avg_loss = total_loss / len(dataloader)
|
|
778
|
+
print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {avg_loss:.6f}")
|
|
779
|
+
</code></pre>
|
|
780
|
+
|
|
781
|
+
<h3 id="visualize-results">6.3 Visualize Results</h3>
|
|
782
|
+
|
|
783
|
+
<pre><code class="language-python">
|
|
784
|
+
import matplotlib.pyplot as plt
|
|
785
|
+
|
|
786
|
+
@torch.no_grad()
|
|
787
|
+
def visualize_denoising(model, dataloader, noise_level=0.5, num_images=5):
|
|
788
|
+
"""Hiển thị: Original → Noisy → Denoised."""
|
|
789
|
+
model.eval()
|
|
790
|
+
images, _ = next(iter(dataloader))
|
|
791
|
+
images = images[:num_images].to(DEVICE)
|
|
792
|
+
|
|
793
|
+
# Add noise
|
|
794
|
+
timesteps = torch.full((num_images,), 500, device=DEVICE)
|
|
795
|
+
noise = torch.randn_like(images) * noise_level
|
|
796
|
+
noisy = images + noise
|
|
797
|
+
|
|
798
|
+
# Denoise
|
|
799
|
+
denoised = model(noisy, timesteps)
|
|
800
|
+
|
|
801
|
+
# Plot
|
|
802
|
+
fig, axes = plt.subplots(3, num_images, figsize=(num_images * 3, 9))
|
|
803
|
+
titles = ['Original', 'Noisy', 'Denoised']
|
|
804
|
+
|
|
805
|
+
for i in range(num_images):
|
|
806
|
+
for j, (img, title) in enumerate(zip(
|
|
807
|
+
[images[i], noisy[i], denoised[i]], titles
|
|
808
|
+
)):
|
|
809
|
+
ax = axes[j][i]
|
|
810
|
+
# Denormalize: [-1,1] → [0,1]
|
|
811
|
+
img_np = (img.cpu().squeeze() * 0.5 + 0.5).clamp(0, 1).numpy()
|
|
812
|
+
ax.imshow(img_np, cmap='gray')
|
|
813
|
+
ax.set_title(title if i == 0 else '')
|
|
814
|
+
ax.axis('off')
|
|
815
|
+
|
|
816
|
+
plt.tight_layout()
|
|
817
|
+
plt.savefig('denoising_results.png', dpi=150)
|
|
818
|
+
plt.show()
|
|
819
|
+
|
|
820
|
+
visualize_denoising(model, dataloader)
|
|
821
|
+
</code></pre>
|
|
822
|
+
|
|
823
|
+
<blockquote><p><strong>Exam tip:</strong> Trong assessment, bạn có thể cần hoàn thành training loop. Nhớ 3 bước quan trọng: (1) thêm noise vào clean image, (2) forward pass qua U-Net với noisy image + timestep, (3) tính MSE loss giữa predicted và original. <strong>Đừng quên truyền timestep vào model!</strong></p></blockquote>
|
|
824
|
+
|
|
825
|
+
<h2 id="cheat-sheet">7. Cheat Sheet — U-Net & Denoising</h2>
|
|
826
|
+
|
|
827
|
+
<table>
|
|
828
|
+
<thead>
|
|
829
|
+
<tr><th>Concept</th><th>Key Detail</th><th>Code/Formula</th></tr>
|
|
830
|
+
</thead>
|
|
831
|
+
<tbody>
|
|
832
|
+
<tr><td>U-Net Structure</td><td>Encoder → Bottleneck → Decoder + Skip Connections</td><td>Hình chữ U, skip = concatenate</td></tr>
|
|
833
|
+
<tr><td>GroupNorm</td><td>Normalize per group, batch-size independent</td><td><code>nn.GroupNorm(8, channels)</code></td></tr>
|
|
834
|
+
<tr><td>GELU</td><td>Smooth activation, x·Φ(x)</td><td><code>nn.GELU()</code></td></tr>
|
|
835
|
+
<tr><td>Rearrange Pooling</td><td>(B,C,2H,2W) → (B,4C,H,W), lossless</td><td><code>rearrange(x, 'b c (h p1) (w p2) → b (c p1 p2) h w', p1=2, p2=2)</code></td></tr>
|
|
836
|
+
<tr><td>Sinusoidal Embed</td><td>sin/cos at varying frequencies</td><td><code>sin(t/10000^(2i/d))</code>, <code>cos(t/10000^(2i/d))</code></td></tr>
|
|
837
|
+
<tr><td>Timestep Injection</td><td>Add to feature maps after GroupNorm</td><td><code>h = h + t_emb[:,:,None,None]</code></td></tr>
|
|
838
|
+
<tr><td>ResidualBlock</td><td>Conv→GN→(+t)→GELU→Conv→GN→GELU + skip</td><td>2 conv layers + residual + timestep</td></tr>
|
|
839
|
+
<tr><td>Denoising Loss</td><td>MSE between predicted clean & actual clean</td><td><code>MSE(model(x_noisy, t), x_clean)</code></td></tr>
|
|
840
|
+
<tr><td>Skip Connection Role</td><td>Preserve spatial details, improve gradient flow</td><td><code>torch.cat([upsample, skip], dim=1)</code></td></tr>
|
|
841
|
+
<tr><td>Channels after Concat</td><td>Channels from upsample + channels from skip</td><td>Phải match in_channels của conv tiếp theo</td></tr>
|
|
842
|
+
</tbody>
|
|
843
|
+
</table>
|
|
844
|
+
|
|
845
|
+
<h2 id="practice-questions">8. Practice Questions</h2>
|
|
846
|
+
|
|
847
|
+
<p>Các câu hỏi dưới đây mô phỏng coding assessment trong NVIDIA DLI. Hãy thử code trước khi xem đáp án!</p>
|
|
848
|
+
|
|
849
|
+
<p><strong>Q1: Implement ResidualBlock with Timestep Injection</strong></p>
|
|
850
|
+
|
|
851
|
+
<p>Complete the <code>forward</code> method of the <code>ResidualBlock</code> below. The block should apply two conv layers with GroupNorm and GELU, inject the timestep embedding after the first normalization, and add a residual connection.</p>
|
|
852
|
+
|
|
853
|
+
<pre><code class="language-python">
|
|
854
|
+
class ResidualBlock(nn.Module):
|
|
855
|
+
def __init__(self, in_ch, out_ch, t_dim):
|
|
856
|
+
super().__init__()
|
|
857
|
+
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
|
|
858
|
+
self.norm1 = nn.GroupNorm(8, out_ch)
|
|
859
|
+
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
|
|
860
|
+
self.norm2 = nn.GroupNorm(8, out_ch)
|
|
861
|
+
self.act = nn.GELU()
|
|
862
|
+
self.time_proj = nn.Linear(t_dim, out_ch)
|
|
863
|
+
self.res_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
|
|
864
|
+
|
|
865
|
+
def forward(self, x, t_emb):
|
|
866
|
+
# TODO: implement this method
|
|
867
|
+
pass
|
|
868
|
+
</code></pre>
|
|
869
|
+
|
|
870
|
+
<details>
|
|
871
|
+
<summary>Show Answer Q1</summary>
|
|
872
|
+
|
|
873
|
+
<pre><code class="language-python">
|
|
874
|
+
def forward(self, x, t_emb):
|
|
875
|
+
residual = self.res_conv(x)
|
|
876
|
+
|
|
877
|
+
h = self.conv1(x)
|
|
878
|
+
h = self.norm1(h)
|
|
879
|
+
|
|
880
|
+
# Inject timestep: project t_emb to out_ch, reshape for broadcasting, add
|
|
881
|
+
t = self.time_proj(t_emb) # (B, out_ch)
|
|
882
|
+
t = t[:, :, None, None] # (B, out_ch, 1, 1)
|
|
883
|
+
h = h + t
|
|
884
|
+
|
|
885
|
+
h = self.act(h)
|
|
886
|
+
|
|
887
|
+
h = self.conv2(h)
|
|
888
|
+
h = self.norm2(h)
|
|
889
|
+
h = self.act(h)
|
|
890
|
+
|
|
891
|
+
return h + residual
|
|
892
|
+
</code></pre>
|
|
893
|
+
|
|
894
|
+
<p><em>Explanation: Key points — (1) timestep is ADDED not concatenated, (2) reshape to (B, C, 1, 1) enables broadcasting across H×W, (3) injection happens after norm1 before GELU, (4) residual uses 1×1 conv if channel dimensions mismatch.</em></p>
|
|
895
|
+
</details>
|
|
896
|
+
|
|
897
|
+
<p><strong>Q2: What happens if you remove skip connections from U-Net?</strong></p>
|
|
898
|
+
|
|
899
|
+
<p>Consider the following modified U-Net that does NOT use skip connections in the decoder:</p>
|
|
900
|
+
|
|
901
|
+
<pre><code class="language-python">
|
|
902
|
+
# Original (with skip connections):
|
|
903
|
+
x = self.upsample(x)
|
|
904
|
+
x = torch.cat([x, skip], dim=1) # concat skip
|
|
905
|
+
x = self.res_block(x, t_emb)
|
|
906
|
+
|
|
907
|
+
# Modified (WITHOUT skip connections):
|
|
908
|
+
x = self.upsample(x)
|
|
909
|
+
# skip connection removed!
|
|
910
|
+
x = self.res_block(x, t_emb)
|
|
911
|
+
</code></pre>
|
|
912
|
+
|
|
913
|
+
<p>What will happen to the denoised output? Choose all that apply:</p>
|
|
914
|
+
|
|
915
|
+
<ul>
|
|
916
|
+
<li>A) Output will be blurry, losing fine details</li>
|
|
917
|
+
<li>B) Model fails to compile due to shape mismatch</li>
|
|
918
|
+
<li>C) Training loss will increase significantly</li>
|
|
919
|
+
<li>D) Model produces identical output regardless of input</li>
|
|
920
|
+
</ul>
|
|
921
|
+
|
|
922
|
+
<details>
|
|
923
|
+
<summary>Show Answer Q2</summary>
|
|
924
|
+
|
|
925
|
+
<p><strong>A and C are correct.</strong></p>
|
|
926
|
+
|
|
927
|
+
<p><em>Explanation: (A) Without skip connections, the decoder only has bottleneck information (8×8 at 512 channels) to reconstruct 64×64 details — fine-grained textures and edges are lost, resulting in blurry outputs. (B) Incorrect if in_channels of res_block is adjusted — no shape mismatch if properly configured. (C) Correct — the model has less information to reconstruct from, so MSE loss between prediction and clean image will be higher. (D) This would only happen in extreme cases like total information bottleneck. The model can still capture rough structure from the bottleneck features.</em></p>
|
|
928
|
+
</details>
|
|
929
|
+
|
|
930
|
+
<p><strong>Q3: Calculate output shapes through each U-Net level</strong></p>
|
|
931
|
+
|
|
932
|
+
<p>Given the following U-Net configuration, fill in the missing tensor shapes:</p>
|
|
933
|
+
|
|
934
|
+
<pre><code class="language-python">
|
|
935
|
+
# Config: in_channels=1, base_channels=32, image_size=32×32
|
|
936
|
+
# Using Rearrange Pooling for downsampling
|
|
937
|
+
|
|
938
|
+
x = input # Shape: (B, 1, 32, 32)
|
|
939
|
+
x = init_conv(x) # Shape: (B, 32, 32, 32)
|
|
940
|
+
|
|
941
|
+
# Encoder Level 1
|
|
942
|
+
skip1, x = down1(x) # skip1: (B, 32, 32, 32), x after rearrange: ???
|
|
943
|
+
x = proj1(x) # Shape: ???
|
|
944
|
+
|
|
945
|
+
# Encoder Level 2
|
|
946
|
+
skip2, x = down2(x) # skip2: ???, x after rearrange: ???
|
|
947
|
+
x = proj2(x) # Shape: ???
|
|
948
|
+
|
|
949
|
+
# Bottleneck
|
|
950
|
+
x = bottleneck(x) # Shape: ???
|
|
951
|
+
|
|
952
|
+
# Decoder Level 2
|
|
953
|
+
x = upsample(x) # Shape: ???
|
|
954
|
+
x = cat(x, skip2) # Shape: ???
|
|
955
|
+
x = res_block(x) # Shape: ???
|
|
956
|
+
|
|
957
|
+
# Decoder Level 1
|
|
958
|
+
x = upsample(x) # Shape: ???
|
|
959
|
+
x = cat(x, skip1) # Shape: ???
|
|
960
|
+
x = res_block(x) # Shape: ???
|
|
961
|
+
|
|
962
|
+
x = final_conv(x) # Shape: (B, 1, 32, 32)
|
|
963
|
+
</code></pre>
|
|
964
|
+
|
|
965
|
+
<details>
|
|
966
|
+
<summary>Show Answer Q3</summary>
|
|
967
|
+
|
|
968
|
+
<pre><code class="language-python">
|
|
969
|
+
x = input # (B, 1, 32, 32)
|
|
970
|
+
x = init_conv(x) # (B, 32, 32, 32)
|
|
971
|
+
|
|
972
|
+
# Encoder Level 1
|
|
973
|
+
skip1 = res1(x) # skip1: (B, 32, 32, 32)
|
|
974
|
+
x = rearrange(skip1) # (B, 128, 16, 16) ← 32×4=128, 32/2=16
|
|
975
|
+
x = proj1(x) # (B, 64, 16, 16) ← 1×1 conv reduce
|
|
976
|
+
|
|
977
|
+
# Encoder Level 2
|
|
978
|
+
skip2 = res2(x) # skip2: (B, 64, 16, 16)
|
|
979
|
+
x = rearrange(skip2) # (B, 256, 8, 8) ← 64×4=256, 16/2=8
|
|
980
|
+
x = proj2(x) # (B, 128, 8, 8) ← 1×1 conv reduce
|
|
981
|
+
|
|
982
|
+
# Bottleneck
|
|
983
|
+
x = bottleneck(x) # (B, 128, 8, 8)
|
|
984
|
+
|
|
985
|
+
# Decoder Level 2
|
|
986
|
+
x = upsample(x) # (B, 128, 16, 16) ← spatial ×2
|
|
987
|
+
x = cat(x, skip2) # (B, 192, 16, 16) ← 128+64=192
|
|
988
|
+
x = res_block(x) # (B, 64, 16, 16) ← project down
|
|
989
|
+
|
|
990
|
+
# Decoder Level 1
|
|
991
|
+
x = upsample(x) # (B, 64, 32, 32) ← spatial ×2
|
|
992
|
+
x = cat(x, skip1) # (B, 96, 32, 32) ← 64+32=96
|
|
993
|
+
x = res_block(x) # (B, 32, 32, 32) ← project down
|
|
994
|
+
|
|
995
|
+
x = final_conv(x) # (B, 1, 32, 32)
|
|
996
|
+
</code></pre>
|
|
997
|
+
|
|
998
|
+
<p><em>Explanation: The key pattern — Rearrange Pooling multiplies channels by 4 and halves spatial dimensions. After concat with skip, channels = upsample_channels + skip_channels. Track these carefully to set correct in_channels for each layer.</em></p>
|
|
999
|
+
</details>
|
|
1000
|
+
|
|
1001
|
+
<p><strong>Q4: Implement SinusoidalPositionEmbedding class</strong></p>
|
|
1002
|
+
|
|
1003
|
+
<p>Implement the <code>forward</code> method that converts integer timesteps to sinusoidal embeddings:</p>
|
|
1004
|
+
|
|
1005
|
+
<pre><code class="language-python">
|
|
1006
|
+
class SinusoidalPositionEmbedding(nn.Module):
|
|
1007
|
+
def __init__(self, embed_dim):
|
|
1008
|
+
super().__init__()
|
|
1009
|
+
self.embed_dim = embed_dim # must be even
|
|
1010
|
+
|
|
1011
|
+
def forward(self, timesteps):
|
|
1012
|
+
"""
|
|
1013
|
+
Args:
|
|
1014
|
+
timesteps: (B,) — integer timesteps
|
|
1015
|
+
Returns:
|
|
1016
|
+
(B, embed_dim) — sinusoidal embeddings
|
|
1017
|
+
"""
|
|
1018
|
+
# TODO: implement using formula:
|
|
1019
|
+
# PE(t, 2i) = sin(t / 10000^(2i/d))
|
|
1020
|
+
# PE(t, 2i+1) = cos(t / 10000^(2i/d))
|
|
1021
|
+
pass
|
|
1022
|
+
</code></pre>
|
|
1023
|
+
|
|
1024
|
+
<details>
|
|
1025
|
+
<summary>Show Answer Q4</summary>
|
|
1026
|
+
|
|
1027
|
+
<pre><code class="language-python">
|
|
1028
|
+
import math
|
|
1029
|
+
|
|
1030
|
+
def forward(self, timesteps):
|
|
1031
|
+
device = timesteps.device
|
|
1032
|
+
half_dim = self.embed_dim // 2
|
|
1033
|
+
|
|
1034
|
+
# Step 1: Compute frequency terms
|
|
1035
|
+
# exp(-log(10000) * i/(d/2)) = 1/10000^(i/(d/2)) for i in [0, d/2)
|
|
1036
|
+
freqs = torch.exp(
|
|
1037
|
+
-math.log(10000.0) * torch.arange(half_dim, device=device).float() / half_dim
|
|
1038
|
+
)
|
|
1039
|
+
|
|
1040
|
+
# Step 2: Outer product of timesteps and frequencies
|
|
1041
|
+
# (B, 1) * (1, d/2) → (B, d/2)
|
|
1042
|
+
args = timesteps[:, None].float() * freqs[None, :]
|
|
1043
|
+
|
|
1044
|
+
# Step 3: Apply sin and cos, concatenate
|
|
1045
|
+
embeddings = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
|
1046
|
+
# Result shape: (B, embed_dim)
|
|
1047
|
+
|
|
1048
|
+
return embeddings
|
|
1049
|
+
</code></pre>
|
|
1050
|
+
|
|
1051
|
+
<p><em>Explanation: Three key steps — (1) compute frequency terms using exp(-log(10000) * i/half_dim) which is equivalent to 1/10000^(2i/d), (2) multiply each timestep by all frequencies via broadcasting, (3) apply sin to first half and cos to second half then concatenate. The math.log(10000.0) formulation is numerically more stable than computing 10000**(2i/d) directly.</em></p>
|
|
1052
|
+
</details>
|
|
1053
|
+
|
|
1054
|
+
<p><strong>Q5: Debug U-Net — Output is always the mean of the training set</strong></p>
|
|
1055
|
+
|
|
1056
|
+
<p>A student implemented a U-Net for denoising but the output always looks like the blurry average of MNIST digits regardless of input. Review the code below and find the bug:</p>
|
|
1057
|
+
|
|
1058
|
+
<pre><code class="language-python">
|
|
1059
|
+
class BuggyUpBlock(nn.Module):
|
|
1060
|
+
def __init__(self, in_channels, out_channels, time_emb_dim):
|
|
1061
|
+
super().__init__()
|
|
1062
|
+
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
|
1063
|
+
# BUG: Notice in_channels here — where is the skip connection?
|
|
1064
|
+
self.res_block = ResidualBlock(in_channels, out_channels, time_emb_dim)
|
|
1065
|
+
|
|
1066
|
+
def forward(self, x, skip, t_emb):
|
|
1067
|
+
x = self.upsample(x)
|
|
1068
|
+
# BUG: skip connection is received but never used!
|
|
1069
|
+
x = self.res_block(x, t_emb)
|
|
1070
|
+
return x
|
|
1071
|
+
|
|
1072
|
+
|
|
1073
|
+
class BuggyUNet(nn.Module):
|
|
1074
|
+
def __init__(self):
|
|
1075
|
+
super().__init__()
|
|
1076
|
+
# ... encoder and bottleneck (correct) ...
|
|
1077
|
+
|
|
1078
|
+
# Decoder — uses BuggyUpBlock
|
|
1079
|
+
self.up3 = BuggyUpBlock(512, 256, t_dim) # skip not concatenated
|
|
1080
|
+
self.up2 = BuggyUpBlock(256, 128, t_dim) # skip not concatenated
|
|
1081
|
+
self.up1 = BuggyUpBlock(128, 64, t_dim) # skip not concatenated
|
|
1082
|
+
</code></pre>
|
|
1083
|
+
|
|
1084
|
+
<p>What is the bug and how do you fix it?</p>
|
|
1085
|
+
|
|
1086
|
+
<details>
|
|
1087
|
+
<summary>Show Answer Q5</summary>
|
|
1088
|
+
|
|
1089
|
+
<p><strong>Bug:</strong> The <code>skip</code> tensor is passed to <code>forward()</code> but <strong>never concatenated</strong> with <code>x</code>. The decoder only sees bottleneck features (heavily compressed, 8×8) and cannot reconstruct spatial details → output converges to dataset mean.</p>
|
|
1090
|
+
|
|
1091
|
+
<pre><code class="language-python">
|
|
1092
|
+
class FixedUpBlock(nn.Module):
|
|
1093
|
+
def __init__(self, in_channels, skip_channels, out_channels, time_emb_dim):
|
|
1094
|
+
super().__init__()
|
|
1095
|
+
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
|
1096
|
+
# FIX: in_channels = in_channels + skip_channels (after concat)
|
|
1097
|
+
self.res_block = ResidualBlock(
|
|
1098
|
+
in_channels + skip_channels, out_channels, time_emb_dim
|
|
1099
|
+
)
|
|
1100
|
+
|
|
1101
|
+
def forward(self, x, skip, t_emb):
|
|
1102
|
+
x = self.upsample(x)
|
|
1103
|
+
x = torch.cat([x, skip], dim=1) # FIX: concatenate skip connection!
|
|
1104
|
+
x = self.res_block(x, t_emb)
|
|
1105
|
+
return x
|
|
1106
|
+
</code></pre>
|
|
1107
|
+
|
|
1108
|
+
<p><em>Explanation: This is a common and subtle bug. The model still trains and produces output of correct shape, but without skip connections the decoder is a pure upsampling network with only bottleneck features. Since the 8×8 bottleneck captures global statistics but not spatial details, the model learns to output the average image (minimum MSE solution when lacking detail info). Two fixes: (1) add <code>torch.cat([x, skip], dim=1)</code> in forward, (2) change ResidualBlock in_channels to account for concatenated skip channels.</em></p>
|
|
1109
|
+
</details>
|
|
1110
|
+
|
|
1111
|
+
<blockquote><p><strong>Exam tip:</strong> Trong real assessment, debugging exercises thường liên quan đến <strong>shape mismatches</strong> hoặc <strong>missing connections</strong>. Khi model output trông "bình thường" nhưng blurry và giống nhau cho mọi input — nghĩ ngay đến skip connections bị thiếu hoặc sai. Luôn in tensor shapes ở mỗi layer khi debug!</p></blockquote>
|