langtune 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langtune/config.py ADDED
@@ -0,0 +1,356 @@
1
+ """
2
+ config.py: Configuration management for Langtune
3
+ """
4
+
5
+ import yaml
6
+ import json
7
+ import os
8
+ from typing import Dict, Any, Optional, Union
9
+ from dataclasses import dataclass, asdict
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ @dataclass
15
+ class LoRAConfig:
16
+ """LoRA configuration parameters."""
17
+ rank: int = 8
18
+ alpha: float = 16.0
19
+ dropout: float = 0.1
20
+ target_modules: list = None
21
+ merge_weights: bool = False
22
+
23
+ def __post_init__(self):
24
+ if self.target_modules is None:
25
+ self.target_modules = ['attention.qkv', 'attention.proj', 'mlp.fc1', 'mlp.fc2']
26
+
27
+ @dataclass
28
+ class ModelConfig:
29
+ """Model architecture configuration."""
30
+ vocab_size: int = 32000
31
+ embed_dim: int = 512
32
+ num_layers: int = 6
33
+ num_heads: int = 8
34
+ max_seq_len: int = 512
35
+ mlp_ratio: float = 4.0
36
+ dropout: float = 0.1
37
+ lora: LoRAConfig = None
38
+
39
+ def __post_init__(self):
40
+ if self.lora is None:
41
+ self.lora = LoRAConfig()
42
+
43
+ @dataclass
44
+ class TrainingConfig:
45
+ """Training configuration parameters."""
46
+ batch_size: int = 32
47
+ learning_rate: float = 1e-4
48
+ weight_decay: float = 0.01
49
+ num_epochs: int = 10
50
+ warmup_steps: int = 1000
51
+ max_grad_norm: float = 1.0
52
+ gradient_accumulation_steps: int = 1
53
+ mixed_precision: bool = False
54
+ save_steps: int = 1000
55
+ eval_steps: int = 500
56
+ logging_steps: int = 100
57
+ save_total_limit: int = 3
58
+ early_stopping_patience: int = 5
59
+ early_stopping_threshold: float = 0.001
60
+
61
+ @dataclass
62
+ class DataConfig:
63
+ """Data loading and preprocessing configuration."""
64
+ train_file: Optional[str] = None
65
+ eval_file: Optional[str] = None
66
+ test_file: Optional[str] = None
67
+ max_length: int = 512
68
+ padding: str = "max_length"
69
+ truncation: bool = True
70
+ tokenizer_name: Optional[str] = None
71
+ cache_dir: Optional[str] = None
72
+
73
+ @dataclass
74
+ class Config:
75
+ """Main configuration class."""
76
+ model: ModelConfig
77
+ training: TrainingConfig
78
+ data: DataConfig
79
+ output_dir: str = "./outputs"
80
+ seed: int = 42
81
+ device: str = "auto" # auto, cpu, cuda, mps
82
+ num_workers: int = 4
83
+ pin_memory: bool = True
84
+
85
+ def __post_init__(self):
86
+ if isinstance(self.model, dict):
87
+ self.model = ModelConfig(**self.model)
88
+ if isinstance(self.training, dict):
89
+ self.training = TrainingConfig(**self.training)
90
+ if isinstance(self.data, dict):
91
+ self.data = DataConfig(**self.data)
92
+
93
+ # Default configurations
94
+ default_model_config = ModelConfig()
95
+ default_training_config = TrainingConfig()
96
+ default_data_config = DataConfig()
97
+ default_config = Config(
98
+ model=default_model_config,
99
+ training=default_training_config,
100
+ data=default_data_config
101
+ )
102
+
103
+ # Preset configurations for different model sizes
104
+ PRESET_CONFIGS = {
105
+ "tiny": {
106
+ "model": {
107
+ "vocab_size": 10000,
108
+ "embed_dim": 128,
109
+ "num_layers": 4,
110
+ "num_heads": 4,
111
+ "max_seq_len": 256,
112
+ "lora": {"rank": 4, "alpha": 8}
113
+ },
114
+ "training": {
115
+ "batch_size": 64,
116
+ "learning_rate": 2e-4
117
+ }
118
+ },
119
+ "small": {
120
+ "model": {
121
+ "vocab_size": 32000,
122
+ "embed_dim": 256,
123
+ "num_layers": 6,
124
+ "num_heads": 8,
125
+ "max_seq_len": 512,
126
+ "lora": {"rank": 8, "alpha": 16}
127
+ },
128
+ "training": {
129
+ "batch_size": 32,
130
+ "learning_rate": 1e-4
131
+ }
132
+ },
133
+ "base": {
134
+ "model": {
135
+ "vocab_size": 50257,
136
+ "embed_dim": 768,
137
+ "num_layers": 12,
138
+ "num_heads": 12,
139
+ "max_seq_len": 1024,
140
+ "lora": {"rank": 16, "alpha": 32}
141
+ },
142
+ "training": {
143
+ "batch_size": 16,
144
+ "learning_rate": 5e-5
145
+ }
146
+ },
147
+ "large": {
148
+ "model": {
149
+ "vocab_size": 50257,
150
+ "embed_dim": 1024,
151
+ "num_layers": 24,
152
+ "num_heads": 16,
153
+ "max_seq_len": 1024,
154
+ "lora": {"rank": 32, "alpha": 64}
155
+ },
156
+ "training": {
157
+ "batch_size": 8,
158
+ "learning_rate": 2e-5
159
+ }
160
+ }
161
+ }
162
+
163
+ def load_config(path: str) -> Config:
164
+ """
165
+ Load configuration from a YAML or JSON file.
166
+
167
+ Args:
168
+ path: Path to the configuration file
169
+
170
+ Returns:
171
+ Config object
172
+ """
173
+ if not os.path.exists(path):
174
+ raise FileNotFoundError(f"Config file not found: {path}")
175
+
176
+ with open(path, 'r') as f:
177
+ if path.endswith('.yaml') or path.endswith('.yml'):
178
+ config_dict = yaml.safe_load(f)
179
+ elif path.endswith('.json'):
180
+ config_dict = json.load(f)
181
+ else:
182
+ raise ValueError(f"Unsupported config file format: {path}")
183
+
184
+ return dict_to_config(config_dict)
185
+
186
+ def save_config(config: Config, path: str) -> None:
187
+ """
188
+ Save configuration to a YAML or JSON file.
189
+
190
+ Args:
191
+ config: Config object to save
192
+ path: Path to save the configuration file
193
+ """
194
+ config_dict = config_to_dict(config)
195
+
196
+ os.makedirs(os.path.dirname(path), exist_ok=True)
197
+
198
+ with open(path, 'w') as f:
199
+ if path.endswith('.yaml') or path.endswith('.yml'):
200
+ yaml.dump(config_dict, f, default_flow_style=False, indent=2)
201
+ elif path.endswith('.json'):
202
+ json.dump(config_dict, f, indent=2)
203
+ else:
204
+ raise ValueError(f"Unsupported config file format: {path}")
205
+
206
+ def dict_to_config(config_dict: Dict[str, Any]) -> Config:
207
+ """
208
+ Convert a dictionary to a Config object.
209
+
210
+ Args:
211
+ config_dict: Dictionary containing configuration
212
+
213
+ Returns:
214
+ Config object
215
+ """
216
+ # Handle nested dictionaries for dataclasses
217
+ if 'model' in config_dict and isinstance(config_dict['model'], dict):
218
+ if 'lora' in config_dict['model'] and isinstance(config_dict['model']['lora'], dict):
219
+ config_dict['model']['lora'] = LoRAConfig(**config_dict['model']['lora'])
220
+ config_dict['model'] = ModelConfig(**config_dict['model'])
221
+
222
+ if 'training' in config_dict and isinstance(config_dict['training'], dict):
223
+ config_dict['training'] = TrainingConfig(**config_dict['training'])
224
+
225
+ if 'data' in config_dict and isinstance(config_dict['data'], dict):
226
+ config_dict['data'] = DataConfig(**config_dict['data'])
227
+
228
+ return Config(**config_dict)
229
+
230
+ def config_to_dict(config: Config) -> Dict[str, Any]:
231
+ """
232
+ Convert a Config object to a dictionary.
233
+
234
+ Args:
235
+ config: Config object
236
+
237
+ Returns:
238
+ Dictionary representation
239
+ """
240
+ return asdict(config)
241
+
242
+ def get_preset_config(preset_name: str) -> Config:
243
+ """
244
+ Get a preset configuration.
245
+
246
+ Args:
247
+ preset_name: Name of the preset (tiny, small, base, large)
248
+
249
+ Returns:
250
+ Config object
251
+ """
252
+ if preset_name not in PRESET_CONFIGS:
253
+ raise ValueError(f"Unknown preset: {preset_name}. Available: {list(PRESET_CONFIGS.keys())}")
254
+
255
+ # Start with default config
256
+ config_dict = config_to_dict(default_config)
257
+
258
+ # Update with preset values
259
+ preset_dict = PRESET_CONFIGS[preset_name]
260
+ config_dict = deep_update(config_dict, preset_dict)
261
+
262
+ return dict_to_config(config_dict)
263
+
264
+ def deep_update(base_dict: Dict[str, Any], update_dict: Dict[str, Any]) -> Dict[str, Any]:
265
+ """
266
+ Deep update a dictionary with values from another dictionary.
267
+
268
+ Args:
269
+ base_dict: Base dictionary to update
270
+ update_dict: Dictionary with updates
271
+
272
+ Returns:
273
+ Updated dictionary
274
+ """
275
+ result = base_dict.copy()
276
+
277
+ for key, value in update_dict.items():
278
+ if key in result and isinstance(result[key], dict) and isinstance(value, dict):
279
+ result[key] = deep_update(result[key], value)
280
+ else:
281
+ result[key] = value
282
+
283
+ return result
284
+
285
+ def update_config(base_config: Config, updates: Dict[str, Any]) -> Config:
286
+ """
287
+ Update a configuration with new values.
288
+
289
+ Args:
290
+ base_config: Base configuration
291
+ updates: Dictionary with updates
292
+
293
+ Returns:
294
+ Updated Config object
295
+ """
296
+ config_dict = config_to_dict(base_config)
297
+ updated_dict = deep_update(config_dict, updates)
298
+ return dict_to_config(updated_dict)
299
+
300
+ def validate_config(config: Config) -> None:
301
+ """
302
+ Validate configuration parameters.
303
+
304
+ Args:
305
+ config: Config object to validate
306
+
307
+ Raises:
308
+ ValueError: If configuration is invalid
309
+ """
310
+ # Model validation
311
+ if config.model.embed_dim % config.model.num_heads != 0:
312
+ raise ValueError(f"embed_dim ({config.model.embed_dim}) must be divisible by num_heads ({config.model.num_heads})")
313
+
314
+ if config.model.vocab_size <= 0:
315
+ raise ValueError(f"vocab_size must be positive, got {config.model.vocab_size}")
316
+
317
+ if config.model.num_layers <= 0:
318
+ raise ValueError(f"num_layers must be positive, got {config.model.num_layers}")
319
+
320
+ # LoRA validation
321
+ if config.model.lora.rank <= 0:
322
+ raise ValueError(f"LoRA rank must be positive, got {config.model.lora.rank}")
323
+
324
+ if config.model.lora.alpha <= 0:
325
+ raise ValueError(f"LoRA alpha must be positive, got {config.model.lora.alpha}")
326
+
327
+ # Training validation
328
+ if config.training.batch_size <= 0:
329
+ raise ValueError(f"batch_size must be positive, got {config.training.batch_size}")
330
+
331
+ if config.training.learning_rate <= 0:
332
+ raise ValueError(f"learning_rate must be positive, got {config.training.learning_rate}")
333
+
334
+ if config.training.num_epochs <= 0:
335
+ raise ValueError(f"num_epochs must be positive, got {config.training.num_epochs}")
336
+
337
+ # Data validation
338
+ if config.data.max_length <= 0:
339
+ raise ValueError(f"max_length must be positive, got {config.data.max_length}")
340
+
341
+ logger.info("Configuration validation passed")
342
+
343
+ # Backward compatibility
344
+ def load_config_legacy(path):
345
+ """Legacy function for backward compatibility."""
346
+ logger.warning("load_config_legacy is deprecated. Use load_config instead.")
347
+ config_dict = load_config(path)
348
+ return config_to_dict(config_dict)
349
+
350
+ def update_config_legacy(base_config, updates):
351
+ """Legacy function for backward compatibility."""
352
+ logger.warning("update_config_legacy is deprecated. Use update_config instead.")
353
+ if isinstance(base_config, dict):
354
+ base_config = dict_to_config(base_config)
355
+ updated_config = update_config(base_config, updates)
356
+ return config_to_dict(updated_config)