orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. orbit/__init__.py +3 -1
  2. orbit/callback.py +4 -3
  3. orbit/dataset/__init__.py +1 -0
  4. orbit/dataset/cogn.py +138 -0
  5. orbit/dataset/data/cogn_en.jsonl +45 -0
  6. orbit/dataset/data/cogn_zh.jsonl +113 -0
  7. orbit/engine.py +210 -146
  8. orbit/kit/__init__.py +2 -0
  9. orbit/kit/interface.py +154 -0
  10. orbit/kit/wrapper.py +157 -0
  11. orbit/model/__init__.py +5 -0
  12. orbit/model/base.py +125 -0
  13. orbit/model/block/__init__.py +34 -0
  14. orbit/model/block/attention.py +265 -0
  15. orbit/model/block/bio.py +537 -0
  16. orbit/model/block/codebook.py +122 -0
  17. orbit/model/block/conv.py +505 -0
  18. orbit/model/block/embedding.py +252 -0
  19. orbit/model/block/film.py +176 -0
  20. orbit/model/block/fusion.py +335 -0
  21. orbit/model/block/gate.py +334 -0
  22. orbit/model/block/lora.py +776 -0
  23. orbit/model/block/mlp.py +68 -0
  24. orbit/model/block/moe.py +94 -0
  25. orbit/model/block/tcn.py +99 -0
  26. orbit/model/config.py +62 -0
  27. orbit/model/kit/__init__.py +6 -0
  28. orbit/model/kit/discriminator.py +46 -0
  29. orbit/model/kit/losses.py +193 -0
  30. orbit/model/motif/__init__.py +0 -0
  31. orbit/model/motif/vision/__init__.py +0 -0
  32. orbit/model/motif/vision/v1.py +645 -0
  33. orbit/model/registry.py +53 -0
  34. orbit/optim/__init__.py +2 -2
  35. orbit/optim/sam.py +10 -3
  36. orbit/plugin/__init__.py +12 -8
  37. orbit/plugin/board.py +1 -2
  38. orbit/plugin/checkpoint.py +137 -62
  39. orbit/plugin/classification.py +2 -2
  40. orbit/plugin/display_model.py +1 -2
  41. orbit/plugin/early_stopping.py +1 -2
  42. orbit/plugin/ema.py +1 -2
  43. orbit/plugin/gradient_accumulation.py +1 -2
  44. orbit/plugin/lora.py +346 -0
  45. orbit/plugin/memory_estimator.py +1 -2
  46. orbit/plugin/warmup.py +1 -2
  47. orbit/utils/__init__.py +24 -1
  48. orbit/utils/cuda.py +10 -0
  49. orbit/utils/freeze.py +61 -17
  50. orbit/utils/image.py +164 -0
  51. orbit/utils/initialization.py +184 -94
  52. orbit/utils/layer_io.py +66 -7
  53. orbit/utils/lora.py +480 -0
  54. orbit/utils/moe.py +55 -0
  55. orbit/utils/seed.py +3 -19
  56. orbit/utils/sft.py +93 -0
  57. orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
  58. orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
  59. orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
  60. orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
  61. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
  62. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,537 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple
6
+
7
+ from orbit.model import BaseBlock, register_model
8
+
9
+
10
+ @dataclass
11
+ class PredictiveCodingOutput:
12
+ output: torch.Tensor
13
+ reconstruction: Optional[torch.Tensor] = None
14
+ hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
15
+
16
+
17
+ @register_model()
18
+ class HebianLayer(BaseBlock):
19
+ ''' Hebbian Learning Layer.
20
+
21
+ 实现基于 Hebbian 规则的无监督学习层。支持标准 Hebbian 规则和 Oja 规则。
22
+ '''
23
+ def __init__(
24
+ self,
25
+ in_features: int,
26
+ out_features: int,
27
+ lr: float = 1e-3,
28
+ mode: str = 'oja',
29
+ bias: bool = True,
30
+ auto_update: bool = True
31
+ ):
32
+ ''' 初始化 Hebbian 学习层。
33
+
34
+ Args:
35
+ in_features (int): 输入特征维度。
36
+ out_features (int): 输出特征维度。
37
+ lr (float, optional): Hebbian 学习率。默认为 1e-3。
38
+ mode (str, optional): 更新模式,可选 'basic' 或 'oja'。默认为 'oja'。
39
+ bias (bool, optional): 是否使用偏置。默认为 True。
40
+ auto_update (bool, optional): 是否在 forward 中自动更新权重。默认为 True。
41
+ '''
42
+ super(HebianLayer, self).__init__()
43
+ self.in_features = in_features
44
+ self.out_features = out_features
45
+ self.lr = lr
46
+ self.mode = mode.lower()
47
+ self.auto_update = auto_update
48
+
49
+ if self.mode not in ['basic', 'oja']:
50
+ raise ValueError(f"Unsupported mode: {mode}. Must be 'basic' or 'oja'.")
51
+
52
+ self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
53
+
54
+ if bias:
55
+ self.bias = nn.Parameter(torch.Tensor(out_features))
56
+ else:
57
+ self.register_parameter('bias', None)
58
+
59
+ self._init_weights(self)
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ ''' 前向传播。
63
+
64
+ Args:
65
+ x (torch.Tensor): 输入张量 (Batch, ..., In_Features)。
66
+
67
+ Returns:
68
+ torch.Tensor: 输出张量 (Batch, ..., Out_Features)。
69
+ '''
70
+ y = F.linear(x, self.weight, self.bias)
71
+
72
+ if self.training and self.auto_update:
73
+ if x.dim() > 2:
74
+ x_flat = x.reshape(-1, x.size(-1))
75
+ y_flat = y.reshape(-1, y.size(-1))
76
+ self._update_weights(x_flat, y_flat)
77
+ else:
78
+ self._update_weights(x, y)
79
+
80
+ return y
81
+
82
+ @torch.no_grad()
83
+ def _update_weights(self, x: torch.Tensor, y: torch.Tensor):
84
+ ''' 执行权重更新。 '''
85
+ if self.mode == 'basic':
86
+ self._basic_update(x, y)
87
+ elif self.mode == 'oja':
88
+ self._oja_update(x, y)
89
+
90
+ @torch.no_grad()
91
+ def _basic_update(self, x: torch.Tensor, y: torch.Tensor):
92
+ ''' 执行标准 Hebbian 更新规则。
93
+
94
+ Args:
95
+ x (torch.Tensor): 输入张量。
96
+ y (torch.Tensor): 输出张量。
97
+ '''
98
+ batch_size = x.size(0)
99
+
100
+ # y^T * x -> (M, N)
101
+ grad_w = torch.matmul(y.t(), x)
102
+
103
+ self.weight.data += self.lr * grad_w / batch_size
104
+
105
+ if self.bias is not None:
106
+ # db = lr * sum(y)
107
+ grad_b = y.sum(dim=0)
108
+ self.bias.data += self.lr * grad_b / batch_size
109
+
110
+ @torch.no_grad()
111
+ def _oja_update(self, x: torch.Tensor, y: torch.Tensor):
112
+ ''' 执行 Oja 更新规则。
113
+
114
+ Oja 规则通过归一化防止权重无限增长。
115
+
116
+ Args:
117
+ x (torch.Tensor): 输入张量。
118
+ y (torch.Tensor): 输出张量。
119
+ '''
120
+ batch_size = x.size(0)
121
+
122
+ # y^T * x -> (M, N)
123
+ yx = torch.matmul(y.t(), x)
124
+
125
+ # y^2 -> (B, M), 在批次上求和 -> (M)
126
+ y_sq = torch.sum(y ** 2, dim=0)
127
+
128
+ # (M, 1) * (M, N) -> (M, N)
129
+ grad_w = yx - y_sq.unsqueeze(1) * self.weight
130
+
131
+ self.weight.data += self.lr * grad_w / batch_size
132
+
133
+ if self.bias is not None:
134
+ grad_b = y.sum(dim=0) - y_sq * self.bias
135
+ self.bias.data += self.lr * grad_b / batch_size
136
+
137
+
138
+ @register_model()
139
+ class PredictiveCodingLayer(BaseBlock):
140
+ ''' Predictive Coding Layer.
141
+
142
+ 实现基于预测编码原理的层。该层维护一个内部状态(表示),
143
+ 并通过最小化预测误差来更新状态。
144
+ '''
145
+ def __init__(
146
+ self,
147
+ in_features: int,
148
+ out_features: int,
149
+ num_iter: int = 10,
150
+ lr_state: float = 0.01,
151
+ lr_weight: float = 1e-3,
152
+ weight_decay: float = 0.0,
153
+ auto_update: bool = True,
154
+ activation: nn.Module = nn.LeakyReLU(),
155
+ output_activation: nn.Module = nn.Identity(),
156
+ separate_weights: bool = False
157
+ ):
158
+ ''' 初始化预测编码层。
159
+
160
+ Args:
161
+ in_features (int): 输入特征维度。
162
+ out_features (int): 输出特征维度(隐藏状态维度)。
163
+ num_iter (int, optional): 推理时的迭代次数。默认为 10。
164
+ lr_state (float, optional): 状态更新率。默认为 0.01。
165
+ lr_weight (float, optional): 权重更新率。默认为 1e-3。
166
+ weight_decay (float, optional): 权重衰减率。默认为 0.0。
167
+ auto_update (bool, optional): 是否在 forward 中自动更新权重。默认为 True。
168
+ activation (nn.Module, optional): 状态激活函数。默认为 nn.LeakyReLU()。
169
+ output_activation (nn.Module, optional): 输出生成激活函数。默认为 nn.Identity()。
170
+ separate_weights (bool, optional): 是否使用分离的编码器和解码器权重。默认为 False。
171
+ '''
172
+ super(PredictiveCodingLayer, self).__init__()
173
+ self.in_features = in_features
174
+ self.out_features = out_features
175
+ self.num_iter = num_iter
176
+ self.lr_state = lr_state
177
+ self.lr_weight = lr_weight
178
+ self.weight_decay = weight_decay
179
+ self.auto_update = auto_update
180
+ self.activation = activation
181
+ self.output_activation = output_activation
182
+ self.separate_weights = separate_weights
183
+
184
+ if self.separate_weights:
185
+ self.encoder = nn.Linear(in_features, out_features, bias=False)
186
+ self.decoder = nn.Linear(out_features, in_features, bias=False)
187
+ self.optimizer = torch.optim.Adam(
188
+ list(self.encoder.parameters()) + list(self.decoder.parameters()),
189
+ lr=lr_weight,
190
+ weight_decay=weight_decay
191
+ )
192
+ else:
193
+ self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
194
+ self.optimizer = torch.optim.Adam([self.weight], lr=lr_weight, weight_decay=weight_decay)
195
+
196
+ self._init_weights(self)
197
+
198
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
199
+ ''' 将输入投影到隐藏状态空间(线性变换)。 '''
200
+ if self.separate_weights:
201
+ return self.encoder(x)
202
+ return F.linear(x, self.weight)
203
+
204
+ def decode(self, state: torch.Tensor) -> torch.Tensor:
205
+ ''' 将隐藏状态投影回输入空间(线性变换)。 '''
206
+ if self.separate_weights:
207
+ return self.decoder(state)
208
+ return F.linear(state, self.weight.t())
209
+
210
+ def step(
211
+ self,
212
+ x: torch.Tensor,
213
+ state: torch.Tensor,
214
+ mask: torch.Tensor = None,
215
+ top_down_input: torch.Tensor = None,
216
+ feature_weights: torch.Tensor = None
217
+ ) -> torch.Tensor:
218
+ ''' 执行单步状态更新。
219
+
220
+ 使用 Autograd 自动计算能量函数相对于状态的梯度,支持非线性生成模型。
221
+
222
+ Args:
223
+ x (torch.Tensor): 输入观测值。
224
+ state (torch.Tensor): 当前隐藏状态。
225
+ mask (torch.Tensor, optional): 误差掩码。
226
+ top_down_input (torch.Tensor, optional): 来自高层的预测/先验。
227
+ feature_weights (torch.Tensor, optional): 特征权重。
228
+
229
+ Returns:
230
+ torch.Tensor: 更新后的隐藏状态。
231
+ '''
232
+ with torch.enable_grad():
233
+ state = state.detach().requires_grad_(True)
234
+
235
+ # pred_x = g(state @ W.T)
236
+ pred_x = self.output_activation(self.decode(state))
237
+
238
+ # Energy = 0.5 * || (x - pred_x) * mask ||^2
239
+ error = x - pred_x
240
+ if mask is not None:
241
+ error = error * mask
242
+
243
+ sq_error = error ** 2
244
+ if feature_weights is not None:
245
+ sq_error = sq_error * feature_weights
246
+
247
+ energy = 0.5 * torch.sum(sq_error)
248
+
249
+ # Energy += 0.5 * || state - top_down_input ||^2
250
+ if top_down_input is not None:
251
+ energy = energy + 0.5 * torch.sum((state - top_down_input) ** 2)
252
+
253
+ # dEnergy/dState
254
+ grad_state = torch.autograd.grad(energy, state)[0]
255
+
256
+ # state = state - lr * grad
257
+ new_state = state - self.lr_state * grad_state
258
+ new_state = self.activation(new_state)
259
+
260
+ return new_state.detach()
261
+
262
+ def forward(
263
+ self,
264
+ x: torch.Tensor,
265
+ mask: torch.Tensor = None,
266
+ top_down_input: torch.Tensor = None,
267
+ feature_weights: torch.Tensor = None,
268
+ num_iter: int = None
269
+ ) -> PredictiveCodingOutput:
270
+ ''' 前向传播。
271
+
272
+ Args:
273
+ x (torch.Tensor): 输入观测值 (Batch, ..., In_Features)。
274
+ mask (torch.Tensor, optional): 误差掩码。
275
+ top_down_input (torch.Tensor, optional): 来自高层的预测/先验。
276
+ feature_weights (torch.Tensor, optional): 特征权重。
277
+ num_iter (int, optional): 推理迭代次数。如果为 None,使用 self.num_iter。
278
+
279
+ Returns:
280
+ PredictiveCodingOutput: 包含最终隐藏状态和重构结果的对象。
281
+ '''
282
+ original_shape = x.shape
283
+ if x.dim() > 2: x = x.reshape(-1, self.in_features)
284
+ if mask is not None and mask.dim() > 2: mask = mask.reshape(-1, self.in_features)
285
+ if top_down_input is not None and top_down_input.dim() > 2:
286
+ top_down_input = top_down_input.reshape(-1, self.out_features)
287
+ if feature_weights is not None and feature_weights.dim() > 2:
288
+ feature_weights = feature_weights.reshape(-1, self.in_features)
289
+
290
+ with torch.no_grad():
291
+ state = self.activation(self.encode(x))
292
+
293
+ n_iter = num_iter if num_iter is not None else self.num_iter
294
+ for _ in range(n_iter):
295
+ state = self.step(x, state, mask, top_down_input, feature_weights)
296
+
297
+ if self.training and self.auto_update:
298
+ self._update_weights(x, state, feature_weights)
299
+
300
+ if len(original_shape) > 2:
301
+ state = state.reshape(original_shape[:-1] + (self.out_features,))
302
+
303
+ with torch.no_grad():
304
+ state_flat = state.reshape(-1, self.out_features) if state.dim() > 2 else state
305
+ pred_x = self.output_activation(self.decode(state_flat))
306
+ if len(original_shape) > 2:
307
+ pred_x = pred_x.reshape(original_shape)
308
+
309
+ return PredictiveCodingOutput(
310
+ output=state,
311
+ reconstruction=pred_x
312
+ )
313
+
314
+ def predict(self, x: torch.Tensor, mask: torch.Tensor = None, feature_weights: torch.Tensor = None, num_iter: int = None) -> torch.Tensor:
315
+ ''' 执行推理并返回重构的输入(包括未观测部分)。
316
+
317
+ Args:
318
+ x (torch.Tensor): 输入观测值。
319
+ mask (torch.Tensor, optional): 掩码。
320
+ feature_weights (torch.Tensor, optional): 特征权重。
321
+ num_iter (int, optional): 推理迭代次数。
322
+
323
+ Returns:
324
+ torch.Tensor: 重构/预测的输入 (Batch, ..., In_Features)。
325
+ '''
326
+ pc_output = self.forward(x, mask, feature_weights=feature_weights, num_iter=num_iter)
327
+ return pc_output.reconstruction
328
+
329
+ def _update_weights(self, x: torch.Tensor, state: torch.Tensor, feature_weights: torch.Tensor = None):
330
+ ''' 更新权重以最小化预测误差。
331
+
332
+ Args:
333
+ x (torch.Tensor): 输入观测值。
334
+ state (torch.Tensor): 隐藏状态。
335
+ feature_weights (torch.Tensor, optional): 特征权重。
336
+ '''
337
+ x = x.detach()
338
+ state = state.detach()
339
+
340
+ self.optimizer.zero_grad()
341
+
342
+ if self.separate_weights:
343
+ # || x - decoder(state) ||^2
344
+ pred_x = self.output_activation(self.decoder(state))
345
+ error = x - pred_x
346
+
347
+ sq_error = error ** 2
348
+ if feature_weights is not None:
349
+ sq_error = sq_error * feature_weights
350
+ loss_decoder = 0.5 * torch.sum(sq_error)
351
+
352
+ # || state - encoder(x) ||^2 (Amortized Inference)
353
+ pred_state = self.activation(self.encoder(x))
354
+ loss_encoder = 0.5 * torch.sum((state - pred_state) ** 2)
355
+
356
+ loss = loss_decoder + loss_encoder
357
+ loss.backward()
358
+ else:
359
+ pred_x = self.output_activation(F.linear(state, self.weight.t()))
360
+
361
+ error = x - pred_x
362
+ sq_error = error ** 2
363
+ if feature_weights is not None:
364
+ sq_error = sq_error * feature_weights
365
+ loss = 0.5 * torch.sum(sq_error)
366
+
367
+ loss.backward()
368
+
369
+ self.optimizer.step()
370
+
371
+ def get_prediction_error(self, x: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
372
+ ''' 计算预测误差。
373
+
374
+ Args:
375
+ x (torch.Tensor): 输入观测值。
376
+ state (torch.Tensor): 隐藏状态。
377
+
378
+ Returns:
379
+ torch.Tensor: 预测误差 (x - pred_x)。
380
+ '''
381
+ if x.dim() > 2:
382
+ x = x.reshape(-1, self.in_features)
383
+ state = state.reshape(-1, self.out_features)
384
+
385
+ with torch.no_grad():
386
+ pred_x = self.output_activation(self.decode(state))
387
+ return x - pred_x
388
+
389
+
390
+ @register_model()
391
+ class PredictiveCodingBlock(BaseBlock):
392
+ ''' 分层预测编码块。
393
+
394
+ 自动管理多层 PredictiveCodingLayer,实现分层预测编码网络。
395
+ 支持任意深度的层级结构和联合推理。
396
+ '''
397
+ def __init__(
398
+ self,
399
+ in_features: int,
400
+ hidden_dims: list[int] | int,
401
+ num_iter: int = 10,
402
+ lr_state: float = 0.1,
403
+ lr_weight: float = 1e-3,
404
+ weight_decay: float = 0.0,
405
+ auto_update: bool = True,
406
+ activation: nn.Module = nn.Tanh(),
407
+ output_activations: list[nn.Module] = None,
408
+ separate_weights: bool = False
409
+ ):
410
+ ''' 初始化分层预测编码块。
411
+
412
+ Args:
413
+ in_features (int): 输入特征维度。
414
+ hidden_dims (list[int] | int): 隐藏层维度列表。
415
+ num_iter (int, optional): 推理迭代次数。
416
+ lr_state (float, optional): 状态更新率。
417
+ lr_weight (float, optional): 权重更新率。
418
+ weight_decay (float, optional): 权重衰减率。
419
+ auto_update (bool, optional): 是否自动更新权重。
420
+ activation (nn.Module, optional): 状态激活函数。
421
+ output_activations (list[nn.Module], optional): 每层的输出激活函数列表。
422
+ separate_weights (bool, optional): 是否使用分离的编码器和解码器权重。
423
+ '''
424
+ super(PredictiveCodingBlock, self).__init__()
425
+
426
+ if isinstance(hidden_dims, int):
427
+ hidden_dims = [hidden_dims]
428
+
429
+ self.dims = [in_features] + hidden_dims
430
+ self.num_iter = num_iter
431
+ self.auto_update = auto_update
432
+
433
+ if output_activations is None:
434
+ output_activations = []
435
+ output_activations.append(nn.LeakyReLU())
436
+ for _ in range(len(self.dims) - 2):
437
+ output_activations.append(nn.LeakyReLU())
438
+
439
+ self.layers: nn.ModuleList[PredictiveCodingLayer] = nn.ModuleList()
440
+ for i in range(len(self.dims) - 1):
441
+ out_act = output_activations[i] if i < len(output_activations) else nn.Identity()
442
+
443
+ self.layers.append(PredictiveCodingLayer(
444
+ in_features=self.dims[i],
445
+ out_features=self.dims[i+1],
446
+ num_iter=num_iter,
447
+ lr_state=lr_state,
448
+ lr_weight=lr_weight,
449
+ weight_decay=weight_decay,
450
+ auto_update=False,
451
+ activation=activation,
452
+ output_activation=out_act,
453
+ separate_weights=separate_weights
454
+ ))
455
+
456
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None, feature_weights: torch.Tensor = None, num_iter: int = None) -> PredictiveCodingOutput:
457
+ ''' 前向传播(联合推理)。
458
+
459
+ Args:
460
+ x (torch.Tensor): 输入观测值。
461
+ mask (torch.Tensor, optional): 输入层的误差掩码。
462
+ feature_weights (torch.Tensor, optional): 输入层的特征权重。
463
+ num_iter (int, optional): 推理迭代次数。如果为 None,使用 self.num_iter。
464
+
465
+ Returns:
466
+ PredictiveCodingOutput: 包含第一层隐藏状态、重构结果和所有层状态的对象。
467
+ '''
468
+ original_shape = x.shape
469
+ if x.dim() > 2: x = x.reshape(-1, self.dims[0])
470
+ if mask is not None and mask.dim() > 2: mask = mask.reshape(-1, self.dims[0])
471
+ if feature_weights is not None and feature_weights.dim() > 2:
472
+ feature_weights = feature_weights.reshape(-1, self.dims[0])
473
+
474
+ states = []
475
+ curr_input = x
476
+ for layer in self.layers:
477
+ s = layer.activation(layer.encode(curr_input))
478
+ states.append(s)
479
+ curr_input = s
480
+
481
+ n_iter = num_iter if num_iter is not None else self.num_iter
482
+ for _ in range(n_iter):
483
+ top_down_preds = [None] * len(self.layers)
484
+ for i in range(len(self.layers) - 1):
485
+ with torch.no_grad():
486
+ top_down_preds[i] = self.layers[i+1].output_activation(
487
+ self.layers[i+1].decode(states[i+1])
488
+ )
489
+
490
+ new_states = []
491
+ for i, layer in enumerate(self.layers):
492
+ inp = x if i == 0 else states[i-1]
493
+
494
+ msk = mask if i == 0 else None
495
+ fw = feature_weights if i == 0 else None
496
+
497
+ new_s = layer.step(
498
+ x=inp,
499
+ state=states[i],
500
+ mask=msk,
501
+ top_down_input=top_down_preds[i],
502
+ feature_weights=fw
503
+ )
504
+ new_states.append(new_s)
505
+ states = new_states
506
+
507
+ if self.training and self.auto_update:
508
+ for i, layer in enumerate(self.layers):
509
+ inp = x if i == 0 else states[i-1]
510
+ fw = feature_weights if i == 0 else None
511
+ layer._update_weights(inp, states[i], feature_weights=fw)
512
+
513
+ state1 = states[0]
514
+
515
+ if len(original_shape) > 2:
516
+ state1 = state1.reshape(original_shape[:-1] + (self.dims[1],))
517
+
518
+ with torch.no_grad():
519
+ state1_flat = state1.reshape(-1, self.dims[1]) if state1.dim() > 2 else state1
520
+ pred_x = self.layers[0].output_activation(self.layers[0].decode(state1_flat))
521
+ if len(original_shape) > 2:
522
+ pred_x = pred_x.reshape(original_shape)
523
+
524
+ return PredictiveCodingOutput(
525
+ output=state1,
526
+ reconstruction=pred_x,
527
+ hidden_states=tuple(states)
528
+ )
529
+
530
+ def predict(self, x: torch.Tensor, mask: torch.Tensor = None, feature_weights: torch.Tensor = None, num_iter: int = None) -> torch.Tensor:
531
+ ''' 执行推理并返回重构的输入。 '''
532
+ pc_output = self.forward(x, mask, feature_weights=feature_weights, num_iter=num_iter)
533
+ return pc_output.reconstruction
534
+
535
+ def get_prediction_error(self, x: torch.Tensor, state1: torch.Tensor) -> torch.Tensor:
536
+ ''' 计算输入层的预测误差。 '''
537
+ return self.layers[0].get_prediction_error(x, state1)
@@ -0,0 +1,122 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from dataclasses import dataclass
4
+ from typing import Tuple
5
+
6
+ from orbit.model import BaseBlock, register_model
7
+
8
+ @dataclass
9
+ class QuantizerOutput:
10
+ z_q: torch.Tensor
11
+ loss: torch.Tensor
12
+ indices: torch.Tensor
13
+ entropy: torch.Tensor
14
+ perplexity: torch.Tensor
15
+
16
+ @register_model()
17
+ class LFQ(BaseBlock):
18
+ '''
19
+ Lookup-Free Quantization (LFQ) 模块。
20
+
21
+ 基于 MagViT-2 论文。直接将 Latent 投影到低维空间进行二值化 (Sign),
22
+ 并将二进制位组合成整数索引。
23
+
24
+ 优点:
25
+ 1. 计算效率高 (无最近邻搜索)。
26
+ 2. 支持超大词表 (如 codebook_dim=18 -> 262144 词表)。
27
+ 3. 训练更稳定。
28
+ '''
29
+
30
+ def __init__(
31
+ self,
32
+ latent_dim: int = 256,
33
+ codebook_dim: int = 18,
34
+ entropy_weight: float = 0.1,
35
+ commitment_weight: float = 0.25,
36
+ diversity_gamma: float = 1.0,
37
+ ):
38
+ '''
39
+ Args:
40
+ latent_dim (int): 输入/输出特征的维度 (Encoder 输出维度)。
41
+ codebook_dim (int): 量化空间的维度 (Bit 数)。词表大小为 2^codebook_dim。
42
+ entropy_weight (float): 熵损失权重,鼓励 Codebook 利用率。
43
+ commitment_weight (float): 承诺损失权重,拉近 Encoder 输出与量化值的距离。
44
+ diversity_gamma (float): 熵惩罚的缩放系数。
45
+ '''
46
+ super().__init__()
47
+ self.latent_dim = latent_dim
48
+ self.codebook_dim = codebook_dim
49
+ self.entropy_weight = entropy_weight
50
+ self.commitment_weight = commitment_weight
51
+ self.diversity_gamma = diversity_gamma
52
+
53
+ self.project_in = nn.Linear(latent_dim, codebook_dim) if latent_dim != codebook_dim else nn.Identity()
54
+ self.project_out = nn.Linear(codebook_dim, latent_dim) if latent_dim != codebook_dim else nn.Identity()
55
+
56
+ self.register_buffer("basis", 2 ** torch.arange(codebook_dim))
57
+
58
+ def entropy_loss(self, affine_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
59
+ '''
60
+ 计算基于 Bit 的熵损失。
61
+
62
+ Args:
63
+ affine_logits: 投影后的 logits [B*H*W, codebook_dim]
64
+
65
+ Returns:
66
+ loss: 熵损失标量 (希望熵最大化 -> 损失最小化)
67
+ avg_entropy: 平均熵 (监控用)
68
+ '''
69
+ probs = torch.sigmoid(affine_logits)
70
+
71
+ # [B*H*W, D] -> [D]
72
+ avg_probs = torch.mean(probs, dim=0)
73
+
74
+ entropy = - (avg_probs * torch.log(avg_probs + 1e-5) +
75
+ (1 - avg_probs) * torch.log(1 - avg_probs + 1e-5))
76
+
77
+ loss = - torch.mean(entropy) * self.diversity_gamma
78
+
79
+ return loss, torch.mean(entropy)
80
+
81
+ def forward(self, z: torch.Tensor) -> QuantizerOutput:
82
+ '''
83
+ Args:
84
+ z (torch.Tensor): [B, C, H, W]
85
+
86
+ Returns:
87
+ QuantizerOutput
88
+ '''
89
+ B, C, H, W = z.shape
90
+
91
+ z_permuted = z.permute(0, 2, 3, 1).contiguous()
92
+ z_flattened = z_permuted.view(-1, C)
93
+
94
+ z_e = self.project_in(z_flattened)
95
+
96
+ z_q = torch.sign(z_e)
97
+
98
+ z_q = z_e + (z_q - z_e).detach()
99
+
100
+ commitment_loss = torch.mean((z_q.detach() - z_e) ** 2)
101
+
102
+ entropy_loss, avg_entropy = self.entropy_loss(z_e)
103
+
104
+ total_loss = self.commitment_weight * commitment_loss + self.entropy_weight * entropy_loss
105
+
106
+ # [N, codebook_dim] * [codebook_dim] -> sum -> [N]
107
+ is_positive = (z_q > 0).long()
108
+ indices = (is_positive * self.basis).sum(dim=1)
109
+ indices = indices.view(B, H, W)
110
+
111
+ z_out = self.project_out(z_q)
112
+ z_out = z_out.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
113
+
114
+ perplexity = 2 ** avg_entropy
115
+
116
+ return QuantizerOutput(
117
+ z_q=z_out,
118
+ loss=total_loss,
119
+ indices=indices,
120
+ entropy=avg_entropy,
121
+ perplexity=perplexity
122
+ )