gptmed 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gptmed/api.py CHANGED
@@ -39,6 +39,8 @@ from gptmed.configs.train_config import TrainingConfig
39
39
  from gptmed.training.dataset import create_dataloaders
40
40
  from gptmed.training.trainer import Trainer
41
41
  from gptmed.inference.generator import TextGenerator
42
+ from gptmed.services.device_manager import DeviceManager
43
+ from gptmed.services.training_service import TrainingService
42
44
 
43
45
 
44
46
  def create_config(output_path: str = 'training_config.yaml') -> None:
@@ -58,7 +60,11 @@ def create_config(output_path: str = 'training_config.yaml') -> None:
58
60
  create_default_config_file(output_path)
59
61
 
60
62
 
61
- def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
63
+ def train_from_config(
64
+ config_path: str,
65
+ verbose: bool = True,
66
+ device: Optional[str] = None
67
+ ) -> Dict[str, Any]:
62
68
  """
63
69
  Train a GPT model using a YAML configuration file.
64
70
 
@@ -68,6 +74,8 @@ def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
68
74
  Args:
69
75
  config_path: Path to YAML configuration file
70
76
  verbose: Whether to print training progress (default: True)
77
+ device: Device to use ('cuda', 'cpu', or 'auto'). If None, uses config value.
78
+ 'auto' will select best available device.
71
79
 
72
80
  Returns:
73
81
  Dictionary with training results:
@@ -82,13 +90,16 @@ def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
82
90
  >>> gptmed.create_config('config.yaml')
83
91
  >>> # ... edit config.yaml ...
84
92
  >>>
85
- >>> # Train the model
86
- >>> results = gptmed.train_from_config('config.yaml')
93
+ >>> # Train the model on CPU
94
+ >>> results = gptmed.train_from_config('config.yaml', device='cpu')
87
95
  >>> print(f"Best model: {results['best_checkpoint']}")
96
+ >>>
97
+ >>> # Train with auto device selection
98
+ >>> results = gptmed.train_from_config('config.yaml', device='auto')
88
99
 
89
100
  Raises:
90
101
  FileNotFoundError: If config file or data files don't exist
91
- ValueError: If configuration is invalid
102
+ ValueError: If configuration is invalid or device is invalid
92
103
  """
93
104
  if verbose:
94
105
  print("=" * 60)
@@ -111,47 +122,42 @@ def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
111
122
  # Convert to arguments
112
123
  args = config_to_args(config)
113
124
 
114
- # Import here to avoid circular imports
115
- import random
116
- import numpy as np
125
+ # Override device if provided as parameter
126
+ if device is not None:
127
+ # Validate and normalize device
128
+ device = DeviceManager.validate_device(device)
129
+ if verbose:
130
+ print(f"\n⚙️ Device override: {device} (from parameter)")
131
+ args['device'] = device
117
132
 
118
- # Set random seed
119
- def set_seed(seed: int):
120
- random.seed(seed)
121
- np.random.seed(seed)
122
- torch.manual_seed(seed)
123
- if torch.cuda.is_available():
124
- torch.cuda.manual_seed(seed)
125
- torch.cuda.manual_seed_all(seed)
126
- torch.backends.cudnn.deterministic = True
127
- torch.backends.cudnn.benchmark = False
133
+ # Create DeviceManager with the selected device
134
+ device_manager = DeviceManager(
135
+ preferred_device=args['device'],
136
+ allow_fallback=True
137
+ )
138
+
139
+ # Print device information
140
+ device_manager.print_device_info(verbose=verbose)
141
+
142
+ # Create TrainingService with DeviceManager
143
+ training_service = TrainingService(
144
+ device_manager=device_manager,
145
+ verbose=verbose
146
+ )
128
147
 
148
+ # Set random seed
129
149
  if verbose:
130
150
  print(f"\n🎲 Setting random seed: {args['seed']}")
131
- set_seed(args['seed'])
151
+ training_service.set_seed(args['seed'])
132
152
 
133
- # Check device
134
- device = args['device']
135
- if device == 'cuda' and not torch.cuda.is_available():
136
- if verbose:
137
- print("⚠️ CUDA not available, using CPU")
138
- device = 'cpu'
153
+ # Get actual device to use
154
+ actual_device = device_manager.get_device()
139
155
 
140
156
  # Load model config
141
157
  if verbose:
142
158
  print(f"\n🧠 Creating model: {args['model_size']}")
143
159
 
144
- if args['model_size'] == 'tiny':
145
- model_config = get_tiny_config()
146
- elif args['model_size'] == 'small':
147
- model_config = get_small_config()
148
- elif args['model_size'] == 'medium':
149
- model_config = get_medium_config()
150
- else:
151
- raise ValueError(f"Unknown model size: {args['model_size']}")
152
-
153
- # Create model
154
- model = GPTTransformer(model_config)
160
+ model = training_service.create_model(args['model_size'])
155
161
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
156
162
 
157
163
  if verbose:
@@ -159,7 +165,7 @@ def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
159
165
  print(f" Parameters: {total_params:,}")
160
166
  print(f" Memory: ~{total_params * 4 / 1024 / 1024:.2f} MB")
161
167
 
162
- # Load data
168
+ # Load data using TrainingService
163
169
  if verbose:
164
170
  print(f"\n📊 Loading data...")
165
171
  print(f" Train: {args['train_data']}")
@@ -176,7 +182,7 @@ def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
176
182
  print(f" Train batches: {len(train_loader)}")
177
183
  print(f" Val batches: {len(val_loader)}")
178
184
 
179
- # Create training config
185
+ # Create training config with actual device
180
186
  train_config = TrainingConfig(
181
187
  batch_size=args['batch_size'],
182
188
  learning_rate=args['learning_rate'],
@@ -195,7 +201,7 @@ def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
195
201
  val_data_path=args['val_data'],
196
202
  checkpoint_dir=args['checkpoint_dir'],
197
203
  log_dir=args['log_dir'],
198
- device=device,
204
+ device=actual_device, # Use actual device from DeviceManager
199
205
  seed=args['seed'],
200
206
  )
201
207
 
@@ -213,70 +219,17 @@ def train_from_config(config_path: str, verbose: bool = True) -> Dict[str, Any]:
213
219
  weight_decay=args['weight_decay'],
214
220
  )
215
221
 
216
- # Create trainer
217
- if verbose:
218
- print(f"\n🎯 Initializing trainer...")
219
-
220
- trainer = Trainer(
222
+ # Execute training using TrainingService
223
+ results = training_service.execute_training(
221
224
  model=model,
222
225
  train_loader=train_loader,
223
226
  val_loader=val_loader,
224
227
  optimizer=optimizer,
225
- config=train_config,
226
- device=device,
228
+ train_config=train_config,
229
+ device=actual_device,
230
+ model_config_dict=model.config.to_dict()
227
231
  )
228
232
 
229
- # Resume if requested
230
- if args['resume_from'] is not None:
231
- if verbose:
232
- print(f"\n📥 Resuming from checkpoint: {args['resume_from']}")
233
- trainer.resume_from_checkpoint(Path(args['resume_from']))
234
-
235
- # Start training
236
- if verbose:
237
- print(f"\n{'='*60}")
238
- print("🚀 Starting Training!")
239
- print(f"{'='*60}\n")
240
-
241
- try:
242
- trainer.train()
243
- except KeyboardInterrupt:
244
- if verbose:
245
- print("\n\n⏸️ Training interrupted by user")
246
- print("💾 Saving checkpoint...")
247
- trainer.checkpoint_manager.save_checkpoint(
248
- model=model,
249
- optimizer=optimizer,
250
- step=trainer.global_step,
251
- epoch=trainer.current_epoch,
252
- val_loss=trainer.best_val_loss,
253
- model_config=model_config.to_dict(),
254
- train_config=train_config.to_dict(),
255
- )
256
- if verbose:
257
- print("✓ Checkpoint saved. Resume with resume_from in config.")
258
-
259
- # Return results
260
- best_checkpoint = Path(train_config.checkpoint_dir) / "best_model.pt"
261
-
262
- results = {
263
- 'best_checkpoint': str(best_checkpoint),
264
- 'final_val_loss': trainer.best_val_loss,
265
- 'total_epochs': trainer.current_epoch,
266
- 'checkpoint_dir': train_config.checkpoint_dir,
267
- 'log_dir': train_config.log_dir,
268
- }
269
-
270
- if verbose:
271
- print(f"\n{'='*60}")
272
- print("✅ Training Complete!")
273
- print(f"{'='*60}")
274
- print(f"\n📁 Results:")
275
- print(f" Best checkpoint: {results['best_checkpoint']}")
276
- print(f" Best val loss: {results['final_val_loss']:.4f}")
277
- print(f" Total epochs: {results['total_epochs']}")
278
- print(f" Logs: {results['log_dir']}")
279
-
280
233
  return results
281
234
 
282
235
 
@@ -76,6 +76,15 @@ def validate_config(config: Dict[str, Any]) -> None:
76
76
  raise ValueError("batch_size must be positive")
77
77
  if config['training']['learning_rate'] <= 0:
78
78
  raise ValueError("learning_rate must be positive")
79
+
80
+ # Validate device
81
+ valid_devices = ['cuda', 'cpu', 'auto']
82
+ device_value = config.get('device', {}).get('device', 'cuda').lower()
83
+ if device_value not in valid_devices:
84
+ raise ValueError(
85
+ f"Invalid device: {device_value}. "
86
+ f"Must be one of {valid_devices}"
87
+ )
79
88
 
80
89
 
81
90
  def config_to_args(config: Dict[str, Any]) -> Dict[str, Any]:
@@ -169,7 +178,7 @@ def create_default_config_file(output_path: str = 'training_config.yaml') -> Non
169
178
  'log_interval': 10
170
179
  },
171
180
  'device': {
172
- 'device': 'cuda',
181
+ 'device': 'cuda', # Options: 'cuda', 'cpu', or 'auto'
173
182
  'seed': 42
174
183
  },
175
184
  'advanced': {
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gptmed
3
- Version: 0.3.3
3
+ Version: 0.3.4
4
4
  Summary: A lightweight GPT-based language model framework for training custom question-answering models on any domain
5
5
  Author-email: Sanjog Sigdel <sigdelsanjog@gmail.com>
6
6
  Maintainer-email: Sanjog Sigdel <sigdelsanjog@gmail.com>
@@ -61,27 +61,6 @@ A lightweight GPT-based language model framework for training custom question-an
61
61
  - 📦 **Lightweight**: Small model size suitable for edge deployment
62
62
  - 🛠️ **Complete Toolkit**: Includes tokenizer training, model training, and inference utilities
63
63
 
64
- ## Table of Contents
65
-
66
- - [Features](#features)
67
- - [Installation](#installation)
68
- - [Quick Start](#quick-start)
69
- - [Package Structure](#package-structure)
70
- - [Core Modules](#core-modules)
71
- - [Model Components](#model-components)
72
- - [Training Components](#training-components)
73
- - [Inference Components](#inference-components)
74
- - [Data Processing](#data-processing)
75
- - [Utilities](#utilities)
76
- - [Model Architecture](#model-architecture)
77
- - [Configuration](#configuration)
78
- - [Documentation](#documentation)
79
- - [Performance](#performance)
80
- - [Examples](#examples)
81
- - [Contributing](#contributing)
82
- - [License](#license)
83
- - [Support](#support)
84
-
85
64
  ## Installation
86
65
 
87
66
  ### From PyPI (Recommended)
@@ -208,134 +187,27 @@ config = TrainingConfig(
208
187
  )
209
188
  ```
210
189
 
211
- ## Package Structure
212
-
213
- ### Core Modules
214
-
215
- The `gptmed` package contains the following main modules:
216
-
217
- ```
218
- gptmed/
219
- ├── model/ # Model architecture and configurations
220
- ├── inference/ # Text generation and sampling
221
- ├── training/ # Training loops and datasets
222
- ├── tokenizer/ # Tokenizer training and data processing
223
- ├── data/ # Data parsers and formatters
224
- ├── configs/ # Training configurations
225
- └── utils/ # Utilities (checkpoints, logging)
226
- ```
227
-
228
- ### Model Components
229
-
230
- **`gptmed.model.architecture`** - GPT Transformer Implementation
231
-
232
- - `GPTTransformer` - Main model class
233
- - `TransformerBlock` - Individual transformer layers
234
- - `MultiHeadAttention` - Attention mechanism
235
- - `FeedForward` - Feed-forward networks
236
- - `RoPEPositionalEncoding` - Rotary position embeddings
237
-
238
- **`gptmed.model.configs`** - Model Configurations
239
-
240
- - `get_tiny_config()` - ~2M parameters (testing)
241
- - `get_small_config()` - ~10M parameters (recommended)
242
- - `get_medium_config()` - ~50M parameters (high quality)
243
- - `ModelConfig` - Custom configuration class
244
-
245
- ### Training Components
246
-
247
- **`gptmed.training`** - Training Pipeline
248
-
249
- - `train.py` - Main training script (CLI: `gptmed-train`)
250
- - `Trainer` - Training loop with checkpointing
251
- - `TokenizedDataset` - PyTorch dataset for tokenized data
252
- - `create_dataloaders()` - DataLoader creation utilities
253
-
254
- **`gptmed.configs`** - Training Configurations
255
-
256
- - `TrainingConfig` - Training hyperparameters
257
- - `get_default_config()` - Default training settings
258
- - `get_quick_test_config()` - Fast testing configuration
259
-
260
- ### Inference Components
261
-
262
- **`gptmed.inference`** - Text Generation
263
-
264
- - `TextGenerator` - Main generation class
265
- - `generator.py` - CLI command (CLI: `gptmed-generate`)
266
- - `sampling.py` - Sampling strategies (top-k, top-p, temperature)
267
- - `decoding_utils.py` - Decoding utilities
268
- - `GenerationConfig` - Generation parameters
269
-
270
- ### Data Processing
271
-
272
- **`gptmed.tokenizer`** - Tokenizer Training & Data Processing
273
-
274
- - `train_tokenizer.py` - Train SentencePiece tokenizer
275
- - `tokenize_data.py` - Convert text to token sequences
276
- - SentencePiece BPE tokenizer support
277
-
278
- **`gptmed.data.parsers`** - Data Parsing & Formatting
279
-
280
- - `MedQuADParser` - XML Q&A parser (example)
281
- - `CausalTextFormatter` - Format Q&A pairs for training
282
- - `FormatConfig` - Formatting configuration
283
-
284
- ### Utilities
285
-
286
- **`gptmed.utils`** - Helper Functions
287
-
288
- - `checkpoints.py` - Model checkpoint management
289
- - `logging.py` - Training metrics logging
290
-
291
- ---
292
-
293
- ## Detailed Project Structure
190
+ ## Project Structure
294
191
 
295
192
  ```
296
193
  gptmed/
297
194
  ├── model/
298
- │ ├── architecture/
299
- │ ├── gpt.py # GPT transformer model
300
- │ │ ├── attention.py # Multi-head attention
301
- │ │ ├── feedforward.py # Feed-forward networks
302
- │ │ └── embeddings.py # Token + positional embeddings
303
- │ └── configs/
304
- │ └── model_config.py # Model size configurations
195
+ │ ├── architecture/ # GPT transformer implementation
196
+ └── configs/ # Model configurations
305
197
  ├── inference/
306
- │ ├── generator.py # Text generation (CLI command)
307
- ├── sampling.py # Sampling strategies
308
- │ ├── decoding_utils.py # Decoding utilities
309
- │ └── generation_config.py # Generation parameters
198
+ │ ├── generator.py # Text generation
199
+ └── sampling.py # Sampling strategies
310
200
  ├── training/
311
- │ ├── train.py # Main training script (CLI command)
312
- │ ├── trainer.py # Training loop
313
- ├── dataset.py # PyTorch dataset
314
- │ └── utils.py # Training utilities
201
+ │ ├── train.py # Training script
202
+ │ ├── trainer.py # Training loop
203
+ └── dataset.py # Data loading
315
204
  ├── tokenizer/
316
- ├── train_tokenizer.py # Train SentencePiece tokenizer
317
- │ └── tokenize_data.py # Tokenize text data
318
- ├── data/
319
- │ └── parsers/
320
- │ ├── medquad_parser.py # Example XML parser
321
- │ └── text_formatter.py # Q&A text formatter
205
+ └── train_tokenizer.py # SentencePiece tokenizer
322
206
  ├── configs/
323
- │ └── train_config.py # Training configurations
207
+ │ └── train_config.py # Training configurations
324
208
  └── utils/
325
- ├── checkpoints.py # Model checkpointing
326
- └── logging.py # Training logging
327
- ```
328
-
329
- ### Command-Line Interface
330
-
331
- The package provides two main CLI commands:
332
-
333
- ```bash
334
- # Train a model
335
- gptmed-train --model-size small --num-epochs 10 --batch-size 16
336
-
337
- # Generate text
338
- gptmed-generate --prompt "Your question?" --max-length 100
209
+ ├── checkpoints.py # Model checkpointing
210
+ └── logging.py # Training logging
339
211
  ```
340
212
 
341
213
  ## Requirements
@@ -1,7 +1,7 @@
1
1
  gptmed/__init__.py,sha256=mwzeW2Qc6j1z5f6HOvZ_BNOnFSncWEK2KEkdqq91yYY,1676
2
- gptmed/api.py,sha256=gUWooWsXDaGb1r22YnzS3w-sU-n-b4gB4-gh0fMsT4A,11109
2
+ gptmed/api.py,sha256=k9a_1F2h__xgKnH2l0FaJqAqu-iTYt5tu_VfVO0UhrA,9806
3
3
  gptmed/configs/__init__.py,sha256=yRa-zgPQ-OCzu8fvCrfWMG-CjF3dru3PZzknzm0oUaQ,23
4
- gptmed/configs/config_loader.py,sha256=ZWdH63XOOu0T8seWBiJFZtzlyFmzHzKmMxon6ZgZHlg,6000
4
+ gptmed/configs/config_loader.py,sha256=3GQ1iCNpdJ5yALWXA3SPPHRkaUO-117vdArEL6u7sK8,6354
5
5
  gptmed/configs/train_config.py,sha256=KqfNBh9hdTTd_6gEAlrClU8sVFSlVDmZJOrf3cPwFe8,4657
6
6
  gptmed/configs/training_config.yaml,sha256=EEZZa3kcsZr3g-_fKDPYZt4_NTpmS-3NvJrTYSWNc8g,2874
7
7
  gptmed/data/__init__.py,sha256=iAHeakB5pBAd7MkmarPPY0UKS9bTaO_winLZ23Y2O90,54
@@ -33,9 +33,9 @@ gptmed/training/utils.py,sha256=pJxCwneNr2STITIYwIDCxRzIICDFOxOMzK8DT7ck2oQ,5651
33
33
  gptmed/utils/__init__.py,sha256=XuMhIqOXF7mjnog_6Iky-hSbwvFb0iK42B4iDUpgi0U,44
34
34
  gptmed/utils/checkpoints.py,sha256=L4q1-_4GbHCoD7QuEKYeQ-xXDTF-6sqZOxKQ_LT8YmQ,7112
35
35
  gptmed/utils/logging.py,sha256=7dJc1tayMxCBjFSDXe4r9ACUTpoPTTGsJ0UZMTqZIDY,5303
36
- gptmed-0.3.3.dist-info/licenses/LICENSE,sha256=v2spsd7N1pKFFh2G8wGP_45iwe5S0DYiJzG4im8Rupc,1066
37
- gptmed-0.3.3.dist-info/METADATA,sha256=0ohKwsi3802GMhVUIx2n76i4QHhY0dkzdG4a_g1p_Hw,13605
38
- gptmed-0.3.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
- gptmed-0.3.3.dist-info/entry_points.txt,sha256=ATqOzTtPVdUiFX5ZSeo3n9JkUCqocUxEXTgy1CfNRZE,110
40
- gptmed-0.3.3.dist-info/top_level.txt,sha256=mhyEq3rG33t21ziJz5w3TPgx0RjPf4zXMNUx2JTiNmE,7
41
- gptmed-0.3.3.dist-info/RECORD,,
36
+ gptmed-0.3.4.dist-info/licenses/LICENSE,sha256=v2spsd7N1pKFFh2G8wGP_45iwe5S0DYiJzG4im8Rupc,1066
37
+ gptmed-0.3.4.dist-info/METADATA,sha256=G86yfOKlnK4YfNvC6HAAY_z2Z_rhWoSF2_3508mebKA,9382
38
+ gptmed-0.3.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
+ gptmed-0.3.4.dist-info/entry_points.txt,sha256=ATqOzTtPVdUiFX5ZSeo3n9JkUCqocUxEXTgy1CfNRZE,110
40
+ gptmed-0.3.4.dist-info/top_level.txt,sha256=mhyEq3rG33t21ziJz5w3TPgx0RjPf4zXMNUx2JTiNmE,7
41
+ gptmed-0.3.4.dist-info/RECORD,,
File without changes