langvision 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langvision might be problematic. Click here for more details.

Files changed (41) hide show
  1. langvision/__init__.py +77 -2
  2. langvision/callbacks/base.py +166 -7
  3. langvision/cli/__init__.py +85 -0
  4. langvision/cli/complete_cli.py +319 -0
  5. langvision/cli/config.py +344 -0
  6. langvision/cli/evaluate.py +201 -0
  7. langvision/cli/export.py +177 -0
  8. langvision/cli/finetune.py +165 -48
  9. langvision/cli/model_zoo.py +162 -0
  10. langvision/cli/train.py +27 -13
  11. langvision/cli/utils.py +258 -0
  12. langvision/components/attention.py +4 -1
  13. langvision/concepts/__init__.py +9 -0
  14. langvision/concepts/ccot.py +30 -0
  15. langvision/concepts/cot.py +29 -0
  16. langvision/concepts/dpo.py +37 -0
  17. langvision/concepts/grpo.py +25 -0
  18. langvision/concepts/lime.py +37 -0
  19. langvision/concepts/ppo.py +47 -0
  20. langvision/concepts/rlhf.py +40 -0
  21. langvision/concepts/rlvr.py +25 -0
  22. langvision/concepts/shap.py +37 -0
  23. langvision/data/enhanced_datasets.py +582 -0
  24. langvision/model_zoo.py +169 -2
  25. langvision/models/lora.py +189 -17
  26. langvision/models/multimodal.py +297 -0
  27. langvision/models/resnet.py +303 -0
  28. langvision/training/advanced_trainer.py +478 -0
  29. langvision/training/trainer.py +30 -2
  30. langvision/utils/config.py +180 -9
  31. langvision/utils/metrics.py +448 -0
  32. langvision/utils/setup.py +266 -0
  33. langvision-0.1.0.dist-info/METADATA +50 -0
  34. langvision-0.1.0.dist-info/RECORD +61 -0
  35. {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/WHEEL +1 -1
  36. langvision-0.1.0.dist-info/entry_points.txt +2 -0
  37. langvision-0.0.1.dist-info/METADATA +0 -463
  38. langvision-0.0.1.dist-info/RECORD +0 -40
  39. langvision-0.0.1.dist-info/entry_points.txt +0 -2
  40. langvision-0.0.1.dist-info/licenses/LICENSE +0 -21
  41. {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export CLI for Langvision models.
4
+ """
5
+
6
+ import argparse
7
+ import logging
8
+ import os
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ import torch.onnx
14
+
15
+ # Import from langvision modules
16
+ try:
17
+ from langvision.models.vision_transformer import VisionTransformer
18
+ except ImportError as e:
19
+ print(f"❌ Error importing langvision modules: {e}")
20
+ print("Please ensure langvision is properly installed:")
21
+ print(" pip install langvision")
22
+ sys.exit(1)
23
+
24
+
25
+ def parse_args():
26
+ """Parse command-line arguments for model export."""
27
+ parser = argparse.ArgumentParser(
28
+ description='Export a trained VisionTransformer model to various formats',
29
+ epilog='''\nExamples:\n langvision export --checkpoint model.pth --format onnx --output model.onnx\n langvision export --checkpoint model.pth --format torchscript --output model.pt\n''',
30
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
31
+ )
32
+
33
+ # Model
34
+ model_group = parser.add_argument_group('Model')
35
+ model_group.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
36
+ model_group.add_argument('--img_size', type=int, default=224, help='Input image size')
37
+ model_group.add_argument('--patch_size', type=int, default=16, help='Patch size for ViT')
38
+ model_group.add_argument('--num_classes', type=int, default=10, help='Number of classes')
39
+ model_group.add_argument('--embed_dim', type=int, default=768, help='Embedding dimension')
40
+ model_group.add_argument('--depth', type=int, default=12, help='Number of transformer layers')
41
+ model_group.add_argument('--num_heads', type=int, default=12, help='Number of attention heads')
42
+ model_group.add_argument('--mlp_ratio', type=float, default=4.0, help='MLP hidden dim ratio')
43
+
44
+ # Export
45
+ export_group = parser.add_argument_group('Export')
46
+ export_group.add_argument('--format', type=str, required=True, choices=['onnx', 'torchscript', 'state_dict'], help='Export format')
47
+ export_group.add_argument('--output', type=str, required=True, help='Output file path')
48
+ export_group.add_argument('--batch_size', type=int, default=1, help='Batch size for export (ONNX/TorchScript)')
49
+ export_group.add_argument('--opset_version', type=int, default=11, help='ONNX opset version')
50
+
51
+ # Device
52
+ device_group = parser.add_argument_group('Device')
53
+ device_group.add_argument('--device', type=str, default='cpu', help='Device to use for export (use CPU for ONNX)')
54
+
55
+ # Misc
56
+ misc_group = parser.add_argument_group('Misc')
57
+ misc_group.add_argument('--log_level', type=str, default='info', help='Logging level')
58
+
59
+ return parser.parse_args()
60
+
61
+
62
+ def setup_logging(log_level: str) -> None:
63
+ """Set up logging with the specified log level."""
64
+ numeric_level = getattr(logging, log_level.upper(), None)
65
+ if not isinstance(numeric_level, int):
66
+ numeric_level = logging.INFO
67
+ logging.basicConfig(level=numeric_level, format='[%(levelname)s] %(message)s')
68
+
69
+
70
+ def main():
71
+ """Main function for model export."""
72
+ args = parse_args()
73
+ setup_logging(args.log_level)
74
+ logger = logging.getLogger(__name__)
75
+
76
+ logger.info("🚀 Starting model export...")
77
+
78
+ # Load model
79
+ logger.info("🤖 Loading model...")
80
+ try:
81
+ model = VisionTransformer(
82
+ img_size=args.img_size,
83
+ patch_size=args.patch_size,
84
+ in_chans=3,
85
+ num_classes=args.num_classes,
86
+ embed_dim=args.embed_dim,
87
+ depth=args.depth,
88
+ num_heads=args.num_heads,
89
+ mlp_ratio=args.mlp_ratio,
90
+ ).to(args.device)
91
+
92
+ # Load checkpoint
93
+ if not os.path.isfile(args.checkpoint):
94
+ logger.error(f"❌ Checkpoint file not found: {args.checkpoint}")
95
+ return 1
96
+
97
+ checkpoint = torch.load(args.checkpoint, map_location=args.device)
98
+ if 'model' in checkpoint:
99
+ model.load_state_dict(checkpoint['model'])
100
+ else:
101
+ model.load_state_dict(checkpoint)
102
+
103
+ model.eval()
104
+ logger.info("✅ Model loaded successfully")
105
+
106
+ except Exception as e:
107
+ logger.error(f"❌ Failed to load model: {e}")
108
+ return 1
109
+
110
+ # Create output directory
111
+ os.makedirs(os.path.dirname(args.output), exist_ok=True)
112
+
113
+ # Export model
114
+ logger.info(f"📦 Exporting model to {args.format.upper()} format...")
115
+ try:
116
+ if args.format == 'onnx':
117
+ # Create dummy input
118
+ dummy_input = torch.randn(args.batch_size, 3, args.img_size, args.img_size).to(args.device)
119
+
120
+ # Export to ONNX
121
+ torch.onnx.export(
122
+ model,
123
+ dummy_input,
124
+ args.output,
125
+ export_params=True,
126
+ opset_version=args.opset_version,
127
+ do_constant_folding=True,
128
+ input_names=['input'],
129
+ output_names=['output'],
130
+ dynamic_axes={
131
+ 'input': {0: 'batch_size'},
132
+ 'output': {0: 'batch_size'}
133
+ }
134
+ )
135
+ logger.info(f"✅ ONNX model exported to {args.output}")
136
+
137
+ elif args.format == 'torchscript':
138
+ # Create dummy input
139
+ dummy_input = torch.randn(args.batch_size, 3, args.img_size, args.img_size).to(args.device)
140
+
141
+ # Export to TorchScript
142
+ traced_model = torch.jit.trace(model, dummy_input)
143
+ traced_model.save(args.output)
144
+ logger.info(f"✅ TorchScript model exported to {args.output}")
145
+
146
+ elif args.format == 'state_dict':
147
+ # Export state dict
148
+ torch.save(model.state_dict(), args.output)
149
+ logger.info(f"✅ State dict exported to {args.output}")
150
+
151
+ except Exception as e:
152
+ logger.error(f"❌ Export failed: {e}")
153
+ return 1
154
+
155
+ # Verify export
156
+ if args.format in ['onnx', 'torchscript']:
157
+ logger.info("🔍 Verifying exported model...")
158
+ try:
159
+ if args.format == 'onnx':
160
+ import onnx
161
+ onnx_model = onnx.load(args.output)
162
+ onnx.checker.check_model(onnx_model)
163
+ logger.info("✅ ONNX model verification passed")
164
+
165
+ elif args.format == 'torchscript':
166
+ loaded_model = torch.jit.load(args.output)
167
+ logger.info("✅ TorchScript model verification passed")
168
+
169
+ except Exception as e:
170
+ logger.warning(f"⚠️ Model verification failed: {e}")
171
+
172
+ logger.info("✅ Model export completed successfully!")
173
+ return 0
174
+
175
+
176
+ if __name__ == '__main__':
177
+ sys.exit(main())
@@ -23,47 +23,67 @@ def set_seed(seed: int) -> None:
23
23
 
24
24
  def parse_args() -> argparse.Namespace:
25
25
  """Parse command-line arguments for fine-tuning."""
26
- parser = argparse.ArgumentParser(description='Fine-tune VisionTransformer with LoRA')
26
+ parser = argparse.ArgumentParser(
27
+ description='Fine-tune VisionTransformer with LoRA and advanced LLM concepts',
28
+ epilog='''\nExamples:\n langvision finetune --dataset cifar10 --epochs 10\n langvision finetune --dataset cifar100 --lora_r 8 --rlhf\n''',
29
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
30
+ )
27
31
  # Data
28
- parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='Dataset to use')
29
- parser.add_argument('--data_dir', type=str, default='./data', help='Dataset directory')
30
- parser.add_argument('--num_classes', type=int, default=10, help='Number of classes')
31
- parser.add_argument('--img_size', type=int, default=224, help='Input image size')
32
- parser.add_argument('--patch_size', type=int, default=16, help='Patch size for ViT')
33
- parser.add_argument('--num_workers', type=int, default=2, help='Number of data loader workers')
32
+ data_group = parser.add_argument_group('Data')
33
+ data_group.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='Dataset to use')
34
+ data_group.add_argument('--data_dir', type=str, default='./data', help='Dataset directory')
35
+ data_group.add_argument('--num_classes', type=int, default=10, help='Number of classes in the dataset')
36
+ data_group.add_argument('--img_size', type=int, default=224, help='Input image size (pixels)')
37
+ data_group.add_argument('--patch_size', type=int, default=16, help='Patch size for Vision Transformer')
38
+ data_group.add_argument('--num_workers', type=int, default=2, help='Number of data loader workers')
34
39
  # Model
35
- parser.add_argument('--embed_dim', type=int, default=768, help='Embedding dimension')
36
- parser.add_argument('--depth', type=int, default=12, help='Number of transformer layers')
37
- parser.add_argument('--num_heads', type=int, default=12, help='Number of attention heads')
38
- parser.add_argument('--mlp_ratio', type=float, default=4.0, help='MLP ratio')
39
- parser.add_argument('--lora_r', type=int, default=4, help='LoRA rank')
40
- parser.add_argument('--lora_alpha', type=float, default=1.0, help='LoRA alpha')
41
- parser.add_argument('--lora_dropout', type=float, default=0.1, help='LoRA dropout')
40
+ model_group = parser.add_argument_group('Model')
41
+ model_group.add_argument('--embed_dim', type=int, default=768, help='Embedding dimension for ViT')
42
+ model_group.add_argument('--depth', type=int, default=12, help='Number of transformer layers')
43
+ model_group.add_argument('--num_heads', type=int, default=12, help='Number of attention heads')
44
+ model_group.add_argument('--mlp_ratio', type=float, default=4.0, help='MLP hidden dim ratio')
45
+ model_group.add_argument('--lora_r', type=int, default=4, help='LoRA rank (low-rank adaptation)')
46
+ model_group.add_argument('--lora_alpha', type=float, default=1.0, help='LoRA alpha scaling')
47
+ model_group.add_argument('--lora_dropout', type=float, default=0.1, help='LoRA dropout rate')
42
48
  # Training
43
- parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
44
- parser.add_argument('--epochs', type=int, default=10, help='Number of epochs')
45
- parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
46
- parser.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'adamw', 'sgd'], help='Optimizer')
47
- parser.add_argument('--weight_decay', type=float, default=0.01, help='Weight decay (L2 regularization)')
48
- parser.add_argument('--scheduler', type=str, default='cosine', choices=['cosine', 'step'], help='LR scheduler')
49
- parser.add_argument('--step_size', type=int, default=5, help='StepLR step size')
50
- parser.add_argument('--gamma', type=float, default=0.5, help='StepLR gamma')
51
- parser.add_argument('--resume', type=str, default=None, help='Path to checkpoint to resume from')
52
- parser.add_argument('--eval_only', action='store_true', help='Only run evaluation')
49
+ train_group = parser.add_argument_group('Training')
50
+ train_group.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
51
+ train_group.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
52
+ train_group.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
53
+ train_group.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'adamw', 'sgd'], help='Optimizer type')
54
+ train_group.add_argument('--weight_decay', type=float, default=0.01, help='Weight decay (L2 regularization)')
55
+ train_group.add_argument('--scheduler', type=str, default='cosine', choices=['cosine', 'step'], help='Learning rate scheduler')
56
+ train_group.add_argument('--step_size', type=int, default=5, help='StepLR: step size')
57
+ train_group.add_argument('--gamma', type=float, default=0.5, help='StepLR: gamma')
58
+ train_group.add_argument('--resume', type=str, default=None, help='Path to checkpoint to resume from')
59
+ train_group.add_argument('--eval_only', action='store_true', help='Only run evaluation (no training)')
53
60
  # Output
54
- parser.add_argument('--output_dir', type=str, default='outputs', help='Directory to save outputs and checkpoints')
55
- parser.add_argument('--save_name', type=str, default='vit_lora_best.pth', help='Checkpoint file name')
61
+ output_group = parser.add_argument_group('Output')
62
+ output_group.add_argument('--output_dir', type=str, default='outputs', help='Directory to save outputs and checkpoints')
63
+ output_group.add_argument('--save_name', type=str, default='vit_lora_best.pth', help='Checkpoint file name')
56
64
  # Callbacks
57
- parser.add_argument('--early_stopping', action='store_true', help='Enable early stopping')
58
- parser.add_argument('--patience', type=int, default=5, help='Early stopping patience')
65
+ callback_group = parser.add_argument_group('Callbacks')
66
+ callback_group.add_argument('--early_stopping', action='store_true', help='Enable early stopping')
67
+ callback_group.add_argument('--patience', type=int, default=5, help='Early stopping patience')
59
68
  # CUDA
60
- parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use')
61
- parser.add_argument('--cuda_deterministic', action='store_true', help='Enable deterministic CUDA (reproducible, slower)')
62
- parser.add_argument('--cuda_benchmark', action='store_true', default=True, help='Enable cudnn.benchmark for fast training (default: True)')
63
- parser.add_argument('--cuda_max_split_size_mb', type=int, default=None, help='Set CUDA max split size in MB (for large models, PyTorch >=1.10)')
69
+ cuda_group = parser.add_argument_group('CUDA')
70
+ cuda_group.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use (cuda or cpu)')
71
+ cuda_group.add_argument('--cuda_deterministic', action='store_true', help='Enable deterministic CUDA (reproducible, slower)')
72
+ cuda_group.add_argument('--cuda_benchmark', action='store_true', default=True, help='Enable cudnn.benchmark for fast training')
73
+ cuda_group.add_argument('--cuda_max_split_size_mb', type=int, default=None, help='Set CUDA max split size in MB (for large models, PyTorch >=1.10)')
64
74
  # Misc
65
- parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
66
- parser.add_argument('--log_level', type=str, default='info', help='Logging level (debug, info, warning, error)')
75
+ misc_group = parser.add_argument_group('Misc')
76
+ misc_group.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
77
+ misc_group.add_argument('--log_level', type=str, default='info', help='Logging level (debug, info, warning, error)')
78
+ # Advanced LLM concepts
79
+ llm_group = parser.add_argument_group('Advanced LLM Concepts')
80
+ llm_group.add_argument('--rlhf', action='store_true', help='Use RLHF (Reinforcement Learning from Human Feedback)')
81
+ llm_group.add_argument('--ppo', action='store_true', help='Use PPO (Proximal Policy Optimization)')
82
+ llm_group.add_argument('--dpo', action='store_true', help='Use DPO (Direct Preference Optimization)')
83
+ llm_group.add_argument('--lime', action='store_true', help='Use LIME for model explainability')
84
+ llm_group.add_argument('--shap', action='store_true', help='Use SHAP for model explainability')
85
+ llm_group.add_argument('--cot', action='store_true', help='Use Chain-of-Thought (CoT) prompt generation')
86
+ llm_group.add_argument('--ccot', action='store_true', help='Use Contrastive Chain-of-Thought (CCoT)')
67
87
  return parser.parse_args()
68
88
 
69
89
 
@@ -76,12 +96,11 @@ def setup_logging(log_level: str) -> None:
76
96
 
77
97
 
78
98
  def main() -> None:
79
- """Main function for fine-tuning VisionTransformer with LoRA."""
99
+ """Main function for fine-tuning VisionTransformer with LoRA and advanced LLM concepts."""
80
100
  args = parse_args()
81
101
  setup_logging(args.log_level)
82
102
  logger = logging.getLogger(__name__)
83
- logger.info(f"Using device: {args.device}")
84
-
103
+ logger.info("[STEP 1] Loading dataset...")
85
104
  setup_cuda(seed=args.seed, deterministic=args.cuda_deterministic, benchmark=args.cuda_benchmark, max_split_size_mb=args.cuda_max_split_size_mb)
86
105
  set_seed(args.seed)
87
106
 
@@ -95,6 +114,7 @@ def main() -> None:
95
114
  train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
96
115
  val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
97
116
 
117
+ logger.info("[STEP 2] Initializing model...")
98
118
  # Model
99
119
  try:
100
120
  model = VisionTransformer(
@@ -116,6 +136,7 @@ def main() -> None:
116
136
  logger.error(f"Failed to initialize model: {e}")
117
137
  return
118
138
 
139
+ logger.info("[STEP 3] Setting up optimizer and scheduler...")
119
140
  # Optimizer
120
141
  lora_params = [p for n, p in model.named_parameters() if 'lora' in n and p.requires_grad]
121
142
  if args.optimizer == 'adam':
@@ -125,7 +146,6 @@ def main() -> None:
125
146
  else:
126
147
  optimizer = torch.optim.SGD(lora_params, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
127
148
 
128
- # Scheduler
129
149
  if args.scheduler == 'cosine':
130
150
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
131
151
  else:
@@ -134,11 +154,64 @@ def main() -> None:
134
154
  criterion = torch.nn.CrossEntropyLoss()
135
155
  scaler = torch.cuda.amp.GradScaler() if args.device == 'cuda' else None
136
156
 
157
+ logger.info("[STEP 4] Setting up trainer...")
137
158
  # Callbacks
138
159
  callbacks = []
139
160
  if args.early_stopping:
140
161
  callbacks.append(EarlyStopping(patience=args.patience))
141
162
 
163
+ # Advanced LLM concept objects
164
+ rlhf = None
165
+ ppo = None
166
+ dpo = None
167
+ lime = None
168
+ shap = None
169
+ cot = None
170
+ ccot = None
171
+ if args.rlhf:
172
+ from langvision.concepts.rlhf import RLHF
173
+ class SimpleRLHF(RLHF):
174
+ def train(self, model, data, feedback_fn, optimizer):
175
+ super().train(model, data, feedback_fn, optimizer)
176
+ rlhf = SimpleRLHF()
177
+ if args.ppo:
178
+ from langvision.concepts.ppo import PPO
179
+ class SimplePPO(PPO):
180
+ def step(self, policy, old_log_probs, states, actions, rewards, optimizer):
181
+ super().step(policy, old_log_probs, states, actions, rewards, optimizer)
182
+ ppo = SimplePPO()
183
+ if args.dpo:
184
+ from langvision.concepts.dpo import DPO
185
+ class SimpleDPO(DPO):
186
+ def optimize_with_preferences(self, model, preferences, optimizer):
187
+ super().optimize_with_preferences(model, preferences, optimizer)
188
+ dpo = SimpleDPO()
189
+ if args.lime:
190
+ from langvision.concepts.lime import LIME
191
+ class SimpleLIME(LIME):
192
+ def explain(self, model, input_data):
193
+ return super().explain(model, input_data)
194
+ lime = SimpleLIME()
195
+ if args.shap:
196
+ from langvision.concepts.shap import SHAP
197
+ class SimpleSHAP(SHAP):
198
+ def explain(self, model, input_data):
199
+ return super().explain(model, input_data)
200
+ shap = SimpleSHAP()
201
+ if args.cot:
202
+ from langvision.concepts.cot import CoT
203
+ class SimpleCoT(CoT):
204
+ def generate_chain(self, prompt):
205
+ return super().generate_chain(prompt)
206
+ cot = SimpleCoT()
207
+ if args.ccot:
208
+ from langvision.concepts.ccot import CCoT
209
+ class SimpleCCoT(CCoT):
210
+ def contrastive_train(self, positive_chains, negative_chains):
211
+ super().contrastive_train(positive_chains, negative_chains)
212
+ ccot = SimpleCCoT()
213
+
214
+ logger.info("[STEP 5] (Optional) Loading checkpoint if provided...")
142
215
  # Trainer
143
216
  trainer = Trainer(
144
217
  model=model,
@@ -148,6 +221,9 @@ def main() -> None:
148
221
  scaler=scaler,
149
222
  callbacks=callbacks,
150
223
  device=args.device,
224
+ rlhf=rlhf,
225
+ ppo=ppo,
226
+ dpo=dpo,
151
227
  )
152
228
 
153
229
  # Optionally resume
@@ -162,20 +238,61 @@ def main() -> None:
162
238
  logger.info(f"Resumed from {args.resume} at epoch {start_epoch}")
163
239
 
164
240
  if args.eval_only:
241
+ logger.info("[STEP 6] Running evaluation only...")
165
242
  val_loss, val_acc = trainer.evaluate(val_loader)
166
243
  logger.info(f"Eval Loss: {val_loss:.4f}, Eval Acc: {val_acc:.4f}")
167
244
  return
168
245
 
169
- # Training
170
- best_acc = trainer.fit(
171
- train_loader,
172
- val_loader,
173
- epochs=args.epochs,
174
- start_epoch=start_epoch,
175
- best_acc=best_acc,
176
- checkpoint_path=os.path.join(args.output_dir, args.save_name),
177
- )
178
- logger.info(f"Best validation accuracy: {best_acc:.4f}")
246
+ logger.info("[STEP 6] Starting training...")
247
+ # Training with per-epoch progress
248
+ for epoch in range(start_epoch, args.epochs):
249
+ logger.info(f" [Epoch {epoch+1}/{args.epochs}] Starting epoch...")
250
+ best_acc = trainer.fit(
251
+ train_loader,
252
+ val_loader,
253
+ epochs=epoch+1,
254
+ start_epoch=epoch,
255
+ best_acc=best_acc,
256
+ checkpoint_path=os.path.join(args.output_dir, args.save_name),
257
+ )
258
+ logger.info(f" [Epoch {epoch+1}/{args.epochs}] Done. Best validation accuracy so far: {best_acc:.4f}")
259
+ logger.info(f"[STEP 7] Training complete. Best validation accuracy: {best_acc:.4f}")
260
+
261
+ # Example: Use RLHF, PPO, DPO, LIME, SHAP, CoT, CCoT in training
262
+ if rlhf is not None:
263
+ logger.info("[STEP 8] Applying RLHF...")
264
+ def feedback_fn(output):
265
+ return 1.0 if output.sum().item() > 0 else -1.0
266
+ rlhf.train(model, [torch.randn(3) for _ in range(10)], feedback_fn, optimizer)
267
+ if ppo is not None:
268
+ logger.info("[STEP 9] Applying PPO...")
269
+ import torch
270
+ policy = model
271
+ old_log_probs = torch.zeros(10)
272
+ states = torch.randn(10, 3)
273
+ actions = torch.randint(0, args.num_classes, (10,))
274
+ rewards = torch.randn(10)
275
+ ppo.step(policy, old_log_probs, states, actions, rewards, optimizer)
276
+ if dpo is not None:
277
+ logger.info("[STEP 10] Applying DPO...")
278
+ preferences = [(torch.randn(3), 1.0), (torch.randn(3), -1.0)]
279
+ dpo.optimize_with_preferences(model, preferences, optimizer)
280
+ if lime is not None:
281
+ logger.info("[STEP 11] Running LIME explainability...")
282
+ lime_explanation = lime.explain(model, [[0.5, 1.0, 2.0], [1.0, 2.0, 3.0]])
283
+ logger.info(f"LIME explanation: {lime_explanation}")
284
+ if shap is not None:
285
+ logger.info("[STEP 12] Running SHAP explainability...")
286
+ shap_explanation = shap.explain(model, [[0.5, 1.0, 2.0], [1.0, 2.0, 3.0]])
287
+ logger.info(f"SHAP explanation: {shap_explanation}")
288
+ if cot is not None:
289
+ logger.info("[STEP 13] Generating Chain-of-Thought...")
290
+ chain = cot.generate_chain("What is 2 + 2?")
291
+ logger.info(f"CoT chain: {chain}")
292
+ if ccot is not None:
293
+ logger.info("[STEP 14] Running Contrastive Chain-of-Thought...")
294
+ ccot.contrastive_train([['Step 1: Think', 'Step 2: Solve']], [['Step 1: Guess', 'Step 2: Wrong']])
295
+ logger.info("[COMPLETE] All steps finished.")
179
296
 
180
297
  if __name__ == '__main__':
181
298
  main()
@@ -0,0 +1,162 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Model Zoo CLI for Langvision.
4
+ """
5
+
6
+ import argparse
7
+ import json
8
+ import logging
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ # Import from langvision modules
13
+ try:
14
+ from langvision.model_zoo import get_available_models, get_model_info, download_model
15
+ except ImportError as e:
16
+ print(f"❌ Error importing langvision modules: {e}")
17
+ print("Please ensure langvision is properly installed:")
18
+ print(" pip install langvision")
19
+ sys.exit(1)
20
+
21
+
22
+ def parse_args():
23
+ """Parse command-line arguments for model zoo operations."""
24
+ parser = argparse.ArgumentParser(
25
+ description='Langvision Model Zoo - Browse, download, and manage pre-trained models',
26
+ epilog='''\nExamples:\n langvision model-zoo list\n langvision model-zoo info vit_base_patch16_224\n langvision model-zoo download vit_base_patch16_224\n''',
27
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
28
+ )
29
+
30
+ subparsers = parser.add_subparsers(dest='command', required=True, help='Model zoo commands')
31
+
32
+ # List command
33
+ list_parser = subparsers.add_parser('list', help='List all available models')
34
+ list_parser.add_argument('--format', type=str, default='table', choices=['table', 'json'], help='Output format')
35
+ list_parser.add_argument('--filter', type=str, help='Filter models by name or type')
36
+
37
+ # Info command
38
+ info_parser = subparsers.add_parser('info', help='Get detailed information about a model')
39
+ info_parser.add_argument('model_name', type=str, help='Name of the model')
40
+ info_parser.add_argument('--format', type=str, default='table', choices=['table', 'json'], help='Output format')
41
+
42
+ # Download command
43
+ download_parser = subparsers.add_parser('download', help='Download a pre-trained model')
44
+ download_parser.add_argument('model_name', type=str, help='Name of the model to download')
45
+ download_parser.add_argument('--output_dir', type=str, default='./models', help='Directory to save the model')
46
+ download_parser.add_argument('--force', action='store_true', help='Force download even if model exists')
47
+
48
+ # Search command
49
+ search_parser = subparsers.add_parser('search', help='Search for models')
50
+ search_parser.add_argument('query', type=str, help='Search query')
51
+ search_parser.add_argument('--format', type=str, default='table', choices=['table', 'json'], help='Output format')
52
+
53
+ # Misc
54
+ parser.add_argument('--log_level', type=str, default='info', help='Logging level')
55
+
56
+ return parser.parse_args()
57
+
58
+
59
+ def setup_logging(log_level: str) -> None:
60
+ """Set up logging with the specified log level."""
61
+ numeric_level = getattr(logging, log_level.upper(), None)
62
+ if not isinstance(numeric_level, int):
63
+ numeric_level = logging.INFO
64
+ logging.basicConfig(level=numeric_level, format='[%(levelname)s] %(message)s')
65
+
66
+
67
+ def print_model_table(models):
68
+ """Print models in a formatted table."""
69
+ if not models:
70
+ print("No models found.")
71
+ return
72
+
73
+ # Calculate column widths
74
+ name_width = max(len(model.get('name', '')) for model in models) + 2
75
+ type_width = max(len(model.get('type', '')) for model in models) + 2
76
+ size_width = max(len(str(model.get('size', ''))) for model in models) + 2
77
+
78
+ # Print header
79
+ print(f"{'Name':<{name_width}} {'Type':<{type_width}} {'Size':<{size_width}} {'Description'}")
80
+ print("-" * (name_width + type_width + size_width + 50))
81
+
82
+ # Print models
83
+ for model in models:
84
+ name = model.get('name', '')
85
+ model_type = model.get('type', '')
86
+ size = model.get('size', '')
87
+ description = model.get('description', '')
88
+ print(f"{name:<{name_width}} {model_type:<{type_width}} {size:<{size_width}} {description}")
89
+
90
+
91
+ def print_model_info(model_info, format='table'):
92
+ """Print detailed model information."""
93
+ if format == 'json':
94
+ print(json.dumps(model_info, indent=2))
95
+ return
96
+
97
+ print(f"\n🤖 Model: {model_info.get('name', 'Unknown')}")
98
+ print("=" * 50)
99
+
100
+ for key, value in model_info.items():
101
+ if key == 'name':
102
+ continue
103
+ print(f"{key.replace('_', ' ').title()}: {value}")
104
+
105
+ if 'config' in model_info:
106
+ print(f"\n📋 Configuration:")
107
+ config = model_info['config']
108
+ for key, value in config.items():
109
+ print(f" {key}: {value}")
110
+
111
+
112
+ def main():
113
+ """Main function for model zoo operations."""
114
+ args = parse_args()
115
+ setup_logging(args.log_level)
116
+ logger = logging.getLogger(__name__)
117
+
118
+ try:
119
+ if args.command == 'list':
120
+ logger.info("📋 Fetching available models...")
121
+ models = get_available_models()
122
+
123
+ if args.filter:
124
+ models = [m for m in models if args.filter.lower() in m.get('name', '').lower()]
125
+
126
+ if args.format == 'json':
127
+ print(json.dumps(models, indent=2))
128
+ else:
129
+ print(f"\n🎯 Available Models ({len(models)} total):")
130
+ print_model_table(models)
131
+
132
+ elif args.command == 'info':
133
+ logger.info(f"🔍 Getting information for model: {args.model_name}")
134
+ model_info = get_model_info(args.model_name)
135
+ print_model_info(model_info, args.format)
136
+
137
+ elif args.command == 'download':
138
+ logger.info(f"⬇️ Downloading model: {args.model_name}")
139
+ output_path = download_model(args.model_name, args.output_dir, force=args.force)
140
+ logger.info(f"✅ Model downloaded to: {output_path}")
141
+
142
+ elif args.command == 'search':
143
+ logger.info(f"🔍 Searching for: {args.query}")
144
+ models = get_available_models()
145
+ results = [m for m in models if args.query.lower() in m.get('name', '').lower() or
146
+ args.query.lower() in m.get('description', '').lower()]
147
+
148
+ if args.format == 'json':
149
+ print(json.dumps(results, indent=2))
150
+ else:
151
+ print(f"\n🔍 Search Results for '{args.query}' ({len(results)} found):")
152
+ print_model_table(results)
153
+
154
+ except Exception as e:
155
+ logger.error(f"❌ Operation failed: {e}")
156
+ return 1
157
+
158
+ return 0
159
+
160
+
161
+ if __name__ == '__main__':
162
+ sys.exit(main())