orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -1
- orbit/callback.py +4 -3
- orbit/dataset/__init__.py +1 -0
- orbit/dataset/cogn.py +138 -0
- orbit/dataset/data/cogn_en.jsonl +45 -0
- orbit/dataset/data/cogn_zh.jsonl +113 -0
- orbit/engine.py +210 -146
- orbit/kit/__init__.py +2 -0
- orbit/kit/interface.py +154 -0
- orbit/kit/wrapper.py +157 -0
- orbit/model/__init__.py +5 -0
- orbit/model/base.py +125 -0
- orbit/model/block/__init__.py +34 -0
- orbit/model/block/attention.py +265 -0
- orbit/model/block/bio.py +537 -0
- orbit/model/block/codebook.py +122 -0
- orbit/model/block/conv.py +505 -0
- orbit/model/block/embedding.py +252 -0
- orbit/model/block/film.py +176 -0
- orbit/model/block/fusion.py +335 -0
- orbit/model/block/gate.py +334 -0
- orbit/model/block/lora.py +776 -0
- orbit/model/block/mlp.py +68 -0
- orbit/model/block/moe.py +94 -0
- orbit/model/block/tcn.py +99 -0
- orbit/model/config.py +62 -0
- orbit/model/kit/__init__.py +6 -0
- orbit/model/kit/discriminator.py +46 -0
- orbit/model/kit/losses.py +193 -0
- orbit/model/motif/__init__.py +0 -0
- orbit/model/motif/vision/__init__.py +0 -0
- orbit/model/motif/vision/v1.py +645 -0
- orbit/model/registry.py +53 -0
- orbit/optim/__init__.py +2 -2
- orbit/optim/sam.py +10 -3
- orbit/plugin/__init__.py +12 -8
- orbit/plugin/board.py +1 -2
- orbit/plugin/checkpoint.py +137 -62
- orbit/plugin/classification.py +2 -2
- orbit/plugin/display_model.py +1 -2
- orbit/plugin/early_stopping.py +1 -2
- orbit/plugin/ema.py +1 -2
- orbit/plugin/gradient_accumulation.py +1 -2
- orbit/plugin/lora.py +346 -0
- orbit/plugin/memory_estimator.py +1 -2
- orbit/plugin/warmup.py +1 -2
- orbit/utils/__init__.py +24 -1
- orbit/utils/cuda.py +10 -0
- orbit/utils/freeze.py +61 -17
- orbit/utils/image.py +164 -0
- orbit/utils/initialization.py +184 -94
- orbit/utils/layer_io.py +66 -7
- orbit/utils/lora.py +480 -0
- orbit/utils/moe.py +55 -0
- orbit/utils/seed.py +3 -19
- orbit/utils/sft.py +93 -0
- orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
- orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
- orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,645 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Tuple, Optional
|
|
6
|
+
|
|
7
|
+
from orbit.model import BaseBlock, register_model
|
|
8
|
+
from orbit.utils.image import pad_to_patch_size
|
|
9
|
+
from orbit.model.block import (
|
|
10
|
+
ConvBlock, ResBasicBlock,
|
|
11
|
+
SpatialMultiHeadAttention, AttentionOutput,
|
|
12
|
+
MRoPEInterleavedEmbedding,
|
|
13
|
+
LFQ, QuantizerOutput
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
def dynamic_collate_fn(batch):
|
|
17
|
+
'''
|
|
18
|
+
将不同尺寸的图片 Batch 整理为统一尺寸 (Padding),并生成有效区域掩码。
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
padded_batch: [B, C, Max_H, Max_W]
|
|
22
|
+
mask_batch: [B, 1, Max_H, Max_W] (1.0 = Valid, 0.0 = Padding)
|
|
23
|
+
'''
|
|
24
|
+
images = [item for item in batch]
|
|
25
|
+
|
|
26
|
+
max_h = max([img.shape[1] for img in images])
|
|
27
|
+
max_w = max([img.shape[2] for img in images])
|
|
28
|
+
|
|
29
|
+
stride = 16 # Patch Size
|
|
30
|
+
max_h = ((max_h + stride - 1) // stride) * stride
|
|
31
|
+
max_w = ((max_w + stride - 1) // stride) * stride
|
|
32
|
+
|
|
33
|
+
batch_size = len(images)
|
|
34
|
+
channels = images[0].shape[0]
|
|
35
|
+
|
|
36
|
+
padded_batch = torch.zeros(batch_size, channels, max_h, max_w)
|
|
37
|
+
mask_batch = torch.zeros(batch_size, 1, max_h, max_w)
|
|
38
|
+
|
|
39
|
+
for i, img in enumerate(images):
|
|
40
|
+
h, w = img.shape[1], img.shape[2]
|
|
41
|
+
padded_batch[i, :, :h, :w] = img
|
|
42
|
+
mask_batch[i, :, :h, :w] = 1.0
|
|
43
|
+
|
|
44
|
+
return padded_batch, mask_batch
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class EncoderOutput:
|
|
48
|
+
'''
|
|
49
|
+
VQGAN Encoder 的输出数据类。
|
|
50
|
+
|
|
51
|
+
Attributes:
|
|
52
|
+
output (torch.Tensor): 编码后的潜在变量,形状为 [B, out_channels, h_latent, w_latent]。
|
|
53
|
+
mask (torch.Tensor): 原始分辨率的有效区域掩码,形状为 [B, 1, H_pad, W_pad]。
|
|
54
|
+
input_shape (Tuple[int, int]): 原始输入图像的尺寸 (H, W)。
|
|
55
|
+
padded_shape (Tuple[int, int]): Padding 后的图像尺寸 (H_pad, W_pad)。
|
|
56
|
+
'''
|
|
57
|
+
output: torch.Tensor
|
|
58
|
+
mask: torch.Tensor
|
|
59
|
+
input_shape: Tuple[int, int]
|
|
60
|
+
padded_shape: Tuple[int, int]
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class DecoderOutput:
|
|
64
|
+
'''
|
|
65
|
+
VQGAN Decoder 的输出数据类。
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
reconstruction (torch.Tensor): 重建的图像 [B, 3, H, W]。
|
|
69
|
+
'''
|
|
70
|
+
reconstruction: torch.Tensor
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class MotifV1Info:
|
|
74
|
+
'''
|
|
75
|
+
MotifV1 的辅助信息数据类。
|
|
76
|
+
|
|
77
|
+
Attributes:
|
|
78
|
+
perplexity (torch.Tensor): 码本困惑度。
|
|
79
|
+
entropy (torch.Tensor): 码本熵。
|
|
80
|
+
indices (torch.Tensor): 量化索引 [B, H, W]。
|
|
81
|
+
mask (Optional[torch.Tensor]): 有效区域掩码 [B, 1, H, W]。
|
|
82
|
+
'''
|
|
83
|
+
perplexity: torch.Tensor
|
|
84
|
+
entropy: torch.Tensor
|
|
85
|
+
indices: torch.Tensor
|
|
86
|
+
mask: Optional[torch.Tensor]
|
|
87
|
+
|
|
88
|
+
@dataclass
|
|
89
|
+
class MotifV1Output:
|
|
90
|
+
'''
|
|
91
|
+
MotifV1 主模型的输出数据类。
|
|
92
|
+
|
|
93
|
+
Attributes:
|
|
94
|
+
reconstruction (torch.Tensor): 重建的图像 [B, 3, H, W]。
|
|
95
|
+
loss (torch.Tensor): 量化损失 (Commitment Loss)。
|
|
96
|
+
info (MotifV1Info): 辅助信息对象。
|
|
97
|
+
'''
|
|
98
|
+
reconstruction: torch.Tensor
|
|
99
|
+
loss: torch.Tensor
|
|
100
|
+
info: MotifV1Info
|
|
101
|
+
|
|
102
|
+
@register_model()
|
|
103
|
+
class MotifV1Encoder(BaseBlock):
|
|
104
|
+
'''
|
|
105
|
+
支持可变输入尺寸和 2D RoPE 的 VQGAN Encoder。
|
|
106
|
+
|
|
107
|
+
该编码器通过一系列下采样卷积块和残差块提取特征,并在中间层应用空间多头注意力机制(Spatial Multi-Head Attention)
|
|
108
|
+
和二维旋转位置编码(2D RoPE)。它能够处理任意尺寸的输入图像,通过 padding 确保尺寸满足 patch_size 的要求。
|
|
109
|
+
'''
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
in_channels: int = 3,
|
|
113
|
+
hidden_dim: int = 128,
|
|
114
|
+
out_channels: int = 256,
|
|
115
|
+
patch_size: int = 16,
|
|
116
|
+
num_res_blocks: int = 1,
|
|
117
|
+
num_heads: int = 4,
|
|
118
|
+
dropout: float = 0.0,
|
|
119
|
+
rope_max_len: int = 4096,
|
|
120
|
+
max_channels: int = 256
|
|
121
|
+
):
|
|
122
|
+
'''
|
|
123
|
+
初始化 MotifV1Encoder
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
in_channels (int): 输入图像的通道数。默认为 3。
|
|
127
|
+
hidden_dim (int): 隐藏层的初始维度。默认为 128。
|
|
128
|
+
out_channels (int): 输出潜在变量的通道数。默认为 256。
|
|
129
|
+
patch_size (int): Patch 大小,必须是 2 的幂。用于计算下采样次数。默认为 16。
|
|
130
|
+
num_res_blocks (int): 每个下采样阶段的残差块数量。默认为 1。
|
|
131
|
+
num_heads (int): 注意力机制的头数。默认为 4。
|
|
132
|
+
dropout (float): Dropout 概率。默认为 0.0。
|
|
133
|
+
rope_max_len (int): RoPE 的最大长度限制。默认为 4096。
|
|
134
|
+
max_channels (int): 隐藏层通道数的最大值。默认为 256。
|
|
135
|
+
|
|
136
|
+
Raises:
|
|
137
|
+
ValueError: 如果 patch_size 不是 2 的幂。
|
|
138
|
+
'''
|
|
139
|
+
super().__init__()
|
|
140
|
+
|
|
141
|
+
if (patch_size & (patch_size - 1)) != 0:
|
|
142
|
+
raise ValueError(f"patch_size must be a power of 2, got {patch_size}")
|
|
143
|
+
|
|
144
|
+
self.patch_size = patch_size
|
|
145
|
+
self.num_downsamples = int(torch.log2(torch.tensor(patch_size)))
|
|
146
|
+
|
|
147
|
+
self.conv_in = ConvBlock(
|
|
148
|
+
in_channels, hidden_dim, kernel_size=3, padding=1, norm=None, activation=None
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
self.down_blocks = nn.ModuleList()
|
|
152
|
+
curr_dim = hidden_dim
|
|
153
|
+
|
|
154
|
+
for _ in range(self.num_downsamples):
|
|
155
|
+
blocks = []
|
|
156
|
+
for _ in range(num_res_blocks):
|
|
157
|
+
blocks.append(ResBasicBlock(
|
|
158
|
+
curr_dim,
|
|
159
|
+
curr_dim,
|
|
160
|
+
dropout=dropout,
|
|
161
|
+
norm='group',
|
|
162
|
+
activation='silu',
|
|
163
|
+
variant='pre_act'
|
|
164
|
+
))
|
|
165
|
+
self.down_blocks.append(nn.Sequential(*blocks))
|
|
166
|
+
|
|
167
|
+
next_dim = min(curr_dim * 2, max_channels)
|
|
168
|
+
|
|
169
|
+
self.down_blocks.append(ConvBlock(
|
|
170
|
+
curr_dim, next_dim, kernel_size=3, stride=2, padding=1,
|
|
171
|
+
norm=None, activation=None
|
|
172
|
+
))
|
|
173
|
+
curr_dim = next_dim
|
|
174
|
+
|
|
175
|
+
self.mid_res1 = ResBasicBlock(
|
|
176
|
+
curr_dim,
|
|
177
|
+
curr_dim,
|
|
178
|
+
dropout=dropout,
|
|
179
|
+
norm='group',
|
|
180
|
+
activation='silu',
|
|
181
|
+
variant='pre_act'
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
self.mid_attn = SpatialMultiHeadAttention(
|
|
185
|
+
hidden_size=curr_dim,
|
|
186
|
+
num_heads=num_heads,
|
|
187
|
+
use_qk_norm=True
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
self.rope = MRoPEInterleavedEmbedding(
|
|
191
|
+
model_dim=curr_dim // num_heads,
|
|
192
|
+
num_axes=2,
|
|
193
|
+
max_len=rope_max_len
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
self.mid_res2 = ResBasicBlock(
|
|
197
|
+
curr_dim,
|
|
198
|
+
curr_dim,
|
|
199
|
+
dropout=dropout,
|
|
200
|
+
norm='group',
|
|
201
|
+
activation='silu',
|
|
202
|
+
variant='pre_act'
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
self.norm_out = nn.GroupNorm(32, curr_dim)
|
|
206
|
+
self.act_out = nn.SiLU(inplace=True)
|
|
207
|
+
self.conv_out = nn.Conv2d(curr_dim, out_channels, kernel_size=3, padding=1)
|
|
208
|
+
|
|
209
|
+
def _build_grid(self, h: int, w: int, device: torch.device) -> torch.Tensor:
|
|
210
|
+
'''
|
|
211
|
+
构建 2D 网格坐标。
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
h (int): 网格的高度。
|
|
215
|
+
w (int): 网格的宽度。
|
|
216
|
+
device (torch.device): 张量所在的设备。
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
torch.Tensor: 网格坐标张量,形状为 [h * w, 2]。
|
|
220
|
+
'''
|
|
221
|
+
y = torch.arange(h, device=device)
|
|
222
|
+
x = torch.arange(w, device=device)
|
|
223
|
+
grid_y, grid_x = torch.meshgrid(y, x, indexing='ij')
|
|
224
|
+
grid = torch.stack([grid_y, grid_x], dim=-1).reshape(-1, 2)
|
|
225
|
+
return grid
|
|
226
|
+
|
|
227
|
+
def forward(self, x: torch.Tensor) -> EncoderOutput:
|
|
228
|
+
'''
|
|
229
|
+
前向传播函数。
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
x (torch.Tensor): 输入图像张量,形状为 [B, C, H, W]。
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
EncoderOutput: 包含编码输出、掩码和尺寸信息的对象。
|
|
236
|
+
'''
|
|
237
|
+
B, C, H, W = x.shape
|
|
238
|
+
|
|
239
|
+
patch_out = pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
|
240
|
+
x_in = patch_out.output
|
|
241
|
+
mask = patch_out.mask
|
|
242
|
+
|
|
243
|
+
h = self.conv_in(x_in)
|
|
244
|
+
|
|
245
|
+
for layer in self.down_blocks:
|
|
246
|
+
h = layer(h)
|
|
247
|
+
|
|
248
|
+
h = self.mid_res1(h)
|
|
249
|
+
|
|
250
|
+
B_feat, C_feat, H_feat, W_feat = h.shape
|
|
251
|
+
h_flat = h.permute(0, 2, 3, 1).reshape(B_feat, H_feat * W_feat, C_feat)
|
|
252
|
+
|
|
253
|
+
positions = self._build_grid(H_feat, W_feat, h.device)
|
|
254
|
+
positions = positions.unsqueeze(0).expand(B_feat, -1, -1)
|
|
255
|
+
|
|
256
|
+
attn_out: AttentionOutput = self.mid_attn(
|
|
257
|
+
hidden_states=h_flat,
|
|
258
|
+
positions=positions,
|
|
259
|
+
rotary_emb=self.rope
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
h = attn_out.output.view(B_feat, H_feat, W_feat, C_feat).permute(0, 3, 1, 2)
|
|
263
|
+
h = self.mid_res2(h)
|
|
264
|
+
|
|
265
|
+
h = self.norm_out(h)
|
|
266
|
+
h = self.act_out(h)
|
|
267
|
+
z = self.conv_out(h)
|
|
268
|
+
|
|
269
|
+
return EncoderOutput(
|
|
270
|
+
output=z,
|
|
271
|
+
mask=mask,
|
|
272
|
+
input_shape=(H, W),
|
|
273
|
+
padded_shape=patch_out.output.shape[-2:]
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
@register_model()
|
|
277
|
+
class MotifV1Decoder(BaseBlock):
|
|
278
|
+
'''
|
|
279
|
+
基于 PixelShuffle 的 VQGAN Decoder。
|
|
280
|
+
|
|
281
|
+
结构:Input Conv -> Mid Block -> Upsample Stack (PixelShuffle) -> Output Conv
|
|
282
|
+
支持根据 Encoder 提供的掩码自动裁剪输出图像。
|
|
283
|
+
'''
|
|
284
|
+
def __init__(
|
|
285
|
+
self,
|
|
286
|
+
in_channels: int = 256,
|
|
287
|
+
hidden_dim: int = 256,
|
|
288
|
+
out_channels: int = 3,
|
|
289
|
+
patch_size: int = 16,
|
|
290
|
+
num_res_blocks: int = 1,
|
|
291
|
+
num_heads: int = 4,
|
|
292
|
+
dropout: float = 0.0,
|
|
293
|
+
rope_max_len: int = 4096,
|
|
294
|
+
):
|
|
295
|
+
'''
|
|
296
|
+
初始化 MotifV1Decoder
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
in_channels (int): 输入潜在变量的通道数。默认为 256。
|
|
300
|
+
hidden_dim (int): 隐藏层的初始维度。默认为 256。
|
|
301
|
+
out_channels (int): 输出图像的通道数。默认为 3。
|
|
302
|
+
patch_size (int): Patch 大小,必须是 2 的幂。用于计算上采样次数。默认为 16。
|
|
303
|
+
num_res_blocks (int): 每个上采样阶段的残差块数量。默认为 1。
|
|
304
|
+
num_heads (int): 注意力机制的头数。默认为 4。
|
|
305
|
+
dropout (float): Dropout 概率。默认为 0.0。
|
|
306
|
+
rope_max_len (int): RoPE 的最大长度限制。默认为 4096。
|
|
307
|
+
|
|
308
|
+
Raises:
|
|
309
|
+
ValueError: 如果 patch_size 不是 2 的幂。
|
|
310
|
+
'''
|
|
311
|
+
super().__init__()
|
|
312
|
+
|
|
313
|
+
if (patch_size & (patch_size - 1)) != 0:
|
|
314
|
+
raise ValueError(f"patch_size must be a power of 2, got {patch_size}")
|
|
315
|
+
|
|
316
|
+
self.patch_size = patch_size
|
|
317
|
+
self.num_upsamples = int(torch.log2(torch.tensor(patch_size)))
|
|
318
|
+
|
|
319
|
+
self.conv_in = ConvBlock(
|
|
320
|
+
in_channels, hidden_dim, kernel_size=3, padding=1, norm=None, activation=None
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
self.mid_res1 = ResBasicBlock(
|
|
324
|
+
hidden_dim,
|
|
325
|
+
hidden_dim,
|
|
326
|
+
dropout=dropout,
|
|
327
|
+
norm='group',
|
|
328
|
+
activation='silu',
|
|
329
|
+
variant='pre_act'
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
self.mid_attn = SpatialMultiHeadAttention(
|
|
333
|
+
hidden_size=hidden_dim,
|
|
334
|
+
num_heads=num_heads,
|
|
335
|
+
use_qk_norm=True
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
self.rope = MRoPEInterleavedEmbedding(
|
|
339
|
+
model_dim=hidden_dim // num_heads,
|
|
340
|
+
num_axes=2,
|
|
341
|
+
max_len=rope_max_len
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
self.mid_res2 = ResBasicBlock(
|
|
345
|
+
hidden_dim,
|
|
346
|
+
hidden_dim,
|
|
347
|
+
dropout=dropout,
|
|
348
|
+
norm='group',
|
|
349
|
+
activation='silu',
|
|
350
|
+
variant='pre_act'
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
self.up_blocks = nn.ModuleList()
|
|
354
|
+
curr_dim = hidden_dim
|
|
355
|
+
|
|
356
|
+
for _ in range(self.num_upsamples):
|
|
357
|
+
blocks = []
|
|
358
|
+
for _ in range(num_res_blocks):
|
|
359
|
+
blocks.append(ResBasicBlock(
|
|
360
|
+
curr_dim,
|
|
361
|
+
curr_dim,
|
|
362
|
+
dropout=dropout,
|
|
363
|
+
norm='group',
|
|
364
|
+
activation='silu',
|
|
365
|
+
variant='pre_act'
|
|
366
|
+
))
|
|
367
|
+
self.up_blocks.append(nn.Sequential(*blocks))
|
|
368
|
+
|
|
369
|
+
next_dim = curr_dim // 2
|
|
370
|
+
|
|
371
|
+
if next_dim < 64: next_dim = 64
|
|
372
|
+
|
|
373
|
+
self.up_blocks.append(nn.Sequential(
|
|
374
|
+
ConvBlock(curr_dim, next_dim * 4, kernel_size=3, padding=1, norm='group', activation='silu'),
|
|
375
|
+
nn.PixelShuffle(2)
|
|
376
|
+
))
|
|
377
|
+
curr_dim = next_dim
|
|
378
|
+
|
|
379
|
+
self.norm_out = nn.GroupNorm(32, curr_dim)
|
|
380
|
+
self.act_out = nn.SiLU(inplace=True)
|
|
381
|
+
self.conv_out = nn.Conv2d(curr_dim, out_channels, kernel_size=3, padding=1)
|
|
382
|
+
|
|
383
|
+
def _build_grid(self, h: int, w: int, device: torch.device) -> torch.Tensor:
|
|
384
|
+
'''
|
|
385
|
+
构建 2D 网格坐标。
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
h (int): 网格的高度。
|
|
389
|
+
w (int): 网格的宽度。
|
|
390
|
+
device (torch.device): 张量所在的设备。
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
torch.Tensor: 网格坐标张量,形状为 [h * w, 2]。
|
|
394
|
+
'''
|
|
395
|
+
y = torch.arange(h, device=device)
|
|
396
|
+
x = torch.arange(w, device=device)
|
|
397
|
+
grid_y, grid_x = torch.meshgrid(y, x, indexing='ij')
|
|
398
|
+
grid = torch.stack([grid_y, grid_x], dim=-1).reshape(-1, 2)
|
|
399
|
+
return grid
|
|
400
|
+
|
|
401
|
+
def forward(self, z: torch.Tensor, mask: torch.Tensor = None) -> DecoderOutput:
|
|
402
|
+
'''
|
|
403
|
+
Args:
|
|
404
|
+
z (torch.Tensor): 量化后的 Latent [B, C, H, W]
|
|
405
|
+
mask (torch.Tensor, optional): Encoder 输出的掩码 [B, 1, H_orig, W_orig]。
|
|
406
|
+
如果提供,将用于裁剪 Padding。
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
DecoderOutput: 包含重建图像的对象。
|
|
410
|
+
'''
|
|
411
|
+
B, C, H, W = z.shape
|
|
412
|
+
|
|
413
|
+
h = self.conv_in(z)
|
|
414
|
+
|
|
415
|
+
h = self.mid_res1(h)
|
|
416
|
+
|
|
417
|
+
B_feat, C_feat, H_feat, W_feat = h.shape
|
|
418
|
+
h_flat = h.permute(0, 2, 3, 1).reshape(B_feat, H_feat * W_feat, C_feat)
|
|
419
|
+
|
|
420
|
+
positions = self._build_grid(H_feat, W_feat, h.device)
|
|
421
|
+
positions = positions.unsqueeze(0).expand(B_feat, -1, -1)
|
|
422
|
+
|
|
423
|
+
attn_out = self.mid_attn(
|
|
424
|
+
hidden_states=h_flat,
|
|
425
|
+
positions=positions,
|
|
426
|
+
rotary_emb=self.rope
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
h = attn_out.output.view(B_feat, H_feat, W_feat, C_feat).permute(0, 3, 1, 2)
|
|
430
|
+
h = self.mid_res2(h)
|
|
431
|
+
|
|
432
|
+
for layer in self.up_blocks: h = layer(h)
|
|
433
|
+
|
|
434
|
+
h = self.norm_out(h)
|
|
435
|
+
h = self.act_out(h)
|
|
436
|
+
recon = self.conv_out(h) # [B, 3, H_pad, W_pad]
|
|
437
|
+
|
|
438
|
+
return DecoderOutput(reconstruction=recon)
|
|
439
|
+
|
|
440
|
+
@register_model()
|
|
441
|
+
class MotifV1(BaseBlock):
|
|
442
|
+
'''
|
|
443
|
+
Motif-V1 主模型。
|
|
444
|
+
整合 Encoder, Decoder 和 LFQ Quantizer。
|
|
445
|
+
'''
|
|
446
|
+
def __init__(
|
|
447
|
+
self,
|
|
448
|
+
in_channels: int = 3,
|
|
449
|
+
out_channels: int = 3,
|
|
450
|
+
latent_dim: int = 256,
|
|
451
|
+
codebook_dim: int = 18,
|
|
452
|
+
patch_size: int = 16,
|
|
453
|
+
|
|
454
|
+
enc_hidden_dim: int = 64,
|
|
455
|
+
enc_num_res_blocks: int = 1,
|
|
456
|
+
enc_num_heads: int = 4,
|
|
457
|
+
enc_max_channels: int = 256,
|
|
458
|
+
|
|
459
|
+
dec_hidden_dim: int = 256,
|
|
460
|
+
dec_num_res_blocks: int = 1,
|
|
461
|
+
dec_num_heads: int = 4,
|
|
462
|
+
|
|
463
|
+
dropout: float = 0.0,
|
|
464
|
+
rope_max_len: int = 4096,
|
|
465
|
+
|
|
466
|
+
entropy_weight: float = 0.1,
|
|
467
|
+
commitment_weight: float = 0.25,
|
|
468
|
+
):
|
|
469
|
+
'''
|
|
470
|
+
初始化 MotifV1 主模型。
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
in_channels (int): 输入图像的通道数。默认为 3。
|
|
474
|
+
out_channels (int): 输出图像的通道数。默认为 3。
|
|
475
|
+
latent_dim (int): Encoder 输出和 Decoder 输入的潜在变量维度。默认为 256。
|
|
476
|
+
codebook_dim (int): LFQ 码本的比特数(维度)。默认为 18。
|
|
477
|
+
patch_size (int): Patch 大小,必须是 2 的幂。默认为 16。
|
|
478
|
+
enc_hidden_dim (int): Encoder 隐藏层的初始维度。默认为 64。
|
|
479
|
+
enc_num_res_blocks (int): Encoder 每个阶段的残差块数量。默认为 1。
|
|
480
|
+
enc_num_heads (int): Encoder 注意力机制的头数。默认为 4。
|
|
481
|
+
enc_max_channels (int): Encoder 隐藏层通道数的最大值。默认为 256。
|
|
482
|
+
dec_hidden_dim (int): Decoder 隐藏层的初始维度。默认为 256。
|
|
483
|
+
dec_num_res_blocks (int): Decoder 每个阶段的残差块数量。默认为 1。
|
|
484
|
+
dec_num_heads (int): Decoder 注意力机制的头数。默认为 4。
|
|
485
|
+
dropout (float): Dropout 概率。默认为 0.0。
|
|
486
|
+
rope_max_len (int): RoPE 的最大长度限制。默认为 4096。
|
|
487
|
+
entropy_weight (float): LFQ 熵损失的权重。默认为 0.1。
|
|
488
|
+
commitment_weight (float): LFQ 承诺损失的权重。默认为 0.25。
|
|
489
|
+
'''
|
|
490
|
+
super().__init__()
|
|
491
|
+
|
|
492
|
+
self.codebook_dim = codebook_dim
|
|
493
|
+
|
|
494
|
+
self.encoder = MotifV1Encoder(
|
|
495
|
+
in_channels=in_channels,
|
|
496
|
+
hidden_dim=enc_hidden_dim,
|
|
497
|
+
out_channels=latent_dim,
|
|
498
|
+
patch_size=patch_size,
|
|
499
|
+
num_res_blocks=enc_num_res_blocks,
|
|
500
|
+
num_heads=enc_num_heads,
|
|
501
|
+
dropout=dropout,
|
|
502
|
+
rope_max_len=rope_max_len,
|
|
503
|
+
max_channels=enc_max_channels
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
self.quantizer = LFQ(
|
|
507
|
+
latent_dim=latent_dim,
|
|
508
|
+
codebook_dim=codebook_dim,
|
|
509
|
+
entropy_weight=entropy_weight,
|
|
510
|
+
commitment_weight=commitment_weight
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
self.decoder = MotifV1Decoder(
|
|
514
|
+
in_channels=latent_dim,
|
|
515
|
+
hidden_dim=dec_hidden_dim,
|
|
516
|
+
out_channels=out_channels,
|
|
517
|
+
patch_size=patch_size,
|
|
518
|
+
num_res_blocks=dec_num_res_blocks,
|
|
519
|
+
num_heads=dec_num_heads,
|
|
520
|
+
dropout=dropout,
|
|
521
|
+
rope_max_len=rope_max_len
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
self.apply(self._init_weights)
|
|
525
|
+
|
|
526
|
+
def _init_weights(self, m):
|
|
527
|
+
'''
|
|
528
|
+
初始化模型权重。
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
m (nn.Module): 需要初始化的模块。
|
|
532
|
+
|
|
533
|
+
Note:
|
|
534
|
+
- Conv2d, Linear, Conv1d 使用 Xavier Uniform 初始化,偏置为 0。
|
|
535
|
+
- LayerNorm, GroupNorm, BatchNorm2d 权重为 1,偏置为 0。
|
|
536
|
+
'''
|
|
537
|
+
if isinstance(m, (nn.Conv2d, nn.Linear, nn.Conv1d)):
|
|
538
|
+
nn.init.xavier_uniform_(m.weight)
|
|
539
|
+
if m.bias is not None:
|
|
540
|
+
nn.init.constant_(m.bias, 0)
|
|
541
|
+
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
|
542
|
+
nn.init.constant_(m.weight, 1)
|
|
543
|
+
nn.init.constant_(m.bias, 0)
|
|
544
|
+
|
|
545
|
+
@property
|
|
546
|
+
def last_layer_weights(self):
|
|
547
|
+
'''用于 VQGANLoss 计算自适应权重'''
|
|
548
|
+
return self.decoder.conv_out.weight
|
|
549
|
+
|
|
550
|
+
def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
551
|
+
'''
|
|
552
|
+
仅编码:将图像编码为量化索引。
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
x (torch.Tensor): 输入图像,形状为 [B, C, H, W]。
|
|
556
|
+
|
|
557
|
+
Returns:
|
|
558
|
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
559
|
+
- indices (torch.Tensor): 量化索引,形状为 [B, H, W]。
|
|
560
|
+
- mask (torch.Tensor): 有效区域掩码,形状为 [B, 1, H, W]。
|
|
561
|
+
- z_q (torch.Tensor): 量化后的潜在变量(用于调试),形状为 [B, C, H, W]。
|
|
562
|
+
'''
|
|
563
|
+
enc_out = self.encoder(x)
|
|
564
|
+
q_out = self.quantizer(enc_out.output)
|
|
565
|
+
return q_out.indices, enc_out.mask, q_out.z_q
|
|
566
|
+
|
|
567
|
+
def decode(self, indices: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
|
568
|
+
'''
|
|
569
|
+
仅解码:将量化索引解码为图像(通常用于 Transformer 推理)。
|
|
570
|
+
|
|
571
|
+
Args:
|
|
572
|
+
indices (torch.Tensor): 量化索引,形状为 [B, H, W]。
|
|
573
|
+
mask (torch.Tensor, optional): 有效区域掩码,用于裁剪输出。默认为 None。
|
|
574
|
+
|
|
575
|
+
Returns:
|
|
576
|
+
torch.Tensor: 重建的图像,形状为 [B, 3, H, W]。
|
|
577
|
+
'''
|
|
578
|
+
z_q = self.indices_to_codes(indices) # [B, H, W, Codebook_Dim]
|
|
579
|
+
|
|
580
|
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
|
581
|
+
|
|
582
|
+
if hasattr(self.quantizer, 'project_out'):
|
|
583
|
+
z_q_permuted = z_q.permute(0, 2, 3, 1)
|
|
584
|
+
z_projected = self.quantizer.project_out(z_q_permuted)
|
|
585
|
+
# [B, H, W, Latent_Dim] -> [B, Latent_Dim, H, W]
|
|
586
|
+
z_q = z_projected.permute(0, 3, 1, 2).contiguous()
|
|
587
|
+
|
|
588
|
+
dec_out = self.decoder(z_q, mask)
|
|
589
|
+
return dec_out.reconstruction
|
|
590
|
+
|
|
591
|
+
def indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor:
|
|
592
|
+
'''
|
|
593
|
+
LFQ 核心工具:将整数索引还原为二值化向量。
|
|
594
|
+
|
|
595
|
+
将整数索引转换为二进制码(-1 或 1)。
|
|
596
|
+
|
|
597
|
+
Args:
|
|
598
|
+
indices (torch.Tensor): 量化索引,形状为 [B, H, W]。
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
torch.Tensor: 二值化向量,形状为 [B, H, W, Codebook_Dim],值为 -1.0 或 1.0。
|
|
602
|
+
'''
|
|
603
|
+
B, H, W = indices.shape
|
|
604
|
+
basis = self.quantizer.basis # [Codebook_Dim]
|
|
605
|
+
|
|
606
|
+
basis_long = basis.long()
|
|
607
|
+
indices_long = indices.long().unsqueeze(-1)
|
|
608
|
+
|
|
609
|
+
is_one = (indices_long & basis_long) > 0
|
|
610
|
+
|
|
611
|
+
codes = torch.where(is_one, torch.tensor(1.0, device=indices.device), torch.tensor(-1.0, device=indices.device))
|
|
612
|
+
|
|
613
|
+
return codes
|
|
614
|
+
|
|
615
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> MotifV1Output:
|
|
616
|
+
'''
|
|
617
|
+
前向传播函数。
|
|
618
|
+
|
|
619
|
+
Args:
|
|
620
|
+
x (torch.Tensor): 已经 Pad 过的输入图像,形状为 [B, 3, H, W]。
|
|
621
|
+
mask (torch.Tensor, optional): 外部传入的有效区域掩码(通常由 collate_fn 生成),形状为 [B, 1, H, W]。
|
|
622
|
+
如果为 None,则使用 Encoder 生成的掩码。
|
|
623
|
+
|
|
624
|
+
Returns:
|
|
625
|
+
MotifV1Output: 包含重建图像、损失和辅助信息的对象。
|
|
626
|
+
'''
|
|
627
|
+
enc_out: EncoderOutput = self.encoder(x)
|
|
628
|
+
|
|
629
|
+
q_out: QuantizerOutput = self.quantizer(enc_out.output)
|
|
630
|
+
|
|
631
|
+
dec_out = self.decoder(q_out.z_q, mask)
|
|
632
|
+
|
|
633
|
+
final_mask = mask if mask is not None else enc_out.mask
|
|
634
|
+
info = MotifV1Info(
|
|
635
|
+
perplexity=q_out.perplexity,
|
|
636
|
+
entropy=q_out.entropy,
|
|
637
|
+
indices=q_out.indices,
|
|
638
|
+
mask=final_mask
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
return MotifV1Output(
|
|
642
|
+
reconstruction=dec_out.reconstruction,
|
|
643
|
+
loss=q_out.loss,
|
|
644
|
+
info=info
|
|
645
|
+
)
|
orbit/model/registry.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import inspect
|
|
3
|
+
from typing import Dict, Type, Any, Optional
|
|
4
|
+
|
|
5
|
+
_MODEL_REGISTRY: Dict[str, Type] = {}
|
|
6
|
+
|
|
7
|
+
def register_model(name: Optional[str] = None):
|
|
8
|
+
'''模型注册装饰器。
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
name (str, optional): 模型名称。如果未提供,则使用类名。
|
|
12
|
+
'''
|
|
13
|
+
def decorator(cls):
|
|
14
|
+
model_name = name if name is not None else cls.__name__
|
|
15
|
+
if model_name in _MODEL_REGISTRY:
|
|
16
|
+
print(f"Warning: Model '{model_name}' is already registered. Overwriting.")
|
|
17
|
+
_MODEL_REGISTRY[model_name] = cls
|
|
18
|
+
return cls
|
|
19
|
+
return decorator
|
|
20
|
+
|
|
21
|
+
def build_model(name: str, **kwargs) -> Any:
|
|
22
|
+
'''根据名称构建模型。
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
name (str): 模型名称。
|
|
26
|
+
**kwargs: 传递给模型构造函数的参数。
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Any: 实例化的模型对象。
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
ValueError: 如果模型名称未注册。
|
|
33
|
+
'''
|
|
34
|
+
if name not in _MODEL_REGISTRY:
|
|
35
|
+
raise ValueError(f"Model '{name}' not found in registry. Available models: {list(_MODEL_REGISTRY.keys())}")
|
|
36
|
+
return _MODEL_REGISTRY[name](**kwargs)
|
|
37
|
+
|
|
38
|
+
def list_models() -> list:
|
|
39
|
+
'''列出所有已注册的模型名称。'''
|
|
40
|
+
return list(_MODEL_REGISTRY.keys())
|
|
41
|
+
|
|
42
|
+
def get_model_class(name: str) -> Type:
|
|
43
|
+
'''获取模型类。
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
name (str): 模型名称。
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Type: 模型类。
|
|
50
|
+
'''
|
|
51
|
+
if name not in _MODEL_REGISTRY:
|
|
52
|
+
raise ValueError(f"Model '{name}' not found in registry.")
|
|
53
|
+
return _MODEL_REGISTRY[name]
|
orbit/optim/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
from
|
|
2
|
-
from
|
|
1
|
+
from .sam import SAM
|
|
2
|
+
from .muon import Muon
|