langtune 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langtune/__init__.py +315 -0
- langtune/acceleration.py +132 -0
- langtune/api.py +320 -0
- langtune/auth.py +434 -0
- langtune/callbacks.py +268 -0
- langtune/cli.py +687 -0
- langtune/client.py +721 -0
- langtune/config.py +356 -0
- langtune/data.py +526 -0
- langtune/distributed.py +154 -0
- langtune/facade.py +174 -0
- langtune/finetune.py +491 -0
- langtune/generation.py +95 -0
- langtune/logging_utils.py +182 -0
- langtune/metrics.py +345 -0
- langtune/model/__init__.py +20 -0
- langtune/model/hub.py +109 -0
- langtune/model/loader.py +84 -0
- langtune/model/safetensors.py +104 -0
- langtune/model/weights.py +100 -0
- langtune/models.py +19 -0
- langtune/nn/fast_transformer.py +399 -0
- langtune/nn/layers.py +178 -0
- langtune/nn/transformer.py +254 -0
- langtune/optimizations.py +870 -0
- langtune/py.typed +2 -0
- langtune/schedulers.py +234 -0
- langtune/tokenizers.py +275 -0
- langtune/trainer.py +889 -0
- langtune/training/neftune.py +80 -0
- langtune/utils.py +337 -0
- langtune-0.1.19.dist-info/METADATA +257 -0
- langtune-0.1.19.dist-info/RECORD +37 -0
- langtune-0.1.19.dist-info/WHEEL +5 -0
- langtune-0.1.19.dist-info/entry_points.txt +2 -0
- langtune-0.1.19.dist-info/licenses/LICENSE +21 -0
- langtune-0.1.19.dist-info/top_level.txt +1 -0
langtune/config.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
"""
|
|
2
|
+
config.py: Configuration management for Langtune
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
from typing import Dict, Any, Optional, Union
|
|
9
|
+
from dataclasses import dataclass, asdict
|
|
10
|
+
import logging
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class LoRAConfig:
|
|
16
|
+
"""LoRA configuration parameters."""
|
|
17
|
+
rank: int = 8
|
|
18
|
+
alpha: float = 16.0
|
|
19
|
+
dropout: float = 0.1
|
|
20
|
+
target_modules: list = None
|
|
21
|
+
merge_weights: bool = False
|
|
22
|
+
|
|
23
|
+
def __post_init__(self):
|
|
24
|
+
if self.target_modules is None:
|
|
25
|
+
self.target_modules = ['attention.qkv', 'attention.proj', 'mlp.fc1', 'mlp.fc2']
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class ModelConfig:
|
|
29
|
+
"""Model architecture configuration."""
|
|
30
|
+
vocab_size: int = 32000
|
|
31
|
+
embed_dim: int = 512
|
|
32
|
+
num_layers: int = 6
|
|
33
|
+
num_heads: int = 8
|
|
34
|
+
max_seq_len: int = 512
|
|
35
|
+
mlp_ratio: float = 4.0
|
|
36
|
+
dropout: float = 0.1
|
|
37
|
+
lora: LoRAConfig = None
|
|
38
|
+
|
|
39
|
+
def __post_init__(self):
|
|
40
|
+
if self.lora is None:
|
|
41
|
+
self.lora = LoRAConfig()
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class TrainingConfig:
|
|
45
|
+
"""Training configuration parameters."""
|
|
46
|
+
batch_size: int = 32
|
|
47
|
+
learning_rate: float = 1e-4
|
|
48
|
+
weight_decay: float = 0.01
|
|
49
|
+
num_epochs: int = 10
|
|
50
|
+
warmup_steps: int = 1000
|
|
51
|
+
max_grad_norm: float = 1.0
|
|
52
|
+
gradient_accumulation_steps: int = 1
|
|
53
|
+
mixed_precision: bool = False
|
|
54
|
+
save_steps: int = 1000
|
|
55
|
+
eval_steps: int = 500
|
|
56
|
+
logging_steps: int = 100
|
|
57
|
+
save_total_limit: int = 3
|
|
58
|
+
early_stopping_patience: int = 5
|
|
59
|
+
early_stopping_threshold: float = 0.001
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class DataConfig:
|
|
63
|
+
"""Data loading and preprocessing configuration."""
|
|
64
|
+
train_file: Optional[str] = None
|
|
65
|
+
eval_file: Optional[str] = None
|
|
66
|
+
test_file: Optional[str] = None
|
|
67
|
+
max_length: int = 512
|
|
68
|
+
padding: str = "max_length"
|
|
69
|
+
truncation: bool = True
|
|
70
|
+
tokenizer_name: Optional[str] = None
|
|
71
|
+
cache_dir: Optional[str] = None
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class Config:
|
|
75
|
+
"""Main configuration class."""
|
|
76
|
+
model: ModelConfig
|
|
77
|
+
training: TrainingConfig
|
|
78
|
+
data: DataConfig
|
|
79
|
+
output_dir: str = "./outputs"
|
|
80
|
+
seed: int = 42
|
|
81
|
+
device: str = "auto" # auto, cpu, cuda, mps
|
|
82
|
+
num_workers: int = 4
|
|
83
|
+
pin_memory: bool = True
|
|
84
|
+
|
|
85
|
+
def __post_init__(self):
|
|
86
|
+
if isinstance(self.model, dict):
|
|
87
|
+
self.model = ModelConfig(**self.model)
|
|
88
|
+
if isinstance(self.training, dict):
|
|
89
|
+
self.training = TrainingConfig(**self.training)
|
|
90
|
+
if isinstance(self.data, dict):
|
|
91
|
+
self.data = DataConfig(**self.data)
|
|
92
|
+
|
|
93
|
+
# Default configurations
|
|
94
|
+
default_model_config = ModelConfig()
|
|
95
|
+
default_training_config = TrainingConfig()
|
|
96
|
+
default_data_config = DataConfig()
|
|
97
|
+
default_config = Config(
|
|
98
|
+
model=default_model_config,
|
|
99
|
+
training=default_training_config,
|
|
100
|
+
data=default_data_config
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Preset configurations for different model sizes
|
|
104
|
+
PRESET_CONFIGS = {
|
|
105
|
+
"tiny": {
|
|
106
|
+
"model": {
|
|
107
|
+
"vocab_size": 10000,
|
|
108
|
+
"embed_dim": 128,
|
|
109
|
+
"num_layers": 4,
|
|
110
|
+
"num_heads": 4,
|
|
111
|
+
"max_seq_len": 256,
|
|
112
|
+
"lora": {"rank": 4, "alpha": 8}
|
|
113
|
+
},
|
|
114
|
+
"training": {
|
|
115
|
+
"batch_size": 64,
|
|
116
|
+
"learning_rate": 2e-4
|
|
117
|
+
}
|
|
118
|
+
},
|
|
119
|
+
"small": {
|
|
120
|
+
"model": {
|
|
121
|
+
"vocab_size": 32000,
|
|
122
|
+
"embed_dim": 256,
|
|
123
|
+
"num_layers": 6,
|
|
124
|
+
"num_heads": 8,
|
|
125
|
+
"max_seq_len": 512,
|
|
126
|
+
"lora": {"rank": 8, "alpha": 16}
|
|
127
|
+
},
|
|
128
|
+
"training": {
|
|
129
|
+
"batch_size": 32,
|
|
130
|
+
"learning_rate": 1e-4
|
|
131
|
+
}
|
|
132
|
+
},
|
|
133
|
+
"base": {
|
|
134
|
+
"model": {
|
|
135
|
+
"vocab_size": 50257,
|
|
136
|
+
"embed_dim": 768,
|
|
137
|
+
"num_layers": 12,
|
|
138
|
+
"num_heads": 12,
|
|
139
|
+
"max_seq_len": 1024,
|
|
140
|
+
"lora": {"rank": 16, "alpha": 32}
|
|
141
|
+
},
|
|
142
|
+
"training": {
|
|
143
|
+
"batch_size": 16,
|
|
144
|
+
"learning_rate": 5e-5
|
|
145
|
+
}
|
|
146
|
+
},
|
|
147
|
+
"large": {
|
|
148
|
+
"model": {
|
|
149
|
+
"vocab_size": 50257,
|
|
150
|
+
"embed_dim": 1024,
|
|
151
|
+
"num_layers": 24,
|
|
152
|
+
"num_heads": 16,
|
|
153
|
+
"max_seq_len": 1024,
|
|
154
|
+
"lora": {"rank": 32, "alpha": 64}
|
|
155
|
+
},
|
|
156
|
+
"training": {
|
|
157
|
+
"batch_size": 8,
|
|
158
|
+
"learning_rate": 2e-5
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
def load_config(path: str) -> Config:
|
|
164
|
+
"""
|
|
165
|
+
Load configuration from a YAML or JSON file.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
path: Path to the configuration file
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Config object
|
|
172
|
+
"""
|
|
173
|
+
if not os.path.exists(path):
|
|
174
|
+
raise FileNotFoundError(f"Config file not found: {path}")
|
|
175
|
+
|
|
176
|
+
with open(path, 'r') as f:
|
|
177
|
+
if path.endswith('.yaml') or path.endswith('.yml'):
|
|
178
|
+
config_dict = yaml.safe_load(f)
|
|
179
|
+
elif path.endswith('.json'):
|
|
180
|
+
config_dict = json.load(f)
|
|
181
|
+
else:
|
|
182
|
+
raise ValueError(f"Unsupported config file format: {path}")
|
|
183
|
+
|
|
184
|
+
return dict_to_config(config_dict)
|
|
185
|
+
|
|
186
|
+
def save_config(config: Config, path: str) -> None:
|
|
187
|
+
"""
|
|
188
|
+
Save configuration to a YAML or JSON file.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
config: Config object to save
|
|
192
|
+
path: Path to save the configuration file
|
|
193
|
+
"""
|
|
194
|
+
config_dict = config_to_dict(config)
|
|
195
|
+
|
|
196
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
197
|
+
|
|
198
|
+
with open(path, 'w') as f:
|
|
199
|
+
if path.endswith('.yaml') or path.endswith('.yml'):
|
|
200
|
+
yaml.dump(config_dict, f, default_flow_style=False, indent=2)
|
|
201
|
+
elif path.endswith('.json'):
|
|
202
|
+
json.dump(config_dict, f, indent=2)
|
|
203
|
+
else:
|
|
204
|
+
raise ValueError(f"Unsupported config file format: {path}")
|
|
205
|
+
|
|
206
|
+
def dict_to_config(config_dict: Dict[str, Any]) -> Config:
|
|
207
|
+
"""
|
|
208
|
+
Convert a dictionary to a Config object.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
config_dict: Dictionary containing configuration
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Config object
|
|
215
|
+
"""
|
|
216
|
+
# Handle nested dictionaries for dataclasses
|
|
217
|
+
if 'model' in config_dict and isinstance(config_dict['model'], dict):
|
|
218
|
+
if 'lora' in config_dict['model'] and isinstance(config_dict['model']['lora'], dict):
|
|
219
|
+
config_dict['model']['lora'] = LoRAConfig(**config_dict['model']['lora'])
|
|
220
|
+
config_dict['model'] = ModelConfig(**config_dict['model'])
|
|
221
|
+
|
|
222
|
+
if 'training' in config_dict and isinstance(config_dict['training'], dict):
|
|
223
|
+
config_dict['training'] = TrainingConfig(**config_dict['training'])
|
|
224
|
+
|
|
225
|
+
if 'data' in config_dict and isinstance(config_dict['data'], dict):
|
|
226
|
+
config_dict['data'] = DataConfig(**config_dict['data'])
|
|
227
|
+
|
|
228
|
+
return Config(**config_dict)
|
|
229
|
+
|
|
230
|
+
def config_to_dict(config: Config) -> Dict[str, Any]:
|
|
231
|
+
"""
|
|
232
|
+
Convert a Config object to a dictionary.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
config: Config object
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
Dictionary representation
|
|
239
|
+
"""
|
|
240
|
+
return asdict(config)
|
|
241
|
+
|
|
242
|
+
def get_preset_config(preset_name: str) -> Config:
|
|
243
|
+
"""
|
|
244
|
+
Get a preset configuration.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
preset_name: Name of the preset (tiny, small, base, large)
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
Config object
|
|
251
|
+
"""
|
|
252
|
+
if preset_name not in PRESET_CONFIGS:
|
|
253
|
+
raise ValueError(f"Unknown preset: {preset_name}. Available: {list(PRESET_CONFIGS.keys())}")
|
|
254
|
+
|
|
255
|
+
# Start with default config
|
|
256
|
+
config_dict = config_to_dict(default_config)
|
|
257
|
+
|
|
258
|
+
# Update with preset values
|
|
259
|
+
preset_dict = PRESET_CONFIGS[preset_name]
|
|
260
|
+
config_dict = deep_update(config_dict, preset_dict)
|
|
261
|
+
|
|
262
|
+
return dict_to_config(config_dict)
|
|
263
|
+
|
|
264
|
+
def deep_update(base_dict: Dict[str, Any], update_dict: Dict[str, Any]) -> Dict[str, Any]:
|
|
265
|
+
"""
|
|
266
|
+
Deep update a dictionary with values from another dictionary.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
base_dict: Base dictionary to update
|
|
270
|
+
update_dict: Dictionary with updates
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
Updated dictionary
|
|
274
|
+
"""
|
|
275
|
+
result = base_dict.copy()
|
|
276
|
+
|
|
277
|
+
for key, value in update_dict.items():
|
|
278
|
+
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
|
279
|
+
result[key] = deep_update(result[key], value)
|
|
280
|
+
else:
|
|
281
|
+
result[key] = value
|
|
282
|
+
|
|
283
|
+
return result
|
|
284
|
+
|
|
285
|
+
def update_config(base_config: Config, updates: Dict[str, Any]) -> Config:
|
|
286
|
+
"""
|
|
287
|
+
Update a configuration with new values.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
base_config: Base configuration
|
|
291
|
+
updates: Dictionary with updates
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
Updated Config object
|
|
295
|
+
"""
|
|
296
|
+
config_dict = config_to_dict(base_config)
|
|
297
|
+
updated_dict = deep_update(config_dict, updates)
|
|
298
|
+
return dict_to_config(updated_dict)
|
|
299
|
+
|
|
300
|
+
def validate_config(config: Config) -> None:
|
|
301
|
+
"""
|
|
302
|
+
Validate configuration parameters.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
config: Config object to validate
|
|
306
|
+
|
|
307
|
+
Raises:
|
|
308
|
+
ValueError: If configuration is invalid
|
|
309
|
+
"""
|
|
310
|
+
# Model validation
|
|
311
|
+
if config.model.embed_dim % config.model.num_heads != 0:
|
|
312
|
+
raise ValueError(f"embed_dim ({config.model.embed_dim}) must be divisible by num_heads ({config.model.num_heads})")
|
|
313
|
+
|
|
314
|
+
if config.model.vocab_size <= 0:
|
|
315
|
+
raise ValueError(f"vocab_size must be positive, got {config.model.vocab_size}")
|
|
316
|
+
|
|
317
|
+
if config.model.num_layers <= 0:
|
|
318
|
+
raise ValueError(f"num_layers must be positive, got {config.model.num_layers}")
|
|
319
|
+
|
|
320
|
+
# LoRA validation
|
|
321
|
+
if config.model.lora.rank <= 0:
|
|
322
|
+
raise ValueError(f"LoRA rank must be positive, got {config.model.lora.rank}")
|
|
323
|
+
|
|
324
|
+
if config.model.lora.alpha <= 0:
|
|
325
|
+
raise ValueError(f"LoRA alpha must be positive, got {config.model.lora.alpha}")
|
|
326
|
+
|
|
327
|
+
# Training validation
|
|
328
|
+
if config.training.batch_size <= 0:
|
|
329
|
+
raise ValueError(f"batch_size must be positive, got {config.training.batch_size}")
|
|
330
|
+
|
|
331
|
+
if config.training.learning_rate <= 0:
|
|
332
|
+
raise ValueError(f"learning_rate must be positive, got {config.training.learning_rate}")
|
|
333
|
+
|
|
334
|
+
if config.training.num_epochs <= 0:
|
|
335
|
+
raise ValueError(f"num_epochs must be positive, got {config.training.num_epochs}")
|
|
336
|
+
|
|
337
|
+
# Data validation
|
|
338
|
+
if config.data.max_length <= 0:
|
|
339
|
+
raise ValueError(f"max_length must be positive, got {config.data.max_length}")
|
|
340
|
+
|
|
341
|
+
logger.info("Configuration validation passed")
|
|
342
|
+
|
|
343
|
+
# Backward compatibility
|
|
344
|
+
def load_config_legacy(path):
|
|
345
|
+
"""Legacy function for backward compatibility."""
|
|
346
|
+
logger.warning("load_config_legacy is deprecated. Use load_config instead.")
|
|
347
|
+
config_dict = load_config(path)
|
|
348
|
+
return config_to_dict(config_dict)
|
|
349
|
+
|
|
350
|
+
def update_config_legacy(base_config, updates):
|
|
351
|
+
"""Legacy function for backward compatibility."""
|
|
352
|
+
logger.warning("update_config_legacy is deprecated. Use update_config instead.")
|
|
353
|
+
if isinstance(base_config, dict):
|
|
354
|
+
base_config = dict_to_config(base_config)
|
|
355
|
+
updated_config = update_config(base_config, updates)
|
|
356
|
+
return config_to_dict(updated_config)
|