langvision 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langvision might be problematic. Click here for more details.

Files changed (41) hide show
  1. langvision/__init__.py +77 -2
  2. langvision/callbacks/base.py +166 -7
  3. langvision/cli/__init__.py +85 -0
  4. langvision/cli/complete_cli.py +319 -0
  5. langvision/cli/config.py +344 -0
  6. langvision/cli/evaluate.py +201 -0
  7. langvision/cli/export.py +177 -0
  8. langvision/cli/finetune.py +165 -48
  9. langvision/cli/model_zoo.py +162 -0
  10. langvision/cli/train.py +27 -13
  11. langvision/cli/utils.py +258 -0
  12. langvision/components/attention.py +4 -1
  13. langvision/concepts/__init__.py +9 -0
  14. langvision/concepts/ccot.py +30 -0
  15. langvision/concepts/cot.py +29 -0
  16. langvision/concepts/dpo.py +37 -0
  17. langvision/concepts/grpo.py +25 -0
  18. langvision/concepts/lime.py +37 -0
  19. langvision/concepts/ppo.py +47 -0
  20. langvision/concepts/rlhf.py +40 -0
  21. langvision/concepts/rlvr.py +25 -0
  22. langvision/concepts/shap.py +37 -0
  23. langvision/data/enhanced_datasets.py +582 -0
  24. langvision/model_zoo.py +169 -2
  25. langvision/models/lora.py +189 -17
  26. langvision/models/multimodal.py +297 -0
  27. langvision/models/resnet.py +303 -0
  28. langvision/training/advanced_trainer.py +478 -0
  29. langvision/training/trainer.py +30 -2
  30. langvision/utils/config.py +180 -9
  31. langvision/utils/metrics.py +448 -0
  32. langvision/utils/setup.py +266 -0
  33. langvision-0.1.0.dist-info/METADATA +50 -0
  34. langvision-0.1.0.dist-info/RECORD +61 -0
  35. {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/WHEEL +1 -1
  36. langvision-0.1.0.dist-info/entry_points.txt +2 -0
  37. langvision-0.0.1.dist-info/METADATA +0 -463
  38. langvision-0.0.1.dist-info/RECORD +0 -40
  39. langvision-0.0.1.dist-info/entry_points.txt +0 -2
  40. langvision-0.0.1.dist-info/licenses/LICENSE +0 -21
  41. {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,478 @@
1
+ """
2
+ Advanced training utilities with LoRA fine-tuning, mixed precision, and comprehensive logging.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from torch.utils.data import DataLoader
9
+ from torch.cuda.amp import GradScaler, autocast
10
+ from typing import Dict, Any, Optional, List, Callable, Union
11
+ import logging
12
+ import time
13
+ import os
14
+ import json
15
+ from pathlib import Path
16
+ import numpy as np
17
+ from tqdm import tqdm
18
+ import wandb
19
+ from dataclasses import dataclass, asdict
20
+
21
+ from ..models.lora import LoRAConfig
22
+ from ..utils.metrics import MetricsTracker
23
+ from ..callbacks.base import Callback
24
+ from ..utils.device import get_device, to_device
25
+
26
+
27
+ @dataclass
28
+ class TrainingConfig:
29
+ """Configuration for training parameters."""
30
+ # Basic training settings
31
+ epochs: int = 100
32
+ batch_size: int = 32
33
+ learning_rate: float = 1e-4
34
+ weight_decay: float = 1e-4
35
+ warmup_epochs: int = 5
36
+
37
+ # LoRA settings
38
+ lora_config: Optional[LoRAConfig] = None
39
+ freeze_backbone: bool = True
40
+
41
+ # Optimization settings
42
+ optimizer: str = "adamw" # "adam", "adamw", "sgd"
43
+ scheduler: str = "cosine" # "cosine", "step", "plateau"
44
+ gradient_clip_norm: float = 1.0
45
+
46
+ # Mixed precision and performance
47
+ use_amp: bool = True
48
+ compile_model: bool = False
49
+
50
+ # Regularization
51
+ dropout: float = 0.1
52
+ label_smoothing: float = 0.0
53
+
54
+ # Logging and checkpointing
55
+ log_interval: int = 10
56
+ eval_interval: int = 1
57
+ save_interval: int = 5
58
+ save_top_k: int = 3
59
+
60
+ # Paths
61
+ output_dir: str = "./outputs"
62
+ experiment_name: str = "langvision_experiment"
63
+
64
+ # Distributed training
65
+ distributed: bool = False
66
+ local_rank: int = 0
67
+
68
+ # Early stopping
69
+ early_stopping_patience: int = 10
70
+ early_stopping_metric: str = "val_loss"
71
+ early_stopping_mode: str = "min" # "min" or "max"
72
+
73
+
74
+ class AdvancedTrainer:
75
+ """Advanced trainer with comprehensive features for vision model fine-tuning."""
76
+
77
+ def __init__(self,
78
+ model: nn.Module,
79
+ train_loader: DataLoader,
80
+ val_loader: Optional[DataLoader] = None,
81
+ config: Optional[TrainingConfig] = None,
82
+ callbacks: Optional[List[Callback]] = None):
83
+
84
+ self.model = model
85
+ self.train_loader = train_loader
86
+ self.val_loader = val_loader
87
+ self.config = config or TrainingConfig()
88
+ self.callbacks = callbacks or []
89
+
90
+ # Setup device and distributed training
91
+ self.device = get_device()
92
+ self.model = to_device(self.model, self.device)
93
+
94
+ # Setup LoRA if configured
95
+ if self.config.lora_config:
96
+ self._setup_lora()
97
+
98
+ # Setup optimization
99
+ self.optimizer = self._create_optimizer()
100
+ self.scheduler = self._create_scheduler()
101
+ self.scaler = GradScaler() if self.config.use_amp else None
102
+
103
+ # Setup loss function
104
+ self.criterion = self._create_criterion()
105
+
106
+ # Setup logging
107
+ self.logger = self._setup_logging()
108
+ self.metrics_tracker = MetricsTracker()
109
+
110
+ # Training state
111
+ self.current_epoch = 0
112
+ self.global_step = 0
113
+ self.best_metric = float('inf') if self.config.early_stopping_mode == 'min' else float('-inf')
114
+ self.patience_counter = 0
115
+
116
+ # Create output directory
117
+ self.output_dir = Path(self.config.output_dir) / self.config.experiment_name
118
+ self.output_dir.mkdir(parents=True, exist_ok=True)
119
+
120
+ # Save config
121
+ with open(self.output_dir / "config.json", "w") as f:
122
+ json.dump(asdict(self.config), f, indent=2)
123
+
124
+ # Model compilation for PyTorch 2.0+
125
+ if self.config.compile_model and hasattr(torch, 'compile'):
126
+ self.model = torch.compile(self.model)
127
+
128
+ def _setup_lora(self):
129
+ """Setup LoRA fine-tuning by freezing backbone parameters."""
130
+ if self.config.freeze_backbone:
131
+ # Freeze all parameters first
132
+ for param in self.model.parameters():
133
+ param.requires_grad = False
134
+
135
+ # Unfreeze LoRA parameters
136
+ for name, module in self.model.named_modules():
137
+ if hasattr(module, 'lora_A') or hasattr(module, 'lora_B'):
138
+ for param in module.parameters():
139
+ if 'lora' in param.name if hasattr(param, 'name') else True:
140
+ param.requires_grad = True
141
+
142
+ # Log trainable parameters
143
+ total_params = sum(p.numel() for p in self.model.parameters())
144
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
145
+
146
+ self.logger.info(f"Total parameters: {total_params:,}")
147
+ self.logger.info(f"Trainable parameters: {trainable_params:,}")
148
+ self.logger.info(f"Trainable ratio: {trainable_params/total_params:.2%}")
149
+
150
+ def _create_optimizer(self) -> optim.Optimizer:
151
+ """Create optimizer based on configuration."""
152
+ trainable_params = [p for p in self.model.parameters() if p.requires_grad]
153
+
154
+ if self.config.optimizer.lower() == "adam":
155
+ return optim.Adam(trainable_params,
156
+ lr=self.config.learning_rate,
157
+ weight_decay=self.config.weight_decay)
158
+ elif self.config.optimizer.lower() == "adamw":
159
+ return optim.AdamW(trainable_params,
160
+ lr=self.config.learning_rate,
161
+ weight_decay=self.config.weight_decay)
162
+ elif self.config.optimizer.lower() == "sgd":
163
+ return optim.SGD(trainable_params,
164
+ lr=self.config.learning_rate,
165
+ momentum=0.9,
166
+ weight_decay=self.config.weight_decay)
167
+ else:
168
+ raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
169
+
170
+ def _create_scheduler(self) -> Optional[optim.lr_scheduler._LRScheduler]:
171
+ """Create learning rate scheduler."""
172
+ if self.config.scheduler.lower() == "cosine":
173
+ return optim.lr_scheduler.CosineAnnealingLR(
174
+ self.optimizer,
175
+ T_max=self.config.epochs,
176
+ eta_min=self.config.learning_rate * 0.01
177
+ )
178
+ elif self.config.scheduler.lower() == "step":
179
+ return optim.lr_scheduler.StepLR(
180
+ self.optimizer,
181
+ step_size=self.config.epochs // 3,
182
+ gamma=0.1
183
+ )
184
+ elif self.config.scheduler.lower() == "plateau":
185
+ return optim.lr_scheduler.ReduceLROnPlateau(
186
+ self.optimizer,
187
+ mode=self.config.early_stopping_mode,
188
+ patience=self.config.early_stopping_patience // 2,
189
+ factor=0.5
190
+ )
191
+ else:
192
+ return None
193
+
194
+ def _create_criterion(self) -> nn.Module:
195
+ """Create loss function with label smoothing support."""
196
+ if self.config.label_smoothing > 0:
197
+ return nn.CrossEntropyLoss(label_smoothing=self.config.label_smoothing)
198
+ else:
199
+ return nn.CrossEntropyLoss()
200
+
201
+ def _setup_logging(self) -> logging.Logger:
202
+ """Setup logging configuration."""
203
+ logger = logging.getLogger(f"langvision.{self.config.experiment_name}")
204
+ logger.setLevel(logging.INFO)
205
+
206
+ # File handler
207
+ log_file = self.output_dir / "training.log"
208
+ file_handler = logging.FileHandler(log_file)
209
+ file_handler.setLevel(logging.INFO)
210
+
211
+ # Console handler
212
+ console_handler = logging.StreamHandler()
213
+ console_handler.setLevel(logging.INFO)
214
+
215
+ # Formatter
216
+ formatter = logging.Formatter(
217
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
218
+ )
219
+ file_handler.setFormatter(formatter)
220
+ console_handler.setFormatter(formatter)
221
+
222
+ logger.addHandler(file_handler)
223
+ logger.addHandler(console_handler)
224
+
225
+ return logger
226
+
227
+ def train_epoch(self) -> Dict[str, float]:
228
+ """Train for one epoch."""
229
+ self.model.train()
230
+ epoch_metrics = {}
231
+
232
+ # Progress bar
233
+ pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch}")
234
+
235
+ for batch_idx, batch in enumerate(pbar):
236
+ # Move batch to device
237
+ batch = to_device(batch, self.device)
238
+
239
+ # Forward pass with mixed precision
240
+ with autocast(enabled=self.config.use_amp):
241
+ outputs = self.model(batch['images'])
242
+ loss = self.criterion(outputs, batch['labels'])
243
+
244
+ # Backward pass
245
+ self.optimizer.zero_grad()
246
+
247
+ if self.config.use_amp:
248
+ self.scaler.scale(loss).backward()
249
+
250
+ # Gradient clipping
251
+ if self.config.gradient_clip_norm > 0:
252
+ self.scaler.unscale_(self.optimizer)
253
+ torch.nn.utils.clip_grad_norm_(
254
+ self.model.parameters(),
255
+ self.config.gradient_clip_norm
256
+ )
257
+
258
+ self.scaler.step(self.optimizer)
259
+ self.scaler.update()
260
+ else:
261
+ loss.backward()
262
+
263
+ # Gradient clipping
264
+ if self.config.gradient_clip_norm > 0:
265
+ torch.nn.utils.clip_grad_norm_(
266
+ self.model.parameters(),
267
+ self.config.gradient_clip_norm
268
+ )
269
+
270
+ self.optimizer.step()
271
+
272
+ # Update metrics
273
+ batch_size = batch['images'].size(0)
274
+ self.metrics_tracker.update('train_loss', loss.item(), batch_size)
275
+
276
+ # Calculate accuracy
277
+ with torch.no_grad():
278
+ pred = outputs.argmax(dim=1)
279
+ acc = (pred == batch['labels']).float().mean()
280
+ self.metrics_tracker.update('train_acc', acc.item(), batch_size)
281
+
282
+ # Update progress bar
283
+ pbar.set_postfix({
284
+ 'loss': f"{loss.item():.4f}",
285
+ 'acc': f"{acc.item():.4f}",
286
+ 'lr': f"{self.optimizer.param_groups[0]['lr']:.6f}"
287
+ })
288
+
289
+ # Logging
290
+ if batch_idx % self.config.log_interval == 0:
291
+ self.logger.info(
292
+ f"Epoch {self.current_epoch}, Batch {batch_idx}: "
293
+ f"Loss={loss.item():.4f}, Acc={acc.item():.4f}"
294
+ )
295
+
296
+ self.global_step += 1
297
+
298
+ # Get epoch metrics
299
+ epoch_metrics = self.metrics_tracker.get_averages(['train_loss', 'train_acc'])
300
+ self.metrics_tracker.reset()
301
+
302
+ return epoch_metrics
303
+
304
+ def validate(self) -> Dict[str, float]:
305
+ """Validate the model."""
306
+ if self.val_loader is None:
307
+ return {}
308
+
309
+ self.model.eval()
310
+
311
+ with torch.no_grad():
312
+ pbar = tqdm(self.val_loader, desc="Validation")
313
+
314
+ for batch in pbar:
315
+ batch = to_device(batch, self.device)
316
+
317
+ with autocast(enabled=self.config.use_amp):
318
+ outputs = self.model(batch['images'])
319
+ loss = self.criterion(outputs, batch['labels'])
320
+
321
+ # Update metrics
322
+ batch_size = batch['images'].size(0)
323
+ self.metrics_tracker.update('val_loss', loss.item(), batch_size)
324
+
325
+ # Calculate accuracy
326
+ pred = outputs.argmax(dim=1)
327
+ acc = (pred == batch['labels']).float().mean()
328
+ self.metrics_tracker.update('val_acc', acc.item(), batch_size)
329
+
330
+ pbar.set_postfix({
331
+ 'val_loss': f"{loss.item():.4f}",
332
+ 'val_acc': f"{acc.item():.4f}"
333
+ })
334
+
335
+ # Get validation metrics
336
+ val_metrics = self.metrics_tracker.get_averages(['val_loss', 'val_acc'])
337
+ self.metrics_tracker.reset()
338
+
339
+ return val_metrics
340
+
341
+ def save_checkpoint(self, metrics: Dict[str, float], is_best: bool = False):
342
+ """Save model checkpoint."""
343
+ checkpoint = {
344
+ 'epoch': self.current_epoch,
345
+ 'global_step': self.global_step,
346
+ 'model_state_dict': self.model.state_dict(),
347
+ 'optimizer_state_dict': self.optimizer.state_dict(),
348
+ 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
349
+ 'scaler_state_dict': self.scaler.state_dict() if self.scaler else None,
350
+ 'metrics': metrics,
351
+ 'config': asdict(self.config)
352
+ }
353
+
354
+ # Save regular checkpoint
355
+ checkpoint_path = self.output_dir / f"checkpoint_epoch_{self.current_epoch}.pt"
356
+ torch.save(checkpoint, checkpoint_path)
357
+
358
+ # Save best checkpoint
359
+ if is_best:
360
+ best_path = self.output_dir / "best_model.pt"
361
+ torch.save(checkpoint, best_path)
362
+ self.logger.info(f"New best model saved with {self.config.early_stopping_metric}: {metrics.get(self.config.early_stopping_metric, 'N/A')}")
363
+
364
+ # Clean up old checkpoints (keep only top-k)
365
+ self._cleanup_checkpoints()
366
+
367
+ def _cleanup_checkpoints(self):
368
+ """Remove old checkpoints, keeping only the most recent ones."""
369
+ checkpoint_files = list(self.output_dir.glob("checkpoint_epoch_*.pt"))
370
+ if len(checkpoint_files) > self.config.save_top_k:
371
+ # Sort by epoch number
372
+ checkpoint_files.sort(key=lambda x: int(x.stem.split('_')[-1]))
373
+ # Remove oldest checkpoints
374
+ for old_checkpoint in checkpoint_files[:-self.config.save_top_k]:
375
+ old_checkpoint.unlink()
376
+
377
+ def train(self):
378
+ """Main training loop."""
379
+ self.logger.info("Starting training...")
380
+ self.logger.info(f"Training for {self.config.epochs} epochs")
381
+
382
+ # Call training start callbacks
383
+ for callback in self.callbacks:
384
+ callback.on_train_start(self)
385
+
386
+ try:
387
+ for epoch in range(self.config.epochs):
388
+ self.current_epoch = epoch
389
+
390
+ # Call epoch start callbacks
391
+ for callback in self.callbacks:
392
+ callback.on_epoch_start(self, epoch)
393
+
394
+ # Training
395
+ train_metrics = self.train_epoch()
396
+
397
+ # Validation
398
+ val_metrics = {}
399
+ if epoch % self.config.eval_interval == 0:
400
+ val_metrics = self.validate()
401
+
402
+ # Combine metrics
403
+ all_metrics = {**train_metrics, **val_metrics}
404
+
405
+ # Learning rate scheduling
406
+ if self.scheduler:
407
+ if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
408
+ metric_value = all_metrics.get(self.config.early_stopping_metric)
409
+ if metric_value is not None:
410
+ self.scheduler.step(metric_value)
411
+ else:
412
+ self.scheduler.step()
413
+
414
+ # Logging
415
+ self.logger.info(f"Epoch {epoch} completed:")
416
+ for metric_name, metric_value in all_metrics.items():
417
+ self.logger.info(f" {metric_name}: {metric_value:.4f}")
418
+
419
+ # Check for best model
420
+ current_metric = all_metrics.get(self.config.early_stopping_metric)
421
+ is_best = False
422
+
423
+ if current_metric is not None:
424
+ if self.config.early_stopping_mode == 'min':
425
+ is_best = current_metric < self.best_metric
426
+ else:
427
+ is_best = current_metric > self.best_metric
428
+
429
+ if is_best:
430
+ self.best_metric = current_metric
431
+ self.patience_counter = 0
432
+ else:
433
+ self.patience_counter += 1
434
+
435
+ # Save checkpoint
436
+ if epoch % self.config.save_interval == 0 or is_best:
437
+ self.save_checkpoint(all_metrics, is_best)
438
+
439
+ # Call epoch end callbacks
440
+ for callback in self.callbacks:
441
+ callback.on_epoch_end(self, epoch, all_metrics)
442
+
443
+ # Early stopping
444
+ if self.patience_counter >= self.config.early_stopping_patience:
445
+ self.logger.info(f"Early stopping triggered after {self.patience_counter} epochs without improvement")
446
+ break
447
+
448
+ except KeyboardInterrupt:
449
+ self.logger.info("Training interrupted by user")
450
+
451
+ except Exception as e:
452
+ self.logger.error(f"Training failed with error: {str(e)}")
453
+ raise
454
+
455
+ finally:
456
+ # Call training end callbacks
457
+ for callback in self.callbacks:
458
+ callback.on_train_end(self)
459
+
460
+ self.logger.info("Training completed!")
461
+
462
+ def load_checkpoint(self, checkpoint_path: str):
463
+ """Load model from checkpoint."""
464
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
465
+
466
+ self.model.load_state_dict(checkpoint['model_state_dict'])
467
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
468
+
469
+ if self.scheduler and checkpoint.get('scheduler_state_dict'):
470
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
471
+
472
+ if self.scaler and checkpoint.get('scaler_state_dict'):
473
+ self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
474
+
475
+ self.current_epoch = checkpoint['epoch']
476
+ self.global_step = checkpoint['global_step']
477
+
478
+ self.logger.info(f"Loaded checkpoint from epoch {self.current_epoch}")
@@ -1,7 +1,15 @@
1
1
  import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
2
5
  import os
3
6
  import logging
4
- from langvision.callbacks.base import Callback
7
+ from typing import Optional, List, Dict, Any
8
+ from tqdm import tqdm
9
+ import time
10
+
11
+ from ..callbacks.base import Callback
12
+ from ..utils.device import get_device, to_device
5
13
 
6
14
  logger = logging.getLogger("langvision.trainer")
7
15
 
@@ -13,8 +21,13 @@ class Trainer:
13
21
  - Callbacks
14
22
  - Distributed/multi-GPU (use torch.nn.parallel.DistributedDataParallel)
15
23
  - TPU (use torch_xla)
24
+
25
+ Integration points for advanced LLM concepts:
26
+ - RLHF: Use RLHF-based feedback in the training loop
27
+ - PPO/DPO/GRPO/RLVR: Use RL-based optimization for policy/model updates
28
+ - LIME/SHAP: Use for model interpretability during/after training
16
29
  """
17
- def __init__(self, model, optimizer, criterion, scheduler=None, scaler=None, callbacks=None, device='cpu'):
30
+ def __init__(self, model, optimizer, criterion, scheduler=None, scaler=None, callbacks=None, device='cpu', rlhf=None, ppo=None, dpo=None):
18
31
  self.model = model
19
32
  self.optimizer = optimizer
20
33
  self.criterion = criterion
@@ -22,6 +35,9 @@ class Trainer:
22
35
  self.scaler = scaler
23
36
  self.callbacks = callbacks or []
24
37
  self.device = device
38
+ self.rlhf = rlhf # RLHF integration (optional)
39
+ self.ppo = ppo # PPO integration (optional)
40
+ self.dpo = dpo # DPO integration (optional)
25
41
  # GPU optimization
26
42
  if device.type == 'cuda':
27
43
  torch.backends.cudnn.benchmark = True
@@ -40,6 +56,18 @@ class Trainer:
40
56
  imgs = imgs.to(self.device, non_blocking=True)
41
57
  labels = labels.to(self.device, non_blocking=True)
42
58
  self.optimizer.zero_grad()
59
+ # RLHF integration example (stub)
60
+ if self.rlhf is not None:
61
+ # TODO: Use self.rlhf.train(data, feedback) for RLHF-based updates
62
+ pass
63
+ # PPO integration example (stub)
64
+ if self.ppo is not None:
65
+ # TODO: Use self.ppo.step(state, action, reward) for PPO-based updates
66
+ pass
67
+ # DPO integration example (stub)
68
+ if self.dpo is not None:
69
+ # TODO: Use self.dpo.optimize_with_preferences(preferences) for DPO-based updates
70
+ pass
43
71
  if self.scaler:
44
72
  with torch.cuda.amp.autocast():
45
73
  outputs = self.model(imgs)