langvision 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langvision might be problematic. Click here for more details.
- langvision/__init__.py +77 -2
- langvision/callbacks/base.py +166 -7
- langvision/cli/__init__.py +85 -0
- langvision/cli/complete_cli.py +319 -0
- langvision/cli/config.py +344 -0
- langvision/cli/evaluate.py +201 -0
- langvision/cli/export.py +177 -0
- langvision/cli/finetune.py +165 -48
- langvision/cli/model_zoo.py +162 -0
- langvision/cli/train.py +27 -13
- langvision/cli/utils.py +258 -0
- langvision/components/attention.py +4 -1
- langvision/concepts/__init__.py +9 -0
- langvision/concepts/ccot.py +30 -0
- langvision/concepts/cot.py +29 -0
- langvision/concepts/dpo.py +37 -0
- langvision/concepts/grpo.py +25 -0
- langvision/concepts/lime.py +37 -0
- langvision/concepts/ppo.py +47 -0
- langvision/concepts/rlhf.py +40 -0
- langvision/concepts/rlvr.py +25 -0
- langvision/concepts/shap.py +37 -0
- langvision/data/enhanced_datasets.py +582 -0
- langvision/model_zoo.py +169 -2
- langvision/models/lora.py +189 -17
- langvision/models/multimodal.py +297 -0
- langvision/models/resnet.py +303 -0
- langvision/training/advanced_trainer.py +478 -0
- langvision/training/trainer.py +30 -2
- langvision/utils/config.py +180 -9
- langvision/utils/metrics.py +448 -0
- langvision/utils/setup.py +266 -0
- langvision-0.1.0.dist-info/METADATA +50 -0
- langvision-0.1.0.dist-info/RECORD +61 -0
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/WHEEL +1 -1
- langvision-0.1.0.dist-info/entry_points.txt +2 -0
- langvision-0.0.1.dist-info/METADATA +0 -463
- langvision-0.0.1.dist-info/RECORD +0 -40
- langvision-0.0.1.dist-info/entry_points.txt +0 -2
- langvision-0.0.1.dist-info/licenses/LICENSE +0 -21
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/top_level.txt +0 -0
langvision/__init__.py
CHANGED
|
@@ -1,7 +1,82 @@
|
|
|
1
1
|
"""
|
|
2
|
-
langvision -
|
|
2
|
+
langvision - Modular Vision LLMs with Efficient LoRA Fine-Tuning
|
|
3
|
+
|
|
4
|
+
A research-friendly framework for building and fine-tuning Vision Large Language Models
|
|
5
|
+
with efficient Low-Rank Adaptation (LoRA) support.
|
|
3
6
|
"""
|
|
4
7
|
|
|
5
8
|
__version__ = "0.1.0"
|
|
9
|
+
__author__ = "Pritesh Raj"
|
|
10
|
+
__email__ = "priteshraj10@gmail.com"
|
|
11
|
+
|
|
12
|
+
# Core imports for easy access
|
|
13
|
+
from .models.vision_transformer import VisionTransformer
|
|
14
|
+
from .models.lora import LoRALinear, LoRAConfig, AdaLoRALinear, QLoRALinear
|
|
15
|
+
from .models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
|
|
16
|
+
from .models.multimodal import VisionLanguageModel, create_multimodal_model, CLIPLoss
|
|
17
|
+
from .utils.config import default_config
|
|
18
|
+
from .training.trainer import Trainer
|
|
19
|
+
from .training.advanced_trainer import AdvancedTrainer, TrainingConfig
|
|
20
|
+
from .data.datasets import get_dataset
|
|
21
|
+
from .data.enhanced_datasets import (
|
|
22
|
+
EnhancedImageDataset, MultimodalDataset, DatasetConfig,
|
|
23
|
+
create_enhanced_dataloaders, SmartAugmentation
|
|
24
|
+
)
|
|
25
|
+
from .utils.metrics import (
|
|
26
|
+
MetricsTracker, ClassificationMetrics, ContrastiveMetrics,
|
|
27
|
+
EvaluationSuite, PerformanceMetrics
|
|
28
|
+
)
|
|
29
|
+
from .callbacks.base import Callback, CallbackManager
|
|
30
|
+
from .concepts import RLHF, CoT, CCoT, GRPO, RLVR, DPO, PPO, LIME, SHAP
|
|
31
|
+
|
|
32
|
+
# Version info
|
|
33
|
+
__all__ = [
|
|
34
|
+
"__version__",
|
|
35
|
+
"__author__",
|
|
36
|
+
"__email__",
|
|
37
|
+
# Core Models
|
|
38
|
+
"VisionTransformer",
|
|
39
|
+
"resnet18", "resnet34", "resnet50", "resnet101", "resnet152",
|
|
40
|
+
"VisionLanguageModel", "create_multimodal_model",
|
|
41
|
+
# LoRA Components
|
|
42
|
+
"LoRALinear", "LoRAConfig", "AdaLoRALinear", "QLoRALinear",
|
|
43
|
+
# Training
|
|
44
|
+
"Trainer", "AdvancedTrainer", "TrainingConfig",
|
|
45
|
+
# Data
|
|
46
|
+
"get_dataset", "EnhancedImageDataset", "MultimodalDataset",
|
|
47
|
+
"DatasetConfig", "create_enhanced_dataloaders", "SmartAugmentation",
|
|
48
|
+
# Utilities
|
|
49
|
+
"default_config", "MetricsTracker", "ClassificationMetrics",
|
|
50
|
+
"ContrastiveMetrics", "EvaluationSuite", "PerformanceMetrics",
|
|
51
|
+
# Callbacks
|
|
52
|
+
"Callback", "CallbackManager",
|
|
53
|
+
# Loss Functions
|
|
54
|
+
"CLIPLoss",
|
|
55
|
+
# Concepts
|
|
56
|
+
"RLHF", "CoT", "CCoT", "GRPO", "RLVR", "DPO", "PPO", "LIME", "SHAP",
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
# Optional imports for advanced usage
|
|
60
|
+
try:
|
|
61
|
+
from .callbacks import EarlyStoppingCallback, LoggingCallback
|
|
62
|
+
from .utils.device import get_device, to_device
|
|
63
|
+
__all__.extend([
|
|
64
|
+
"EarlyStoppingCallback",
|
|
65
|
+
"LoggingCallback",
|
|
66
|
+
"get_device",
|
|
67
|
+
"to_device"
|
|
68
|
+
])
|
|
69
|
+
except ImportError:
|
|
70
|
+
pass
|
|
6
71
|
|
|
7
|
-
|
|
72
|
+
# Package metadata
|
|
73
|
+
PACKAGE_METADATA = {
|
|
74
|
+
"name": "langvision",
|
|
75
|
+
"version": __version__,
|
|
76
|
+
"description": "Modular Vision LLMs with Efficient LoRA Fine-Tuning",
|
|
77
|
+
"author": __author__,
|
|
78
|
+
"email": __email__,
|
|
79
|
+
"url": "https://github.com/langtrain-ai/langtrain",
|
|
80
|
+
"license": "MIT",
|
|
81
|
+
"python_requires": ">=3.8",
|
|
82
|
+
}
|
langvision/callbacks/base.py
CHANGED
|
@@ -1,11 +1,170 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
"""Enhanced callback system for training and evaluation hooks."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Dict, Any, Optional, List
|
|
5
|
+
import logging
|
|
6
|
+
import traceback
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Callback(ABC):
|
|
10
|
+
"""Base class for all callbacks with comprehensive training hooks."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, name: Optional[str] = None):
|
|
13
|
+
self.name = name or self.__class__.__name__
|
|
14
|
+
self.logger = logging.getLogger(f"langvision.callbacks.{self.name}")
|
|
15
|
+
|
|
16
|
+
def on_train_start(self, trainer) -> None:
|
|
17
|
+
"""Called at the beginning of training."""
|
|
3
18
|
pass
|
|
4
|
-
|
|
19
|
+
|
|
20
|
+
def on_train_end(self, trainer) -> None:
|
|
21
|
+
"""Called at the end of training."""
|
|
5
22
|
pass
|
|
6
|
-
|
|
23
|
+
|
|
24
|
+
def on_epoch_start(self, trainer, epoch: int) -> None:
|
|
25
|
+
"""Called at the beginning of each epoch."""
|
|
7
26
|
pass
|
|
8
|
-
|
|
27
|
+
|
|
28
|
+
def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, Any]) -> None:
|
|
29
|
+
"""Called at the end of each epoch."""
|
|
9
30
|
pass
|
|
10
|
-
|
|
11
|
-
|
|
31
|
+
|
|
32
|
+
def on_batch_start(self, trainer, batch_idx: int, batch: Dict[str, Any]) -> None:
|
|
33
|
+
"""Called at the beginning of each batch."""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
def on_batch_end(self, trainer, batch_idx: int, batch: Dict[str, Any], outputs: Dict[str, Any]) -> None:
|
|
37
|
+
"""Called at the end of each batch."""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
def on_validation_start(self, trainer) -> None:
|
|
41
|
+
"""Called at the beginning of validation."""
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
def on_validation_end(self, trainer, metrics: Dict[str, Any]) -> None:
|
|
45
|
+
"""Called at the end of validation."""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
def on_test_start(self, trainer) -> None:
|
|
49
|
+
"""Called at the beginning of testing."""
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
def on_test_end(self, trainer, metrics: Dict[str, Any]) -> None:
|
|
53
|
+
"""Called at the end of testing."""
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
def on_checkpoint_save(self, trainer, checkpoint_path: str, metrics: Dict[str, Any]) -> None:
|
|
57
|
+
"""Called when a checkpoint is saved."""
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
def on_checkpoint_load(self, trainer, checkpoint_path: str) -> None:
|
|
61
|
+
"""Called when a checkpoint is loaded."""
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
def on_lr_schedule(self, trainer, old_lr: float, new_lr: float) -> None:
|
|
65
|
+
"""Called when learning rate is scheduled."""
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
def on_exception(self, trainer, exception: Exception) -> bool:
|
|
69
|
+
"""Called when an exception occurs during training.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
bool: True if the exception was handled and training should continue,
|
|
73
|
+
False if training should stop.
|
|
74
|
+
"""
|
|
75
|
+
return False
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class CallbackManager:
|
|
79
|
+
"""Manager for handling multiple callbacks with error handling."""
|
|
80
|
+
|
|
81
|
+
def __init__(self, callbacks: Optional[List[Callback]] = None):
|
|
82
|
+
self.callbacks = callbacks or []
|
|
83
|
+
self.logger = logging.getLogger("langvision.callbacks.manager")
|
|
84
|
+
|
|
85
|
+
def add_callback(self, callback: Callback) -> None:
|
|
86
|
+
"""Add a callback to the manager."""
|
|
87
|
+
if not isinstance(callback, Callback):
|
|
88
|
+
raise TypeError(f"Expected Callback instance, got {type(callback)}")
|
|
89
|
+
self.callbacks.append(callback)
|
|
90
|
+
self.logger.info(f"Added callback: {callback.name}")
|
|
91
|
+
|
|
92
|
+
def remove_callback(self, callback_name: str) -> bool:
|
|
93
|
+
"""Remove a callback by name."""
|
|
94
|
+
for i, callback in enumerate(self.callbacks):
|
|
95
|
+
if callback.name == callback_name:
|
|
96
|
+
del self.callbacks[i]
|
|
97
|
+
self.logger.info(f"Removed callback: {callback_name}")
|
|
98
|
+
return True
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
def _call_callbacks(self, method_name: str, *args, **kwargs) -> None:
|
|
102
|
+
"""Safely call a method on all callbacks."""
|
|
103
|
+
for callback in self.callbacks:
|
|
104
|
+
try:
|
|
105
|
+
method = getattr(callback, method_name, None)
|
|
106
|
+
if method and callable(method):
|
|
107
|
+
method(*args, **kwargs)
|
|
108
|
+
except Exception as e:
|
|
109
|
+
self.logger.error(
|
|
110
|
+
f"Error in callback {callback.name}.{method_name}: {str(e)}\n"
|
|
111
|
+
f"Traceback: {traceback.format_exc()}"
|
|
112
|
+
)
|
|
113
|
+
# Continue with other callbacks even if one fails
|
|
114
|
+
|
|
115
|
+
def on_train_start(self, trainer) -> None:
|
|
116
|
+
self._call_callbacks('on_train_start', trainer)
|
|
117
|
+
|
|
118
|
+
def on_train_end(self, trainer) -> None:
|
|
119
|
+
self._call_callbacks('on_train_end', trainer)
|
|
120
|
+
|
|
121
|
+
def on_epoch_start(self, trainer, epoch: int) -> None:
|
|
122
|
+
self._call_callbacks('on_epoch_start', trainer, epoch)
|
|
123
|
+
|
|
124
|
+
def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, Any]) -> None:
|
|
125
|
+
self._call_callbacks('on_epoch_end', trainer, epoch, metrics)
|
|
126
|
+
|
|
127
|
+
def on_batch_start(self, trainer, batch_idx: int, batch: Dict[str, Any]) -> None:
|
|
128
|
+
self._call_callbacks('on_batch_start', trainer, batch_idx, batch)
|
|
129
|
+
|
|
130
|
+
def on_batch_end(self, trainer, batch_idx: int, batch: Dict[str, Any], outputs: Dict[str, Any]) -> None:
|
|
131
|
+
self._call_callbacks('on_batch_end', trainer, batch_idx, batch, outputs)
|
|
132
|
+
|
|
133
|
+
def on_validation_start(self, trainer) -> None:
|
|
134
|
+
self._call_callbacks('on_validation_start', trainer)
|
|
135
|
+
|
|
136
|
+
def on_validation_end(self, trainer, metrics: Dict[str, Any]) -> None:
|
|
137
|
+
self._call_callbacks('on_validation_end', trainer, metrics)
|
|
138
|
+
|
|
139
|
+
def on_test_start(self, trainer) -> None:
|
|
140
|
+
self._call_callbacks('on_test_start', trainer)
|
|
141
|
+
|
|
142
|
+
def on_test_end(self, trainer, metrics: Dict[str, Any]) -> None:
|
|
143
|
+
self._call_callbacks('on_test_end', trainer, metrics)
|
|
144
|
+
|
|
145
|
+
def on_checkpoint_save(self, trainer, checkpoint_path: str, metrics: Dict[str, Any]) -> None:
|
|
146
|
+
self._call_callbacks('on_checkpoint_save', trainer, checkpoint_path, metrics)
|
|
147
|
+
|
|
148
|
+
def on_checkpoint_load(self, trainer, checkpoint_path: str) -> None:
|
|
149
|
+
self._call_callbacks('on_checkpoint_load', trainer, checkpoint_path)
|
|
150
|
+
|
|
151
|
+
def on_lr_schedule(self, trainer, old_lr: float, new_lr: float) -> None:
|
|
152
|
+
self._call_callbacks('on_lr_schedule', trainer, old_lr, new_lr)
|
|
153
|
+
|
|
154
|
+
def on_exception(self, trainer, exception: Exception) -> bool:
|
|
155
|
+
"""Handle exceptions through callbacks.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
bool: True if any callback handled the exception and training should continue.
|
|
159
|
+
"""
|
|
160
|
+
handled = False
|
|
161
|
+
for callback in self.callbacks:
|
|
162
|
+
try:
|
|
163
|
+
if callback.on_exception(trainer, exception):
|
|
164
|
+
handled = True
|
|
165
|
+
self.logger.info(f"Exception handled by callback: {callback.name}")
|
|
166
|
+
except Exception as callback_error:
|
|
167
|
+
self.logger.error(
|
|
168
|
+
f"Error in callback {callback.name}.on_exception: {str(callback_error)}"
|
|
169
|
+
)
|
|
170
|
+
return handled
|
langvision/cli/__init__.py
CHANGED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import sys
|
|
3
|
+
from .train import main as train_main
|
|
4
|
+
from .finetune import main as finetune_main
|
|
5
|
+
from .evaluate import main as evaluate_main
|
|
6
|
+
from .export import main as export_main
|
|
7
|
+
from .model_zoo import main as model_zoo_main
|
|
8
|
+
from .config import main as config_main
|
|
9
|
+
|
|
10
|
+
__version__ = "0.1.0" # Keep in sync with package version in pyproject.toml
|
|
11
|
+
|
|
12
|
+
def print_banner():
|
|
13
|
+
banner = r"""
|
|
14
|
+
\033[1;36m
|
|
15
|
+
_ __ ___ _
|
|
16
|
+
| | __ _ _ __ __ _\ \ / (_)___(_) ___ _ __
|
|
17
|
+
| | / _` | '_ \ / _` |\ \ / /| / __| |/ _ \| '_ \
|
|
18
|
+
| |__| (_| | | | | (_| | \ V / | \__ \ | (_) | | | |
|
|
19
|
+
|_____\__,_|_| |_|\__, | \_/ |_|___/_|\___/|_| |_|
|
|
20
|
+
|___/
|
|
21
|
+
\033[0m
|
|
22
|
+
"""
|
|
23
|
+
print(banner)
|
|
24
|
+
print("\033[1;33mLANGVISION\033[0m: Modular Vision LLMs with Efficient LoRA Fine-Tuning")
|
|
25
|
+
print(f"\033[1;35mVersion:\033[0m {__version__}")
|
|
26
|
+
print("\033[1;32mDocs:\033[0m https://github.com/langtrain-ai/langtrain/tree/main/docs \033[1;34mPyPI:\033[0m https://pypi.org/project/langvision/\n")
|
|
27
|
+
|
|
28
|
+
def main():
|
|
29
|
+
print_banner()
|
|
30
|
+
parser = argparse.ArgumentParser(
|
|
31
|
+
prog="langvision",
|
|
32
|
+
description="Langvision: Modular Vision LLMs with Efficient LoRA Fine-Tuning.\n\nUse subcommands to train or finetune vision models."
|
|
33
|
+
)
|
|
34
|
+
parser.add_argument('--version', action='version', version=f'%(prog)s {__version__}')
|
|
35
|
+
subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-commands')
|
|
36
|
+
|
|
37
|
+
# Train subcommand
|
|
38
|
+
train_parser = subparsers.add_parser('train', help='Train a VisionTransformer model')
|
|
39
|
+
train_parser.add_argument('args', nargs=argparse.REMAINDER)
|
|
40
|
+
|
|
41
|
+
# Finetune subcommand
|
|
42
|
+
finetune_parser = subparsers.add_parser('finetune', help='Finetune a VisionTransformer model with LoRA and LLM concepts')
|
|
43
|
+
finetune_parser.add_argument('args', nargs=argparse.REMAINDER)
|
|
44
|
+
|
|
45
|
+
# Evaluate subcommand
|
|
46
|
+
evaluate_parser = subparsers.add_parser('evaluate', help='Evaluate a trained model')
|
|
47
|
+
evaluate_parser.add_argument('args', nargs=argparse.REMAINDER)
|
|
48
|
+
|
|
49
|
+
# Export subcommand
|
|
50
|
+
export_parser = subparsers.add_parser('export', help='Export a model to various formats (ONNX, TorchScript)')
|
|
51
|
+
export_parser.add_argument('args', nargs=argparse.REMAINDER)
|
|
52
|
+
|
|
53
|
+
# Model Zoo subcommand
|
|
54
|
+
model_zoo_parser = subparsers.add_parser('model-zoo', help='Browse and download pre-trained models')
|
|
55
|
+
model_zoo_parser.add_argument('args', nargs=argparse.REMAINDER)
|
|
56
|
+
|
|
57
|
+
# Config subcommand
|
|
58
|
+
config_parser = subparsers.add_parser('config', help='Manage configuration files')
|
|
59
|
+
config_parser.add_argument('args', nargs=argparse.REMAINDER)
|
|
60
|
+
|
|
61
|
+
args = parser.parse_args()
|
|
62
|
+
|
|
63
|
+
if args.command == 'train':
|
|
64
|
+
sys.argv = [sys.argv[0]] + args.args
|
|
65
|
+
train_main()
|
|
66
|
+
elif args.command == 'finetune':
|
|
67
|
+
sys.argv = [sys.argv[0]] + args.args
|
|
68
|
+
finetune_main()
|
|
69
|
+
elif args.command == 'evaluate':
|
|
70
|
+
sys.argv = [sys.argv[0]] + args.args
|
|
71
|
+
evaluate_main()
|
|
72
|
+
elif args.command == 'export':
|
|
73
|
+
sys.argv = [sys.argv[0]] + args.args
|
|
74
|
+
export_main()
|
|
75
|
+
elif args.command == 'model-zoo':
|
|
76
|
+
sys.argv = [sys.argv[0]] + args.args
|
|
77
|
+
model_zoo_main()
|
|
78
|
+
elif args.command == 'config':
|
|
79
|
+
sys.argv = [sys.argv[0]] + args.args
|
|
80
|
+
config_main()
|
|
81
|
+
else:
|
|
82
|
+
parser.print_help()
|
|
83
|
+
|
|
84
|
+
if __name__ == '__main__':
|
|
85
|
+
main()
|
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Complete CLI interface for Langvision framework.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import sys
|
|
7
|
+
import logging
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Optional, Dict, Any
|
|
10
|
+
|
|
11
|
+
from ..utils.setup import initialize_langvision, quick_setup
|
|
12
|
+
from ..models.lora import LoRAConfig
|
|
13
|
+
from ..training.advanced_trainer import AdvancedTrainer, TrainingConfig
|
|
14
|
+
from ..data.enhanced_datasets import create_enhanced_dataloaders, DatasetConfig
|
|
15
|
+
from ..models.vision_transformer import VisionTransformer
|
|
16
|
+
from ..models.resnet import resnet50
|
|
17
|
+
from ..models.multimodal import create_multimodal_model
|
|
18
|
+
from ..utils.device import get_device
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_parser() -> argparse.ArgumentParser:
|
|
22
|
+
"""Create comprehensive argument parser for Langvision CLI."""
|
|
23
|
+
|
|
24
|
+
parser = argparse.ArgumentParser(
|
|
25
|
+
prog="langvision",
|
|
26
|
+
description="Langvision: Advanced Vision-Language Models with Efficient LoRA Fine-Tuning",
|
|
27
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
28
|
+
epilog="""
|
|
29
|
+
Examples:
|
|
30
|
+
# Basic image classification training
|
|
31
|
+
langvision train --model vit --dataset cifar10 --data-dir ./data --epochs 50
|
|
32
|
+
|
|
33
|
+
# LoRA fine-tuning with custom parameters
|
|
34
|
+
langvision train --model resnet50 --lora-r 16 --lora-alpha 32 --freeze-backbone
|
|
35
|
+
|
|
36
|
+
# Multimodal training
|
|
37
|
+
langvision train-multimodal --vision-model vit --text-model bert-base-uncased --data-dir ./data
|
|
38
|
+
|
|
39
|
+
# Evaluation
|
|
40
|
+
langvision evaluate --model-path ./checkpoints/best_model.pt --data-dir ./test_data
|
|
41
|
+
"""
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
parser.add_argument("--version", action="version", version="langvision 0.1.0")
|
|
45
|
+
|
|
46
|
+
# Create subparsers
|
|
47
|
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
48
|
+
|
|
49
|
+
# Training command
|
|
50
|
+
train_parser = subparsers.add_parser("train", help="Train a vision model")
|
|
51
|
+
add_training_args(train_parser)
|
|
52
|
+
|
|
53
|
+
# Multimodal training command
|
|
54
|
+
multimodal_parser = subparsers.add_parser("train-multimodal", help="Train a multimodal vision-language model")
|
|
55
|
+
add_multimodal_args(multimodal_parser)
|
|
56
|
+
|
|
57
|
+
# Evaluation command
|
|
58
|
+
eval_parser = subparsers.add_parser("evaluate", help="Evaluate a trained model")
|
|
59
|
+
add_evaluation_args(eval_parser)
|
|
60
|
+
|
|
61
|
+
# Setup command
|
|
62
|
+
setup_parser = subparsers.add_parser("setup", help="Setup and validate Langvision environment")
|
|
63
|
+
add_setup_args(setup_parser)
|
|
64
|
+
|
|
65
|
+
return parser
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def add_training_args(parser: argparse.ArgumentParser) -> None:
|
|
69
|
+
"""Add training-specific arguments."""
|
|
70
|
+
|
|
71
|
+
# Model arguments
|
|
72
|
+
model_group = parser.add_argument_group("Model Configuration")
|
|
73
|
+
model_group.add_argument("--model", choices=["vit", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"],
|
|
74
|
+
default="vit", help="Model architecture to use")
|
|
75
|
+
model_group.add_argument("--num-classes", type=int, default=10, help="Number of output classes")
|
|
76
|
+
model_group.add_argument("--img-size", type=int, default=224, help="Input image size")
|
|
77
|
+
|
|
78
|
+
# LoRA arguments
|
|
79
|
+
lora_group = parser.add_argument_group("LoRA Configuration")
|
|
80
|
+
lora_group.add_argument("--use-lora", action="store_true", help="Enable LoRA fine-tuning")
|
|
81
|
+
lora_group.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
|
|
82
|
+
lora_group.add_argument("--lora-alpha", type=float, default=32.0, help="LoRA alpha")
|
|
83
|
+
lora_group.add_argument("--lora-dropout", type=float, default=0.1, help="LoRA dropout")
|
|
84
|
+
lora_group.add_argument("--freeze-backbone", action="store_true", help="Freeze backbone parameters")
|
|
85
|
+
|
|
86
|
+
# Data arguments
|
|
87
|
+
data_group = parser.add_argument_group("Data Configuration")
|
|
88
|
+
data_group.add_argument("--dataset", choices=["cifar10", "cifar100", "imagefolder"],
|
|
89
|
+
default="cifar10", help="Dataset to use")
|
|
90
|
+
data_group.add_argument("--data-dir", type=str, required=True, help="Path to dataset")
|
|
91
|
+
data_group.add_argument("--batch-size", type=int, default=32, help="Batch size")
|
|
92
|
+
data_group.add_argument("--num-workers", type=int, default=4, help="Number of data loading workers")
|
|
93
|
+
data_group.add_argument("--augmentation", action="store_true", help="Enable data augmentation")
|
|
94
|
+
data_group.add_argument("--augmentation-strength", type=float, default=0.5, help="Augmentation strength")
|
|
95
|
+
|
|
96
|
+
# Training arguments
|
|
97
|
+
train_group = parser.add_argument_group("Training Configuration")
|
|
98
|
+
train_group.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
|
99
|
+
train_group.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
|
|
100
|
+
train_group.add_argument("--weight-decay", type=float, default=1e-4, help="Weight decay")
|
|
101
|
+
train_group.add_argument("--optimizer", choices=["adam", "adamw", "sgd"], default="adamw", help="Optimizer")
|
|
102
|
+
train_group.add_argument("--scheduler", choices=["cosine", "step", "plateau"], default="cosine", help="LR scheduler")
|
|
103
|
+
train_group.add_argument("--warmup-epochs", type=int, default=5, help="Warmup epochs")
|
|
104
|
+
train_group.add_argument("--use-amp", action="store_true", help="Use mixed precision training")
|
|
105
|
+
train_group.add_argument("--gradient-clip", type=float, default=1.0, help="Gradient clipping norm")
|
|
106
|
+
|
|
107
|
+
# Output arguments
|
|
108
|
+
output_group = parser.add_argument_group("Output Configuration")
|
|
109
|
+
output_group.add_argument("--output-dir", type=str, default="./outputs", help="Output directory")
|
|
110
|
+
output_group.add_argument("--experiment-name", type=str, default="langvision_experiment", help="Experiment name")
|
|
111
|
+
output_group.add_argument("--save-interval", type=int, default=5, help="Save checkpoint every N epochs")
|
|
112
|
+
output_group.add_argument("--log-interval", type=int, default=10, help="Log every N batches")
|
|
113
|
+
|
|
114
|
+
# System arguments
|
|
115
|
+
system_group = parser.add_argument_group("System Configuration")
|
|
116
|
+
system_group.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
117
|
+
system_group.add_argument("--device", type=str, default="auto", help="Device to use (auto, cpu, cuda, mps)")
|
|
118
|
+
system_group.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
|
119
|
+
default="INFO", help="Logging level")
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def add_multimodal_args(parser: argparse.ArgumentParser) -> None:
|
|
123
|
+
"""Add multimodal training-specific arguments."""
|
|
124
|
+
|
|
125
|
+
# Inherit basic training args
|
|
126
|
+
add_training_args(parser)
|
|
127
|
+
|
|
128
|
+
# Multimodal-specific arguments
|
|
129
|
+
multimodal_group = parser.add_argument_group("Multimodal Configuration")
|
|
130
|
+
multimodal_group.add_argument("--vision-model", choices=["vit_base"], default="vit_base",
|
|
131
|
+
help="Vision model architecture")
|
|
132
|
+
multimodal_group.add_argument("--text-model", type=str, default="bert-base-uncased",
|
|
133
|
+
help="Text model from HuggingFace")
|
|
134
|
+
multimodal_group.add_argument("--vision-dim", type=int, default=768, help="Vision feature dimension")
|
|
135
|
+
multimodal_group.add_argument("--text-dim", type=int, default=768, help="Text feature dimension")
|
|
136
|
+
multimodal_group.add_argument("--hidden-dim", type=int, default=512, help="Hidden dimension for fusion")
|
|
137
|
+
multimodal_group.add_argument("--max-text-length", type=int, default=77, help="Maximum text sequence length")
|
|
138
|
+
multimodal_group.add_argument("--annotations-file", type=str, help="Path to text annotations file")
|
|
139
|
+
multimodal_group.add_argument("--contrastive-weight", type=float, default=1.0,
|
|
140
|
+
help="Weight for contrastive loss")
|
|
141
|
+
multimodal_group.add_argument("--classification-weight", type=float, default=0.5,
|
|
142
|
+
help="Weight for classification loss")
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def add_evaluation_args(parser: argparse.ArgumentParser) -> None:
|
|
146
|
+
"""Add evaluation-specific arguments."""
|
|
147
|
+
|
|
148
|
+
parser.add_argument("--model-path", type=str, required=True, help="Path to trained model checkpoint")
|
|
149
|
+
parser.add_argument("--data-dir", type=str, required=True, help="Path to evaluation dataset")
|
|
150
|
+
parser.add_argument("--dataset", choices=["cifar10", "cifar100", "imagefolder"],
|
|
151
|
+
default="cifar10", help="Dataset type")
|
|
152
|
+
parser.add_argument("--batch-size", type=int, default=32, help="Batch size for evaluation")
|
|
153
|
+
parser.add_argument("--num-workers", type=int, default=4, help="Number of data loading workers")
|
|
154
|
+
parser.add_argument("--output-dir", type=str, default="./eval_results", help="Output directory for results")
|
|
155
|
+
parser.add_argument("--save-predictions", action="store_true", help="Save model predictions")
|
|
156
|
+
parser.add_argument("--benchmark", action="store_true", help="Run inference speed benchmark")
|
|
157
|
+
parser.add_argument("--device", type=str, default="auto", help="Device to use")
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def add_setup_args(parser: argparse.ArgumentParser) -> None:
|
|
161
|
+
"""Add setup-specific arguments."""
|
|
162
|
+
|
|
163
|
+
parser.add_argument("--check-deps", action="store_true", help="Check all dependencies")
|
|
164
|
+
parser.add_argument("--validate-env", action="store_true", help="Validate environment")
|
|
165
|
+
parser.add_argument("--setup-cuda", action="store_true", help="Setup CUDA environment")
|
|
166
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed for testing")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def handle_train_command(args: argparse.Namespace) -> None:
|
|
170
|
+
"""Handle training command."""
|
|
171
|
+
|
|
172
|
+
logger = logging.getLogger("langvision.cli")
|
|
173
|
+
logger.info("Starting training...")
|
|
174
|
+
|
|
175
|
+
# Initialize framework
|
|
176
|
+
quick_setup(seed=args.seed, log_level=args.log_level)
|
|
177
|
+
|
|
178
|
+
# Create LoRA config if requested
|
|
179
|
+
lora_config = None
|
|
180
|
+
if args.use_lora:
|
|
181
|
+
lora_config = LoRAConfig(
|
|
182
|
+
r=args.lora_r,
|
|
183
|
+
alpha=args.lora_alpha,
|
|
184
|
+
dropout=args.lora_dropout
|
|
185
|
+
)
|
|
186
|
+
logger.info(f"Using LoRA with r={args.lora_r}, alpha={args.lora_alpha}")
|
|
187
|
+
|
|
188
|
+
# Create model
|
|
189
|
+
if args.model == "vit":
|
|
190
|
+
model = VisionTransformer(
|
|
191
|
+
img_size=args.img_size,
|
|
192
|
+
num_classes=args.num_classes,
|
|
193
|
+
lora_config=lora_config
|
|
194
|
+
)
|
|
195
|
+
elif args.model.startswith("resnet"):
|
|
196
|
+
model_fn = {
|
|
197
|
+
"resnet18": "resnet18", "resnet34": "resnet34", "resnet50": "resnet50",
|
|
198
|
+
"resnet101": "resnet101", "resnet152": "resnet152"
|
|
199
|
+
}[args.model]
|
|
200
|
+
from ..models import resnet
|
|
201
|
+
model = getattr(resnet, model_fn)(num_classes=args.num_classes, lora_config=lora_config)
|
|
202
|
+
else:
|
|
203
|
+
raise ValueError(f"Unsupported model: {args.model}")
|
|
204
|
+
|
|
205
|
+
logger.info(f"Created {args.model} model with {args.num_classes} classes")
|
|
206
|
+
|
|
207
|
+
# Create dataset config
|
|
208
|
+
dataset_config = DatasetConfig(
|
|
209
|
+
root_dir=args.data_dir,
|
|
210
|
+
image_size=(args.img_size, args.img_size),
|
|
211
|
+
batch_size=args.batch_size,
|
|
212
|
+
num_workers=args.num_workers,
|
|
213
|
+
use_augmentation=args.augmentation,
|
|
214
|
+
augmentation_strength=args.augmentation_strength
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Create dataloaders
|
|
218
|
+
dataloaders = create_enhanced_dataloaders(dataset_config)
|
|
219
|
+
logger.info(f"Created dataloaders for {args.dataset}")
|
|
220
|
+
|
|
221
|
+
# Create training config
|
|
222
|
+
training_config = TrainingConfig(
|
|
223
|
+
epochs=args.epochs,
|
|
224
|
+
batch_size=args.batch_size,
|
|
225
|
+
learning_rate=args.lr,
|
|
226
|
+
weight_decay=args.weight_decay,
|
|
227
|
+
optimizer=args.optimizer,
|
|
228
|
+
scheduler=args.scheduler,
|
|
229
|
+
warmup_epochs=args.warmup_epochs,
|
|
230
|
+
use_amp=args.use_amp,
|
|
231
|
+
gradient_clip_norm=args.gradient_clip,
|
|
232
|
+
output_dir=args.output_dir,
|
|
233
|
+
experiment_name=args.experiment_name,
|
|
234
|
+
save_interval=args.save_interval,
|
|
235
|
+
log_interval=args.log_interval,
|
|
236
|
+
lora_config=lora_config,
|
|
237
|
+
freeze_backbone=args.freeze_backbone
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Create trainer
|
|
241
|
+
trainer = AdvancedTrainer(
|
|
242
|
+
model=model,
|
|
243
|
+
train_loader=dataloaders['train'],
|
|
244
|
+
val_loader=dataloaders.get('val'),
|
|
245
|
+
config=training_config
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Start training
|
|
249
|
+
trainer.train()
|
|
250
|
+
logger.info("Training completed!")
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def handle_setup_command(args: argparse.Namespace) -> None:
|
|
254
|
+
"""Handle setup command."""
|
|
255
|
+
|
|
256
|
+
config = {
|
|
257
|
+
"log_level": "INFO",
|
|
258
|
+
"seed": args.seed
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
if args.setup_cuda:
|
|
262
|
+
config["setup_cuda"] = True
|
|
263
|
+
|
|
264
|
+
results = initialize_langvision(config)
|
|
265
|
+
|
|
266
|
+
print("\n=== Langvision Environment Setup ===")
|
|
267
|
+
print(f"Python: {results['validation_results']['python_version']}")
|
|
268
|
+
print(f"PyTorch: {results['validation_results']['pytorch_version']}")
|
|
269
|
+
|
|
270
|
+
if results['validation_results']['cuda_available']:
|
|
271
|
+
print(f"CUDA: {results['validation_results']['cuda_version']}")
|
|
272
|
+
print(f"GPUs: {results['validation_results']['gpu_count']}")
|
|
273
|
+
else:
|
|
274
|
+
print("CUDA: Not available")
|
|
275
|
+
|
|
276
|
+
if args.check_deps:
|
|
277
|
+
print("\n=== Dependencies ===")
|
|
278
|
+
for dep, available in results['dependencies'].items():
|
|
279
|
+
status = "✓" if available else "✗"
|
|
280
|
+
print(f"{status} {dep}")
|
|
281
|
+
|
|
282
|
+
print("\nSetup completed successfully!")
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def main() -> None:
|
|
286
|
+
"""Main CLI entry point."""
|
|
287
|
+
|
|
288
|
+
parser = create_parser()
|
|
289
|
+
args = parser.parse_args()
|
|
290
|
+
|
|
291
|
+
if not args.command:
|
|
292
|
+
parser.print_help()
|
|
293
|
+
sys.exit(1)
|
|
294
|
+
|
|
295
|
+
try:
|
|
296
|
+
if args.command == "train":
|
|
297
|
+
handle_train_command(args)
|
|
298
|
+
elif args.command == "train-multimodal":
|
|
299
|
+
# TODO: Implement multimodal training handler
|
|
300
|
+
print("Multimodal training not yet implemented in CLI")
|
|
301
|
+
sys.exit(1)
|
|
302
|
+
elif args.command == "evaluate":
|
|
303
|
+
# TODO: Implement evaluation handler
|
|
304
|
+
print("Evaluation not yet implemented in CLI")
|
|
305
|
+
sys.exit(1)
|
|
306
|
+
elif args.command == "setup":
|
|
307
|
+
handle_setup_command(args)
|
|
308
|
+
else:
|
|
309
|
+
parser.print_help()
|
|
310
|
+
sys.exit(1)
|
|
311
|
+
|
|
312
|
+
except Exception as e:
|
|
313
|
+
logger = logging.getLogger("langvision.cli")
|
|
314
|
+
logger.error(f"Command failed: {str(e)}")
|
|
315
|
+
sys.exit(1)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
if __name__ == "__main__":
|
|
319
|
+
main()
|