orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. orbit/__init__.py +3 -1
  2. orbit/callback.py +4 -3
  3. orbit/dataset/__init__.py +1 -0
  4. orbit/dataset/cogn.py +138 -0
  5. orbit/dataset/data/cogn_en.jsonl +45 -0
  6. orbit/dataset/data/cogn_zh.jsonl +113 -0
  7. orbit/engine.py +210 -146
  8. orbit/kit/__init__.py +2 -0
  9. orbit/kit/interface.py +154 -0
  10. orbit/kit/wrapper.py +157 -0
  11. orbit/model/__init__.py +5 -0
  12. orbit/model/base.py +125 -0
  13. orbit/model/block/__init__.py +34 -0
  14. orbit/model/block/attention.py +265 -0
  15. orbit/model/block/bio.py +537 -0
  16. orbit/model/block/codebook.py +122 -0
  17. orbit/model/block/conv.py +505 -0
  18. orbit/model/block/embedding.py +252 -0
  19. orbit/model/block/film.py +176 -0
  20. orbit/model/block/fusion.py +335 -0
  21. orbit/model/block/gate.py +334 -0
  22. orbit/model/block/lora.py +776 -0
  23. orbit/model/block/mlp.py +68 -0
  24. orbit/model/block/moe.py +94 -0
  25. orbit/model/block/tcn.py +99 -0
  26. orbit/model/config.py +62 -0
  27. orbit/model/kit/__init__.py +6 -0
  28. orbit/model/kit/discriminator.py +46 -0
  29. orbit/model/kit/losses.py +193 -0
  30. orbit/model/motif/__init__.py +0 -0
  31. orbit/model/motif/vision/__init__.py +0 -0
  32. orbit/model/motif/vision/v1.py +645 -0
  33. orbit/model/registry.py +53 -0
  34. orbit/optim/__init__.py +2 -2
  35. orbit/optim/sam.py +10 -3
  36. orbit/plugin/__init__.py +12 -8
  37. orbit/plugin/board.py +1 -2
  38. orbit/plugin/checkpoint.py +137 -62
  39. orbit/plugin/classification.py +2 -2
  40. orbit/plugin/display_model.py +1 -2
  41. orbit/plugin/early_stopping.py +1 -2
  42. orbit/plugin/ema.py +1 -2
  43. orbit/plugin/gradient_accumulation.py +1 -2
  44. orbit/plugin/lora.py +346 -0
  45. orbit/plugin/memory_estimator.py +1 -2
  46. orbit/plugin/warmup.py +1 -2
  47. orbit/utils/__init__.py +24 -1
  48. orbit/utils/cuda.py +10 -0
  49. orbit/utils/freeze.py +61 -17
  50. orbit/utils/image.py +164 -0
  51. orbit/utils/initialization.py +184 -94
  52. orbit/utils/layer_io.py +66 -7
  53. orbit/utils/lora.py +480 -0
  54. orbit/utils/moe.py +55 -0
  55. orbit/utils/seed.py +3 -19
  56. orbit/utils/sft.py +93 -0
  57. orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
  58. orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
  59. orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
  60. orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
  61. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
  62. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,68 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from orbit.model import BaseBlock, register_model
6
+
7
+
8
+ @register_model()
9
+ class MLP(BaseBlock):
10
+ ''' 多层感知机 (MLP) 模块。
11
+
12
+ 支持标准 MLP 和门控 MLP (Gated MLP) 结构。
13
+
14
+ Args:
15
+ in_features (int): 输入特征维度。
16
+ hidden_features (int): 隐藏层特征维度。
17
+ out_features (int, optional): 输出特征维度。如果为 None,则等于 in_features。默认为 None。
18
+ gate (bool, optional): 是否使用门控机制。默认为 False。
19
+ dropout (float, optional): Dropout 概率。默认为 0.0。
20
+ '''
21
+ def __init__(
22
+ self,
23
+ in_features: int,
24
+ hidden_features: int,
25
+ out_features: int = None,
26
+ bias: bool = True,
27
+ use_gate: bool = False,
28
+ dropout: float = 0.0
29
+ ):
30
+ super(MLP, self).__init__()
31
+
32
+ out_features = out_features or in_features
33
+
34
+ self.in_features = in_features
35
+ self.hidden_features = hidden_features
36
+ self.out_features = out_features
37
+ self.bias = bias
38
+ self.use_gate = use_gate
39
+ self.dropout = nn.Dropout(dropout)
40
+ self.act = nn.SiLU()
41
+
42
+ if use_gate:
43
+ self.gate_proj = nn.Linear(in_features, hidden_features, bias=bias)
44
+ self.up_proj = nn.Linear(in_features, hidden_features, bias=bias)
45
+ self.down_proj = nn.Linear(hidden_features, out_features, bias=bias)
46
+ else:
47
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
48
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
49
+
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ ''' 前向传播。
53
+
54
+ Args:
55
+ x (torch.Tensor): 输入张量。
56
+
57
+ Returns:
58
+ torch.Tensor: 输出张量。
59
+ '''
60
+ if self.use_gate:
61
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
62
+ else:
63
+ x = self.fc1(x)
64
+ x = self.act(x)
65
+ x = self.dropout(x)
66
+ x = self.fc2(x)
67
+ x = self.dropout(x)
68
+ return x
@@ -0,0 +1,94 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from orbit.model import BaseBlock, register_model
6
+ from orbit.model.block.mlp import MLP
7
+ from orbit.model.block.gate import TopKGate
8
+
9
+
10
+ @register_model()
11
+ class MoE(BaseBlock):
12
+ def __init__(
13
+ self,
14
+ in_features: int,
15
+ out_features: int,
16
+ num_experts: int = 4,
17
+ top_k: int = 2,
18
+ hidden_features: int = None,
19
+ dropout: float = 0.1,
20
+ use_gate: bool = False,
21
+ use_mlp_router: bool = False
22
+ ):
23
+ super(MoE, self).__init__()
24
+
25
+ hidden_features = hidden_features or in_features
26
+
27
+ self.in_features = in_features
28
+ self.hidden_features = hidden_features
29
+ self.out_features = out_features
30
+ self.dropout = dropout
31
+ self.num_experts = num_experts
32
+ self.top_k = top_k
33
+ self.use_gate = use_gate
34
+ self.use_mlp_router = use_mlp_router
35
+
36
+ self.router = TopKGate(
37
+ in_features=in_features,
38
+ out_features=num_experts,
39
+ k=top_k,
40
+ use_mlp=use_mlp_router,
41
+ hidden_features=hidden_features,
42
+ post_softmax=True
43
+ )
44
+
45
+ self.experts = nn.ModuleList([
46
+ MLP(
47
+ in_features=in_features,
48
+ hidden_features=hidden_features,
49
+ out_features=out_features,
50
+ dropout=dropout,
51
+ use_gate=use_gate
52
+ )
53
+ for _ in range(num_experts)
54
+ ])
55
+
56
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
57
+ """
58
+ 前向传播。
59
+
60
+ Args:
61
+ x (torch.Tensor): 输入张量。Shape: [batch_size, seq_len, in_dim]
62
+
63
+ Returns:
64
+ tuple[torch.Tensor, torch.Tensor]:
65
+ - 输出张量。Shape: [batch_size, seq_len, out_features]
66
+ - 辅助损失 (Auxiliary Loss)。标量。
67
+ """
68
+ batch_size, seq_len, dim = x.shape
69
+
70
+ x_flat = x.view(-1, dim)
71
+
72
+ gate_output = self.router(x_flat)
73
+
74
+ routing_probs = F.softmax(gate_output.logits, dim=-1)
75
+ selection_mask = torch.zeros_like(routing_probs).scatter_(1, gate_output.indices, 1.0)
76
+
77
+ fraction = selection_mask.mean(dim=0)
78
+ mean_probs = routing_probs.mean(dim=0)
79
+ aux_loss = self.num_experts * (fraction * mean_probs).sum()
80
+
81
+ final_output = torch.zeros(batch_size * seq_len, self.out_features, device=x.device, dtype=x.dtype)
82
+
83
+ for i, expert in enumerate(self.experts):
84
+ mask = (gate_output.indices == i)
85
+ batch_idx, k_idx = torch.where(mask)
86
+
87
+ if batch_idx.numel() == 0: continue
88
+
89
+ inp = x_flat[batch_idx]
90
+ expert_out = expert(inp)
91
+ w = gate_output.values[batch_idx, k_idx].unsqueeze(-1)
92
+ final_output.index_add_(0, batch_idx, expert_out * w)
93
+
94
+ return final_output.view(batch_size, seq_len, self.out_features), aux_loss
@@ -0,0 +1,99 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Optional
4
+
5
+ from orbit.model import BaseBlock, register_model
6
+ from orbit.model.block.conv import CausalConv1d
7
+
8
+ @register_model()
9
+ class TCN(BaseBlock):
10
+ '''
11
+ 时间卷积网络 (Temporal Convolutional Network, TCN)。
12
+
13
+ 由一系列因果空洞卷积层 (Causal Dilated Convolutions) 组成。
14
+ 支持手动指定每层通道数或根据目标感受野自动构建。
15
+ '''
16
+
17
+ def __init__(
18
+ self,
19
+ in_channels: int,
20
+ num_channels: Optional[List[int]] = None,
21
+ out_channels: Optional[int] = None,
22
+ step: Optional[int] = None,
23
+ kernel_size: int = 3,
24
+ dropout: float = 0.2,
25
+ use_res: bool = True,
26
+ norm: str = None,
27
+ activation: str = 'leaky_relu',
28
+ leaky_relu: float = 0.1
29
+ ):
30
+ '''
31
+ 初始化 TCN 模块。
32
+
33
+ Args:
34
+ in_channels (int): 输入通道数。
35
+ num_channels (List[int], optional): 每一层的输出通道数列表。如果提供此参数,将忽略 out_channels 和 step。
36
+ out_channels (int, optional): 自动构建模式下的统一输出通道数。
37
+ step (int, optional): 自动构建模式下的目标感受野 (时间步长)。
38
+ kernel_size (int, optional): 卷积核大小。默认为 3。
39
+ dropout (float, optional): Dropout 概率。默认为 0.2。
40
+ use_res (bool, optional): 是否使用残差连接。默认为 True。
41
+ norm (str, optional): 归一化类型 (传递给 CausalConv1d/ConvBlock)。默认为 None。
42
+ activation (str, optional): 激活函数类型。默认为 'leaky_relu'。
43
+ leaky_relu (float, optional): LeakyReLU 的负斜率。默认为 0.1。
44
+ '''
45
+ super(TCN, self).__init__()
46
+
47
+ self.in_channels = in_channels
48
+ self.kernel_size = kernel_size
49
+ self.dropout = dropout
50
+
51
+ if num_channels is not None:
52
+ layers = []
53
+ num_levels = len(num_channels)
54
+ for i in range(num_levels):
55
+ dilation_size = 2 ** i
56
+ in_ch = in_channels if i == 0 else num_channels[i-1]
57
+ out_ch = num_channels[i]
58
+
59
+ layers.append(CausalConv1d(
60
+ in_channels=in_ch,
61
+ out_channels=out_ch,
62
+ kernel_size=kernel_size,
63
+ dilation=dilation_size,
64
+ norm=norm,
65
+ activation=activation,
66
+ leaky_relu=leaky_relu,
67
+ use_res=use_res,
68
+ dropout=dropout
69
+ ))
70
+ self.network = nn.Sequential(*layers)
71
+ self.out_channels = num_channels[-1]
72
+
73
+ elif step is not None and out_channels is not None:
74
+ self.network = CausalConv1d.auto_block(
75
+ in_channels=in_channels,
76
+ out_channels=out_channels,
77
+ step=step,
78
+ kernel_size=kernel_size,
79
+ norm=norm,
80
+ activation=activation,
81
+ leaky_relu=leaky_relu,
82
+ use_res=use_res,
83
+ dropout=dropout
84
+ )
85
+ self.out_channels = out_channels
86
+ else:
87
+ raise ValueError("Must provide either 'num_channels' (list) or both 'step' and 'out_channels' (int).")
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ '''
91
+ 前向传播。
92
+
93
+ Args:
94
+ x (torch.Tensor): 输入张量。Shape: [Batch, in_channels, Seq_Len]
95
+
96
+ Returns:
97
+ torch.Tensor: 输出张量。Shape: [Batch, out_channels, Seq_Len]
98
+ '''
99
+ return self.network(x)
orbit/model/config.py ADDED
@@ -0,0 +1,62 @@
1
+ import json
2
+ import os
3
+ from typing import Any, Dict
4
+
5
+ class ModelConfig:
6
+ '''基础配置类,用于管理模型超参数。
7
+
8
+ 支持从 JSON 文件加载和保存,以及字典风格的属性访问。
9
+ '''
10
+ def __init__(self, **kwargs):
11
+ for k, v in kwargs.items():
12
+ setattr(self, k, v)
13
+
14
+ @classmethod
15
+ def from_pretrained(cls, path: str) -> 'ModelConfig':
16
+ '''从 JSON 文件加载配置。
17
+
18
+ Args:
19
+ path (str): JSON 文件路径。
20
+
21
+ Returns:
22
+ ModelConfig: 加载的配置对象。
23
+ '''
24
+ if not os.path.exists(path):
25
+ raise FileNotFoundError(f"Config file not found: {path}")
26
+
27
+ with open(path, 'r', encoding='utf-8') as f:
28
+ config_dict = json.load(f)
29
+
30
+ return cls(**config_dict)
31
+
32
+ def save_pretrained(self, path: str):
33
+ '''将配置保存到 JSON 文件。
34
+
35
+ Args:
36
+ path (str): 保存路径。
37
+ '''
38
+ directory = os.path.dirname(path)
39
+ if directory and not os.path.exists(directory):
40
+ os.makedirs(directory)
41
+
42
+ with open(path, 'w', encoding='utf-8') as f:
43
+ json.dump(self.to_dict(), f, indent=4, ensure_ascii=False)
44
+
45
+ def to_dict(self) -> Dict[str, Any]:
46
+ '''转换为字典。'''
47
+ return {k: v for k, v in self.__dict__.items() if not k.startswith('_')}
48
+
49
+ def __getitem__(self, key):
50
+ return getattr(self, key)
51
+
52
+ def __setitem__(self, key, value):
53
+ setattr(self, key, value)
54
+
55
+ def __contains__(self, key):
56
+ return hasattr(self, key)
57
+
58
+ def get(self, key, default=None):
59
+ return getattr(self, key, default)
60
+
61
+ def __repr__(self):
62
+ return f"{self.__class__.__name__}({self.to_dict()})"
@@ -0,0 +1,6 @@
1
+ from .discriminator import (
2
+ NLayerDiscriminator
3
+ )
4
+ from .losses import (
5
+ VQGANLossOutput, VQGANDiscriminatorLog, VQGANGeneratorLog, VQGANLoss
6
+ )
@@ -0,0 +1,46 @@
1
+ import torch.nn as nn
2
+ from orbit.model import BaseBlock, register_model
3
+
4
+
5
+ @register_model()
6
+ class NLayerDiscriminator(BaseBlock):
7
+ '''
8
+ PatchGAN 判别器。
9
+ 输出不是一个标量,而是一个 N x N 的矩阵,每个点代表对应 Patch 是真还是假。
10
+ '''
11
+ def __init__(self, input_nc=3, ndf=64, n_layers=3):
12
+ super().__init__()
13
+
14
+ sequence = [
15
+ nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
16
+ nn.LeakyReLU(0.2, True)
17
+ ]
18
+
19
+ nf_mult = 1
20
+ nf_mult_prev = 1
21
+ for n in range(1, n_layers):
22
+ nf_mult_prev = nf_mult
23
+ nf_mult = min(2 ** n, 8)
24
+ sequence += [
25
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=2, padding=1, bias=False),
26
+ nn.BatchNorm2d(ndf * nf_mult),
27
+ nn.LeakyReLU(0.2, True)
28
+ ]
29
+
30
+ nf_mult_prev = nf_mult
31
+ nf_mult = min(2 ** n_layers, 8)
32
+
33
+ sequence += [
34
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1, padding=1, bias=False),
35
+ nn.BatchNorm2d(ndf * nf_mult),
36
+ nn.LeakyReLU(0.2, True)
37
+ ]
38
+
39
+ sequence += [
40
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)
41
+ ]
42
+
43
+ self.main = nn.Sequential(*sequence)
44
+
45
+ def forward(self, input):
46
+ return self.main(input)
@@ -0,0 +1,193 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from lpips import LPIPS
5
+ from dataclasses import dataclass
6
+ from typing import Dict, Tuple, Optional, Union
7
+
8
+ @dataclass
9
+ class VQGANGeneratorLog:
10
+ ''' VQGAN 生成器阶段的日志数据类。
11
+
12
+ Attributes:
13
+ total_loss (torch.Tensor): 总损失。
14
+ quant_loss (torch.Tensor): 量化损失。
15
+ nll_loss (torch.Tensor): 负对数似然损失(重构损失)。
16
+ p_loss (torch.Tensor): 感知损失。
17
+ rec_loss (torch.Tensor): 总重构损失(像素 + 感知)。
18
+ d_weight (torch.Tensor): 对抗损失的自适应权重。
19
+ g_loss (torch.Tensor): 生成器对抗损失。
20
+ '''
21
+ total_loss: torch.Tensor
22
+ quant_loss: torch.Tensor
23
+ nll_loss: torch.Tensor
24
+ p_loss: torch.Tensor
25
+ rec_loss: torch.Tensor
26
+ d_weight: torch.Tensor
27
+ g_loss: torch.Tensor
28
+
29
+ @dataclass
30
+ class VQGANDiscriminatorLog:
31
+ ''' VQGAN 判别器阶段的日志数据类。
32
+
33
+ Attributes:
34
+ disc_loss (torch.Tensor): 判别器总损失。
35
+ logits_real (torch.Tensor): 真实样本的 logits 均值。
36
+ logits_fake (torch.Tensor): 生成样本的 logits 均值。
37
+ '''
38
+ disc_loss: torch.Tensor
39
+ logits_real: torch.Tensor
40
+ logits_fake: torch.Tensor
41
+
42
+ @dataclass
43
+ class VQGANLossOutput:
44
+ ''' VQGANLoss 的输出数据类。
45
+
46
+ Attributes:
47
+ loss (torch.Tensor): 总损失标量。
48
+ log (Union[VQGANGeneratorLog, VQGANDiscriminatorLog]): 损失日志对象。
49
+ '''
50
+ loss: torch.Tensor
51
+ log: Union[VQGANGeneratorLog, VQGANDiscriminatorLog]
52
+
53
+ class VQGANLoss(nn.Module):
54
+ ''' VQGAN 模型的损失函数模块。
55
+
56
+ 结合了感知损失 (LPIPS)、重构损失 (L1/L2)、对抗损失 (GAN Loss) 和代码本损失。
57
+ '''
58
+ def __init__(
59
+ self,
60
+ disc_start: int = 10000,
61
+ kl_weight: float = 1.0,
62
+ pixelloss_weight: float = 1.0,
63
+ perceptual_weight: float = 1.0,
64
+ disc_weight: float = 0.8,
65
+ disc_factor: float = 1.0
66
+ ):
67
+ ''' 初始化 VQGANLoss。
68
+
69
+ Args:
70
+ disc_start (int): 判别器开始训练的步数。默认为 10000。
71
+ logvar_init (float): 对数方差的初始值。默认为 0.0。
72
+ kl_weight (float): KL 散度损失的权重。默认为 1.0。
73
+ pixelloss_weight (float): 像素级重构损失的权重。默认为 1.0。
74
+ perceptual_weight (float): 感知损失的权重。默认为 1.0。
75
+ disc_weight (float): 判别器损失的权重。默认为 0.8。
76
+ disc_factor (float): 判别器损失的缩放因子。默认为 1.0。
77
+ '''
78
+ super().__init__()
79
+
80
+ self.kl_weight = kl_weight
81
+ self.pixel_weight = pixelloss_weight
82
+ self.perceptual_weight = perceptual_weight
83
+ self.disc_factor = disc_factor
84
+ self.disc_weight = disc_weight
85
+ self.disc_start = disc_start
86
+
87
+ self.perceptual_loss = LPIPS(net='vgg', verbose=False).eval()
88
+
89
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer_weights):
90
+ ''' 计算自适应对抗损失权重 lambda。
91
+
92
+ Args:
93
+ nll_loss (torch.Tensor): 负对数似然损失(重构损失)。
94
+ g_loss (torch.Tensor): 生成器损失。
95
+ last_layer_weights (torch.Tensor): 解码器最后一层的权重。
96
+
97
+ Returns:
98
+ torch.Tensor: 自适应权重。
99
+ '''
100
+ nll_grads = torch.autograd.grad(nll_loss, last_layer_weights, retain_graph=True)[0]
101
+ g_grads = torch.autograd.grad(g_loss, last_layer_weights, retain_graph=True)[0]
102
+
103
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
104
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
105
+ return d_weight * self.disc_weight
106
+
107
+ def forward(
108
+ self,
109
+ inputs: torch.Tensor,
110
+ reconstructions: torch.Tensor,
111
+ quantizer_loss: torch.Tensor,
112
+ global_step: int,
113
+ last_layer_weights: torch.Tensor,
114
+ discriminator: nn.Module,
115
+ optimizer_idx: int,
116
+ mask: torch.Tensor = None
117
+ ) -> VQGANLossOutput:
118
+ ''' 前向计算损失。
119
+
120
+ Args:
121
+ inputs (torch.Tensor): 原始输入图像。
122
+ reconstructions (torch.Tensor): 重建图像。
123
+ quantizer_loss (torch.Tensor): 量化器损失。
124
+ global_step (int): 当前全局步数。
125
+ last_layer_weights (torch.Tensor): 解码器最后一层的权重。
126
+ discriminator (nn.Module): 判别器模型。
127
+ optimizer_idx (int): 优化器索引(0 为生成器,1 为判别器)。
128
+ mask (torch.Tensor, optional): 有效区域掩码。默认为 None。
129
+
130
+ Returns:
131
+ VQGANLossOutput: 包含 loss 和 log 的对象。
132
+ '''
133
+ rec_loss_tensor = torch.abs(inputs - reconstructions)
134
+
135
+ if mask is not None:
136
+ if mask.shape[-2:] != rec_loss_tensor.shape[-2:]:
137
+ mask = F.interpolate(mask, size=rec_loss_tensor.shape[-2:], mode='nearest')
138
+
139
+ mask_expanded = mask.expand_as(rec_loss_tensor)
140
+ nll_loss = (rec_loss_tensor * mask_expanded).sum() / (mask_expanded.sum() + 1e-6)
141
+ else:
142
+ nll_loss = torch.mean(rec_loss_tensor)
143
+
144
+ p_loss_scalar = torch.tensor(0.0, device=inputs.device)
145
+ if self.perceptual_weight > 0:
146
+ p_loss = self.perceptual_loss(inputs, reconstructions)
147
+ p_loss_scalar = p_loss.mean()
148
+
149
+ rec_loss_total = nll_loss * self.pixel_weight + self.perceptual_weight * p_loss_scalar
150
+
151
+ if optimizer_idx == 0:
152
+ logits_fake = discriminator(reconstructions)
153
+ g_loss = -torch.mean(logits_fake)
154
+
155
+ try: d_weight = self.calculate_adaptive_weight(rec_loss_total, g_loss, last_layer_weights)
156
+ except RuntimeError:
157
+ assert not self.training
158
+ d_weight = torch.tensor(0.0)
159
+
160
+ disc_factor = 1 if global_step >= self.disc_start else 0
161
+
162
+ loss = rec_loss_total + \
163
+ self.kl_weight * quantizer_loss + \
164
+ d_weight * disc_factor * g_loss
165
+
166
+ log = VQGANGeneratorLog(
167
+ total_loss=loss.detach(),
168
+ quant_loss=quantizer_loss.detach(),
169
+ nll_loss=nll_loss.detach(),
170
+ p_loss=p_loss_scalar.detach(),
171
+ rec_loss=rec_loss_total.detach(),
172
+ d_weight=d_weight.detach(),
173
+ g_loss=g_loss.detach()
174
+ )
175
+ return VQGANLossOutput(loss=loss, log=log)
176
+
177
+ if optimizer_idx == 1:
178
+ logits_real = discriminator(inputs.detach())
179
+ logits_fake = discriminator(reconstructions.detach())
180
+
181
+ disc_factor = 1 if global_step >= self.disc_start else 0
182
+
183
+ # Hinge Loss
184
+ loss_real = torch.mean(F.relu(1. - logits_real))
185
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
186
+ d_loss = disc_factor * 0.5 * (loss_real + loss_fake)
187
+
188
+ log = VQGANDiscriminatorLog(
189
+ disc_loss=d_loss.detach(),
190
+ logits_real=logits_real.mean().detach(),
191
+ logits_fake=logits_fake.mean().detach()
192
+ )
193
+ return VQGANLossOutput(loss=d_loss, log=log)
File without changes
File without changes