langvision 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langvision might be problematic. Click here for more details.
- langvision/__init__.py +77 -2
- langvision/callbacks/base.py +166 -7
- langvision/cli/__init__.py +85 -0
- langvision/cli/complete_cli.py +319 -0
- langvision/cli/config.py +344 -0
- langvision/cli/evaluate.py +201 -0
- langvision/cli/export.py +177 -0
- langvision/cli/finetune.py +165 -48
- langvision/cli/model_zoo.py +162 -0
- langvision/cli/train.py +27 -13
- langvision/cli/utils.py +258 -0
- langvision/components/attention.py +4 -1
- langvision/concepts/__init__.py +9 -0
- langvision/concepts/ccot.py +30 -0
- langvision/concepts/cot.py +29 -0
- langvision/concepts/dpo.py +37 -0
- langvision/concepts/grpo.py +25 -0
- langvision/concepts/lime.py +37 -0
- langvision/concepts/ppo.py +47 -0
- langvision/concepts/rlhf.py +40 -0
- langvision/concepts/rlvr.py +25 -0
- langvision/concepts/shap.py +37 -0
- langvision/data/enhanced_datasets.py +582 -0
- langvision/model_zoo.py +169 -2
- langvision/models/lora.py +189 -17
- langvision/models/multimodal.py +297 -0
- langvision/models/resnet.py +303 -0
- langvision/training/advanced_trainer.py +478 -0
- langvision/training/trainer.py +30 -2
- langvision/utils/config.py +180 -9
- langvision/utils/metrics.py +448 -0
- langvision/utils/setup.py +266 -0
- langvision-0.1.0.dist-info/METADATA +50 -0
- langvision-0.1.0.dist-info/RECORD +61 -0
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/WHEEL +1 -1
- langvision-0.1.0.dist-info/entry_points.txt +2 -0
- langvision-0.0.1.dist-info/METADATA +0 -463
- langvision-0.0.1.dist-info/RECORD +0 -40
- langvision-0.0.1.dist-info/entry_points.txt +0 -2
- langvision-0.0.1.dist-info/licenses/LICENSE +0 -21
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/top_level.txt +0 -0
langvision/cli/export.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Export CLI for Langvision models.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.onnx
|
|
14
|
+
|
|
15
|
+
# Import from langvision modules
|
|
16
|
+
try:
|
|
17
|
+
from langvision.models.vision_transformer import VisionTransformer
|
|
18
|
+
except ImportError as e:
|
|
19
|
+
print(f"❌ Error importing langvision modules: {e}")
|
|
20
|
+
print("Please ensure langvision is properly installed:")
|
|
21
|
+
print(" pip install langvision")
|
|
22
|
+
sys.exit(1)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def parse_args():
|
|
26
|
+
"""Parse command-line arguments for model export."""
|
|
27
|
+
parser = argparse.ArgumentParser(
|
|
28
|
+
description='Export a trained VisionTransformer model to various formats',
|
|
29
|
+
epilog='''\nExamples:\n langvision export --checkpoint model.pth --format onnx --output model.onnx\n langvision export --checkpoint model.pth --format torchscript --output model.pt\n''',
|
|
30
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Model
|
|
34
|
+
model_group = parser.add_argument_group('Model')
|
|
35
|
+
model_group.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
|
|
36
|
+
model_group.add_argument('--img_size', type=int, default=224, help='Input image size')
|
|
37
|
+
model_group.add_argument('--patch_size', type=int, default=16, help='Patch size for ViT')
|
|
38
|
+
model_group.add_argument('--num_classes', type=int, default=10, help='Number of classes')
|
|
39
|
+
model_group.add_argument('--embed_dim', type=int, default=768, help='Embedding dimension')
|
|
40
|
+
model_group.add_argument('--depth', type=int, default=12, help='Number of transformer layers')
|
|
41
|
+
model_group.add_argument('--num_heads', type=int, default=12, help='Number of attention heads')
|
|
42
|
+
model_group.add_argument('--mlp_ratio', type=float, default=4.0, help='MLP hidden dim ratio')
|
|
43
|
+
|
|
44
|
+
# Export
|
|
45
|
+
export_group = parser.add_argument_group('Export')
|
|
46
|
+
export_group.add_argument('--format', type=str, required=True, choices=['onnx', 'torchscript', 'state_dict'], help='Export format')
|
|
47
|
+
export_group.add_argument('--output', type=str, required=True, help='Output file path')
|
|
48
|
+
export_group.add_argument('--batch_size', type=int, default=1, help='Batch size for export (ONNX/TorchScript)')
|
|
49
|
+
export_group.add_argument('--opset_version', type=int, default=11, help='ONNX opset version')
|
|
50
|
+
|
|
51
|
+
# Device
|
|
52
|
+
device_group = parser.add_argument_group('Device')
|
|
53
|
+
device_group.add_argument('--device', type=str, default='cpu', help='Device to use for export (use CPU for ONNX)')
|
|
54
|
+
|
|
55
|
+
# Misc
|
|
56
|
+
misc_group = parser.add_argument_group('Misc')
|
|
57
|
+
misc_group.add_argument('--log_level', type=str, default='info', help='Logging level')
|
|
58
|
+
|
|
59
|
+
return parser.parse_args()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def setup_logging(log_level: str) -> None:
|
|
63
|
+
"""Set up logging with the specified log level."""
|
|
64
|
+
numeric_level = getattr(logging, log_level.upper(), None)
|
|
65
|
+
if not isinstance(numeric_level, int):
|
|
66
|
+
numeric_level = logging.INFO
|
|
67
|
+
logging.basicConfig(level=numeric_level, format='[%(levelname)s] %(message)s')
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def main():
|
|
71
|
+
"""Main function for model export."""
|
|
72
|
+
args = parse_args()
|
|
73
|
+
setup_logging(args.log_level)
|
|
74
|
+
logger = logging.getLogger(__name__)
|
|
75
|
+
|
|
76
|
+
logger.info("🚀 Starting model export...")
|
|
77
|
+
|
|
78
|
+
# Load model
|
|
79
|
+
logger.info("🤖 Loading model...")
|
|
80
|
+
try:
|
|
81
|
+
model = VisionTransformer(
|
|
82
|
+
img_size=args.img_size,
|
|
83
|
+
patch_size=args.patch_size,
|
|
84
|
+
in_chans=3,
|
|
85
|
+
num_classes=args.num_classes,
|
|
86
|
+
embed_dim=args.embed_dim,
|
|
87
|
+
depth=args.depth,
|
|
88
|
+
num_heads=args.num_heads,
|
|
89
|
+
mlp_ratio=args.mlp_ratio,
|
|
90
|
+
).to(args.device)
|
|
91
|
+
|
|
92
|
+
# Load checkpoint
|
|
93
|
+
if not os.path.isfile(args.checkpoint):
|
|
94
|
+
logger.error(f"❌ Checkpoint file not found: {args.checkpoint}")
|
|
95
|
+
return 1
|
|
96
|
+
|
|
97
|
+
checkpoint = torch.load(args.checkpoint, map_location=args.device)
|
|
98
|
+
if 'model' in checkpoint:
|
|
99
|
+
model.load_state_dict(checkpoint['model'])
|
|
100
|
+
else:
|
|
101
|
+
model.load_state_dict(checkpoint)
|
|
102
|
+
|
|
103
|
+
model.eval()
|
|
104
|
+
logger.info("✅ Model loaded successfully")
|
|
105
|
+
|
|
106
|
+
except Exception as e:
|
|
107
|
+
logger.error(f"❌ Failed to load model: {e}")
|
|
108
|
+
return 1
|
|
109
|
+
|
|
110
|
+
# Create output directory
|
|
111
|
+
os.makedirs(os.path.dirname(args.output), exist_ok=True)
|
|
112
|
+
|
|
113
|
+
# Export model
|
|
114
|
+
logger.info(f"📦 Exporting model to {args.format.upper()} format...")
|
|
115
|
+
try:
|
|
116
|
+
if args.format == 'onnx':
|
|
117
|
+
# Create dummy input
|
|
118
|
+
dummy_input = torch.randn(args.batch_size, 3, args.img_size, args.img_size).to(args.device)
|
|
119
|
+
|
|
120
|
+
# Export to ONNX
|
|
121
|
+
torch.onnx.export(
|
|
122
|
+
model,
|
|
123
|
+
dummy_input,
|
|
124
|
+
args.output,
|
|
125
|
+
export_params=True,
|
|
126
|
+
opset_version=args.opset_version,
|
|
127
|
+
do_constant_folding=True,
|
|
128
|
+
input_names=['input'],
|
|
129
|
+
output_names=['output'],
|
|
130
|
+
dynamic_axes={
|
|
131
|
+
'input': {0: 'batch_size'},
|
|
132
|
+
'output': {0: 'batch_size'}
|
|
133
|
+
}
|
|
134
|
+
)
|
|
135
|
+
logger.info(f"✅ ONNX model exported to {args.output}")
|
|
136
|
+
|
|
137
|
+
elif args.format == 'torchscript':
|
|
138
|
+
# Create dummy input
|
|
139
|
+
dummy_input = torch.randn(args.batch_size, 3, args.img_size, args.img_size).to(args.device)
|
|
140
|
+
|
|
141
|
+
# Export to TorchScript
|
|
142
|
+
traced_model = torch.jit.trace(model, dummy_input)
|
|
143
|
+
traced_model.save(args.output)
|
|
144
|
+
logger.info(f"✅ TorchScript model exported to {args.output}")
|
|
145
|
+
|
|
146
|
+
elif args.format == 'state_dict':
|
|
147
|
+
# Export state dict
|
|
148
|
+
torch.save(model.state_dict(), args.output)
|
|
149
|
+
logger.info(f"✅ State dict exported to {args.output}")
|
|
150
|
+
|
|
151
|
+
except Exception as e:
|
|
152
|
+
logger.error(f"❌ Export failed: {e}")
|
|
153
|
+
return 1
|
|
154
|
+
|
|
155
|
+
# Verify export
|
|
156
|
+
if args.format in ['onnx', 'torchscript']:
|
|
157
|
+
logger.info("🔍 Verifying exported model...")
|
|
158
|
+
try:
|
|
159
|
+
if args.format == 'onnx':
|
|
160
|
+
import onnx
|
|
161
|
+
onnx_model = onnx.load(args.output)
|
|
162
|
+
onnx.checker.check_model(onnx_model)
|
|
163
|
+
logger.info("✅ ONNX model verification passed")
|
|
164
|
+
|
|
165
|
+
elif args.format == 'torchscript':
|
|
166
|
+
loaded_model = torch.jit.load(args.output)
|
|
167
|
+
logger.info("✅ TorchScript model verification passed")
|
|
168
|
+
|
|
169
|
+
except Exception as e:
|
|
170
|
+
logger.warning(f"⚠️ Model verification failed: {e}")
|
|
171
|
+
|
|
172
|
+
logger.info("✅ Model export completed successfully!")
|
|
173
|
+
return 0
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
if __name__ == '__main__':
|
|
177
|
+
sys.exit(main())
|
langvision/cli/finetune.py
CHANGED
|
@@ -23,47 +23,67 @@ def set_seed(seed: int) -> None:
|
|
|
23
23
|
|
|
24
24
|
def parse_args() -> argparse.Namespace:
|
|
25
25
|
"""Parse command-line arguments for fine-tuning."""
|
|
26
|
-
parser = argparse.ArgumentParser(
|
|
26
|
+
parser = argparse.ArgumentParser(
|
|
27
|
+
description='Fine-tune VisionTransformer with LoRA and advanced LLM concepts',
|
|
28
|
+
epilog='''\nExamples:\n langvision finetune --dataset cifar10 --epochs 10\n langvision finetune --dataset cifar100 --lora_r 8 --rlhf\n''',
|
|
29
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
30
|
+
)
|
|
27
31
|
# Data
|
|
28
|
-
parser.
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
32
|
+
data_group = parser.add_argument_group('Data')
|
|
33
|
+
data_group.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='Dataset to use')
|
|
34
|
+
data_group.add_argument('--data_dir', type=str, default='./data', help='Dataset directory')
|
|
35
|
+
data_group.add_argument('--num_classes', type=int, default=10, help='Number of classes in the dataset')
|
|
36
|
+
data_group.add_argument('--img_size', type=int, default=224, help='Input image size (pixels)')
|
|
37
|
+
data_group.add_argument('--patch_size', type=int, default=16, help='Patch size for Vision Transformer')
|
|
38
|
+
data_group.add_argument('--num_workers', type=int, default=2, help='Number of data loader workers')
|
|
34
39
|
# Model
|
|
35
|
-
parser.
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
40
|
+
model_group = parser.add_argument_group('Model')
|
|
41
|
+
model_group.add_argument('--embed_dim', type=int, default=768, help='Embedding dimension for ViT')
|
|
42
|
+
model_group.add_argument('--depth', type=int, default=12, help='Number of transformer layers')
|
|
43
|
+
model_group.add_argument('--num_heads', type=int, default=12, help='Number of attention heads')
|
|
44
|
+
model_group.add_argument('--mlp_ratio', type=float, default=4.0, help='MLP hidden dim ratio')
|
|
45
|
+
model_group.add_argument('--lora_r', type=int, default=4, help='LoRA rank (low-rank adaptation)')
|
|
46
|
+
model_group.add_argument('--lora_alpha', type=float, default=1.0, help='LoRA alpha scaling')
|
|
47
|
+
model_group.add_argument('--lora_dropout', type=float, default=0.1, help='LoRA dropout rate')
|
|
42
48
|
# Training
|
|
43
|
-
parser.
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
49
|
+
train_group = parser.add_argument_group('Training')
|
|
50
|
+
train_group.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
|
|
51
|
+
train_group.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
|
|
52
|
+
train_group.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
|
|
53
|
+
train_group.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'adamw', 'sgd'], help='Optimizer type')
|
|
54
|
+
train_group.add_argument('--weight_decay', type=float, default=0.01, help='Weight decay (L2 regularization)')
|
|
55
|
+
train_group.add_argument('--scheduler', type=str, default='cosine', choices=['cosine', 'step'], help='Learning rate scheduler')
|
|
56
|
+
train_group.add_argument('--step_size', type=int, default=5, help='StepLR: step size')
|
|
57
|
+
train_group.add_argument('--gamma', type=float, default=0.5, help='StepLR: gamma')
|
|
58
|
+
train_group.add_argument('--resume', type=str, default=None, help='Path to checkpoint to resume from')
|
|
59
|
+
train_group.add_argument('--eval_only', action='store_true', help='Only run evaluation (no training)')
|
|
53
60
|
# Output
|
|
54
|
-
parser.
|
|
55
|
-
|
|
61
|
+
output_group = parser.add_argument_group('Output')
|
|
62
|
+
output_group.add_argument('--output_dir', type=str, default='outputs', help='Directory to save outputs and checkpoints')
|
|
63
|
+
output_group.add_argument('--save_name', type=str, default='vit_lora_best.pth', help='Checkpoint file name')
|
|
56
64
|
# Callbacks
|
|
57
|
-
parser.
|
|
58
|
-
|
|
65
|
+
callback_group = parser.add_argument_group('Callbacks')
|
|
66
|
+
callback_group.add_argument('--early_stopping', action='store_true', help='Enable early stopping')
|
|
67
|
+
callback_group.add_argument('--patience', type=int, default=5, help='Early stopping patience')
|
|
59
68
|
# CUDA
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
69
|
+
cuda_group = parser.add_argument_group('CUDA')
|
|
70
|
+
cuda_group.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use (cuda or cpu)')
|
|
71
|
+
cuda_group.add_argument('--cuda_deterministic', action='store_true', help='Enable deterministic CUDA (reproducible, slower)')
|
|
72
|
+
cuda_group.add_argument('--cuda_benchmark', action='store_true', default=True, help='Enable cudnn.benchmark for fast training')
|
|
73
|
+
cuda_group.add_argument('--cuda_max_split_size_mb', type=int, default=None, help='Set CUDA max split size in MB (for large models, PyTorch >=1.10)')
|
|
64
74
|
# Misc
|
|
65
|
-
parser.
|
|
66
|
-
|
|
75
|
+
misc_group = parser.add_argument_group('Misc')
|
|
76
|
+
misc_group.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
|
|
77
|
+
misc_group.add_argument('--log_level', type=str, default='info', help='Logging level (debug, info, warning, error)')
|
|
78
|
+
# Advanced LLM concepts
|
|
79
|
+
llm_group = parser.add_argument_group('Advanced LLM Concepts')
|
|
80
|
+
llm_group.add_argument('--rlhf', action='store_true', help='Use RLHF (Reinforcement Learning from Human Feedback)')
|
|
81
|
+
llm_group.add_argument('--ppo', action='store_true', help='Use PPO (Proximal Policy Optimization)')
|
|
82
|
+
llm_group.add_argument('--dpo', action='store_true', help='Use DPO (Direct Preference Optimization)')
|
|
83
|
+
llm_group.add_argument('--lime', action='store_true', help='Use LIME for model explainability')
|
|
84
|
+
llm_group.add_argument('--shap', action='store_true', help='Use SHAP for model explainability')
|
|
85
|
+
llm_group.add_argument('--cot', action='store_true', help='Use Chain-of-Thought (CoT) prompt generation')
|
|
86
|
+
llm_group.add_argument('--ccot', action='store_true', help='Use Contrastive Chain-of-Thought (CCoT)')
|
|
67
87
|
return parser.parse_args()
|
|
68
88
|
|
|
69
89
|
|
|
@@ -76,12 +96,11 @@ def setup_logging(log_level: str) -> None:
|
|
|
76
96
|
|
|
77
97
|
|
|
78
98
|
def main() -> None:
|
|
79
|
-
"""Main function for fine-tuning VisionTransformer with LoRA."""
|
|
99
|
+
"""Main function for fine-tuning VisionTransformer with LoRA and advanced LLM concepts."""
|
|
80
100
|
args = parse_args()
|
|
81
101
|
setup_logging(args.log_level)
|
|
82
102
|
logger = logging.getLogger(__name__)
|
|
83
|
-
logger.info(
|
|
84
|
-
|
|
103
|
+
logger.info("[STEP 1] Loading dataset...")
|
|
85
104
|
setup_cuda(seed=args.seed, deterministic=args.cuda_deterministic, benchmark=args.cuda_benchmark, max_split_size_mb=args.cuda_max_split_size_mb)
|
|
86
105
|
set_seed(args.seed)
|
|
87
106
|
|
|
@@ -95,6 +114,7 @@ def main() -> None:
|
|
|
95
114
|
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
|
|
96
115
|
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
|
|
97
116
|
|
|
117
|
+
logger.info("[STEP 2] Initializing model...")
|
|
98
118
|
# Model
|
|
99
119
|
try:
|
|
100
120
|
model = VisionTransformer(
|
|
@@ -116,6 +136,7 @@ def main() -> None:
|
|
|
116
136
|
logger.error(f"Failed to initialize model: {e}")
|
|
117
137
|
return
|
|
118
138
|
|
|
139
|
+
logger.info("[STEP 3] Setting up optimizer and scheduler...")
|
|
119
140
|
# Optimizer
|
|
120
141
|
lora_params = [p for n, p in model.named_parameters() if 'lora' in n and p.requires_grad]
|
|
121
142
|
if args.optimizer == 'adam':
|
|
@@ -125,7 +146,6 @@ def main() -> None:
|
|
|
125
146
|
else:
|
|
126
147
|
optimizer = torch.optim.SGD(lora_params, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
|
|
127
148
|
|
|
128
|
-
# Scheduler
|
|
129
149
|
if args.scheduler == 'cosine':
|
|
130
150
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
|
131
151
|
else:
|
|
@@ -134,11 +154,64 @@ def main() -> None:
|
|
|
134
154
|
criterion = torch.nn.CrossEntropyLoss()
|
|
135
155
|
scaler = torch.cuda.amp.GradScaler() if args.device == 'cuda' else None
|
|
136
156
|
|
|
157
|
+
logger.info("[STEP 4] Setting up trainer...")
|
|
137
158
|
# Callbacks
|
|
138
159
|
callbacks = []
|
|
139
160
|
if args.early_stopping:
|
|
140
161
|
callbacks.append(EarlyStopping(patience=args.patience))
|
|
141
162
|
|
|
163
|
+
# Advanced LLM concept objects
|
|
164
|
+
rlhf = None
|
|
165
|
+
ppo = None
|
|
166
|
+
dpo = None
|
|
167
|
+
lime = None
|
|
168
|
+
shap = None
|
|
169
|
+
cot = None
|
|
170
|
+
ccot = None
|
|
171
|
+
if args.rlhf:
|
|
172
|
+
from langvision.concepts.rlhf import RLHF
|
|
173
|
+
class SimpleRLHF(RLHF):
|
|
174
|
+
def train(self, model, data, feedback_fn, optimizer):
|
|
175
|
+
super().train(model, data, feedback_fn, optimizer)
|
|
176
|
+
rlhf = SimpleRLHF()
|
|
177
|
+
if args.ppo:
|
|
178
|
+
from langvision.concepts.ppo import PPO
|
|
179
|
+
class SimplePPO(PPO):
|
|
180
|
+
def step(self, policy, old_log_probs, states, actions, rewards, optimizer):
|
|
181
|
+
super().step(policy, old_log_probs, states, actions, rewards, optimizer)
|
|
182
|
+
ppo = SimplePPO()
|
|
183
|
+
if args.dpo:
|
|
184
|
+
from langvision.concepts.dpo import DPO
|
|
185
|
+
class SimpleDPO(DPO):
|
|
186
|
+
def optimize_with_preferences(self, model, preferences, optimizer):
|
|
187
|
+
super().optimize_with_preferences(model, preferences, optimizer)
|
|
188
|
+
dpo = SimpleDPO()
|
|
189
|
+
if args.lime:
|
|
190
|
+
from langvision.concepts.lime import LIME
|
|
191
|
+
class SimpleLIME(LIME):
|
|
192
|
+
def explain(self, model, input_data):
|
|
193
|
+
return super().explain(model, input_data)
|
|
194
|
+
lime = SimpleLIME()
|
|
195
|
+
if args.shap:
|
|
196
|
+
from langvision.concepts.shap import SHAP
|
|
197
|
+
class SimpleSHAP(SHAP):
|
|
198
|
+
def explain(self, model, input_data):
|
|
199
|
+
return super().explain(model, input_data)
|
|
200
|
+
shap = SimpleSHAP()
|
|
201
|
+
if args.cot:
|
|
202
|
+
from langvision.concepts.cot import CoT
|
|
203
|
+
class SimpleCoT(CoT):
|
|
204
|
+
def generate_chain(self, prompt):
|
|
205
|
+
return super().generate_chain(prompt)
|
|
206
|
+
cot = SimpleCoT()
|
|
207
|
+
if args.ccot:
|
|
208
|
+
from langvision.concepts.ccot import CCoT
|
|
209
|
+
class SimpleCCoT(CCoT):
|
|
210
|
+
def contrastive_train(self, positive_chains, negative_chains):
|
|
211
|
+
super().contrastive_train(positive_chains, negative_chains)
|
|
212
|
+
ccot = SimpleCCoT()
|
|
213
|
+
|
|
214
|
+
logger.info("[STEP 5] (Optional) Loading checkpoint if provided...")
|
|
142
215
|
# Trainer
|
|
143
216
|
trainer = Trainer(
|
|
144
217
|
model=model,
|
|
@@ -148,6 +221,9 @@ def main() -> None:
|
|
|
148
221
|
scaler=scaler,
|
|
149
222
|
callbacks=callbacks,
|
|
150
223
|
device=args.device,
|
|
224
|
+
rlhf=rlhf,
|
|
225
|
+
ppo=ppo,
|
|
226
|
+
dpo=dpo,
|
|
151
227
|
)
|
|
152
228
|
|
|
153
229
|
# Optionally resume
|
|
@@ -162,20 +238,61 @@ def main() -> None:
|
|
|
162
238
|
logger.info(f"Resumed from {args.resume} at epoch {start_epoch}")
|
|
163
239
|
|
|
164
240
|
if args.eval_only:
|
|
241
|
+
logger.info("[STEP 6] Running evaluation only...")
|
|
165
242
|
val_loss, val_acc = trainer.evaluate(val_loader)
|
|
166
243
|
logger.info(f"Eval Loss: {val_loss:.4f}, Eval Acc: {val_acc:.4f}")
|
|
167
244
|
return
|
|
168
245
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
246
|
+
logger.info("[STEP 6] Starting training...")
|
|
247
|
+
# Training with per-epoch progress
|
|
248
|
+
for epoch in range(start_epoch, args.epochs):
|
|
249
|
+
logger.info(f" [Epoch {epoch+1}/{args.epochs}] Starting epoch...")
|
|
250
|
+
best_acc = trainer.fit(
|
|
251
|
+
train_loader,
|
|
252
|
+
val_loader,
|
|
253
|
+
epochs=epoch+1,
|
|
254
|
+
start_epoch=epoch,
|
|
255
|
+
best_acc=best_acc,
|
|
256
|
+
checkpoint_path=os.path.join(args.output_dir, args.save_name),
|
|
257
|
+
)
|
|
258
|
+
logger.info(f" [Epoch {epoch+1}/{args.epochs}] Done. Best validation accuracy so far: {best_acc:.4f}")
|
|
259
|
+
logger.info(f"[STEP 7] Training complete. Best validation accuracy: {best_acc:.4f}")
|
|
260
|
+
|
|
261
|
+
# Example: Use RLHF, PPO, DPO, LIME, SHAP, CoT, CCoT in training
|
|
262
|
+
if rlhf is not None:
|
|
263
|
+
logger.info("[STEP 8] Applying RLHF...")
|
|
264
|
+
def feedback_fn(output):
|
|
265
|
+
return 1.0 if output.sum().item() > 0 else -1.0
|
|
266
|
+
rlhf.train(model, [torch.randn(3) for _ in range(10)], feedback_fn, optimizer)
|
|
267
|
+
if ppo is not None:
|
|
268
|
+
logger.info("[STEP 9] Applying PPO...")
|
|
269
|
+
import torch
|
|
270
|
+
policy = model
|
|
271
|
+
old_log_probs = torch.zeros(10)
|
|
272
|
+
states = torch.randn(10, 3)
|
|
273
|
+
actions = torch.randint(0, args.num_classes, (10,))
|
|
274
|
+
rewards = torch.randn(10)
|
|
275
|
+
ppo.step(policy, old_log_probs, states, actions, rewards, optimizer)
|
|
276
|
+
if dpo is not None:
|
|
277
|
+
logger.info("[STEP 10] Applying DPO...")
|
|
278
|
+
preferences = [(torch.randn(3), 1.0), (torch.randn(3), -1.0)]
|
|
279
|
+
dpo.optimize_with_preferences(model, preferences, optimizer)
|
|
280
|
+
if lime is not None:
|
|
281
|
+
logger.info("[STEP 11] Running LIME explainability...")
|
|
282
|
+
lime_explanation = lime.explain(model, [[0.5, 1.0, 2.0], [1.0, 2.0, 3.0]])
|
|
283
|
+
logger.info(f"LIME explanation: {lime_explanation}")
|
|
284
|
+
if shap is not None:
|
|
285
|
+
logger.info("[STEP 12] Running SHAP explainability...")
|
|
286
|
+
shap_explanation = shap.explain(model, [[0.5, 1.0, 2.0], [1.0, 2.0, 3.0]])
|
|
287
|
+
logger.info(f"SHAP explanation: {shap_explanation}")
|
|
288
|
+
if cot is not None:
|
|
289
|
+
logger.info("[STEP 13] Generating Chain-of-Thought...")
|
|
290
|
+
chain = cot.generate_chain("What is 2 + 2?")
|
|
291
|
+
logger.info(f"CoT chain: {chain}")
|
|
292
|
+
if ccot is not None:
|
|
293
|
+
logger.info("[STEP 14] Running Contrastive Chain-of-Thought...")
|
|
294
|
+
ccot.contrastive_train([['Step 1: Think', 'Step 2: Solve']], [['Step 1: Guess', 'Step 2: Wrong']])
|
|
295
|
+
logger.info("[COMPLETE] All steps finished.")
|
|
179
296
|
|
|
180
297
|
if __name__ == '__main__':
|
|
181
298
|
main()
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Model Zoo CLI for Langvision.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import sys
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
# Import from langvision modules
|
|
13
|
+
try:
|
|
14
|
+
from langvision.model_zoo import get_available_models, get_model_info, download_model
|
|
15
|
+
except ImportError as e:
|
|
16
|
+
print(f"❌ Error importing langvision modules: {e}")
|
|
17
|
+
print("Please ensure langvision is properly installed:")
|
|
18
|
+
print(" pip install langvision")
|
|
19
|
+
sys.exit(1)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def parse_args():
|
|
23
|
+
"""Parse command-line arguments for model zoo operations."""
|
|
24
|
+
parser = argparse.ArgumentParser(
|
|
25
|
+
description='Langvision Model Zoo - Browse, download, and manage pre-trained models',
|
|
26
|
+
epilog='''\nExamples:\n langvision model-zoo list\n langvision model-zoo info vit_base_patch16_224\n langvision model-zoo download vit_base_patch16_224\n''',
|
|
27
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
subparsers = parser.add_subparsers(dest='command', required=True, help='Model zoo commands')
|
|
31
|
+
|
|
32
|
+
# List command
|
|
33
|
+
list_parser = subparsers.add_parser('list', help='List all available models')
|
|
34
|
+
list_parser.add_argument('--format', type=str, default='table', choices=['table', 'json'], help='Output format')
|
|
35
|
+
list_parser.add_argument('--filter', type=str, help='Filter models by name or type')
|
|
36
|
+
|
|
37
|
+
# Info command
|
|
38
|
+
info_parser = subparsers.add_parser('info', help='Get detailed information about a model')
|
|
39
|
+
info_parser.add_argument('model_name', type=str, help='Name of the model')
|
|
40
|
+
info_parser.add_argument('--format', type=str, default='table', choices=['table', 'json'], help='Output format')
|
|
41
|
+
|
|
42
|
+
# Download command
|
|
43
|
+
download_parser = subparsers.add_parser('download', help='Download a pre-trained model')
|
|
44
|
+
download_parser.add_argument('model_name', type=str, help='Name of the model to download')
|
|
45
|
+
download_parser.add_argument('--output_dir', type=str, default='./models', help='Directory to save the model')
|
|
46
|
+
download_parser.add_argument('--force', action='store_true', help='Force download even if model exists')
|
|
47
|
+
|
|
48
|
+
# Search command
|
|
49
|
+
search_parser = subparsers.add_parser('search', help='Search for models')
|
|
50
|
+
search_parser.add_argument('query', type=str, help='Search query')
|
|
51
|
+
search_parser.add_argument('--format', type=str, default='table', choices=['table', 'json'], help='Output format')
|
|
52
|
+
|
|
53
|
+
# Misc
|
|
54
|
+
parser.add_argument('--log_level', type=str, default='info', help='Logging level')
|
|
55
|
+
|
|
56
|
+
return parser.parse_args()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def setup_logging(log_level: str) -> None:
|
|
60
|
+
"""Set up logging with the specified log level."""
|
|
61
|
+
numeric_level = getattr(logging, log_level.upper(), None)
|
|
62
|
+
if not isinstance(numeric_level, int):
|
|
63
|
+
numeric_level = logging.INFO
|
|
64
|
+
logging.basicConfig(level=numeric_level, format='[%(levelname)s] %(message)s')
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def print_model_table(models):
|
|
68
|
+
"""Print models in a formatted table."""
|
|
69
|
+
if not models:
|
|
70
|
+
print("No models found.")
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
# Calculate column widths
|
|
74
|
+
name_width = max(len(model.get('name', '')) for model in models) + 2
|
|
75
|
+
type_width = max(len(model.get('type', '')) for model in models) + 2
|
|
76
|
+
size_width = max(len(str(model.get('size', ''))) for model in models) + 2
|
|
77
|
+
|
|
78
|
+
# Print header
|
|
79
|
+
print(f"{'Name':<{name_width}} {'Type':<{type_width}} {'Size':<{size_width}} {'Description'}")
|
|
80
|
+
print("-" * (name_width + type_width + size_width + 50))
|
|
81
|
+
|
|
82
|
+
# Print models
|
|
83
|
+
for model in models:
|
|
84
|
+
name = model.get('name', '')
|
|
85
|
+
model_type = model.get('type', '')
|
|
86
|
+
size = model.get('size', '')
|
|
87
|
+
description = model.get('description', '')
|
|
88
|
+
print(f"{name:<{name_width}} {model_type:<{type_width}} {size:<{size_width}} {description}")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def print_model_info(model_info, format='table'):
|
|
92
|
+
"""Print detailed model information."""
|
|
93
|
+
if format == 'json':
|
|
94
|
+
print(json.dumps(model_info, indent=2))
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
print(f"\n🤖 Model: {model_info.get('name', 'Unknown')}")
|
|
98
|
+
print("=" * 50)
|
|
99
|
+
|
|
100
|
+
for key, value in model_info.items():
|
|
101
|
+
if key == 'name':
|
|
102
|
+
continue
|
|
103
|
+
print(f"{key.replace('_', ' ').title()}: {value}")
|
|
104
|
+
|
|
105
|
+
if 'config' in model_info:
|
|
106
|
+
print(f"\n📋 Configuration:")
|
|
107
|
+
config = model_info['config']
|
|
108
|
+
for key, value in config.items():
|
|
109
|
+
print(f" {key}: {value}")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def main():
|
|
113
|
+
"""Main function for model zoo operations."""
|
|
114
|
+
args = parse_args()
|
|
115
|
+
setup_logging(args.log_level)
|
|
116
|
+
logger = logging.getLogger(__name__)
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
if args.command == 'list':
|
|
120
|
+
logger.info("📋 Fetching available models...")
|
|
121
|
+
models = get_available_models()
|
|
122
|
+
|
|
123
|
+
if args.filter:
|
|
124
|
+
models = [m for m in models if args.filter.lower() in m.get('name', '').lower()]
|
|
125
|
+
|
|
126
|
+
if args.format == 'json':
|
|
127
|
+
print(json.dumps(models, indent=2))
|
|
128
|
+
else:
|
|
129
|
+
print(f"\n🎯 Available Models ({len(models)} total):")
|
|
130
|
+
print_model_table(models)
|
|
131
|
+
|
|
132
|
+
elif args.command == 'info':
|
|
133
|
+
logger.info(f"🔍 Getting information for model: {args.model_name}")
|
|
134
|
+
model_info = get_model_info(args.model_name)
|
|
135
|
+
print_model_info(model_info, args.format)
|
|
136
|
+
|
|
137
|
+
elif args.command == 'download':
|
|
138
|
+
logger.info(f"⬇️ Downloading model: {args.model_name}")
|
|
139
|
+
output_path = download_model(args.model_name, args.output_dir, force=args.force)
|
|
140
|
+
logger.info(f"✅ Model downloaded to: {output_path}")
|
|
141
|
+
|
|
142
|
+
elif args.command == 'search':
|
|
143
|
+
logger.info(f"🔍 Searching for: {args.query}")
|
|
144
|
+
models = get_available_models()
|
|
145
|
+
results = [m for m in models if args.query.lower() in m.get('name', '').lower() or
|
|
146
|
+
args.query.lower() in m.get('description', '').lower()]
|
|
147
|
+
|
|
148
|
+
if args.format == 'json':
|
|
149
|
+
print(json.dumps(results, indent=2))
|
|
150
|
+
else:
|
|
151
|
+
print(f"\n🔍 Search Results for '{args.query}' ({len(results)} found):")
|
|
152
|
+
print_model_table(results)
|
|
153
|
+
|
|
154
|
+
except Exception as e:
|
|
155
|
+
logger.error(f"❌ Operation failed: {e}")
|
|
156
|
+
return 1
|
|
157
|
+
|
|
158
|
+
return 0
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
if __name__ == '__main__':
|
|
162
|
+
sys.exit(main())
|