gptmed 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. gptmed/__init__.py +37 -0
  2. gptmed/configs/__init__.py +1 -0
  3. gptmed/configs/train_config.py +154 -0
  4. gptmed/data/__init__.py +5 -0
  5. gptmed/data/parsers/__init__.py +10 -0
  6. gptmed/data/parsers/medquad_parser.py +257 -0
  7. gptmed/data/parsers/text_formatter.py +148 -0
  8. gptmed/inference/__init__.py +1 -0
  9. gptmed/inference/decoding_utils.py +190 -0
  10. gptmed/inference/generation_config.py +83 -0
  11. gptmed/inference/generator.py +253 -0
  12. gptmed/inference/sampling.py +261 -0
  13. gptmed/model/__init__.py +9 -0
  14. gptmed/model/architecture/__init__.py +35 -0
  15. gptmed/model/architecture/attention.py +188 -0
  16. gptmed/model/architecture/decoder_block.py +130 -0
  17. gptmed/model/architecture/embeddings.py +146 -0
  18. gptmed/model/architecture/feedforward.py +109 -0
  19. gptmed/model/architecture/transformer.py +204 -0
  20. gptmed/model/configs/__init__.py +17 -0
  21. gptmed/model/configs/model_config.py +155 -0
  22. gptmed/tokenizer/__init__.py +7 -0
  23. gptmed/tokenizer/tokenize_data.py +286 -0
  24. gptmed/tokenizer/train_tokenizer.py +218 -0
  25. gptmed/training/__init__.py +1 -0
  26. gptmed/training/dataset.py +183 -0
  27. gptmed/training/train.py +272 -0
  28. gptmed/training/trainer.py +331 -0
  29. gptmed/training/utils.py +212 -0
  30. gptmed/utils/__init__.py +1 -0
  31. gptmed/utils/checkpoints.py +224 -0
  32. gptmed/utils/logging.py +189 -0
  33. gptmed-0.0.1.dist-info/METADATA +325 -0
  34. gptmed-0.0.1.dist-info/RECORD +38 -0
  35. gptmed-0.0.1.dist-info/WHEEL +5 -0
  36. gptmed-0.0.1.dist-info/entry_points.txt +3 -0
  37. gptmed-0.0.1.dist-info/licenses/LICENSE +21 -0
  38. gptmed-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,286 @@
1
+ """
2
+ Dataset Tokenization
3
+
4
+ Applies trained tokenizer to processed text and creates train/validation splits.
5
+
6
+ Design decisions:
7
+ - Sequence length: 512 tokens (fits GTX 1080, captures most Q&A pairs)
8
+ - Train/val split: 90/10 (enough validation data to detect overfitting)
9
+ - Padding: Left-pad or truncate (causal LM sees left-to-right)
10
+ - No data augmentation: Keep it simple for Phase 1
11
+
12
+ Common failure modes:
13
+ - Truncating answers → model never learns to generate long responses
14
+ - Too short sequences → can't learn context
15
+ - Too long sequences → OOM on GPU
16
+ - Wrong padding → breaks attention masks
17
+ """
18
+
19
+ import argparse
20
+ import json
21
+ from pathlib import Path
22
+ from typing import List, Tuple
23
+ import random
24
+
25
+ import sentencepiece as spm
26
+ import numpy as np
27
+
28
+
29
+ def load_tokenizer(model_path: Path) -> spm.SentencePieceProcessor:
30
+ """Load trained SentencePiece tokenizer."""
31
+ sp = spm.SentencePieceProcessor()
32
+ sp.load(str(model_path))
33
+ return sp
34
+
35
+
36
+ def tokenize_text(text: str, tokenizer: spm.SentencePieceProcessor) -> List[int]:
37
+ """
38
+ Tokenize text to IDs.
39
+
40
+ Args:
41
+ text: Input text
42
+ tokenizer: SentencePiece processor
43
+
44
+ Returns:
45
+ List of token IDs
46
+ """
47
+ return tokenizer.encode_as_ids(text)
48
+
49
+
50
+ def create_sequences(token_ids: List[int], max_length: int, stride: int = None) -> List[List[int]]:
51
+ """
52
+ Split token IDs into fixed-length sequences.
53
+
54
+ Args:
55
+ token_ids: Full token ID sequence
56
+ max_length: Maximum sequence length
57
+ stride: Stride for overlapping windows (None = no overlap)
58
+
59
+ Returns:
60
+ List of sequences
61
+
62
+ Design note: For causal LM, we create non-overlapping chunks.
63
+ Each chunk is a separate training example for next-token prediction.
64
+ """
65
+ if stride is None:
66
+ stride = max_length # No overlap
67
+
68
+ sequences = []
69
+ for i in range(0, len(token_ids) - max_length + 1, stride):
70
+ seq = token_ids[i : i + max_length]
71
+ sequences.append(seq)
72
+
73
+ # Handle remaining tokens (last incomplete sequence)
74
+ remainder = len(token_ids) % stride
75
+ if remainder > 0 and len(sequences) > 0:
76
+ # Include last incomplete sequence if it's at least half the max_length
77
+ last_start = len(token_ids) - remainder
78
+ if remainder >= max_length // 2:
79
+ last_seq = token_ids[last_start:]
80
+ # Pad to max_length
81
+ last_seq = last_seq + [0] * (max_length - len(last_seq))
82
+ sequences.append(last_seq)
83
+
84
+ return sequences
85
+
86
+
87
+ def analyze_lengths(text_file: Path, tokenizer: spm.SentencePieceProcessor) -> dict:
88
+ """
89
+ Analyze token length distribution to choose optimal sequence length.
90
+
91
+ This helps avoid:
92
+ - Truncating too much data (information loss)
93
+ - Wasting memory on padding (inefficiency)
94
+ """
95
+ lengths = []
96
+
97
+ with open(text_file, "r", encoding="utf-8") as f:
98
+ content = f.read()
99
+
100
+ # Split by double newline (our document separator)
101
+ documents = content.split("\n\n")
102
+
103
+ for doc in documents:
104
+ if doc.strip():
105
+ tokens = tokenizer.encode_as_ids(doc.strip())
106
+ lengths.append(len(tokens))
107
+
108
+ lengths = np.array(lengths)
109
+
110
+ stats = {
111
+ "num_documents": len(lengths),
112
+ "mean": float(np.mean(lengths)),
113
+ "median": float(np.median(lengths)),
114
+ "std": float(np.std(lengths)),
115
+ "min": int(np.min(lengths)),
116
+ "max": int(np.max(lengths)),
117
+ "percentile_50": float(np.percentile(lengths, 50)),
118
+ "percentile_75": float(np.percentile(lengths, 75)),
119
+ "percentile_90": float(np.percentile(lengths, 90)),
120
+ "percentile_95": float(np.percentile(lengths, 95)),
121
+ "percentile_99": float(np.percentile(lengths, 99)),
122
+ }
123
+
124
+ return stats
125
+
126
+
127
+ def main():
128
+ parser = argparse.ArgumentParser(description="Tokenize MedQuAD dataset with trained tokenizer")
129
+ parser.add_argument(
130
+ "--input-file",
131
+ type=str,
132
+ default="./data/processed/medquad_simple.txt",
133
+ help="Input text file",
134
+ )
135
+ parser.add_argument(
136
+ "--tokenizer-model",
137
+ type=str,
138
+ default="./tokenizer/medquad_tokenizer.model",
139
+ help="Path to trained tokenizer model",
140
+ )
141
+ parser.add_argument(
142
+ "--output-dir",
143
+ type=str,
144
+ default="./data/tokenized",
145
+ help="Output directory for tokenized data",
146
+ )
147
+ parser.add_argument(
148
+ "--max-length",
149
+ type=int,
150
+ default=512,
151
+ help="Maximum sequence length (default: 512 for GTX 1080)",
152
+ )
153
+ parser.add_argument(
154
+ "--train-ratio", type=float, default=0.9, help="Train/val split ratio (default: 0.9)"
155
+ )
156
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
157
+
158
+ args = parser.parse_args()
159
+
160
+ # Set random seed
161
+ random.seed(args.seed)
162
+ np.random.seed(args.seed)
163
+
164
+ print("=" * 60)
165
+ print("Dataset Tokenization")
166
+ print("=" * 60)
167
+ print(f"Input: {args.input_file}")
168
+ print(f"Tokenizer: {args.tokenizer_model}")
169
+ print(f"Max length: {args.max_length}")
170
+ print(f"Train ratio: {args.train_ratio}")
171
+ print()
172
+
173
+ # Load tokenizer
174
+ input_file = Path(args.input_file)
175
+ tokenizer_model = Path(args.tokenizer_model)
176
+
177
+ if not input_file.exists():
178
+ print(f"❌ Error: Input file not found: {input_file}")
179
+ return
180
+
181
+ if not tokenizer_model.exists():
182
+ print(f"❌ Error: Tokenizer not found: {tokenizer_model}")
183
+ print("\nPlease train tokenizer first:")
184
+ print(" python tokenizer/train_tokenizer.py")
185
+ return
186
+
187
+ print("Loading tokenizer...")
188
+ tokenizer = load_tokenizer(tokenizer_model)
189
+ print(f"✅ Tokenizer loaded (vocab size: {tokenizer.vocab_size()})")
190
+
191
+ # Analyze sequence lengths
192
+ print("\nAnalyzing sequence lengths...")
193
+ stats = analyze_lengths(input_file, tokenizer)
194
+
195
+ print("\nSequence Length Statistics:")
196
+ print(f" Documents: {stats['num_documents']}")
197
+ print(f" Mean: {stats['mean']:.1f} tokens")
198
+ print(f" Median: {stats['median']:.1f} tokens")
199
+ print(f" Std dev: {stats['std']:.1f} tokens")
200
+ print(f" Min: {stats['min']} tokens")
201
+ print(f" Max: {stats['max']} tokens")
202
+ print(f"\nPercentiles:")
203
+ print(f" 50th: {stats['percentile_50']:.1f} tokens")
204
+ print(f" 75th: {stats['percentile_75']:.1f} tokens")
205
+ print(f" 90th: {stats['percentile_90']:.1f} tokens")
206
+ print(f" 95th: {stats['percentile_95']:.1f} tokens")
207
+ print(f" 99th: {stats['percentile_99']:.1f} tokens")
208
+
209
+ # Warning about truncation
210
+ truncated_pct = (
211
+ np.array([l for l in stats.values() if isinstance(l, (int, float))]) > args.max_length
212
+ ).sum()
213
+ if stats["percentile_95"] > args.max_length:
214
+ print(
215
+ f"\n⚠️ WARNING: max_length={args.max_length} will truncate ~{100-95:.0f}% of sequences"
216
+ )
217
+ print(f" Consider increasing to {int(stats['percentile_95'])} to capture 95% of data")
218
+
219
+ # Tokenize full dataset
220
+ print("\nTokenizing dataset...")
221
+ with open(input_file, "r", encoding="utf-8") as f:
222
+ full_text = f.read()
223
+
224
+ token_ids = tokenizer.encode_as_ids(full_text)
225
+ print(f"Total tokens: {len(token_ids):,}")
226
+
227
+ # Create sequences
228
+ print(f"\nCreating sequences (max_length={args.max_length})...")
229
+ sequences = create_sequences(token_ids, max_length=args.max_length)
230
+ print(f"Total sequences: {len(sequences):,}")
231
+
232
+ # Train/val split
233
+ random.shuffle(sequences)
234
+ split_idx = int(len(sequences) * args.train_ratio)
235
+
236
+ train_sequences = sequences[:split_idx]
237
+ val_sequences = sequences[split_idx:]
238
+
239
+ print(f"\nSplit:")
240
+ print(f" Train: {len(train_sequences):,} sequences")
241
+ print(f" Val: {len(val_sequences):,} sequences")
242
+
243
+ # Save tokenized data
244
+ output_dir = Path(args.output_dir)
245
+ output_dir.mkdir(parents=True, exist_ok=True)
246
+
247
+ print("\nSaving tokenized data...")
248
+
249
+ # Save as numpy arrays (efficient for PyTorch DataLoader)
250
+ train_array = np.array(train_sequences, dtype=np.int32)
251
+ val_array = np.array(val_sequences, dtype=np.int32)
252
+
253
+ np.save(output_dir / "train.npy", train_array)
254
+ np.save(output_dir / "val.npy", val_array)
255
+
256
+ print(f"✅ Train data saved: {output_dir / 'train.npy'}")
257
+ print(f"✅ Val data saved: {output_dir / 'val.npy'}")
258
+
259
+ # Save metadata
260
+ metadata = {
261
+ "vocab_size": tokenizer.vocab_size(),
262
+ "max_length": args.max_length,
263
+ "num_train_sequences": len(train_sequences),
264
+ "num_val_sequences": len(val_sequences),
265
+ "total_tokens": len(token_ids),
266
+ "length_stats": stats,
267
+ "tokenizer_model": str(tokenizer_model),
268
+ "seed": args.seed,
269
+ }
270
+
271
+ with open(output_dir / "metadata.json", "w") as f:
272
+ json.dump(metadata, f, indent=2)
273
+
274
+ print(f"✅ Metadata saved: {output_dir / 'metadata.json'}")
275
+
276
+ print("\n" + "=" * 60)
277
+ print("✅ Tokenization complete!")
278
+ print("=" * 60)
279
+ print("\nNext steps (Phase 2):")
280
+ print("1. Implement Transformer architecture")
281
+ print("2. Create PyTorch Dataset and DataLoader")
282
+ print("3. Define training loop")
283
+
284
+
285
+ if __name__ == "__main__":
286
+ main()
@@ -0,0 +1,218 @@
1
+ """
2
+ SentencePiece Tokenizer Training
3
+
4
+ Trains a BPE (Byte-Pair Encoding) tokenizer on the processed MedQuAD text.
5
+
6
+ Design decisions explained:
7
+ - BPE over WordPiece: Better for medical terminology (handles rare words via subwords)
8
+ - Vocab size: Trade-off between model size and expressiveness
9
+ - Character coverage: 0.9995 captures medical unicode (Greek letters, symbols)
10
+ - No pre-tokenization: Let BPE learn from raw text
11
+
12
+ Common failure modes to avoid:
13
+ - Too small vocab → repetitive, generic text
14
+ - Too large vocab → overfitting, slow training
15
+ - Missing special tokens → can't mark boundaries
16
+ - Wrong normalization → "COVID-19" vs "covid-19" treated differently
17
+ """
18
+
19
+ import argparse
20
+ from pathlib import Path
21
+ import sentencepiece as spm
22
+
23
+
24
+ def train_sentencepiece_tokenizer(
25
+ input_file: Path,
26
+ output_prefix: Path,
27
+ vocab_size: int = 8000,
28
+ model_type: str = "bpe",
29
+ character_coverage: float = 0.9995,
30
+ add_special_tokens: bool = True,
31
+ ):
32
+ """
33
+ Train a SentencePiece tokenizer.
34
+
35
+ Args:
36
+ input_file: Path to training text file
37
+ output_prefix: Output path prefix (will create .model and .vocab files)
38
+ vocab_size: Vocabulary size (default 8000 for small dataset)
39
+ model_type: 'bpe' or 'unigram'
40
+ character_coverage: Character coverage for unicode (0.9995 = medical terms)
41
+ add_special_tokens: Whether to add custom special tokens
42
+
43
+ Design notes:
44
+ - vocab_size=8000: Good for 40K examples on GTX 1080
45
+ * Too small (2K): Poor compression, repetitive outputs
46
+ * Too large (32K): Overfits, wastes memory on rare terms
47
+ - BPE: Deterministic, easier to debug than unigram
48
+ - character_coverage=0.9995: Captures medical unicode (μ, α, β, etc.)
49
+ """
50
+
51
+ # Build training command
52
+ train_args = [
53
+ f"--input={input_file}",
54
+ f"--model_prefix={output_prefix}",
55
+ f"--vocab_size={vocab_size}",
56
+ f"--model_type={model_type}",
57
+ f"--character_coverage={character_coverage}",
58
+ "--pad_id=0", # <pad> token at ID 0
59
+ "--unk_id=1", # <unk> token at ID 1
60
+ "--bos_id=2", # <bos> (beginning of sequence) at ID 2
61
+ "--eos_id=3", # <eos> (end of sequence) at ID 3
62
+ "--pad_piece=[PAD]",
63
+ "--unk_piece=[UNK]",
64
+ "--bos_piece=[BOS]",
65
+ "--eos_piece=[EOS]",
66
+ ]
67
+
68
+ # Add user-defined special tokens if requested
69
+ if add_special_tokens:
70
+ # These match our text formatting (Q:, A:)
71
+ user_defined = "[Q],[A]"
72
+ train_args.append(f"--user_defined_symbols={user_defined}")
73
+
74
+ # Normalization rules (CAREFUL: medical text is case-sensitive)
75
+ # We use minimal normalization to preserve:
76
+ # - "COVID-19" vs "covid-19"
77
+ # - Drug names (proper capitalization matters)
78
+ # - Abbreviations (BP vs bp)
79
+ train_args.extend(
80
+ [
81
+ "--normalization_rule_name=identity", # No normalization
82
+ "--remove_extra_whitespaces=true", # Clean whitespace
83
+ "--split_by_unicode_script=true", # Split CJK if present
84
+ "--split_by_whitespace=true",
85
+ "--split_by_number=true", # Split numbers
86
+ "--max_sentence_length=4096", # Long medical answers
87
+ ]
88
+ )
89
+
90
+ print("=" * 60)
91
+ print("Training SentencePiece Tokenizer")
92
+ print("=" * 60)
93
+ print(f"Input: {input_file}")
94
+ print(f"Output prefix: {output_prefix}")
95
+ print(f"Vocab size: {vocab_size}")
96
+ print(f"Model type: {model_type}")
97
+ print(f"Character coverage: {character_coverage}")
98
+ print()
99
+
100
+ # Train the tokenizer
101
+ spm.SentencePieceTrainer.Train(" ".join(train_args))
102
+
103
+ print("\n✅ Tokenizer training complete!")
104
+ print(f"Model saved: {output_prefix}.model")
105
+ print(f"Vocab saved: {output_prefix}.vocab")
106
+
107
+ # Load and inspect the tokenizer
108
+ sp = spm.SentencePieceProcessor()
109
+ sp.load(f"{output_prefix}.model")
110
+
111
+ print("\n" + "=" * 60)
112
+ print("Tokenizer Statistics")
113
+ print("=" * 60)
114
+ print(f"Vocabulary size: {sp.vocab_size()}")
115
+ print(f"PAD token: {sp.id_to_piece(0)} (ID: 0)")
116
+ print(f"UNK token: {sp.id_to_piece(1)} (ID: 1)")
117
+ print(f"BOS token: {sp.id_to_piece(2)} (ID: 2)")
118
+ print(f"EOS token: {sp.id_to_piece(3)} (ID: 3)")
119
+
120
+ # Test tokenization on medical example
121
+ print("\n" + "=" * 60)
122
+ print("Sample Tokenization")
123
+ print("=" * 60)
124
+
125
+ test_texts = [
126
+ "What is diabetes?",
127
+ "COVID-19 vaccination side effects",
128
+ "Hypertension treatment guidelines",
129
+ ]
130
+
131
+ for text in test_texts:
132
+ tokens = sp.encode_as_pieces(text)
133
+ ids = sp.encode_as_ids(text)
134
+ print(f"\nText: {text}")
135
+ print(f"Tokens: {tokens}")
136
+ print(f"IDs: {ids}")
137
+ print(f"Token count: {len(tokens)}")
138
+
139
+ # Vocabulary inspection
140
+ print("\n" + "=" * 60)
141
+ print("Sample Vocabulary (first 20 tokens)")
142
+ print("=" * 60)
143
+ for i in range(min(20, sp.vocab_size())):
144
+ piece = sp.id_to_piece(i)
145
+ print(f"ID {i:4d}: {piece}")
146
+
147
+ return sp
148
+
149
+
150
+ def main():
151
+ parser = argparse.ArgumentParser(description="Train SentencePiece tokenizer for MedQuAD")
152
+ parser.add_argument(
153
+ "--input-file",
154
+ type=str,
155
+ default="./data/processed/medquad_simple.txt",
156
+ help="Input text file for training",
157
+ )
158
+ parser.add_argument(
159
+ "--output-dir", type=str, default="./tokenizer", help="Output directory for tokenizer files"
160
+ )
161
+ parser.add_argument(
162
+ "--vocab-size",
163
+ type=int,
164
+ default=8000,
165
+ help="Vocabulary size (default: 8000 for small dataset)",
166
+ )
167
+ parser.add_argument(
168
+ "--model-type",
169
+ type=str,
170
+ choices=["bpe", "unigram"],
171
+ default="bpe",
172
+ help="Tokenizer algorithm",
173
+ )
174
+ parser.add_argument(
175
+ "--character-coverage",
176
+ type=float,
177
+ default=0.9995,
178
+ help="Character coverage (0.9995 = captures medical unicode)",
179
+ )
180
+
181
+ args = parser.parse_args()
182
+
183
+ # Validate input file
184
+ input_file = Path(args.input_file)
185
+ if not input_file.exists():
186
+ print(f"❌ Error: Input file not found: {input_file}")
187
+ print("\nPlease run preprocessing first:")
188
+ print(" python preprocess.py")
189
+ return
190
+
191
+ # Create output directory
192
+ output_dir = Path(args.output_dir)
193
+ output_dir.mkdir(parents=True, exist_ok=True)
194
+
195
+ # Output prefix for .model and .vocab files
196
+ output_prefix = output_dir / "medquad_tokenizer"
197
+
198
+ # Train tokenizer
199
+ tokenizer = train_sentencepiece_tokenizer(
200
+ input_file=input_file,
201
+ output_prefix=output_prefix,
202
+ vocab_size=args.vocab_size,
203
+ model_type=args.model_type,
204
+ character_coverage=args.character_coverage,
205
+ add_special_tokens=True,
206
+ )
207
+
208
+ print("\n" + "=" * 60)
209
+ print("✅ Tokenizer training complete!")
210
+ print("=" * 60)
211
+ print("\nNext steps:")
212
+ print("1. Inspect vocabulary for medical terms")
213
+ print("2. Test on sample medical texts")
214
+ print("3. Tokenize full dataset: python tokenizer/tokenize_data.py")
215
+
216
+
217
+ if __name__ == "__main__":
218
+ main()
@@ -0,0 +1 @@
1
+ """Training package."""
@@ -0,0 +1,183 @@
1
+ """
2
+ PyTorch Dataset for Tokenized Data
3
+
4
+ PURPOSE:
5
+ Load tokenized sequences from numpy arrays and provide them as PyTorch tensors
6
+ for training. Handles batching and shuffling efficiently.
7
+
8
+ WHAT THIS FILE DOES:
9
+ 1. Load tokenized data from .npy files
10
+ 2. Create training examples for next-token prediction
11
+ 3. Provide batches to DataLoader
12
+
13
+ TRAINING OBJECTIVE EXPLAINED:
14
+ For causal language modeling, each sequence is both input and target:
15
+ - Input: [token_0, token_1, token_2, ..., token_n-1]
16
+ - Target: [token_1, token_2, token_3, ..., token_n]
17
+
18
+ The model learns to predict token_i given tokens [0, 1, ..., i-1].
19
+
20
+ PACKAGES USED:
21
+ - torch: PyTorch tensors and Dataset
22
+ - numpy: Load .npy files
23
+
24
+ FILES FROM THIS PROJECT:
25
+ - data/tokenized/train.npy (created by tokenize_data.py)
26
+ - data/tokenized/val.npy
27
+
28
+ TENSOR SHAPES:
29
+ - Loaded data: [num_sequences, seq_len]
30
+ - Each batch: [batch_size, seq_len]
31
+ - Input: [batch_size, seq_len]
32
+ - Target: [batch_size, seq_len]
33
+
34
+ COMMON FAILURE MODES:
35
+ - Wrong data path → FileNotFoundError
36
+ - Mismatched shapes → crashes during training
37
+ - Not shuffling train data → poor generalization
38
+ - Loading entire dataset into memory → OOM (for large datasets)
39
+ """
40
+
41
+ import numpy as np
42
+ import torch
43
+ from torch.utils.data import Dataset, DataLoader
44
+ from pathlib import Path
45
+
46
+
47
+ class TokenizedDataset(Dataset):
48
+ """
49
+ Dataset for tokenized sequences.
50
+
51
+ This is a simple in-memory dataset. For larger datasets, you'd use
52
+ memory-mapped arrays or streaming.
53
+
54
+ Each item returns (input_ids, target_ids) for next-token prediction.
55
+ """
56
+
57
+ def __init__(self, data_path: Path):
58
+ """
59
+ Args:
60
+ data_path: Path to .npy file with tokenized sequences
61
+ """
62
+ self.data_path = Path(data_path)
63
+
64
+ if not self.data_path.exists():
65
+ raise FileNotFoundError(f"Data file not found: {data_path}")
66
+
67
+ # Load tokenized sequences
68
+ # Shape: [num_sequences, seq_len]
69
+ self.data = np.load(self.data_path)
70
+
71
+ print(f"Loaded dataset from {data_path}")
72
+ print(f" Shape: {self.data.shape}")
73
+ print(f" Dtype: {self.data.dtype}")
74
+ print(f" Num sequences: {len(self.data)}")
75
+
76
+ def __len__(self) -> int:
77
+ """Number of sequences in dataset."""
78
+ return len(self.data)
79
+
80
+ def __getitem__(self, idx: int) -> tuple:
81
+ """
82
+ Get a single training example.
83
+
84
+ Args:
85
+ idx: Sequence index
86
+
87
+ Returns:
88
+ (input_ids, target_ids) tuple
89
+
90
+ Training objective:
91
+ input: [t0, t1, t2, ..., tn-1]
92
+ target: [t1, t2, t3, ..., tn]
93
+
94
+ The model predicts token at position i using tokens [0, ..., i-1].
95
+ """
96
+ sequence = self.data[idx]
97
+
98
+ # Convert to torch tensor
99
+ sequence = torch.from_numpy(sequence).long()
100
+
101
+ # For causal LM, input and target are the same sequence, shifted by 1
102
+ # Input: [0, 1, 2, 3, ..., n-1]
103
+ # Target: [1, 2, 3, 4, ..., n]
104
+ input_ids = sequence[:-1] # All tokens except last
105
+ target_ids = sequence[1:] # All tokens except first
106
+
107
+ return input_ids, target_ids
108
+
109
+
110
+ def create_dataloaders(
111
+ train_path: Path, val_path: Path, batch_size: int, num_workers: int = 0
112
+ ) -> tuple:
113
+ """
114
+ Create train and validation dataloaders.
115
+
116
+ Args:
117
+ train_path: Path to train .npy file
118
+ val_path: Path to validation .npy file
119
+ batch_size: Batch size
120
+ num_workers: Number of dataloader workers (0 = main process)
121
+
122
+ Returns:
123
+ (train_loader, val_loader) tuple
124
+
125
+ Design decisions:
126
+ - Shuffle train data (prevents overfitting to sequence order)
127
+ - Don't shuffle val data (consistent evaluation)
128
+ - num_workers=0 on small datasets (overhead not worth it)
129
+ - drop_last=True to avoid partial batches (can cause issues with batch norm)
130
+ """
131
+ # Create datasets
132
+ train_dataset = TokenizedDataset(train_path)
133
+ val_dataset = TokenizedDataset(val_path)
134
+
135
+ # Create dataloaders
136
+ train_loader = DataLoader(
137
+ train_dataset,
138
+ batch_size=batch_size,
139
+ shuffle=True, # Shuffle for training
140
+ num_workers=num_workers,
141
+ pin_memory=True, # Faster GPU transfer
142
+ drop_last=True, # Drop incomplete last batch
143
+ )
144
+
145
+ val_loader = DataLoader(
146
+ val_dataset,
147
+ batch_size=batch_size,
148
+ shuffle=False, # Don't shuffle validation
149
+ num_workers=num_workers,
150
+ pin_memory=True,
151
+ drop_last=False, # Keep all validation data
152
+ )
153
+
154
+ print(f"\nDataLoaders created:")
155
+ print(f" Train batches: {len(train_loader)}")
156
+ print(f" Val batches: {len(val_loader)}")
157
+ print(f" Batch size: {batch_size}")
158
+
159
+ return train_loader, val_loader
160
+
161
+
162
+ def get_batch_info(batch: tuple) -> dict:
163
+ """
164
+ Get information about a batch (for debugging).
165
+
166
+ Args:
167
+ batch: (input_ids, target_ids) tuple
168
+
169
+ Returns:
170
+ Dictionary with batch statistics
171
+ """
172
+ input_ids, target_ids = batch
173
+
174
+ return {
175
+ "batch_size": input_ids.size(0),
176
+ "seq_len": input_ids.size(1),
177
+ "input_shape": tuple(input_ids.shape),
178
+ "target_shape": tuple(target_ids.shape),
179
+ "input_dtype": input_ids.dtype,
180
+ "target_dtype": target_ids.dtype,
181
+ "input_device": input_ids.device,
182
+ "target_device": target_ids.device,
183
+ }