codon-model 0.0.2__tar.gz → 0.0.3a1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {codon_model-0.0.2/codon_model.egg-info → codon_model-0.0.3a1}/PKG-INFO +1 -1
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/__init__.py +1 -1
- codon_model-0.0.3a1/codon/kit/train/__init__.py +6 -0
- codon_model-0.0.2/codon/kit/auto_vision_train.py → codon_model-0.0.3a1/codon/kit/train/vision.py +82 -40
- codon_model-0.0.3a1/codon/model/__init__.py +9 -0
- codon_model-0.0.3a1/codon/model/motif/__init__.py +20 -0
- codon_model-0.0.3a1/codon/model/motif/base.py +231 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/model/motif/motif_a1.py +5 -10
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/model/motif/motif_v1.py +120 -90
- codon_model-0.0.3a1/codon/utils/dataset/base.py +124 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/dataset/image.py +5 -4
- codon_model-0.0.3a1/codon/utils/transforms.py +22 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1/codon_model.egg-info}/PKG-INFO +1 -1
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon_model.egg-info/SOURCES.txt +4 -1
- {codon_model-0.0.2 → codon_model-0.0.3a1}/test/test_motifv1_train.py +9 -13
- codon_model-0.0.2/codon/kit/__init__.py +0 -6
- codon_model-0.0.2/codon/model/motif/__init__.py +0 -14
- codon_model-0.0.2/codon/utils/dataset/base.py +0 -46
- {codon_model-0.0.2 → codon_model-0.0.3a1}/LICENSE +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/base.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/__init__.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/attention.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/bio/__init__.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/bio/hebian.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/bio/predictive.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/codebook.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/conv.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/embedding.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/film.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/fusion.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/lora.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/mlp.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/moe.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/pixelshuffle.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/transformer.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/exp/__init__.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/exp/moe.py +0 -0
- {codon_model-0.0.2/codon/model → codon_model-0.0.3a1/codon/kit}/__init__.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/model/patch_disc.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/model/resnet.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/model/tcn.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/ops/__init__.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/ops/attention.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/ops/bio.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/ops/pixelshuffle.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/__init__.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/dataset/__init__.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/dataset/corpus.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/dataset/dataviewer.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/dataset/flatdata.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/mask.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/safecode.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/seed.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/split.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/theta.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/token.py +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon_model.egg-info/dependency_links.txt +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon_model.egg-info/requires.txt +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/codon_model.egg-info/top_level.txt +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/setup.cfg +0 -0
- {codon_model-0.0.2 → codon_model-0.0.3a1}/setup.py +0 -0
codon_model-0.0.2/codon/kit/auto_vision_train.py → codon_model-0.0.3a1/codon/kit/train/vision.py
RENAMED
|
@@ -2,19 +2,23 @@ import torch
|
|
|
2
2
|
import torch.nn as nn
|
|
3
3
|
import numpy as np
|
|
4
4
|
|
|
5
|
+
from PIL import Image
|
|
5
6
|
from dataclasses import dataclass
|
|
6
7
|
from typing import Union, Optional, Literal, Callable
|
|
7
|
-
from PIL import Image
|
|
8
8
|
|
|
9
|
-
from codon.model
|
|
10
|
-
from codon.model.
|
|
9
|
+
from codon.model import PatchDiscriminator
|
|
10
|
+
from codon.model.motif import (
|
|
11
|
+
AutoencoderVisionModel,
|
|
12
|
+
AutoVisionEncoderOutput,
|
|
13
|
+
AutoVisionDecoderOutput
|
|
14
|
+
)
|
|
11
15
|
from codon.utils.split import split_image, SplitedImage
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
@dataclass
|
|
15
|
-
class
|
|
19
|
+
class AutoVisionTrainResult:
|
|
16
20
|
'''
|
|
17
|
-
Dataclass to hold the outputs and metrics from a single
|
|
21
|
+
Dataclass to hold the outputs and metrics from a single auto_vision_train step.
|
|
18
22
|
|
|
19
23
|
Attributes:
|
|
20
24
|
loss_g (float): Total generator loss.
|
|
@@ -40,8 +44,30 @@ class AutoTrainMotifVisionOutput:
|
|
|
40
44
|
fake_patches: Optional[torch.Tensor] = None
|
|
41
45
|
|
|
42
46
|
|
|
43
|
-
def
|
|
44
|
-
|
|
47
|
+
def _patches_to_image(patches: torch.Tensor, grid_shape: tuple) -> torch.Tensor:
|
|
48
|
+
'''
|
|
49
|
+
Helper function to reconstruct a full image tensor from a sequence of patches.
|
|
50
|
+
This is used to supply a padded full image to the generic AutoencoderVisionModel.encode().
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
patches (torch.Tensor): Patches tensor with shape [num_patches_h * num_patches_w, channels, patch_size, patch_size].
|
|
54
|
+
grid_shape (tuple): Grid shape as (num_patches_h, num_patches_w).
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
torch.Tensor: Reconstructed full image tensor with shape [1, channels, height, width].
|
|
58
|
+
'''
|
|
59
|
+
num_patches_h, num_patches_w = grid_shape
|
|
60
|
+
channels, patch_size = patches.shape[1], patches.shape[2]
|
|
61
|
+
|
|
62
|
+
patches = patches.view(1, num_patches_h, num_patches_w, channels, patch_size, patch_size)
|
|
63
|
+
patches = patches.permute(0, 3, 1, 4, 2, 5).contiguous()
|
|
64
|
+
patches = patches.view(1, channels, num_patches_h * patch_size, num_patches_w * patch_size)
|
|
65
|
+
|
|
66
|
+
return patches
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def auto_vision_train(
|
|
70
|
+
model: AutoencoderVisionModel,
|
|
45
71
|
discriminator: PatchDiscriminator,
|
|
46
72
|
optimizer_g: torch.optim.Optimizer,
|
|
47
73
|
optimizer_d: torch.optim.Optimizer,
|
|
@@ -53,36 +79,36 @@ def auto_train_motif_vision(
|
|
|
53
79
|
perceptual_weight: float = 1.0,
|
|
54
80
|
adv_weight: float = 0.1,
|
|
55
81
|
quant_weight: float = 1.0,
|
|
56
|
-
|
|
57
|
-
device: Union[str, torch.device] = 'cpu'
|
|
58
|
-
) -> AutoTrainMotifVisionOutput:
|
|
82
|
+
) -> AutoVisionTrainResult:
|
|
59
83
|
'''
|
|
60
|
-
Executes a single end-to-end training step for
|
|
84
|
+
Executes a single end-to-end training step for an AutoencoderVisionModel.
|
|
61
85
|
|
|
62
|
-
This function handles image splitting, forward passes for both
|
|
63
|
-
and the discriminator (PatchDiscriminator),
|
|
64
|
-
L1/MSE, and Quantization), and backpropagation.
|
|
86
|
+
This function handles image splitting (with necessary padding), forward passes for both
|
|
87
|
+
the generator (AutoencoderVisionModel) and the discriminator (PatchDiscriminator),
|
|
88
|
+
loss calculations (including GAN, LPIPS, L1/MSE, and Quantization), and backpropagation.
|
|
65
89
|
|
|
66
90
|
Args:
|
|
67
|
-
model (
|
|
91
|
+
model (AutoencoderVisionModel): The autoencoder vision model.
|
|
68
92
|
discriminator (PatchDiscriminator): The PatchGAN discriminator.
|
|
69
|
-
optimizer_g (torch.optim.Optimizer): Optimizer for the
|
|
93
|
+
optimizer_g (torch.optim.Optimizer): Optimizer for the autoencoder model.
|
|
70
94
|
optimizer_d (torch.optim.Optimizer): Optimizer for the discriminator.
|
|
71
95
|
image (Union[torch.Tensor, str, Image.Image, np.ndarray]): The input image.
|
|
72
|
-
patch_size (int): The patch size used by the
|
|
96
|
+
patch_size (int): The patch size used by the model. Defaults to 12.
|
|
73
97
|
recon_loss_type (Literal['l1', 'mse']): Type of reconstruction loss. Defaults to 'l1'.
|
|
74
98
|
recon_weight (float): Weight for the reconstruction loss. Defaults to 1.0.
|
|
75
99
|
perceptual_loss_fn (Callable, optional): Initialized LPIPS or other perceptual loss function. Defaults to None.
|
|
76
100
|
perceptual_weight (float): Weight for the perceptual loss. Defaults to 1.0.
|
|
77
101
|
adv_weight (float): Weight for the generator's adversarial GAN loss. Defaults to 0.1.
|
|
78
102
|
quant_weight (float): Weight for the lookup-free quantization loss. Defaults to 1.0.
|
|
79
|
-
codebook_size (int): The total capacity of the codebook. Defaults to 2^18 = 262144.
|
|
80
|
-
device (Union[str, torch.device]): Device to perform computations on. Defaults to 'cpu'.
|
|
81
103
|
|
|
82
104
|
Returns:
|
|
83
|
-
|
|
105
|
+
AutoVisionTrainResult: Dataclass containing all the calculated losses and metrics.
|
|
84
106
|
'''
|
|
85
|
-
#
|
|
107
|
+
# Fallback mechanisms to get device and codebook_size if they aren't explicitly properties
|
|
108
|
+
device = getattr(model, 'device', next(model.parameters()).device)
|
|
109
|
+
codebook_size = getattr(model, 'codebook_size', 2**18)
|
|
110
|
+
|
|
111
|
+
# 1. Process and split the input image with padding to handle arbitrary sizes
|
|
86
112
|
splited: SplitedImage = split_image(
|
|
87
113
|
image=image,
|
|
88
114
|
patch_size=patch_size,
|
|
@@ -102,10 +128,16 @@ def auto_train_motif_vision(
|
|
|
102
128
|
else:
|
|
103
129
|
recon_criterion = mse_criterion
|
|
104
130
|
|
|
105
|
-
# Forward pass through
|
|
106
|
-
|
|
107
|
-
|
|
131
|
+
# 2. Forward pass through generator (AutoencoderVisionModel)
|
|
132
|
+
# Reconstruct padded full image to feed into generic encode method
|
|
133
|
+
padded_full_image = _patches_to_image(real_patches, grid_shape).to(device)
|
|
134
|
+
|
|
135
|
+
encoder_out: AutoVisionEncoderOutput = model.encode(padded_full_image)
|
|
136
|
+
decoder_out: AutoVisionDecoderOutput = model.decode(encoder_out)
|
|
137
|
+
|
|
138
|
+
fake_patches = decoder_out.reconstructed
|
|
108
139
|
|
|
140
|
+
# 3. Discriminator Training
|
|
109
141
|
optimizer_d.zero_grad()
|
|
110
142
|
|
|
111
143
|
# Forward discriminator on real patches
|
|
@@ -121,51 +153,61 @@ def auto_train_motif_vision(
|
|
|
121
153
|
loss_d.backward()
|
|
122
154
|
optimizer_d.step()
|
|
123
155
|
|
|
156
|
+
# 4. Generator Training
|
|
124
157
|
optimizer_g.zero_grad()
|
|
125
158
|
|
|
126
|
-
#
|
|
159
|
+
# 4.1 Reconstruction Loss (L1 or MSE)
|
|
127
160
|
loss_recon = recon_criterion(fake_patches, real_patches)
|
|
128
161
|
|
|
129
|
-
#
|
|
162
|
+
# 4.2 Perceptual Loss (LPIPS)
|
|
130
163
|
loss_perceptual_val = torch.tensor(0.0, device=device)
|
|
131
164
|
if perceptual_loss_fn is not None:
|
|
132
|
-
#
|
|
165
|
+
# Expected image range handling: LPIPS usually expects [-1, 1], models might output [0, 1]
|
|
133
166
|
p_real = real_patches * 2.0 - 1.0
|
|
134
167
|
p_fake = fake_patches * 2.0 - 1.0
|
|
135
168
|
loss_perceptual_val = perceptual_loss_fn(p_real, p_fake).mean()
|
|
136
169
|
|
|
137
|
-
#
|
|
138
|
-
|
|
170
|
+
# 4.3 Quantization Loss
|
|
171
|
+
# Fallback to 0.0 if the encoder output does not provide a quantization loss (e.g., standard AE)
|
|
172
|
+
loss_quant_val = torch.tensor(0.0, device=device)
|
|
173
|
+
if encoder_out.loss is not None:
|
|
174
|
+
loss_quant_val = encoder_out.loss
|
|
139
175
|
|
|
140
|
-
#
|
|
176
|
+
# 4.4 Generator Adversarial Loss
|
|
141
177
|
d_out_fake_g = discriminator(fake_patches)
|
|
142
178
|
loss_adv = mse_criterion(d_out_fake_g, torch.ones_like(d_out_fake_g))
|
|
143
179
|
|
|
144
|
-
#
|
|
180
|
+
# 4.5 Total Generator Loss
|
|
145
181
|
loss_g = (
|
|
146
182
|
recon_weight * loss_recon +
|
|
147
183
|
perceptual_weight * loss_perceptual_val +
|
|
148
|
-
quant_weight *
|
|
184
|
+
quant_weight * loss_quant_val +
|
|
149
185
|
adv_weight * loss_adv
|
|
150
186
|
)
|
|
151
187
|
|
|
152
188
|
loss_g.backward()
|
|
153
189
|
optimizer_g.step()
|
|
154
190
|
|
|
155
|
-
# Calculate codebook utilization
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
191
|
+
# Calculate codebook utilization if applicable
|
|
192
|
+
usage_rate = 0.0
|
|
193
|
+
if encoder_out.indices is not None:
|
|
194
|
+
indices = encoder_out.indices
|
|
195
|
+
unique_indices = torch.unique(indices)
|
|
196
|
+
usage_rate = unique_indices.numel() / codebook_size
|
|
197
|
+
|
|
198
|
+
perplexity_val = 0.0
|
|
199
|
+
if encoder_out.perplexity is not None:
|
|
200
|
+
perplexity_val = encoder_out.perplexity.item()
|
|
201
|
+
|
|
202
|
+
return AutoVisionTrainResult(
|
|
161
203
|
loss_g=loss_g.item(),
|
|
162
204
|
loss_d=loss_d.item(),
|
|
163
205
|
loss_recon=loss_recon.item(),
|
|
164
206
|
loss_perceptual=loss_perceptual_val.item(),
|
|
165
|
-
loss_quant=
|
|
207
|
+
loss_quant=loss_quant_val.item(),
|
|
166
208
|
loss_adv=loss_adv.item(),
|
|
167
209
|
codebook_usage_rate=float(usage_rate),
|
|
168
|
-
perplexity=
|
|
210
|
+
perplexity=float(perplexity_val),
|
|
169
211
|
real_patches=real_patches,
|
|
170
212
|
fake_patches=fake_patches
|
|
171
213
|
)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from .base import (
|
|
2
|
+
CausalLanguageModel,
|
|
3
|
+
CausalLanguageModelOutput,
|
|
4
|
+
AutoencoderVisionModel,
|
|
5
|
+
AutoVisionEncoderOutput,
|
|
6
|
+
AutoVisionDecoderOutput
|
|
7
|
+
)
|
|
8
|
+
from .motif_a1 import MotifA1
|
|
9
|
+
from .motif_v1 import MotifV1Encoder, MotifV1Decoder, MotifV1
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
'CausalLanguageModel',
|
|
14
|
+
'CausalLanguageModelOutput',
|
|
15
|
+
'AutoencoderVisionModel',
|
|
16
|
+
'AutoVisionEncoderOutput',
|
|
17
|
+
'AutoVisionDecoderOutput',
|
|
18
|
+
'MotifA1',
|
|
19
|
+
'MotifV1Encoder', 'MotifV1Decoder', 'MotifV1',
|
|
20
|
+
]
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
from typing import Callable, Any, Iterator, Union, Optional, List, Tuple
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
|
|
8
|
+
from codon.base import BasicModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class AutoVisionEncoderOutput:
|
|
13
|
+
'''
|
|
14
|
+
Output of autoencoder vision model encoder.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
z_q (torch.Tensor): Quantized latent tensor.
|
|
18
|
+
loss (torch.Tensor): Quantization loss.
|
|
19
|
+
indices (torch.Tensor): Quantized indices.
|
|
20
|
+
grid_shape (tuple): Grid shape as (num_patches_h, num_patches_w).
|
|
21
|
+
entropy (torch.Tensor): Average bit-wise entropy from codebook.
|
|
22
|
+
perplexity (torch.Tensor): Perplexity calculated as 2^entropy.
|
|
23
|
+
hidden_states (torch.Tensor): Hidden states before quantization.
|
|
24
|
+
'''
|
|
25
|
+
z_q: torch.Tensor
|
|
26
|
+
loss: torch.Tensor = None
|
|
27
|
+
indices: torch.Tensor = None
|
|
28
|
+
grid_shape: tuple = None
|
|
29
|
+
entropy: torch.Tensor = None
|
|
30
|
+
perplexity: torch.Tensor = None
|
|
31
|
+
hidden_states: torch.Tensor = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class AutoVisionDecoderOutput:
|
|
36
|
+
'''
|
|
37
|
+
Output of autoencoder vision model decoder.
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
reconstructed (torch.Tensor): Reconstructed output tensor.
|
|
41
|
+
grid_shape (tuple): Grid shape as (num_patches_h, num_patches_w).
|
|
42
|
+
hidden_states (torch.Tensor): Hidden states after attention.
|
|
43
|
+
'''
|
|
44
|
+
reconstructed: torch.Tensor
|
|
45
|
+
grid_shape: tuple = None
|
|
46
|
+
hidden_states: torch.Tensor = None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class CausalLanguageModelOutput:
|
|
51
|
+
'''
|
|
52
|
+
Output of causal language model.
|
|
53
|
+
|
|
54
|
+
Attributes:
|
|
55
|
+
logits (torch.Tensor): Prediction logits.
|
|
56
|
+
past_key_values (list, optional): List of past key value states.
|
|
57
|
+
aux_loss (torch.Tensor, optional): Auxiliary loss.
|
|
58
|
+
attentions (list, optional): List of attention weights.
|
|
59
|
+
hidden_states (tuple, optional): Tuple of hidden states.
|
|
60
|
+
'''
|
|
61
|
+
logits: torch.Tensor
|
|
62
|
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None
|
|
63
|
+
aux_loss: Optional[torch.Tensor] = None
|
|
64
|
+
attentions: Optional[List[torch.Tensor]] = None
|
|
65
|
+
hidden_states: Optional[Tuple[torch.Tensor]] = None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class CausalLanguageModel(BasicModel):
|
|
69
|
+
'''
|
|
70
|
+
Base class for causal language models with text generation capabilities.
|
|
71
|
+
|
|
72
|
+
Attributes:
|
|
73
|
+
gradient_checkpointing (bool): Whether gradient checkpointing is enabled.
|
|
74
|
+
'''
|
|
75
|
+
|
|
76
|
+
def generate(
|
|
77
|
+
self,
|
|
78
|
+
input_ids: torch.Tensor,
|
|
79
|
+
max_new_tokens: int = 100,
|
|
80
|
+
temperature: float = 1.0,
|
|
81
|
+
top_k: int = None,
|
|
82
|
+
eos_token_id: int = None
|
|
83
|
+
) -> torch.Tensor:
|
|
84
|
+
'''
|
|
85
|
+
Generate text tokens autoregressively.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
input_ids (torch.Tensor): Input token IDs with shape [batch, seq_len].
|
|
89
|
+
max_new_tokens (int): Maximum number of new tokens to generate. Defaults to 100.
|
|
90
|
+
temperature (float): Sampling temperature. Higher values increase randomness.
|
|
91
|
+
Defaults to 1.0.
|
|
92
|
+
top_k (int, optional): If set, sample only from top k tokens. Defaults to None.
|
|
93
|
+
eos_token_id (int, optional): End-of-sequence token ID. If None, generation
|
|
94
|
+
stops after max_new_tokens. Defaults to None.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
torch.Tensor: Generated token IDs with shape [batch, seq_len + num_generated].
|
|
98
|
+
'''
|
|
99
|
+
self.eval()
|
|
100
|
+
with torch.no_grad():
|
|
101
|
+
batch_size, seq_len = input_ids.shape
|
|
102
|
+
generated = input_ids.clone()
|
|
103
|
+
|
|
104
|
+
past_key_values = None
|
|
105
|
+
for _ in range(max_new_tokens):
|
|
106
|
+
if seq_len > 1:
|
|
107
|
+
outputs = self.forward(
|
|
108
|
+
input_ids=generated,
|
|
109
|
+
past_key_values=past_key_values,
|
|
110
|
+
use_cache=True
|
|
111
|
+
)
|
|
112
|
+
past_key_values = outputs.past_key_values
|
|
113
|
+
logits = outputs.logits[:, -1, :]
|
|
114
|
+
else:
|
|
115
|
+
outputs = self.forward(input_ids=generated)
|
|
116
|
+
logits = outputs.logits[:, -1, :]
|
|
117
|
+
|
|
118
|
+
logits = logits / temperature
|
|
119
|
+
|
|
120
|
+
if top_k is not None:
|
|
121
|
+
top_k_vals = torch.topk(logits, top_k).values[:, -1]
|
|
122
|
+
logits = torch.where(logits < top_k_vals.unsqueeze(1), torch.full_like(logits, float('-inf')), logits)
|
|
123
|
+
|
|
124
|
+
probs = F.softmax(logits, dim=-1)
|
|
125
|
+
next_token = torch.multinomial(probs, num_samples=1)
|
|
126
|
+
|
|
127
|
+
generated = torch.cat([generated, next_token], dim=1)
|
|
128
|
+
|
|
129
|
+
if eos_token_id is not None and (next_token == eos_token_id).all():
|
|
130
|
+
break
|
|
131
|
+
|
|
132
|
+
return generated
|
|
133
|
+
|
|
134
|
+
def compute_perplexity(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
|
135
|
+
'''
|
|
136
|
+
Compute perplexity from logits and target tokens.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
logits (torch.Tensor): Model output logits with shape [batch, seq_len, vocab_size].
|
|
140
|
+
targets (torch.Tensor): Target token IDs with shape [batch, seq_len].
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
torch.Tensor: Perplexity value (lower is better).
|
|
144
|
+
'''
|
|
145
|
+
batch_size, seq_len, vocab_size = logits.shape
|
|
146
|
+
|
|
147
|
+
logits_flat = logits.reshape(batch_size * seq_len, vocab_size)
|
|
148
|
+
targets_flat = targets.reshape(batch_size * seq_len)
|
|
149
|
+
|
|
150
|
+
loss = F.cross_entropy(logits_flat, targets_flat, reduction='mean')
|
|
151
|
+
perplexity = torch.exp(loss)
|
|
152
|
+
|
|
153
|
+
return perplexity
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class AutoencoderVisionModel(BasicModel):
|
|
157
|
+
'''
|
|
158
|
+
Base class for autoencoder vision models with encoding/decoding capabilities.
|
|
159
|
+
|
|
160
|
+
Attributes:
|
|
161
|
+
gradient_checkpointing (bool): Whether gradient checkpointing is enabled.
|
|
162
|
+
'''
|
|
163
|
+
def __init__(self):
|
|
164
|
+
super().__init__()
|
|
165
|
+
self.codebook_size: int = 0
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def compute_psnr(img1: torch.Tensor, img2: torch.Tensor, max_value: float = 1.0) -> torch.Tensor:
|
|
169
|
+
'''
|
|
170
|
+
Compute Peak Signal-to-Noise Ratio between two images.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
img1 (torch.Tensor): Reference image tensor.
|
|
174
|
+
img2 (torch.Tensor): Comparison image tensor.
|
|
175
|
+
max_value (float): Maximum possible pixel value. Defaults to 1.0.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
torch.Tensor: PSNR value in dB (higher is better).
|
|
179
|
+
'''
|
|
180
|
+
mse = torch.mean((img1 - img2) ** 2)
|
|
181
|
+
psnr = 10 * torch.log10(max_value ** 2 / mse)
|
|
182
|
+
return psnr
|
|
183
|
+
|
|
184
|
+
def encode(self, x: torch.Tensor) -> AutoVisionEncoderOutput:
|
|
185
|
+
'''
|
|
186
|
+
Encode an image to latent representation.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
x (torch.Tensor): Input image tensor with shape [batch, channels, height, width].
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
AutoVisionEncoderOutput: Output containing latent representation and grid_shape.
|
|
193
|
+
'''
|
|
194
|
+
return self._encode(x)
|
|
195
|
+
|
|
196
|
+
def decode(self, encoder_output: AutoVisionEncoderOutput) -> AutoVisionDecoderOutput:
|
|
197
|
+
'''
|
|
198
|
+
Decode a latent representation to an image.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
encoder_output (AutoVisionEncoderOutput): Output from encode method containing
|
|
202
|
+
latent representation and grid_shape.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
AutoVisionDecoderOutput: Output containing reconstructed image and grid_shape.
|
|
206
|
+
'''
|
|
207
|
+
return self._decode(encoder_output)
|
|
208
|
+
|
|
209
|
+
def _encode(self, x: torch.Tensor) -> AutoVisionEncoderOutput:
|
|
210
|
+
'''
|
|
211
|
+
Internal encoding method to be implemented by subclasses.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
x (torch.Tensor): Input image tensor.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
AutoVisionEncoderOutput: Output containing latent representation and grid_shape.
|
|
218
|
+
'''
|
|
219
|
+
raise NotImplementedError('Subclasses must implement _encode method')
|
|
220
|
+
|
|
221
|
+
def _decode(self, encoder_output: AutoVisionEncoderOutput) -> AutoVisionDecoderOutput:
|
|
222
|
+
'''
|
|
223
|
+
Internal decoding method to be implemented by subclasses.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
encoder_output (AutoVisionEncoderOutput): Output from encode method.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
AutoVisionDecoderOutput: Output containing reconstructed image.
|
|
230
|
+
'''
|
|
231
|
+
raise NotImplementedError('Subclasses must implement _decode method')
|
|
@@ -3,18 +3,13 @@ from codon.base import *
|
|
|
3
3
|
from codon.block.transformer import TransformerMoEDecoder
|
|
4
4
|
from codon.block.embedding import RotaryEmbedding
|
|
5
5
|
|
|
6
|
+
from .base import CausalLanguageModel, CausalLanguageModelOutput
|
|
7
|
+
|
|
6
8
|
from typing import Optional, List, Tuple
|
|
7
9
|
from dataclasses import dataclass
|
|
8
10
|
|
|
9
|
-
@dataclass
|
|
10
|
-
class MotifA1Output:
|
|
11
|
-
logits: torch.Tensor
|
|
12
|
-
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None
|
|
13
|
-
aux_loss: Optional[torch.Tensor] = None
|
|
14
|
-
attentions: Optional[List[torch.Tensor]] = None
|
|
15
|
-
|
|
16
11
|
|
|
17
|
-
class MotifA1(
|
|
12
|
+
class MotifA1(CausalLanguageModel):
|
|
18
13
|
def __init__(
|
|
19
14
|
self,
|
|
20
15
|
vocab_size: int = 32000,
|
|
@@ -75,7 +70,7 @@ class MotifA1(BasicModel):
|
|
|
75
70
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
|
76
71
|
use_cache: bool = False,
|
|
77
72
|
output_attentions: bool = False
|
|
78
|
-
) ->
|
|
73
|
+
) -> CausalLanguageModelOutput:
|
|
79
74
|
x = self.token_emb(input_ids)
|
|
80
75
|
x = self.dropout(x)
|
|
81
76
|
|
|
@@ -113,7 +108,7 @@ class MotifA1(BasicModel):
|
|
|
113
108
|
x = self.norm(x)
|
|
114
109
|
logits = self.proj_out(x)
|
|
115
110
|
|
|
116
|
-
return
|
|
111
|
+
return CausalLanguageModelOutput(
|
|
117
112
|
logits=logits,
|
|
118
113
|
past_key_values=new_kv_cache,
|
|
119
114
|
aux_loss=aux_loss,
|