gptmed 0.3.4__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gptmed/services/__init__.py +15 -0
- gptmed/services/device_manager.py +252 -0
- gptmed/services/training_service.py +335 -0
- {gptmed-0.3.4.dist-info → gptmed-0.3.5.dist-info}/METADATA +1 -1
- {gptmed-0.3.4.dist-info → gptmed-0.3.5.dist-info}/RECORD +9 -6
- {gptmed-0.3.4.dist-info → gptmed-0.3.5.dist-info}/WHEEL +0 -0
- {gptmed-0.3.4.dist-info → gptmed-0.3.5.dist-info}/entry_points.txt +0 -0
- {gptmed-0.3.4.dist-info → gptmed-0.3.5.dist-info}/licenses/LICENSE +0 -0
- {gptmed-0.3.4.dist-info → gptmed-0.3.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Services Layer
|
|
3
|
+
|
|
4
|
+
Business logic services following SOLID principles.
|
|
5
|
+
This layer implements the service pattern to encapsulate complex operations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from gptmed.services.device_manager import DeviceManager, DeviceStrategy
|
|
9
|
+
from gptmed.services.training_service import TrainingService
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
'DeviceManager',
|
|
13
|
+
'DeviceStrategy',
|
|
14
|
+
'TrainingService',
|
|
15
|
+
]
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Device Manager Service
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Manages device selection and configuration for model training and inference.
|
|
6
|
+
Implements Strategy Pattern for flexible device handling.
|
|
7
|
+
|
|
8
|
+
DESIGN PATTERNS:
|
|
9
|
+
- Strategy Pattern: Different strategies for CPU vs GPU
|
|
10
|
+
- Dependency Injection: DeviceManager can be injected into services
|
|
11
|
+
- Single Responsibility: Only handles device-related concerns
|
|
12
|
+
|
|
13
|
+
WHAT THIS FILE DOES:
|
|
14
|
+
1. Validates device availability (CUDA check)
|
|
15
|
+
2. Provides device selection logic with fallback
|
|
16
|
+
3. Manages device-specific configurations
|
|
17
|
+
4. Ensures consistent device handling across the codebase
|
|
18
|
+
|
|
19
|
+
PACKAGES USED:
|
|
20
|
+
- torch: Device detection and management
|
|
21
|
+
- abc: Abstract base classes for strategy pattern
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from abc import ABC, abstractmethod
|
|
25
|
+
from typing import Optional
|
|
26
|
+
import torch
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class DeviceStrategy(ABC):
|
|
30
|
+
"""
|
|
31
|
+
Abstract base class for device strategies.
|
|
32
|
+
Implements Strategy Pattern for different device types.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def get_device(self) -> str:
|
|
37
|
+
"""
|
|
38
|
+
Get the device string for PyTorch.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Device string ('cuda' or 'cpu')
|
|
42
|
+
"""
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def is_available(self) -> bool:
|
|
47
|
+
"""
|
|
48
|
+
Check if the device is available.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
True if device is available, False otherwise
|
|
52
|
+
"""
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def get_device_info(self) -> dict:
|
|
57
|
+
"""
|
|
58
|
+
Get information about the device.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Dictionary with device information
|
|
62
|
+
"""
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class CUDAStrategy(DeviceStrategy):
|
|
67
|
+
"""Strategy for CUDA/GPU devices."""
|
|
68
|
+
|
|
69
|
+
def get_device(self) -> str:
|
|
70
|
+
"""Get CUDA device if available."""
|
|
71
|
+
return 'cuda' if self.is_available() else 'cpu'
|
|
72
|
+
|
|
73
|
+
def is_available(self) -> bool:
|
|
74
|
+
"""Check if CUDA is available."""
|
|
75
|
+
return torch.cuda.is_available()
|
|
76
|
+
|
|
77
|
+
def get_device_info(self) -> dict:
|
|
78
|
+
"""Get CUDA device information."""
|
|
79
|
+
if not self.is_available():
|
|
80
|
+
return {
|
|
81
|
+
'device': 'cuda',
|
|
82
|
+
'available': False,
|
|
83
|
+
'message': 'CUDA not available'
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
return {
|
|
87
|
+
'device': 'cuda',
|
|
88
|
+
'available': True,
|
|
89
|
+
'device_name': torch.cuda.get_device_name(0),
|
|
90
|
+
'device_count': torch.cuda.device_count(),
|
|
91
|
+
'cuda_version': torch.version.cuda if torch.version.cuda else 'N/A',
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class CPUStrategy(DeviceStrategy):
|
|
96
|
+
"""Strategy for CPU devices."""
|
|
97
|
+
|
|
98
|
+
def get_device(self) -> str:
|
|
99
|
+
"""Always return CPU."""
|
|
100
|
+
return 'cpu'
|
|
101
|
+
|
|
102
|
+
def is_available(self) -> bool:
|
|
103
|
+
"""CPU is always available."""
|
|
104
|
+
return True
|
|
105
|
+
|
|
106
|
+
def get_device_info(self) -> dict:
|
|
107
|
+
"""Get CPU device information."""
|
|
108
|
+
return {
|
|
109
|
+
'device': 'cpu',
|
|
110
|
+
'available': True,
|
|
111
|
+
'num_threads': torch.get_num_threads(),
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class DeviceManager:
|
|
116
|
+
"""
|
|
117
|
+
Manages device selection and configuration.
|
|
118
|
+
|
|
119
|
+
Follows Single Responsibility Principle - only handles device concerns.
|
|
120
|
+
Uses Strategy Pattern for different device types.
|
|
121
|
+
|
|
122
|
+
Example:
|
|
123
|
+
>>> device_manager = DeviceManager(preferred_device='cuda')
|
|
124
|
+
>>> device = device_manager.get_device()
|
|
125
|
+
>>> print(f"Using device: {device}")
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
def __init__(self, preferred_device: str = 'cuda', allow_fallback: bool = True):
|
|
129
|
+
"""
|
|
130
|
+
Initialize DeviceManager.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
preferred_device: Preferred device ('cuda' or 'cpu')
|
|
134
|
+
allow_fallback: If True, fallback to CPU if CUDA unavailable
|
|
135
|
+
"""
|
|
136
|
+
self.preferred_device = preferred_device.lower()
|
|
137
|
+
self.allow_fallback = allow_fallback
|
|
138
|
+
|
|
139
|
+
# Validate device input
|
|
140
|
+
if self.preferred_device not in ['cuda', 'cpu']:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
f"Invalid device: {preferred_device}. Must be 'cuda' or 'cpu'"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Select strategy based on preferred device
|
|
146
|
+
if self.preferred_device == 'cuda':
|
|
147
|
+
self.strategy = CUDAStrategy()
|
|
148
|
+
else:
|
|
149
|
+
self.strategy = CPUStrategy()
|
|
150
|
+
|
|
151
|
+
def get_device(self) -> str:
|
|
152
|
+
"""
|
|
153
|
+
Get the actual device to use.
|
|
154
|
+
|
|
155
|
+
Returns fallback device if preferred is unavailable and fallback is allowed.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Device string ('cuda' or 'cpu')
|
|
159
|
+
|
|
160
|
+
Raises:
|
|
161
|
+
RuntimeError: If preferred device unavailable and fallback disabled
|
|
162
|
+
"""
|
|
163
|
+
if self.strategy.is_available():
|
|
164
|
+
return self.strategy.get_device()
|
|
165
|
+
|
|
166
|
+
# Handle unavailable device
|
|
167
|
+
if self.allow_fallback and self.preferred_device == 'cuda':
|
|
168
|
+
# Fallback to CPU
|
|
169
|
+
return 'cpu'
|
|
170
|
+
else:
|
|
171
|
+
raise RuntimeError(
|
|
172
|
+
f"Device '{self.preferred_device}' is not available and "
|
|
173
|
+
f"fallback is {'disabled' if not self.allow_fallback else 'not applicable'}"
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def get_device_info(self) -> dict:
|
|
177
|
+
"""
|
|
178
|
+
Get information about the current device.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Dictionary with device information
|
|
182
|
+
"""
|
|
183
|
+
info = self.strategy.get_device_info()
|
|
184
|
+
info['preferred_device'] = self.preferred_device
|
|
185
|
+
info['actual_device'] = self.get_device()
|
|
186
|
+
info['allow_fallback'] = self.allow_fallback
|
|
187
|
+
return info
|
|
188
|
+
|
|
189
|
+
def print_device_info(self, verbose: bool = True) -> None:
|
|
190
|
+
"""
|
|
191
|
+
Print device information.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
verbose: If True, print detailed information
|
|
195
|
+
"""
|
|
196
|
+
if not verbose:
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
info = self.get_device_info()
|
|
200
|
+
actual = info['actual_device']
|
|
201
|
+
preferred = info['preferred_device']
|
|
202
|
+
|
|
203
|
+
print(f"\n💻 Device Configuration:")
|
|
204
|
+
print(f" Preferred: {preferred}")
|
|
205
|
+
print(f" Using: {actual}")
|
|
206
|
+
|
|
207
|
+
if preferred != actual:
|
|
208
|
+
print(f" ⚠️ Fallback to CPU (CUDA not available)")
|
|
209
|
+
|
|
210
|
+
if actual == 'cuda' and info.get('available'):
|
|
211
|
+
print(f" GPU: {info.get('device_name', 'Unknown')}")
|
|
212
|
+
print(f" CUDA Version: {info.get('cuda_version', 'N/A')}")
|
|
213
|
+
print(f" GPU Count: {info.get('device_count', 0)}")
|
|
214
|
+
elif actual == 'cpu':
|
|
215
|
+
print(f" CPU Threads: {info.get('num_threads', 'N/A')}")
|
|
216
|
+
|
|
217
|
+
@staticmethod
|
|
218
|
+
def validate_device(device: str) -> str:
|
|
219
|
+
"""
|
|
220
|
+
Validate and normalize device string.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
device: Device string to validate
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Normalized device string
|
|
227
|
+
|
|
228
|
+
Raises:
|
|
229
|
+
ValueError: If device is invalid
|
|
230
|
+
"""
|
|
231
|
+
device = device.lower().strip()
|
|
232
|
+
|
|
233
|
+
if device not in ['cuda', 'cpu', 'auto']:
|
|
234
|
+
raise ValueError(
|
|
235
|
+
f"Invalid device: '{device}'. Must be 'cuda', 'cpu', or 'auto'"
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Auto-select best available device
|
|
239
|
+
if device == 'auto':
|
|
240
|
+
return 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
241
|
+
|
|
242
|
+
return device
|
|
243
|
+
|
|
244
|
+
@staticmethod
|
|
245
|
+
def get_optimal_device() -> str:
|
|
246
|
+
"""
|
|
247
|
+
Get the optimal device for the current environment.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
'cuda' if available, otherwise 'cpu'
|
|
251
|
+
"""
|
|
252
|
+
return 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Training Service
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Encapsulates training logic following Service Layer Pattern.
|
|
6
|
+
Provides a high-level interface for model training with device flexibility.
|
|
7
|
+
|
|
8
|
+
DESIGN PATTERNS:
|
|
9
|
+
- Service Layer Pattern: Business logic separated from API layer
|
|
10
|
+
- Dependency Injection: DeviceManager injected for flexibility
|
|
11
|
+
- Single Responsibility: Only handles training orchestration
|
|
12
|
+
- Open/Closed Principle: Extensible without modification
|
|
13
|
+
|
|
14
|
+
WHAT THIS FILE DOES:
|
|
15
|
+
1. Orchestrates the training process
|
|
16
|
+
2. Manages device configuration via DeviceManager
|
|
17
|
+
3. Coordinates model, data, optimizer, and trainer
|
|
18
|
+
4. Provides clean interface for training operations
|
|
19
|
+
|
|
20
|
+
PACKAGES USED:
|
|
21
|
+
- torch: PyTorch training
|
|
22
|
+
- pathlib: Path handling
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
import random
|
|
27
|
+
import numpy as np
|
|
28
|
+
from pathlib import Path
|
|
29
|
+
from typing import Dict, Any, Optional
|
|
30
|
+
|
|
31
|
+
from gptmed.services.device_manager import DeviceManager
|
|
32
|
+
from gptmed.model.architecture import GPTTransformer
|
|
33
|
+
from gptmed.model.configs.model_config import get_tiny_config, get_small_config, get_medium_config
|
|
34
|
+
from gptmed.configs.train_config import TrainingConfig
|
|
35
|
+
from gptmed.training.dataset import create_dataloaders
|
|
36
|
+
from gptmed.training.trainer import Trainer
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class TrainingService:
|
|
40
|
+
"""
|
|
41
|
+
High-level service for model training.
|
|
42
|
+
|
|
43
|
+
Implements Service Layer Pattern to encapsulate training logic.
|
|
44
|
+
Uses Dependency Injection for DeviceManager.
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
>>> device_manager = DeviceManager(preferred_device='cpu')
|
|
48
|
+
>>> service = TrainingService(device_manager=device_manager)
|
|
49
|
+
>>> results = service.train_from_config('config.yaml', verbose=True)
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
device_manager: Optional[DeviceManager] = None,
|
|
55
|
+
verbose: bool = True
|
|
56
|
+
):
|
|
57
|
+
"""
|
|
58
|
+
Initialize TrainingService.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
device_manager: DeviceManager instance (if None, creates default)
|
|
62
|
+
verbose: Whether to print training information
|
|
63
|
+
"""
|
|
64
|
+
self.device_manager = device_manager or DeviceManager(preferred_device='cuda')
|
|
65
|
+
self.verbose = verbose
|
|
66
|
+
|
|
67
|
+
def set_seed(self, seed: int) -> None:
|
|
68
|
+
"""
|
|
69
|
+
Set random seeds for reproducibility.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
seed: Random seed value
|
|
73
|
+
"""
|
|
74
|
+
random.seed(seed)
|
|
75
|
+
np.random.seed(seed)
|
|
76
|
+
torch.manual_seed(seed)
|
|
77
|
+
if torch.cuda.is_available():
|
|
78
|
+
torch.cuda.manual_seed(seed)
|
|
79
|
+
torch.cuda.manual_seed_all(seed)
|
|
80
|
+
torch.backends.cudnn.deterministic = True
|
|
81
|
+
torch.backends.cudnn.benchmark = False
|
|
82
|
+
|
|
83
|
+
def create_model(self, model_size: str) -> GPTTransformer:
|
|
84
|
+
"""
|
|
85
|
+
Create model based on size specification.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
model_size: Model size ('tiny', 'small', or 'medium')
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
GPTTransformer model instance
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
ValueError: If model_size is invalid
|
|
95
|
+
"""
|
|
96
|
+
if model_size == 'tiny':
|
|
97
|
+
model_config = get_tiny_config()
|
|
98
|
+
elif model_size == 'small':
|
|
99
|
+
model_config = get_small_config()
|
|
100
|
+
elif model_size == 'medium':
|
|
101
|
+
model_config = get_medium_config()
|
|
102
|
+
else:
|
|
103
|
+
raise ValueError(f"Unknown model size: {model_size}")
|
|
104
|
+
|
|
105
|
+
return GPTTransformer(model_config)
|
|
106
|
+
|
|
107
|
+
def prepare_training(
|
|
108
|
+
self,
|
|
109
|
+
model: GPTTransformer,
|
|
110
|
+
train_config: TrainingConfig,
|
|
111
|
+
device: str
|
|
112
|
+
) -> tuple:
|
|
113
|
+
"""
|
|
114
|
+
Prepare components for training.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
model: Model to train
|
|
118
|
+
train_config: Training configuration
|
|
119
|
+
device: Device to use
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Tuple of (train_loader, val_loader, optimizer)
|
|
123
|
+
"""
|
|
124
|
+
# Load data
|
|
125
|
+
if self.verbose:
|
|
126
|
+
print(f"\n📊 Loading data...")
|
|
127
|
+
print(f" Train: {train_config.train_data_path}")
|
|
128
|
+
print(f" Val: {train_config.val_data_path}")
|
|
129
|
+
|
|
130
|
+
train_loader, val_loader = create_dataloaders(
|
|
131
|
+
train_path=Path(train_config.train_data_path),
|
|
132
|
+
val_path=Path(train_config.val_data_path),
|
|
133
|
+
batch_size=train_config.batch_size,
|
|
134
|
+
num_workers=0,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
if self.verbose:
|
|
138
|
+
print(f" Train batches: {len(train_loader)}")
|
|
139
|
+
print(f" Val batches: {len(val_loader)}")
|
|
140
|
+
|
|
141
|
+
# Create optimizer
|
|
142
|
+
if self.verbose:
|
|
143
|
+
print(f"\n⚙️ Setting up optimizer...")
|
|
144
|
+
print(f" Learning rate: {train_config.learning_rate}")
|
|
145
|
+
print(f" Weight decay: {train_config.weight_decay}")
|
|
146
|
+
|
|
147
|
+
optimizer = torch.optim.AdamW(
|
|
148
|
+
model.parameters(),
|
|
149
|
+
lr=train_config.learning_rate,
|
|
150
|
+
betas=train_config.betas,
|
|
151
|
+
eps=train_config.eps,
|
|
152
|
+
weight_decay=train_config.weight_decay,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return train_loader, val_loader, optimizer
|
|
156
|
+
|
|
157
|
+
def execute_training(
|
|
158
|
+
self,
|
|
159
|
+
model: GPTTransformer,
|
|
160
|
+
train_loader,
|
|
161
|
+
val_loader,
|
|
162
|
+
optimizer,
|
|
163
|
+
train_config: TrainingConfig,
|
|
164
|
+
device: str,
|
|
165
|
+
model_config_dict: dict
|
|
166
|
+
) -> Dict[str, Any]:
|
|
167
|
+
"""
|
|
168
|
+
Execute the training process.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
model: Model to train
|
|
172
|
+
train_loader: Training data loader
|
|
173
|
+
val_loader: Validation data loader
|
|
174
|
+
optimizer: Optimizer
|
|
175
|
+
train_config: Training configuration
|
|
176
|
+
device: Device to use
|
|
177
|
+
model_config_dict: Model configuration as dictionary
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Dictionary with training results
|
|
181
|
+
"""
|
|
182
|
+
# Create trainer
|
|
183
|
+
if self.verbose:
|
|
184
|
+
print(f"\n🎯 Initializing trainer...")
|
|
185
|
+
|
|
186
|
+
trainer = Trainer(
|
|
187
|
+
model=model,
|
|
188
|
+
train_loader=train_loader,
|
|
189
|
+
val_loader=val_loader,
|
|
190
|
+
optimizer=optimizer,
|
|
191
|
+
config=train_config,
|
|
192
|
+
device=device,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Resume if requested
|
|
196
|
+
if hasattr(train_config, 'resume_from') and train_config.resume_from is not None:
|
|
197
|
+
if self.verbose:
|
|
198
|
+
print(f"\n📥 Resuming from checkpoint: {train_config.resume_from}")
|
|
199
|
+
trainer.resume_from_checkpoint(Path(train_config.resume_from))
|
|
200
|
+
elif train_config.checkpoint_dir and hasattr(train_config, 'checkpoint_dir'):
|
|
201
|
+
# Check if there's a resume_from in the checkpoint dir
|
|
202
|
+
resume_path = Path(train_config.checkpoint_dir) / "resume_from.pt"
|
|
203
|
+
if resume_path.exists() and self.verbose:
|
|
204
|
+
print(f"\n📥 Found checkpoint to resume: {resume_path}")
|
|
205
|
+
|
|
206
|
+
# Start training
|
|
207
|
+
if self.verbose:
|
|
208
|
+
print(f"\n{'='*60}")
|
|
209
|
+
print("🚀 Starting Training!")
|
|
210
|
+
print(f"{'='*60}\n")
|
|
211
|
+
|
|
212
|
+
try:
|
|
213
|
+
trainer.train()
|
|
214
|
+
except KeyboardInterrupt:
|
|
215
|
+
if self.verbose:
|
|
216
|
+
print("\n\n⏸️ Training interrupted by user")
|
|
217
|
+
print("💾 Saving checkpoint...")
|
|
218
|
+
trainer.checkpoint_manager.save_checkpoint(
|
|
219
|
+
model=model,
|
|
220
|
+
optimizer=optimizer,
|
|
221
|
+
step=trainer.global_step,
|
|
222
|
+
epoch=trainer.current_epoch,
|
|
223
|
+
val_loss=trainer.best_val_loss,
|
|
224
|
+
model_config=model_config_dict,
|
|
225
|
+
train_config=train_config.to_dict(),
|
|
226
|
+
)
|
|
227
|
+
if self.verbose:
|
|
228
|
+
print("✓ Checkpoint saved. Resume with resume_from in config.")
|
|
229
|
+
|
|
230
|
+
# Return results
|
|
231
|
+
best_checkpoint = Path(train_config.checkpoint_dir) / "best_model.pt"
|
|
232
|
+
|
|
233
|
+
results = {
|
|
234
|
+
'best_checkpoint': str(best_checkpoint),
|
|
235
|
+
'final_val_loss': trainer.best_val_loss,
|
|
236
|
+
'total_epochs': trainer.current_epoch,
|
|
237
|
+
'checkpoint_dir': train_config.checkpoint_dir,
|
|
238
|
+
'log_dir': train_config.log_dir,
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
if self.verbose:
|
|
242
|
+
print(f"\n{'='*60}")
|
|
243
|
+
print("✅ Training Complete!")
|
|
244
|
+
print(f"{'='*60}")
|
|
245
|
+
print(f"\n📁 Results:")
|
|
246
|
+
print(f" Best checkpoint: {results['best_checkpoint']}")
|
|
247
|
+
print(f" Best val loss: {results['final_val_loss']:.4f}")
|
|
248
|
+
print(f" Total epochs: {results['total_epochs']}")
|
|
249
|
+
print(f" Logs: {results['log_dir']}")
|
|
250
|
+
|
|
251
|
+
return results
|
|
252
|
+
|
|
253
|
+
def train(
|
|
254
|
+
self,
|
|
255
|
+
model_size: str,
|
|
256
|
+
train_data_path: str,
|
|
257
|
+
val_data_path: str,
|
|
258
|
+
batch_size: int = 16,
|
|
259
|
+
learning_rate: float = 3e-4,
|
|
260
|
+
num_epochs: int = 10,
|
|
261
|
+
checkpoint_dir: str = "./model/checkpoints",
|
|
262
|
+
log_dir: str = "./logs",
|
|
263
|
+
seed: int = 42,
|
|
264
|
+
**kwargs
|
|
265
|
+
) -> Dict[str, Any]:
|
|
266
|
+
"""
|
|
267
|
+
High-level training interface.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
model_size: Model size ('tiny', 'small', 'medium')
|
|
271
|
+
train_data_path: Path to training data
|
|
272
|
+
val_data_path: Path to validation data
|
|
273
|
+
batch_size: Training batch size
|
|
274
|
+
learning_rate: Learning rate
|
|
275
|
+
num_epochs: Number of training epochs
|
|
276
|
+
checkpoint_dir: Directory for checkpoints
|
|
277
|
+
log_dir: Directory for logs
|
|
278
|
+
seed: Random seed
|
|
279
|
+
**kwargs: Additional training config parameters
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Dictionary with training results
|
|
283
|
+
"""
|
|
284
|
+
# Set seed
|
|
285
|
+
if self.verbose:
|
|
286
|
+
print(f"\n🎲 Setting random seed: {seed}")
|
|
287
|
+
self.set_seed(seed)
|
|
288
|
+
|
|
289
|
+
# Get device
|
|
290
|
+
device = self.device_manager.get_device()
|
|
291
|
+
self.device_manager.print_device_info(verbose=self.verbose)
|
|
292
|
+
|
|
293
|
+
# Create model
|
|
294
|
+
if self.verbose:
|
|
295
|
+
print(f"\n🧠 Creating model: {model_size}")
|
|
296
|
+
|
|
297
|
+
model = self.create_model(model_size)
|
|
298
|
+
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
299
|
+
|
|
300
|
+
if self.verbose:
|
|
301
|
+
print(f" Model size: {model_size}")
|
|
302
|
+
print(f" Parameters: {total_params:,}")
|
|
303
|
+
print(f" Memory: ~{total_params * 4 / 1024 / 1024:.2f} MB")
|
|
304
|
+
|
|
305
|
+
# Create training config
|
|
306
|
+
train_config = TrainingConfig(
|
|
307
|
+
train_data_path=train_data_path,
|
|
308
|
+
val_data_path=val_data_path,
|
|
309
|
+
batch_size=batch_size,
|
|
310
|
+
learning_rate=learning_rate,
|
|
311
|
+
num_epochs=num_epochs,
|
|
312
|
+
checkpoint_dir=checkpoint_dir,
|
|
313
|
+
log_dir=log_dir,
|
|
314
|
+
device=device,
|
|
315
|
+
seed=seed,
|
|
316
|
+
**{k: v for k, v in kwargs.items() if hasattr(TrainingConfig, k)}
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# Prepare training components
|
|
320
|
+
train_loader, val_loader, optimizer = self.prepare_training(
|
|
321
|
+
model, train_config, device
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# Execute training
|
|
325
|
+
results = self.execute_training(
|
|
326
|
+
model=model,
|
|
327
|
+
train_loader=train_loader,
|
|
328
|
+
val_loader=val_loader,
|
|
329
|
+
optimizer=optimizer,
|
|
330
|
+
train_config=train_config,
|
|
331
|
+
device=device,
|
|
332
|
+
model_config_dict=model.config.to_dict()
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
return results
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gptmed
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.5
|
|
4
4
|
Summary: A lightweight GPT-based language model framework for training custom question-answering models on any domain
|
|
5
5
|
Author-email: Sanjog Sigdel <sigdelsanjog@gmail.com>
|
|
6
6
|
Maintainer-email: Sanjog Sigdel <sigdelsanjog@gmail.com>
|
|
@@ -22,6 +22,9 @@ gptmed/model/architecture/feedforward.py,sha256=uJ5QOlWX0ritKDQLUE7GPmMojelR9-sT
|
|
|
22
22
|
gptmed/model/architecture/transformer.py,sha256=H1njPoy0Uam59JbA24C0olEDwPfhh3ev4HsUFRIC_0Y,6626
|
|
23
23
|
gptmed/model/configs/__init__.py,sha256=LDCWhlCDOU7490wcfSId_jXBPfQrtYQEw8FoD67rqBs,275
|
|
24
24
|
gptmed/model/configs/model_config.py,sha256=wI-i2Dw_pTdIKCDe1pqLvP3ky3YedEy7DwZYN5lwmKE,4673
|
|
25
|
+
gptmed/services/__init__.py,sha256=FtM7NQ_S4VOfl2n6A6cLcOxG9-w7BK7DicQsUvOMmGE,369
|
|
26
|
+
gptmed/services/device_manager.py,sha256=RSsu0RlsexCIO-p4eejOZAPLgpaVA0y9niTg8wf1luY,7513
|
|
27
|
+
gptmed/services/training_service.py,sha256=o9Kxxoi6CVDvvM9rwGYNX426qTnmqxLXLt_bVi1ZSK4,11253
|
|
25
28
|
gptmed/tokenizer/__init__.py,sha256=KhLAHPmQyoWhnKDenyIJRxgFflKI7xklip28j4cKfKw,157
|
|
26
29
|
gptmed/tokenizer/tokenize_data.py,sha256=KgMtMfaz_RtOhN_CrvC267k9ujxRdO89rToVJ6nzdwg,9139
|
|
27
30
|
gptmed/tokenizer/train_tokenizer.py,sha256=f0Hucyft9e8LU2RtpTqg8h_0SpOC_oMABl0_me-wfL8,7068
|
|
@@ -33,9 +36,9 @@ gptmed/training/utils.py,sha256=pJxCwneNr2STITIYwIDCxRzIICDFOxOMzK8DT7ck2oQ,5651
|
|
|
33
36
|
gptmed/utils/__init__.py,sha256=XuMhIqOXF7mjnog_6Iky-hSbwvFb0iK42B4iDUpgi0U,44
|
|
34
37
|
gptmed/utils/checkpoints.py,sha256=L4q1-_4GbHCoD7QuEKYeQ-xXDTF-6sqZOxKQ_LT8YmQ,7112
|
|
35
38
|
gptmed/utils/logging.py,sha256=7dJc1tayMxCBjFSDXe4r9ACUTpoPTTGsJ0UZMTqZIDY,5303
|
|
36
|
-
gptmed-0.3.
|
|
37
|
-
gptmed-0.3.
|
|
38
|
-
gptmed-0.3.
|
|
39
|
-
gptmed-0.3.
|
|
40
|
-
gptmed-0.3.
|
|
41
|
-
gptmed-0.3.
|
|
39
|
+
gptmed-0.3.5.dist-info/licenses/LICENSE,sha256=v2spsd7N1pKFFh2G8wGP_45iwe5S0DYiJzG4im8Rupc,1066
|
|
40
|
+
gptmed-0.3.5.dist-info/METADATA,sha256=Zx3kFlZiBdkXco_VkEqOnIeasCYrgWl2XP21D2QcmuA,9382
|
|
41
|
+
gptmed-0.3.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
42
|
+
gptmed-0.3.5.dist-info/entry_points.txt,sha256=ATqOzTtPVdUiFX5ZSeo3n9JkUCqocUxEXTgy1CfNRZE,110
|
|
43
|
+
gptmed-0.3.5.dist-info/top_level.txt,sha256=mhyEq3rG33t21ziJz5w3TPgx0RjPf4zXMNUx2JTiNmE,7
|
|
44
|
+
gptmed-0.3.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|