langtune 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langtune/auth.py ADDED
@@ -0,0 +1,434 @@
1
+ """
2
+ auth.py: API Key authentication and usage tracking for Langtune
3
+
4
+ Users must obtain an API key from https://langtrain.xyz to use this package.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import json
10
+ import hashlib
11
+ import time
12
+ from pathlib import Path
13
+ from typing import Optional, Dict, Any
14
+ from datetime import datetime, timedelta
15
+
16
+ # Try to import requests for API calls
17
+ try:
18
+ import requests
19
+ REQUESTS_AVAILABLE = True
20
+ except ImportError:
21
+ REQUESTS_AVAILABLE = False
22
+
23
+ # Try to import rich for beautiful output
24
+ try:
25
+ from rich.console import Console
26
+ from rich.panel import Panel
27
+ from rich.text import Text
28
+ from rich import box
29
+ RICH_AVAILABLE = True
30
+ console = Console()
31
+ except ImportError:
32
+ RICH_AVAILABLE = False
33
+ console = None
34
+
35
+ # API configuration
36
+ API_BASE_URL = "https://api.langtrain.xyz"
37
+ AUTH_ENDPOINT = f"{API_BASE_URL}/v1/auth/verify"
38
+ USAGE_ENDPOINT = f"{API_BASE_URL}/v1/usage"
39
+
40
+ # Config paths
41
+ CONFIG_DIR = Path.home() / ".langtune"
42
+ CONFIG_FILE = CONFIG_DIR / "config.json"
43
+ CACHE_FILE = CONFIG_DIR / ".auth_cache"
44
+
45
+ # Environment variable names
46
+ API_KEY_ENV = "LANGTUNE_API_KEY"
47
+
48
+
49
+ class AuthenticationError(Exception):
50
+ """Raised when API key authentication fails."""
51
+ pass
52
+
53
+
54
+ class UsageLimitError(Exception):
55
+ """Raised when usage limit is exceeded."""
56
+ pass
57
+
58
+
59
+ def _get_config_dir() -> Path:
60
+ """Get or create the config directory."""
61
+ CONFIG_DIR.mkdir(parents=True, exist_ok=True)
62
+ return CONFIG_DIR
63
+
64
+
65
+ def _load_config() -> Dict[str, Any]:
66
+ """Load configuration from file."""
67
+ if CONFIG_FILE.exists():
68
+ try:
69
+ with open(CONFIG_FILE, 'r') as f:
70
+ return json.load(f)
71
+ except (json.JSONDecodeError, IOError):
72
+ return {}
73
+ return {}
74
+
75
+
76
+ def _save_config(config: Dict[str, Any]) -> None:
77
+ """Save configuration to file."""
78
+ _get_config_dir()
79
+ with open(CONFIG_FILE, 'w') as f:
80
+ json.dump(config, f, indent=2)
81
+
82
+
83
+ def _load_auth_cache() -> Dict[str, Any]:
84
+ """Load cached authentication data."""
85
+ if CACHE_FILE.exists():
86
+ try:
87
+ with open(CACHE_FILE, 'r') as f:
88
+ return json.load(f)
89
+ except (json.JSONDecodeError, IOError):
90
+ return {}
91
+ return {}
92
+
93
+
94
+ def _save_auth_cache(cache: Dict[str, Any]) -> None:
95
+ """Save authentication cache."""
96
+ _get_config_dir()
97
+ with open(CACHE_FILE, 'w') as f:
98
+ json.dump(cache, f)
99
+
100
+
101
+ def get_api_key() -> Optional[str]:
102
+ """
103
+ Get the API key from environment or config file.
104
+
105
+ Priority:
106
+ 1. LANGTUNE_API_KEY environment variable
107
+ 2. Config file (~/.langtune/config.json)
108
+ """
109
+ # Check environment variable first
110
+ api_key = os.environ.get(API_KEY_ENV)
111
+ if api_key:
112
+ return api_key
113
+
114
+ # Check config file
115
+ config = _load_config()
116
+ return config.get("api_key")
117
+
118
+
119
+ def set_api_key(api_key: str) -> None:
120
+ """Save API key to config file."""
121
+ config = _load_config()
122
+ config["api_key"] = api_key
123
+ _save_config(config)
124
+
125
+ if RICH_AVAILABLE:
126
+ console.print("[green]✓[/] API key saved to ~/.langtune/config.json")
127
+ else:
128
+ print("✓ API key saved to ~/.langtune/config.json")
129
+
130
+
131
+ def _hash_key(api_key: str) -> str:
132
+ """Hash API key for cache lookup."""
133
+ return hashlib.sha256(api_key.encode()).hexdigest()[:16]
134
+
135
+
136
+ def verify_api_key(api_key: str, force_refresh: bool = False) -> Dict[str, Any]:
137
+ """
138
+ Verify API key with the Langtrain API.
139
+
140
+ Returns user info and usage limits on success.
141
+ Raises AuthenticationError on failure.
142
+ """
143
+ # Check cache first (valid for 1 hour)
144
+ cache = _load_auth_cache()
145
+ key_hash = _hash_key(api_key)
146
+
147
+ if not force_refresh and key_hash in cache:
148
+ cached_data = cache[key_hash]
149
+ cache_time = cached_data.get("cached_at", 0)
150
+ if time.time() - cache_time < 3600: # 1 hour cache
151
+ return cached_data.get("data", {})
152
+
153
+ # For now, simulate API verification (offline mode)
154
+ # In production, this would make an actual API call
155
+ if not REQUESTS_AVAILABLE:
156
+ # Offline verification - accept keys that match pattern
157
+ if api_key.startswith("lt_") and len(api_key) >= 32:
158
+ user_data = {
159
+ "valid": True,
160
+ "user_id": key_hash,
161
+ "plan": "free",
162
+ "usage": {
163
+ "tokens_used": 0,
164
+ "tokens_limit": 100000,
165
+ "requests_used": 0,
166
+ "requests_limit": 1000
167
+ },
168
+ "offline_mode": True
169
+ }
170
+ # Cache the result
171
+ cache[key_hash] = {
172
+ "cached_at": time.time(),
173
+ "data": user_data
174
+ }
175
+ _save_auth_cache(cache)
176
+ return user_data
177
+ else:
178
+ raise AuthenticationError(
179
+ "Invalid API key format. Keys should start with 'lt_' and be at least 32 characters.\n"
180
+ "Get your API key at: https://app.langtrain.xyz"
181
+ )
182
+
183
+ # Make API call to verify key
184
+ try:
185
+ headers = {"Authorization": f"Bearer {api_key}"}
186
+ response = requests.post(AUTH_ENDPOINT, headers=headers, timeout=10)
187
+
188
+ if response.status_code == 200:
189
+ user_data = response.json()
190
+ # Cache the result
191
+ cache[key_hash] = {
192
+ "cached_at": time.time(),
193
+ "data": user_data
194
+ }
195
+ _save_auth_cache(cache)
196
+ return user_data
197
+ elif response.status_code == 401:
198
+ raise AuthenticationError(
199
+ "Invalid API key. Please check your key at: https://app.langtrain.xyz"
200
+ )
201
+ elif response.status_code == 403:
202
+ raise UsageLimitError(
203
+ "API key is valid but access is denied. Your subscription may have expired.\n"
204
+ "Manage your subscription at: https://billing.langtrain.xyz"
205
+ )
206
+ else:
207
+ raise AuthenticationError(
208
+ f"Authentication failed with status {response.status_code}. "
209
+ "Please try again or contact support."
210
+ )
211
+ except requests.exceptions.RequestException as e:
212
+ # If we can't reach the API, use cached data if available
213
+ if key_hash in cache:
214
+ return cache[key_hash].get("data", {})
215
+ raise AuthenticationError(
216
+ f"Could not verify API key: {e}\n"
217
+ "Please check your internet connection."
218
+ )
219
+
220
+
221
+ def check_usage(api_key: str) -> Dict[str, Any]:
222
+ """Check current usage against limits."""
223
+ user_data = verify_api_key(api_key)
224
+ usage = user_data.get("usage", {})
225
+
226
+ tokens_used = usage.get("tokens_used", 0)
227
+ tokens_limit = usage.get("tokens_limit", 100000)
228
+ requests_used = usage.get("requests_used", 0)
229
+ requests_limit = usage.get("requests_limit", 1000)
230
+
231
+ if tokens_used >= tokens_limit:
232
+ raise UsageLimitError(
233
+ f"Token limit exceeded ({tokens_used:,}/{tokens_limit:,}).\n"
234
+ "Upgrade your plan at: https://billing.langtrain.xyz"
235
+ )
236
+
237
+ if requests_used >= requests_limit:
238
+ raise UsageLimitError(
239
+ f"Request limit exceeded ({requests_used:,}/{requests_limit:,}).\n"
240
+ "Upgrade your plan at: https://billing.langtrain.xyz"
241
+ )
242
+
243
+ return {
244
+ "tokens_used": tokens_used,
245
+ "tokens_limit": tokens_limit,
246
+ "tokens_remaining": tokens_limit - tokens_used,
247
+ "requests_used": requests_used,
248
+ "requests_limit": requests_limit,
249
+ "requests_remaining": requests_limit - requests_used,
250
+ "plan": user_data.get("plan", "free")
251
+ }
252
+
253
+
254
+ def require_auth(func):
255
+ """Decorator to require API key authentication for a function."""
256
+ def wrapper(*args, **kwargs):
257
+ api_key = get_api_key()
258
+
259
+ if not api_key:
260
+ _print_auth_required()
261
+ sys.exit(1)
262
+
263
+ try:
264
+ verify_api_key(api_key)
265
+ except AuthenticationError as e:
266
+ _print_auth_error(str(e))
267
+ sys.exit(1)
268
+ except UsageLimitError as e:
269
+ _print_usage_error(str(e))
270
+ sys.exit(1)
271
+
272
+ return func(*args, **kwargs)
273
+
274
+ return wrapper
275
+
276
+
277
+ def _print_auth_required():
278
+ """Print authentication required message."""
279
+ if RICH_AVAILABLE:
280
+ text = Text()
281
+ text.append("\n🔐 ", style="")
282
+ text.append("API Key Required\n\n", style="bold red")
283
+ text.append("Langtune requires an API key to run. Get your free key at:\n", style="")
284
+ text.append("https://app.langtrain.xyz\n\n", style="blue underline")
285
+ text.append("Once you have your key, authenticate using:\n\n", style="")
286
+ text.append(" langtune auth login\n\n", style="cyan")
287
+ text.append("Or set the environment variable:\n\n", style="")
288
+ text.append(f" export {API_KEY_ENV}=lt_your_api_key_here\n", style="cyan")
289
+
290
+ panel = Panel(text, title="[bold]Authentication Required[/]", border_style="red", box=box.ROUNDED)
291
+ console.print(panel)
292
+ else:
293
+ print(f"""
294
+ 🔐 API Key Required
295
+
296
+ Langtune requires an API key to run. Get your free key at:
297
+ https://app.langtrain.xyz
298
+
299
+ Once you have your key, authenticate using:
300
+
301
+ langtune auth login
302
+
303
+ Or set the environment variable:
304
+
305
+ export {API_KEY_ENV}=lt_your_api_key_here
306
+ """)
307
+
308
+
309
+ def _print_auth_error(message: str):
310
+ """Print authentication error message."""
311
+ if RICH_AVAILABLE:
312
+ console.print(f"\n[bold red]❌ Authentication Error[/]\n")
313
+ console.print(f"[red]{message}[/]\n")
314
+ else:
315
+ print(f"\n❌ Authentication Error\n")
316
+ print(f"{message}\n")
317
+
318
+
319
+ def _print_usage_error(message: str):
320
+ """Print usage limit error message."""
321
+ if RICH_AVAILABLE:
322
+ console.print(f"\n[bold yellow]⚠️ Usage Limit Reached[/]\n")
323
+ console.print(f"[yellow]{message}[/]\n")
324
+ else:
325
+ print(f"\n⚠️ Usage Limit Reached\n")
326
+ print(f"{message}\n")
327
+
328
+
329
+ def print_usage_info():
330
+ """Print current usage information."""
331
+ api_key = get_api_key()
332
+
333
+ if not api_key:
334
+ _print_auth_required()
335
+ return
336
+
337
+ try:
338
+ usage = check_usage(api_key)
339
+
340
+ if RICH_AVAILABLE:
341
+ from rich.table import Table
342
+
343
+ table = Table(title="Langtune Usage", box=box.ROUNDED, title_style="bold cyan")
344
+ table.add_column("Metric", style="cyan", no_wrap=True)
345
+ table.add_column("Used", style="white", justify="right")
346
+ table.add_column("Limit", style="white", justify="right")
347
+ table.add_column("Remaining", style="green", justify="right")
348
+
349
+ table.add_row(
350
+ "Tokens",
351
+ f"{usage['tokens_used']:,}",
352
+ f"{usage['tokens_limit']:,}",
353
+ f"{usage['tokens_remaining']:,}"
354
+ )
355
+ table.add_row(
356
+ "Requests",
357
+ f"{usage['requests_used']:,}",
358
+ f"{usage['requests_limit']:,}",
359
+ f"{usage['requests_remaining']:,}"
360
+ )
361
+
362
+ console.print()
363
+ console.print(f"[dim]Plan:[/] [bold]{usage['plan'].title()}[/]")
364
+ console.print(table)
365
+ console.print()
366
+ console.print("[dim]Manage your plan at:[/] [blue underline]https://billing.langtrain.xyz[/]\n")
367
+ else:
368
+ print(f"\nPlan: {usage['plan'].title()}")
369
+ print(f"Tokens: {usage['tokens_used']:,} / {usage['tokens_limit']:,}")
370
+ print(f"Requests: {usage['requests_used']:,} / {usage['requests_limit']:,}")
371
+ print(f"\nManage your plan at: https://billing.langtrain.xyz\n")
372
+
373
+ except (AuthenticationError, UsageLimitError) as e:
374
+ if RICH_AVAILABLE:
375
+ console.print(f"[red]{e}[/]")
376
+ else:
377
+ print(str(e))
378
+
379
+
380
+ def interactive_login():
381
+ """Interactive login flow."""
382
+ if RICH_AVAILABLE:
383
+ console.print("\n[bold cyan]🔐 Langtune Authentication[/]\n")
384
+ console.print("Get your API key at: [blue underline]https://app.langtrain.xyz[/]\n")
385
+ api_key = console.input("[bold]Enter your API key:[/] ")
386
+ else:
387
+ print("\n🔐 Langtune Authentication\n")
388
+ print("Get your API key at: https://app.langtrain.xyz\n")
389
+ api_key = input("Enter your API key: ")
390
+
391
+ api_key = api_key.strip()
392
+
393
+ if not api_key:
394
+ if RICH_AVAILABLE:
395
+ console.print("[red]No API key entered.[/]")
396
+ else:
397
+ print("No API key entered.")
398
+ return False
399
+
400
+ try:
401
+ user_data = verify_api_key(api_key, force_refresh=True)
402
+ set_api_key(api_key)
403
+
404
+ if RICH_AVAILABLE:
405
+ console.print(f"\n[bold green]✓ Authentication successful![/]")
406
+ console.print(f"[dim]Plan:[/] [bold]{user_data.get('plan', 'free').title()}[/]")
407
+ console.print(f"\n[dim]You're ready to use Langtune. Run[/] [cyan]langtune info[/] [dim]to get started.[/]\n")
408
+ else:
409
+ print(f"\n✓ Authentication successful!")
410
+ print(f"Plan: {user_data.get('plan', 'free').title()}")
411
+ print(f"\nYou're ready to use Langtune. Run 'langtune info' to get started.\n")
412
+
413
+ return True
414
+
415
+ except AuthenticationError as e:
416
+ _print_auth_error(str(e))
417
+ return False
418
+
419
+
420
+ def logout():
421
+ """Remove stored API key."""
422
+ config = _load_config()
423
+ if "api_key" in config:
424
+ del config["api_key"]
425
+ _save_config(config)
426
+
427
+ # Clear cache
428
+ if CACHE_FILE.exists():
429
+ CACHE_FILE.unlink()
430
+
431
+ if RICH_AVAILABLE:
432
+ console.print("[green]✓[/] Logged out successfully. API key removed from ~/.langtune/config.json")
433
+ else:
434
+ print("✓ Logged out successfully. API key removed from ~/.langtune/config.json")
langtune/callbacks.py ADDED
@@ -0,0 +1,268 @@
1
+ """
2
+ callbacks.py: Training callbacks for Langtune
3
+
4
+ Provides extensible callback system for training hooks.
5
+ """
6
+
7
+ import logging
8
+ from typing import Dict, Any, Optional, List, Callable
9
+ from pathlib import Path
10
+ import json
11
+ import time
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class Callback:
17
+ """Base callback class. Override methods to customize training behavior."""
18
+
19
+ def on_train_begin(self, trainer, **kwargs):
20
+ """Called at the start of training."""
21
+ pass
22
+
23
+ def on_train_end(self, trainer, **kwargs):
24
+ """Called at the end of training."""
25
+ pass
26
+
27
+ def on_epoch_begin(self, trainer, epoch: int, **kwargs):
28
+ """Called at the start of each epoch."""
29
+ pass
30
+
31
+ def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, float], **kwargs):
32
+ """Called at the end of each epoch."""
33
+ pass
34
+
35
+ def on_batch_begin(self, trainer, batch_idx: int, **kwargs):
36
+ """Called at the start of each batch."""
37
+ pass
38
+
39
+ def on_batch_end(self, trainer, batch_idx: int, loss: float, **kwargs):
40
+ """Called at the end of each batch."""
41
+ pass
42
+
43
+ def on_validation_begin(self, trainer, **kwargs):
44
+ """Called at the start of validation."""
45
+ pass
46
+
47
+ def on_validation_end(self, trainer, metrics: Dict[str, float], **kwargs):
48
+ """Called at the end of validation."""
49
+ pass
50
+
51
+
52
+ class CallbackList:
53
+ """Container for multiple callbacks."""
54
+
55
+ def __init__(self, callbacks: Optional[List[Callback]] = None):
56
+ self.callbacks = callbacks or []
57
+
58
+ def add(self, callback: Callback):
59
+ """Add a callback."""
60
+ self.callbacks.append(callback)
61
+
62
+ def on_train_begin(self, trainer, **kwargs):
63
+ for cb in self.callbacks:
64
+ cb.on_train_begin(trainer, **kwargs)
65
+
66
+ def on_train_end(self, trainer, **kwargs):
67
+ for cb in self.callbacks:
68
+ cb.on_train_end(trainer, **kwargs)
69
+
70
+ def on_epoch_begin(self, trainer, epoch: int, **kwargs):
71
+ for cb in self.callbacks:
72
+ cb.on_epoch_begin(trainer, epoch, **kwargs)
73
+
74
+ def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, float], **kwargs):
75
+ for cb in self.callbacks:
76
+ cb.on_epoch_end(trainer, epoch, metrics, **kwargs)
77
+
78
+ def on_batch_begin(self, trainer, batch_idx: int, **kwargs):
79
+ for cb in self.callbacks:
80
+ cb.on_batch_begin(trainer, batch_idx, **kwargs)
81
+
82
+ def on_batch_end(self, trainer, batch_idx: int, loss: float, **kwargs):
83
+ for cb in self.callbacks:
84
+ cb.on_batch_end(trainer, batch_idx, loss, **kwargs)
85
+
86
+ def on_validation_begin(self, trainer, **kwargs):
87
+ for cb in self.callbacks:
88
+ cb.on_validation_begin(trainer, **kwargs)
89
+
90
+ def on_validation_end(self, trainer, metrics: Dict[str, float], **kwargs):
91
+ for cb in self.callbacks:
92
+ cb.on_validation_end(trainer, metrics, **kwargs)
93
+
94
+
95
+ class ProgressCallback(Callback):
96
+ """Logs training progress."""
97
+
98
+ def __init__(self, log_every: int = 10):
99
+ self.log_every = log_every
100
+ self.batch_count = 0
101
+
102
+ def on_epoch_begin(self, trainer, epoch: int, **kwargs):
103
+ self.batch_count = 0
104
+ logger.info(f"Starting epoch {epoch + 1}")
105
+
106
+ def on_batch_end(self, trainer, batch_idx: int, loss: float, **kwargs):
107
+ self.batch_count += 1
108
+ if self.batch_count % self.log_every == 0:
109
+ logger.info(f"Batch {batch_idx}, Loss: {loss:.4f}")
110
+
111
+ def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, float], **kwargs):
112
+ metrics_str = ", ".join(f"{k}: {v:.4f}" for k, v in metrics.items())
113
+ logger.info(f"Epoch {epoch + 1} completed - {metrics_str}")
114
+
115
+
116
+ class LearningRateMonitorCallback(Callback):
117
+ """Monitor and log learning rate."""
118
+
119
+ def __init__(self):
120
+ self.lrs = []
121
+
122
+ def on_batch_end(self, trainer, batch_idx: int, loss: float, **kwargs):
123
+ if hasattr(trainer, 'optimizer'):
124
+ lr = trainer.optimizer.param_groups[0]['lr']
125
+ self.lrs.append(lr)
126
+
127
+
128
+ class GradientMonitorCallback(Callback):
129
+ """Monitor gradient statistics for debugging."""
130
+
131
+ def __init__(self, log_every: int = 100):
132
+ self.log_every = log_every
133
+ self.step = 0
134
+
135
+ def on_batch_end(self, trainer, batch_idx: int, loss: float, **kwargs):
136
+ self.step += 1
137
+ if self.step % self.log_every == 0 and hasattr(trainer, 'model'):
138
+ total_norm = 0.0
139
+ param_count = 0
140
+ for p in trainer.model.parameters():
141
+ if p.grad is not None:
142
+ total_norm += p.grad.data.norm(2).item() ** 2
143
+ param_count += 1
144
+
145
+ if param_count > 0:
146
+ total_norm = total_norm ** 0.5
147
+ logger.info(f"Step {self.step}: Gradient norm = {total_norm:.4f}")
148
+
149
+
150
+ class ModelSizeCallback(Callback):
151
+ """Log model size information at training start."""
152
+
153
+ def on_train_begin(self, trainer, **kwargs):
154
+ if hasattr(trainer, 'model'):
155
+ total = sum(p.numel() for p in trainer.model.parameters())
156
+ trainable = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)
157
+
158
+ logger.info(f"Model size: {total:,} total params, {trainable:,} trainable ({100*trainable/total:.1f}%)")
159
+
160
+
161
+ class TimerCallback(Callback):
162
+ """Track training time."""
163
+
164
+ def __init__(self):
165
+ self.train_start = None
166
+ self.epoch_start = None
167
+ self.epoch_times = []
168
+
169
+ def on_train_begin(self, trainer, **kwargs):
170
+ self.train_start = time.time()
171
+
172
+ def on_epoch_begin(self, trainer, epoch: int, **kwargs):
173
+ self.epoch_start = time.time()
174
+
175
+ def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, float], **kwargs):
176
+ if self.epoch_start:
177
+ elapsed = time.time() - self.epoch_start
178
+ self.epoch_times.append(elapsed)
179
+ logger.info(f"Epoch {epoch + 1} took {elapsed:.1f}s")
180
+
181
+ def on_train_end(self, trainer, **kwargs):
182
+ if self.train_start:
183
+ total = time.time() - self.train_start
184
+ logger.info(f"Total training time: {total:.1f}s ({total/60:.1f}m)")
185
+
186
+
187
+ class SaveHistoryCallback(Callback):
188
+ """Save training history to JSON."""
189
+
190
+ def __init__(self, save_path: str = "training_history.json"):
191
+ self.save_path = save_path
192
+ self.history = {"epochs": [], "metrics": []}
193
+
194
+ def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, float], **kwargs):
195
+ self.history["epochs"].append(epoch + 1)
196
+ self.history["metrics"].append(metrics)
197
+
198
+ with open(self.save_path, 'w') as f:
199
+ json.dump(self.history, f, indent=2)
200
+
201
+
202
+ class MemoryMonitorCallback(Callback):
203
+ """Monitor GPU memory usage."""
204
+
205
+ def __init__(self, log_every: int = 50):
206
+ self.log_every = log_every
207
+ self.step = 0
208
+
209
+ def on_batch_end(self, trainer, batch_idx: int, loss: float, **kwargs):
210
+ import torch
211
+ self.step += 1
212
+
213
+ if self.step % self.log_every == 0 and torch.cuda.is_available():
214
+ allocated = torch.cuda.memory_allocated() / 1e9
215
+ reserved = torch.cuda.memory_reserved() / 1e9
216
+ logger.info(f"Step {self.step}: GPU Memory - {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
217
+
218
+
219
+ class WandbCallback(Callback):
220
+ """Log to Weights & Biases."""
221
+
222
+ def __init__(self, project: str = "langtune", run_name: Optional[str] = None):
223
+ self.project = project
224
+ self.run_name = run_name
225
+ self.wandb = None
226
+
227
+ def on_train_begin(self, trainer, **kwargs):
228
+ try:
229
+ import wandb
230
+ wandb.init(project=self.project, name=self.run_name)
231
+ self.wandb = wandb
232
+ except ImportError:
233
+ logger.warning("wandb not installed, skipping W&B logging")
234
+
235
+ def on_batch_end(self, trainer, batch_idx: int, loss: float, **kwargs):
236
+ if self.wandb:
237
+ self.wandb.log({"train_loss": loss})
238
+
239
+ def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, float], **kwargs):
240
+ if self.wandb:
241
+ self.wandb.log({"epoch": epoch + 1, **metrics})
242
+
243
+ def on_train_end(self, trainer, **kwargs):
244
+ if self.wandb:
245
+ self.wandb.finish()
246
+
247
+
248
+ # Default callback presets
249
+ def get_default_callbacks() -> CallbackList:
250
+ """Get a list of recommended default callbacks."""
251
+ return CallbackList([
252
+ ModelSizeCallback(),
253
+ ProgressCallback(log_every=10),
254
+ TimerCallback(),
255
+ LearningRateMonitorCallback()
256
+ ])
257
+
258
+
259
+ def get_verbose_callbacks() -> CallbackList:
260
+ """Get verbose callbacks for debugging."""
261
+ return CallbackList([
262
+ ModelSizeCallback(),
263
+ ProgressCallback(log_every=1),
264
+ TimerCallback(),
265
+ LearningRateMonitorCallback(),
266
+ GradientMonitorCallback(log_every=50),
267
+ MemoryMonitorCallback(log_every=50)
268
+ ])