gptmed 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. gptmed/__init__.py +37 -0
  2. gptmed/configs/__init__.py +1 -0
  3. gptmed/configs/train_config.py +154 -0
  4. gptmed/data/__init__.py +5 -0
  5. gptmed/data/parsers/__init__.py +10 -0
  6. gptmed/data/parsers/medquad_parser.py +257 -0
  7. gptmed/data/parsers/text_formatter.py +148 -0
  8. gptmed/inference/__init__.py +1 -0
  9. gptmed/inference/decoding_utils.py +190 -0
  10. gptmed/inference/generation_config.py +83 -0
  11. gptmed/inference/generator.py +253 -0
  12. gptmed/inference/sampling.py +261 -0
  13. gptmed/model/__init__.py +9 -0
  14. gptmed/model/architecture/__init__.py +35 -0
  15. gptmed/model/architecture/attention.py +188 -0
  16. gptmed/model/architecture/decoder_block.py +130 -0
  17. gptmed/model/architecture/embeddings.py +146 -0
  18. gptmed/model/architecture/feedforward.py +109 -0
  19. gptmed/model/architecture/transformer.py +204 -0
  20. gptmed/model/configs/__init__.py +17 -0
  21. gptmed/model/configs/model_config.py +155 -0
  22. gptmed/tokenizer/__init__.py +7 -0
  23. gptmed/tokenizer/tokenize_data.py +286 -0
  24. gptmed/tokenizer/train_tokenizer.py +218 -0
  25. gptmed/training/__init__.py +1 -0
  26. gptmed/training/dataset.py +183 -0
  27. gptmed/training/train.py +272 -0
  28. gptmed/training/trainer.py +331 -0
  29. gptmed/training/utils.py +212 -0
  30. gptmed/utils/__init__.py +1 -0
  31. gptmed/utils/checkpoints.py +224 -0
  32. gptmed/utils/logging.py +189 -0
  33. gptmed-0.0.1.dist-info/METADATA +325 -0
  34. gptmed-0.0.1.dist-info/RECORD +38 -0
  35. gptmed-0.0.1.dist-info/WHEEL +5 -0
  36. gptmed-0.0.1.dist-info/entry_points.txt +3 -0
  37. gptmed-0.0.1.dist-info/licenses/LICENSE +21 -0
  38. gptmed-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,190 @@
1
+ """
2
+ Decoding Utilities
3
+
4
+ PURPOSE:
5
+ Helper functions for text generation:
6
+ - Repetition penalty
7
+ - N-gram blocking
8
+ - Stopping criteria
9
+
10
+ WHAT THIS FILE DOES:
11
+ 1. Apply repetition penalty to discourage repeated tokens
12
+ 2. Block n-gram repetition (prevents "the the the")
13
+ 3. Check stopping conditions
14
+
15
+ WHY THESE ARE NEEDED:
16
+ - Models often get stuck in repetition loops
17
+ - "The patient has has has has..."
18
+ - N-gram blocking prevents this
19
+ - Repetition penalty makes it less likely
20
+
21
+ PACKAGES USED:
22
+ - torch: PyTorch tensors
23
+
24
+ FILES FROM THIS PROJECT:
25
+ - None (utility functions)
26
+ """
27
+
28
+ import torch
29
+ from typing import List, Set
30
+
31
+
32
+ def apply_repetition_penalty(
33
+ logits: torch.Tensor, generated_tokens: torch.Tensor, penalty: float = 1.0
34
+ ) -> torch.Tensor:
35
+ """
36
+ Apply repetition penalty to discourage repeated tokens.
37
+
38
+ Args:
39
+ logits: Current logits [batch_size, vocab_size]
40
+ generated_tokens: Previously generated tokens [batch_size, seq_len]
41
+ penalty: Penalty factor (>1.0 penalizes, <1.0 encourages)
42
+
43
+ Returns:
44
+ Modified logits
45
+
46
+ How it works:
47
+ - For each token that appeared before:
48
+ - If its logit is positive: divide by penalty
49
+ - If its logit is negative: multiply by penalty
50
+ - This makes repeated tokens less likely
51
+
52
+ Typical values:
53
+ - penalty=1.0: No penalty
54
+ - penalty=1.1: Mild penalty
55
+ - penalty=1.2: Moderate (recommended)
56
+ - penalty=1.5: Strong penalty
57
+
58
+ Warning: Too high penalty can make model avoid common words!
59
+ """
60
+ if penalty == 1.0:
61
+ return logits
62
+
63
+ batch_size = logits.size(0)
64
+
65
+ for batch_idx in range(batch_size):
66
+ # Get unique tokens in this sequence
67
+ unique_tokens = generated_tokens[batch_idx].unique()
68
+
69
+ for token_id in unique_tokens:
70
+ token_id = token_id.item()
71
+
72
+ # Apply penalty
73
+ if logits[batch_idx, token_id] > 0:
74
+ logits[batch_idx, token_id] /= penalty
75
+ else:
76
+ logits[batch_idx, token_id] *= penalty
77
+
78
+ return logits
79
+
80
+
81
+ def get_ngrams(tokens: List[int], n: int) -> Set[tuple]:
82
+ """
83
+ Extract all n-grams from token list.
84
+
85
+ Args:
86
+ tokens: List of token IDs
87
+ n: N-gram size
88
+
89
+ Returns:
90
+ Set of n-grams (tuples)
91
+
92
+ Example:
93
+ tokens = [1, 2, 3, 4]
94
+ get_ngrams(tokens, 2) = {(1,2), (2,3), (3,4)}
95
+ """
96
+ if len(tokens) < n:
97
+ return set()
98
+
99
+ ngrams = set()
100
+ for i in range(len(tokens) - n + 1):
101
+ ngram = tuple(tokens[i : i + n])
102
+ ngrams.add(ngram)
103
+
104
+ return ngrams
105
+
106
+
107
+ def block_ngram_repeats(
108
+ logits: torch.Tensor, generated_tokens: List[int], ngram_size: int = 3
109
+ ) -> torch.Tensor:
110
+ """
111
+ Block tokens that would create repeated n-grams.
112
+
113
+ Args:
114
+ logits: Current logits [vocab_size]
115
+ generated_tokens: Previously generated tokens (list)
116
+ ngram_size: N-gram size to block
117
+
118
+ Returns:
119
+ Modified logits with blocked tokens set to -inf
120
+
121
+ How it works:
122
+ - Look at last (n-1) tokens
123
+ - Find all tokens that appeared after this (n-1)-gram before
124
+ - Set their logits to -inf (can't be sampled)
125
+
126
+ Example with n=3:
127
+ - Generated: "The patient has the patient"
128
+ - Last 2 tokens: "the patient"
129
+ - Previously after "the patient": "has"
130
+ - Block "has" from being generated again
131
+
132
+ This prevents: "The patient has the patient has the patient has..."
133
+ """
134
+ if ngram_size == 0 or len(generated_tokens) < ngram_size:
135
+ return logits
136
+
137
+ # Get context (last n-1 tokens)
138
+ context = tuple(generated_tokens[-(ngram_size - 1) :])
139
+
140
+ # Find all tokens that appeared after this context
141
+ blocked_tokens = set()
142
+
143
+ for i in range(len(generated_tokens) - ngram_size + 1):
144
+ # Check if this position matches our context
145
+ if tuple(generated_tokens[i : i + ngram_size - 1]) == context:
146
+ # The next token creates a repeated n-gram
147
+ next_token = generated_tokens[i + ngram_size - 1]
148
+ blocked_tokens.add(next_token)
149
+
150
+ # Block these tokens
151
+ for token_id in blocked_tokens:
152
+ logits[token_id] = float("-inf")
153
+
154
+ return logits
155
+
156
+
157
+ def should_stop_generation(
158
+ generated_tokens: List[int], stop_tokens: List[int], max_length: int, min_length: int
159
+ ) -> bool:
160
+ """
161
+ Check if generation should stop.
162
+
163
+ Args:
164
+ generated_tokens: Generated token IDs
165
+ stop_tokens: Token IDs that trigger stopping (e.g., EOS)
166
+ max_length: Maximum allowed length
167
+ min_length: Minimum required length
168
+
169
+ Returns:
170
+ True if should stop, False otherwise
171
+
172
+ Stopping criteria:
173
+ 1. Reached max_length
174
+ 2. Generated a stop token (and past min_length)
175
+ """
176
+ current_length = len(generated_tokens)
177
+
178
+ # Must generate at least min_length tokens
179
+ if current_length < min_length:
180
+ return False
181
+
182
+ # Stop if reached max length
183
+ if current_length >= max_length:
184
+ return True
185
+
186
+ # Stop if generated a stop token
187
+ if generated_tokens[-1] in stop_tokens:
188
+ return True
189
+
190
+ return False
@@ -0,0 +1,83 @@
1
+ """
2
+ Generation Configuration
3
+
4
+ PURPOSE:
5
+ Hyperparameters for text generation (inference).
6
+ Controls how creative vs conservative the model's outputs are.
7
+
8
+ WHAT THIS FILE CONTAINS:
9
+ - Temperature: Randomness control
10
+ - Top-k, top-p: Sampling constraints
11
+ - Repetition penalty: Prevent repetitive text
12
+ - Max length: Stop generation
13
+
14
+ PACKAGES USED:
15
+ - dataclasses: Clean config structure
16
+
17
+ FILES FROM THIS PROJECT:
18
+ - None (base config)
19
+
20
+ KEY PARAMETERS EXPLAINED:
21
+ - temperature: 0.0 = greedy, 0.7 = balanced, 1.5 = very creative
22
+ - top_k: Only sample from top k tokens (50-100 typical)
23
+ - top_p: Sample from tokens with cumulative prob p (0.9-0.95 typical)
24
+ - repetition_penalty: >1.0 discourages repetition (1.2 is good)
25
+ """
26
+
27
+ from dataclasses import dataclass
28
+
29
+
30
+ @dataclass
31
+ class GenerationConfig:
32
+ """Configuration for text generation."""
33
+
34
+ # Sampling strategy
35
+ temperature: float = 0.8 # Higher = more random (0.0 = greedy)
36
+ top_k: int = 50 # Only sample from top k tokens (0 = disabled)
37
+ top_p: float = 0.95 # Nucleus sampling threshold (1.0 = disabled)
38
+
39
+ # Repetition control
40
+ repetition_penalty: float = 1.2 # >1.0 penalizes repetition
41
+ no_repeat_ngram_size: int = 3 # Block repeating n-grams (0 = disabled)
42
+
43
+ # Length control
44
+ max_length: int = 200 # Maximum tokens to generate
45
+ min_length: int = 10 # Minimum tokens to generate
46
+
47
+ # Stopping criteria
48
+ stop_tokens: list = None # Token IDs that stop generation
49
+
50
+ # Special tokens
51
+ bos_token_id: int = 2 # Beginning of sequence
52
+ eos_token_id: int = 3 # End of sequence
53
+ pad_token_id: int = 0 # Padding
54
+
55
+ def __post_init__(self):
56
+ """Validate config."""
57
+ if self.stop_tokens is None:
58
+ self.stop_tokens = [self.eos_token_id]
59
+
60
+ assert self.temperature >= 0.0, "temperature must be >= 0"
61
+ assert self.top_k >= 0, "top_k must be >= 0"
62
+ assert 0.0 <= self.top_p <= 1.0, "top_p must be in [0, 1]"
63
+ assert self.repetition_penalty >= 1.0, "repetition_penalty must be >= 1.0"
64
+
65
+
66
+ def get_greedy_config() -> GenerationConfig:
67
+ """Greedy decoding (deterministic, picks highest prob)."""
68
+ return GenerationConfig(temperature=0.0, top_k=0, top_p=1.0, repetition_penalty=1.0)
69
+
70
+
71
+ def get_balanced_config() -> GenerationConfig:
72
+ """Balanced sampling (good default)."""
73
+ return GenerationConfig(temperature=0.8, top_k=50, top_p=0.95, repetition_penalty=1.2)
74
+
75
+
76
+ def get_creative_config() -> GenerationConfig:
77
+ """Creative sampling (more diverse, less coherent)."""
78
+ return GenerationConfig(temperature=1.2, top_k=100, top_p=0.95, repetition_penalty=1.3)
79
+
80
+
81
+ def get_conservative_config() -> GenerationConfig:
82
+ """Conservative sampling (safe, coherent, less diverse)."""
83
+ return GenerationConfig(temperature=0.5, top_k=30, top_p=0.9, repetition_penalty=1.1)
@@ -0,0 +1,253 @@
1
+ """
2
+ Text Generator
3
+
4
+ PURPOSE:
5
+ Main class for text generation using trained GPT model.
6
+ Combines model loading, tokenization, and decoding strategies.
7
+
8
+ WHAT THIS FILE DOES:
9
+ 1. Load trained model from checkpoint
10
+ 2. Generate text autoregressively (token by token)
11
+ 3. Apply sampling strategies and repetition control
12
+ 4. Convert tokens back to text
13
+
14
+ GENERATION PROCESS:
15
+ 1. Start with prompt tokens
16
+ 2. Loop:
17
+ a. Model forward pass → logits
18
+ b. Apply repetition penalty
19
+ c. Sample next token
20
+ d. Append to sequence
21
+ e. Check stopping criteria
22
+ 3. Decode tokens to text
23
+
24
+ PACKAGES USED:
25
+ - torch: PyTorch
26
+ - sentencepiece: Tokenizer
27
+
28
+ FILES FROM THIS PROJECT:
29
+ - model/architecture/transformer.py: GPT model
30
+ - model/configs/model_config.py: Model config
31
+ - inference/sampling.py: Sampling strategies
32
+ - inference/decoding_utils.py: Repetition control
33
+ - utils/checkpoints.py: Load checkpoints
34
+ """
35
+
36
+ import torch
37
+ import sentencepiece as spm
38
+ from pathlib import Path
39
+ from typing import List, Optional
40
+
41
+ from llm_med.model.architecture import GPTTransformer
42
+ from llm_med.model.configs.model_config import ModelConfig
43
+ from llm_med.inference.generation_config import GenerationConfig
44
+ from llm_med.inference.sampling import sample_next_token
45
+ from llm_med.inference.decoding_utils import (
46
+ apply_repetition_penalty,
47
+ block_ngram_repeats,
48
+ should_stop_generation,
49
+ )
50
+
51
+
52
+ class TextGenerator:
53
+ """
54
+ Text generation with trained GPT model.
55
+
56
+ This is your interface for inference.
57
+ """
58
+
59
+ def __init__(
60
+ self, model: GPTTransformer, tokenizer: spm.SentencePieceProcessor, device: str = "cuda"
61
+ ):
62
+ """
63
+ Args:
64
+ model: Trained GPT model
65
+ tokenizer: SentencePiece tokenizer
66
+ device: Device to run on
67
+ """
68
+ self.model = model.to(device)
69
+ self.model.eval() # Set to evaluation mode
70
+ self.tokenizer = tokenizer
71
+ self.device = device
72
+
73
+ @classmethod
74
+ def from_checkpoint(cls, checkpoint_path: Path, tokenizer_path: Path, device: str = "cuda"):
75
+ """
76
+ Load generator from checkpoint.
77
+
78
+ Args:
79
+ checkpoint_path: Path to model checkpoint
80
+ tokenizer_path: Path to tokenizer .model file
81
+ device: Device to load on
82
+
83
+ Returns:
84
+ TextGenerator instance
85
+ """
86
+ # Load checkpoint
87
+ print(f"Loading checkpoint: {checkpoint_path}")
88
+ checkpoint = torch.load(checkpoint_path, map_location=device)
89
+
90
+ # Create model from config
91
+ model_config_dict = checkpoint["model_config"]
92
+ model_config = ModelConfig(**model_config_dict)
93
+ model = GPTTransformer(model_config)
94
+
95
+ # Load weights
96
+ model.load_state_dict(checkpoint["model_state_dict"])
97
+ model.to(device)
98
+ model.eval()
99
+
100
+ print(f"Model loaded (step {checkpoint['step']})")
101
+ print(f"Validation loss: {checkpoint.get('val_loss', 'N/A')}")
102
+
103
+ # Load tokenizer
104
+ print(f"Loading tokenizer: {tokenizer_path}")
105
+ tokenizer = spm.SentencePieceProcessor()
106
+ tokenizer.load(str(tokenizer_path))
107
+
108
+ return cls(model, tokenizer, device)
109
+
110
+ def encode(self, text: str) -> List[int]:
111
+ """Encode text to token IDs."""
112
+ return self.tokenizer.encode_as_ids(text)
113
+
114
+ def decode(self, token_ids: List[int]) -> str:
115
+ """Decode token IDs to text."""
116
+ return self.tokenizer.decode_ids(token_ids)
117
+
118
+ def generate(
119
+ self, prompt: str, gen_config: GenerationConfig = None, verbose: bool = False
120
+ ) -> str:
121
+ """
122
+ Generate text from a prompt.
123
+
124
+ Args:
125
+ prompt: Input text prompt
126
+ gen_config: Generation configuration
127
+ verbose: Print generation progress
128
+
129
+ Returns:
130
+ Generated text
131
+
132
+ Process:
133
+ 1. Encode prompt to tokens
134
+ 2. Generate tokens autoregressively
135
+ 3. Decode back to text
136
+ """
137
+ if gen_config is None:
138
+ gen_config = GenerationConfig()
139
+
140
+ # Encode prompt
141
+ prompt_tokens = self.encode(prompt)
142
+
143
+ if verbose:
144
+ print(f"Prompt: {prompt}")
145
+ print(f"Prompt tokens: {prompt_tokens}")
146
+ print(f"Generating...")
147
+
148
+ # Generate tokens
149
+ generated_tokens = self.generate_tokens(
150
+ prompt_tokens=prompt_tokens, gen_config=gen_config, verbose=verbose
151
+ )
152
+
153
+ # Decode to text
154
+ generated_text = self.decode(generated_tokens)
155
+
156
+ if verbose:
157
+ print(f"\nGenerated {len(generated_tokens)} tokens")
158
+ print(f"Output: {generated_text}")
159
+
160
+ return generated_text
161
+
162
+ def generate_tokens(
163
+ self, prompt_tokens: List[int], gen_config: GenerationConfig, verbose: bool = False
164
+ ) -> List[int]:
165
+ """
166
+ Generate tokens autoregressively.
167
+
168
+ Args:
169
+ prompt_tokens: Input token IDs
170
+ gen_config: Generation config
171
+ verbose: Print progress
172
+
173
+ Returns:
174
+ List of generated token IDs (including prompt)
175
+ """
176
+ # Start with prompt
177
+ generated = prompt_tokens.copy()
178
+
179
+ # Generation loop
180
+ with torch.no_grad():
181
+ for step in range(gen_config.max_length):
182
+ # Check stopping criteria
183
+ if should_stop_generation(
184
+ generated_tokens=generated,
185
+ stop_tokens=gen_config.stop_tokens,
186
+ max_length=gen_config.max_length,
187
+ min_length=gen_config.min_length,
188
+ ):
189
+ break
190
+
191
+ # Prepare input (last max_seq_len tokens)
192
+ max_seq_len = self.model.config.max_seq_len
193
+ input_ids = generated[-max_seq_len:]
194
+ input_tensor = torch.tensor([input_ids], device=self.device)
195
+
196
+ # Forward pass
197
+ logits = self.model(input_tensor)
198
+
199
+ # Get logits for last position
200
+ next_token_logits = logits[0, -1, :] # [vocab_size]
201
+
202
+ # Apply repetition penalty
203
+ if gen_config.repetition_penalty != 1.0:
204
+ generated_tensor = torch.tensor([generated], device=self.device)
205
+ next_token_logits = apply_repetition_penalty(
206
+ next_token_logits.unsqueeze(0),
207
+ generated_tensor,
208
+ gen_config.repetition_penalty,
209
+ ).squeeze(0)
210
+
211
+ # Block n-gram repeats
212
+ if gen_config.no_repeat_ngram_size > 0:
213
+ next_token_logits = block_ngram_repeats(
214
+ next_token_logits, generated, gen_config.no_repeat_ngram_size
215
+ )
216
+
217
+ # Sample next token
218
+ next_token = sample_next_token(
219
+ next_token_logits.unsqueeze(0),
220
+ temperature=gen_config.temperature,
221
+ top_k=gen_config.top_k,
222
+ top_p=gen_config.top_p,
223
+ )
224
+
225
+ next_token_id = next_token.item()
226
+ generated.append(next_token_id)
227
+
228
+ if verbose and step % 10 == 0:
229
+ partial_text = self.decode(generated)
230
+ print(f"Step {step}: {partial_text}")
231
+
232
+ return generated
233
+
234
+ def generate_batch(self, prompts: List[str], gen_config: GenerationConfig = None) -> List[str]:
235
+ """
236
+ Generate text for multiple prompts.
237
+
238
+ Args:
239
+ prompts: List of input prompts
240
+ gen_config: Generation config
241
+
242
+ Returns:
243
+ List of generated texts
244
+ """
245
+ if gen_config is None:
246
+ gen_config = GenerationConfig()
247
+
248
+ results = []
249
+ for prompt in prompts:
250
+ output = self.generate(prompt, gen_config, verbose=False)
251
+ results.append(output)
252
+
253
+ return results