orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -1
- orbit/callback.py +4 -3
- orbit/dataset/__init__.py +1 -0
- orbit/dataset/cogn.py +138 -0
- orbit/dataset/data/cogn_en.jsonl +45 -0
- orbit/dataset/data/cogn_zh.jsonl +113 -0
- orbit/engine.py +210 -146
- orbit/kit/__init__.py +2 -0
- orbit/kit/interface.py +154 -0
- orbit/kit/wrapper.py +157 -0
- orbit/model/__init__.py +5 -0
- orbit/model/base.py +125 -0
- orbit/model/block/__init__.py +34 -0
- orbit/model/block/attention.py +265 -0
- orbit/model/block/bio.py +537 -0
- orbit/model/block/codebook.py +122 -0
- orbit/model/block/conv.py +505 -0
- orbit/model/block/embedding.py +252 -0
- orbit/model/block/film.py +176 -0
- orbit/model/block/fusion.py +335 -0
- orbit/model/block/gate.py +334 -0
- orbit/model/block/lora.py +776 -0
- orbit/model/block/mlp.py +68 -0
- orbit/model/block/moe.py +94 -0
- orbit/model/block/tcn.py +99 -0
- orbit/model/config.py +62 -0
- orbit/model/kit/__init__.py +6 -0
- orbit/model/kit/discriminator.py +46 -0
- orbit/model/kit/losses.py +193 -0
- orbit/model/motif/__init__.py +0 -0
- orbit/model/motif/vision/__init__.py +0 -0
- orbit/model/motif/vision/v1.py +645 -0
- orbit/model/registry.py +53 -0
- orbit/optim/__init__.py +2 -2
- orbit/optim/sam.py +10 -3
- orbit/plugin/__init__.py +12 -8
- orbit/plugin/board.py +1 -2
- orbit/plugin/checkpoint.py +137 -62
- orbit/plugin/classification.py +2 -2
- orbit/plugin/display_model.py +1 -2
- orbit/plugin/early_stopping.py +1 -2
- orbit/plugin/ema.py +1 -2
- orbit/plugin/gradient_accumulation.py +1 -2
- orbit/plugin/lora.py +346 -0
- orbit/plugin/memory_estimator.py +1 -2
- orbit/plugin/warmup.py +1 -2
- orbit/utils/__init__.py +24 -1
- orbit/utils/cuda.py +10 -0
- orbit/utils/freeze.py +61 -17
- orbit/utils/image.py +164 -0
- orbit/utils/initialization.py +184 -94
- orbit/utils/layer_io.py +66 -7
- orbit/utils/lora.py +480 -0
- orbit/utils/moe.py +55 -0
- orbit/utils/seed.py +3 -19
- orbit/utils/sft.py +93 -0
- orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
- orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
- orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
orbit/model/block/bio.py
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Optional, Tuple
|
|
6
|
+
|
|
7
|
+
from orbit.model import BaseBlock, register_model
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class PredictiveCodingOutput:
|
|
12
|
+
output: torch.Tensor
|
|
13
|
+
reconstruction: Optional[torch.Tensor] = None
|
|
14
|
+
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@register_model()
|
|
18
|
+
class HebianLayer(BaseBlock):
|
|
19
|
+
''' Hebbian Learning Layer.
|
|
20
|
+
|
|
21
|
+
实现基于 Hebbian 规则的无监督学习层。支持标准 Hebbian 规则和 Oja 规则。
|
|
22
|
+
'''
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
in_features: int,
|
|
26
|
+
out_features: int,
|
|
27
|
+
lr: float = 1e-3,
|
|
28
|
+
mode: str = 'oja',
|
|
29
|
+
bias: bool = True,
|
|
30
|
+
auto_update: bool = True
|
|
31
|
+
):
|
|
32
|
+
''' 初始化 Hebbian 学习层。
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
in_features (int): 输入特征维度。
|
|
36
|
+
out_features (int): 输出特征维度。
|
|
37
|
+
lr (float, optional): Hebbian 学习率。默认为 1e-3。
|
|
38
|
+
mode (str, optional): 更新模式,可选 'basic' 或 'oja'。默认为 'oja'。
|
|
39
|
+
bias (bool, optional): 是否使用偏置。默认为 True。
|
|
40
|
+
auto_update (bool, optional): 是否在 forward 中自动更新权重。默认为 True。
|
|
41
|
+
'''
|
|
42
|
+
super(HebianLayer, self).__init__()
|
|
43
|
+
self.in_features = in_features
|
|
44
|
+
self.out_features = out_features
|
|
45
|
+
self.lr = lr
|
|
46
|
+
self.mode = mode.lower()
|
|
47
|
+
self.auto_update = auto_update
|
|
48
|
+
|
|
49
|
+
if self.mode not in ['basic', 'oja']:
|
|
50
|
+
raise ValueError(f"Unsupported mode: {mode}. Must be 'basic' or 'oja'.")
|
|
51
|
+
|
|
52
|
+
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
|
|
53
|
+
|
|
54
|
+
if bias:
|
|
55
|
+
self.bias = nn.Parameter(torch.Tensor(out_features))
|
|
56
|
+
else:
|
|
57
|
+
self.register_parameter('bias', None)
|
|
58
|
+
|
|
59
|
+
self._init_weights(self)
|
|
60
|
+
|
|
61
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
62
|
+
''' 前向传播。
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
x (torch.Tensor): 输入张量 (Batch, ..., In_Features)。
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
torch.Tensor: 输出张量 (Batch, ..., Out_Features)。
|
|
69
|
+
'''
|
|
70
|
+
y = F.linear(x, self.weight, self.bias)
|
|
71
|
+
|
|
72
|
+
if self.training and self.auto_update:
|
|
73
|
+
if x.dim() > 2:
|
|
74
|
+
x_flat = x.reshape(-1, x.size(-1))
|
|
75
|
+
y_flat = y.reshape(-1, y.size(-1))
|
|
76
|
+
self._update_weights(x_flat, y_flat)
|
|
77
|
+
else:
|
|
78
|
+
self._update_weights(x, y)
|
|
79
|
+
|
|
80
|
+
return y
|
|
81
|
+
|
|
82
|
+
@torch.no_grad()
|
|
83
|
+
def _update_weights(self, x: torch.Tensor, y: torch.Tensor):
|
|
84
|
+
''' 执行权重更新。 '''
|
|
85
|
+
if self.mode == 'basic':
|
|
86
|
+
self._basic_update(x, y)
|
|
87
|
+
elif self.mode == 'oja':
|
|
88
|
+
self._oja_update(x, y)
|
|
89
|
+
|
|
90
|
+
@torch.no_grad()
|
|
91
|
+
def _basic_update(self, x: torch.Tensor, y: torch.Tensor):
|
|
92
|
+
''' 执行标准 Hebbian 更新规则。
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
x (torch.Tensor): 输入张量。
|
|
96
|
+
y (torch.Tensor): 输出张量。
|
|
97
|
+
'''
|
|
98
|
+
batch_size = x.size(0)
|
|
99
|
+
|
|
100
|
+
# y^T * x -> (M, N)
|
|
101
|
+
grad_w = torch.matmul(y.t(), x)
|
|
102
|
+
|
|
103
|
+
self.weight.data += self.lr * grad_w / batch_size
|
|
104
|
+
|
|
105
|
+
if self.bias is not None:
|
|
106
|
+
# db = lr * sum(y)
|
|
107
|
+
grad_b = y.sum(dim=0)
|
|
108
|
+
self.bias.data += self.lr * grad_b / batch_size
|
|
109
|
+
|
|
110
|
+
@torch.no_grad()
|
|
111
|
+
def _oja_update(self, x: torch.Tensor, y: torch.Tensor):
|
|
112
|
+
''' 执行 Oja 更新规则。
|
|
113
|
+
|
|
114
|
+
Oja 规则通过归一化防止权重无限增长。
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
x (torch.Tensor): 输入张量。
|
|
118
|
+
y (torch.Tensor): 输出张量。
|
|
119
|
+
'''
|
|
120
|
+
batch_size = x.size(0)
|
|
121
|
+
|
|
122
|
+
# y^T * x -> (M, N)
|
|
123
|
+
yx = torch.matmul(y.t(), x)
|
|
124
|
+
|
|
125
|
+
# y^2 -> (B, M), 在批次上求和 -> (M)
|
|
126
|
+
y_sq = torch.sum(y ** 2, dim=0)
|
|
127
|
+
|
|
128
|
+
# (M, 1) * (M, N) -> (M, N)
|
|
129
|
+
grad_w = yx - y_sq.unsqueeze(1) * self.weight
|
|
130
|
+
|
|
131
|
+
self.weight.data += self.lr * grad_w / batch_size
|
|
132
|
+
|
|
133
|
+
if self.bias is not None:
|
|
134
|
+
grad_b = y.sum(dim=0) - y_sq * self.bias
|
|
135
|
+
self.bias.data += self.lr * grad_b / batch_size
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@register_model()
|
|
139
|
+
class PredictiveCodingLayer(BaseBlock):
|
|
140
|
+
''' Predictive Coding Layer.
|
|
141
|
+
|
|
142
|
+
实现基于预测编码原理的层。该层维护一个内部状态(表示),
|
|
143
|
+
并通过最小化预测误差来更新状态。
|
|
144
|
+
'''
|
|
145
|
+
def __init__(
|
|
146
|
+
self,
|
|
147
|
+
in_features: int,
|
|
148
|
+
out_features: int,
|
|
149
|
+
num_iter: int = 10,
|
|
150
|
+
lr_state: float = 0.01,
|
|
151
|
+
lr_weight: float = 1e-3,
|
|
152
|
+
weight_decay: float = 0.0,
|
|
153
|
+
auto_update: bool = True,
|
|
154
|
+
activation: nn.Module = nn.LeakyReLU(),
|
|
155
|
+
output_activation: nn.Module = nn.Identity(),
|
|
156
|
+
separate_weights: bool = False
|
|
157
|
+
):
|
|
158
|
+
''' 初始化预测编码层。
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
in_features (int): 输入特征维度。
|
|
162
|
+
out_features (int): 输出特征维度(隐藏状态维度)。
|
|
163
|
+
num_iter (int, optional): 推理时的迭代次数。默认为 10。
|
|
164
|
+
lr_state (float, optional): 状态更新率。默认为 0.01。
|
|
165
|
+
lr_weight (float, optional): 权重更新率。默认为 1e-3。
|
|
166
|
+
weight_decay (float, optional): 权重衰减率。默认为 0.0。
|
|
167
|
+
auto_update (bool, optional): 是否在 forward 中自动更新权重。默认为 True。
|
|
168
|
+
activation (nn.Module, optional): 状态激活函数。默认为 nn.LeakyReLU()。
|
|
169
|
+
output_activation (nn.Module, optional): 输出生成激活函数。默认为 nn.Identity()。
|
|
170
|
+
separate_weights (bool, optional): 是否使用分离的编码器和解码器权重。默认为 False。
|
|
171
|
+
'''
|
|
172
|
+
super(PredictiveCodingLayer, self).__init__()
|
|
173
|
+
self.in_features = in_features
|
|
174
|
+
self.out_features = out_features
|
|
175
|
+
self.num_iter = num_iter
|
|
176
|
+
self.lr_state = lr_state
|
|
177
|
+
self.lr_weight = lr_weight
|
|
178
|
+
self.weight_decay = weight_decay
|
|
179
|
+
self.auto_update = auto_update
|
|
180
|
+
self.activation = activation
|
|
181
|
+
self.output_activation = output_activation
|
|
182
|
+
self.separate_weights = separate_weights
|
|
183
|
+
|
|
184
|
+
if self.separate_weights:
|
|
185
|
+
self.encoder = nn.Linear(in_features, out_features, bias=False)
|
|
186
|
+
self.decoder = nn.Linear(out_features, in_features, bias=False)
|
|
187
|
+
self.optimizer = torch.optim.Adam(
|
|
188
|
+
list(self.encoder.parameters()) + list(self.decoder.parameters()),
|
|
189
|
+
lr=lr_weight,
|
|
190
|
+
weight_decay=weight_decay
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
|
|
194
|
+
self.optimizer = torch.optim.Adam([self.weight], lr=lr_weight, weight_decay=weight_decay)
|
|
195
|
+
|
|
196
|
+
self._init_weights(self)
|
|
197
|
+
|
|
198
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
199
|
+
''' 将输入投影到隐藏状态空间(线性变换)。 '''
|
|
200
|
+
if self.separate_weights:
|
|
201
|
+
return self.encoder(x)
|
|
202
|
+
return F.linear(x, self.weight)
|
|
203
|
+
|
|
204
|
+
def decode(self, state: torch.Tensor) -> torch.Tensor:
|
|
205
|
+
''' 将隐藏状态投影回输入空间(线性变换)。 '''
|
|
206
|
+
if self.separate_weights:
|
|
207
|
+
return self.decoder(state)
|
|
208
|
+
return F.linear(state, self.weight.t())
|
|
209
|
+
|
|
210
|
+
def step(
|
|
211
|
+
self,
|
|
212
|
+
x: torch.Tensor,
|
|
213
|
+
state: torch.Tensor,
|
|
214
|
+
mask: torch.Tensor = None,
|
|
215
|
+
top_down_input: torch.Tensor = None,
|
|
216
|
+
feature_weights: torch.Tensor = None
|
|
217
|
+
) -> torch.Tensor:
|
|
218
|
+
''' 执行单步状态更新。
|
|
219
|
+
|
|
220
|
+
使用 Autograd 自动计算能量函数相对于状态的梯度,支持非线性生成模型。
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
x (torch.Tensor): 输入观测值。
|
|
224
|
+
state (torch.Tensor): 当前隐藏状态。
|
|
225
|
+
mask (torch.Tensor, optional): 误差掩码。
|
|
226
|
+
top_down_input (torch.Tensor, optional): 来自高层的预测/先验。
|
|
227
|
+
feature_weights (torch.Tensor, optional): 特征权重。
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
torch.Tensor: 更新后的隐藏状态。
|
|
231
|
+
'''
|
|
232
|
+
with torch.enable_grad():
|
|
233
|
+
state = state.detach().requires_grad_(True)
|
|
234
|
+
|
|
235
|
+
# pred_x = g(state @ W.T)
|
|
236
|
+
pred_x = self.output_activation(self.decode(state))
|
|
237
|
+
|
|
238
|
+
# Energy = 0.5 * || (x - pred_x) * mask ||^2
|
|
239
|
+
error = x - pred_x
|
|
240
|
+
if mask is not None:
|
|
241
|
+
error = error * mask
|
|
242
|
+
|
|
243
|
+
sq_error = error ** 2
|
|
244
|
+
if feature_weights is not None:
|
|
245
|
+
sq_error = sq_error * feature_weights
|
|
246
|
+
|
|
247
|
+
energy = 0.5 * torch.sum(sq_error)
|
|
248
|
+
|
|
249
|
+
# Energy += 0.5 * || state - top_down_input ||^2
|
|
250
|
+
if top_down_input is not None:
|
|
251
|
+
energy = energy + 0.5 * torch.sum((state - top_down_input) ** 2)
|
|
252
|
+
|
|
253
|
+
# dEnergy/dState
|
|
254
|
+
grad_state = torch.autograd.grad(energy, state)[0]
|
|
255
|
+
|
|
256
|
+
# state = state - lr * grad
|
|
257
|
+
new_state = state - self.lr_state * grad_state
|
|
258
|
+
new_state = self.activation(new_state)
|
|
259
|
+
|
|
260
|
+
return new_state.detach()
|
|
261
|
+
|
|
262
|
+
def forward(
|
|
263
|
+
self,
|
|
264
|
+
x: torch.Tensor,
|
|
265
|
+
mask: torch.Tensor = None,
|
|
266
|
+
top_down_input: torch.Tensor = None,
|
|
267
|
+
feature_weights: torch.Tensor = None,
|
|
268
|
+
num_iter: int = None
|
|
269
|
+
) -> PredictiveCodingOutput:
|
|
270
|
+
''' 前向传播。
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
x (torch.Tensor): 输入观测值 (Batch, ..., In_Features)。
|
|
274
|
+
mask (torch.Tensor, optional): 误差掩码。
|
|
275
|
+
top_down_input (torch.Tensor, optional): 来自高层的预测/先验。
|
|
276
|
+
feature_weights (torch.Tensor, optional): 特征权重。
|
|
277
|
+
num_iter (int, optional): 推理迭代次数。如果为 None,使用 self.num_iter。
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
PredictiveCodingOutput: 包含最终隐藏状态和重构结果的对象。
|
|
281
|
+
'''
|
|
282
|
+
original_shape = x.shape
|
|
283
|
+
if x.dim() > 2: x = x.reshape(-1, self.in_features)
|
|
284
|
+
if mask is not None and mask.dim() > 2: mask = mask.reshape(-1, self.in_features)
|
|
285
|
+
if top_down_input is not None and top_down_input.dim() > 2:
|
|
286
|
+
top_down_input = top_down_input.reshape(-1, self.out_features)
|
|
287
|
+
if feature_weights is not None and feature_weights.dim() > 2:
|
|
288
|
+
feature_weights = feature_weights.reshape(-1, self.in_features)
|
|
289
|
+
|
|
290
|
+
with torch.no_grad():
|
|
291
|
+
state = self.activation(self.encode(x))
|
|
292
|
+
|
|
293
|
+
n_iter = num_iter if num_iter is not None else self.num_iter
|
|
294
|
+
for _ in range(n_iter):
|
|
295
|
+
state = self.step(x, state, mask, top_down_input, feature_weights)
|
|
296
|
+
|
|
297
|
+
if self.training and self.auto_update:
|
|
298
|
+
self._update_weights(x, state, feature_weights)
|
|
299
|
+
|
|
300
|
+
if len(original_shape) > 2:
|
|
301
|
+
state = state.reshape(original_shape[:-1] + (self.out_features,))
|
|
302
|
+
|
|
303
|
+
with torch.no_grad():
|
|
304
|
+
state_flat = state.reshape(-1, self.out_features) if state.dim() > 2 else state
|
|
305
|
+
pred_x = self.output_activation(self.decode(state_flat))
|
|
306
|
+
if len(original_shape) > 2:
|
|
307
|
+
pred_x = pred_x.reshape(original_shape)
|
|
308
|
+
|
|
309
|
+
return PredictiveCodingOutput(
|
|
310
|
+
output=state,
|
|
311
|
+
reconstruction=pred_x
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
def predict(self, x: torch.Tensor, mask: torch.Tensor = None, feature_weights: torch.Tensor = None, num_iter: int = None) -> torch.Tensor:
|
|
315
|
+
''' 执行推理并返回重构的输入(包括未观测部分)。
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
x (torch.Tensor): 输入观测值。
|
|
319
|
+
mask (torch.Tensor, optional): 掩码。
|
|
320
|
+
feature_weights (torch.Tensor, optional): 特征权重。
|
|
321
|
+
num_iter (int, optional): 推理迭代次数。
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
torch.Tensor: 重构/预测的输入 (Batch, ..., In_Features)。
|
|
325
|
+
'''
|
|
326
|
+
pc_output = self.forward(x, mask, feature_weights=feature_weights, num_iter=num_iter)
|
|
327
|
+
return pc_output.reconstruction
|
|
328
|
+
|
|
329
|
+
def _update_weights(self, x: torch.Tensor, state: torch.Tensor, feature_weights: torch.Tensor = None):
|
|
330
|
+
''' 更新权重以最小化预测误差。
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
x (torch.Tensor): 输入观测值。
|
|
334
|
+
state (torch.Tensor): 隐藏状态。
|
|
335
|
+
feature_weights (torch.Tensor, optional): 特征权重。
|
|
336
|
+
'''
|
|
337
|
+
x = x.detach()
|
|
338
|
+
state = state.detach()
|
|
339
|
+
|
|
340
|
+
self.optimizer.zero_grad()
|
|
341
|
+
|
|
342
|
+
if self.separate_weights:
|
|
343
|
+
# || x - decoder(state) ||^2
|
|
344
|
+
pred_x = self.output_activation(self.decoder(state))
|
|
345
|
+
error = x - pred_x
|
|
346
|
+
|
|
347
|
+
sq_error = error ** 2
|
|
348
|
+
if feature_weights is not None:
|
|
349
|
+
sq_error = sq_error * feature_weights
|
|
350
|
+
loss_decoder = 0.5 * torch.sum(sq_error)
|
|
351
|
+
|
|
352
|
+
# || state - encoder(x) ||^2 (Amortized Inference)
|
|
353
|
+
pred_state = self.activation(self.encoder(x))
|
|
354
|
+
loss_encoder = 0.5 * torch.sum((state - pred_state) ** 2)
|
|
355
|
+
|
|
356
|
+
loss = loss_decoder + loss_encoder
|
|
357
|
+
loss.backward()
|
|
358
|
+
else:
|
|
359
|
+
pred_x = self.output_activation(F.linear(state, self.weight.t()))
|
|
360
|
+
|
|
361
|
+
error = x - pred_x
|
|
362
|
+
sq_error = error ** 2
|
|
363
|
+
if feature_weights is not None:
|
|
364
|
+
sq_error = sq_error * feature_weights
|
|
365
|
+
loss = 0.5 * torch.sum(sq_error)
|
|
366
|
+
|
|
367
|
+
loss.backward()
|
|
368
|
+
|
|
369
|
+
self.optimizer.step()
|
|
370
|
+
|
|
371
|
+
def get_prediction_error(self, x: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
|
|
372
|
+
''' 计算预测误差。
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
x (torch.Tensor): 输入观测值。
|
|
376
|
+
state (torch.Tensor): 隐藏状态。
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
torch.Tensor: 预测误差 (x - pred_x)。
|
|
380
|
+
'''
|
|
381
|
+
if x.dim() > 2:
|
|
382
|
+
x = x.reshape(-1, self.in_features)
|
|
383
|
+
state = state.reshape(-1, self.out_features)
|
|
384
|
+
|
|
385
|
+
with torch.no_grad():
|
|
386
|
+
pred_x = self.output_activation(self.decode(state))
|
|
387
|
+
return x - pred_x
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
@register_model()
|
|
391
|
+
class PredictiveCodingBlock(BaseBlock):
|
|
392
|
+
''' 分层预测编码块。
|
|
393
|
+
|
|
394
|
+
自动管理多层 PredictiveCodingLayer,实现分层预测编码网络。
|
|
395
|
+
支持任意深度的层级结构和联合推理。
|
|
396
|
+
'''
|
|
397
|
+
def __init__(
|
|
398
|
+
self,
|
|
399
|
+
in_features: int,
|
|
400
|
+
hidden_dims: list[int] | int,
|
|
401
|
+
num_iter: int = 10,
|
|
402
|
+
lr_state: float = 0.1,
|
|
403
|
+
lr_weight: float = 1e-3,
|
|
404
|
+
weight_decay: float = 0.0,
|
|
405
|
+
auto_update: bool = True,
|
|
406
|
+
activation: nn.Module = nn.Tanh(),
|
|
407
|
+
output_activations: list[nn.Module] = None,
|
|
408
|
+
separate_weights: bool = False
|
|
409
|
+
):
|
|
410
|
+
''' 初始化分层预测编码块。
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
in_features (int): 输入特征维度。
|
|
414
|
+
hidden_dims (list[int] | int): 隐藏层维度列表。
|
|
415
|
+
num_iter (int, optional): 推理迭代次数。
|
|
416
|
+
lr_state (float, optional): 状态更新率。
|
|
417
|
+
lr_weight (float, optional): 权重更新率。
|
|
418
|
+
weight_decay (float, optional): 权重衰减率。
|
|
419
|
+
auto_update (bool, optional): 是否自动更新权重。
|
|
420
|
+
activation (nn.Module, optional): 状态激活函数。
|
|
421
|
+
output_activations (list[nn.Module], optional): 每层的输出激活函数列表。
|
|
422
|
+
separate_weights (bool, optional): 是否使用分离的编码器和解码器权重。
|
|
423
|
+
'''
|
|
424
|
+
super(PredictiveCodingBlock, self).__init__()
|
|
425
|
+
|
|
426
|
+
if isinstance(hidden_dims, int):
|
|
427
|
+
hidden_dims = [hidden_dims]
|
|
428
|
+
|
|
429
|
+
self.dims = [in_features] + hidden_dims
|
|
430
|
+
self.num_iter = num_iter
|
|
431
|
+
self.auto_update = auto_update
|
|
432
|
+
|
|
433
|
+
if output_activations is None:
|
|
434
|
+
output_activations = []
|
|
435
|
+
output_activations.append(nn.LeakyReLU())
|
|
436
|
+
for _ in range(len(self.dims) - 2):
|
|
437
|
+
output_activations.append(nn.LeakyReLU())
|
|
438
|
+
|
|
439
|
+
self.layers: nn.ModuleList[PredictiveCodingLayer] = nn.ModuleList()
|
|
440
|
+
for i in range(len(self.dims) - 1):
|
|
441
|
+
out_act = output_activations[i] if i < len(output_activations) else nn.Identity()
|
|
442
|
+
|
|
443
|
+
self.layers.append(PredictiveCodingLayer(
|
|
444
|
+
in_features=self.dims[i],
|
|
445
|
+
out_features=self.dims[i+1],
|
|
446
|
+
num_iter=num_iter,
|
|
447
|
+
lr_state=lr_state,
|
|
448
|
+
lr_weight=lr_weight,
|
|
449
|
+
weight_decay=weight_decay,
|
|
450
|
+
auto_update=False,
|
|
451
|
+
activation=activation,
|
|
452
|
+
output_activation=out_act,
|
|
453
|
+
separate_weights=separate_weights
|
|
454
|
+
))
|
|
455
|
+
|
|
456
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, feature_weights: torch.Tensor = None, num_iter: int = None) -> PredictiveCodingOutput:
|
|
457
|
+
''' 前向传播(联合推理)。
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
x (torch.Tensor): 输入观测值。
|
|
461
|
+
mask (torch.Tensor, optional): 输入层的误差掩码。
|
|
462
|
+
feature_weights (torch.Tensor, optional): 输入层的特征权重。
|
|
463
|
+
num_iter (int, optional): 推理迭代次数。如果为 None,使用 self.num_iter。
|
|
464
|
+
|
|
465
|
+
Returns:
|
|
466
|
+
PredictiveCodingOutput: 包含第一层隐藏状态、重构结果和所有层状态的对象。
|
|
467
|
+
'''
|
|
468
|
+
original_shape = x.shape
|
|
469
|
+
if x.dim() > 2: x = x.reshape(-1, self.dims[0])
|
|
470
|
+
if mask is not None and mask.dim() > 2: mask = mask.reshape(-1, self.dims[0])
|
|
471
|
+
if feature_weights is not None and feature_weights.dim() > 2:
|
|
472
|
+
feature_weights = feature_weights.reshape(-1, self.dims[0])
|
|
473
|
+
|
|
474
|
+
states = []
|
|
475
|
+
curr_input = x
|
|
476
|
+
for layer in self.layers:
|
|
477
|
+
s = layer.activation(layer.encode(curr_input))
|
|
478
|
+
states.append(s)
|
|
479
|
+
curr_input = s
|
|
480
|
+
|
|
481
|
+
n_iter = num_iter if num_iter is not None else self.num_iter
|
|
482
|
+
for _ in range(n_iter):
|
|
483
|
+
top_down_preds = [None] * len(self.layers)
|
|
484
|
+
for i in range(len(self.layers) - 1):
|
|
485
|
+
with torch.no_grad():
|
|
486
|
+
top_down_preds[i] = self.layers[i+1].output_activation(
|
|
487
|
+
self.layers[i+1].decode(states[i+1])
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
new_states = []
|
|
491
|
+
for i, layer in enumerate(self.layers):
|
|
492
|
+
inp = x if i == 0 else states[i-1]
|
|
493
|
+
|
|
494
|
+
msk = mask if i == 0 else None
|
|
495
|
+
fw = feature_weights if i == 0 else None
|
|
496
|
+
|
|
497
|
+
new_s = layer.step(
|
|
498
|
+
x=inp,
|
|
499
|
+
state=states[i],
|
|
500
|
+
mask=msk,
|
|
501
|
+
top_down_input=top_down_preds[i],
|
|
502
|
+
feature_weights=fw
|
|
503
|
+
)
|
|
504
|
+
new_states.append(new_s)
|
|
505
|
+
states = new_states
|
|
506
|
+
|
|
507
|
+
if self.training and self.auto_update:
|
|
508
|
+
for i, layer in enumerate(self.layers):
|
|
509
|
+
inp = x if i == 0 else states[i-1]
|
|
510
|
+
fw = feature_weights if i == 0 else None
|
|
511
|
+
layer._update_weights(inp, states[i], feature_weights=fw)
|
|
512
|
+
|
|
513
|
+
state1 = states[0]
|
|
514
|
+
|
|
515
|
+
if len(original_shape) > 2:
|
|
516
|
+
state1 = state1.reshape(original_shape[:-1] + (self.dims[1],))
|
|
517
|
+
|
|
518
|
+
with torch.no_grad():
|
|
519
|
+
state1_flat = state1.reshape(-1, self.dims[1]) if state1.dim() > 2 else state1
|
|
520
|
+
pred_x = self.layers[0].output_activation(self.layers[0].decode(state1_flat))
|
|
521
|
+
if len(original_shape) > 2:
|
|
522
|
+
pred_x = pred_x.reshape(original_shape)
|
|
523
|
+
|
|
524
|
+
return PredictiveCodingOutput(
|
|
525
|
+
output=state1,
|
|
526
|
+
reconstruction=pred_x,
|
|
527
|
+
hidden_states=tuple(states)
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
def predict(self, x: torch.Tensor, mask: torch.Tensor = None, feature_weights: torch.Tensor = None, num_iter: int = None) -> torch.Tensor:
|
|
531
|
+
''' 执行推理并返回重构的输入。 '''
|
|
532
|
+
pc_output = self.forward(x, mask, feature_weights=feature_weights, num_iter=num_iter)
|
|
533
|
+
return pc_output.reconstruction
|
|
534
|
+
|
|
535
|
+
def get_prediction_error(self, x: torch.Tensor, state1: torch.Tensor) -> torch.Tensor:
|
|
536
|
+
''' 计算输入层的预测误差。 '''
|
|
537
|
+
return self.layers[0].get_prediction_error(x, state1)
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
from orbit.model import BaseBlock, register_model
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class QuantizerOutput:
|
|
10
|
+
z_q: torch.Tensor
|
|
11
|
+
loss: torch.Tensor
|
|
12
|
+
indices: torch.Tensor
|
|
13
|
+
entropy: torch.Tensor
|
|
14
|
+
perplexity: torch.Tensor
|
|
15
|
+
|
|
16
|
+
@register_model()
|
|
17
|
+
class LFQ(BaseBlock):
|
|
18
|
+
'''
|
|
19
|
+
Lookup-Free Quantization (LFQ) 模块。
|
|
20
|
+
|
|
21
|
+
基于 MagViT-2 论文。直接将 Latent 投影到低维空间进行二值化 (Sign),
|
|
22
|
+
并将二进制位组合成整数索引。
|
|
23
|
+
|
|
24
|
+
优点:
|
|
25
|
+
1. 计算效率高 (无最近邻搜索)。
|
|
26
|
+
2. 支持超大词表 (如 codebook_dim=18 -> 262144 词表)。
|
|
27
|
+
3. 训练更稳定。
|
|
28
|
+
'''
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
latent_dim: int = 256,
|
|
33
|
+
codebook_dim: int = 18,
|
|
34
|
+
entropy_weight: float = 0.1,
|
|
35
|
+
commitment_weight: float = 0.25,
|
|
36
|
+
diversity_gamma: float = 1.0,
|
|
37
|
+
):
|
|
38
|
+
'''
|
|
39
|
+
Args:
|
|
40
|
+
latent_dim (int): 输入/输出特征的维度 (Encoder 输出维度)。
|
|
41
|
+
codebook_dim (int): 量化空间的维度 (Bit 数)。词表大小为 2^codebook_dim。
|
|
42
|
+
entropy_weight (float): 熵损失权重,鼓励 Codebook 利用率。
|
|
43
|
+
commitment_weight (float): 承诺损失权重,拉近 Encoder 输出与量化值的距离。
|
|
44
|
+
diversity_gamma (float): 熵惩罚的缩放系数。
|
|
45
|
+
'''
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.latent_dim = latent_dim
|
|
48
|
+
self.codebook_dim = codebook_dim
|
|
49
|
+
self.entropy_weight = entropy_weight
|
|
50
|
+
self.commitment_weight = commitment_weight
|
|
51
|
+
self.diversity_gamma = diversity_gamma
|
|
52
|
+
|
|
53
|
+
self.project_in = nn.Linear(latent_dim, codebook_dim) if latent_dim != codebook_dim else nn.Identity()
|
|
54
|
+
self.project_out = nn.Linear(codebook_dim, latent_dim) if latent_dim != codebook_dim else nn.Identity()
|
|
55
|
+
|
|
56
|
+
self.register_buffer("basis", 2 ** torch.arange(codebook_dim))
|
|
57
|
+
|
|
58
|
+
def entropy_loss(self, affine_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
59
|
+
'''
|
|
60
|
+
计算基于 Bit 的熵损失。
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
affine_logits: 投影后的 logits [B*H*W, codebook_dim]
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
loss: 熵损失标量 (希望熵最大化 -> 损失最小化)
|
|
67
|
+
avg_entropy: 平均熵 (监控用)
|
|
68
|
+
'''
|
|
69
|
+
probs = torch.sigmoid(affine_logits)
|
|
70
|
+
|
|
71
|
+
# [B*H*W, D] -> [D]
|
|
72
|
+
avg_probs = torch.mean(probs, dim=0)
|
|
73
|
+
|
|
74
|
+
entropy = - (avg_probs * torch.log(avg_probs + 1e-5) +
|
|
75
|
+
(1 - avg_probs) * torch.log(1 - avg_probs + 1e-5))
|
|
76
|
+
|
|
77
|
+
loss = - torch.mean(entropy) * self.diversity_gamma
|
|
78
|
+
|
|
79
|
+
return loss, torch.mean(entropy)
|
|
80
|
+
|
|
81
|
+
def forward(self, z: torch.Tensor) -> QuantizerOutput:
|
|
82
|
+
'''
|
|
83
|
+
Args:
|
|
84
|
+
z (torch.Tensor): [B, C, H, W]
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
QuantizerOutput
|
|
88
|
+
'''
|
|
89
|
+
B, C, H, W = z.shape
|
|
90
|
+
|
|
91
|
+
z_permuted = z.permute(0, 2, 3, 1).contiguous()
|
|
92
|
+
z_flattened = z_permuted.view(-1, C)
|
|
93
|
+
|
|
94
|
+
z_e = self.project_in(z_flattened)
|
|
95
|
+
|
|
96
|
+
z_q = torch.sign(z_e)
|
|
97
|
+
|
|
98
|
+
z_q = z_e + (z_q - z_e).detach()
|
|
99
|
+
|
|
100
|
+
commitment_loss = torch.mean((z_q.detach() - z_e) ** 2)
|
|
101
|
+
|
|
102
|
+
entropy_loss, avg_entropy = self.entropy_loss(z_e)
|
|
103
|
+
|
|
104
|
+
total_loss = self.commitment_weight * commitment_loss + self.entropy_weight * entropy_loss
|
|
105
|
+
|
|
106
|
+
# [N, codebook_dim] * [codebook_dim] -> sum -> [N]
|
|
107
|
+
is_positive = (z_q > 0).long()
|
|
108
|
+
indices = (is_positive * self.basis).sum(dim=1)
|
|
109
|
+
indices = indices.view(B, H, W)
|
|
110
|
+
|
|
111
|
+
z_out = self.project_out(z_q)
|
|
112
|
+
z_out = z_out.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
|
|
113
|
+
|
|
114
|
+
perplexity = 2 ** avg_entropy
|
|
115
|
+
|
|
116
|
+
return QuantizerOutput(
|
|
117
|
+
z_q=z_out,
|
|
118
|
+
loss=total_loss,
|
|
119
|
+
indices=indices,
|
|
120
|
+
entropy=avg_entropy,
|
|
121
|
+
perplexity=perplexity
|
|
122
|
+
)
|