gptmed 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gptmed/api.py CHANGED
@@ -39,6 +39,8 @@ from gptmed.configs.train_config import TrainingConfig
39
39
  from gptmed.training.dataset import create_dataloaders
40
40
  from gptmed.training.trainer import Trainer
41
41
  from gptmed.inference.generator import TextGenerator
42
+ from gptmed.services.device_manager import DeviceManager
43
+ from gptmed.services.training_service import TrainingService
42
44
 
43
45
 
44
46
  def create_config(output_path: str = 'training_config.yaml') -> None:
@@ -58,7 +60,11 @@ def create_config(output_path: str = 'training_config.yaml') -> None:
58
60
  create_default_config_file(output_path)
59
61
 
60
62
 
61
- def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
63
+ def train_from_config(
64
+ config_path: str,
65
+ verbose: bool = True,
66
+ device: Optional[str] = None
67
+ ) -> Dict[str, Any]:
62
68
  """
63
69
  Train a GPT model using a YAML configuration file.
64
70
 
@@ -68,6 +74,8 @@ def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
68
74
  Args:
69
75
  config_path: Path to YAML configuration file
70
76
  verbose: Whether to print training progress (default: True)
77
+ device: Device to use ('cuda', 'cpu', or 'auto'). If None, uses config value.
78
+ 'auto' will select best available device.
71
79
 
72
80
  Returns:
73
81
  Dictionary with training results:
@@ -82,13 +90,16 @@ def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
82
90
  >>> gptmed.create_config('config.yaml')
83
91
  >>> # ... edit config.yaml ...
84
92
  >>>
85
- >>> # Train the model
86
- >>> results = gptmed.train_from_config('config.yaml')
93
+ >>> # Train the model on CPU
94
+ >>> results = gptmed.train_from_config('config.yaml', device='cpu')
87
95
  >>> print(f"Best model: {results['best_checkpoint']}")
96
+ >>>
97
+ >>> # Train with auto device selection
98
+ >>> results = gptmed.train_from_config('config.yaml', device='auto')
88
99
 
89
100
  Raises:
90
101
  FileNotFoundError: If config file or data files don't exist
91
- ValueError: If configuration is invalid
102
+ ValueError: If configuration is invalid or device is invalid
92
103
  """
93
104
  if verbose:
94
105
  print("=" * 60)
@@ -111,47 +122,42 @@ def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
111
122
  # Convert to arguments
112
123
  args = config_to_args(config)
113
124
 
114
- # Import here to avoid circular imports
115
- import random
116
- import numpy as np
125
+ # Override device if provided as parameter
126
+ if device is not None:
127
+ # Validate and normalize device
128
+ device = DeviceManager.validate_device(device)
129
+ if verbose:
130
+ print(f"\n⚙️ Device override: {device} (from parameter)")
131
+ args['device'] = device
117
132
 
118
- # Set random seed
119
- def set_seed(seed: int):
120
- random.seed(seed)
121
- np.random.seed(seed)
122
- torch.manual_seed(seed)
123
- if torch.cuda.is_available():
124
- torch.cuda.manual_seed(seed)
125
- torch.cuda.manual_seed_all(seed)
126
- torch.backends.cudnn.deterministic = True
127
- torch.backends.cudnn.benchmark = False
133
+ # Create DeviceManager with the selected device
134
+ device_manager = DeviceManager(
135
+ preferred_device=args['device'],
136
+ allow_fallback=True
137
+ )
138
+
139
+ # Print device information
140
+ device_manager.print_device_info(verbose=verbose)
141
+
142
+ # Create TrainingService with DeviceManager
143
+ training_service = TrainingService(
144
+ device_manager=device_manager,
145
+ verbose=verbose
146
+ )
128
147
 
148
+ # Set random seed
129
149
  if verbose:
130
150
  print(f"\n🎲 Setting random seed: {args['seed']}")
131
- set_seed(args['seed'])
151
+ training_service.set_seed(args['seed'])
132
152
 
133
- # Check device
134
- device = args['device']
135
- if device == 'cuda' and not torch.cuda.is_available():
136
- if verbose:
137
- print("⚠️ CUDA not available, using CPU")
138
- device = 'cpu'
153
+ # Get actual device to use
154
+ actual_device = device_manager.get_device()
139
155
 
140
156
  # Load model config
141
157
  if verbose:
142
158
  print(f"\n🧠 Creating model: {args['model_size']}")
143
159
 
144
- if args['model_size'] == 'tiny':
145
- model_config = get_tiny_config()
146
- elif args['model_size'] == 'small':
147
- model_config = get_small_config()
148
- elif args['model_size'] == 'medium':
149
- model_config = get_medium_config()
150
- else:
151
- raise ValueError(f"Unknown model size: {args['model_size']}")
152
-
153
- # Create model
154
- model = GPTTransformer(model_config)
160
+ model = training_service.create_model(args['model_size'])
155
161
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
156
162
 
157
163
  if verbose:
@@ -159,7 +165,7 @@ def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
159
165
  print(f" Parameters: {total_params:,}")
160
166
  print(f" Memory: ~{total_params * 4 / 1024 / 1024:.2f} MB")
161
167
 
162
- # Load data
168
+ # Load data using TrainingService
163
169
  if verbose:
164
170
  print(f"\n📊 Loading data...")
165
171
  print(f" Train: {args['train_data']}")
@@ -176,7 +182,7 @@ def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
176
182
  print(f" Train batches: {len(train_loader)}")
177
183
  print(f" Val batches: {len(val_loader)}")
178
184
 
179
- # Create training config
185
+ # Create training config with actual device
180
186
  train_config = TrainingConfig(
181
187
  batch_size=args['batch_size'],
182
188
  learning_rate=args['learning_rate'],
@@ -195,7 +201,7 @@ def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
195
201
  val_data_path=args['val_data'],
196
202
  checkpoint_dir=args['checkpoint_dir'],
197
203
  log_dir=args['log_dir'],
198
- device=device,
204
+ device=actual_device, # Use actual device from DeviceManager
199
205
  seed=args['seed'],
200
206
  )
201
207
 
@@ -213,70 +219,17 @@ def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
213
219
  weight_decay=args['weight_decay'],
214
220
  )
215
221
 
216
- # Create trainer
217
- if verbose:
218
- print(f"\n🎯 Initializing trainer...")
219
-
220
- trainer = Trainer(
222
+ # Execute training using TrainingService
223
+ results = training_service.execute_training(
221
224
  model=model,
222
225
  train_loader=train_loader,
223
226
  val_loader=val_loader,
224
227
  optimizer=optimizer,
225
- config=train_config,
226
- device=device,
228
+ train_config=train_config,
229
+ device=actual_device,
230
+ model_config_dict=model.config.to_dict()
227
231
  )
228
232
 
229
- # Resume if requested
230
- if args['resume_from'] is not None:
231
- if verbose:
232
- print(f"\n📥 Resuming from checkpoint: {args['resume_from']}")
233
- trainer.resume_from_checkpoint(Path(args['resume_from']))
234
-
235
- # Start training
236
- if verbose:
237
- print(f"\n{'='*60}")
238
- print("🚀 Starting Training!")
239
- print(f"{'='*60}\n")
240
-
241
- try:
242
- trainer.train()
243
- except KeyboardInterrupt:
244
- if verbose:
245
- print("\n\n⏸️ Training interrupted by user")
246
- print("💾 Saving checkpoint...")
247
- trainer.checkpoint_manager.save_checkpoint(
248
- model=model,
249
- optimizer=optimizer,
250
- step=trainer.global_step,
251
- epoch=trainer.current_epoch,
252
- val_loss=trainer.best_val_loss,
253
- model_config=model_config.to_dict(),
254
- train_config=train_config.to_dict(),
255
- )
256
- if verbose:
257
- print("✓ Checkpoint saved. Resume with resume_from in config.")
258
-
259
- # Return results
260
- best_checkpoint = Path(train_config.checkpoint_dir) / "best_model.pt"
261
-
262
- results = {
263
- 'best_checkpoint': str(best_checkpoint),
264
- 'final_val_loss': trainer.best_val_loss,
265
- 'total_epochs': trainer.current_epoch,
266
- 'checkpoint_dir': train_config.checkpoint_dir,
267
- 'log_dir': train_config.log_dir,
268
- }
269
-
270
- if verbose:
271
- print(f"\n{'='*60}")
272
- print("✅ Training Complete!")
273
- print(f"{'='*60}")
274
- print(f"\n📁 Results:")
275
- print(f" Best checkpoint: {results['best_checkpoint']}")
276
- print(f" Best val loss: {results['final_val_loss']:.4f}")
277
- print(f" Total epochs: {results['total_epochs']}")
278
- print(f" Logs: {results['log_dir']}")
279
-
280
233
  return results
281
234
 
282
235
 
@@ -76,6 +76,15 @@ def validate_config(config: Dict[str, Any]) -> None:
76
76
  raise ValueError("batch_size must be positive")
77
77
  if config['training']['learning_rate'] <= 0:
78
78
  raise ValueError("learning_rate must be positive")
79
+
80
+ # Validate device
81
+ valid_devices = ['cuda', 'cpu', 'auto']
82
+ device_value = config.get('device', {}).get('device', 'cuda').lower()
83
+ if device_value not in valid_devices:
84
+ raise ValueError(
85
+ f"Invalid device: {device_value}. "
86
+ f"Must be one of {valid_devices}"
87
+ )
79
88
 
80
89
 
81
90
  def config_to_args(config: Dict[str, Any]) -> Dict[str, Any]:
@@ -169,7 +178,7 @@ def create_default_config_file(output_path: str = 'training_config.yaml') -> Non
169
178
  'log_interval': 10
170
179
  },
171
180
  'device': {
172
- 'device': 'cuda',
181
+ 'device': 'cuda', # Options: 'cuda', 'cpu', or 'auto'
173
182
  'seed': 42
174
183
  },
175
184
  'advanced': {
@@ -0,0 +1,15 @@
1
+ """
2
+ Services Layer
3
+
4
+ Business logic services following SOLID principles.
5
+ This layer implements the service pattern to encapsulate complex operations.
6
+ """
7
+
8
+ from gptmed.services.device_manager import DeviceManager, DeviceStrategy
9
+ from gptmed.services.training_service import TrainingService
10
+
11
+ __all__ = [
12
+ 'DeviceManager',
13
+ 'DeviceStrategy',
14
+ 'TrainingService',
15
+ ]
@@ -0,0 +1,252 @@
1
+ """
2
+ Device Manager Service
3
+
4
+ PURPOSE:
5
+ Manages device selection and configuration for model training and inference.
6
+ Implements Strategy Pattern for flexible device handling.
7
+
8
+ DESIGN PATTERNS:
9
+ - Strategy Pattern: Different strategies for CPU vs GPU
10
+ - Dependency Injection: DeviceManager can be injected into services
11
+ - Single Responsibility: Only handles device-related concerns
12
+
13
+ WHAT THIS FILE DOES:
14
+ 1. Validates device availability (CUDA check)
15
+ 2. Provides device selection logic with fallback
16
+ 3. Manages device-specific configurations
17
+ 4. Ensures consistent device handling across the codebase
18
+
19
+ PACKAGES USED:
20
+ - torch: Device detection and management
21
+ - abc: Abstract base classes for strategy pattern
22
+ """
23
+
24
+ from abc import ABC, abstractmethod
25
+ from typing import Optional
26
+ import torch
27
+
28
+
29
+ class DeviceStrategy(ABC):
30
+ """
31
+ Abstract base class for device strategies.
32
+ Implements Strategy Pattern for different device types.
33
+ """
34
+
35
+ @abstractmethod
36
+ def get_device(self) -> str:
37
+ """
38
+ Get the device string for PyTorch.
39
+
40
+ Returns:
41
+ Device string ('cuda' or 'cpu')
42
+ """
43
+ pass
44
+
45
+ @abstractmethod
46
+ def is_available(self) -> bool:
47
+ """
48
+ Check if the device is available.
49
+
50
+ Returns:
51
+ True if device is available, False otherwise
52
+ """
53
+ pass
54
+
55
+ @abstractmethod
56
+ def get_device_info(self) -> dict:
57
+ """
58
+ Get information about the device.
59
+
60
+ Returns:
61
+ Dictionary with device information
62
+ """
63
+ pass
64
+
65
+
66
+ class CUDAStrategy(DeviceStrategy):
67
+ """Strategy for CUDA/GPU devices."""
68
+
69
+ def get_device(self) -> str:
70
+ """Get CUDA device if available."""
71
+ return 'cuda' if self.is_available() else 'cpu'
72
+
73
+ def is_available(self) -> bool:
74
+ """Check if CUDA is available."""
75
+ return torch.cuda.is_available()
76
+
77
+ def get_device_info(self) -> dict:
78
+ """Get CUDA device information."""
79
+ if not self.is_available():
80
+ return {
81
+ 'device': 'cuda',
82
+ 'available': False,
83
+ 'message': 'CUDA not available'
84
+ }
85
+
86
+ return {
87
+ 'device': 'cuda',
88
+ 'available': True,
89
+ 'device_name': torch.cuda.get_device_name(0),
90
+ 'device_count': torch.cuda.device_count(),
91
+ 'cuda_version': torch.version.cuda if torch.version.cuda else 'N/A',
92
+ }
93
+
94
+
95
+ class CPUStrategy(DeviceStrategy):
96
+ """Strategy for CPU devices."""
97
+
98
+ def get_device(self) -> str:
99
+ """Always return CPU."""
100
+ return 'cpu'
101
+
102
+ def is_available(self) -> bool:
103
+ """CPU is always available."""
104
+ return True
105
+
106
+ def get_device_info(self) -> dict:
107
+ """Get CPU device information."""
108
+ return {
109
+ 'device': 'cpu',
110
+ 'available': True,
111
+ 'num_threads': torch.get_num_threads(),
112
+ }
113
+
114
+
115
+ class DeviceManager:
116
+ """
117
+ Manages device selection and configuration.
118
+
119
+ Follows Single Responsibility Principle - only handles device concerns.
120
+ Uses Strategy Pattern for different device types.
121
+
122
+ Example:
123
+ >>> device_manager = DeviceManager(preferred_device='cuda')
124
+ >>> device = device_manager.get_device()
125
+ >>> print(f"Using device: {device}")
126
+ """
127
+
128
+ def __init__(self, preferred_device: str = 'cuda', allow_fallback: bool = True):
129
+ """
130
+ Initialize DeviceManager.
131
+
132
+ Args:
133
+ preferred_device: Preferred device ('cuda' or 'cpu')
134
+ allow_fallback: If True, fallback to CPU if CUDA unavailable
135
+ """
136
+ self.preferred_device = preferred_device.lower()
137
+ self.allow_fallback = allow_fallback
138
+
139
+ # Validate device input
140
+ if self.preferred_device not in ['cuda', 'cpu']:
141
+ raise ValueError(
142
+ f"Invalid device: {preferred_device}. Must be 'cuda' or 'cpu'"
143
+ )
144
+
145
+ # Select strategy based on preferred device
146
+ if self.preferred_device == 'cuda':
147
+ self.strategy = CUDAStrategy()
148
+ else:
149
+ self.strategy = CPUStrategy()
150
+
151
+ def get_device(self) -> str:
152
+ """
153
+ Get the actual device to use.
154
+
155
+ Returns fallback device if preferred is unavailable and fallback is allowed.
156
+
157
+ Returns:
158
+ Device string ('cuda' or 'cpu')
159
+
160
+ Raises:
161
+ RuntimeError: If preferred device unavailable and fallback disabled
162
+ """
163
+ if self.strategy.is_available():
164
+ return self.strategy.get_device()
165
+
166
+ # Handle unavailable device
167
+ if self.allow_fallback and self.preferred_device == 'cuda':
168
+ # Fallback to CPU
169
+ return 'cpu'
170
+ else:
171
+ raise RuntimeError(
172
+ f"Device '{self.preferred_device}' is not available and "
173
+ f"fallback is {'disabled' if not self.allow_fallback else 'not applicable'}"
174
+ )
175
+
176
+ def get_device_info(self) -> dict:
177
+ """
178
+ Get information about the current device.
179
+
180
+ Returns:
181
+ Dictionary with device information
182
+ """
183
+ info = self.strategy.get_device_info()
184
+ info['preferred_device'] = self.preferred_device
185
+ info['actual_device'] = self.get_device()
186
+ info['allow_fallback'] = self.allow_fallback
187
+ return info
188
+
189
+ def print_device_info(self, verbose: bool = True) -> None:
190
+ """
191
+ Print device information.
192
+
193
+ Args:
194
+ verbose: If True, print detailed information
195
+ """
196
+ if not verbose:
197
+ return
198
+
199
+ info = self.get_device_info()
200
+ actual = info['actual_device']
201
+ preferred = info['preferred_device']
202
+
203
+ print(f"\n💻 Device Configuration:")
204
+ print(f" Preferred: {preferred}")
205
+ print(f" Using: {actual}")
206
+
207
+ if preferred != actual:
208
+ print(f" ⚠️ Fallback to CPU (CUDA not available)")
209
+
210
+ if actual == 'cuda' and info.get('available'):
211
+ print(f" GPU: {info.get('device_name', 'Unknown')}")
212
+ print(f" CUDA Version: {info.get('cuda_version', 'N/A')}")
213
+ print(f" GPU Count: {info.get('device_count', 0)}")
214
+ elif actual == 'cpu':
215
+ print(f" CPU Threads: {info.get('num_threads', 'N/A')}")
216
+
217
+ @staticmethod
218
+ def validate_device(device: str) -> str:
219
+ """
220
+ Validate and normalize device string.
221
+
222
+ Args:
223
+ device: Device string to validate
224
+
225
+ Returns:
226
+ Normalized device string
227
+
228
+ Raises:
229
+ ValueError: If device is invalid
230
+ """
231
+ device = device.lower().strip()
232
+
233
+ if device not in ['cuda', 'cpu', 'auto']:
234
+ raise ValueError(
235
+ f"Invalid device: '{device}'. Must be 'cuda', 'cpu', or 'auto'"
236
+ )
237
+
238
+ # Auto-select best available device
239
+ if device == 'auto':
240
+ return 'cuda' if torch.cuda.is_available() else 'cpu'
241
+
242
+ return device
243
+
244
+ @staticmethod
245
+ def get_optimal_device() -> str:
246
+ """
247
+ Get the optimal device for the current environment.
248
+
249
+ Returns:
250
+ 'cuda' if available, otherwise 'cpu'
251
+ """
252
+ return 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -0,0 +1,335 @@
1
+ """
2
+ Training Service
3
+
4
+ PURPOSE:
5
+ Encapsulates training logic following Service Layer Pattern.
6
+ Provides a high-level interface for model training with device flexibility.
7
+
8
+ DESIGN PATTERNS:
9
+ - Service Layer Pattern: Business logic separated from API layer
10
+ - Dependency Injection: DeviceManager injected for flexibility
11
+ - Single Responsibility: Only handles training orchestration
12
+ - Open/Closed Principle: Extensible without modification
13
+
14
+ WHAT THIS FILE DOES:
15
+ 1. Orchestrates the training process
16
+ 2. Manages device configuration via DeviceManager
17
+ 3. Coordinates model, data, optimizer, and trainer
18
+ 4. Provides clean interface for training operations
19
+
20
+ PACKAGES USED:
21
+ - torch: PyTorch training
22
+ - pathlib: Path handling
23
+ """
24
+
25
+ import torch
26
+ import random
27
+ import numpy as np
28
+ from pathlib import Path
29
+ from typing import Dict, Any, Optional
30
+
31
+ from gptmed.services.device_manager import DeviceManager
32
+ from gptmed.model.architecture import GPTTransformer
33
+ from gptmed.model.configs.model_config import get_tiny_config, get_small_config, get_medium_config
34
+ from gptmed.configs.train_config import TrainingConfig
35
+ from gptmed.training.dataset import create_dataloaders
36
+ from gptmed.training.trainer import Trainer
37
+
38
+
39
+ class TrainingService:
40
+ """
41
+ High-level service for model training.
42
+
43
+ Implements Service Layer Pattern to encapsulate training logic.
44
+ Uses Dependency Injection for DeviceManager.
45
+
46
+ Example:
47
+ >>> device_manager = DeviceManager(preferred_device='cpu')
48
+ >>> service = TrainingService(device_manager=device_manager)
49
+ >>> results = service.train_from_config('config.yaml', verbose=True)
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ device_manager: Optional[DeviceManager] = None,
55
+ verbose: bool = True
56
+ ):
57
+ """
58
+ Initialize TrainingService.
59
+
60
+ Args:
61
+ device_manager: DeviceManager instance (if None, creates default)
62
+ verbose: Whether to print training information
63
+ """
64
+ self.device_manager = device_manager or DeviceManager(preferred_device='cuda')
65
+ self.verbose = verbose
66
+
67
+ def set_seed(self, seed: int) -> None:
68
+ """
69
+ Set random seeds for reproducibility.
70
+
71
+ Args:
72
+ seed: Random seed value
73
+ """
74
+ random.seed(seed)
75
+ np.random.seed(seed)
76
+ torch.manual_seed(seed)
77
+ if torch.cuda.is_available():
78
+ torch.cuda.manual_seed(seed)
79
+ torch.cuda.manual_seed_all(seed)
80
+ torch.backends.cudnn.deterministic = True
81
+ torch.backends.cudnn.benchmark = False
82
+
83
+ def create_model(self, model_size: str) -> GPTTransformer:
84
+ """
85
+ Create model based on size specification.
86
+
87
+ Args:
88
+ model_size: Model size ('tiny', 'small', or 'medium')
89
+
90
+ Returns:
91
+ GPTTransformer model instance
92
+
93
+ Raises:
94
+ ValueError: If model_size is invalid
95
+ """
96
+ if model_size == 'tiny':
97
+ model_config = get_tiny_config()
98
+ elif model_size == 'small':
99
+ model_config = get_small_config()
100
+ elif model_size == 'medium':
101
+ model_config = get_medium_config()
102
+ else:
103
+ raise ValueError(f"Unknown model size: {model_size}")
104
+
105
+ return GPTTransformer(model_config)
106
+
107
+ def prepare_training(
108
+ self,
109
+ model: GPTTransformer,
110
+ train_config: TrainingConfig,
111
+ device: str
112
+ ) -> tuple:
113
+ """
114
+ Prepare components for training.
115
+
116
+ Args:
117
+ model: Model to train
118
+ train_config: Training configuration
119
+ device: Device to use
120
+
121
+ Returns:
122
+ Tuple of (train_loader, val_loader, optimizer)
123
+ """
124
+ # Load data
125
+ if self.verbose:
126
+ print(f"\n📊 Loading data...")
127
+ print(f" Train: {train_config.train_data_path}")
128
+ print(f" Val: {train_config.val_data_path}")
129
+
130
+ train_loader, val_loader = create_dataloaders(
131
+ train_path=Path(train_config.train_data_path),
132
+ val_path=Path(train_config.val_data_path),
133
+ batch_size=train_config.batch_size,
134
+ num_workers=0,
135
+ )
136
+
137
+ if self.verbose:
138
+ print(f" Train batches: {len(train_loader)}")
139
+ print(f" Val batches: {len(val_loader)}")
140
+
141
+ # Create optimizer
142
+ if self.verbose:
143
+ print(f"\n⚙️ Setting up optimizer...")
144
+ print(f" Learning rate: {train_config.learning_rate}")
145
+ print(f" Weight decay: {train_config.weight_decay}")
146
+
147
+ optimizer = torch.optim.AdamW(
148
+ model.parameters(),
149
+ lr=train_config.learning_rate,
150
+ betas=train_config.betas,
151
+ eps=train_config.eps,
152
+ weight_decay=train_config.weight_decay,
153
+ )
154
+
155
+ return train_loader, val_loader, optimizer
156
+
157
+ def execute_training(
158
+ self,
159
+ model: GPTTransformer,
160
+ train_loader,
161
+ val_loader,
162
+ optimizer,
163
+ train_config: TrainingConfig,
164
+ device: str,
165
+ model_config_dict: dict
166
+ ) -> Dict[str, Any]:
167
+ """
168
+ Execute the training process.
169
+
170
+ Args:
171
+ model: Model to train
172
+ train_loader: Training data loader
173
+ val_loader: Validation data loader
174
+ optimizer: Optimizer
175
+ train_config: Training configuration
176
+ device: Device to use
177
+ model_config_dict: Model configuration as dictionary
178
+
179
+ Returns:
180
+ Dictionary with training results
181
+ """
182
+ # Create trainer
183
+ if self.verbose:
184
+ print(f"\n🎯 Initializing trainer...")
185
+
186
+ trainer = Trainer(
187
+ model=model,
188
+ train_loader=train_loader,
189
+ val_loader=val_loader,
190
+ optimizer=optimizer,
191
+ config=train_config,
192
+ device=device,
193
+ )
194
+
195
+ # Resume if requested
196
+ if hasattr(train_config, 'resume_from') and train_config.resume_from is not None:
197
+ if self.verbose:
198
+ print(f"\n📥 Resuming from checkpoint: {train_config.resume_from}")
199
+ trainer.resume_from_checkpoint(Path(train_config.resume_from))
200
+ elif train_config.checkpoint_dir and hasattr(train_config, 'checkpoint_dir'):
201
+ # Check if there's a resume_from in the checkpoint dir
202
+ resume_path = Path(train_config.checkpoint_dir) / "resume_from.pt"
203
+ if resume_path.exists() and self.verbose:
204
+ print(f"\n📥 Found checkpoint to resume: {resume_path}")
205
+
206
+ # Start training
207
+ if self.verbose:
208
+ print(f"\n{'='*60}")
209
+ print("🚀 Starting Training!")
210
+ print(f"{'='*60}\n")
211
+
212
+ try:
213
+ trainer.train()
214
+ except KeyboardInterrupt:
215
+ if self.verbose:
216
+ print("\n\n⏸️ Training interrupted by user")
217
+ print("💾 Saving checkpoint...")
218
+ trainer.checkpoint_manager.save_checkpoint(
219
+ model=model,
220
+ optimizer=optimizer,
221
+ step=trainer.global_step,
222
+ epoch=trainer.current_epoch,
223
+ val_loss=trainer.best_val_loss,
224
+ model_config=model_config_dict,
225
+ train_config=train_config.to_dict(),
226
+ )
227
+ if self.verbose:
228
+ print("✓ Checkpoint saved. Resume with resume_from in config.")
229
+
230
+ # Return results
231
+ best_checkpoint = Path(train_config.checkpoint_dir) / "best_model.pt"
232
+
233
+ results = {
234
+ 'best_checkpoint': str(best_checkpoint),
235
+ 'final_val_loss': trainer.best_val_loss,
236
+ 'total_epochs': trainer.current_epoch,
237
+ 'checkpoint_dir': train_config.checkpoint_dir,
238
+ 'log_dir': train_config.log_dir,
239
+ }
240
+
241
+ if self.verbose:
242
+ print(f"\n{'='*60}")
243
+ print("✅ Training Complete!")
244
+ print(f"{'='*60}")
245
+ print(f"\n📁 Results:")
246
+ print(f" Best checkpoint: {results['best_checkpoint']}")
247
+ print(f" Best val loss: {results['final_val_loss']:.4f}")
248
+ print(f" Total epochs: {results['total_epochs']}")
249
+ print(f" Logs: {results['log_dir']}")
250
+
251
+ return results
252
+
253
+ def train(
254
+ self,
255
+ model_size: str,
256
+ train_data_path: str,
257
+ val_data_path: str,
258
+ batch_size: int = 16,
259
+ learning_rate: float = 3e-4,
260
+ num_epochs: int = 10,
261
+ checkpoint_dir: str = "./model/checkpoints",
262
+ log_dir: str = "./logs",
263
+ seed: int = 42,
264
+ **kwargs
265
+ ) -> Dict[str, Any]:
266
+ """
267
+ High-level training interface.
268
+
269
+ Args:
270
+ model_size: Model size ('tiny', 'small', 'medium')
271
+ train_data_path: Path to training data
272
+ val_data_path: Path to validation data
273
+ batch_size: Training batch size
274
+ learning_rate: Learning rate
275
+ num_epochs: Number of training epochs
276
+ checkpoint_dir: Directory for checkpoints
277
+ log_dir: Directory for logs
278
+ seed: Random seed
279
+ **kwargs: Additional training config parameters
280
+
281
+ Returns:
282
+ Dictionary with training results
283
+ """
284
+ # Set seed
285
+ if self.verbose:
286
+ print(f"\n🎲 Setting random seed: {seed}")
287
+ self.set_seed(seed)
288
+
289
+ # Get device
290
+ device = self.device_manager.get_device()
291
+ self.device_manager.print_device_info(verbose=self.verbose)
292
+
293
+ # Create model
294
+ if self.verbose:
295
+ print(f"\n🧠 Creating model: {model_size}")
296
+
297
+ model = self.create_model(model_size)
298
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
299
+
300
+ if self.verbose:
301
+ print(f" Model size: {model_size}")
302
+ print(f" Parameters: {total_params:,}")
303
+ print(f" Memory: ~{total_params * 4 / 1024 / 1024:.2f} MB")
304
+
305
+ # Create training config
306
+ train_config = TrainingConfig(
307
+ train_data_path=train_data_path,
308
+ val_data_path=val_data_path,
309
+ batch_size=batch_size,
310
+ learning_rate=learning_rate,
311
+ num_epochs=num_epochs,
312
+ checkpoint_dir=checkpoint_dir,
313
+ log_dir=log_dir,
314
+ device=device,
315
+ seed=seed,
316
+ **{k: v for k, v in kwargs.items() if hasattr(TrainingConfig, k)}
317
+ )
318
+
319
+ # Prepare training components
320
+ train_loader, val_loader, optimizer = self.prepare_training(
321
+ model, train_config, device
322
+ )
323
+
324
+ # Execute training
325
+ results = self.execute_training(
326
+ model=model,
327
+ train_loader=train_loader,
328
+ val_loader=val_loader,
329
+ optimizer=optimizer,
330
+ train_config=train_config,
331
+ device=device,
332
+ model_config_dict=model.config.to_dict()
333
+ )
334
+
335
+ return results
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gptmed
3
- Version: 0.3.3
3
+ Version: 0.3.5
4
4
  Summary: A lightweight GPT-based language model framework for training custom question-answering models on any domain
5
5
  Author-email: Sanjog Sigdel <sigdelsanjog@gmail.com>
6
6
  Maintainer-email: Sanjog Sigdel <sigdelsanjog@gmail.com>
@@ -61,27 +61,6 @@ A lightweight GPT-based language model framework for training custom question-an
61
61
  - 📦 **Lightweight**: Small model size suitable for edge deployment
62
62
  - 🛠️ **Complete Toolkit**: Includes tokenizer training, model training, and inference utilities
63
63
 
64
- ## Table of Contents
65
-
66
- - [Features](#features)
67
- - [Installation](#installation)
68
- - [Quick Start](#quick-start)
69
- - [Package Structure](#package-structure)
70
- - [Core Modules](#core-modules)
71
- - [Model Components](#model-components)
72
- - [Training Components](#training-components)
73
- - [Inference Components](#inference-components)
74
- - [Data Processing](#data-processing)
75
- - [Utilities](#utilities)
76
- - [Model Architecture](#model-architecture)
77
- - [Configuration](#configuration)
78
- - [Documentation](#documentation)
79
- - [Performance](#performance)
80
- - [Examples](#examples)
81
- - [Contributing](#contributing)
82
- - [License](#license)
83
- - [Support](#support)
84
-
85
64
  ## Installation
86
65
 
87
66
  ### From PyPI (Recommended)
@@ -208,134 +187,27 @@ config = TrainingConfig(
208
187
  )
209
188
  ```
210
189
 
211
- ## Package Structure
212
-
213
- ### Core Modules
214
-
215
- The `gptmed` package contains the following main modules:
216
-
217
- ```
218
- gptmed/
219
- ├── model/ # Model architecture and configurations
220
- ├── inference/ # Text generation and sampling
221
- ├── training/ # Training loops and datasets
222
- ├── tokenizer/ # Tokenizer training and data processing
223
- ├── data/ # Data parsers and formatters
224
- ├── configs/ # Training configurations
225
- └── utils/ # Utilities (checkpoints, logging)
226
- ```
227
-
228
- ### Model Components
229
-
230
- **`gptmed.model.architecture`** - GPT Transformer Implementation
231
-
232
- - `GPTTransformer` - Main model class
233
- - `TransformerBlock` - Individual transformer layers
234
- - `MultiHeadAttention` - Attention mechanism
235
- - `FeedForward` - Feed-forward networks
236
- - `RoPEPositionalEncoding` - Rotary position embeddings
237
-
238
- **`gptmed.model.configs`** - Model Configurations
239
-
240
- - `get_tiny_config()` - ~2M parameters (testing)
241
- - `get_small_config()` - ~10M parameters (recommended)
242
- - `get_medium_config()` - ~50M parameters (high quality)
243
- - `ModelConfig` - Custom configuration class
244
-
245
- ### Training Components
246
-
247
- **`gptmed.training`** - Training Pipeline
248
-
249
- - `train.py` - Main training script (CLI: `gptmed-train`)
250
- - `Trainer` - Training loop with checkpointing
251
- - `TokenizedDataset` - PyTorch dataset for tokenized data
252
- - `create_dataloaders()` - DataLoader creation utilities
253
-
254
- **`gptmed.configs`** - Training Configurations
255
-
256
- - `TrainingConfig` - Training hyperparameters
257
- - `get_default_config()` - Default training settings
258
- - `get_quick_test_config()` - Fast testing configuration
259
-
260
- ### Inference Components
261
-
262
- **`gptmed.inference`** - Text Generation
263
-
264
- - `TextGenerator` - Main generation class
265
- - `generator.py` - CLI command (CLI: `gptmed-generate`)
266
- - `sampling.py` - Sampling strategies (top-k, top-p, temperature)
267
- - `decoding_utils.py` - Decoding utilities
268
- - `GenerationConfig` - Generation parameters
269
-
270
- ### Data Processing
271
-
272
- **`gptmed.tokenizer`** - Tokenizer Training & Data Processing
273
-
274
- - `train_tokenizer.py` - Train SentencePiece tokenizer
275
- - `tokenize_data.py` - Convert text to token sequences
276
- - SentencePiece BPE tokenizer support
277
-
278
- **`gptmed.data.parsers`** - Data Parsing & Formatting
279
-
280
- - `MedQuADParser` - XML Q&A parser (example)
281
- - `CausalTextFormatter` - Format Q&A pairs for training
282
- - `FormatConfig` - Formatting configuration
283
-
284
- ### Utilities
285
-
286
- **`gptmed.utils`** - Helper Functions
287
-
288
- - `checkpoints.py` - Model checkpoint management
289
- - `logging.py` - Training metrics logging
290
-
291
- ---
292
-
293
- ## Detailed Project Structure
190
+ ## Project Structure
294
191
 
295
192
  ```
296
193
  gptmed/
297
194
  ├── model/
298
- │ ├── architecture/
299
- │ ├── gpt.py # GPT transformer model
300
- │ │ ├── attention.py # Multi-head attention
301
- │ │ ├── feedforward.py # Feed-forward networks
302
- │ │ └── embeddings.py # Token + positional embeddings
303
- │ └── configs/
304
- │ └── model_config.py # Model size configurations
195
+ │ ├── architecture/ # GPT transformer implementation
196
+ └── configs/ # Model configurations
305
197
  ├── inference/
306
- │ ├── generator.py # Text generation (CLI command)
307
- ├── sampling.py # Sampling strategies
308
- │ ├── decoding_utils.py # Decoding utilities
309
- │ └── generation_config.py # Generation parameters
198
+ │ ├── generator.py # Text generation
199
+ └── sampling.py # Sampling strategies
310
200
  ├── training/
311
- │ ├── train.py # Main training script (CLI command)
312
- │ ├── trainer.py # Training loop
313
- ├── dataset.py # PyTorch dataset
314
- │ └── utils.py # Training utilities
201
+ │ ├── train.py # Training script
202
+ │ ├── trainer.py # Training loop
203
+ └── dataset.py # Data loading
315
204
  ├── tokenizer/
316
- ├── train_tokenizer.py # Train SentencePiece tokenizer
317
- │ └── tokenize_data.py # Tokenize text data
318
- ├── data/
319
- │ └── parsers/
320
- │ ├── medquad_parser.py # Example XML parser
321
- │ └── text_formatter.py # Q&A text formatter
205
+ └── train_tokenizer.py # SentencePiece tokenizer
322
206
  ├── configs/
323
- │ └── train_config.py # Training configurations
207
+ │ └── train_config.py # Training configurations
324
208
  └── utils/
325
- ├── checkpoints.py # Model checkpointing
326
- └── logging.py # Training logging
327
- ```
328
-
329
- ### Command-Line Interface
330
-
331
- The package provides two main CLI commands:
332
-
333
- ```bash
334
- # Train a model
335
- gptmed-train --model-size small --num-epochs 10 --batch-size 16
336
-
337
- # Generate text
338
- gptmed-generate --prompt "Your question?" --max-length 100
209
+ ├── checkpoints.py # Model checkpointing
210
+ └── logging.py # Training logging
339
211
  ```
340
212
 
341
213
  ## Requirements
@@ -1,7 +1,7 @@
1
1
  gptmed/__init__.py,sha256=mwzeW2Qc6j1z5f6HOvZ_BNOnFSncWEK2KEkdqq91yYY,1676
2
- gptmed/api.py,sha256=gUWooWsXDaGb1r22YnzS3w-sU-n-b4gB4-gh0fMsT4A,11109
2
+ gptmed/api.py,sha256=k9a_1F2h__xgKnH2l0FaJqAqu-iTYt5tu_VfVO0UhrA,9806
3
3
  gptmed/configs/__init__.py,sha256=yRa-zgPQ-OCzu8fvCrfWMG-CjF3dru3PZzknzm0oUaQ,23
4
- gptmed/configs/config_loader.py,sha256=ZWdH63XOOu0T8seWBiJFZtzlyFmzHzKmMxon6ZgZHlg,6000
4
+ gptmed/configs/config_loader.py,sha256=3GQ1iCNpdJ5yALWXA3SPPHRkaUO-117vdArEL6u7sK8,6354
5
5
  gptmed/configs/train_config.py,sha256=KqfNBh9hdTTd_6gEAlrClU8sVFSlVDmZJOrf3cPwFe8,4657
6
6
  gptmed/configs/training_config.yaml,sha256=EEZZa3kcsZr3g-_fKDPYZt4_NTpmS-3NvJrTYSWNc8g,2874
7
7
  gptmed/data/__init__.py,sha256=iAHeakB5pBAd7MkmarPPY0UKS9bTaO_winLZ23Y2O90,54
@@ -22,6 +22,9 @@ gptmed/model/architecture/feedforward.py,sha256=uJ5QOlWX0ritKDQLUE7GPmMojelR9-sT
22
22
  gptmed/model/architecture/transformer.py,sha256=H1njPoy0Uam59JbA24C0olEDwPfhh3ev4HsUFRIC_0Y,6626
23
23
  gptmed/model/configs/__init__.py,sha256=LDCWhlCDOU7490wcfSId_jXBPfQrtYQEw8FoD67rqBs,275
24
24
  gptmed/model/configs/model_config.py,sha256=wI-i2Dw_pTdIKCDe1pqLvP3ky3YedEy7DwZYN5lwmKE,4673
25
+ gptmed/services/__init__.py,sha256=FtM7NQ_S4VOfl2n6A6cLcOxG9-w7BK7DicQsUvOMmGE,369
26
+ gptmed/services/device_manager.py,sha256=RSsu0RlsexCIO-p4eejOZAPLgpaVA0y9niTg8wf1luY,7513
27
+ gptmed/services/training_service.py,sha256=o9Kxxoi6CVDvvM9rwGYNX426qTnmqxLXLt_bVi1ZSK4,11253
25
28
  gptmed/tokenizer/__init__.py,sha256=KhLAHPmQyoWhnKDenyIJRxgFflKI7xklip28j4cKfKw,157
26
29
  gptmed/tokenizer/tokenize_data.py,sha256=KgMtMfaz_RtOhN_CrvC267k9ujxRdO89rToVJ6nzdwg,9139
27
30
  gptmed/tokenizer/train_tokenizer.py,sha256=f0Hucyft9e8LU2RtpTqg8h_0SpOC_oMABl0_me-wfL8,7068
@@ -33,9 +36,9 @@ gptmed/training/utils.py,sha256=pJxCwneNr2STITIYwIDCxRzIICDFOxOMzK8DT7ck2oQ,5651
33
36
  gptmed/utils/__init__.py,sha256=XuMhIqOXF7mjnog_6Iky-hSbwvFb0iK42B4iDUpgi0U,44
34
37
  gptmed/utils/checkpoints.py,sha256=L4q1-_4GbHCoD7QuEKYeQ-xXDTF-6sqZOxKQ_LT8YmQ,7112
35
38
  gptmed/utils/logging.py,sha256=7dJc1tayMxCBjFSDXe4r9ACUTpoPTTGsJ0UZMTqZIDY,5303
36
- gptmed-0.3.3.dist-info/licenses/LICENSE,sha256=v2spsd7N1pKFFh2G8wGP_45iwe5S0DYiJzG4im8Rupc,1066
37
- gptmed-0.3.3.dist-info/METADATA,sha256=0ohKwsi3802GMhVUIx2n76i4QHhY0dkzdG4a_g1p_Hw,13605
38
- gptmed-0.3.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
- gptmed-0.3.3.dist-info/entry_points.txt,sha256=ATqOzTtPVdUiFX5ZSeo3n9JkUCqocUxEXTgy1CfNRZE,110
40
- gptmed-0.3.3.dist-info/top_level.txt,sha256=mhyEq3rG33t21ziJz5w3TPgx0RjPf4zXMNUx2JTiNmE,7
41
- gptmed-0.3.3.dist-info/RECORD,,
39
+ gptmed-0.3.5.dist-info/licenses/LICENSE,sha256=v2spsd7N1pKFFh2G8wGP_45iwe5S0DYiJzG4im8Rupc,1066
40
+ gptmed-0.3.5.dist-info/METADATA,sha256=Zx3kFlZiBdkXco_VkEqOnIeasCYrgWl2XP21D2QcmuA,9382
41
+ gptmed-0.3.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
42
+ gptmed-0.3.5.dist-info/entry_points.txt,sha256=ATqOzTtPVdUiFX5ZSeo3n9JkUCqocUxEXTgy1CfNRZE,110
43
+ gptmed-0.3.5.dist-info/top_level.txt,sha256=mhyEq3rG33t21ziJz5w3TPgx0RjPf4zXMNUx2JTiNmE,7
44
+ gptmed-0.3.5.dist-info/RECORD,,
File without changes