codon-model 0.0.2__tar.gz → 0.0.3a1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {codon_model-0.0.2/codon_model.egg-info → codon_model-0.0.3a1}/PKG-INFO +1 -1
  2. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/__init__.py +1 -1
  3. codon_model-0.0.3a1/codon/kit/train/__init__.py +6 -0
  4. codon_model-0.0.2/codon/kit/auto_vision_train.py → codon_model-0.0.3a1/codon/kit/train/vision.py +82 -40
  5. codon_model-0.0.3a1/codon/model/__init__.py +9 -0
  6. codon_model-0.0.3a1/codon/model/motif/__init__.py +20 -0
  7. codon_model-0.0.3a1/codon/model/motif/base.py +231 -0
  8. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/model/motif/motif_a1.py +5 -10
  9. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/model/motif/motif_v1.py +120 -90
  10. codon_model-0.0.3a1/codon/utils/dataset/base.py +124 -0
  11. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/dataset/image.py +5 -4
  12. codon_model-0.0.3a1/codon/utils/transforms.py +22 -0
  13. {codon_model-0.0.2 → codon_model-0.0.3a1/codon_model.egg-info}/PKG-INFO +1 -1
  14. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon_model.egg-info/SOURCES.txt +4 -1
  15. {codon_model-0.0.2 → codon_model-0.0.3a1}/test/test_motifv1_train.py +9 -13
  16. codon_model-0.0.2/codon/kit/__init__.py +0 -6
  17. codon_model-0.0.2/codon/model/motif/__init__.py +0 -14
  18. codon_model-0.0.2/codon/utils/dataset/base.py +0 -46
  19. {codon_model-0.0.2 → codon_model-0.0.3a1}/LICENSE +0 -0
  20. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/base.py +0 -0
  21. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/__init__.py +0 -0
  22. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/attention.py +0 -0
  23. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/bio/__init__.py +0 -0
  24. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/bio/hebian.py +0 -0
  25. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/bio/predictive.py +0 -0
  26. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/codebook.py +0 -0
  27. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/conv.py +0 -0
  28. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/embedding.py +0 -0
  29. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/film.py +0 -0
  30. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/fusion.py +0 -0
  31. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/lora.py +0 -0
  32. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/mlp.py +0 -0
  33. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/moe.py +0 -0
  34. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/pixelshuffle.py +0 -0
  35. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/block/transformer.py +0 -0
  36. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/exp/__init__.py +0 -0
  37. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/exp/moe.py +0 -0
  38. {codon_model-0.0.2/codon/model → codon_model-0.0.3a1/codon/kit}/__init__.py +0 -0
  39. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/model/patch_disc.py +0 -0
  40. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/model/resnet.py +0 -0
  41. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/model/tcn.py +0 -0
  42. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/ops/__init__.py +0 -0
  43. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/ops/attention.py +0 -0
  44. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/ops/bio.py +0 -0
  45. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/ops/pixelshuffle.py +0 -0
  46. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/__init__.py +0 -0
  47. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/dataset/__init__.py +0 -0
  48. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/dataset/corpus.py +0 -0
  49. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/dataset/dataviewer.py +0 -0
  50. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/dataset/flatdata.py +0 -0
  51. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/mask.py +0 -0
  52. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/safecode.py +0 -0
  53. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/seed.py +0 -0
  54. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/split.py +0 -0
  55. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/theta.py +0 -0
  56. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon/utils/token.py +0 -0
  57. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon_model.egg-info/dependency_links.txt +0 -0
  58. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon_model.egg-info/requires.txt +0 -0
  59. {codon_model-0.0.2 → codon_model-0.0.3a1}/codon_model.egg-info/top_level.txt +0 -0
  60. {codon_model-0.0.2 → codon_model-0.0.3a1}/setup.cfg +0 -0
  61. {codon_model-0.0.2 → codon_model-0.0.3a1}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: codon-model
3
- Version: 0.0.2
3
+ Version: 0.0.3a1
4
4
  Summary: Codon model package
5
5
  Author: CodonTeam
6
6
  Requires-Python: >=3.8
@@ -1,5 +1,5 @@
1
1
  from typing import Optional
2
2
 
3
- __version__ = '0.0.2'
3
+ __version__ = '0.0.3a1'
4
4
 
5
5
  __seed__: Optional[int] = None
@@ -0,0 +1,6 @@
1
+ from .vision import auto_vision_train, AutoVisionTrainResult
2
+
3
+ __all__ = [
4
+ 'auto_vision_train',
5
+ 'AutoVisionTrainResult'
6
+ ]
@@ -2,19 +2,23 @@ import torch
2
2
  import torch.nn as nn
3
3
  import numpy as np
4
4
 
5
+ from PIL import Image
5
6
  from dataclasses import dataclass
6
7
  from typing import Union, Optional, Literal, Callable
7
- from PIL import Image
8
8
 
9
- from codon.model.motif.motif_v1 import MotifV1, MotifV1Output
10
- from codon.model.patch_disc import PatchDiscriminator
9
+ from codon.model import PatchDiscriminator
10
+ from codon.model.motif import (
11
+ AutoencoderVisionModel,
12
+ AutoVisionEncoderOutput,
13
+ AutoVisionDecoderOutput
14
+ )
11
15
  from codon.utils.split import split_image, SplitedImage
12
16
 
13
17
 
14
18
  @dataclass
15
- class AutoTrainMotifVisionOutput:
19
+ class AutoVisionTrainResult:
16
20
  '''
17
- Dataclass to hold the outputs and metrics from a single auto_train step.
21
+ Dataclass to hold the outputs and metrics from a single auto_vision_train step.
18
22
 
19
23
  Attributes:
20
24
  loss_g (float): Total generator loss.
@@ -40,8 +44,30 @@ class AutoTrainMotifVisionOutput:
40
44
  fake_patches: Optional[torch.Tensor] = None
41
45
 
42
46
 
43
- def auto_train_motif_vision(
44
- model: MotifV1,
47
+ def _patches_to_image(patches: torch.Tensor, grid_shape: tuple) -> torch.Tensor:
48
+ '''
49
+ Helper function to reconstruct a full image tensor from a sequence of patches.
50
+ This is used to supply a padded full image to the generic AutoencoderVisionModel.encode().
51
+
52
+ Args:
53
+ patches (torch.Tensor): Patches tensor with shape [num_patches_h * num_patches_w, channels, patch_size, patch_size].
54
+ grid_shape (tuple): Grid shape as (num_patches_h, num_patches_w).
55
+
56
+ Returns:
57
+ torch.Tensor: Reconstructed full image tensor with shape [1, channels, height, width].
58
+ '''
59
+ num_patches_h, num_patches_w = grid_shape
60
+ channels, patch_size = patches.shape[1], patches.shape[2]
61
+
62
+ patches = patches.view(1, num_patches_h, num_patches_w, channels, patch_size, patch_size)
63
+ patches = patches.permute(0, 3, 1, 4, 2, 5).contiguous()
64
+ patches = patches.view(1, channels, num_patches_h * patch_size, num_patches_w * patch_size)
65
+
66
+ return patches
67
+
68
+
69
+ def auto_vision_train(
70
+ model: AutoencoderVisionModel,
45
71
  discriminator: PatchDiscriminator,
46
72
  optimizer_g: torch.optim.Optimizer,
47
73
  optimizer_d: torch.optim.Optimizer,
@@ -53,36 +79,36 @@ def auto_train_motif_vision(
53
79
  perceptual_weight: float = 1.0,
54
80
  adv_weight: float = 0.1,
55
81
  quant_weight: float = 1.0,
56
- codebook_size: int = 2**18,
57
- device: Union[str, torch.device] = 'cpu'
58
- ) -> AutoTrainMotifVisionOutput:
82
+ ) -> AutoVisionTrainResult:
59
83
  '''
60
- Executes a single end-to-end training step for the MotifV1 autoencoder.
84
+ Executes a single end-to-end training step for an AutoencoderVisionModel.
61
85
 
62
- This function handles image splitting, forward passes for both the generator (MotifV1)
63
- and the discriminator (PatchDiscriminator), loss calculations (including GAN, LPIPS,
64
- L1/MSE, and Quantization), and backpropagation.
86
+ This function handles image splitting (with necessary padding), forward passes for both
87
+ the generator (AutoencoderVisionModel) and the discriminator (PatchDiscriminator),
88
+ loss calculations (including GAN, LPIPS, L1/MSE, and Quantization), and backpropagation.
65
89
 
66
90
  Args:
67
- model (MotifV1): The MotifV1 autoencoder model.
91
+ model (AutoencoderVisionModel): The autoencoder vision model.
68
92
  discriminator (PatchDiscriminator): The PatchGAN discriminator.
69
- optimizer_g (torch.optim.Optimizer): Optimizer for the MotifV1 model.
93
+ optimizer_g (torch.optim.Optimizer): Optimizer for the autoencoder model.
70
94
  optimizer_d (torch.optim.Optimizer): Optimizer for the discriminator.
71
95
  image (Union[torch.Tensor, str, Image.Image, np.ndarray]): The input image.
72
- patch_size (int): The patch size used by the MotifV1 model. Defaults to 12.
96
+ patch_size (int): The patch size used by the model. Defaults to 12.
73
97
  recon_loss_type (Literal['l1', 'mse']): Type of reconstruction loss. Defaults to 'l1'.
74
98
  recon_weight (float): Weight for the reconstruction loss. Defaults to 1.0.
75
99
  perceptual_loss_fn (Callable, optional): Initialized LPIPS or other perceptual loss function. Defaults to None.
76
100
  perceptual_weight (float): Weight for the perceptual loss. Defaults to 1.0.
77
101
  adv_weight (float): Weight for the generator's adversarial GAN loss. Defaults to 0.1.
78
102
  quant_weight (float): Weight for the lookup-free quantization loss. Defaults to 1.0.
79
- codebook_size (int): The total capacity of the codebook. Defaults to 2^18 = 262144.
80
- device (Union[str, torch.device]): Device to perform computations on. Defaults to 'cpu'.
81
103
 
82
104
  Returns:
83
- AutoTrainMotifVisionOutput: Dataclass containing all the calculated losses and metrics.
105
+ AutoVisionTrainResult: Dataclass containing all the calculated losses and metrics.
84
106
  '''
85
- # 1. Process and split the input image
107
+ # Fallback mechanisms to get device and codebook_size if they aren't explicitly properties
108
+ device = getattr(model, 'device', next(model.parameters()).device)
109
+ codebook_size = getattr(model, 'codebook_size', 2**18)
110
+
111
+ # 1. Process and split the input image with padding to handle arbitrary sizes
86
112
  splited: SplitedImage = split_image(
87
113
  image=image,
88
114
  patch_size=patch_size,
@@ -102,10 +128,16 @@ def auto_train_motif_vision(
102
128
  else:
103
129
  recon_criterion = mse_criterion
104
130
 
105
- # Forward pass through MotifV1 once, reusing outputs for both discriminator and generator
106
- motif_out: MotifV1Output = model(real_patches, grid_shape)
107
- fake_patches = motif_out.reconstructed_image
131
+ # 2. Forward pass through generator (AutoencoderVisionModel)
132
+ # Reconstruct padded full image to feed into generic encode method
133
+ padded_full_image = _patches_to_image(real_patches, grid_shape).to(device)
134
+
135
+ encoder_out: AutoVisionEncoderOutput = model.encode(padded_full_image)
136
+ decoder_out: AutoVisionDecoderOutput = model.decode(encoder_out)
137
+
138
+ fake_patches = decoder_out.reconstructed
108
139
 
140
+ # 3. Discriminator Training
109
141
  optimizer_d.zero_grad()
110
142
 
111
143
  # Forward discriminator on real patches
@@ -121,51 +153,61 @@ def auto_train_motif_vision(
121
153
  loss_d.backward()
122
154
  optimizer_d.step()
123
155
 
156
+ # 4. Generator Training
124
157
  optimizer_g.zero_grad()
125
158
 
126
- # 2.1 Reconstruction Loss (L1 or MSE)
159
+ # 4.1 Reconstruction Loss (L1 or MSE)
127
160
  loss_recon = recon_criterion(fake_patches, real_patches)
128
161
 
129
- # 2.2 Perceptual Loss (LPIPS)
162
+ # 4.2 Perceptual Loss (LPIPS)
130
163
  loss_perceptual_val = torch.tensor(0.0, device=device)
131
164
  if perceptual_loss_fn is not None:
132
- # LPIPS expects input in range [-1, 1], Motif uses [0, 1]
165
+ # Expected image range handling: LPIPS usually expects [-1, 1], models might output [0, 1]
133
166
  p_real = real_patches * 2.0 - 1.0
134
167
  p_fake = fake_patches * 2.0 - 1.0
135
168
  loss_perceptual_val = perceptual_loss_fn(p_real, p_fake).mean()
136
169
 
137
- # 2.3 Quantization Loss
138
- loss_quant = motif_out.quantization_loss
170
+ # 4.3 Quantization Loss
171
+ # Fallback to 0.0 if the encoder output does not provide a quantization loss (e.g., standard AE)
172
+ loss_quant_val = torch.tensor(0.0, device=device)
173
+ if encoder_out.loss is not None:
174
+ loss_quant_val = encoder_out.loss
139
175
 
140
- # 2.4 Generator Adversarial Loss
176
+ # 4.4 Generator Adversarial Loss
141
177
  d_out_fake_g = discriminator(fake_patches)
142
178
  loss_adv = mse_criterion(d_out_fake_g, torch.ones_like(d_out_fake_g))
143
179
 
144
- # 2.5 Total Generator Loss
180
+ # 4.5 Total Generator Loss
145
181
  loss_g = (
146
182
  recon_weight * loss_recon +
147
183
  perceptual_weight * loss_perceptual_val +
148
- quant_weight * loss_quant +
184
+ quant_weight * loss_quant_val +
149
185
  adv_weight * loss_adv
150
186
  )
151
187
 
152
188
  loss_g.backward()
153
189
  optimizer_g.step()
154
190
 
155
- # Calculate codebook utilization
156
- indices = motif_out.indices
157
- unique_indices = torch.unique(indices)
158
- usage_rate = unique_indices.numel() / codebook_size
159
-
160
- return AutoTrainMotifVisionOutput(
191
+ # Calculate codebook utilization if applicable
192
+ usage_rate = 0.0
193
+ if encoder_out.indices is not None:
194
+ indices = encoder_out.indices
195
+ unique_indices = torch.unique(indices)
196
+ usage_rate = unique_indices.numel() / codebook_size
197
+
198
+ perplexity_val = 0.0
199
+ if encoder_out.perplexity is not None:
200
+ perplexity_val = encoder_out.perplexity.item()
201
+
202
+ return AutoVisionTrainResult(
161
203
  loss_g=loss_g.item(),
162
204
  loss_d=loss_d.item(),
163
205
  loss_recon=loss_recon.item(),
164
206
  loss_perceptual=loss_perceptual_val.item(),
165
- loss_quant=loss_quant.item(),
207
+ loss_quant=loss_quant_val.item(),
166
208
  loss_adv=loss_adv.item(),
167
209
  codebook_usage_rate=float(usage_rate),
168
- perplexity=motif_out.perplexity.item(),
210
+ perplexity=float(perplexity_val),
169
211
  real_patches=real_patches,
170
212
  fake_patches=fake_patches
171
213
  )
@@ -0,0 +1,9 @@
1
+ from .resnet import ResNet
2
+ from .patch_disc import PatchDiscriminator
3
+ from .tcn import TemporalConvNet
4
+
5
+ __all__ = [
6
+ 'ResNet',
7
+ 'PatchDiscriminator',
8
+ 'TemporalConvNet'
9
+ ]
@@ -0,0 +1,20 @@
1
+ from .base import (
2
+ CausalLanguageModel,
3
+ CausalLanguageModelOutput,
4
+ AutoencoderVisionModel,
5
+ AutoVisionEncoderOutput,
6
+ AutoVisionDecoderOutput
7
+ )
8
+ from .motif_a1 import MotifA1
9
+ from .motif_v1 import MotifV1Encoder, MotifV1Decoder, MotifV1
10
+
11
+
12
+ __all__ = [
13
+ 'CausalLanguageModel',
14
+ 'CausalLanguageModelOutput',
15
+ 'AutoencoderVisionModel',
16
+ 'AutoVisionEncoderOutput',
17
+ 'AutoVisionDecoderOutput',
18
+ 'MotifA1',
19
+ 'MotifV1Encoder', 'MotifV1Decoder', 'MotifV1',
20
+ ]
@@ -0,0 +1,231 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from typing import Callable, Any, Iterator, Union, Optional, List, Tuple
6
+ from dataclasses import dataclass
7
+
8
+ from codon.base import BasicModel
9
+
10
+
11
+ @dataclass
12
+ class AutoVisionEncoderOutput:
13
+ '''
14
+ Output of autoencoder vision model encoder.
15
+
16
+ Attributes:
17
+ z_q (torch.Tensor): Quantized latent tensor.
18
+ loss (torch.Tensor): Quantization loss.
19
+ indices (torch.Tensor): Quantized indices.
20
+ grid_shape (tuple): Grid shape as (num_patches_h, num_patches_w).
21
+ entropy (torch.Tensor): Average bit-wise entropy from codebook.
22
+ perplexity (torch.Tensor): Perplexity calculated as 2^entropy.
23
+ hidden_states (torch.Tensor): Hidden states before quantization.
24
+ '''
25
+ z_q: torch.Tensor
26
+ loss: torch.Tensor = None
27
+ indices: torch.Tensor = None
28
+ grid_shape: tuple = None
29
+ entropy: torch.Tensor = None
30
+ perplexity: torch.Tensor = None
31
+ hidden_states: torch.Tensor = None
32
+
33
+
34
+ @dataclass
35
+ class AutoVisionDecoderOutput:
36
+ '''
37
+ Output of autoencoder vision model decoder.
38
+
39
+ Attributes:
40
+ reconstructed (torch.Tensor): Reconstructed output tensor.
41
+ grid_shape (tuple): Grid shape as (num_patches_h, num_patches_w).
42
+ hidden_states (torch.Tensor): Hidden states after attention.
43
+ '''
44
+ reconstructed: torch.Tensor
45
+ grid_shape: tuple = None
46
+ hidden_states: torch.Tensor = None
47
+
48
+
49
+ @dataclass
50
+ class CausalLanguageModelOutput:
51
+ '''
52
+ Output of causal language model.
53
+
54
+ Attributes:
55
+ logits (torch.Tensor): Prediction logits.
56
+ past_key_values (list, optional): List of past key value states.
57
+ aux_loss (torch.Tensor, optional): Auxiliary loss.
58
+ attentions (list, optional): List of attention weights.
59
+ hidden_states (tuple, optional): Tuple of hidden states.
60
+ '''
61
+ logits: torch.Tensor
62
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None
63
+ aux_loss: Optional[torch.Tensor] = None
64
+ attentions: Optional[List[torch.Tensor]] = None
65
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
66
+
67
+
68
+ class CausalLanguageModel(BasicModel):
69
+ '''
70
+ Base class for causal language models with text generation capabilities.
71
+
72
+ Attributes:
73
+ gradient_checkpointing (bool): Whether gradient checkpointing is enabled.
74
+ '''
75
+
76
+ def generate(
77
+ self,
78
+ input_ids: torch.Tensor,
79
+ max_new_tokens: int = 100,
80
+ temperature: float = 1.0,
81
+ top_k: int = None,
82
+ eos_token_id: int = None
83
+ ) -> torch.Tensor:
84
+ '''
85
+ Generate text tokens autoregressively.
86
+
87
+ Args:
88
+ input_ids (torch.Tensor): Input token IDs with shape [batch, seq_len].
89
+ max_new_tokens (int): Maximum number of new tokens to generate. Defaults to 100.
90
+ temperature (float): Sampling temperature. Higher values increase randomness.
91
+ Defaults to 1.0.
92
+ top_k (int, optional): If set, sample only from top k tokens. Defaults to None.
93
+ eos_token_id (int, optional): End-of-sequence token ID. If None, generation
94
+ stops after max_new_tokens. Defaults to None.
95
+
96
+ Returns:
97
+ torch.Tensor: Generated token IDs with shape [batch, seq_len + num_generated].
98
+ '''
99
+ self.eval()
100
+ with torch.no_grad():
101
+ batch_size, seq_len = input_ids.shape
102
+ generated = input_ids.clone()
103
+
104
+ past_key_values = None
105
+ for _ in range(max_new_tokens):
106
+ if seq_len > 1:
107
+ outputs = self.forward(
108
+ input_ids=generated,
109
+ past_key_values=past_key_values,
110
+ use_cache=True
111
+ )
112
+ past_key_values = outputs.past_key_values
113
+ logits = outputs.logits[:, -1, :]
114
+ else:
115
+ outputs = self.forward(input_ids=generated)
116
+ logits = outputs.logits[:, -1, :]
117
+
118
+ logits = logits / temperature
119
+
120
+ if top_k is not None:
121
+ top_k_vals = torch.topk(logits, top_k).values[:, -1]
122
+ logits = torch.where(logits < top_k_vals.unsqueeze(1), torch.full_like(logits, float('-inf')), logits)
123
+
124
+ probs = F.softmax(logits, dim=-1)
125
+ next_token = torch.multinomial(probs, num_samples=1)
126
+
127
+ generated = torch.cat([generated, next_token], dim=1)
128
+
129
+ if eos_token_id is not None and (next_token == eos_token_id).all():
130
+ break
131
+
132
+ return generated
133
+
134
+ def compute_perplexity(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
135
+ '''
136
+ Compute perplexity from logits and target tokens.
137
+
138
+ Args:
139
+ logits (torch.Tensor): Model output logits with shape [batch, seq_len, vocab_size].
140
+ targets (torch.Tensor): Target token IDs with shape [batch, seq_len].
141
+
142
+ Returns:
143
+ torch.Tensor: Perplexity value (lower is better).
144
+ '''
145
+ batch_size, seq_len, vocab_size = logits.shape
146
+
147
+ logits_flat = logits.reshape(batch_size * seq_len, vocab_size)
148
+ targets_flat = targets.reshape(batch_size * seq_len)
149
+
150
+ loss = F.cross_entropy(logits_flat, targets_flat, reduction='mean')
151
+ perplexity = torch.exp(loss)
152
+
153
+ return perplexity
154
+
155
+
156
+ class AutoencoderVisionModel(BasicModel):
157
+ '''
158
+ Base class for autoencoder vision models with encoding/decoding capabilities.
159
+
160
+ Attributes:
161
+ gradient_checkpointing (bool): Whether gradient checkpointing is enabled.
162
+ '''
163
+ def __init__(self):
164
+ super().__init__()
165
+ self.codebook_size: int = 0
166
+
167
+ @staticmethod
168
+ def compute_psnr(img1: torch.Tensor, img2: torch.Tensor, max_value: float = 1.0) -> torch.Tensor:
169
+ '''
170
+ Compute Peak Signal-to-Noise Ratio between two images.
171
+
172
+ Args:
173
+ img1 (torch.Tensor): Reference image tensor.
174
+ img2 (torch.Tensor): Comparison image tensor.
175
+ max_value (float): Maximum possible pixel value. Defaults to 1.0.
176
+
177
+ Returns:
178
+ torch.Tensor: PSNR value in dB (higher is better).
179
+ '''
180
+ mse = torch.mean((img1 - img2) ** 2)
181
+ psnr = 10 * torch.log10(max_value ** 2 / mse)
182
+ return psnr
183
+
184
+ def encode(self, x: torch.Tensor) -> AutoVisionEncoderOutput:
185
+ '''
186
+ Encode an image to latent representation.
187
+
188
+ Args:
189
+ x (torch.Tensor): Input image tensor with shape [batch, channels, height, width].
190
+
191
+ Returns:
192
+ AutoVisionEncoderOutput: Output containing latent representation and grid_shape.
193
+ '''
194
+ return self._encode(x)
195
+
196
+ def decode(self, encoder_output: AutoVisionEncoderOutput) -> AutoVisionDecoderOutput:
197
+ '''
198
+ Decode a latent representation to an image.
199
+
200
+ Args:
201
+ encoder_output (AutoVisionEncoderOutput): Output from encode method containing
202
+ latent representation and grid_shape.
203
+
204
+ Returns:
205
+ AutoVisionDecoderOutput: Output containing reconstructed image and grid_shape.
206
+ '''
207
+ return self._decode(encoder_output)
208
+
209
+ def _encode(self, x: torch.Tensor) -> AutoVisionEncoderOutput:
210
+ '''
211
+ Internal encoding method to be implemented by subclasses.
212
+
213
+ Args:
214
+ x (torch.Tensor): Input image tensor.
215
+
216
+ Returns:
217
+ AutoVisionEncoderOutput: Output containing latent representation and grid_shape.
218
+ '''
219
+ raise NotImplementedError('Subclasses must implement _encode method')
220
+
221
+ def _decode(self, encoder_output: AutoVisionEncoderOutput) -> AutoVisionDecoderOutput:
222
+ '''
223
+ Internal decoding method to be implemented by subclasses.
224
+
225
+ Args:
226
+ encoder_output (AutoVisionEncoderOutput): Output from encode method.
227
+
228
+ Returns:
229
+ AutoVisionDecoderOutput: Output containing reconstructed image.
230
+ '''
231
+ raise NotImplementedError('Subclasses must implement _decode method')
@@ -3,18 +3,13 @@ from codon.base import *
3
3
  from codon.block.transformer import TransformerMoEDecoder
4
4
  from codon.block.embedding import RotaryEmbedding
5
5
 
6
+ from .base import CausalLanguageModel, CausalLanguageModelOutput
7
+
6
8
  from typing import Optional, List, Tuple
7
9
  from dataclasses import dataclass
8
10
 
9
- @dataclass
10
- class MotifA1Output:
11
- logits: torch.Tensor
12
- past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None
13
- aux_loss: Optional[torch.Tensor] = None
14
- attentions: Optional[List[torch.Tensor]] = None
15
-
16
11
 
17
- class MotifA1(BasicModel):
12
+ class MotifA1(CausalLanguageModel):
18
13
  def __init__(
19
14
  self,
20
15
  vocab_size: int = 32000,
@@ -75,7 +70,7 @@ class MotifA1(BasicModel):
75
70
  past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
76
71
  use_cache: bool = False,
77
72
  output_attentions: bool = False
78
- ) -> MotifA1Output:
73
+ ) -> CausalLanguageModelOutput:
79
74
  x = self.token_emb(input_ids)
80
75
  x = self.dropout(x)
81
76
 
@@ -113,7 +108,7 @@ class MotifA1(BasicModel):
113
108
  x = self.norm(x)
114
109
  logits = self.proj_out(x)
115
110
 
116
- return MotifA1Output(
111
+ return CausalLanguageModelOutput(
117
112
  logits=logits,
118
113
  past_key_values=new_kv_cache,
119
114
  aux_loss=aux_loss,