gptmed 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gptmed/__init__.py +37 -0
- gptmed/configs/__init__.py +1 -0
- gptmed/configs/train_config.py +154 -0
- gptmed/data/__init__.py +5 -0
- gptmed/data/parsers/__init__.py +10 -0
- gptmed/data/parsers/medquad_parser.py +257 -0
- gptmed/data/parsers/text_formatter.py +148 -0
- gptmed/inference/__init__.py +1 -0
- gptmed/inference/decoding_utils.py +190 -0
- gptmed/inference/generation_config.py +83 -0
- gptmed/inference/generator.py +253 -0
- gptmed/inference/sampling.py +261 -0
- gptmed/model/__init__.py +9 -0
- gptmed/model/architecture/__init__.py +35 -0
- gptmed/model/architecture/attention.py +188 -0
- gptmed/model/architecture/decoder_block.py +130 -0
- gptmed/model/architecture/embeddings.py +146 -0
- gptmed/model/architecture/feedforward.py +109 -0
- gptmed/model/architecture/transformer.py +204 -0
- gptmed/model/configs/__init__.py +17 -0
- gptmed/model/configs/model_config.py +155 -0
- gptmed/tokenizer/__init__.py +7 -0
- gptmed/tokenizer/tokenize_data.py +286 -0
- gptmed/tokenizer/train_tokenizer.py +218 -0
- gptmed/training/__init__.py +1 -0
- gptmed/training/dataset.py +183 -0
- gptmed/training/train.py +272 -0
- gptmed/training/trainer.py +331 -0
- gptmed/training/utils.py +212 -0
- gptmed/utils/__init__.py +1 -0
- gptmed/utils/checkpoints.py +224 -0
- gptmed/utils/logging.py +189 -0
- gptmed-0.0.1.dist-info/METADATA +325 -0
- gptmed-0.0.1.dist-info/RECORD +38 -0
- gptmed-0.0.1.dist-info/WHEEL +5 -0
- gptmed-0.0.1.dist-info/entry_points.txt +3 -0
- gptmed-0.0.1.dist-info/licenses/LICENSE +21 -0
- gptmed-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Decoding Utilities
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Helper functions for text generation:
|
|
6
|
+
- Repetition penalty
|
|
7
|
+
- N-gram blocking
|
|
8
|
+
- Stopping criteria
|
|
9
|
+
|
|
10
|
+
WHAT THIS FILE DOES:
|
|
11
|
+
1. Apply repetition penalty to discourage repeated tokens
|
|
12
|
+
2. Block n-gram repetition (prevents "the the the")
|
|
13
|
+
3. Check stopping conditions
|
|
14
|
+
|
|
15
|
+
WHY THESE ARE NEEDED:
|
|
16
|
+
- Models often get stuck in repetition loops
|
|
17
|
+
- "The patient has has has has..."
|
|
18
|
+
- N-gram blocking prevents this
|
|
19
|
+
- Repetition penalty makes it less likely
|
|
20
|
+
|
|
21
|
+
PACKAGES USED:
|
|
22
|
+
- torch: PyTorch tensors
|
|
23
|
+
|
|
24
|
+
FILES FROM THIS PROJECT:
|
|
25
|
+
- None (utility functions)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
import torch
|
|
29
|
+
from typing import List, Set
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def apply_repetition_penalty(
|
|
33
|
+
logits: torch.Tensor, generated_tokens: torch.Tensor, penalty: float = 1.0
|
|
34
|
+
) -> torch.Tensor:
|
|
35
|
+
"""
|
|
36
|
+
Apply repetition penalty to discourage repeated tokens.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
logits: Current logits [batch_size, vocab_size]
|
|
40
|
+
generated_tokens: Previously generated tokens [batch_size, seq_len]
|
|
41
|
+
penalty: Penalty factor (>1.0 penalizes, <1.0 encourages)
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Modified logits
|
|
45
|
+
|
|
46
|
+
How it works:
|
|
47
|
+
- For each token that appeared before:
|
|
48
|
+
- If its logit is positive: divide by penalty
|
|
49
|
+
- If its logit is negative: multiply by penalty
|
|
50
|
+
- This makes repeated tokens less likely
|
|
51
|
+
|
|
52
|
+
Typical values:
|
|
53
|
+
- penalty=1.0: No penalty
|
|
54
|
+
- penalty=1.1: Mild penalty
|
|
55
|
+
- penalty=1.2: Moderate (recommended)
|
|
56
|
+
- penalty=1.5: Strong penalty
|
|
57
|
+
|
|
58
|
+
Warning: Too high penalty can make model avoid common words!
|
|
59
|
+
"""
|
|
60
|
+
if penalty == 1.0:
|
|
61
|
+
return logits
|
|
62
|
+
|
|
63
|
+
batch_size = logits.size(0)
|
|
64
|
+
|
|
65
|
+
for batch_idx in range(batch_size):
|
|
66
|
+
# Get unique tokens in this sequence
|
|
67
|
+
unique_tokens = generated_tokens[batch_idx].unique()
|
|
68
|
+
|
|
69
|
+
for token_id in unique_tokens:
|
|
70
|
+
token_id = token_id.item()
|
|
71
|
+
|
|
72
|
+
# Apply penalty
|
|
73
|
+
if logits[batch_idx, token_id] > 0:
|
|
74
|
+
logits[batch_idx, token_id] /= penalty
|
|
75
|
+
else:
|
|
76
|
+
logits[batch_idx, token_id] *= penalty
|
|
77
|
+
|
|
78
|
+
return logits
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_ngrams(tokens: List[int], n: int) -> Set[tuple]:
|
|
82
|
+
"""
|
|
83
|
+
Extract all n-grams from token list.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
tokens: List of token IDs
|
|
87
|
+
n: N-gram size
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Set of n-grams (tuples)
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
tokens = [1, 2, 3, 4]
|
|
94
|
+
get_ngrams(tokens, 2) = {(1,2), (2,3), (3,4)}
|
|
95
|
+
"""
|
|
96
|
+
if len(tokens) < n:
|
|
97
|
+
return set()
|
|
98
|
+
|
|
99
|
+
ngrams = set()
|
|
100
|
+
for i in range(len(tokens) - n + 1):
|
|
101
|
+
ngram = tuple(tokens[i : i + n])
|
|
102
|
+
ngrams.add(ngram)
|
|
103
|
+
|
|
104
|
+
return ngrams
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def block_ngram_repeats(
|
|
108
|
+
logits: torch.Tensor, generated_tokens: List[int], ngram_size: int = 3
|
|
109
|
+
) -> torch.Tensor:
|
|
110
|
+
"""
|
|
111
|
+
Block tokens that would create repeated n-grams.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
logits: Current logits [vocab_size]
|
|
115
|
+
generated_tokens: Previously generated tokens (list)
|
|
116
|
+
ngram_size: N-gram size to block
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Modified logits with blocked tokens set to -inf
|
|
120
|
+
|
|
121
|
+
How it works:
|
|
122
|
+
- Look at last (n-1) tokens
|
|
123
|
+
- Find all tokens that appeared after this (n-1)-gram before
|
|
124
|
+
- Set their logits to -inf (can't be sampled)
|
|
125
|
+
|
|
126
|
+
Example with n=3:
|
|
127
|
+
- Generated: "The patient has the patient"
|
|
128
|
+
- Last 2 tokens: "the patient"
|
|
129
|
+
- Previously after "the patient": "has"
|
|
130
|
+
- Block "has" from being generated again
|
|
131
|
+
|
|
132
|
+
This prevents: "The patient has the patient has the patient has..."
|
|
133
|
+
"""
|
|
134
|
+
if ngram_size == 0 or len(generated_tokens) < ngram_size:
|
|
135
|
+
return logits
|
|
136
|
+
|
|
137
|
+
# Get context (last n-1 tokens)
|
|
138
|
+
context = tuple(generated_tokens[-(ngram_size - 1) :])
|
|
139
|
+
|
|
140
|
+
# Find all tokens that appeared after this context
|
|
141
|
+
blocked_tokens = set()
|
|
142
|
+
|
|
143
|
+
for i in range(len(generated_tokens) - ngram_size + 1):
|
|
144
|
+
# Check if this position matches our context
|
|
145
|
+
if tuple(generated_tokens[i : i + ngram_size - 1]) == context:
|
|
146
|
+
# The next token creates a repeated n-gram
|
|
147
|
+
next_token = generated_tokens[i + ngram_size - 1]
|
|
148
|
+
blocked_tokens.add(next_token)
|
|
149
|
+
|
|
150
|
+
# Block these tokens
|
|
151
|
+
for token_id in blocked_tokens:
|
|
152
|
+
logits[token_id] = float("-inf")
|
|
153
|
+
|
|
154
|
+
return logits
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def should_stop_generation(
|
|
158
|
+
generated_tokens: List[int], stop_tokens: List[int], max_length: int, min_length: int
|
|
159
|
+
) -> bool:
|
|
160
|
+
"""
|
|
161
|
+
Check if generation should stop.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
generated_tokens: Generated token IDs
|
|
165
|
+
stop_tokens: Token IDs that trigger stopping (e.g., EOS)
|
|
166
|
+
max_length: Maximum allowed length
|
|
167
|
+
min_length: Minimum required length
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
True if should stop, False otherwise
|
|
171
|
+
|
|
172
|
+
Stopping criteria:
|
|
173
|
+
1. Reached max_length
|
|
174
|
+
2. Generated a stop token (and past min_length)
|
|
175
|
+
"""
|
|
176
|
+
current_length = len(generated_tokens)
|
|
177
|
+
|
|
178
|
+
# Must generate at least min_length tokens
|
|
179
|
+
if current_length < min_length:
|
|
180
|
+
return False
|
|
181
|
+
|
|
182
|
+
# Stop if reached max length
|
|
183
|
+
if current_length >= max_length:
|
|
184
|
+
return True
|
|
185
|
+
|
|
186
|
+
# Stop if generated a stop token
|
|
187
|
+
if generated_tokens[-1] in stop_tokens:
|
|
188
|
+
return True
|
|
189
|
+
|
|
190
|
+
return False
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Generation Configuration
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Hyperparameters for text generation (inference).
|
|
6
|
+
Controls how creative vs conservative the model's outputs are.
|
|
7
|
+
|
|
8
|
+
WHAT THIS FILE CONTAINS:
|
|
9
|
+
- Temperature: Randomness control
|
|
10
|
+
- Top-k, top-p: Sampling constraints
|
|
11
|
+
- Repetition penalty: Prevent repetitive text
|
|
12
|
+
- Max length: Stop generation
|
|
13
|
+
|
|
14
|
+
PACKAGES USED:
|
|
15
|
+
- dataclasses: Clean config structure
|
|
16
|
+
|
|
17
|
+
FILES FROM THIS PROJECT:
|
|
18
|
+
- None (base config)
|
|
19
|
+
|
|
20
|
+
KEY PARAMETERS EXPLAINED:
|
|
21
|
+
- temperature: 0.0 = greedy, 0.7 = balanced, 1.5 = very creative
|
|
22
|
+
- top_k: Only sample from top k tokens (50-100 typical)
|
|
23
|
+
- top_p: Sample from tokens with cumulative prob p (0.9-0.95 typical)
|
|
24
|
+
- repetition_penalty: >1.0 discourages repetition (1.2 is good)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from dataclasses import dataclass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class GenerationConfig:
|
|
32
|
+
"""Configuration for text generation."""
|
|
33
|
+
|
|
34
|
+
# Sampling strategy
|
|
35
|
+
temperature: float = 0.8 # Higher = more random (0.0 = greedy)
|
|
36
|
+
top_k: int = 50 # Only sample from top k tokens (0 = disabled)
|
|
37
|
+
top_p: float = 0.95 # Nucleus sampling threshold (1.0 = disabled)
|
|
38
|
+
|
|
39
|
+
# Repetition control
|
|
40
|
+
repetition_penalty: float = 1.2 # >1.0 penalizes repetition
|
|
41
|
+
no_repeat_ngram_size: int = 3 # Block repeating n-grams (0 = disabled)
|
|
42
|
+
|
|
43
|
+
# Length control
|
|
44
|
+
max_length: int = 200 # Maximum tokens to generate
|
|
45
|
+
min_length: int = 10 # Minimum tokens to generate
|
|
46
|
+
|
|
47
|
+
# Stopping criteria
|
|
48
|
+
stop_tokens: list = None # Token IDs that stop generation
|
|
49
|
+
|
|
50
|
+
# Special tokens
|
|
51
|
+
bos_token_id: int = 2 # Beginning of sequence
|
|
52
|
+
eos_token_id: int = 3 # End of sequence
|
|
53
|
+
pad_token_id: int = 0 # Padding
|
|
54
|
+
|
|
55
|
+
def __post_init__(self):
|
|
56
|
+
"""Validate config."""
|
|
57
|
+
if self.stop_tokens is None:
|
|
58
|
+
self.stop_tokens = [self.eos_token_id]
|
|
59
|
+
|
|
60
|
+
assert self.temperature >= 0.0, "temperature must be >= 0"
|
|
61
|
+
assert self.top_k >= 0, "top_k must be >= 0"
|
|
62
|
+
assert 0.0 <= self.top_p <= 1.0, "top_p must be in [0, 1]"
|
|
63
|
+
assert self.repetition_penalty >= 1.0, "repetition_penalty must be >= 1.0"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_greedy_config() -> GenerationConfig:
|
|
67
|
+
"""Greedy decoding (deterministic, picks highest prob)."""
|
|
68
|
+
return GenerationConfig(temperature=0.0, top_k=0, top_p=1.0, repetition_penalty=1.0)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_balanced_config() -> GenerationConfig:
|
|
72
|
+
"""Balanced sampling (good default)."""
|
|
73
|
+
return GenerationConfig(temperature=0.8, top_k=50, top_p=0.95, repetition_penalty=1.2)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_creative_config() -> GenerationConfig:
|
|
77
|
+
"""Creative sampling (more diverse, less coherent)."""
|
|
78
|
+
return GenerationConfig(temperature=1.2, top_k=100, top_p=0.95, repetition_penalty=1.3)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_conservative_config() -> GenerationConfig:
|
|
82
|
+
"""Conservative sampling (safe, coherent, less diverse)."""
|
|
83
|
+
return GenerationConfig(temperature=0.5, top_k=30, top_p=0.9, repetition_penalty=1.1)
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Text Generator
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Main class for text generation using trained GPT model.
|
|
6
|
+
Combines model loading, tokenization, and decoding strategies.
|
|
7
|
+
|
|
8
|
+
WHAT THIS FILE DOES:
|
|
9
|
+
1. Load trained model from checkpoint
|
|
10
|
+
2. Generate text autoregressively (token by token)
|
|
11
|
+
3. Apply sampling strategies and repetition control
|
|
12
|
+
4. Convert tokens back to text
|
|
13
|
+
|
|
14
|
+
GENERATION PROCESS:
|
|
15
|
+
1. Start with prompt tokens
|
|
16
|
+
2. Loop:
|
|
17
|
+
a. Model forward pass → logits
|
|
18
|
+
b. Apply repetition penalty
|
|
19
|
+
c. Sample next token
|
|
20
|
+
d. Append to sequence
|
|
21
|
+
e. Check stopping criteria
|
|
22
|
+
3. Decode tokens to text
|
|
23
|
+
|
|
24
|
+
PACKAGES USED:
|
|
25
|
+
- torch: PyTorch
|
|
26
|
+
- sentencepiece: Tokenizer
|
|
27
|
+
|
|
28
|
+
FILES FROM THIS PROJECT:
|
|
29
|
+
- model/architecture/transformer.py: GPT model
|
|
30
|
+
- model/configs/model_config.py: Model config
|
|
31
|
+
- inference/sampling.py: Sampling strategies
|
|
32
|
+
- inference/decoding_utils.py: Repetition control
|
|
33
|
+
- utils/checkpoints.py: Load checkpoints
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
import torch
|
|
37
|
+
import sentencepiece as spm
|
|
38
|
+
from pathlib import Path
|
|
39
|
+
from typing import List, Optional
|
|
40
|
+
|
|
41
|
+
from llm_med.model.architecture import GPTTransformer
|
|
42
|
+
from llm_med.model.configs.model_config import ModelConfig
|
|
43
|
+
from llm_med.inference.generation_config import GenerationConfig
|
|
44
|
+
from llm_med.inference.sampling import sample_next_token
|
|
45
|
+
from llm_med.inference.decoding_utils import (
|
|
46
|
+
apply_repetition_penalty,
|
|
47
|
+
block_ngram_repeats,
|
|
48
|
+
should_stop_generation,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class TextGenerator:
|
|
53
|
+
"""
|
|
54
|
+
Text generation with trained GPT model.
|
|
55
|
+
|
|
56
|
+
This is your interface for inference.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self, model: GPTTransformer, tokenizer: spm.SentencePieceProcessor, device: str = "cuda"
|
|
61
|
+
):
|
|
62
|
+
"""
|
|
63
|
+
Args:
|
|
64
|
+
model: Trained GPT model
|
|
65
|
+
tokenizer: SentencePiece tokenizer
|
|
66
|
+
device: Device to run on
|
|
67
|
+
"""
|
|
68
|
+
self.model = model.to(device)
|
|
69
|
+
self.model.eval() # Set to evaluation mode
|
|
70
|
+
self.tokenizer = tokenizer
|
|
71
|
+
self.device = device
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def from_checkpoint(cls, checkpoint_path: Path, tokenizer_path: Path, device: str = "cuda"):
|
|
75
|
+
"""
|
|
76
|
+
Load generator from checkpoint.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
checkpoint_path: Path to model checkpoint
|
|
80
|
+
tokenizer_path: Path to tokenizer .model file
|
|
81
|
+
device: Device to load on
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
TextGenerator instance
|
|
85
|
+
"""
|
|
86
|
+
# Load checkpoint
|
|
87
|
+
print(f"Loading checkpoint: {checkpoint_path}")
|
|
88
|
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
89
|
+
|
|
90
|
+
# Create model from config
|
|
91
|
+
model_config_dict = checkpoint["model_config"]
|
|
92
|
+
model_config = ModelConfig(**model_config_dict)
|
|
93
|
+
model = GPTTransformer(model_config)
|
|
94
|
+
|
|
95
|
+
# Load weights
|
|
96
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
97
|
+
model.to(device)
|
|
98
|
+
model.eval()
|
|
99
|
+
|
|
100
|
+
print(f"Model loaded (step {checkpoint['step']})")
|
|
101
|
+
print(f"Validation loss: {checkpoint.get('val_loss', 'N/A')}")
|
|
102
|
+
|
|
103
|
+
# Load tokenizer
|
|
104
|
+
print(f"Loading tokenizer: {tokenizer_path}")
|
|
105
|
+
tokenizer = spm.SentencePieceProcessor()
|
|
106
|
+
tokenizer.load(str(tokenizer_path))
|
|
107
|
+
|
|
108
|
+
return cls(model, tokenizer, device)
|
|
109
|
+
|
|
110
|
+
def encode(self, text: str) -> List[int]:
|
|
111
|
+
"""Encode text to token IDs."""
|
|
112
|
+
return self.tokenizer.encode_as_ids(text)
|
|
113
|
+
|
|
114
|
+
def decode(self, token_ids: List[int]) -> str:
|
|
115
|
+
"""Decode token IDs to text."""
|
|
116
|
+
return self.tokenizer.decode_ids(token_ids)
|
|
117
|
+
|
|
118
|
+
def generate(
|
|
119
|
+
self, prompt: str, gen_config: GenerationConfig = None, verbose: bool = False
|
|
120
|
+
) -> str:
|
|
121
|
+
"""
|
|
122
|
+
Generate text from a prompt.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
prompt: Input text prompt
|
|
126
|
+
gen_config: Generation configuration
|
|
127
|
+
verbose: Print generation progress
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Generated text
|
|
131
|
+
|
|
132
|
+
Process:
|
|
133
|
+
1. Encode prompt to tokens
|
|
134
|
+
2. Generate tokens autoregressively
|
|
135
|
+
3. Decode back to text
|
|
136
|
+
"""
|
|
137
|
+
if gen_config is None:
|
|
138
|
+
gen_config = GenerationConfig()
|
|
139
|
+
|
|
140
|
+
# Encode prompt
|
|
141
|
+
prompt_tokens = self.encode(prompt)
|
|
142
|
+
|
|
143
|
+
if verbose:
|
|
144
|
+
print(f"Prompt: {prompt}")
|
|
145
|
+
print(f"Prompt tokens: {prompt_tokens}")
|
|
146
|
+
print(f"Generating...")
|
|
147
|
+
|
|
148
|
+
# Generate tokens
|
|
149
|
+
generated_tokens = self.generate_tokens(
|
|
150
|
+
prompt_tokens=prompt_tokens, gen_config=gen_config, verbose=verbose
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Decode to text
|
|
154
|
+
generated_text = self.decode(generated_tokens)
|
|
155
|
+
|
|
156
|
+
if verbose:
|
|
157
|
+
print(f"\nGenerated {len(generated_tokens)} tokens")
|
|
158
|
+
print(f"Output: {generated_text}")
|
|
159
|
+
|
|
160
|
+
return generated_text
|
|
161
|
+
|
|
162
|
+
def generate_tokens(
|
|
163
|
+
self, prompt_tokens: List[int], gen_config: GenerationConfig, verbose: bool = False
|
|
164
|
+
) -> List[int]:
|
|
165
|
+
"""
|
|
166
|
+
Generate tokens autoregressively.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
prompt_tokens: Input token IDs
|
|
170
|
+
gen_config: Generation config
|
|
171
|
+
verbose: Print progress
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
List of generated token IDs (including prompt)
|
|
175
|
+
"""
|
|
176
|
+
# Start with prompt
|
|
177
|
+
generated = prompt_tokens.copy()
|
|
178
|
+
|
|
179
|
+
# Generation loop
|
|
180
|
+
with torch.no_grad():
|
|
181
|
+
for step in range(gen_config.max_length):
|
|
182
|
+
# Check stopping criteria
|
|
183
|
+
if should_stop_generation(
|
|
184
|
+
generated_tokens=generated,
|
|
185
|
+
stop_tokens=gen_config.stop_tokens,
|
|
186
|
+
max_length=gen_config.max_length,
|
|
187
|
+
min_length=gen_config.min_length,
|
|
188
|
+
):
|
|
189
|
+
break
|
|
190
|
+
|
|
191
|
+
# Prepare input (last max_seq_len tokens)
|
|
192
|
+
max_seq_len = self.model.config.max_seq_len
|
|
193
|
+
input_ids = generated[-max_seq_len:]
|
|
194
|
+
input_tensor = torch.tensor([input_ids], device=self.device)
|
|
195
|
+
|
|
196
|
+
# Forward pass
|
|
197
|
+
logits = self.model(input_tensor)
|
|
198
|
+
|
|
199
|
+
# Get logits for last position
|
|
200
|
+
next_token_logits = logits[0, -1, :] # [vocab_size]
|
|
201
|
+
|
|
202
|
+
# Apply repetition penalty
|
|
203
|
+
if gen_config.repetition_penalty != 1.0:
|
|
204
|
+
generated_tensor = torch.tensor([generated], device=self.device)
|
|
205
|
+
next_token_logits = apply_repetition_penalty(
|
|
206
|
+
next_token_logits.unsqueeze(0),
|
|
207
|
+
generated_tensor,
|
|
208
|
+
gen_config.repetition_penalty,
|
|
209
|
+
).squeeze(0)
|
|
210
|
+
|
|
211
|
+
# Block n-gram repeats
|
|
212
|
+
if gen_config.no_repeat_ngram_size > 0:
|
|
213
|
+
next_token_logits = block_ngram_repeats(
|
|
214
|
+
next_token_logits, generated, gen_config.no_repeat_ngram_size
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Sample next token
|
|
218
|
+
next_token = sample_next_token(
|
|
219
|
+
next_token_logits.unsqueeze(0),
|
|
220
|
+
temperature=gen_config.temperature,
|
|
221
|
+
top_k=gen_config.top_k,
|
|
222
|
+
top_p=gen_config.top_p,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
next_token_id = next_token.item()
|
|
226
|
+
generated.append(next_token_id)
|
|
227
|
+
|
|
228
|
+
if verbose and step % 10 == 0:
|
|
229
|
+
partial_text = self.decode(generated)
|
|
230
|
+
print(f"Step {step}: {partial_text}")
|
|
231
|
+
|
|
232
|
+
return generated
|
|
233
|
+
|
|
234
|
+
def generate_batch(self, prompts: List[str], gen_config: GenerationConfig = None) -> List[str]:
|
|
235
|
+
"""
|
|
236
|
+
Generate text for multiple prompts.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
prompts: List of input prompts
|
|
240
|
+
gen_config: Generation config
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
List of generated texts
|
|
244
|
+
"""
|
|
245
|
+
if gen_config is None:
|
|
246
|
+
gen_config = GenerationConfig()
|
|
247
|
+
|
|
248
|
+
results = []
|
|
249
|
+
for prompt in prompts:
|
|
250
|
+
output = self.generate(prompt, gen_config, verbose=False)
|
|
251
|
+
results.append(output)
|
|
252
|
+
|
|
253
|
+
return results
|