orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. orbit/__init__.py +3 -1
  2. orbit/callback.py +4 -3
  3. orbit/dataset/__init__.py +1 -0
  4. orbit/dataset/cogn.py +138 -0
  5. orbit/dataset/data/cogn_en.jsonl +45 -0
  6. orbit/dataset/data/cogn_zh.jsonl +113 -0
  7. orbit/engine.py +210 -146
  8. orbit/kit/__init__.py +2 -0
  9. orbit/kit/interface.py +154 -0
  10. orbit/kit/wrapper.py +157 -0
  11. orbit/model/__init__.py +5 -0
  12. orbit/model/base.py +125 -0
  13. orbit/model/block/__init__.py +34 -0
  14. orbit/model/block/attention.py +265 -0
  15. orbit/model/block/bio.py +537 -0
  16. orbit/model/block/codebook.py +122 -0
  17. orbit/model/block/conv.py +505 -0
  18. orbit/model/block/embedding.py +252 -0
  19. orbit/model/block/film.py +176 -0
  20. orbit/model/block/fusion.py +335 -0
  21. orbit/model/block/gate.py +334 -0
  22. orbit/model/block/lora.py +776 -0
  23. orbit/model/block/mlp.py +68 -0
  24. orbit/model/block/moe.py +94 -0
  25. orbit/model/block/tcn.py +99 -0
  26. orbit/model/config.py +62 -0
  27. orbit/model/kit/__init__.py +6 -0
  28. orbit/model/kit/discriminator.py +46 -0
  29. orbit/model/kit/losses.py +193 -0
  30. orbit/model/motif/__init__.py +0 -0
  31. orbit/model/motif/vision/__init__.py +0 -0
  32. orbit/model/motif/vision/v1.py +645 -0
  33. orbit/model/registry.py +53 -0
  34. orbit/optim/__init__.py +2 -2
  35. orbit/optim/sam.py +10 -3
  36. orbit/plugin/__init__.py +12 -8
  37. orbit/plugin/board.py +1 -2
  38. orbit/plugin/checkpoint.py +137 -62
  39. orbit/plugin/classification.py +2 -2
  40. orbit/plugin/display_model.py +1 -2
  41. orbit/plugin/early_stopping.py +1 -2
  42. orbit/plugin/ema.py +1 -2
  43. orbit/plugin/gradient_accumulation.py +1 -2
  44. orbit/plugin/lora.py +346 -0
  45. orbit/plugin/memory_estimator.py +1 -2
  46. orbit/plugin/warmup.py +1 -2
  47. orbit/utils/__init__.py +24 -1
  48. orbit/utils/cuda.py +10 -0
  49. orbit/utils/freeze.py +61 -17
  50. orbit/utils/image.py +164 -0
  51. orbit/utils/initialization.py +184 -94
  52. orbit/utils/layer_io.py +66 -7
  53. orbit/utils/lora.py +480 -0
  54. orbit/utils/moe.py +55 -0
  55. orbit/utils/seed.py +3 -19
  56. orbit/utils/sft.py +93 -0
  57. orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
  58. orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
  59. orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
  60. orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
  61. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
  62. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,645 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Tuple, Optional
6
+
7
+ from orbit.model import BaseBlock, register_model
8
+ from orbit.utils.image import pad_to_patch_size
9
+ from orbit.model.block import (
10
+ ConvBlock, ResBasicBlock,
11
+ SpatialMultiHeadAttention, AttentionOutput,
12
+ MRoPEInterleavedEmbedding,
13
+ LFQ, QuantizerOutput
14
+ )
15
+
16
+ def dynamic_collate_fn(batch):
17
+ '''
18
+ 将不同尺寸的图片 Batch 整理为统一尺寸 (Padding),并生成有效区域掩码。
19
+
20
+ Returns:
21
+ padded_batch: [B, C, Max_H, Max_W]
22
+ mask_batch: [B, 1, Max_H, Max_W] (1.0 = Valid, 0.0 = Padding)
23
+ '''
24
+ images = [item for item in batch]
25
+
26
+ max_h = max([img.shape[1] for img in images])
27
+ max_w = max([img.shape[2] for img in images])
28
+
29
+ stride = 16 # Patch Size
30
+ max_h = ((max_h + stride - 1) // stride) * stride
31
+ max_w = ((max_w + stride - 1) // stride) * stride
32
+
33
+ batch_size = len(images)
34
+ channels = images[0].shape[0]
35
+
36
+ padded_batch = torch.zeros(batch_size, channels, max_h, max_w)
37
+ mask_batch = torch.zeros(batch_size, 1, max_h, max_w)
38
+
39
+ for i, img in enumerate(images):
40
+ h, w = img.shape[1], img.shape[2]
41
+ padded_batch[i, :, :h, :w] = img
42
+ mask_batch[i, :, :h, :w] = 1.0
43
+
44
+ return padded_batch, mask_batch
45
+
46
+ @dataclass
47
+ class EncoderOutput:
48
+ '''
49
+ VQGAN Encoder 的输出数据类。
50
+
51
+ Attributes:
52
+ output (torch.Tensor): 编码后的潜在变量,形状为 [B, out_channels, h_latent, w_latent]。
53
+ mask (torch.Tensor): 原始分辨率的有效区域掩码,形状为 [B, 1, H_pad, W_pad]。
54
+ input_shape (Tuple[int, int]): 原始输入图像的尺寸 (H, W)。
55
+ padded_shape (Tuple[int, int]): Padding 后的图像尺寸 (H_pad, W_pad)。
56
+ '''
57
+ output: torch.Tensor
58
+ mask: torch.Tensor
59
+ input_shape: Tuple[int, int]
60
+ padded_shape: Tuple[int, int]
61
+
62
+ @dataclass
63
+ class DecoderOutput:
64
+ '''
65
+ VQGAN Decoder 的输出数据类。
66
+
67
+ Attributes:
68
+ reconstruction (torch.Tensor): 重建的图像 [B, 3, H, W]。
69
+ '''
70
+ reconstruction: torch.Tensor
71
+
72
+ @dataclass
73
+ class MotifV1Info:
74
+ '''
75
+ MotifV1 的辅助信息数据类。
76
+
77
+ Attributes:
78
+ perplexity (torch.Tensor): 码本困惑度。
79
+ entropy (torch.Tensor): 码本熵。
80
+ indices (torch.Tensor): 量化索引 [B, H, W]。
81
+ mask (Optional[torch.Tensor]): 有效区域掩码 [B, 1, H, W]。
82
+ '''
83
+ perplexity: torch.Tensor
84
+ entropy: torch.Tensor
85
+ indices: torch.Tensor
86
+ mask: Optional[torch.Tensor]
87
+
88
+ @dataclass
89
+ class MotifV1Output:
90
+ '''
91
+ MotifV1 主模型的输出数据类。
92
+
93
+ Attributes:
94
+ reconstruction (torch.Tensor): 重建的图像 [B, 3, H, W]。
95
+ loss (torch.Tensor): 量化损失 (Commitment Loss)。
96
+ info (MotifV1Info): 辅助信息对象。
97
+ '''
98
+ reconstruction: torch.Tensor
99
+ loss: torch.Tensor
100
+ info: MotifV1Info
101
+
102
+ @register_model()
103
+ class MotifV1Encoder(BaseBlock):
104
+ '''
105
+ 支持可变输入尺寸和 2D RoPE 的 VQGAN Encoder。
106
+
107
+ 该编码器通过一系列下采样卷积块和残差块提取特征,并在中间层应用空间多头注意力机制(Spatial Multi-Head Attention)
108
+ 和二维旋转位置编码(2D RoPE)。它能够处理任意尺寸的输入图像,通过 padding 确保尺寸满足 patch_size 的要求。
109
+ '''
110
+ def __init__(
111
+ self,
112
+ in_channels: int = 3,
113
+ hidden_dim: int = 128,
114
+ out_channels: int = 256,
115
+ patch_size: int = 16,
116
+ num_res_blocks: int = 1,
117
+ num_heads: int = 4,
118
+ dropout: float = 0.0,
119
+ rope_max_len: int = 4096,
120
+ max_channels: int = 256
121
+ ):
122
+ '''
123
+ 初始化 MotifV1Encoder
124
+
125
+ Args:
126
+ in_channels (int): 输入图像的通道数。默认为 3。
127
+ hidden_dim (int): 隐藏层的初始维度。默认为 128。
128
+ out_channels (int): 输出潜在变量的通道数。默认为 256。
129
+ patch_size (int): Patch 大小,必须是 2 的幂。用于计算下采样次数。默认为 16。
130
+ num_res_blocks (int): 每个下采样阶段的残差块数量。默认为 1。
131
+ num_heads (int): 注意力机制的头数。默认为 4。
132
+ dropout (float): Dropout 概率。默认为 0.0。
133
+ rope_max_len (int): RoPE 的最大长度限制。默认为 4096。
134
+ max_channels (int): 隐藏层通道数的最大值。默认为 256。
135
+
136
+ Raises:
137
+ ValueError: 如果 patch_size 不是 2 的幂。
138
+ '''
139
+ super().__init__()
140
+
141
+ if (patch_size & (patch_size - 1)) != 0:
142
+ raise ValueError(f"patch_size must be a power of 2, got {patch_size}")
143
+
144
+ self.patch_size = patch_size
145
+ self.num_downsamples = int(torch.log2(torch.tensor(patch_size)))
146
+
147
+ self.conv_in = ConvBlock(
148
+ in_channels, hidden_dim, kernel_size=3, padding=1, norm=None, activation=None
149
+ )
150
+
151
+ self.down_blocks = nn.ModuleList()
152
+ curr_dim = hidden_dim
153
+
154
+ for _ in range(self.num_downsamples):
155
+ blocks = []
156
+ for _ in range(num_res_blocks):
157
+ blocks.append(ResBasicBlock(
158
+ curr_dim,
159
+ curr_dim,
160
+ dropout=dropout,
161
+ norm='group',
162
+ activation='silu',
163
+ variant='pre_act'
164
+ ))
165
+ self.down_blocks.append(nn.Sequential(*blocks))
166
+
167
+ next_dim = min(curr_dim * 2, max_channels)
168
+
169
+ self.down_blocks.append(ConvBlock(
170
+ curr_dim, next_dim, kernel_size=3, stride=2, padding=1,
171
+ norm=None, activation=None
172
+ ))
173
+ curr_dim = next_dim
174
+
175
+ self.mid_res1 = ResBasicBlock(
176
+ curr_dim,
177
+ curr_dim,
178
+ dropout=dropout,
179
+ norm='group',
180
+ activation='silu',
181
+ variant='pre_act'
182
+ )
183
+
184
+ self.mid_attn = SpatialMultiHeadAttention(
185
+ hidden_size=curr_dim,
186
+ num_heads=num_heads,
187
+ use_qk_norm=True
188
+ )
189
+
190
+ self.rope = MRoPEInterleavedEmbedding(
191
+ model_dim=curr_dim // num_heads,
192
+ num_axes=2,
193
+ max_len=rope_max_len
194
+ )
195
+
196
+ self.mid_res2 = ResBasicBlock(
197
+ curr_dim,
198
+ curr_dim,
199
+ dropout=dropout,
200
+ norm='group',
201
+ activation='silu',
202
+ variant='pre_act'
203
+ )
204
+
205
+ self.norm_out = nn.GroupNorm(32, curr_dim)
206
+ self.act_out = nn.SiLU(inplace=True)
207
+ self.conv_out = nn.Conv2d(curr_dim, out_channels, kernel_size=3, padding=1)
208
+
209
+ def _build_grid(self, h: int, w: int, device: torch.device) -> torch.Tensor:
210
+ '''
211
+ 构建 2D 网格坐标。
212
+
213
+ Args:
214
+ h (int): 网格的高度。
215
+ w (int): 网格的宽度。
216
+ device (torch.device): 张量所在的设备。
217
+
218
+ Returns:
219
+ torch.Tensor: 网格坐标张量,形状为 [h * w, 2]。
220
+ '''
221
+ y = torch.arange(h, device=device)
222
+ x = torch.arange(w, device=device)
223
+ grid_y, grid_x = torch.meshgrid(y, x, indexing='ij')
224
+ grid = torch.stack([grid_y, grid_x], dim=-1).reshape(-1, 2)
225
+ return grid
226
+
227
+ def forward(self, x: torch.Tensor) -> EncoderOutput:
228
+ '''
229
+ 前向传播函数。
230
+
231
+ Args:
232
+ x (torch.Tensor): 输入图像张量,形状为 [B, C, H, W]。
233
+
234
+ Returns:
235
+ EncoderOutput: 包含编码输出、掩码和尺寸信息的对象。
236
+ '''
237
+ B, C, H, W = x.shape
238
+
239
+ patch_out = pad_to_patch_size(x, (self.patch_size, self.patch_size))
240
+ x_in = patch_out.output
241
+ mask = patch_out.mask
242
+
243
+ h = self.conv_in(x_in)
244
+
245
+ for layer in self.down_blocks:
246
+ h = layer(h)
247
+
248
+ h = self.mid_res1(h)
249
+
250
+ B_feat, C_feat, H_feat, W_feat = h.shape
251
+ h_flat = h.permute(0, 2, 3, 1).reshape(B_feat, H_feat * W_feat, C_feat)
252
+
253
+ positions = self._build_grid(H_feat, W_feat, h.device)
254
+ positions = positions.unsqueeze(0).expand(B_feat, -1, -1)
255
+
256
+ attn_out: AttentionOutput = self.mid_attn(
257
+ hidden_states=h_flat,
258
+ positions=positions,
259
+ rotary_emb=self.rope
260
+ )
261
+
262
+ h = attn_out.output.view(B_feat, H_feat, W_feat, C_feat).permute(0, 3, 1, 2)
263
+ h = self.mid_res2(h)
264
+
265
+ h = self.norm_out(h)
266
+ h = self.act_out(h)
267
+ z = self.conv_out(h)
268
+
269
+ return EncoderOutput(
270
+ output=z,
271
+ mask=mask,
272
+ input_shape=(H, W),
273
+ padded_shape=patch_out.output.shape[-2:]
274
+ )
275
+
276
+ @register_model()
277
+ class MotifV1Decoder(BaseBlock):
278
+ '''
279
+ 基于 PixelShuffle 的 VQGAN Decoder。
280
+
281
+ 结构:Input Conv -> Mid Block -> Upsample Stack (PixelShuffle) -> Output Conv
282
+ 支持根据 Encoder 提供的掩码自动裁剪输出图像。
283
+ '''
284
+ def __init__(
285
+ self,
286
+ in_channels: int = 256,
287
+ hidden_dim: int = 256,
288
+ out_channels: int = 3,
289
+ patch_size: int = 16,
290
+ num_res_blocks: int = 1,
291
+ num_heads: int = 4,
292
+ dropout: float = 0.0,
293
+ rope_max_len: int = 4096,
294
+ ):
295
+ '''
296
+ 初始化 MotifV1Decoder
297
+
298
+ Args:
299
+ in_channels (int): 输入潜在变量的通道数。默认为 256。
300
+ hidden_dim (int): 隐藏层的初始维度。默认为 256。
301
+ out_channels (int): 输出图像的通道数。默认为 3。
302
+ patch_size (int): Patch 大小,必须是 2 的幂。用于计算上采样次数。默认为 16。
303
+ num_res_blocks (int): 每个上采样阶段的残差块数量。默认为 1。
304
+ num_heads (int): 注意力机制的头数。默认为 4。
305
+ dropout (float): Dropout 概率。默认为 0.0。
306
+ rope_max_len (int): RoPE 的最大长度限制。默认为 4096。
307
+
308
+ Raises:
309
+ ValueError: 如果 patch_size 不是 2 的幂。
310
+ '''
311
+ super().__init__()
312
+
313
+ if (patch_size & (patch_size - 1)) != 0:
314
+ raise ValueError(f"patch_size must be a power of 2, got {patch_size}")
315
+
316
+ self.patch_size = patch_size
317
+ self.num_upsamples = int(torch.log2(torch.tensor(patch_size)))
318
+
319
+ self.conv_in = ConvBlock(
320
+ in_channels, hidden_dim, kernel_size=3, padding=1, norm=None, activation=None
321
+ )
322
+
323
+ self.mid_res1 = ResBasicBlock(
324
+ hidden_dim,
325
+ hidden_dim,
326
+ dropout=dropout,
327
+ norm='group',
328
+ activation='silu',
329
+ variant='pre_act'
330
+ )
331
+
332
+ self.mid_attn = SpatialMultiHeadAttention(
333
+ hidden_size=hidden_dim,
334
+ num_heads=num_heads,
335
+ use_qk_norm=True
336
+ )
337
+
338
+ self.rope = MRoPEInterleavedEmbedding(
339
+ model_dim=hidden_dim // num_heads,
340
+ num_axes=2,
341
+ max_len=rope_max_len
342
+ )
343
+
344
+ self.mid_res2 = ResBasicBlock(
345
+ hidden_dim,
346
+ hidden_dim,
347
+ dropout=dropout,
348
+ norm='group',
349
+ activation='silu',
350
+ variant='pre_act'
351
+ )
352
+
353
+ self.up_blocks = nn.ModuleList()
354
+ curr_dim = hidden_dim
355
+
356
+ for _ in range(self.num_upsamples):
357
+ blocks = []
358
+ for _ in range(num_res_blocks):
359
+ blocks.append(ResBasicBlock(
360
+ curr_dim,
361
+ curr_dim,
362
+ dropout=dropout,
363
+ norm='group',
364
+ activation='silu',
365
+ variant='pre_act'
366
+ ))
367
+ self.up_blocks.append(nn.Sequential(*blocks))
368
+
369
+ next_dim = curr_dim // 2
370
+
371
+ if next_dim < 64: next_dim = 64
372
+
373
+ self.up_blocks.append(nn.Sequential(
374
+ ConvBlock(curr_dim, next_dim * 4, kernel_size=3, padding=1, norm='group', activation='silu'),
375
+ nn.PixelShuffle(2)
376
+ ))
377
+ curr_dim = next_dim
378
+
379
+ self.norm_out = nn.GroupNorm(32, curr_dim)
380
+ self.act_out = nn.SiLU(inplace=True)
381
+ self.conv_out = nn.Conv2d(curr_dim, out_channels, kernel_size=3, padding=1)
382
+
383
+ def _build_grid(self, h: int, w: int, device: torch.device) -> torch.Tensor:
384
+ '''
385
+ 构建 2D 网格坐标。
386
+
387
+ Args:
388
+ h (int): 网格的高度。
389
+ w (int): 网格的宽度。
390
+ device (torch.device): 张量所在的设备。
391
+
392
+ Returns:
393
+ torch.Tensor: 网格坐标张量,形状为 [h * w, 2]。
394
+ '''
395
+ y = torch.arange(h, device=device)
396
+ x = torch.arange(w, device=device)
397
+ grid_y, grid_x = torch.meshgrid(y, x, indexing='ij')
398
+ grid = torch.stack([grid_y, grid_x], dim=-1).reshape(-1, 2)
399
+ return grid
400
+
401
+ def forward(self, z: torch.Tensor, mask: torch.Tensor = None) -> DecoderOutput:
402
+ '''
403
+ Args:
404
+ z (torch.Tensor): 量化后的 Latent [B, C, H, W]
405
+ mask (torch.Tensor, optional): Encoder 输出的掩码 [B, 1, H_orig, W_orig]。
406
+ 如果提供,将用于裁剪 Padding。
407
+
408
+ Returns:
409
+ DecoderOutput: 包含重建图像的对象。
410
+ '''
411
+ B, C, H, W = z.shape
412
+
413
+ h = self.conv_in(z)
414
+
415
+ h = self.mid_res1(h)
416
+
417
+ B_feat, C_feat, H_feat, W_feat = h.shape
418
+ h_flat = h.permute(0, 2, 3, 1).reshape(B_feat, H_feat * W_feat, C_feat)
419
+
420
+ positions = self._build_grid(H_feat, W_feat, h.device)
421
+ positions = positions.unsqueeze(0).expand(B_feat, -1, -1)
422
+
423
+ attn_out = self.mid_attn(
424
+ hidden_states=h_flat,
425
+ positions=positions,
426
+ rotary_emb=self.rope
427
+ )
428
+
429
+ h = attn_out.output.view(B_feat, H_feat, W_feat, C_feat).permute(0, 3, 1, 2)
430
+ h = self.mid_res2(h)
431
+
432
+ for layer in self.up_blocks: h = layer(h)
433
+
434
+ h = self.norm_out(h)
435
+ h = self.act_out(h)
436
+ recon = self.conv_out(h) # [B, 3, H_pad, W_pad]
437
+
438
+ return DecoderOutput(reconstruction=recon)
439
+
440
+ @register_model()
441
+ class MotifV1(BaseBlock):
442
+ '''
443
+ Motif-V1 主模型。
444
+ 整合 Encoder, Decoder 和 LFQ Quantizer。
445
+ '''
446
+ def __init__(
447
+ self,
448
+ in_channels: int = 3,
449
+ out_channels: int = 3,
450
+ latent_dim: int = 256,
451
+ codebook_dim: int = 18,
452
+ patch_size: int = 16,
453
+
454
+ enc_hidden_dim: int = 64,
455
+ enc_num_res_blocks: int = 1,
456
+ enc_num_heads: int = 4,
457
+ enc_max_channels: int = 256,
458
+
459
+ dec_hidden_dim: int = 256,
460
+ dec_num_res_blocks: int = 1,
461
+ dec_num_heads: int = 4,
462
+
463
+ dropout: float = 0.0,
464
+ rope_max_len: int = 4096,
465
+
466
+ entropy_weight: float = 0.1,
467
+ commitment_weight: float = 0.25,
468
+ ):
469
+ '''
470
+ 初始化 MotifV1 主模型。
471
+
472
+ Args:
473
+ in_channels (int): 输入图像的通道数。默认为 3。
474
+ out_channels (int): 输出图像的通道数。默认为 3。
475
+ latent_dim (int): Encoder 输出和 Decoder 输入的潜在变量维度。默认为 256。
476
+ codebook_dim (int): LFQ 码本的比特数(维度)。默认为 18。
477
+ patch_size (int): Patch 大小,必须是 2 的幂。默认为 16。
478
+ enc_hidden_dim (int): Encoder 隐藏层的初始维度。默认为 64。
479
+ enc_num_res_blocks (int): Encoder 每个阶段的残差块数量。默认为 1。
480
+ enc_num_heads (int): Encoder 注意力机制的头数。默认为 4。
481
+ enc_max_channels (int): Encoder 隐藏层通道数的最大值。默认为 256。
482
+ dec_hidden_dim (int): Decoder 隐藏层的初始维度。默认为 256。
483
+ dec_num_res_blocks (int): Decoder 每个阶段的残差块数量。默认为 1。
484
+ dec_num_heads (int): Decoder 注意力机制的头数。默认为 4。
485
+ dropout (float): Dropout 概率。默认为 0.0。
486
+ rope_max_len (int): RoPE 的最大长度限制。默认为 4096。
487
+ entropy_weight (float): LFQ 熵损失的权重。默认为 0.1。
488
+ commitment_weight (float): LFQ 承诺损失的权重。默认为 0.25。
489
+ '''
490
+ super().__init__()
491
+
492
+ self.codebook_dim = codebook_dim
493
+
494
+ self.encoder = MotifV1Encoder(
495
+ in_channels=in_channels,
496
+ hidden_dim=enc_hidden_dim,
497
+ out_channels=latent_dim,
498
+ patch_size=patch_size,
499
+ num_res_blocks=enc_num_res_blocks,
500
+ num_heads=enc_num_heads,
501
+ dropout=dropout,
502
+ rope_max_len=rope_max_len,
503
+ max_channels=enc_max_channels
504
+ )
505
+
506
+ self.quantizer = LFQ(
507
+ latent_dim=latent_dim,
508
+ codebook_dim=codebook_dim,
509
+ entropy_weight=entropy_weight,
510
+ commitment_weight=commitment_weight
511
+ )
512
+
513
+ self.decoder = MotifV1Decoder(
514
+ in_channels=latent_dim,
515
+ hidden_dim=dec_hidden_dim,
516
+ out_channels=out_channels,
517
+ patch_size=patch_size,
518
+ num_res_blocks=dec_num_res_blocks,
519
+ num_heads=dec_num_heads,
520
+ dropout=dropout,
521
+ rope_max_len=rope_max_len
522
+ )
523
+
524
+ self.apply(self._init_weights)
525
+
526
+ def _init_weights(self, m):
527
+ '''
528
+ 初始化模型权重。
529
+
530
+ Args:
531
+ m (nn.Module): 需要初始化的模块。
532
+
533
+ Note:
534
+ - Conv2d, Linear, Conv1d 使用 Xavier Uniform 初始化,偏置为 0。
535
+ - LayerNorm, GroupNorm, BatchNorm2d 权重为 1,偏置为 0。
536
+ '''
537
+ if isinstance(m, (nn.Conv2d, nn.Linear, nn.Conv1d)):
538
+ nn.init.xavier_uniform_(m.weight)
539
+ if m.bias is not None:
540
+ nn.init.constant_(m.bias, 0)
541
+ elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
542
+ nn.init.constant_(m.weight, 1)
543
+ nn.init.constant_(m.bias, 0)
544
+
545
+ @property
546
+ def last_layer_weights(self):
547
+ '''用于 VQGANLoss 计算自适应权重'''
548
+ return self.decoder.conv_out.weight
549
+
550
+ def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
551
+ '''
552
+ 仅编码:将图像编码为量化索引。
553
+
554
+ Args:
555
+ x (torch.Tensor): 输入图像,形状为 [B, C, H, W]。
556
+
557
+ Returns:
558
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
559
+ - indices (torch.Tensor): 量化索引,形状为 [B, H, W]。
560
+ - mask (torch.Tensor): 有效区域掩码,形状为 [B, 1, H, W]。
561
+ - z_q (torch.Tensor): 量化后的潜在变量(用于调试),形状为 [B, C, H, W]。
562
+ '''
563
+ enc_out = self.encoder(x)
564
+ q_out = self.quantizer(enc_out.output)
565
+ return q_out.indices, enc_out.mask, q_out.z_q
566
+
567
+ def decode(self, indices: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
568
+ '''
569
+ 仅解码:将量化索引解码为图像(通常用于 Transformer 推理)。
570
+
571
+ Args:
572
+ indices (torch.Tensor): 量化索引,形状为 [B, H, W]。
573
+ mask (torch.Tensor, optional): 有效区域掩码,用于裁剪输出。默认为 None。
574
+
575
+ Returns:
576
+ torch.Tensor: 重建的图像,形状为 [B, 3, H, W]。
577
+ '''
578
+ z_q = self.indices_to_codes(indices) # [B, H, W, Codebook_Dim]
579
+
580
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
581
+
582
+ if hasattr(self.quantizer, 'project_out'):
583
+ z_q_permuted = z_q.permute(0, 2, 3, 1)
584
+ z_projected = self.quantizer.project_out(z_q_permuted)
585
+ # [B, H, W, Latent_Dim] -> [B, Latent_Dim, H, W]
586
+ z_q = z_projected.permute(0, 3, 1, 2).contiguous()
587
+
588
+ dec_out = self.decoder(z_q, mask)
589
+ return dec_out.reconstruction
590
+
591
+ def indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor:
592
+ '''
593
+ LFQ 核心工具:将整数索引还原为二值化向量。
594
+
595
+ 将整数索引转换为二进制码(-1 或 1)。
596
+
597
+ Args:
598
+ indices (torch.Tensor): 量化索引,形状为 [B, H, W]。
599
+
600
+ Returns:
601
+ torch.Tensor: 二值化向量,形状为 [B, H, W, Codebook_Dim],值为 -1.0 或 1.0。
602
+ '''
603
+ B, H, W = indices.shape
604
+ basis = self.quantizer.basis # [Codebook_Dim]
605
+
606
+ basis_long = basis.long()
607
+ indices_long = indices.long().unsqueeze(-1)
608
+
609
+ is_one = (indices_long & basis_long) > 0
610
+
611
+ codes = torch.where(is_one, torch.tensor(1.0, device=indices.device), torch.tensor(-1.0, device=indices.device))
612
+
613
+ return codes
614
+
615
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> MotifV1Output:
616
+ '''
617
+ 前向传播函数。
618
+
619
+ Args:
620
+ x (torch.Tensor): 已经 Pad 过的输入图像,形状为 [B, 3, H, W]。
621
+ mask (torch.Tensor, optional): 外部传入的有效区域掩码(通常由 collate_fn 生成),形状为 [B, 1, H, W]。
622
+ 如果为 None,则使用 Encoder 生成的掩码。
623
+
624
+ Returns:
625
+ MotifV1Output: 包含重建图像、损失和辅助信息的对象。
626
+ '''
627
+ enc_out: EncoderOutput = self.encoder(x)
628
+
629
+ q_out: QuantizerOutput = self.quantizer(enc_out.output)
630
+
631
+ dec_out = self.decoder(q_out.z_q, mask)
632
+
633
+ final_mask = mask if mask is not None else enc_out.mask
634
+ info = MotifV1Info(
635
+ perplexity=q_out.perplexity,
636
+ entropy=q_out.entropy,
637
+ indices=q_out.indices,
638
+ mask=final_mask
639
+ )
640
+
641
+ return MotifV1Output(
642
+ reconstruction=dec_out.reconstruction,
643
+ loss=q_out.loss,
644
+ info=info
645
+ )
@@ -0,0 +1,53 @@
1
+ import sys
2
+ import inspect
3
+ from typing import Dict, Type, Any, Optional
4
+
5
+ _MODEL_REGISTRY: Dict[str, Type] = {}
6
+
7
+ def register_model(name: Optional[str] = None):
8
+ '''模型注册装饰器。
9
+
10
+ Args:
11
+ name (str, optional): 模型名称。如果未提供,则使用类名。
12
+ '''
13
+ def decorator(cls):
14
+ model_name = name if name is not None else cls.__name__
15
+ if model_name in _MODEL_REGISTRY:
16
+ print(f"Warning: Model '{model_name}' is already registered. Overwriting.")
17
+ _MODEL_REGISTRY[model_name] = cls
18
+ return cls
19
+ return decorator
20
+
21
+ def build_model(name: str, **kwargs) -> Any:
22
+ '''根据名称构建模型。
23
+
24
+ Args:
25
+ name (str): 模型名称。
26
+ **kwargs: 传递给模型构造函数的参数。
27
+
28
+ Returns:
29
+ Any: 实例化的模型对象。
30
+
31
+ Raises:
32
+ ValueError: 如果模型名称未注册。
33
+ '''
34
+ if name not in _MODEL_REGISTRY:
35
+ raise ValueError(f"Model '{name}' not found in registry. Available models: {list(_MODEL_REGISTRY.keys())}")
36
+ return _MODEL_REGISTRY[name](**kwargs)
37
+
38
+ def list_models() -> list:
39
+ '''列出所有已注册的模型名称。'''
40
+ return list(_MODEL_REGISTRY.keys())
41
+
42
+ def get_model_class(name: str) -> Type:
43
+ '''获取模型类。
44
+
45
+ Args:
46
+ name (str): 模型名称。
47
+
48
+ Returns:
49
+ Type: 模型类。
50
+ '''
51
+ if name not in _MODEL_REGISTRY:
52
+ raise ValueError(f"Model '{name}' not found in registry.")
53
+ return _MODEL_REGISTRY[name]
orbit/optim/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from orbit.optim.sam import SAM
2
- from orbit.optim.muon import Muon
1
+ from .sam import SAM
2
+ from .muon import Muon