gptmed 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gptmed/__init__.py +37 -0
- gptmed/configs/__init__.py +1 -0
- gptmed/configs/train_config.py +154 -0
- gptmed/data/__init__.py +5 -0
- gptmed/data/parsers/__init__.py +10 -0
- gptmed/data/parsers/medquad_parser.py +257 -0
- gptmed/data/parsers/text_formatter.py +148 -0
- gptmed/inference/__init__.py +1 -0
- gptmed/inference/decoding_utils.py +190 -0
- gptmed/inference/generation_config.py +83 -0
- gptmed/inference/generator.py +253 -0
- gptmed/inference/sampling.py +261 -0
- gptmed/model/__init__.py +9 -0
- gptmed/model/architecture/__init__.py +35 -0
- gptmed/model/architecture/attention.py +188 -0
- gptmed/model/architecture/decoder_block.py +130 -0
- gptmed/model/architecture/embeddings.py +146 -0
- gptmed/model/architecture/feedforward.py +109 -0
- gptmed/model/architecture/transformer.py +204 -0
- gptmed/model/configs/__init__.py +17 -0
- gptmed/model/configs/model_config.py +155 -0
- gptmed/tokenizer/__init__.py +7 -0
- gptmed/tokenizer/tokenize_data.py +286 -0
- gptmed/tokenizer/train_tokenizer.py +218 -0
- gptmed/training/__init__.py +1 -0
- gptmed/training/dataset.py +183 -0
- gptmed/training/train.py +272 -0
- gptmed/training/trainer.py +331 -0
- gptmed/training/utils.py +212 -0
- gptmed/utils/__init__.py +1 -0
- gptmed/utils/checkpoints.py +224 -0
- gptmed/utils/logging.py +189 -0
- gptmed-0.0.1.dist-info/METADATA +325 -0
- gptmed-0.0.1.dist-info/RECORD +38 -0
- gptmed-0.0.1.dist-info/WHEEL +5 -0
- gptmed-0.0.1.dist-info/entry_points.txt +3 -0
- gptmed-0.0.1.dist-info/licenses/LICENSE +21 -0
- gptmed-0.0.1.dist-info/top_level.txt +1 -0
gptmed/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""
|
|
2
|
+
llm-med: A lightweight medical question-answering language model
|
|
3
|
+
|
|
4
|
+
This package provides a GPT-based transformer architecture trained on the MedQuAD dataset
|
|
5
|
+
for medical domain question answering.
|
|
6
|
+
|
|
7
|
+
Main Components:
|
|
8
|
+
- model: GPT transformer architecture
|
|
9
|
+
- inference: Text generation and sampling
|
|
10
|
+
- training: Training loop and utilities
|
|
11
|
+
- tokenizer: SentencePiece tokenizer
|
|
12
|
+
- configs: Configuration management
|
|
13
|
+
- utils: Utility functions
|
|
14
|
+
|
|
15
|
+
Example:
|
|
16
|
+
>>> from llm_med.model.architecture import GPTTransformer
|
|
17
|
+
>>> from llm_med.model.configs.model_config import get_small_config
|
|
18
|
+
>>> from llm_med.inference.generator import TextGenerator
|
|
19
|
+
>>>
|
|
20
|
+
>>> config = get_small_config()
|
|
21
|
+
>>> model = GPTTransformer(config)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
__version__ = "0.2.0"
|
|
25
|
+
__author__ = "Sanjog Sigdel"
|
|
26
|
+
__email__ = "sigdelsanjog@gmail.com"
|
|
27
|
+
|
|
28
|
+
# Expose main components at package level for convenience
|
|
29
|
+
from llm_med.model.architecture import GPTTransformer
|
|
30
|
+
from llm_med.model.configs.model_config import ModelConfig, get_small_config, get_tiny_config
|
|
31
|
+
|
|
32
|
+
__all__ = [
|
|
33
|
+
"GPTTransformer",
|
|
34
|
+
"ModelConfig",
|
|
35
|
+
"get_small_config",
|
|
36
|
+
"get_tiny_config",
|
|
37
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Configs package."""
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Training Configuration
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Central place for all training hyperparameters. Separating training config
|
|
6
|
+
from model config follows separation of concerns - you can train the same
|
|
7
|
+
model architecture with different training strategies.
|
|
8
|
+
|
|
9
|
+
WHAT THIS FILE CONTAINS:
|
|
10
|
+
- Batch size, learning rate, epochs
|
|
11
|
+
- Optimizer settings (weight decay, betas)
|
|
12
|
+
- Learning rate schedule parameters
|
|
13
|
+
- Gradient clipping threshold
|
|
14
|
+
- Checkpoint and logging intervals
|
|
15
|
+
|
|
16
|
+
PACKAGES USED:
|
|
17
|
+
- dataclasses: Clean config structure
|
|
18
|
+
- json: Save/load configs
|
|
19
|
+
|
|
20
|
+
FILES FROM THIS PROJECT:
|
|
21
|
+
- None (base config)
|
|
22
|
+
|
|
23
|
+
DESIGN DECISIONS:
|
|
24
|
+
- Small batch size (16-32) for 8GB VRAM
|
|
25
|
+
- Learning rate ~1e-4 to 3e-4 (typical for small transformers)
|
|
26
|
+
- Weight decay 0.01 (L2 regularization)
|
|
27
|
+
- Gradient clipping at 1.0 (prevents exploding gradients)
|
|
28
|
+
- Warmup steps to stabilize early training
|
|
29
|
+
|
|
30
|
+
COMMON FAILURE MODES:
|
|
31
|
+
- LR too high → loss explodes, NaN
|
|
32
|
+
- LR too low → very slow convergence
|
|
33
|
+
- No warmup → unstable early training
|
|
34
|
+
- Batch size too large → OOM
|
|
35
|
+
- No gradient clipping → gradient explosion
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
from dataclasses import dataclass
|
|
39
|
+
from pathlib import Path
|
|
40
|
+
import json
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class TrainingConfig:
|
|
45
|
+
"""Training hyperparameters."""
|
|
46
|
+
|
|
47
|
+
# Data
|
|
48
|
+
train_data_path: str = "./data/tokenized/train.npy"
|
|
49
|
+
val_data_path: str = "./data/tokenized/val.npy"
|
|
50
|
+
|
|
51
|
+
# Batch size (adjust based on VRAM)
|
|
52
|
+
batch_size: int = 16 # GTX 1080: 16-32 works well
|
|
53
|
+
|
|
54
|
+
# Training duration
|
|
55
|
+
num_epochs: int = 10 # Total training epochs
|
|
56
|
+
max_steps: int = -1 # -1 means train for num_epochs
|
|
57
|
+
|
|
58
|
+
# Optimization
|
|
59
|
+
learning_rate: float = 3e-4 # Peak learning rate
|
|
60
|
+
weight_decay: float = 0.01 # L2 regularization
|
|
61
|
+
betas: tuple = (0.9, 0.999) # Adam beta1, beta2
|
|
62
|
+
eps: float = 1e-8 # Adam epsilon
|
|
63
|
+
|
|
64
|
+
# Learning rate schedule
|
|
65
|
+
warmup_steps: int = 100 # Linear warmup steps
|
|
66
|
+
lr_decay: str = "cosine" # 'cosine' or 'linear' or 'constant'
|
|
67
|
+
min_lr: float = 1e-5 # Minimum LR for decay
|
|
68
|
+
|
|
69
|
+
# Gradient clipping (CRITICAL for stability)
|
|
70
|
+
grad_clip: float = 1.0 # Clip gradient norm to this value
|
|
71
|
+
|
|
72
|
+
# Evaluation
|
|
73
|
+
eval_interval: int = 500 # Evaluate every N steps
|
|
74
|
+
eval_iters: int = 100 # Number of eval batches
|
|
75
|
+
|
|
76
|
+
# Checkpointing
|
|
77
|
+
checkpoint_dir: str = "./model/checkpoints"
|
|
78
|
+
save_interval: int = 1000 # Save checkpoint every N steps
|
|
79
|
+
keep_last_n: int = 3 # Keep only last N checkpoints
|
|
80
|
+
|
|
81
|
+
# Logging
|
|
82
|
+
log_interval: int = 10 # Log every N steps
|
|
83
|
+
log_dir: str = "./logs"
|
|
84
|
+
|
|
85
|
+
# Device
|
|
86
|
+
device: str = "cuda" # 'cuda' or 'cpu'
|
|
87
|
+
|
|
88
|
+
# Mixed precision (optional, for faster training)
|
|
89
|
+
use_amp: bool = False # Automatic Mixed Precision
|
|
90
|
+
|
|
91
|
+
# Reproducibility
|
|
92
|
+
seed: int = 42
|
|
93
|
+
|
|
94
|
+
def to_dict(self) -> dict:
|
|
95
|
+
"""Convert to dictionary."""
|
|
96
|
+
return {
|
|
97
|
+
"train_data_path": self.train_data_path,
|
|
98
|
+
"val_data_path": self.val_data_path,
|
|
99
|
+
"batch_size": self.batch_size,
|
|
100
|
+
"num_epochs": self.num_epochs,
|
|
101
|
+
"max_steps": self.max_steps,
|
|
102
|
+
"learning_rate": self.learning_rate,
|
|
103
|
+
"weight_decay": self.weight_decay,
|
|
104
|
+
"betas": self.betas,
|
|
105
|
+
"eps": self.eps,
|
|
106
|
+
"warmup_steps": self.warmup_steps,
|
|
107
|
+
"lr_decay": self.lr_decay,
|
|
108
|
+
"min_lr": self.min_lr,
|
|
109
|
+
"grad_clip": self.grad_clip,
|
|
110
|
+
"eval_interval": self.eval_interval,
|
|
111
|
+
"eval_iters": self.eval_iters,
|
|
112
|
+
"checkpoint_dir": self.checkpoint_dir,
|
|
113
|
+
"save_interval": self.save_interval,
|
|
114
|
+
"keep_last_n": self.keep_last_n,
|
|
115
|
+
"log_interval": self.log_interval,
|
|
116
|
+
"log_dir": self.log_dir,
|
|
117
|
+
"device": self.device,
|
|
118
|
+
"use_amp": self.use_amp,
|
|
119
|
+
"seed": self.seed,
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
def save(self, path: Path):
|
|
123
|
+
"""Save config to JSON."""
|
|
124
|
+
with open(path, "w") as f:
|
|
125
|
+
json.dump(self.to_dict(), f, indent=2)
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def from_dict(cls, config_dict: dict):
|
|
129
|
+
"""Load from dictionary."""
|
|
130
|
+
return cls(**config_dict)
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def from_file(cls, path: Path):
|
|
134
|
+
"""Load from JSON file."""
|
|
135
|
+
with open(path, "r") as f:
|
|
136
|
+
config_dict = json.load(f)
|
|
137
|
+
return cls.from_dict(config_dict)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def get_default_config() -> TrainingConfig:
|
|
141
|
+
"""Default training config for GTX 1080."""
|
|
142
|
+
return TrainingConfig()
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def get_quick_test_config() -> TrainingConfig:
|
|
146
|
+
"""Quick config for testing (small batch, few steps)."""
|
|
147
|
+
return TrainingConfig(
|
|
148
|
+
batch_size=4,
|
|
149
|
+
num_epochs=1,
|
|
150
|
+
max_steps=100,
|
|
151
|
+
eval_interval=50,
|
|
152
|
+
save_interval=50,
|
|
153
|
+
log_interval=5,
|
|
154
|
+
)
|
gptmed/data/__init__.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MedQuAD XML Parser
|
|
3
|
+
|
|
4
|
+
Parses MedQuAD XML files and extracts question-answer pairs.
|
|
5
|
+
Follows Single Responsibility Principle - handles only XML parsing logic.
|
|
6
|
+
|
|
7
|
+
Design decisions:
|
|
8
|
+
- Uses lxml for robust XML parsing (handles malformed XML better than xml.etree)
|
|
9
|
+
- Filters empty answers (copyright-removed collections)
|
|
10
|
+
- Preserves question types and focus metadata for future analysis
|
|
11
|
+
- Returns structured data (not formatting) - separation of concerns
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import xml.etree.ElementTree as ET
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Dict, List, Optional
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class QAPair:
|
|
22
|
+
"""
|
|
23
|
+
Structured representation of a question-answer pair.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
question: The medical question text
|
|
27
|
+
answer: The answer text (may be empty for copyright-removed content)
|
|
28
|
+
qid: Unique question ID
|
|
29
|
+
qtype: Question type (e.g., 'treatment', 'symptoms', 'causes')
|
|
30
|
+
focus: Medical entity the question focuses on (disease, drug, etc.)
|
|
31
|
+
focus_category: Category of focus (Disease, Drug, Other)
|
|
32
|
+
source: Source collection name
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
question: str
|
|
36
|
+
answer: str
|
|
37
|
+
qid: str
|
|
38
|
+
qtype: str
|
|
39
|
+
focus: str
|
|
40
|
+
focus_category: Optional[str]
|
|
41
|
+
source: str
|
|
42
|
+
|
|
43
|
+
def has_answer(self) -> bool:
|
|
44
|
+
"""Check if this pair has a non-empty answer."""
|
|
45
|
+
return bool(self.answer and self.answer.strip())
|
|
46
|
+
|
|
47
|
+
def to_dict(self) -> Dict:
|
|
48
|
+
"""Convert to dictionary for serialization."""
|
|
49
|
+
return {
|
|
50
|
+
"question": self.question,
|
|
51
|
+
"answer": self.answer,
|
|
52
|
+
"qid": self.qid,
|
|
53
|
+
"qtype": self.qtype,
|
|
54
|
+
"focus": self.focus,
|
|
55
|
+
"focus_category": self.focus_category,
|
|
56
|
+
"source": self.source,
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class MedQuADParser:
|
|
61
|
+
"""
|
|
62
|
+
Parser for MedQuAD XML files.
|
|
63
|
+
|
|
64
|
+
This class is responsible only for parsing XML structure.
|
|
65
|
+
Formatting to causal text is handled separately (SRP).
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
# Collections known to have removed answers (MedlinePlus copyright)
|
|
69
|
+
EMPTY_COLLECTIONS = {"10_MPlus_ADAM_QA", "11_MPlusDrugs_QA", "12_MPlusHerbsSupplements_QA"}
|
|
70
|
+
|
|
71
|
+
def __init__(self, dataset_root: Path, skip_empty: bool = True):
|
|
72
|
+
"""
|
|
73
|
+
Initialize parser.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
dataset_root: Path to MedQuAD root directory
|
|
77
|
+
skip_empty: Whether to skip collections with removed answers
|
|
78
|
+
"""
|
|
79
|
+
self.dataset_root = Path(dataset_root)
|
|
80
|
+
self.skip_empty = skip_empty
|
|
81
|
+
|
|
82
|
+
if not self.dataset_root.exists():
|
|
83
|
+
raise ValueError(f"Dataset root does not exist: {dataset_root}")
|
|
84
|
+
|
|
85
|
+
def parse_xml_file(self, xml_path: Path) -> List[QAPair]:
|
|
86
|
+
"""
|
|
87
|
+
Parse a single XML file and extract Q&A pairs.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
xml_path: Path to XML file
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
List of QAPair objects
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
ET.ParseError: If XML is malformed
|
|
97
|
+
"""
|
|
98
|
+
try:
|
|
99
|
+
tree = ET.parse(xml_path)
|
|
100
|
+
root = tree.getroot()
|
|
101
|
+
|
|
102
|
+
# Extract document-level metadata
|
|
103
|
+
source = root.get("source", "Unknown")
|
|
104
|
+
focus = root.findtext("Focus", default="Unknown")
|
|
105
|
+
|
|
106
|
+
# Get focus category if available
|
|
107
|
+
focus_category = None
|
|
108
|
+
focus_annotations = root.find("FocusAnnotations")
|
|
109
|
+
if focus_annotations is not None:
|
|
110
|
+
category = focus_annotations.find("Category")
|
|
111
|
+
if category is not None:
|
|
112
|
+
focus_category = category.text
|
|
113
|
+
|
|
114
|
+
# Extract Q&A pairs
|
|
115
|
+
qa_pairs = []
|
|
116
|
+
qa_pairs_elem = root.find("QAPairs")
|
|
117
|
+
|
|
118
|
+
if qa_pairs_elem is None:
|
|
119
|
+
return qa_pairs
|
|
120
|
+
|
|
121
|
+
for qa_elem in qa_pairs_elem.findall("QAPair"):
|
|
122
|
+
question_elem = qa_elem.find("Question")
|
|
123
|
+
answer_elem = qa_elem.find("Answer")
|
|
124
|
+
|
|
125
|
+
if question_elem is None:
|
|
126
|
+
continue
|
|
127
|
+
|
|
128
|
+
question = question_elem.text or ""
|
|
129
|
+
answer = answer_elem.text if answer_elem is not None else ""
|
|
130
|
+
qid = question_elem.get("qid", "")
|
|
131
|
+
qtype = question_elem.get("qtype", "unknown")
|
|
132
|
+
|
|
133
|
+
qa_pair = QAPair(
|
|
134
|
+
question=question.strip(),
|
|
135
|
+
answer=answer.strip() if answer else "",
|
|
136
|
+
qid=qid,
|
|
137
|
+
qtype=qtype,
|
|
138
|
+
focus=focus,
|
|
139
|
+
focus_category=focus_category,
|
|
140
|
+
source=source,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
qa_pairs.append(qa_pair)
|
|
144
|
+
|
|
145
|
+
return qa_pairs
|
|
146
|
+
|
|
147
|
+
except ET.ParseError as e:
|
|
148
|
+
raise ET.ParseError(f"Failed to parse {xml_path}: {e}")
|
|
149
|
+
|
|
150
|
+
def get_collection_paths(self) -> List[Path]:
|
|
151
|
+
"""
|
|
152
|
+
Get all collection directories to process.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
List of collection directory paths
|
|
156
|
+
"""
|
|
157
|
+
collections = []
|
|
158
|
+
|
|
159
|
+
for item in self.dataset_root.iterdir():
|
|
160
|
+
if not item.is_dir():
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
# Skip hidden directories and git
|
|
164
|
+
if item.name.startswith("."):
|
|
165
|
+
continue
|
|
166
|
+
|
|
167
|
+
# Skip empty collections if requested
|
|
168
|
+
if self.skip_empty and item.name in self.EMPTY_COLLECTIONS:
|
|
169
|
+
continue
|
|
170
|
+
|
|
171
|
+
collections.append(item)
|
|
172
|
+
|
|
173
|
+
return sorted(collections)
|
|
174
|
+
|
|
175
|
+
def parse_collection(self, collection_path: Path) -> List[QAPair]:
|
|
176
|
+
"""
|
|
177
|
+
Parse all XML files in a collection directory.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
collection_path: Path to collection directory
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
List of all QAPair objects from this collection
|
|
184
|
+
"""
|
|
185
|
+
qa_pairs = []
|
|
186
|
+
xml_files = list(collection_path.glob("*.xml"))
|
|
187
|
+
|
|
188
|
+
for xml_file in xml_files:
|
|
189
|
+
try:
|
|
190
|
+
pairs = self.parse_xml_file(xml_file)
|
|
191
|
+
qa_pairs.extend(pairs)
|
|
192
|
+
except ET.ParseError as e:
|
|
193
|
+
print(f"Warning: Skipping malformed file {xml_file}: {e}")
|
|
194
|
+
continue
|
|
195
|
+
|
|
196
|
+
return qa_pairs
|
|
197
|
+
|
|
198
|
+
def parse_all(self, filter_empty_answers: bool = True) -> List[QAPair]:
|
|
199
|
+
"""
|
|
200
|
+
Parse all collections in the dataset.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
filter_empty_answers: Whether to filter out pairs with empty answers
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
List of all QAPair objects from all collections
|
|
207
|
+
"""
|
|
208
|
+
all_pairs = []
|
|
209
|
+
collections = self.get_collection_paths()
|
|
210
|
+
|
|
211
|
+
print(f"Found {len(collections)} collections to process")
|
|
212
|
+
|
|
213
|
+
for collection in collections:
|
|
214
|
+
print(f"Processing {collection.name}...")
|
|
215
|
+
pairs = self.parse_collection(collection)
|
|
216
|
+
|
|
217
|
+
if filter_empty_answers:
|
|
218
|
+
pairs = [p for p in pairs if p.has_answer()]
|
|
219
|
+
|
|
220
|
+
all_pairs.extend(pairs)
|
|
221
|
+
print(f" Extracted {len(pairs)} Q&A pairs")
|
|
222
|
+
|
|
223
|
+
print(f"\nTotal Q&A pairs: {len(all_pairs)}")
|
|
224
|
+
return all_pairs
|
|
225
|
+
|
|
226
|
+
def get_statistics(self, qa_pairs: List[QAPair]) -> Dict:
|
|
227
|
+
"""
|
|
228
|
+
Get statistics about parsed Q&A pairs.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
qa_pairs: List of QAPair objects
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
Dictionary with statistics
|
|
235
|
+
"""
|
|
236
|
+
stats = {
|
|
237
|
+
"total_pairs": len(qa_pairs),
|
|
238
|
+
"pairs_with_answers": sum(1 for p in qa_pairs if p.has_answer()),
|
|
239
|
+
"pairs_without_answers": sum(1 for p in qa_pairs if not p.has_answer()),
|
|
240
|
+
"sources": {},
|
|
241
|
+
"qtypes": {},
|
|
242
|
+
"focus_categories": {},
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
for pair in qa_pairs:
|
|
246
|
+
# Count by source
|
|
247
|
+
stats["sources"][pair.source] = stats["sources"].get(pair.source, 0) + 1
|
|
248
|
+
|
|
249
|
+
# Count by question type
|
|
250
|
+
stats["qtypes"][pair.qtype] = stats["qtypes"].get(pair.qtype, 0) + 1
|
|
251
|
+
|
|
252
|
+
# Count by focus category
|
|
253
|
+
if pair.focus_category:
|
|
254
|
+
cat = pair.focus_category
|
|
255
|
+
stats["focus_categories"][cat] = stats["focus_categories"].get(cat, 0) + 1
|
|
256
|
+
|
|
257
|
+
return stats
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Text Formatter for Causal Language Modeling
|
|
3
|
+
|
|
4
|
+
Converts structured Q&A pairs into causal text format suitable for
|
|
5
|
+
next-token prediction training.
|
|
6
|
+
|
|
7
|
+
Design decisions explained:
|
|
8
|
+
- Simple format preserves question-answer structure
|
|
9
|
+
- Special tokens ([Q], [A]) help model learn task boundaries
|
|
10
|
+
- Newlines create clear separation for tokenizer
|
|
11
|
+
- No complex templating - reduces failure modes
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from typing import List
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class FormatConfig:
|
|
20
|
+
"""Configuration for text formatting."""
|
|
21
|
+
|
|
22
|
+
use_special_tokens: bool = True # Use [Q] and [A] markers
|
|
23
|
+
add_separator: bool = True # Add newline between Q and A
|
|
24
|
+
add_end_token: bool = True # Add end-of-text marker
|
|
25
|
+
question_prefix: str = "Q: "
|
|
26
|
+
answer_prefix: str = "A: "
|
|
27
|
+
separator: str = "\n"
|
|
28
|
+
end_token: str = "\n\n" # Double newline = document boundary
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CausalTextFormatter:
|
|
32
|
+
"""
|
|
33
|
+
Formats Q&A pairs for causal language modeling.
|
|
34
|
+
|
|
35
|
+
Follows Open-Closed Principle: easy to extend with new formats
|
|
36
|
+
without modifying existing code.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, config: FormatConfig = None):
|
|
40
|
+
"""
|
|
41
|
+
Initialize formatter.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
config: Formatting configuration
|
|
45
|
+
"""
|
|
46
|
+
self.config = config or FormatConfig()
|
|
47
|
+
|
|
48
|
+
def format_single_pair(self, question: str, answer: str) -> str:
|
|
49
|
+
"""
|
|
50
|
+
Format a single Q&A pair.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
question: Question text
|
|
54
|
+
answer: Answer text
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Formatted text string
|
|
58
|
+
|
|
59
|
+
Example:
|
|
60
|
+
>>> formatter = CausalTextFormatter()
|
|
61
|
+
>>> formatter.format_single_pair("What is cancer?", "Cancer is...")
|
|
62
|
+
"Q: What is cancer?\\nA: Cancer is...\\n\\n"
|
|
63
|
+
"""
|
|
64
|
+
parts = []
|
|
65
|
+
|
|
66
|
+
# Add question
|
|
67
|
+
if self.config.use_special_tokens:
|
|
68
|
+
parts.append(f"{self.config.question_prefix}{question}")
|
|
69
|
+
else:
|
|
70
|
+
parts.append(question)
|
|
71
|
+
|
|
72
|
+
# Add separator
|
|
73
|
+
if self.config.add_separator:
|
|
74
|
+
parts.append(self.config.separator)
|
|
75
|
+
|
|
76
|
+
# Add answer
|
|
77
|
+
if self.config.use_special_tokens:
|
|
78
|
+
parts.append(f"{self.config.answer_prefix}{answer}")
|
|
79
|
+
else:
|
|
80
|
+
parts.append(answer)
|
|
81
|
+
|
|
82
|
+
# Add end token
|
|
83
|
+
if self.config.add_end_token:
|
|
84
|
+
parts.append(self.config.end_token)
|
|
85
|
+
|
|
86
|
+
return "".join(parts)
|
|
87
|
+
|
|
88
|
+
def format_batch(self, qa_pairs: List[tuple]) -> str:
|
|
89
|
+
"""
|
|
90
|
+
Format multiple Q&A pairs into a single text corpus.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
qa_pairs: List of (question, answer) tuples
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Single formatted text string
|
|
97
|
+
"""
|
|
98
|
+
formatted_pairs = [self.format_single_pair(q, a) for q, a in qa_pairs]
|
|
99
|
+
return "".join(formatted_pairs)
|
|
100
|
+
|
|
101
|
+
def format_from_structured(self, qa_objects: List) -> str:
|
|
102
|
+
"""
|
|
103
|
+
Format from structured QAPair objects.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
qa_objects: List of QAPair objects (from parser)
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Formatted text corpus
|
|
110
|
+
"""
|
|
111
|
+
pairs = [(obj.question, obj.answer) for obj in qa_objects]
|
|
112
|
+
return self.format_batch(pairs)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class MinimalFormatter(CausalTextFormatter):
|
|
116
|
+
"""
|
|
117
|
+
Minimal format without special tokens.
|
|
118
|
+
Useful for baseline comparisons.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(self):
|
|
122
|
+
config = FormatConfig(
|
|
123
|
+
use_special_tokens=False,
|
|
124
|
+
add_separator=True,
|
|
125
|
+
add_end_token=True,
|
|
126
|
+
separator="\n",
|
|
127
|
+
end_token="\n\n",
|
|
128
|
+
)
|
|
129
|
+
super().__init__(config)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class StructuredFormatter(CausalTextFormatter):
|
|
133
|
+
"""
|
|
134
|
+
More structured format with explicit markers.
|
|
135
|
+
Better for instruction-tuning style training.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def __init__(self):
|
|
139
|
+
config = FormatConfig(
|
|
140
|
+
use_special_tokens=True,
|
|
141
|
+
add_separator=True,
|
|
142
|
+
add_end_token=True,
|
|
143
|
+
question_prefix="### Question: ",
|
|
144
|
+
answer_prefix="### Answer: ",
|
|
145
|
+
separator="\n\n",
|
|
146
|
+
end_token="\n\n---\n\n",
|
|
147
|
+
)
|
|
148
|
+
super().__init__(config)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Inference package."""
|