orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -1
- orbit/callback.py +4 -3
- orbit/dataset/__init__.py +1 -0
- orbit/dataset/cogn.py +138 -0
- orbit/dataset/data/cogn_en.jsonl +45 -0
- orbit/dataset/data/cogn_zh.jsonl +113 -0
- orbit/engine.py +210 -146
- orbit/kit/__init__.py +2 -0
- orbit/kit/interface.py +154 -0
- orbit/kit/wrapper.py +157 -0
- orbit/model/__init__.py +5 -0
- orbit/model/base.py +125 -0
- orbit/model/block/__init__.py +34 -0
- orbit/model/block/attention.py +265 -0
- orbit/model/block/bio.py +537 -0
- orbit/model/block/codebook.py +122 -0
- orbit/model/block/conv.py +505 -0
- orbit/model/block/embedding.py +252 -0
- orbit/model/block/film.py +176 -0
- orbit/model/block/fusion.py +335 -0
- orbit/model/block/gate.py +334 -0
- orbit/model/block/lora.py +776 -0
- orbit/model/block/mlp.py +68 -0
- orbit/model/block/moe.py +94 -0
- orbit/model/block/tcn.py +99 -0
- orbit/model/config.py +62 -0
- orbit/model/kit/__init__.py +6 -0
- orbit/model/kit/discriminator.py +46 -0
- orbit/model/kit/losses.py +193 -0
- orbit/model/motif/__init__.py +0 -0
- orbit/model/motif/vision/__init__.py +0 -0
- orbit/model/motif/vision/v1.py +645 -0
- orbit/model/registry.py +53 -0
- orbit/optim/__init__.py +2 -2
- orbit/optim/sam.py +10 -3
- orbit/plugin/__init__.py +12 -8
- orbit/plugin/board.py +1 -2
- orbit/plugin/checkpoint.py +137 -62
- orbit/plugin/classification.py +2 -2
- orbit/plugin/display_model.py +1 -2
- orbit/plugin/early_stopping.py +1 -2
- orbit/plugin/ema.py +1 -2
- orbit/plugin/gradient_accumulation.py +1 -2
- orbit/plugin/lora.py +346 -0
- orbit/plugin/memory_estimator.py +1 -2
- orbit/plugin/warmup.py +1 -2
- orbit/utils/__init__.py +24 -1
- orbit/utils/cuda.py +10 -0
- orbit/utils/freeze.py +61 -17
- orbit/utils/image.py +164 -0
- orbit/utils/initialization.py +184 -94
- orbit/utils/layer_io.py +66 -7
- orbit/utils/lora.py +480 -0
- orbit/utils/moe.py +55 -0
- orbit/utils/seed.py +3 -19
- orbit/utils/sft.py +93 -0
- orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
- orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
- orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
from orbit.model import BaseBlock, register_model
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@register_model()
|
|
9
|
+
class RotaryPositionalEmbedding(BaseBlock):
|
|
10
|
+
'''
|
|
11
|
+
旋转位置编码 (Rotary Positional Embedding, RoPE)。
|
|
12
|
+
'''
|
|
13
|
+
|
|
14
|
+
def __init__(self, model_dim: int, max_len: int = 128000, base: int = 10000):
|
|
15
|
+
'''
|
|
16
|
+
初始化 RoPE 模块。
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model_dim (int): 模型的维度 (或 head_dim)。必须是偶数。
|
|
20
|
+
max_len (int, optional): 预计算位置编码的最大序列长度。默认为 128000。
|
|
21
|
+
base (int, optional): 计算频率的基数。默认为 10000。
|
|
22
|
+
'''
|
|
23
|
+
super(RotaryPositionalEmbedding, self).__init__()
|
|
24
|
+
|
|
25
|
+
self.model_dim = model_dim
|
|
26
|
+
self.max_len = max_len
|
|
27
|
+
self.base = base
|
|
28
|
+
|
|
29
|
+
inv_freq = 1.0 / (base ** (torch.arange(0, model_dim, 2).float() / model_dim))
|
|
30
|
+
|
|
31
|
+
t = torch.arange(max_len, dtype=torch.float)
|
|
32
|
+
|
|
33
|
+
freqs = torch.outer(t, inv_freq)
|
|
34
|
+
|
|
35
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
36
|
+
|
|
37
|
+
self.register_buffer('cos_cached', emb.cos())
|
|
38
|
+
self.register_buffer('sin_cached', emb.sin())
|
|
39
|
+
|
|
40
|
+
def _rotate_half(self, x: torch.Tensor) -> torch.Tensor:
|
|
41
|
+
'''
|
|
42
|
+
将向量分为两半并旋转: [-x2, x1]。
|
|
43
|
+
无论输入是 3D 还是 4D,Split 都是作用在最后一维 (model_dim)。
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
x (torch.Tensor): 输入张量。
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
torch.Tensor: 旋转后的张量。
|
|
50
|
+
'''
|
|
51
|
+
x1, x2 = x.chunk(2, dim=-1)
|
|
52
|
+
return torch.cat((-x2, x1), dim=-1)
|
|
53
|
+
|
|
54
|
+
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
|
|
55
|
+
'''
|
|
56
|
+
应用旋转位置编码。
|
|
57
|
+
|
|
58
|
+
自动适配两种输入:
|
|
59
|
+
1. [Batch, Seq_Len, Dim]
|
|
60
|
+
2. [Batch, Head, Seq_Len, Head_Dim]
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
x (torch.Tensor): 输入张量。
|
|
64
|
+
start_pos (int, optional): 起始位置索引,用于 KV Cache 推理。默认为 0。
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
torch.Tensor: 添加了位置信息的张量。
|
|
68
|
+
'''
|
|
69
|
+
ndim = x.ndim
|
|
70
|
+
seq_len = x.shape[-2]
|
|
71
|
+
|
|
72
|
+
cos = self.cos_cached[start_pos : start_pos + seq_len, :]
|
|
73
|
+
sin = self.sin_cached[start_pos : start_pos + seq_len, :]
|
|
74
|
+
|
|
75
|
+
shape = [1] * (ndim - 2) + [seq_len, -1]
|
|
76
|
+
cos = cos.view(*shape)
|
|
77
|
+
sin = sin.view(*shape)
|
|
78
|
+
|
|
79
|
+
return (x * cos) + (self._rotate_half(x) * sin)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@register_model()
|
|
83
|
+
class SinusoidalPositionalEmbedding(BaseBlock):
|
|
84
|
+
|
|
85
|
+
def __init__(self, model_dim: int, max_len: int = 128000):
|
|
86
|
+
'''
|
|
87
|
+
初始化绝对位置编码模块。
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
model_dim (int): 模型的维度。
|
|
91
|
+
max_len (int, optional): 最大序列长度。默认为 128000。
|
|
92
|
+
'''
|
|
93
|
+
super(SinusoidalPositionalEmbedding, self).__init__()
|
|
94
|
+
|
|
95
|
+
self.model_dim = model_dim
|
|
96
|
+
self.max_len = max_len
|
|
97
|
+
|
|
98
|
+
pe = torch.zeros(max_len, model_dim)
|
|
99
|
+
|
|
100
|
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
|
101
|
+
|
|
102
|
+
div_term = torch.exp(torch.arange(0, model_dim, 2).float() * (-math.log(10000.0) / model_dim))
|
|
103
|
+
|
|
104
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
|
105
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
|
106
|
+
|
|
107
|
+
pe = pe.unsqueeze(0)
|
|
108
|
+
self.register_buffer('pe', pe)
|
|
109
|
+
|
|
110
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
111
|
+
'''
|
|
112
|
+
前向传播。
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
x (torch.Tensor): 输入张量。Shape: [Batch_Size, Seq_Len, model_dim]。
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
torch.Tensor: 加上位置编码后的张量。
|
|
119
|
+
'''
|
|
120
|
+
x = x + self.pe[:, :x.size(1), :]
|
|
121
|
+
return x
|
|
122
|
+
|
|
123
|
+
@register_model()
|
|
124
|
+
class MRoPEInterleavedEmbedding(BaseBlock):
|
|
125
|
+
'''
|
|
126
|
+
交错分配多模态旋转位置编码 (MRoPE‑Interleave)。
|
|
127
|
+
支持三维位置(时间 t、高度 h、宽度 w),频率通道采用轮转交错分配 (thw…thw…thw)。
|
|
128
|
+
'''
|
|
129
|
+
def __init__(self, model_dim: int, max_len: int = 128000, base: int = 10000, num_axes: int = 3):
|
|
130
|
+
'''
|
|
131
|
+
初始化 MRoPEInterleaved 模块。
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
model_dim (int): 模型的维度。必须是偶数且能被 num_axes 整除。
|
|
135
|
+
max_len (int, optional): 预计算位置编码的最大序列长度。默认为 128000。
|
|
136
|
+
base (int, optional): 计算频率的基数。默认为 10000。
|
|
137
|
+
num_axes (int, optional): 位置轴的数量(例如 3 表示时间、高度、宽度)。默认为 3。
|
|
138
|
+
'''
|
|
139
|
+
super().__init__()
|
|
140
|
+
assert model_dim % 2 == 0, 'model_dim must be even'
|
|
141
|
+
assert model_dim % num_axes == 0, f'model_dim {model_dim} not divisible by num_axes {num_axes}'
|
|
142
|
+
|
|
143
|
+
self.model_dim = model_dim
|
|
144
|
+
self.max_len = max_len
|
|
145
|
+
self.base = base
|
|
146
|
+
self.num_axes = num_axes
|
|
147
|
+
|
|
148
|
+
inv_freq = 1.0 / (base ** (torch.arange(0, model_dim, 2).float() / model_dim))
|
|
149
|
+
|
|
150
|
+
t_range = torch.arange(max_len, dtype=torch.float)
|
|
151
|
+
freqs = torch.outer(t_range, inv_freq) # [max_len, dim/2]
|
|
152
|
+
|
|
153
|
+
emb = torch.cat((freqs, freqs), dim=-1) # [max_len, dim]
|
|
154
|
+
|
|
155
|
+
self.register_buffer('cos_cached', emb.cos())
|
|
156
|
+
self.register_buffer('sin_cached', emb.sin())
|
|
157
|
+
|
|
158
|
+
self.register_buffer(
|
|
159
|
+
'axis_mask',
|
|
160
|
+
torch.arange(model_dim) % num_axes,
|
|
161
|
+
persistent=False
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
k = model_dim // num_axes
|
|
165
|
+
idx = []
|
|
166
|
+
for p in range(model_dim):
|
|
167
|
+
j = p % num_axes
|
|
168
|
+
i = p // num_axes
|
|
169
|
+
pos_in_old = j * k + i
|
|
170
|
+
idx.append(pos_in_old)
|
|
171
|
+
|
|
172
|
+
self.register_buffer('interleave_idx', torch.tensor(idx, dtype=torch.long), persistent=False)
|
|
173
|
+
|
|
174
|
+
def _rotate_half(self, x: torch.Tensor) -> torch.Tensor:
|
|
175
|
+
'''
|
|
176
|
+
将向量分为两半并旋转: [-x2, x1]。
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
x (torch.Tensor): 输入张量。
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
torch.Tensor: 旋转后的张量。
|
|
183
|
+
'''
|
|
184
|
+
x1, x2 = x.chunk(2, dim=-1)
|
|
185
|
+
return torch.cat((-x2, x1), dim=-1)
|
|
186
|
+
|
|
187
|
+
def forward(self, x: torch.Tensor, positions: torch.Tensor = None, start_pos: int = 0) -> torch.Tensor:
|
|
188
|
+
'''
|
|
189
|
+
应用多模态旋转位置编码。
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
x (torch.Tensor): 输入张量。Shape: [Batch, Seq_Len, Dim] 或 [Batch, Head, Seq_Len, Head_Dim]。
|
|
193
|
+
positions (torch.Tensor, optional): 位置索引张量。Shape: [Batch, Seq_Len] 或 [Batch, Seq_Len, num_axes]。
|
|
194
|
+
如果是 2D 张量,将自动扩展为 [Batch, Seq_Len, num_axes]。
|
|
195
|
+
如果为 None 且 num_axes=1,将自动创建线性位置索引。
|
|
196
|
+
start_pos (int, optional): 起始位置索引。默认为 0。
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
torch.Tensor: 添加了位置信息的张量。
|
|
200
|
+
|
|
201
|
+
Raises:
|
|
202
|
+
ValueError: 如果 positions 为 None 且 num_axes > 1。
|
|
203
|
+
'''
|
|
204
|
+
ndim = x.ndim
|
|
205
|
+
seq_len = x.shape[-2]
|
|
206
|
+
batch_size = x.shape[0]
|
|
207
|
+
|
|
208
|
+
if positions is None:
|
|
209
|
+
if self.num_axes == 1:
|
|
210
|
+
positions = torch.arange(0, seq_len, device=x.device, dtype=torch.long)
|
|
211
|
+
else:
|
|
212
|
+
raise ValueError("positions must be provided when num_axes > 1 (e.g. for vision/multimodal inputs)")
|
|
213
|
+
|
|
214
|
+
if positions.ndim == 1:
|
|
215
|
+
positions = positions.unsqueeze(0).unsqueeze(-1).expand(batch_size, -1, self.num_axes)
|
|
216
|
+
|
|
217
|
+
if positions.ndim == 2:
|
|
218
|
+
positions = positions.unsqueeze(-1).expand(-1, -1, self.num_axes)
|
|
219
|
+
|
|
220
|
+
if positions.ndim == 3 and positions.shape[-1] == 1:
|
|
221
|
+
positions = positions.expand(-1, -1, self.num_axes)
|
|
222
|
+
|
|
223
|
+
batch_size = positions.shape[0]
|
|
224
|
+
|
|
225
|
+
cos_list, sin_list = [], []
|
|
226
|
+
|
|
227
|
+
for ax in range(self.num_axes):
|
|
228
|
+
pos_ax = positions[..., ax]
|
|
229
|
+
pos_ax = torch.clamp(pos_ax + start_pos, 0, self.max_len - 1).long()
|
|
230
|
+
|
|
231
|
+
cos_full = self.cos_cached[pos_ax]
|
|
232
|
+
sin_full = self.sin_cached[pos_ax]
|
|
233
|
+
|
|
234
|
+
mask = (self.axis_mask == ax)
|
|
235
|
+
cos_ax = cos_full[..., mask]
|
|
236
|
+
sin_ax = sin_full[..., mask]
|
|
237
|
+
|
|
238
|
+
cos_list.append(cos_ax)
|
|
239
|
+
sin_list.append(sin_ax)
|
|
240
|
+
|
|
241
|
+
cos_all = torch.cat(cos_list, dim=-1)
|
|
242
|
+
sin_all = torch.cat(sin_list, dim=-1)
|
|
243
|
+
|
|
244
|
+
cos_all = cos_all[..., self.interleave_idx]
|
|
245
|
+
sin_all = sin_all[..., self.interleave_idx]
|
|
246
|
+
|
|
247
|
+
if ndim == 4:
|
|
248
|
+
shape = [batch_size, 1, seq_len, -1]
|
|
249
|
+
cos_all = cos_all.view(*shape)
|
|
250
|
+
sin_all = sin_all.view(*shape)
|
|
251
|
+
|
|
252
|
+
return (x * cos_all) + (self._rotate_half(x) * sin_all)
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
from orbit.model import BaseBlock, register_model
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class FiLMOutput:
|
|
11
|
+
''' FiLM 模块的输出容器。
|
|
12
|
+
|
|
13
|
+
Attributes:
|
|
14
|
+
output (torch.Tensor): 经过 gamma 和 beta 调制后的特征。
|
|
15
|
+
gate (Optional[torch.Tensor]): 用于残差连接的门控值。
|
|
16
|
+
'''
|
|
17
|
+
output: torch.Tensor
|
|
18
|
+
gate: Optional[torch.Tensor] = None
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def gated_output(self):
|
|
22
|
+
if self.gate is None: return self.output
|
|
23
|
+
return self.output * self.gate
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@register_model()
|
|
27
|
+
class FiLM(BaseBlock):
|
|
28
|
+
''' Feature-wise Linear Modulation (FiLM) 模块。
|
|
29
|
+
|
|
30
|
+
对输入特征进行仿射变换:FiLM(x) = (1 + gamma(z)) * x + beta(z)
|
|
31
|
+
其中 gamma 和 beta 是从条件输入 z 生成的。
|
|
32
|
+
初始状态下,gamma 为 0,beta 为 0,即恒等映射。
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
in_features (int): 输入特征维度。
|
|
36
|
+
cond_features (int): 条件特征维度。
|
|
37
|
+
use_beta (bool, optional): 是否使用平移项 (beta)。默认为 True。
|
|
38
|
+
use_gamma (bool, optional): 是否使用缩放项 (gamma)。默认为 True。
|
|
39
|
+
use_gate (bool, optional): 是否使用门控项 (gate)。默认为 True。
|
|
40
|
+
use_context_gate (bool, optional): 是否使用上下文门控 (context gate)。
|
|
41
|
+
如果为 True,将使用输入特征和条件特征的拼接来生成门控值,并覆盖 use_gate 的设置。默认为 False。
|
|
42
|
+
channel_first (bool, optional): 特征维度是否在第 1 维 (如 CNN [B, C, H, W])。
|
|
43
|
+
如果为 False,则假设特征在最后一维 (如 Transformer [B, L, C])。默认为 False。
|
|
44
|
+
'''
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
in_features: int,
|
|
48
|
+
cond_features: int,
|
|
49
|
+
use_beta: bool = True,
|
|
50
|
+
use_gamma: bool = True,
|
|
51
|
+
use_gate: bool = True,
|
|
52
|
+
use_context_gate: bool = False,
|
|
53
|
+
channel_first: bool = False
|
|
54
|
+
):
|
|
55
|
+
super(FiLM, self).__init__()
|
|
56
|
+
|
|
57
|
+
if use_context_gate: use_gate = False
|
|
58
|
+
|
|
59
|
+
self.in_features = in_features
|
|
60
|
+
self.cond_features = cond_features
|
|
61
|
+
self.use_beta = use_beta
|
|
62
|
+
self.use_gamma = use_gamma
|
|
63
|
+
self.use_gate = use_gate
|
|
64
|
+
self.use_context_gate = use_context_gate
|
|
65
|
+
self.channel_first = channel_first
|
|
66
|
+
|
|
67
|
+
self.out_dim = 0
|
|
68
|
+
if use_gamma: self.out_dim += in_features
|
|
69
|
+
if use_beta: self.out_dim += in_features
|
|
70
|
+
if use_gate: self.out_dim += in_features
|
|
71
|
+
|
|
72
|
+
self.gate_proj = nn.Linear(in_features + cond_features, in_features) if use_context_gate else nn.Identity()
|
|
73
|
+
|
|
74
|
+
if self.out_dim > 0:
|
|
75
|
+
self.proj = nn.Linear(cond_features, self.out_dim)
|
|
76
|
+
else: self.proj = None
|
|
77
|
+
|
|
78
|
+
self._init_weights(self)
|
|
79
|
+
|
|
80
|
+
def _init_weights(self, model: nn.Module):
|
|
81
|
+
''' 初始化权重。
|
|
82
|
+
|
|
83
|
+
将投影层的权重和偏置初始化为 0,以确保初始状态为恒等映射。
|
|
84
|
+
如果使用了上下文门控,其投影层使用 Xavier Uniform 初始化。
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
model (nn.Module): 需要初始化的模型。
|
|
88
|
+
'''
|
|
89
|
+
if model is self and self.proj is not None:
|
|
90
|
+
nn.init.constant_(self.proj.weight, 0)
|
|
91
|
+
nn.init.constant_(self.proj.bias, 0)
|
|
92
|
+
if isinstance(self.gate_proj, nn.Identity): return
|
|
93
|
+
nn.init.xavier_uniform_(self.gate_proj.weight, gain=0.1)
|
|
94
|
+
nn.init.zeros_(self.gate_proj.bias)
|
|
95
|
+
|
|
96
|
+
def _reshape(self, param: torch.Tensor, ref_ndim: int) -> torch.Tensor:
|
|
97
|
+
''' 调整参数形状以匹配输入特征的维度,以便进行广播。
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
param (torch.Tensor): 需要重塑的参数张量。
|
|
101
|
+
ref_ndim (int): 参考张量(通常是输入特征 x)的维度数。
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
torch.Tensor: 重塑后的参数张量。
|
|
105
|
+
'''
|
|
106
|
+
if self.channel_first:
|
|
107
|
+
param = param.movedim(-1, 1)
|
|
108
|
+
for _ in range(ref_ndim - param.ndim):
|
|
109
|
+
param = param.unsqueeze(-1)
|
|
110
|
+
else:
|
|
111
|
+
for _ in range(ref_ndim - param.ndim):
|
|
112
|
+
param = param.unsqueeze(-2)
|
|
113
|
+
return param
|
|
114
|
+
|
|
115
|
+
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> FiLMOutput:
|
|
116
|
+
''' 前向传播。
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
x (torch.Tensor): 输入特征。形状为 [B, C, ...] (如果 channel_first=True)
|
|
120
|
+
或 [B, ..., C] (如果 channel_first=False)。
|
|
121
|
+
cond (torch.Tensor): 条件输入。形状为 [B, ..., cond_features]。
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
FiLMOutput: 调制后的特征。
|
|
125
|
+
'''
|
|
126
|
+
if self.proj is None: return FiLMOutput(output=x)
|
|
127
|
+
|
|
128
|
+
params = self.proj(cond)
|
|
129
|
+
|
|
130
|
+
count = sum([self.use_gamma, self.use_beta, self.use_gate])
|
|
131
|
+
if count > 1:
|
|
132
|
+
params_list = params.chunk(count, dim=-1)
|
|
133
|
+
else:
|
|
134
|
+
params_list = [params]
|
|
135
|
+
|
|
136
|
+
idx = 0
|
|
137
|
+
gamma, beta, gate = None, None, None
|
|
138
|
+
if self.use_gamma:
|
|
139
|
+
gamma = params_list[idx]
|
|
140
|
+
idx += 1
|
|
141
|
+
if self.use_beta:
|
|
142
|
+
beta = params_list[idx]
|
|
143
|
+
idx += 1
|
|
144
|
+
if self.use_gate:
|
|
145
|
+
gate = params_list[idx]
|
|
146
|
+
idx += 1
|
|
147
|
+
|
|
148
|
+
out = x
|
|
149
|
+
if gamma is not None:
|
|
150
|
+
out = out * (1 + self._reshape(gamma, x.ndim))
|
|
151
|
+
if beta is not None:
|
|
152
|
+
out = out + self._reshape(beta, x.ndim)
|
|
153
|
+
|
|
154
|
+
final_gate = None
|
|
155
|
+
if self.use_context_gate:
|
|
156
|
+
if cond.ndim < x.ndim:
|
|
157
|
+
shape = list(x.shape)
|
|
158
|
+
feat_dim = 1 if self.channel_first else -1
|
|
159
|
+
shape[feat_dim] = -1
|
|
160
|
+
cond_expanded = self._reshape(cond, x.ndim).expand(shape)
|
|
161
|
+
else:
|
|
162
|
+
cond_expanded = cond
|
|
163
|
+
|
|
164
|
+
feat_dim = 1 if self.channel_first else -1
|
|
165
|
+
context_input = torch.cat([x, cond_expanded], dim=feat_dim)
|
|
166
|
+
|
|
167
|
+
if self.channel_first:
|
|
168
|
+
context_input = context_input.movedim(1, -1)
|
|
169
|
+
final_gate = self.gate_proj(context_input).movedim(-1, 1)
|
|
170
|
+
else:
|
|
171
|
+
final_gate = self.gate_proj(context_input)
|
|
172
|
+
|
|
173
|
+
elif gate is not None:
|
|
174
|
+
final_gate = self._reshape(gate, x.ndim)
|
|
175
|
+
|
|
176
|
+
return FiLMOutput(output=out, gate=final_gate)
|