gptmed 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. gptmed/__init__.py +37 -0
  2. gptmed/configs/__init__.py +1 -0
  3. gptmed/configs/train_config.py +154 -0
  4. gptmed/data/__init__.py +5 -0
  5. gptmed/data/parsers/__init__.py +10 -0
  6. gptmed/data/parsers/medquad_parser.py +257 -0
  7. gptmed/data/parsers/text_formatter.py +148 -0
  8. gptmed/inference/__init__.py +1 -0
  9. gptmed/inference/decoding_utils.py +190 -0
  10. gptmed/inference/generation_config.py +83 -0
  11. gptmed/inference/generator.py +253 -0
  12. gptmed/inference/sampling.py +261 -0
  13. gptmed/model/__init__.py +9 -0
  14. gptmed/model/architecture/__init__.py +35 -0
  15. gptmed/model/architecture/attention.py +188 -0
  16. gptmed/model/architecture/decoder_block.py +130 -0
  17. gptmed/model/architecture/embeddings.py +146 -0
  18. gptmed/model/architecture/feedforward.py +109 -0
  19. gptmed/model/architecture/transformer.py +204 -0
  20. gptmed/model/configs/__init__.py +17 -0
  21. gptmed/model/configs/model_config.py +155 -0
  22. gptmed/tokenizer/__init__.py +7 -0
  23. gptmed/tokenizer/tokenize_data.py +286 -0
  24. gptmed/tokenizer/train_tokenizer.py +218 -0
  25. gptmed/training/__init__.py +1 -0
  26. gptmed/training/dataset.py +183 -0
  27. gptmed/training/train.py +272 -0
  28. gptmed/training/trainer.py +331 -0
  29. gptmed/training/utils.py +212 -0
  30. gptmed/utils/__init__.py +1 -0
  31. gptmed/utils/checkpoints.py +224 -0
  32. gptmed/utils/logging.py +189 -0
  33. gptmed-0.0.1.dist-info/METADATA +325 -0
  34. gptmed-0.0.1.dist-info/RECORD +38 -0
  35. gptmed-0.0.1.dist-info/WHEEL +5 -0
  36. gptmed-0.0.1.dist-info/entry_points.txt +3 -0
  37. gptmed-0.0.1.dist-info/licenses/LICENSE +21 -0
  38. gptmed-0.0.1.dist-info/top_level.txt +1 -0
gptmed/__init__.py ADDED
@@ -0,0 +1,37 @@
1
+ """
2
+ llm-med: A lightweight medical question-answering language model
3
+
4
+ This package provides a GPT-based transformer architecture trained on the MedQuAD dataset
5
+ for medical domain question answering.
6
+
7
+ Main Components:
8
+ - model: GPT transformer architecture
9
+ - inference: Text generation and sampling
10
+ - training: Training loop and utilities
11
+ - tokenizer: SentencePiece tokenizer
12
+ - configs: Configuration management
13
+ - utils: Utility functions
14
+
15
+ Example:
16
+ >>> from llm_med.model.architecture import GPTTransformer
17
+ >>> from llm_med.model.configs.model_config import get_small_config
18
+ >>> from llm_med.inference.generator import TextGenerator
19
+ >>>
20
+ >>> config = get_small_config()
21
+ >>> model = GPTTransformer(config)
22
+ """
23
+
24
+ __version__ = "0.2.0"
25
+ __author__ = "Sanjog Sigdel"
26
+ __email__ = "sigdelsanjog@gmail.com"
27
+
28
+ # Expose main components at package level for convenience
29
+ from llm_med.model.architecture import GPTTransformer
30
+ from llm_med.model.configs.model_config import ModelConfig, get_small_config, get_tiny_config
31
+
32
+ __all__ = [
33
+ "GPTTransformer",
34
+ "ModelConfig",
35
+ "get_small_config",
36
+ "get_tiny_config",
37
+ ]
@@ -0,0 +1 @@
1
+ """Configs package."""
@@ -0,0 +1,154 @@
1
+ """
2
+ Training Configuration
3
+
4
+ PURPOSE:
5
+ Central place for all training hyperparameters. Separating training config
6
+ from model config follows separation of concerns - you can train the same
7
+ model architecture with different training strategies.
8
+
9
+ WHAT THIS FILE CONTAINS:
10
+ - Batch size, learning rate, epochs
11
+ - Optimizer settings (weight decay, betas)
12
+ - Learning rate schedule parameters
13
+ - Gradient clipping threshold
14
+ - Checkpoint and logging intervals
15
+
16
+ PACKAGES USED:
17
+ - dataclasses: Clean config structure
18
+ - json: Save/load configs
19
+
20
+ FILES FROM THIS PROJECT:
21
+ - None (base config)
22
+
23
+ DESIGN DECISIONS:
24
+ - Small batch size (16-32) for 8GB VRAM
25
+ - Learning rate ~1e-4 to 3e-4 (typical for small transformers)
26
+ - Weight decay 0.01 (L2 regularization)
27
+ - Gradient clipping at 1.0 (prevents exploding gradients)
28
+ - Warmup steps to stabilize early training
29
+
30
+ COMMON FAILURE MODES:
31
+ - LR too high → loss explodes, NaN
32
+ - LR too low → very slow convergence
33
+ - No warmup → unstable early training
34
+ - Batch size too large → OOM
35
+ - No gradient clipping → gradient explosion
36
+ """
37
+
38
+ from dataclasses import dataclass
39
+ from pathlib import Path
40
+ import json
41
+
42
+
43
+ @dataclass
44
+ class TrainingConfig:
45
+ """Training hyperparameters."""
46
+
47
+ # Data
48
+ train_data_path: str = "./data/tokenized/train.npy"
49
+ val_data_path: str = "./data/tokenized/val.npy"
50
+
51
+ # Batch size (adjust based on VRAM)
52
+ batch_size: int = 16 # GTX 1080: 16-32 works well
53
+
54
+ # Training duration
55
+ num_epochs: int = 10 # Total training epochs
56
+ max_steps: int = -1 # -1 means train for num_epochs
57
+
58
+ # Optimization
59
+ learning_rate: float = 3e-4 # Peak learning rate
60
+ weight_decay: float = 0.01 # L2 regularization
61
+ betas: tuple = (0.9, 0.999) # Adam beta1, beta2
62
+ eps: float = 1e-8 # Adam epsilon
63
+
64
+ # Learning rate schedule
65
+ warmup_steps: int = 100 # Linear warmup steps
66
+ lr_decay: str = "cosine" # 'cosine' or 'linear' or 'constant'
67
+ min_lr: float = 1e-5 # Minimum LR for decay
68
+
69
+ # Gradient clipping (CRITICAL for stability)
70
+ grad_clip: float = 1.0 # Clip gradient norm to this value
71
+
72
+ # Evaluation
73
+ eval_interval: int = 500 # Evaluate every N steps
74
+ eval_iters: int = 100 # Number of eval batches
75
+
76
+ # Checkpointing
77
+ checkpoint_dir: str = "./model/checkpoints"
78
+ save_interval: int = 1000 # Save checkpoint every N steps
79
+ keep_last_n: int = 3 # Keep only last N checkpoints
80
+
81
+ # Logging
82
+ log_interval: int = 10 # Log every N steps
83
+ log_dir: str = "./logs"
84
+
85
+ # Device
86
+ device: str = "cuda" # 'cuda' or 'cpu'
87
+
88
+ # Mixed precision (optional, for faster training)
89
+ use_amp: bool = False # Automatic Mixed Precision
90
+
91
+ # Reproducibility
92
+ seed: int = 42
93
+
94
+ def to_dict(self) -> dict:
95
+ """Convert to dictionary."""
96
+ return {
97
+ "train_data_path": self.train_data_path,
98
+ "val_data_path": self.val_data_path,
99
+ "batch_size": self.batch_size,
100
+ "num_epochs": self.num_epochs,
101
+ "max_steps": self.max_steps,
102
+ "learning_rate": self.learning_rate,
103
+ "weight_decay": self.weight_decay,
104
+ "betas": self.betas,
105
+ "eps": self.eps,
106
+ "warmup_steps": self.warmup_steps,
107
+ "lr_decay": self.lr_decay,
108
+ "min_lr": self.min_lr,
109
+ "grad_clip": self.grad_clip,
110
+ "eval_interval": self.eval_interval,
111
+ "eval_iters": self.eval_iters,
112
+ "checkpoint_dir": self.checkpoint_dir,
113
+ "save_interval": self.save_interval,
114
+ "keep_last_n": self.keep_last_n,
115
+ "log_interval": self.log_interval,
116
+ "log_dir": self.log_dir,
117
+ "device": self.device,
118
+ "use_amp": self.use_amp,
119
+ "seed": self.seed,
120
+ }
121
+
122
+ def save(self, path: Path):
123
+ """Save config to JSON."""
124
+ with open(path, "w") as f:
125
+ json.dump(self.to_dict(), f, indent=2)
126
+
127
+ @classmethod
128
+ def from_dict(cls, config_dict: dict):
129
+ """Load from dictionary."""
130
+ return cls(**config_dict)
131
+
132
+ @classmethod
133
+ def from_file(cls, path: Path):
134
+ """Load from JSON file."""
135
+ with open(path, "r") as f:
136
+ config_dict = json.load(f)
137
+ return cls.from_dict(config_dict)
138
+
139
+
140
+ def get_default_config() -> TrainingConfig:
141
+ """Default training config for GTX 1080."""
142
+ return TrainingConfig()
143
+
144
+
145
+ def get_quick_test_config() -> TrainingConfig:
146
+ """Quick config for testing (small batch, few steps)."""
147
+ return TrainingConfig(
148
+ batch_size=4,
149
+ num_epochs=1,
150
+ max_steps=100,
151
+ eval_interval=50,
152
+ save_interval=50,
153
+ log_interval=5,
154
+ )
@@ -0,0 +1,5 @@
1
+ """
2
+ Data processing module
3
+ """
4
+
5
+ __all__ = ["parsers"]
@@ -0,0 +1,10 @@
1
+ """
2
+ Data parsers for MedQuAD dataset.
3
+
4
+ This module contains parsers to extract and process medical Q&A pairs
5
+ from various XML sources.
6
+ """
7
+
8
+ from .medquad_parser import MedQuADParser
9
+
10
+ __all__ = ["MedQuADParser"]
@@ -0,0 +1,257 @@
1
+ """
2
+ MedQuAD XML Parser
3
+
4
+ Parses MedQuAD XML files and extracts question-answer pairs.
5
+ Follows Single Responsibility Principle - handles only XML parsing logic.
6
+
7
+ Design decisions:
8
+ - Uses lxml for robust XML parsing (handles malformed XML better than xml.etree)
9
+ - Filters empty answers (copyright-removed collections)
10
+ - Preserves question types and focus metadata for future analysis
11
+ - Returns structured data (not formatting) - separation of concerns
12
+ """
13
+
14
+ import xml.etree.ElementTree as ET
15
+ from pathlib import Path
16
+ from typing import Dict, List, Optional
17
+ from dataclasses import dataclass
18
+
19
+
20
+ @dataclass
21
+ class QAPair:
22
+ """
23
+ Structured representation of a question-answer pair.
24
+
25
+ Attributes:
26
+ question: The medical question text
27
+ answer: The answer text (may be empty for copyright-removed content)
28
+ qid: Unique question ID
29
+ qtype: Question type (e.g., 'treatment', 'symptoms', 'causes')
30
+ focus: Medical entity the question focuses on (disease, drug, etc.)
31
+ focus_category: Category of focus (Disease, Drug, Other)
32
+ source: Source collection name
33
+ """
34
+
35
+ question: str
36
+ answer: str
37
+ qid: str
38
+ qtype: str
39
+ focus: str
40
+ focus_category: Optional[str]
41
+ source: str
42
+
43
+ def has_answer(self) -> bool:
44
+ """Check if this pair has a non-empty answer."""
45
+ return bool(self.answer and self.answer.strip())
46
+
47
+ def to_dict(self) -> Dict:
48
+ """Convert to dictionary for serialization."""
49
+ return {
50
+ "question": self.question,
51
+ "answer": self.answer,
52
+ "qid": self.qid,
53
+ "qtype": self.qtype,
54
+ "focus": self.focus,
55
+ "focus_category": self.focus_category,
56
+ "source": self.source,
57
+ }
58
+
59
+
60
+ class MedQuADParser:
61
+ """
62
+ Parser for MedQuAD XML files.
63
+
64
+ This class is responsible only for parsing XML structure.
65
+ Formatting to causal text is handled separately (SRP).
66
+ """
67
+
68
+ # Collections known to have removed answers (MedlinePlus copyright)
69
+ EMPTY_COLLECTIONS = {"10_MPlus_ADAM_QA", "11_MPlusDrugs_QA", "12_MPlusHerbsSupplements_QA"}
70
+
71
+ def __init__(self, dataset_root: Path, skip_empty: bool = True):
72
+ """
73
+ Initialize parser.
74
+
75
+ Args:
76
+ dataset_root: Path to MedQuAD root directory
77
+ skip_empty: Whether to skip collections with removed answers
78
+ """
79
+ self.dataset_root = Path(dataset_root)
80
+ self.skip_empty = skip_empty
81
+
82
+ if not self.dataset_root.exists():
83
+ raise ValueError(f"Dataset root does not exist: {dataset_root}")
84
+
85
+ def parse_xml_file(self, xml_path: Path) -> List[QAPair]:
86
+ """
87
+ Parse a single XML file and extract Q&A pairs.
88
+
89
+ Args:
90
+ xml_path: Path to XML file
91
+
92
+ Returns:
93
+ List of QAPair objects
94
+
95
+ Raises:
96
+ ET.ParseError: If XML is malformed
97
+ """
98
+ try:
99
+ tree = ET.parse(xml_path)
100
+ root = tree.getroot()
101
+
102
+ # Extract document-level metadata
103
+ source = root.get("source", "Unknown")
104
+ focus = root.findtext("Focus", default="Unknown")
105
+
106
+ # Get focus category if available
107
+ focus_category = None
108
+ focus_annotations = root.find("FocusAnnotations")
109
+ if focus_annotations is not None:
110
+ category = focus_annotations.find("Category")
111
+ if category is not None:
112
+ focus_category = category.text
113
+
114
+ # Extract Q&A pairs
115
+ qa_pairs = []
116
+ qa_pairs_elem = root.find("QAPairs")
117
+
118
+ if qa_pairs_elem is None:
119
+ return qa_pairs
120
+
121
+ for qa_elem in qa_pairs_elem.findall("QAPair"):
122
+ question_elem = qa_elem.find("Question")
123
+ answer_elem = qa_elem.find("Answer")
124
+
125
+ if question_elem is None:
126
+ continue
127
+
128
+ question = question_elem.text or ""
129
+ answer = answer_elem.text if answer_elem is not None else ""
130
+ qid = question_elem.get("qid", "")
131
+ qtype = question_elem.get("qtype", "unknown")
132
+
133
+ qa_pair = QAPair(
134
+ question=question.strip(),
135
+ answer=answer.strip() if answer else "",
136
+ qid=qid,
137
+ qtype=qtype,
138
+ focus=focus,
139
+ focus_category=focus_category,
140
+ source=source,
141
+ )
142
+
143
+ qa_pairs.append(qa_pair)
144
+
145
+ return qa_pairs
146
+
147
+ except ET.ParseError as e:
148
+ raise ET.ParseError(f"Failed to parse {xml_path}: {e}")
149
+
150
+ def get_collection_paths(self) -> List[Path]:
151
+ """
152
+ Get all collection directories to process.
153
+
154
+ Returns:
155
+ List of collection directory paths
156
+ """
157
+ collections = []
158
+
159
+ for item in self.dataset_root.iterdir():
160
+ if not item.is_dir():
161
+ continue
162
+
163
+ # Skip hidden directories and git
164
+ if item.name.startswith("."):
165
+ continue
166
+
167
+ # Skip empty collections if requested
168
+ if self.skip_empty and item.name in self.EMPTY_COLLECTIONS:
169
+ continue
170
+
171
+ collections.append(item)
172
+
173
+ return sorted(collections)
174
+
175
+ def parse_collection(self, collection_path: Path) -> List[QAPair]:
176
+ """
177
+ Parse all XML files in a collection directory.
178
+
179
+ Args:
180
+ collection_path: Path to collection directory
181
+
182
+ Returns:
183
+ List of all QAPair objects from this collection
184
+ """
185
+ qa_pairs = []
186
+ xml_files = list(collection_path.glob("*.xml"))
187
+
188
+ for xml_file in xml_files:
189
+ try:
190
+ pairs = self.parse_xml_file(xml_file)
191
+ qa_pairs.extend(pairs)
192
+ except ET.ParseError as e:
193
+ print(f"Warning: Skipping malformed file {xml_file}: {e}")
194
+ continue
195
+
196
+ return qa_pairs
197
+
198
+ def parse_all(self, filter_empty_answers: bool = True) -> List[QAPair]:
199
+ """
200
+ Parse all collections in the dataset.
201
+
202
+ Args:
203
+ filter_empty_answers: Whether to filter out pairs with empty answers
204
+
205
+ Returns:
206
+ List of all QAPair objects from all collections
207
+ """
208
+ all_pairs = []
209
+ collections = self.get_collection_paths()
210
+
211
+ print(f"Found {len(collections)} collections to process")
212
+
213
+ for collection in collections:
214
+ print(f"Processing {collection.name}...")
215
+ pairs = self.parse_collection(collection)
216
+
217
+ if filter_empty_answers:
218
+ pairs = [p for p in pairs if p.has_answer()]
219
+
220
+ all_pairs.extend(pairs)
221
+ print(f" Extracted {len(pairs)} Q&A pairs")
222
+
223
+ print(f"\nTotal Q&A pairs: {len(all_pairs)}")
224
+ return all_pairs
225
+
226
+ def get_statistics(self, qa_pairs: List[QAPair]) -> Dict:
227
+ """
228
+ Get statistics about parsed Q&A pairs.
229
+
230
+ Args:
231
+ qa_pairs: List of QAPair objects
232
+
233
+ Returns:
234
+ Dictionary with statistics
235
+ """
236
+ stats = {
237
+ "total_pairs": len(qa_pairs),
238
+ "pairs_with_answers": sum(1 for p in qa_pairs if p.has_answer()),
239
+ "pairs_without_answers": sum(1 for p in qa_pairs if not p.has_answer()),
240
+ "sources": {},
241
+ "qtypes": {},
242
+ "focus_categories": {},
243
+ }
244
+
245
+ for pair in qa_pairs:
246
+ # Count by source
247
+ stats["sources"][pair.source] = stats["sources"].get(pair.source, 0) + 1
248
+
249
+ # Count by question type
250
+ stats["qtypes"][pair.qtype] = stats["qtypes"].get(pair.qtype, 0) + 1
251
+
252
+ # Count by focus category
253
+ if pair.focus_category:
254
+ cat = pair.focus_category
255
+ stats["focus_categories"][cat] = stats["focus_categories"].get(cat, 0) + 1
256
+
257
+ return stats
@@ -0,0 +1,148 @@
1
+ """
2
+ Text Formatter for Causal Language Modeling
3
+
4
+ Converts structured Q&A pairs into causal text format suitable for
5
+ next-token prediction training.
6
+
7
+ Design decisions explained:
8
+ - Simple format preserves question-answer structure
9
+ - Special tokens ([Q], [A]) help model learn task boundaries
10
+ - Newlines create clear separation for tokenizer
11
+ - No complex templating - reduces failure modes
12
+ """
13
+
14
+ from typing import List
15
+ from dataclasses import dataclass
16
+
17
+
18
+ @dataclass
19
+ class FormatConfig:
20
+ """Configuration for text formatting."""
21
+
22
+ use_special_tokens: bool = True # Use [Q] and [A] markers
23
+ add_separator: bool = True # Add newline between Q and A
24
+ add_end_token: bool = True # Add end-of-text marker
25
+ question_prefix: str = "Q: "
26
+ answer_prefix: str = "A: "
27
+ separator: str = "\n"
28
+ end_token: str = "\n\n" # Double newline = document boundary
29
+
30
+
31
+ class CausalTextFormatter:
32
+ """
33
+ Formats Q&A pairs for causal language modeling.
34
+
35
+ Follows Open-Closed Principle: easy to extend with new formats
36
+ without modifying existing code.
37
+ """
38
+
39
+ def __init__(self, config: FormatConfig = None):
40
+ """
41
+ Initialize formatter.
42
+
43
+ Args:
44
+ config: Formatting configuration
45
+ """
46
+ self.config = config or FormatConfig()
47
+
48
+ def format_single_pair(self, question: str, answer: str) -> str:
49
+ """
50
+ Format a single Q&A pair.
51
+
52
+ Args:
53
+ question: Question text
54
+ answer: Answer text
55
+
56
+ Returns:
57
+ Formatted text string
58
+
59
+ Example:
60
+ >>> formatter = CausalTextFormatter()
61
+ >>> formatter.format_single_pair("What is cancer?", "Cancer is...")
62
+ "Q: What is cancer?\\nA: Cancer is...\\n\\n"
63
+ """
64
+ parts = []
65
+
66
+ # Add question
67
+ if self.config.use_special_tokens:
68
+ parts.append(f"{self.config.question_prefix}{question}")
69
+ else:
70
+ parts.append(question)
71
+
72
+ # Add separator
73
+ if self.config.add_separator:
74
+ parts.append(self.config.separator)
75
+
76
+ # Add answer
77
+ if self.config.use_special_tokens:
78
+ parts.append(f"{self.config.answer_prefix}{answer}")
79
+ else:
80
+ parts.append(answer)
81
+
82
+ # Add end token
83
+ if self.config.add_end_token:
84
+ parts.append(self.config.end_token)
85
+
86
+ return "".join(parts)
87
+
88
+ def format_batch(self, qa_pairs: List[tuple]) -> str:
89
+ """
90
+ Format multiple Q&A pairs into a single text corpus.
91
+
92
+ Args:
93
+ qa_pairs: List of (question, answer) tuples
94
+
95
+ Returns:
96
+ Single formatted text string
97
+ """
98
+ formatted_pairs = [self.format_single_pair(q, a) for q, a in qa_pairs]
99
+ return "".join(formatted_pairs)
100
+
101
+ def format_from_structured(self, qa_objects: List) -> str:
102
+ """
103
+ Format from structured QAPair objects.
104
+
105
+ Args:
106
+ qa_objects: List of QAPair objects (from parser)
107
+
108
+ Returns:
109
+ Formatted text corpus
110
+ """
111
+ pairs = [(obj.question, obj.answer) for obj in qa_objects]
112
+ return self.format_batch(pairs)
113
+
114
+
115
+ class MinimalFormatter(CausalTextFormatter):
116
+ """
117
+ Minimal format without special tokens.
118
+ Useful for baseline comparisons.
119
+ """
120
+
121
+ def __init__(self):
122
+ config = FormatConfig(
123
+ use_special_tokens=False,
124
+ add_separator=True,
125
+ add_end_token=True,
126
+ separator="\n",
127
+ end_token="\n\n",
128
+ )
129
+ super().__init__(config)
130
+
131
+
132
+ class StructuredFormatter(CausalTextFormatter):
133
+ """
134
+ More structured format with explicit markers.
135
+ Better for instruction-tuning style training.
136
+ """
137
+
138
+ def __init__(self):
139
+ config = FormatConfig(
140
+ use_special_tokens=True,
141
+ add_separator=True,
142
+ add_end_token=True,
143
+ question_prefix="### Question: ",
144
+ answer_prefix="### Answer: ",
145
+ separator="\n\n",
146
+ end_token="\n\n---\n\n",
147
+ )
148
+ super().__init__(config)
@@ -0,0 +1 @@
1
+ """Inference package."""