gptmed 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gptmed/api.py CHANGED
@@ -139,6 +139,14 @@ def train_from_config(
139
139
  # Print device information
140
140
  device_manager.print_device_info(verbose=verbose)
141
141
 
142
+ # Optionally enable Redis metrics storage
143
+ from gptmed.observability.redis_metrics_storage import RedisMetricsStorage
144
+ from gptmed.observability.metrics_tracker import MetricsTracker
145
+
146
+ redis_enabled = False
147
+ if 'redis' in config and config['redis'].get('enabled', False):
148
+ redis_enabled = True
149
+
142
150
  # Create TrainingService with DeviceManager
143
151
  training_service = TrainingService(
144
152
  device_manager=device_manager,
@@ -203,6 +211,7 @@ def train_from_config(
203
211
  log_dir=args['log_dir'],
204
212
  device=actual_device, # Use actual device from DeviceManager
205
213
  seed=args['seed'],
214
+ resume_from=args.get('resume_from'), # Pass resume checkpoint path
206
215
  )
207
216
 
208
217
  # Create optimizer
@@ -219,6 +228,23 @@ def train_from_config(
219
228
  weight_decay=args['weight_decay'],
220
229
  )
221
230
 
231
+
232
+ # If Redis is enabled, inject RedisMetricsStorage into MetricsTracker
233
+ observers = None
234
+ if redis_enabled:
235
+ if verbose:
236
+ print("\n🔗 Enabling Redis metrics storage...")
237
+ observers = [
238
+ MetricsTracker(
239
+ log_dir=train_config.log_dir,
240
+ experiment_name="gptmed_training",
241
+ moving_avg_window=100,
242
+ log_interval=train_config.log_interval,
243
+ verbose=verbose,
244
+ storage_backend=RedisMetricsStorage(),
245
+ )
246
+ ]
247
+
222
248
  # Execute training using TrainingService
223
249
  results = training_service.execute_training(
224
250
  model=model,
@@ -227,9 +253,9 @@ def train_from_config(
227
253
  optimizer=optimizer,
228
254
  train_config=train_config,
229
255
  device=actual_device,
230
- model_config_dict=model.config.to_dict()
256
+ model_config_dict=model.config.to_dict(),
257
+ observers=observers,
231
258
  )
232
-
233
259
  return results
234
260
 
235
261
 
@@ -35,6 +35,14 @@ def load_yaml_config(config_path: str) -> Dict[str, Any]:
35
35
  except yaml.YAMLError as e:
36
36
  raise ValueError(f"Error parsing YAML configuration: {e}")
37
37
 
38
+ # If redis section is missing, add defaults
39
+ if 'redis' not in config:
40
+ config['redis'] = {
41
+ 'host': 'localhost',
42
+ 'port': 6379,
43
+ 'db': 0,
44
+ 'password': None
45
+ }
38
46
  return config
39
47
 
40
48
 
@@ -86,6 +94,17 @@ def validate_config(config: Dict[str, Any]) -> None:
86
94
  f"Must be one of {valid_devices}"
87
95
  )
88
96
 
97
+ # Validate redis config
98
+ redis_cfg = config.get('redis', {})
99
+ if not isinstance(redis_cfg, dict):
100
+ raise ValueError("Redis config must be a dictionary.")
101
+ if 'host' not in redis_cfg or not redis_cfg['host']:
102
+ raise ValueError("Redis config missing 'host'.")
103
+ if 'port' not in redis_cfg or not isinstance(redis_cfg['port'], int):
104
+ raise ValueError("Redis config missing or invalid 'port'.")
105
+ if 'db' not in redis_cfg or not isinstance(redis_cfg['db'], int):
106
+ raise ValueError("Redis config missing or invalid 'db'.")
107
+
89
108
 
90
109
  def config_to_args(config: Dict[str, Any]) -> Dict[str, Any]:
91
110
  """
@@ -135,6 +154,14 @@ def config_to_args(config: Dict[str, Any]) -> Dict[str, Any]:
135
154
  'max_steps': config.get('advanced', {}).get('max_steps', -1),
136
155
  'resume_from': config.get('advanced', {}).get('resume_from'),
137
156
  'quick_test': config.get('advanced', {}).get('quick_test', False),
157
+
158
+ # Redis
159
+ 'redis_config': config.get('redis', {
160
+ 'host': 'localhost',
161
+ 'port': 6379,
162
+ 'db': 0,
163
+ 'password': None
164
+ }),
138
165
  }
139
166
 
140
167
  return args
@@ -185,6 +212,12 @@ def create_default_config_file(output_path: str = 'training_config.yaml') -> Non
185
212
  'max_steps': -1,
186
213
  'resume_from': None,
187
214
  'quick_test': False
215
+ },
216
+ 'redis': {
217
+ 'host': 'localhost',
218
+ 'port': 6379,
219
+ 'db': 0,
220
+ 'password': None
188
221
  }
189
222
  }
190
223
 
@@ -0,0 +1,17 @@
1
+ """
2
+ Configuration for Redis connection for real-time training metrics storage.
3
+ """
4
+
5
+ import os
6
+
7
+ REDIS_HOST = "localhost"
8
+ REDIS_PORT = 6379
9
+ REDIS_DB = 0
10
+ REDIS_PASSWORD = None
11
+
12
+ REDIS_CONFIG = {
13
+ "host": REDIS_HOST,
14
+ "port": REDIS_PORT,
15
+ "db": REDIS_DB,
16
+ "password": REDIS_PASSWORD,
17
+ }
@@ -37,6 +37,7 @@ COMMON FAILURE MODES:
37
37
 
38
38
  from dataclasses import dataclass
39
39
  from pathlib import Path
40
+ from typing import Optional
40
41
  import json
41
42
 
42
43
 
@@ -91,6 +92,9 @@ class TrainingConfig:
91
92
  # Reproducibility
92
93
  seed: int = 42
93
94
 
95
+ # Resume training
96
+ resume_from: Optional[str] = None # Path to checkpoint to resume from
97
+
94
98
  def to_dict(self) -> dict:
95
99
  """Convert to dictionary."""
96
100
  return {
@@ -117,6 +121,7 @@ class TrainingConfig:
117
121
  "device": self.device,
118
122
  "use_amp": self.use_amp,
119
123
  "seed": self.seed,
124
+ "resume_from": self.resume_from,
120
125
  }
121
126
 
122
127
  def save(self, path: Path):
@@ -21,6 +21,7 @@ WHAT TO LOOK FOR:
21
21
  - Loss = NaN → Exploding gradients
22
22
  """
23
23
 
24
+
24
25
  import json
25
26
  import math
26
27
  import time
@@ -36,6 +37,9 @@ from gptmed.observability.base import (
36
37
  GradientMetrics,
37
38
  )
38
39
 
40
+ # Import the interface but not the Redis implementation directly (for loose coupling)
41
+ from gptmed.observability.redis_metrics_storage import MetricsStorageInterface
42
+
39
43
 
40
44
  @dataclass
41
45
  class LossCurvePoint:
@@ -79,6 +83,7 @@ class MetricsTracker(TrainingObserver):
79
83
  moving_avg_window: int = 100,
80
84
  log_interval: int = 10,
81
85
  verbose: bool = True,
86
+ storage_backend: Optional[MetricsStorageInterface] = None,
82
87
  ):
83
88
  """
84
89
  Initialize MetricsTracker.
@@ -91,26 +96,31 @@ class MetricsTracker(TrainingObserver):
91
96
  verbose: Whether to print progress
92
97
  """
93
98
  super().__init__(name="MetricsTracker")
94
-
99
+
95
100
  self.log_dir = Path(log_dir)
96
101
  self.log_dir.mkdir(parents=True, exist_ok=True)
97
-
102
+
98
103
  self.experiment_name = experiment_name
99
104
  self.moving_avg_window = moving_avg_window
100
105
  self.log_interval = log_interval
101
106
  self.verbose = verbose
102
-
107
+
108
+ # Optional metrics storage backend (e.g., Redis)
109
+ self.storage_backend = storage_backend
110
+
103
111
  # Initialize storage
104
112
  self._reset_storage()
105
-
113
+
106
114
  # File paths
107
115
  self.metrics_file = self.log_dir / f"{experiment_name}_metrics.jsonl"
108
116
  self.summary_file = self.log_dir / f"{experiment_name}_summary.json"
109
-
117
+
110
118
  if self.verbose:
111
119
  print(f"📊 MetricsTracker initialized")
112
120
  print(f" Log directory: {self.log_dir}")
113
121
  print(f" Moving average window: {moving_avg_window}")
122
+ if self.storage_backend:
123
+ print(f" Using external metrics storage: {type(self.storage_backend).__name__}")
114
124
 
115
125
  def _reset_storage(self) -> None:
116
126
  """Reset all metric storage."""
@@ -163,53 +173,69 @@ class MetricsTracker(TrainingObserver):
163
173
  def on_step(self, metrics: StepMetrics) -> None:
164
174
  """Called after each training step."""
165
175
  timestamp = time.time() - self.start_time if self.start_time else 0
166
-
176
+
167
177
  # Store loss
168
178
  self.train_losses.append(LossCurvePoint(
169
179
  step=metrics.step,
170
180
  loss=metrics.loss,
171
181
  timestamp=timestamp,
172
182
  ))
173
-
183
+
174
184
  # Update moving average buffer
175
185
  self._loss_buffer.append(metrics.loss)
176
-
186
+
177
187
  # Store learning rate
178
188
  self.learning_rates.append((metrics.step, metrics.learning_rate))
179
-
189
+
180
190
  # Store gradient norm
181
191
  self.gradient_norms.append((metrics.step, metrics.grad_norm))
182
-
192
+
183
193
  # Store perplexity
184
194
  self.train_perplexities.append((metrics.step, metrics.perplexity))
185
-
195
+
186
196
  # Log to file periodically
187
197
  if metrics.step % self.log_interval == 0:
188
198
  self._log_step(metrics, timestamp)
199
+ # Also log to external storage if available
200
+ if self.storage_backend:
201
+ self.storage_backend.save_step_metrics({
202
+ "type": "step",
203
+ "timestamp": timestamp,
204
+ "moving_avg_loss": self.get_moving_average(),
205
+ **metrics.to_dict(),
206
+ })
189
207
 
190
208
  def on_validation(self, metrics: ValidationMetrics) -> None:
191
209
  """Called after validation."""
192
210
  timestamp = time.time() - self.start_time if self.start_time else 0
193
-
211
+
194
212
  # Store validation loss
195
213
  self.val_losses.append(LossCurvePoint(
196
214
  step=metrics.step,
197
215
  loss=metrics.val_loss,
198
216
  timestamp=timestamp,
199
217
  ))
200
-
218
+
201
219
  # Store validation perplexity
202
220
  self.val_perplexities.append((metrics.step, metrics.val_perplexity))
203
-
221
+
204
222
  # Track best
205
223
  if metrics.val_loss < self.best_val_loss:
206
224
  self.best_val_loss = metrics.val_loss
207
225
  self.best_val_step = metrics.step
208
226
  if self.verbose:
209
227
  print(f" ⭐ New best val_loss: {metrics.val_loss:.4f}")
210
-
228
+
211
229
  # Log to file
212
230
  self._log_validation(metrics, timestamp)
231
+ # Also log to external storage if available
232
+ if self.storage_backend:
233
+ self.storage_backend.save_validation_metrics({
234
+ "type": "validation",
235
+ "timestamp": timestamp,
236
+ "is_best": metrics.val_loss <= self.best_val_loss,
237
+ **metrics.to_dict(),
238
+ })
213
239
 
214
240
  def on_train_end(self, final_metrics: Dict[str, Any]) -> None:
215
241
  """Called when training completes."""
@@ -0,0 +1,26 @@
1
+ """
2
+ Redis client for real-time metrics storage.
3
+ Follows SOLID principles: single responsibility, dependency inversion, and interface segregation.
4
+ """
5
+
6
+ import redis
7
+ from typing import Any, Dict
8
+ from gptmed.configs.configs import REDIS_CONFIG
9
+
10
+ class MetricsStorageInterface:
11
+ """Interface for metrics storage backends."""
12
+ def save_step_metrics(self, metrics: Dict[str, Any]):
13
+ raise NotImplementedError
14
+ def save_validation_metrics(self, metrics: Dict[str, Any]):
15
+ raise NotImplementedError
16
+
17
+ class RedisMetricsStorage(MetricsStorageInterface):
18
+ """Redis implementation for metrics storage."""
19
+ def __init__(self):
20
+ self.client = redis.Redis(**REDIS_CONFIG)
21
+ def save_step_metrics(self, metrics: Dict[str, Any]):
22
+ # Use a Redis list for steps
23
+ self.client.rpush("training:steps", str(metrics))
24
+ def save_validation_metrics(self, metrics: Dict[str, Any]):
25
+ # Use a Redis list for validation
26
+ self.client.rpush("training:validation", str(metrics))
@@ -1,9 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gptmed
3
- Version: 0.4.0
3
+ Version: 0.5.0
4
4
  Summary: A lightweight GPT-based language model framework for training custom question-answering models on any domain
5
- Author-email: Sanjog Sigdel <sigdelsanjog@gmail.com>
6
- Maintainer-email: Sanjog Sigdel <sigdelsanjog@gmail.com>
5
+ Author-email: Sanjog Sigdel <sigdelsanjog@gmail.com>, Sanjog Sigdel <sanjog.sigdel@ku.edu.np>
7
6
  License-Expression: MIT
8
7
  Project-URL: Homepage, https://github.com/sigdelsanjog/gptmed
9
8
  Project-URL: Documentation, https://github.com/sigdelsanjog/gptmed#readme
@@ -1,8 +1,9 @@
1
1
  gptmed/__init__.py,sha256=lSCUt0jmB81dEG0UroQdrk8TMG9Hv-_a14nAvB6yYiQ,2725
2
- gptmed/api.py,sha256=k9a_1F2h__xgKnH2l0FaJqAqu-iTYt5tu_VfVO0UhrA,9806
2
+ gptmed/api.py,sha256=BhR2I5CZNlYD34U4i85FQTUp8tLNyzhOrsT7monKpW0,10780
3
3
  gptmed/configs/__init__.py,sha256=yRa-zgPQ-OCzu8fvCrfWMG-CjF3dru3PZzknzm0oUaQ,23
4
- gptmed/configs/config_loader.py,sha256=3GQ1iCNpdJ5yALWXA3SPPHRkaUO-117vdArEL6u7sK8,6354
5
- gptmed/configs/train_config.py,sha256=KqfNBh9hdTTd_6gEAlrClU8sVFSlVDmZJOrf3cPwFe8,4657
4
+ gptmed/configs/config_loader.py,sha256=NhIjmZ5ACcwZubdEcDq42PJuR03ulmZv_GYizhOIlPI,7466
5
+ gptmed/configs/configs.py,sha256=704fWZS2OSRloEuhwlP6ezLly2paZfsdYCkJ1jfKuPE,293
6
+ gptmed/configs/train_config.py,sha256=cuGE5o4N3TA65Sue8J3XrbmI5QKI7Ww3WeHd2M7yoHQ,4828
6
7
  gptmed/configs/training_config.yaml,sha256=EEZZa3kcsZr3g-_fKDPYZt4_NTpmS-3NvJrTYSWNc8g,2874
7
8
  gptmed/data/__init__.py,sha256=iAHeakB5pBAd7MkmarPPY0UKS9bTaO_winLZ23Y2O90,54
8
9
  gptmed/data/parsers/__init__.py,sha256=BgVzXuZgeE5DUCC4SzN7vflL40wQ4Q4_4DmJ1Y43_nw,211
@@ -25,7 +26,8 @@ gptmed/model/configs/model_config.py,sha256=wI-i2Dw_pTdIKCDe1pqLvP3ky3YedEy7DwZY
25
26
  gptmed/observability/__init__.py,sha256=AtGf0D8jEx2LGQ0Ro-Eh0SFDuA5ZjZkot7D1Y8j1jiM,1180
26
27
  gptmed/observability/base.py,sha256=Mi3F95bJ9Tw5scoSyw9AtKlcu9aG444G1UlycIIGCtI,10748
27
28
  gptmed/observability/callbacks.py,sha256=1b84_e86mfyt2EQGzf-6K2Sba3bZJt4I3bBJb52TAbA,13170
28
- gptmed/observability/metrics_tracker.py,sha256=Bs6tppQYG9AOb3rj2T1lhWKDyOw4R4ZG6nFGRiek8FQ,19441
29
+ gptmed/observability/metrics_tracker.py,sha256=AEGcAjMTGMy--NIBxPEWfvwa3e5lvdkJEDNxHss6Dak,20493
30
+ gptmed/observability/redis_metrics_storage.py,sha256=CYNRYB481-tGZ-BTMOSFlP-enn26dxg3nrUEni2hDXA,1014
29
31
  gptmed/services/__init__.py,sha256=FtM7NQ_S4VOfl2n6A6cLcOxG9-w7BK7DicQsUvOMmGE,369
30
32
  gptmed/services/device_manager.py,sha256=RSsu0RlsexCIO-p4eejOZAPLgpaVA0y9niTg8wf1luY,7513
31
33
  gptmed/services/training_service.py,sha256=cF3yYo8aZe7BfQ-paTN-l7EYs9h8L_JUyRhiI0GEP4E,16921
@@ -40,9 +42,9 @@ gptmed/training/utils.py,sha256=pJxCwneNr2STITIYwIDCxRzIICDFOxOMzK8DT7ck2oQ,5651
40
42
  gptmed/utils/__init__.py,sha256=XuMhIqOXF7mjnog_6Iky-hSbwvFb0iK42B4iDUpgi0U,44
41
43
  gptmed/utils/checkpoints.py,sha256=jPKJtO0YRZieGmpwqotgDkBzd__s_raDxS1kLpfjBJE,7113
42
44
  gptmed/utils/logging.py,sha256=7dJc1tayMxCBjFSDXe4r9ACUTpoPTTGsJ0UZMTqZIDY,5303
43
- gptmed-0.4.0.dist-info/licenses/LICENSE,sha256=v2spsd7N1pKFFh2G8wGP_45iwe5S0DYiJzG4im8Rupc,1066
44
- gptmed-0.4.0.dist-info/METADATA,sha256=kVsL6zbBoGw1jrlaDiPkBAr_D7YedPCSwZkjGCFz04c,13832
45
- gptmed-0.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
46
- gptmed-0.4.0.dist-info/entry_points.txt,sha256=ATqOzTtPVdUiFX5ZSeo3n9JkUCqocUxEXTgy1CfNRZE,110
47
- gptmed-0.4.0.dist-info/top_level.txt,sha256=mhyEq3rG33t21ziJz5w3TPgx0RjPf4zXMNUx2JTiNmE,7
48
- gptmed-0.4.0.dist-info/RECORD,,
45
+ gptmed-0.5.0.dist-info/licenses/LICENSE,sha256=v2spsd7N1pKFFh2G8wGP_45iwe5S0DYiJzG4im8Rupc,1066
46
+ gptmed-0.5.0.dist-info/METADATA,sha256=EcZFQ_be5xb2uhu6x8HwReHloAEypxZUp2foN5Xn6VY,13816
47
+ gptmed-0.5.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
48
+ gptmed-0.5.0.dist-info/entry_points.txt,sha256=ATqOzTtPVdUiFX5ZSeo3n9JkUCqocUxEXTgy1CfNRZE,110
49
+ gptmed-0.5.0.dist-info/top_level.txt,sha256=mhyEq3rG33t21ziJz5w3TPgx0RjPf4zXMNUx2JTiNmE,7
50
+ gptmed-0.5.0.dist-info/RECORD,,
File without changes