gptmed 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gptmed/__init__.py +37 -0
- gptmed/configs/__init__.py +1 -0
- gptmed/configs/train_config.py +154 -0
- gptmed/data/__init__.py +5 -0
- gptmed/data/parsers/__init__.py +10 -0
- gptmed/data/parsers/medquad_parser.py +257 -0
- gptmed/data/parsers/text_formatter.py +148 -0
- gptmed/inference/__init__.py +1 -0
- gptmed/inference/decoding_utils.py +190 -0
- gptmed/inference/generation_config.py +83 -0
- gptmed/inference/generator.py +253 -0
- gptmed/inference/sampling.py +261 -0
- gptmed/model/__init__.py +9 -0
- gptmed/model/architecture/__init__.py +35 -0
- gptmed/model/architecture/attention.py +188 -0
- gptmed/model/architecture/decoder_block.py +130 -0
- gptmed/model/architecture/embeddings.py +146 -0
- gptmed/model/architecture/feedforward.py +109 -0
- gptmed/model/architecture/transformer.py +204 -0
- gptmed/model/configs/__init__.py +17 -0
- gptmed/model/configs/model_config.py +155 -0
- gptmed/tokenizer/__init__.py +7 -0
- gptmed/tokenizer/tokenize_data.py +286 -0
- gptmed/tokenizer/train_tokenizer.py +218 -0
- gptmed/training/__init__.py +1 -0
- gptmed/training/dataset.py +183 -0
- gptmed/training/train.py +272 -0
- gptmed/training/trainer.py +331 -0
- gptmed/training/utils.py +212 -0
- gptmed/utils/__init__.py +1 -0
- gptmed/utils/checkpoints.py +224 -0
- gptmed/utils/logging.py +189 -0
- gptmed-0.0.1.dist-info/METADATA +325 -0
- gptmed-0.0.1.dist-info/RECORD +38 -0
- gptmed-0.0.1.dist-info/WHEEL +5 -0
- gptmed-0.0.1.dist-info/entry_points.txt +3 -0
- gptmed-0.0.1.dist-info/licenses/LICENSE +21 -0
- gptmed-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dataset Tokenization
|
|
3
|
+
|
|
4
|
+
Applies trained tokenizer to processed text and creates train/validation splits.
|
|
5
|
+
|
|
6
|
+
Design decisions:
|
|
7
|
+
- Sequence length: 512 tokens (fits GTX 1080, captures most Q&A pairs)
|
|
8
|
+
- Train/val split: 90/10 (enough validation data to detect overfitting)
|
|
9
|
+
- Padding: Left-pad or truncate (causal LM sees left-to-right)
|
|
10
|
+
- No data augmentation: Keep it simple for Phase 1
|
|
11
|
+
|
|
12
|
+
Common failure modes:
|
|
13
|
+
- Truncating answers → model never learns to generate long responses
|
|
14
|
+
- Too short sequences → can't learn context
|
|
15
|
+
- Too long sequences → OOM on GPU
|
|
16
|
+
- Wrong padding → breaks attention masks
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import argparse
|
|
20
|
+
import json
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import List, Tuple
|
|
23
|
+
import random
|
|
24
|
+
|
|
25
|
+
import sentencepiece as spm
|
|
26
|
+
import numpy as np
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def load_tokenizer(model_path: Path) -> spm.SentencePieceProcessor:
|
|
30
|
+
"""Load trained SentencePiece tokenizer."""
|
|
31
|
+
sp = spm.SentencePieceProcessor()
|
|
32
|
+
sp.load(str(model_path))
|
|
33
|
+
return sp
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def tokenize_text(text: str, tokenizer: spm.SentencePieceProcessor) -> List[int]:
|
|
37
|
+
"""
|
|
38
|
+
Tokenize text to IDs.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
text: Input text
|
|
42
|
+
tokenizer: SentencePiece processor
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
List of token IDs
|
|
46
|
+
"""
|
|
47
|
+
return tokenizer.encode_as_ids(text)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def create_sequences(token_ids: List[int], max_length: int, stride: int = None) -> List[List[int]]:
|
|
51
|
+
"""
|
|
52
|
+
Split token IDs into fixed-length sequences.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
token_ids: Full token ID sequence
|
|
56
|
+
max_length: Maximum sequence length
|
|
57
|
+
stride: Stride for overlapping windows (None = no overlap)
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
List of sequences
|
|
61
|
+
|
|
62
|
+
Design note: For causal LM, we create non-overlapping chunks.
|
|
63
|
+
Each chunk is a separate training example for next-token prediction.
|
|
64
|
+
"""
|
|
65
|
+
if stride is None:
|
|
66
|
+
stride = max_length # No overlap
|
|
67
|
+
|
|
68
|
+
sequences = []
|
|
69
|
+
for i in range(0, len(token_ids) - max_length + 1, stride):
|
|
70
|
+
seq = token_ids[i : i + max_length]
|
|
71
|
+
sequences.append(seq)
|
|
72
|
+
|
|
73
|
+
# Handle remaining tokens (last incomplete sequence)
|
|
74
|
+
remainder = len(token_ids) % stride
|
|
75
|
+
if remainder > 0 and len(sequences) > 0:
|
|
76
|
+
# Include last incomplete sequence if it's at least half the max_length
|
|
77
|
+
last_start = len(token_ids) - remainder
|
|
78
|
+
if remainder >= max_length // 2:
|
|
79
|
+
last_seq = token_ids[last_start:]
|
|
80
|
+
# Pad to max_length
|
|
81
|
+
last_seq = last_seq + [0] * (max_length - len(last_seq))
|
|
82
|
+
sequences.append(last_seq)
|
|
83
|
+
|
|
84
|
+
return sequences
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def analyze_lengths(text_file: Path, tokenizer: spm.SentencePieceProcessor) -> dict:
|
|
88
|
+
"""
|
|
89
|
+
Analyze token length distribution to choose optimal sequence length.
|
|
90
|
+
|
|
91
|
+
This helps avoid:
|
|
92
|
+
- Truncating too much data (information loss)
|
|
93
|
+
- Wasting memory on padding (inefficiency)
|
|
94
|
+
"""
|
|
95
|
+
lengths = []
|
|
96
|
+
|
|
97
|
+
with open(text_file, "r", encoding="utf-8") as f:
|
|
98
|
+
content = f.read()
|
|
99
|
+
|
|
100
|
+
# Split by double newline (our document separator)
|
|
101
|
+
documents = content.split("\n\n")
|
|
102
|
+
|
|
103
|
+
for doc in documents:
|
|
104
|
+
if doc.strip():
|
|
105
|
+
tokens = tokenizer.encode_as_ids(doc.strip())
|
|
106
|
+
lengths.append(len(tokens))
|
|
107
|
+
|
|
108
|
+
lengths = np.array(lengths)
|
|
109
|
+
|
|
110
|
+
stats = {
|
|
111
|
+
"num_documents": len(lengths),
|
|
112
|
+
"mean": float(np.mean(lengths)),
|
|
113
|
+
"median": float(np.median(lengths)),
|
|
114
|
+
"std": float(np.std(lengths)),
|
|
115
|
+
"min": int(np.min(lengths)),
|
|
116
|
+
"max": int(np.max(lengths)),
|
|
117
|
+
"percentile_50": float(np.percentile(lengths, 50)),
|
|
118
|
+
"percentile_75": float(np.percentile(lengths, 75)),
|
|
119
|
+
"percentile_90": float(np.percentile(lengths, 90)),
|
|
120
|
+
"percentile_95": float(np.percentile(lengths, 95)),
|
|
121
|
+
"percentile_99": float(np.percentile(lengths, 99)),
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
return stats
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def main():
|
|
128
|
+
parser = argparse.ArgumentParser(description="Tokenize MedQuAD dataset with trained tokenizer")
|
|
129
|
+
parser.add_argument(
|
|
130
|
+
"--input-file",
|
|
131
|
+
type=str,
|
|
132
|
+
default="./data/processed/medquad_simple.txt",
|
|
133
|
+
help="Input text file",
|
|
134
|
+
)
|
|
135
|
+
parser.add_argument(
|
|
136
|
+
"--tokenizer-model",
|
|
137
|
+
type=str,
|
|
138
|
+
default="./tokenizer/medquad_tokenizer.model",
|
|
139
|
+
help="Path to trained tokenizer model",
|
|
140
|
+
)
|
|
141
|
+
parser.add_argument(
|
|
142
|
+
"--output-dir",
|
|
143
|
+
type=str,
|
|
144
|
+
default="./data/tokenized",
|
|
145
|
+
help="Output directory for tokenized data",
|
|
146
|
+
)
|
|
147
|
+
parser.add_argument(
|
|
148
|
+
"--max-length",
|
|
149
|
+
type=int,
|
|
150
|
+
default=512,
|
|
151
|
+
help="Maximum sequence length (default: 512 for GTX 1080)",
|
|
152
|
+
)
|
|
153
|
+
parser.add_argument(
|
|
154
|
+
"--train-ratio", type=float, default=0.9, help="Train/val split ratio (default: 0.9)"
|
|
155
|
+
)
|
|
156
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
|
157
|
+
|
|
158
|
+
args = parser.parse_args()
|
|
159
|
+
|
|
160
|
+
# Set random seed
|
|
161
|
+
random.seed(args.seed)
|
|
162
|
+
np.random.seed(args.seed)
|
|
163
|
+
|
|
164
|
+
print("=" * 60)
|
|
165
|
+
print("Dataset Tokenization")
|
|
166
|
+
print("=" * 60)
|
|
167
|
+
print(f"Input: {args.input_file}")
|
|
168
|
+
print(f"Tokenizer: {args.tokenizer_model}")
|
|
169
|
+
print(f"Max length: {args.max_length}")
|
|
170
|
+
print(f"Train ratio: {args.train_ratio}")
|
|
171
|
+
print()
|
|
172
|
+
|
|
173
|
+
# Load tokenizer
|
|
174
|
+
input_file = Path(args.input_file)
|
|
175
|
+
tokenizer_model = Path(args.tokenizer_model)
|
|
176
|
+
|
|
177
|
+
if not input_file.exists():
|
|
178
|
+
print(f"❌ Error: Input file not found: {input_file}")
|
|
179
|
+
return
|
|
180
|
+
|
|
181
|
+
if not tokenizer_model.exists():
|
|
182
|
+
print(f"❌ Error: Tokenizer not found: {tokenizer_model}")
|
|
183
|
+
print("\nPlease train tokenizer first:")
|
|
184
|
+
print(" python tokenizer/train_tokenizer.py")
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
print("Loading tokenizer...")
|
|
188
|
+
tokenizer = load_tokenizer(tokenizer_model)
|
|
189
|
+
print(f"✅ Tokenizer loaded (vocab size: {tokenizer.vocab_size()})")
|
|
190
|
+
|
|
191
|
+
# Analyze sequence lengths
|
|
192
|
+
print("\nAnalyzing sequence lengths...")
|
|
193
|
+
stats = analyze_lengths(input_file, tokenizer)
|
|
194
|
+
|
|
195
|
+
print("\nSequence Length Statistics:")
|
|
196
|
+
print(f" Documents: {stats['num_documents']}")
|
|
197
|
+
print(f" Mean: {stats['mean']:.1f} tokens")
|
|
198
|
+
print(f" Median: {stats['median']:.1f} tokens")
|
|
199
|
+
print(f" Std dev: {stats['std']:.1f} tokens")
|
|
200
|
+
print(f" Min: {stats['min']} tokens")
|
|
201
|
+
print(f" Max: {stats['max']} tokens")
|
|
202
|
+
print(f"\nPercentiles:")
|
|
203
|
+
print(f" 50th: {stats['percentile_50']:.1f} tokens")
|
|
204
|
+
print(f" 75th: {stats['percentile_75']:.1f} tokens")
|
|
205
|
+
print(f" 90th: {stats['percentile_90']:.1f} tokens")
|
|
206
|
+
print(f" 95th: {stats['percentile_95']:.1f} tokens")
|
|
207
|
+
print(f" 99th: {stats['percentile_99']:.1f} tokens")
|
|
208
|
+
|
|
209
|
+
# Warning about truncation
|
|
210
|
+
truncated_pct = (
|
|
211
|
+
np.array([l for l in stats.values() if isinstance(l, (int, float))]) > args.max_length
|
|
212
|
+
).sum()
|
|
213
|
+
if stats["percentile_95"] > args.max_length:
|
|
214
|
+
print(
|
|
215
|
+
f"\n⚠️ WARNING: max_length={args.max_length} will truncate ~{100-95:.0f}% of sequences"
|
|
216
|
+
)
|
|
217
|
+
print(f" Consider increasing to {int(stats['percentile_95'])} to capture 95% of data")
|
|
218
|
+
|
|
219
|
+
# Tokenize full dataset
|
|
220
|
+
print("\nTokenizing dataset...")
|
|
221
|
+
with open(input_file, "r", encoding="utf-8") as f:
|
|
222
|
+
full_text = f.read()
|
|
223
|
+
|
|
224
|
+
token_ids = tokenizer.encode_as_ids(full_text)
|
|
225
|
+
print(f"Total tokens: {len(token_ids):,}")
|
|
226
|
+
|
|
227
|
+
# Create sequences
|
|
228
|
+
print(f"\nCreating sequences (max_length={args.max_length})...")
|
|
229
|
+
sequences = create_sequences(token_ids, max_length=args.max_length)
|
|
230
|
+
print(f"Total sequences: {len(sequences):,}")
|
|
231
|
+
|
|
232
|
+
# Train/val split
|
|
233
|
+
random.shuffle(sequences)
|
|
234
|
+
split_idx = int(len(sequences) * args.train_ratio)
|
|
235
|
+
|
|
236
|
+
train_sequences = sequences[:split_idx]
|
|
237
|
+
val_sequences = sequences[split_idx:]
|
|
238
|
+
|
|
239
|
+
print(f"\nSplit:")
|
|
240
|
+
print(f" Train: {len(train_sequences):,} sequences")
|
|
241
|
+
print(f" Val: {len(val_sequences):,} sequences")
|
|
242
|
+
|
|
243
|
+
# Save tokenized data
|
|
244
|
+
output_dir = Path(args.output_dir)
|
|
245
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
246
|
+
|
|
247
|
+
print("\nSaving tokenized data...")
|
|
248
|
+
|
|
249
|
+
# Save as numpy arrays (efficient for PyTorch DataLoader)
|
|
250
|
+
train_array = np.array(train_sequences, dtype=np.int32)
|
|
251
|
+
val_array = np.array(val_sequences, dtype=np.int32)
|
|
252
|
+
|
|
253
|
+
np.save(output_dir / "train.npy", train_array)
|
|
254
|
+
np.save(output_dir / "val.npy", val_array)
|
|
255
|
+
|
|
256
|
+
print(f"✅ Train data saved: {output_dir / 'train.npy'}")
|
|
257
|
+
print(f"✅ Val data saved: {output_dir / 'val.npy'}")
|
|
258
|
+
|
|
259
|
+
# Save metadata
|
|
260
|
+
metadata = {
|
|
261
|
+
"vocab_size": tokenizer.vocab_size(),
|
|
262
|
+
"max_length": args.max_length,
|
|
263
|
+
"num_train_sequences": len(train_sequences),
|
|
264
|
+
"num_val_sequences": len(val_sequences),
|
|
265
|
+
"total_tokens": len(token_ids),
|
|
266
|
+
"length_stats": stats,
|
|
267
|
+
"tokenizer_model": str(tokenizer_model),
|
|
268
|
+
"seed": args.seed,
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
with open(output_dir / "metadata.json", "w") as f:
|
|
272
|
+
json.dump(metadata, f, indent=2)
|
|
273
|
+
|
|
274
|
+
print(f"✅ Metadata saved: {output_dir / 'metadata.json'}")
|
|
275
|
+
|
|
276
|
+
print("\n" + "=" * 60)
|
|
277
|
+
print("✅ Tokenization complete!")
|
|
278
|
+
print("=" * 60)
|
|
279
|
+
print("\nNext steps (Phase 2):")
|
|
280
|
+
print("1. Implement Transformer architecture")
|
|
281
|
+
print("2. Create PyTorch Dataset and DataLoader")
|
|
282
|
+
print("3. Define training loop")
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
if __name__ == "__main__":
|
|
286
|
+
main()
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SentencePiece Tokenizer Training
|
|
3
|
+
|
|
4
|
+
Trains a BPE (Byte-Pair Encoding) tokenizer on the processed MedQuAD text.
|
|
5
|
+
|
|
6
|
+
Design decisions explained:
|
|
7
|
+
- BPE over WordPiece: Better for medical terminology (handles rare words via subwords)
|
|
8
|
+
- Vocab size: Trade-off between model size and expressiveness
|
|
9
|
+
- Character coverage: 0.9995 captures medical unicode (Greek letters, symbols)
|
|
10
|
+
- No pre-tokenization: Let BPE learn from raw text
|
|
11
|
+
|
|
12
|
+
Common failure modes to avoid:
|
|
13
|
+
- Too small vocab → repetitive, generic text
|
|
14
|
+
- Too large vocab → overfitting, slow training
|
|
15
|
+
- Missing special tokens → can't mark boundaries
|
|
16
|
+
- Wrong normalization → "COVID-19" vs "covid-19" treated differently
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import argparse
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
import sentencepiece as spm
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def train_sentencepiece_tokenizer(
|
|
25
|
+
input_file: Path,
|
|
26
|
+
output_prefix: Path,
|
|
27
|
+
vocab_size: int = 8000,
|
|
28
|
+
model_type: str = "bpe",
|
|
29
|
+
character_coverage: float = 0.9995,
|
|
30
|
+
add_special_tokens: bool = True,
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Train a SentencePiece tokenizer.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
input_file: Path to training text file
|
|
37
|
+
output_prefix: Output path prefix (will create .model and .vocab files)
|
|
38
|
+
vocab_size: Vocabulary size (default 8000 for small dataset)
|
|
39
|
+
model_type: 'bpe' or 'unigram'
|
|
40
|
+
character_coverage: Character coverage for unicode (0.9995 = medical terms)
|
|
41
|
+
add_special_tokens: Whether to add custom special tokens
|
|
42
|
+
|
|
43
|
+
Design notes:
|
|
44
|
+
- vocab_size=8000: Good for 40K examples on GTX 1080
|
|
45
|
+
* Too small (2K): Poor compression, repetitive outputs
|
|
46
|
+
* Too large (32K): Overfits, wastes memory on rare terms
|
|
47
|
+
- BPE: Deterministic, easier to debug than unigram
|
|
48
|
+
- character_coverage=0.9995: Captures medical unicode (μ, α, β, etc.)
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
# Build training command
|
|
52
|
+
train_args = [
|
|
53
|
+
f"--input={input_file}",
|
|
54
|
+
f"--model_prefix={output_prefix}",
|
|
55
|
+
f"--vocab_size={vocab_size}",
|
|
56
|
+
f"--model_type={model_type}",
|
|
57
|
+
f"--character_coverage={character_coverage}",
|
|
58
|
+
"--pad_id=0", # <pad> token at ID 0
|
|
59
|
+
"--unk_id=1", # <unk> token at ID 1
|
|
60
|
+
"--bos_id=2", # <bos> (beginning of sequence) at ID 2
|
|
61
|
+
"--eos_id=3", # <eos> (end of sequence) at ID 3
|
|
62
|
+
"--pad_piece=[PAD]",
|
|
63
|
+
"--unk_piece=[UNK]",
|
|
64
|
+
"--bos_piece=[BOS]",
|
|
65
|
+
"--eos_piece=[EOS]",
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
# Add user-defined special tokens if requested
|
|
69
|
+
if add_special_tokens:
|
|
70
|
+
# These match our text formatting (Q:, A:)
|
|
71
|
+
user_defined = "[Q],[A]"
|
|
72
|
+
train_args.append(f"--user_defined_symbols={user_defined}")
|
|
73
|
+
|
|
74
|
+
# Normalization rules (CAREFUL: medical text is case-sensitive)
|
|
75
|
+
# We use minimal normalization to preserve:
|
|
76
|
+
# - "COVID-19" vs "covid-19"
|
|
77
|
+
# - Drug names (proper capitalization matters)
|
|
78
|
+
# - Abbreviations (BP vs bp)
|
|
79
|
+
train_args.extend(
|
|
80
|
+
[
|
|
81
|
+
"--normalization_rule_name=identity", # No normalization
|
|
82
|
+
"--remove_extra_whitespaces=true", # Clean whitespace
|
|
83
|
+
"--split_by_unicode_script=true", # Split CJK if present
|
|
84
|
+
"--split_by_whitespace=true",
|
|
85
|
+
"--split_by_number=true", # Split numbers
|
|
86
|
+
"--max_sentence_length=4096", # Long medical answers
|
|
87
|
+
]
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
print("=" * 60)
|
|
91
|
+
print("Training SentencePiece Tokenizer")
|
|
92
|
+
print("=" * 60)
|
|
93
|
+
print(f"Input: {input_file}")
|
|
94
|
+
print(f"Output prefix: {output_prefix}")
|
|
95
|
+
print(f"Vocab size: {vocab_size}")
|
|
96
|
+
print(f"Model type: {model_type}")
|
|
97
|
+
print(f"Character coverage: {character_coverage}")
|
|
98
|
+
print()
|
|
99
|
+
|
|
100
|
+
# Train the tokenizer
|
|
101
|
+
spm.SentencePieceTrainer.Train(" ".join(train_args))
|
|
102
|
+
|
|
103
|
+
print("\n✅ Tokenizer training complete!")
|
|
104
|
+
print(f"Model saved: {output_prefix}.model")
|
|
105
|
+
print(f"Vocab saved: {output_prefix}.vocab")
|
|
106
|
+
|
|
107
|
+
# Load and inspect the tokenizer
|
|
108
|
+
sp = spm.SentencePieceProcessor()
|
|
109
|
+
sp.load(f"{output_prefix}.model")
|
|
110
|
+
|
|
111
|
+
print("\n" + "=" * 60)
|
|
112
|
+
print("Tokenizer Statistics")
|
|
113
|
+
print("=" * 60)
|
|
114
|
+
print(f"Vocabulary size: {sp.vocab_size()}")
|
|
115
|
+
print(f"PAD token: {sp.id_to_piece(0)} (ID: 0)")
|
|
116
|
+
print(f"UNK token: {sp.id_to_piece(1)} (ID: 1)")
|
|
117
|
+
print(f"BOS token: {sp.id_to_piece(2)} (ID: 2)")
|
|
118
|
+
print(f"EOS token: {sp.id_to_piece(3)} (ID: 3)")
|
|
119
|
+
|
|
120
|
+
# Test tokenization on medical example
|
|
121
|
+
print("\n" + "=" * 60)
|
|
122
|
+
print("Sample Tokenization")
|
|
123
|
+
print("=" * 60)
|
|
124
|
+
|
|
125
|
+
test_texts = [
|
|
126
|
+
"What is diabetes?",
|
|
127
|
+
"COVID-19 vaccination side effects",
|
|
128
|
+
"Hypertension treatment guidelines",
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
for text in test_texts:
|
|
132
|
+
tokens = sp.encode_as_pieces(text)
|
|
133
|
+
ids = sp.encode_as_ids(text)
|
|
134
|
+
print(f"\nText: {text}")
|
|
135
|
+
print(f"Tokens: {tokens}")
|
|
136
|
+
print(f"IDs: {ids}")
|
|
137
|
+
print(f"Token count: {len(tokens)}")
|
|
138
|
+
|
|
139
|
+
# Vocabulary inspection
|
|
140
|
+
print("\n" + "=" * 60)
|
|
141
|
+
print("Sample Vocabulary (first 20 tokens)")
|
|
142
|
+
print("=" * 60)
|
|
143
|
+
for i in range(min(20, sp.vocab_size())):
|
|
144
|
+
piece = sp.id_to_piece(i)
|
|
145
|
+
print(f"ID {i:4d}: {piece}")
|
|
146
|
+
|
|
147
|
+
return sp
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def main():
|
|
151
|
+
parser = argparse.ArgumentParser(description="Train SentencePiece tokenizer for MedQuAD")
|
|
152
|
+
parser.add_argument(
|
|
153
|
+
"--input-file",
|
|
154
|
+
type=str,
|
|
155
|
+
default="./data/processed/medquad_simple.txt",
|
|
156
|
+
help="Input text file for training",
|
|
157
|
+
)
|
|
158
|
+
parser.add_argument(
|
|
159
|
+
"--output-dir", type=str, default="./tokenizer", help="Output directory for tokenizer files"
|
|
160
|
+
)
|
|
161
|
+
parser.add_argument(
|
|
162
|
+
"--vocab-size",
|
|
163
|
+
type=int,
|
|
164
|
+
default=8000,
|
|
165
|
+
help="Vocabulary size (default: 8000 for small dataset)",
|
|
166
|
+
)
|
|
167
|
+
parser.add_argument(
|
|
168
|
+
"--model-type",
|
|
169
|
+
type=str,
|
|
170
|
+
choices=["bpe", "unigram"],
|
|
171
|
+
default="bpe",
|
|
172
|
+
help="Tokenizer algorithm",
|
|
173
|
+
)
|
|
174
|
+
parser.add_argument(
|
|
175
|
+
"--character-coverage",
|
|
176
|
+
type=float,
|
|
177
|
+
default=0.9995,
|
|
178
|
+
help="Character coverage (0.9995 = captures medical unicode)",
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
args = parser.parse_args()
|
|
182
|
+
|
|
183
|
+
# Validate input file
|
|
184
|
+
input_file = Path(args.input_file)
|
|
185
|
+
if not input_file.exists():
|
|
186
|
+
print(f"❌ Error: Input file not found: {input_file}")
|
|
187
|
+
print("\nPlease run preprocessing first:")
|
|
188
|
+
print(" python preprocess.py")
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
# Create output directory
|
|
192
|
+
output_dir = Path(args.output_dir)
|
|
193
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
194
|
+
|
|
195
|
+
# Output prefix for .model and .vocab files
|
|
196
|
+
output_prefix = output_dir / "medquad_tokenizer"
|
|
197
|
+
|
|
198
|
+
# Train tokenizer
|
|
199
|
+
tokenizer = train_sentencepiece_tokenizer(
|
|
200
|
+
input_file=input_file,
|
|
201
|
+
output_prefix=output_prefix,
|
|
202
|
+
vocab_size=args.vocab_size,
|
|
203
|
+
model_type=args.model_type,
|
|
204
|
+
character_coverage=args.character_coverage,
|
|
205
|
+
add_special_tokens=True,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
print("\n" + "=" * 60)
|
|
209
|
+
print("✅ Tokenizer training complete!")
|
|
210
|
+
print("=" * 60)
|
|
211
|
+
print("\nNext steps:")
|
|
212
|
+
print("1. Inspect vocabulary for medical terms")
|
|
213
|
+
print("2. Test on sample medical texts")
|
|
214
|
+
print("3. Tokenize full dataset: python tokenizer/tokenize_data.py")
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
if __name__ == "__main__":
|
|
218
|
+
main()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Training package."""
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PyTorch Dataset for Tokenized Data
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Load tokenized sequences from numpy arrays and provide them as PyTorch tensors
|
|
6
|
+
for training. Handles batching and shuffling efficiently.
|
|
7
|
+
|
|
8
|
+
WHAT THIS FILE DOES:
|
|
9
|
+
1. Load tokenized data from .npy files
|
|
10
|
+
2. Create training examples for next-token prediction
|
|
11
|
+
3. Provide batches to DataLoader
|
|
12
|
+
|
|
13
|
+
TRAINING OBJECTIVE EXPLAINED:
|
|
14
|
+
For causal language modeling, each sequence is both input and target:
|
|
15
|
+
- Input: [token_0, token_1, token_2, ..., token_n-1]
|
|
16
|
+
- Target: [token_1, token_2, token_3, ..., token_n]
|
|
17
|
+
|
|
18
|
+
The model learns to predict token_i given tokens [0, 1, ..., i-1].
|
|
19
|
+
|
|
20
|
+
PACKAGES USED:
|
|
21
|
+
- torch: PyTorch tensors and Dataset
|
|
22
|
+
- numpy: Load .npy files
|
|
23
|
+
|
|
24
|
+
FILES FROM THIS PROJECT:
|
|
25
|
+
- data/tokenized/train.npy (created by tokenize_data.py)
|
|
26
|
+
- data/tokenized/val.npy
|
|
27
|
+
|
|
28
|
+
TENSOR SHAPES:
|
|
29
|
+
- Loaded data: [num_sequences, seq_len]
|
|
30
|
+
- Each batch: [batch_size, seq_len]
|
|
31
|
+
- Input: [batch_size, seq_len]
|
|
32
|
+
- Target: [batch_size, seq_len]
|
|
33
|
+
|
|
34
|
+
COMMON FAILURE MODES:
|
|
35
|
+
- Wrong data path → FileNotFoundError
|
|
36
|
+
- Mismatched shapes → crashes during training
|
|
37
|
+
- Not shuffling train data → poor generalization
|
|
38
|
+
- Loading entire dataset into memory → OOM (for large datasets)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
import numpy as np
|
|
42
|
+
import torch
|
|
43
|
+
from torch.utils.data import Dataset, DataLoader
|
|
44
|
+
from pathlib import Path
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class TokenizedDataset(Dataset):
|
|
48
|
+
"""
|
|
49
|
+
Dataset for tokenized sequences.
|
|
50
|
+
|
|
51
|
+
This is a simple in-memory dataset. For larger datasets, you'd use
|
|
52
|
+
memory-mapped arrays or streaming.
|
|
53
|
+
|
|
54
|
+
Each item returns (input_ids, target_ids) for next-token prediction.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, data_path: Path):
|
|
58
|
+
"""
|
|
59
|
+
Args:
|
|
60
|
+
data_path: Path to .npy file with tokenized sequences
|
|
61
|
+
"""
|
|
62
|
+
self.data_path = Path(data_path)
|
|
63
|
+
|
|
64
|
+
if not self.data_path.exists():
|
|
65
|
+
raise FileNotFoundError(f"Data file not found: {data_path}")
|
|
66
|
+
|
|
67
|
+
# Load tokenized sequences
|
|
68
|
+
# Shape: [num_sequences, seq_len]
|
|
69
|
+
self.data = np.load(self.data_path)
|
|
70
|
+
|
|
71
|
+
print(f"Loaded dataset from {data_path}")
|
|
72
|
+
print(f" Shape: {self.data.shape}")
|
|
73
|
+
print(f" Dtype: {self.data.dtype}")
|
|
74
|
+
print(f" Num sequences: {len(self.data)}")
|
|
75
|
+
|
|
76
|
+
def __len__(self) -> int:
|
|
77
|
+
"""Number of sequences in dataset."""
|
|
78
|
+
return len(self.data)
|
|
79
|
+
|
|
80
|
+
def __getitem__(self, idx: int) -> tuple:
|
|
81
|
+
"""
|
|
82
|
+
Get a single training example.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
idx: Sequence index
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
(input_ids, target_ids) tuple
|
|
89
|
+
|
|
90
|
+
Training objective:
|
|
91
|
+
input: [t0, t1, t2, ..., tn-1]
|
|
92
|
+
target: [t1, t2, t3, ..., tn]
|
|
93
|
+
|
|
94
|
+
The model predicts token at position i using tokens [0, ..., i-1].
|
|
95
|
+
"""
|
|
96
|
+
sequence = self.data[idx]
|
|
97
|
+
|
|
98
|
+
# Convert to torch tensor
|
|
99
|
+
sequence = torch.from_numpy(sequence).long()
|
|
100
|
+
|
|
101
|
+
# For causal LM, input and target are the same sequence, shifted by 1
|
|
102
|
+
# Input: [0, 1, 2, 3, ..., n-1]
|
|
103
|
+
# Target: [1, 2, 3, 4, ..., n]
|
|
104
|
+
input_ids = sequence[:-1] # All tokens except last
|
|
105
|
+
target_ids = sequence[1:] # All tokens except first
|
|
106
|
+
|
|
107
|
+
return input_ids, target_ids
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def create_dataloaders(
|
|
111
|
+
train_path: Path, val_path: Path, batch_size: int, num_workers: int = 0
|
|
112
|
+
) -> tuple:
|
|
113
|
+
"""
|
|
114
|
+
Create train and validation dataloaders.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
train_path: Path to train .npy file
|
|
118
|
+
val_path: Path to validation .npy file
|
|
119
|
+
batch_size: Batch size
|
|
120
|
+
num_workers: Number of dataloader workers (0 = main process)
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
(train_loader, val_loader) tuple
|
|
124
|
+
|
|
125
|
+
Design decisions:
|
|
126
|
+
- Shuffle train data (prevents overfitting to sequence order)
|
|
127
|
+
- Don't shuffle val data (consistent evaluation)
|
|
128
|
+
- num_workers=0 on small datasets (overhead not worth it)
|
|
129
|
+
- drop_last=True to avoid partial batches (can cause issues with batch norm)
|
|
130
|
+
"""
|
|
131
|
+
# Create datasets
|
|
132
|
+
train_dataset = TokenizedDataset(train_path)
|
|
133
|
+
val_dataset = TokenizedDataset(val_path)
|
|
134
|
+
|
|
135
|
+
# Create dataloaders
|
|
136
|
+
train_loader = DataLoader(
|
|
137
|
+
train_dataset,
|
|
138
|
+
batch_size=batch_size,
|
|
139
|
+
shuffle=True, # Shuffle for training
|
|
140
|
+
num_workers=num_workers,
|
|
141
|
+
pin_memory=True, # Faster GPU transfer
|
|
142
|
+
drop_last=True, # Drop incomplete last batch
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
val_loader = DataLoader(
|
|
146
|
+
val_dataset,
|
|
147
|
+
batch_size=batch_size,
|
|
148
|
+
shuffle=False, # Don't shuffle validation
|
|
149
|
+
num_workers=num_workers,
|
|
150
|
+
pin_memory=True,
|
|
151
|
+
drop_last=False, # Keep all validation data
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
print(f"\nDataLoaders created:")
|
|
155
|
+
print(f" Train batches: {len(train_loader)}")
|
|
156
|
+
print(f" Val batches: {len(val_loader)}")
|
|
157
|
+
print(f" Batch size: {batch_size}")
|
|
158
|
+
|
|
159
|
+
return train_loader, val_loader
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def get_batch_info(batch: tuple) -> dict:
|
|
163
|
+
"""
|
|
164
|
+
Get information about a batch (for debugging).
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
batch: (input_ids, target_ids) tuple
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Dictionary with batch statistics
|
|
171
|
+
"""
|
|
172
|
+
input_ids, target_ids = batch
|
|
173
|
+
|
|
174
|
+
return {
|
|
175
|
+
"batch_size": input_ids.size(0),
|
|
176
|
+
"seq_len": input_ids.size(1),
|
|
177
|
+
"input_shape": tuple(input_ids.shape),
|
|
178
|
+
"target_shape": tuple(target_ids.shape),
|
|
179
|
+
"input_dtype": input_ids.dtype,
|
|
180
|
+
"target_dtype": target_ids.dtype,
|
|
181
|
+
"input_device": input_ids.device,
|
|
182
|
+
"target_device": target_ids.device,
|
|
183
|
+
}
|