langtune 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langtune/__init__.py +315 -0
- langtune/acceleration.py +132 -0
- langtune/api.py +320 -0
- langtune/auth.py +434 -0
- langtune/callbacks.py +268 -0
- langtune/cli.py +687 -0
- langtune/client.py +721 -0
- langtune/config.py +356 -0
- langtune/data.py +526 -0
- langtune/distributed.py +154 -0
- langtune/facade.py +174 -0
- langtune/finetune.py +491 -0
- langtune/generation.py +95 -0
- langtune/logging_utils.py +182 -0
- langtune/metrics.py +345 -0
- langtune/model/__init__.py +20 -0
- langtune/model/hub.py +109 -0
- langtune/model/loader.py +84 -0
- langtune/model/safetensors.py +104 -0
- langtune/model/weights.py +100 -0
- langtune/models.py +19 -0
- langtune/nn/fast_transformer.py +399 -0
- langtune/nn/layers.py +178 -0
- langtune/nn/transformer.py +254 -0
- langtune/optimizations.py +870 -0
- langtune/py.typed +2 -0
- langtune/schedulers.py +234 -0
- langtune/tokenizers.py +275 -0
- langtune/trainer.py +889 -0
- langtune/training/neftune.py +80 -0
- langtune/utils.py +337 -0
- langtune-0.1.19.dist-info/METADATA +257 -0
- langtune-0.1.19.dist-info/RECORD +37 -0
- langtune-0.1.19.dist-info/WHEEL +5 -0
- langtune-0.1.19.dist-info/entry_points.txt +2 -0
- langtune-0.1.19.dist-info/licenses/LICENSE +21 -0
- langtune-0.1.19.dist-info/top_level.txt +1 -0
langtune/auth.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
1
|
+
"""
|
|
2
|
+
auth.py: API Key authentication and usage tracking for Langtune
|
|
3
|
+
|
|
4
|
+
Users must obtain an API key from https://langtrain.xyz to use this package.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
import json
|
|
10
|
+
import hashlib
|
|
11
|
+
import time
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Optional, Dict, Any
|
|
14
|
+
from datetime import datetime, timedelta
|
|
15
|
+
|
|
16
|
+
# Try to import requests for API calls
|
|
17
|
+
try:
|
|
18
|
+
import requests
|
|
19
|
+
REQUESTS_AVAILABLE = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
REQUESTS_AVAILABLE = False
|
|
22
|
+
|
|
23
|
+
# Try to import rich for beautiful output
|
|
24
|
+
try:
|
|
25
|
+
from rich.console import Console
|
|
26
|
+
from rich.panel import Panel
|
|
27
|
+
from rich.text import Text
|
|
28
|
+
from rich import box
|
|
29
|
+
RICH_AVAILABLE = True
|
|
30
|
+
console = Console()
|
|
31
|
+
except ImportError:
|
|
32
|
+
RICH_AVAILABLE = False
|
|
33
|
+
console = None
|
|
34
|
+
|
|
35
|
+
# API configuration
|
|
36
|
+
API_BASE_URL = "https://api.langtrain.xyz"
|
|
37
|
+
AUTH_ENDPOINT = f"{API_BASE_URL}/v1/auth/verify"
|
|
38
|
+
USAGE_ENDPOINT = f"{API_BASE_URL}/v1/usage"
|
|
39
|
+
|
|
40
|
+
# Config paths
|
|
41
|
+
CONFIG_DIR = Path.home() / ".langtune"
|
|
42
|
+
CONFIG_FILE = CONFIG_DIR / "config.json"
|
|
43
|
+
CACHE_FILE = CONFIG_DIR / ".auth_cache"
|
|
44
|
+
|
|
45
|
+
# Environment variable names
|
|
46
|
+
API_KEY_ENV = "LANGTUNE_API_KEY"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class AuthenticationError(Exception):
|
|
50
|
+
"""Raised when API key authentication fails."""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class UsageLimitError(Exception):
|
|
55
|
+
"""Raised when usage limit is exceeded."""
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _get_config_dir() -> Path:
|
|
60
|
+
"""Get or create the config directory."""
|
|
61
|
+
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
|
62
|
+
return CONFIG_DIR
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _load_config() -> Dict[str, Any]:
|
|
66
|
+
"""Load configuration from file."""
|
|
67
|
+
if CONFIG_FILE.exists():
|
|
68
|
+
try:
|
|
69
|
+
with open(CONFIG_FILE, 'r') as f:
|
|
70
|
+
return json.load(f)
|
|
71
|
+
except (json.JSONDecodeError, IOError):
|
|
72
|
+
return {}
|
|
73
|
+
return {}
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _save_config(config: Dict[str, Any]) -> None:
|
|
77
|
+
"""Save configuration to file."""
|
|
78
|
+
_get_config_dir()
|
|
79
|
+
with open(CONFIG_FILE, 'w') as f:
|
|
80
|
+
json.dump(config, f, indent=2)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _load_auth_cache() -> Dict[str, Any]:
|
|
84
|
+
"""Load cached authentication data."""
|
|
85
|
+
if CACHE_FILE.exists():
|
|
86
|
+
try:
|
|
87
|
+
with open(CACHE_FILE, 'r') as f:
|
|
88
|
+
return json.load(f)
|
|
89
|
+
except (json.JSONDecodeError, IOError):
|
|
90
|
+
return {}
|
|
91
|
+
return {}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _save_auth_cache(cache: Dict[str, Any]) -> None:
|
|
95
|
+
"""Save authentication cache."""
|
|
96
|
+
_get_config_dir()
|
|
97
|
+
with open(CACHE_FILE, 'w') as f:
|
|
98
|
+
json.dump(cache, f)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_api_key() -> Optional[str]:
|
|
102
|
+
"""
|
|
103
|
+
Get the API key from environment or config file.
|
|
104
|
+
|
|
105
|
+
Priority:
|
|
106
|
+
1. LANGTUNE_API_KEY environment variable
|
|
107
|
+
2. Config file (~/.langtune/config.json)
|
|
108
|
+
"""
|
|
109
|
+
# Check environment variable first
|
|
110
|
+
api_key = os.environ.get(API_KEY_ENV)
|
|
111
|
+
if api_key:
|
|
112
|
+
return api_key
|
|
113
|
+
|
|
114
|
+
# Check config file
|
|
115
|
+
config = _load_config()
|
|
116
|
+
return config.get("api_key")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def set_api_key(api_key: str) -> None:
|
|
120
|
+
"""Save API key to config file."""
|
|
121
|
+
config = _load_config()
|
|
122
|
+
config["api_key"] = api_key
|
|
123
|
+
_save_config(config)
|
|
124
|
+
|
|
125
|
+
if RICH_AVAILABLE:
|
|
126
|
+
console.print("[green]✓[/] API key saved to ~/.langtune/config.json")
|
|
127
|
+
else:
|
|
128
|
+
print("✓ API key saved to ~/.langtune/config.json")
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _hash_key(api_key: str) -> str:
|
|
132
|
+
"""Hash API key for cache lookup."""
|
|
133
|
+
return hashlib.sha256(api_key.encode()).hexdigest()[:16]
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def verify_api_key(api_key: str, force_refresh: bool = False) -> Dict[str, Any]:
|
|
137
|
+
"""
|
|
138
|
+
Verify API key with the Langtrain API.
|
|
139
|
+
|
|
140
|
+
Returns user info and usage limits on success.
|
|
141
|
+
Raises AuthenticationError on failure.
|
|
142
|
+
"""
|
|
143
|
+
# Check cache first (valid for 1 hour)
|
|
144
|
+
cache = _load_auth_cache()
|
|
145
|
+
key_hash = _hash_key(api_key)
|
|
146
|
+
|
|
147
|
+
if not force_refresh and key_hash in cache:
|
|
148
|
+
cached_data = cache[key_hash]
|
|
149
|
+
cache_time = cached_data.get("cached_at", 0)
|
|
150
|
+
if time.time() - cache_time < 3600: # 1 hour cache
|
|
151
|
+
return cached_data.get("data", {})
|
|
152
|
+
|
|
153
|
+
# For now, simulate API verification (offline mode)
|
|
154
|
+
# In production, this would make an actual API call
|
|
155
|
+
if not REQUESTS_AVAILABLE:
|
|
156
|
+
# Offline verification - accept keys that match pattern
|
|
157
|
+
if api_key.startswith("lt_") and len(api_key) >= 32:
|
|
158
|
+
user_data = {
|
|
159
|
+
"valid": True,
|
|
160
|
+
"user_id": key_hash,
|
|
161
|
+
"plan": "free",
|
|
162
|
+
"usage": {
|
|
163
|
+
"tokens_used": 0,
|
|
164
|
+
"tokens_limit": 100000,
|
|
165
|
+
"requests_used": 0,
|
|
166
|
+
"requests_limit": 1000
|
|
167
|
+
},
|
|
168
|
+
"offline_mode": True
|
|
169
|
+
}
|
|
170
|
+
# Cache the result
|
|
171
|
+
cache[key_hash] = {
|
|
172
|
+
"cached_at": time.time(),
|
|
173
|
+
"data": user_data
|
|
174
|
+
}
|
|
175
|
+
_save_auth_cache(cache)
|
|
176
|
+
return user_data
|
|
177
|
+
else:
|
|
178
|
+
raise AuthenticationError(
|
|
179
|
+
"Invalid API key format. Keys should start with 'lt_' and be at least 32 characters.\n"
|
|
180
|
+
"Get your API key at: https://app.langtrain.xyz"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Make API call to verify key
|
|
184
|
+
try:
|
|
185
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
|
186
|
+
response = requests.post(AUTH_ENDPOINT, headers=headers, timeout=10)
|
|
187
|
+
|
|
188
|
+
if response.status_code == 200:
|
|
189
|
+
user_data = response.json()
|
|
190
|
+
# Cache the result
|
|
191
|
+
cache[key_hash] = {
|
|
192
|
+
"cached_at": time.time(),
|
|
193
|
+
"data": user_data
|
|
194
|
+
}
|
|
195
|
+
_save_auth_cache(cache)
|
|
196
|
+
return user_data
|
|
197
|
+
elif response.status_code == 401:
|
|
198
|
+
raise AuthenticationError(
|
|
199
|
+
"Invalid API key. Please check your key at: https://app.langtrain.xyz"
|
|
200
|
+
)
|
|
201
|
+
elif response.status_code == 403:
|
|
202
|
+
raise UsageLimitError(
|
|
203
|
+
"API key is valid but access is denied. Your subscription may have expired.\n"
|
|
204
|
+
"Manage your subscription at: https://billing.langtrain.xyz"
|
|
205
|
+
)
|
|
206
|
+
else:
|
|
207
|
+
raise AuthenticationError(
|
|
208
|
+
f"Authentication failed with status {response.status_code}. "
|
|
209
|
+
"Please try again or contact support."
|
|
210
|
+
)
|
|
211
|
+
except requests.exceptions.RequestException as e:
|
|
212
|
+
# If we can't reach the API, use cached data if available
|
|
213
|
+
if key_hash in cache:
|
|
214
|
+
return cache[key_hash].get("data", {})
|
|
215
|
+
raise AuthenticationError(
|
|
216
|
+
f"Could not verify API key: {e}\n"
|
|
217
|
+
"Please check your internet connection."
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def check_usage(api_key: str) -> Dict[str, Any]:
|
|
222
|
+
"""Check current usage against limits."""
|
|
223
|
+
user_data = verify_api_key(api_key)
|
|
224
|
+
usage = user_data.get("usage", {})
|
|
225
|
+
|
|
226
|
+
tokens_used = usage.get("tokens_used", 0)
|
|
227
|
+
tokens_limit = usage.get("tokens_limit", 100000)
|
|
228
|
+
requests_used = usage.get("requests_used", 0)
|
|
229
|
+
requests_limit = usage.get("requests_limit", 1000)
|
|
230
|
+
|
|
231
|
+
if tokens_used >= tokens_limit:
|
|
232
|
+
raise UsageLimitError(
|
|
233
|
+
f"Token limit exceeded ({tokens_used:,}/{tokens_limit:,}).\n"
|
|
234
|
+
"Upgrade your plan at: https://billing.langtrain.xyz"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
if requests_used >= requests_limit:
|
|
238
|
+
raise UsageLimitError(
|
|
239
|
+
f"Request limit exceeded ({requests_used:,}/{requests_limit:,}).\n"
|
|
240
|
+
"Upgrade your plan at: https://billing.langtrain.xyz"
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
return {
|
|
244
|
+
"tokens_used": tokens_used,
|
|
245
|
+
"tokens_limit": tokens_limit,
|
|
246
|
+
"tokens_remaining": tokens_limit - tokens_used,
|
|
247
|
+
"requests_used": requests_used,
|
|
248
|
+
"requests_limit": requests_limit,
|
|
249
|
+
"requests_remaining": requests_limit - requests_used,
|
|
250
|
+
"plan": user_data.get("plan", "free")
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def require_auth(func):
|
|
255
|
+
"""Decorator to require API key authentication for a function."""
|
|
256
|
+
def wrapper(*args, **kwargs):
|
|
257
|
+
api_key = get_api_key()
|
|
258
|
+
|
|
259
|
+
if not api_key:
|
|
260
|
+
_print_auth_required()
|
|
261
|
+
sys.exit(1)
|
|
262
|
+
|
|
263
|
+
try:
|
|
264
|
+
verify_api_key(api_key)
|
|
265
|
+
except AuthenticationError as e:
|
|
266
|
+
_print_auth_error(str(e))
|
|
267
|
+
sys.exit(1)
|
|
268
|
+
except UsageLimitError as e:
|
|
269
|
+
_print_usage_error(str(e))
|
|
270
|
+
sys.exit(1)
|
|
271
|
+
|
|
272
|
+
return func(*args, **kwargs)
|
|
273
|
+
|
|
274
|
+
return wrapper
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _print_auth_required():
|
|
278
|
+
"""Print authentication required message."""
|
|
279
|
+
if RICH_AVAILABLE:
|
|
280
|
+
text = Text()
|
|
281
|
+
text.append("\n🔐 ", style="")
|
|
282
|
+
text.append("API Key Required\n\n", style="bold red")
|
|
283
|
+
text.append("Langtune requires an API key to run. Get your free key at:\n", style="")
|
|
284
|
+
text.append("https://app.langtrain.xyz\n\n", style="blue underline")
|
|
285
|
+
text.append("Once you have your key, authenticate using:\n\n", style="")
|
|
286
|
+
text.append(" langtune auth login\n\n", style="cyan")
|
|
287
|
+
text.append("Or set the environment variable:\n\n", style="")
|
|
288
|
+
text.append(f" export {API_KEY_ENV}=lt_your_api_key_here\n", style="cyan")
|
|
289
|
+
|
|
290
|
+
panel = Panel(text, title="[bold]Authentication Required[/]", border_style="red", box=box.ROUNDED)
|
|
291
|
+
console.print(panel)
|
|
292
|
+
else:
|
|
293
|
+
print(f"""
|
|
294
|
+
🔐 API Key Required
|
|
295
|
+
|
|
296
|
+
Langtune requires an API key to run. Get your free key at:
|
|
297
|
+
https://app.langtrain.xyz
|
|
298
|
+
|
|
299
|
+
Once you have your key, authenticate using:
|
|
300
|
+
|
|
301
|
+
langtune auth login
|
|
302
|
+
|
|
303
|
+
Or set the environment variable:
|
|
304
|
+
|
|
305
|
+
export {API_KEY_ENV}=lt_your_api_key_here
|
|
306
|
+
""")
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def _print_auth_error(message: str):
|
|
310
|
+
"""Print authentication error message."""
|
|
311
|
+
if RICH_AVAILABLE:
|
|
312
|
+
console.print(f"\n[bold red]❌ Authentication Error[/]\n")
|
|
313
|
+
console.print(f"[red]{message}[/]\n")
|
|
314
|
+
else:
|
|
315
|
+
print(f"\n❌ Authentication Error\n")
|
|
316
|
+
print(f"{message}\n")
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def _print_usage_error(message: str):
|
|
320
|
+
"""Print usage limit error message."""
|
|
321
|
+
if RICH_AVAILABLE:
|
|
322
|
+
console.print(f"\n[bold yellow]⚠️ Usage Limit Reached[/]\n")
|
|
323
|
+
console.print(f"[yellow]{message}[/]\n")
|
|
324
|
+
else:
|
|
325
|
+
print(f"\n⚠️ Usage Limit Reached\n")
|
|
326
|
+
print(f"{message}\n")
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def print_usage_info():
|
|
330
|
+
"""Print current usage information."""
|
|
331
|
+
api_key = get_api_key()
|
|
332
|
+
|
|
333
|
+
if not api_key:
|
|
334
|
+
_print_auth_required()
|
|
335
|
+
return
|
|
336
|
+
|
|
337
|
+
try:
|
|
338
|
+
usage = check_usage(api_key)
|
|
339
|
+
|
|
340
|
+
if RICH_AVAILABLE:
|
|
341
|
+
from rich.table import Table
|
|
342
|
+
|
|
343
|
+
table = Table(title="Langtune Usage", box=box.ROUNDED, title_style="bold cyan")
|
|
344
|
+
table.add_column("Metric", style="cyan", no_wrap=True)
|
|
345
|
+
table.add_column("Used", style="white", justify="right")
|
|
346
|
+
table.add_column("Limit", style="white", justify="right")
|
|
347
|
+
table.add_column("Remaining", style="green", justify="right")
|
|
348
|
+
|
|
349
|
+
table.add_row(
|
|
350
|
+
"Tokens",
|
|
351
|
+
f"{usage['tokens_used']:,}",
|
|
352
|
+
f"{usage['tokens_limit']:,}",
|
|
353
|
+
f"{usage['tokens_remaining']:,}"
|
|
354
|
+
)
|
|
355
|
+
table.add_row(
|
|
356
|
+
"Requests",
|
|
357
|
+
f"{usage['requests_used']:,}",
|
|
358
|
+
f"{usage['requests_limit']:,}",
|
|
359
|
+
f"{usage['requests_remaining']:,}"
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
console.print()
|
|
363
|
+
console.print(f"[dim]Plan:[/] [bold]{usage['plan'].title()}[/]")
|
|
364
|
+
console.print(table)
|
|
365
|
+
console.print()
|
|
366
|
+
console.print("[dim]Manage your plan at:[/] [blue underline]https://billing.langtrain.xyz[/]\n")
|
|
367
|
+
else:
|
|
368
|
+
print(f"\nPlan: {usage['plan'].title()}")
|
|
369
|
+
print(f"Tokens: {usage['tokens_used']:,} / {usage['tokens_limit']:,}")
|
|
370
|
+
print(f"Requests: {usage['requests_used']:,} / {usage['requests_limit']:,}")
|
|
371
|
+
print(f"\nManage your plan at: https://billing.langtrain.xyz\n")
|
|
372
|
+
|
|
373
|
+
except (AuthenticationError, UsageLimitError) as e:
|
|
374
|
+
if RICH_AVAILABLE:
|
|
375
|
+
console.print(f"[red]{e}[/]")
|
|
376
|
+
else:
|
|
377
|
+
print(str(e))
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def interactive_login():
|
|
381
|
+
"""Interactive login flow."""
|
|
382
|
+
if RICH_AVAILABLE:
|
|
383
|
+
console.print("\n[bold cyan]🔐 Langtune Authentication[/]\n")
|
|
384
|
+
console.print("Get your API key at: [blue underline]https://app.langtrain.xyz[/]\n")
|
|
385
|
+
api_key = console.input("[bold]Enter your API key:[/] ")
|
|
386
|
+
else:
|
|
387
|
+
print("\n🔐 Langtune Authentication\n")
|
|
388
|
+
print("Get your API key at: https://app.langtrain.xyz\n")
|
|
389
|
+
api_key = input("Enter your API key: ")
|
|
390
|
+
|
|
391
|
+
api_key = api_key.strip()
|
|
392
|
+
|
|
393
|
+
if not api_key:
|
|
394
|
+
if RICH_AVAILABLE:
|
|
395
|
+
console.print("[red]No API key entered.[/]")
|
|
396
|
+
else:
|
|
397
|
+
print("No API key entered.")
|
|
398
|
+
return False
|
|
399
|
+
|
|
400
|
+
try:
|
|
401
|
+
user_data = verify_api_key(api_key, force_refresh=True)
|
|
402
|
+
set_api_key(api_key)
|
|
403
|
+
|
|
404
|
+
if RICH_AVAILABLE:
|
|
405
|
+
console.print(f"\n[bold green]✓ Authentication successful![/]")
|
|
406
|
+
console.print(f"[dim]Plan:[/] [bold]{user_data.get('plan', 'free').title()}[/]")
|
|
407
|
+
console.print(f"\n[dim]You're ready to use Langtune. Run[/] [cyan]langtune info[/] [dim]to get started.[/]\n")
|
|
408
|
+
else:
|
|
409
|
+
print(f"\n✓ Authentication successful!")
|
|
410
|
+
print(f"Plan: {user_data.get('plan', 'free').title()}")
|
|
411
|
+
print(f"\nYou're ready to use Langtune. Run 'langtune info' to get started.\n")
|
|
412
|
+
|
|
413
|
+
return True
|
|
414
|
+
|
|
415
|
+
except AuthenticationError as e:
|
|
416
|
+
_print_auth_error(str(e))
|
|
417
|
+
return False
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def logout():
|
|
421
|
+
"""Remove stored API key."""
|
|
422
|
+
config = _load_config()
|
|
423
|
+
if "api_key" in config:
|
|
424
|
+
del config["api_key"]
|
|
425
|
+
_save_config(config)
|
|
426
|
+
|
|
427
|
+
# Clear cache
|
|
428
|
+
if CACHE_FILE.exists():
|
|
429
|
+
CACHE_FILE.unlink()
|
|
430
|
+
|
|
431
|
+
if RICH_AVAILABLE:
|
|
432
|
+
console.print("[green]✓[/] Logged out successfully. API key removed from ~/.langtune/config.json")
|
|
433
|
+
else:
|
|
434
|
+
print("✓ Logged out successfully. API key removed from ~/.langtune/config.json")
|
langtune/callbacks.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""
|
|
2
|
+
callbacks.py: Training callbacks for Langtune
|
|
3
|
+
|
|
4
|
+
Provides extensible callback system for training hooks.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Dict, Any, Optional, List, Callable
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import json
|
|
11
|
+
import time
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Callback:
|
|
17
|
+
"""Base callback class. Override methods to customize training behavior."""
|
|
18
|
+
|
|
19
|
+
def on_train_begin(self, trainer, **kwargs):
|
|
20
|
+
"""Called at the start of training."""
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
def on_train_end(self, trainer, **kwargs):
|
|
24
|
+
"""Called at the end of training."""
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
def on_epoch_begin(self, trainer, epoch: int, **kwargs):
|
|
28
|
+
"""Called at the start of each epoch."""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, float], **kwargs):
|
|
32
|
+
"""Called at the end of each epoch."""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
def on_batch_begin(self, trainer, batch_idx: int, **kwargs):
|
|
36
|
+
"""Called at the start of each batch."""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
def on_batch_end(self, trainer, batch_idx: int, loss: float, **kwargs):
|
|
40
|
+
"""Called at the end of each batch."""
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
def on_validation_begin(self, trainer, **kwargs):
|
|
44
|
+
"""Called at the start of validation."""
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
def on_validation_end(self, trainer, metrics: Dict[str, float], **kwargs):
|
|
48
|
+
"""Called at the end of validation."""
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class CallbackList:
|
|
53
|
+
"""Container for multiple callbacks."""
|
|
54
|
+
|
|
55
|
+
def __init__(self, callbacks: Optional[List[Callback]] = None):
|
|
56
|
+
self.callbacks = callbacks or []
|
|
57
|
+
|
|
58
|
+
def add(self, callback: Callback):
|
|
59
|
+
"""Add a callback."""
|
|
60
|
+
self.callbacks.append(callback)
|
|
61
|
+
|
|
62
|
+
def on_train_begin(self, trainer, **kwargs):
|
|
63
|
+
for cb in self.callbacks:
|
|
64
|
+
cb.on_train_begin(trainer, **kwargs)
|
|
65
|
+
|
|
66
|
+
def on_train_end(self, trainer, **kwargs):
|
|
67
|
+
for cb in self.callbacks:
|
|
68
|
+
cb.on_train_end(trainer, **kwargs)
|
|
69
|
+
|
|
70
|
+
def on_epoch_begin(self, trainer, epoch: int, **kwargs):
|
|
71
|
+
for cb in self.callbacks:
|
|
72
|
+
cb.on_epoch_begin(trainer, epoch, **kwargs)
|
|
73
|
+
|
|
74
|
+
def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, float], **kwargs):
|
|
75
|
+
for cb in self.callbacks:
|
|
76
|
+
cb.on_epoch_end(trainer, epoch, metrics, **kwargs)
|
|
77
|
+
|
|
78
|
+
def on_batch_begin(self, trainer, batch_idx: int, **kwargs):
|
|
79
|
+
for cb in self.callbacks:
|
|
80
|
+
cb.on_batch_begin(trainer, batch_idx, **kwargs)
|
|
81
|
+
|
|
82
|
+
def on_batch_end(self, trainer, batch_idx: int, loss: float, **kwargs):
|
|
83
|
+
for cb in self.callbacks:
|
|
84
|
+
cb.on_batch_end(trainer, batch_idx, loss, **kwargs)
|
|
85
|
+
|
|
86
|
+
def on_validation_begin(self, trainer, **kwargs):
|
|
87
|
+
for cb in self.callbacks:
|
|
88
|
+
cb.on_validation_begin(trainer, **kwargs)
|
|
89
|
+
|
|
90
|
+
def on_validation_end(self, trainer, metrics: Dict[str, float], **kwargs):
|
|
91
|
+
for cb in self.callbacks:
|
|
92
|
+
cb.on_validation_end(trainer, metrics, **kwargs)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class ProgressCallback(Callback):
|
|
96
|
+
"""Logs training progress."""
|
|
97
|
+
|
|
98
|
+
def __init__(self, log_every: int = 10):
|
|
99
|
+
self.log_every = log_every
|
|
100
|
+
self.batch_count = 0
|
|
101
|
+
|
|
102
|
+
def on_epoch_begin(self, trainer, epoch: int, **kwargs):
|
|
103
|
+
self.batch_count = 0
|
|
104
|
+
logger.info(f"Starting epoch {epoch + 1}")
|
|
105
|
+
|
|
106
|
+
def on_batch_end(self, trainer, batch_idx: int, loss: float, **kwargs):
|
|
107
|
+
self.batch_count += 1
|
|
108
|
+
if self.batch_count % self.log_every == 0:
|
|
109
|
+
logger.info(f"Batch {batch_idx}, Loss: {loss:.4f}")
|
|
110
|
+
|
|
111
|
+
def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, float], **kwargs):
|
|
112
|
+
metrics_str = ", ".join(f"{k}: {v:.4f}" for k, v in metrics.items())
|
|
113
|
+
logger.info(f"Epoch {epoch + 1} completed - {metrics_str}")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class LearningRateMonitorCallback(Callback):
|
|
117
|
+
"""Monitor and log learning rate."""
|
|
118
|
+
|
|
119
|
+
def __init__(self):
|
|
120
|
+
self.lrs = []
|
|
121
|
+
|
|
122
|
+
def on_batch_end(self, trainer, batch_idx: int, loss: float, **kwargs):
|
|
123
|
+
if hasattr(trainer, 'optimizer'):
|
|
124
|
+
lr = trainer.optimizer.param_groups[0]['lr']
|
|
125
|
+
self.lrs.append(lr)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class GradientMonitorCallback(Callback):
|
|
129
|
+
"""Monitor gradient statistics for debugging."""
|
|
130
|
+
|
|
131
|
+
def __init__(self, log_every: int = 100):
|
|
132
|
+
self.log_every = log_every
|
|
133
|
+
self.step = 0
|
|
134
|
+
|
|
135
|
+
def on_batch_end(self, trainer, batch_idx: int, loss: float, **kwargs):
|
|
136
|
+
self.step += 1
|
|
137
|
+
if self.step % self.log_every == 0 and hasattr(trainer, 'model'):
|
|
138
|
+
total_norm = 0.0
|
|
139
|
+
param_count = 0
|
|
140
|
+
for p in trainer.model.parameters():
|
|
141
|
+
if p.grad is not None:
|
|
142
|
+
total_norm += p.grad.data.norm(2).item() ** 2
|
|
143
|
+
param_count += 1
|
|
144
|
+
|
|
145
|
+
if param_count > 0:
|
|
146
|
+
total_norm = total_norm ** 0.5
|
|
147
|
+
logger.info(f"Step {self.step}: Gradient norm = {total_norm:.4f}")
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class ModelSizeCallback(Callback):
|
|
151
|
+
"""Log model size information at training start."""
|
|
152
|
+
|
|
153
|
+
def on_train_begin(self, trainer, **kwargs):
|
|
154
|
+
if hasattr(trainer, 'model'):
|
|
155
|
+
total = sum(p.numel() for p in trainer.model.parameters())
|
|
156
|
+
trainable = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)
|
|
157
|
+
|
|
158
|
+
logger.info(f"Model size: {total:,} total params, {trainable:,} trainable ({100*trainable/total:.1f}%)")
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class TimerCallback(Callback):
|
|
162
|
+
"""Track training time."""
|
|
163
|
+
|
|
164
|
+
def __init__(self):
|
|
165
|
+
self.train_start = None
|
|
166
|
+
self.epoch_start = None
|
|
167
|
+
self.epoch_times = []
|
|
168
|
+
|
|
169
|
+
def on_train_begin(self, trainer, **kwargs):
|
|
170
|
+
self.train_start = time.time()
|
|
171
|
+
|
|
172
|
+
def on_epoch_begin(self, trainer, epoch: int, **kwargs):
|
|
173
|
+
self.epoch_start = time.time()
|
|
174
|
+
|
|
175
|
+
def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, float], **kwargs):
|
|
176
|
+
if self.epoch_start:
|
|
177
|
+
elapsed = time.time() - self.epoch_start
|
|
178
|
+
self.epoch_times.append(elapsed)
|
|
179
|
+
logger.info(f"Epoch {epoch + 1} took {elapsed:.1f}s")
|
|
180
|
+
|
|
181
|
+
def on_train_end(self, trainer, **kwargs):
|
|
182
|
+
if self.train_start:
|
|
183
|
+
total = time.time() - self.train_start
|
|
184
|
+
logger.info(f"Total training time: {total:.1f}s ({total/60:.1f}m)")
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class SaveHistoryCallback(Callback):
|
|
188
|
+
"""Save training history to JSON."""
|
|
189
|
+
|
|
190
|
+
def __init__(self, save_path: str = "training_history.json"):
|
|
191
|
+
self.save_path = save_path
|
|
192
|
+
self.history = {"epochs": [], "metrics": []}
|
|
193
|
+
|
|
194
|
+
def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, float], **kwargs):
|
|
195
|
+
self.history["epochs"].append(epoch + 1)
|
|
196
|
+
self.history["metrics"].append(metrics)
|
|
197
|
+
|
|
198
|
+
with open(self.save_path, 'w') as f:
|
|
199
|
+
json.dump(self.history, f, indent=2)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class MemoryMonitorCallback(Callback):
|
|
203
|
+
"""Monitor GPU memory usage."""
|
|
204
|
+
|
|
205
|
+
def __init__(self, log_every: int = 50):
|
|
206
|
+
self.log_every = log_every
|
|
207
|
+
self.step = 0
|
|
208
|
+
|
|
209
|
+
def on_batch_end(self, trainer, batch_idx: int, loss: float, **kwargs):
|
|
210
|
+
import torch
|
|
211
|
+
self.step += 1
|
|
212
|
+
|
|
213
|
+
if self.step % self.log_every == 0 and torch.cuda.is_available():
|
|
214
|
+
allocated = torch.cuda.memory_allocated() / 1e9
|
|
215
|
+
reserved = torch.cuda.memory_reserved() / 1e9
|
|
216
|
+
logger.info(f"Step {self.step}: GPU Memory - {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class WandbCallback(Callback):
|
|
220
|
+
"""Log to Weights & Biases."""
|
|
221
|
+
|
|
222
|
+
def __init__(self, project: str = "langtune", run_name: Optional[str] = None):
|
|
223
|
+
self.project = project
|
|
224
|
+
self.run_name = run_name
|
|
225
|
+
self.wandb = None
|
|
226
|
+
|
|
227
|
+
def on_train_begin(self, trainer, **kwargs):
|
|
228
|
+
try:
|
|
229
|
+
import wandb
|
|
230
|
+
wandb.init(project=self.project, name=self.run_name)
|
|
231
|
+
self.wandb = wandb
|
|
232
|
+
except ImportError:
|
|
233
|
+
logger.warning("wandb not installed, skipping W&B logging")
|
|
234
|
+
|
|
235
|
+
def on_batch_end(self, trainer, batch_idx: int, loss: float, **kwargs):
|
|
236
|
+
if self.wandb:
|
|
237
|
+
self.wandb.log({"train_loss": loss})
|
|
238
|
+
|
|
239
|
+
def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, float], **kwargs):
|
|
240
|
+
if self.wandb:
|
|
241
|
+
self.wandb.log({"epoch": epoch + 1, **metrics})
|
|
242
|
+
|
|
243
|
+
def on_train_end(self, trainer, **kwargs):
|
|
244
|
+
if self.wandb:
|
|
245
|
+
self.wandb.finish()
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
# Default callback presets
|
|
249
|
+
def get_default_callbacks() -> CallbackList:
|
|
250
|
+
"""Get a list of recommended default callbacks."""
|
|
251
|
+
return CallbackList([
|
|
252
|
+
ModelSizeCallback(),
|
|
253
|
+
ProgressCallback(log_every=10),
|
|
254
|
+
TimerCallback(),
|
|
255
|
+
LearningRateMonitorCallback()
|
|
256
|
+
])
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def get_verbose_callbacks() -> CallbackList:
|
|
260
|
+
"""Get verbose callbacks for debugging."""
|
|
261
|
+
return CallbackList([
|
|
262
|
+
ModelSizeCallback(),
|
|
263
|
+
ProgressCallback(log_every=1),
|
|
264
|
+
TimerCallback(),
|
|
265
|
+
LearningRateMonitorCallback(),
|
|
266
|
+
GradientMonitorCallback(log_every=50),
|
|
267
|
+
MemoryMonitorCallback(log_every=50)
|
|
268
|
+
])
|