langvision 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langvision might be problematic. Click here for more details.

Files changed (41) hide show
  1. langvision/__init__.py +77 -2
  2. langvision/callbacks/base.py +166 -7
  3. langvision/cli/__init__.py +85 -0
  4. langvision/cli/complete_cli.py +319 -0
  5. langvision/cli/config.py +344 -0
  6. langvision/cli/evaluate.py +201 -0
  7. langvision/cli/export.py +177 -0
  8. langvision/cli/finetune.py +165 -48
  9. langvision/cli/model_zoo.py +162 -0
  10. langvision/cli/train.py +27 -13
  11. langvision/cli/utils.py +258 -0
  12. langvision/components/attention.py +4 -1
  13. langvision/concepts/__init__.py +9 -0
  14. langvision/concepts/ccot.py +30 -0
  15. langvision/concepts/cot.py +29 -0
  16. langvision/concepts/dpo.py +37 -0
  17. langvision/concepts/grpo.py +25 -0
  18. langvision/concepts/lime.py +37 -0
  19. langvision/concepts/ppo.py +47 -0
  20. langvision/concepts/rlhf.py +40 -0
  21. langvision/concepts/rlvr.py +25 -0
  22. langvision/concepts/shap.py +37 -0
  23. langvision/data/enhanced_datasets.py +582 -0
  24. langvision/model_zoo.py +169 -2
  25. langvision/models/lora.py +189 -17
  26. langvision/models/multimodal.py +297 -0
  27. langvision/models/resnet.py +303 -0
  28. langvision/training/advanced_trainer.py +478 -0
  29. langvision/training/trainer.py +30 -2
  30. langvision/utils/config.py +180 -9
  31. langvision/utils/metrics.py +448 -0
  32. langvision/utils/setup.py +266 -0
  33. langvision-0.1.0.dist-info/METADATA +50 -0
  34. langvision-0.1.0.dist-info/RECORD +61 -0
  35. {langvision-0.0.2.dist-info → langvision-0.1.0.dist-info}/WHEEL +1 -1
  36. langvision-0.1.0.dist-info/entry_points.txt +2 -0
  37. langvision-0.0.2.dist-info/METADATA +0 -372
  38. langvision-0.0.2.dist-info/RECORD +0 -40
  39. langvision-0.0.2.dist-info/entry_points.txt +0 -2
  40. langvision-0.0.2.dist-info/licenses/LICENSE +0 -21
  41. {langvision-0.0.2.dist-info → langvision-0.1.0.dist-info}/top_level.txt +0 -0
langvision/__init__.py CHANGED
@@ -1,7 +1,82 @@
1
1
  """
2
- langvision - A platform for building and deploying AI agents
2
+ langvision - Modular Vision LLMs with Efficient LoRA Fine-Tuning
3
+
4
+ A research-friendly framework for building and fine-tuning Vision Large Language Models
5
+ with efficient Low-Rank Adaptation (LoRA) support.
3
6
  """
4
7
 
5
8
  __version__ = "0.1.0"
9
+ __author__ = "Pritesh Raj"
10
+ __email__ = "priteshraj10@gmail.com"
11
+
12
+ # Core imports for easy access
13
+ from .models.vision_transformer import VisionTransformer
14
+ from .models.lora import LoRALinear, LoRAConfig, AdaLoRALinear, QLoRALinear
15
+ from .models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
16
+ from .models.multimodal import VisionLanguageModel, create_multimodal_model, CLIPLoss
17
+ from .utils.config import default_config
18
+ from .training.trainer import Trainer
19
+ from .training.advanced_trainer import AdvancedTrainer, TrainingConfig
20
+ from .data.datasets import get_dataset
21
+ from .data.enhanced_datasets import (
22
+ EnhancedImageDataset, MultimodalDataset, DatasetConfig,
23
+ create_enhanced_dataloaders, SmartAugmentation
24
+ )
25
+ from .utils.metrics import (
26
+ MetricsTracker, ClassificationMetrics, ContrastiveMetrics,
27
+ EvaluationSuite, PerformanceMetrics
28
+ )
29
+ from .callbacks.base import Callback, CallbackManager
30
+ from .concepts import RLHF, CoT, CCoT, GRPO, RLVR, DPO, PPO, LIME, SHAP
31
+
32
+ # Version info
33
+ __all__ = [
34
+ "__version__",
35
+ "__author__",
36
+ "__email__",
37
+ # Core Models
38
+ "VisionTransformer",
39
+ "resnet18", "resnet34", "resnet50", "resnet101", "resnet152",
40
+ "VisionLanguageModel", "create_multimodal_model",
41
+ # LoRA Components
42
+ "LoRALinear", "LoRAConfig", "AdaLoRALinear", "QLoRALinear",
43
+ # Training
44
+ "Trainer", "AdvancedTrainer", "TrainingConfig",
45
+ # Data
46
+ "get_dataset", "EnhancedImageDataset", "MultimodalDataset",
47
+ "DatasetConfig", "create_enhanced_dataloaders", "SmartAugmentation",
48
+ # Utilities
49
+ "default_config", "MetricsTracker", "ClassificationMetrics",
50
+ "ContrastiveMetrics", "EvaluationSuite", "PerformanceMetrics",
51
+ # Callbacks
52
+ "Callback", "CallbackManager",
53
+ # Loss Functions
54
+ "CLIPLoss",
55
+ # Concepts
56
+ "RLHF", "CoT", "CCoT", "GRPO", "RLVR", "DPO", "PPO", "LIME", "SHAP",
57
+ ]
58
+
59
+ # Optional imports for advanced usage
60
+ try:
61
+ from .callbacks import EarlyStoppingCallback, LoggingCallback
62
+ from .utils.device import get_device, to_device
63
+ __all__.extend([
64
+ "EarlyStoppingCallback",
65
+ "LoggingCallback",
66
+ "get_device",
67
+ "to_device"
68
+ ])
69
+ except ImportError:
70
+ pass
6
71
 
7
- """langvision: Vision LLMs with LoRA and unified model zoo."""
72
+ # Package metadata
73
+ PACKAGE_METADATA = {
74
+ "name": "langvision",
75
+ "version": __version__,
76
+ "description": "Modular Vision LLMs with Efficient LoRA Fine-Tuning",
77
+ "author": __author__,
78
+ "email": __email__,
79
+ "url": "https://github.com/langtrain-ai/langtrain",
80
+ "license": "MIT",
81
+ "python_requires": ">=3.8",
82
+ }
@@ -1,11 +1,170 @@
1
- class Callback:
2
- def on_train_begin(self, trainer):
1
+ """Enhanced callback system for training and evaluation hooks."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Dict, Any, Optional, List
5
+ import logging
6
+ import traceback
7
+
8
+
9
+ class Callback(ABC):
10
+ """Base class for all callbacks with comprehensive training hooks."""
11
+
12
+ def __init__(self, name: Optional[str] = None):
13
+ self.name = name or self.__class__.__name__
14
+ self.logger = logging.getLogger(f"langvision.callbacks.{self.name}")
15
+
16
+ def on_train_start(self, trainer) -> None:
17
+ """Called at the beginning of training."""
3
18
  pass
4
- def on_epoch_begin(self, trainer, epoch):
19
+
20
+ def on_train_end(self, trainer) -> None:
21
+ """Called at the end of training."""
5
22
  pass
6
- def on_batch_end(self, trainer, batch, logs=None):
23
+
24
+ def on_epoch_start(self, trainer, epoch: int) -> None:
25
+ """Called at the beginning of each epoch."""
7
26
  pass
8
- def on_epoch_end(self, trainer, epoch, logs=None):
27
+
28
+ def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, Any]) -> None:
29
+ """Called at the end of each epoch."""
9
30
  pass
10
- def on_train_end(self, trainer):
11
- pass
31
+
32
+ def on_batch_start(self, trainer, batch_idx: int, batch: Dict[str, Any]) -> None:
33
+ """Called at the beginning of each batch."""
34
+ pass
35
+
36
+ def on_batch_end(self, trainer, batch_idx: int, batch: Dict[str, Any], outputs: Dict[str, Any]) -> None:
37
+ """Called at the end of each batch."""
38
+ pass
39
+
40
+ def on_validation_start(self, trainer) -> None:
41
+ """Called at the beginning of validation."""
42
+ pass
43
+
44
+ def on_validation_end(self, trainer, metrics: Dict[str, Any]) -> None:
45
+ """Called at the end of validation."""
46
+ pass
47
+
48
+ def on_test_start(self, trainer) -> None:
49
+ """Called at the beginning of testing."""
50
+ pass
51
+
52
+ def on_test_end(self, trainer, metrics: Dict[str, Any]) -> None:
53
+ """Called at the end of testing."""
54
+ pass
55
+
56
+ def on_checkpoint_save(self, trainer, checkpoint_path: str, metrics: Dict[str, Any]) -> None:
57
+ """Called when a checkpoint is saved."""
58
+ pass
59
+
60
+ def on_checkpoint_load(self, trainer, checkpoint_path: str) -> None:
61
+ """Called when a checkpoint is loaded."""
62
+ pass
63
+
64
+ def on_lr_schedule(self, trainer, old_lr: float, new_lr: float) -> None:
65
+ """Called when learning rate is scheduled."""
66
+ pass
67
+
68
+ def on_exception(self, trainer, exception: Exception) -> bool:
69
+ """Called when an exception occurs during training.
70
+
71
+ Returns:
72
+ bool: True if the exception was handled and training should continue,
73
+ False if training should stop.
74
+ """
75
+ return False
76
+
77
+
78
+ class CallbackManager:
79
+ """Manager for handling multiple callbacks with error handling."""
80
+
81
+ def __init__(self, callbacks: Optional[List[Callback]] = None):
82
+ self.callbacks = callbacks or []
83
+ self.logger = logging.getLogger("langvision.callbacks.manager")
84
+
85
+ def add_callback(self, callback: Callback) -> None:
86
+ """Add a callback to the manager."""
87
+ if not isinstance(callback, Callback):
88
+ raise TypeError(f"Expected Callback instance, got {type(callback)}")
89
+ self.callbacks.append(callback)
90
+ self.logger.info(f"Added callback: {callback.name}")
91
+
92
+ def remove_callback(self, callback_name: str) -> bool:
93
+ """Remove a callback by name."""
94
+ for i, callback in enumerate(self.callbacks):
95
+ if callback.name == callback_name:
96
+ del self.callbacks[i]
97
+ self.logger.info(f"Removed callback: {callback_name}")
98
+ return True
99
+ return False
100
+
101
+ def _call_callbacks(self, method_name: str, *args, **kwargs) -> None:
102
+ """Safely call a method on all callbacks."""
103
+ for callback in self.callbacks:
104
+ try:
105
+ method = getattr(callback, method_name, None)
106
+ if method and callable(method):
107
+ method(*args, **kwargs)
108
+ except Exception as e:
109
+ self.logger.error(
110
+ f"Error in callback {callback.name}.{method_name}: {str(e)}\n"
111
+ f"Traceback: {traceback.format_exc()}"
112
+ )
113
+ # Continue with other callbacks even if one fails
114
+
115
+ def on_train_start(self, trainer) -> None:
116
+ self._call_callbacks('on_train_start', trainer)
117
+
118
+ def on_train_end(self, trainer) -> None:
119
+ self._call_callbacks('on_train_end', trainer)
120
+
121
+ def on_epoch_start(self, trainer, epoch: int) -> None:
122
+ self._call_callbacks('on_epoch_start', trainer, epoch)
123
+
124
+ def on_epoch_end(self, trainer, epoch: int, metrics: Dict[str, Any]) -> None:
125
+ self._call_callbacks('on_epoch_end', trainer, epoch, metrics)
126
+
127
+ def on_batch_start(self, trainer, batch_idx: int, batch: Dict[str, Any]) -> None:
128
+ self._call_callbacks('on_batch_start', trainer, batch_idx, batch)
129
+
130
+ def on_batch_end(self, trainer, batch_idx: int, batch: Dict[str, Any], outputs: Dict[str, Any]) -> None:
131
+ self._call_callbacks('on_batch_end', trainer, batch_idx, batch, outputs)
132
+
133
+ def on_validation_start(self, trainer) -> None:
134
+ self._call_callbacks('on_validation_start', trainer)
135
+
136
+ def on_validation_end(self, trainer, metrics: Dict[str, Any]) -> None:
137
+ self._call_callbacks('on_validation_end', trainer, metrics)
138
+
139
+ def on_test_start(self, trainer) -> None:
140
+ self._call_callbacks('on_test_start', trainer)
141
+
142
+ def on_test_end(self, trainer, metrics: Dict[str, Any]) -> None:
143
+ self._call_callbacks('on_test_end', trainer, metrics)
144
+
145
+ def on_checkpoint_save(self, trainer, checkpoint_path: str, metrics: Dict[str, Any]) -> None:
146
+ self._call_callbacks('on_checkpoint_save', trainer, checkpoint_path, metrics)
147
+
148
+ def on_checkpoint_load(self, trainer, checkpoint_path: str) -> None:
149
+ self._call_callbacks('on_checkpoint_load', trainer, checkpoint_path)
150
+
151
+ def on_lr_schedule(self, trainer, old_lr: float, new_lr: float) -> None:
152
+ self._call_callbacks('on_lr_schedule', trainer, old_lr, new_lr)
153
+
154
+ def on_exception(self, trainer, exception: Exception) -> bool:
155
+ """Handle exceptions through callbacks.
156
+
157
+ Returns:
158
+ bool: True if any callback handled the exception and training should continue.
159
+ """
160
+ handled = False
161
+ for callback in self.callbacks:
162
+ try:
163
+ if callback.on_exception(trainer, exception):
164
+ handled = True
165
+ self.logger.info(f"Exception handled by callback: {callback.name}")
166
+ except Exception as callback_error:
167
+ self.logger.error(
168
+ f"Error in callback {callback.name}.on_exception: {str(callback_error)}"
169
+ )
170
+ return handled
@@ -0,0 +1,85 @@
1
+ import argparse
2
+ import sys
3
+ from .train import main as train_main
4
+ from .finetune import main as finetune_main
5
+ from .evaluate import main as evaluate_main
6
+ from .export import main as export_main
7
+ from .model_zoo import main as model_zoo_main
8
+ from .config import main as config_main
9
+
10
+ __version__ = "0.1.0" # Keep in sync with package version in pyproject.toml
11
+
12
+ def print_banner():
13
+ banner = r"""
14
+ \033[1;36m
15
+ _ __ ___ _
16
+ | | __ _ _ __ __ _\ \ / (_)___(_) ___ _ __
17
+ | | / _` | '_ \ / _` |\ \ / /| / __| |/ _ \| '_ \
18
+ | |__| (_| | | | | (_| | \ V / | \__ \ | (_) | | | |
19
+ |_____\__,_|_| |_|\__, | \_/ |_|___/_|\___/|_| |_|
20
+ |___/
21
+ \033[0m
22
+ """
23
+ print(banner)
24
+ print("\033[1;33mLANGVISION\033[0m: Modular Vision LLMs with Efficient LoRA Fine-Tuning")
25
+ print(f"\033[1;35mVersion:\033[0m {__version__}")
26
+ print("\033[1;32mDocs:\033[0m https://github.com/langtrain-ai/langtrain/tree/main/docs \033[1;34mPyPI:\033[0m https://pypi.org/project/langvision/\n")
27
+
28
+ def main():
29
+ print_banner()
30
+ parser = argparse.ArgumentParser(
31
+ prog="langvision",
32
+ description="Langvision: Modular Vision LLMs with Efficient LoRA Fine-Tuning.\n\nUse subcommands to train or finetune vision models."
33
+ )
34
+ parser.add_argument('--version', action='version', version=f'%(prog)s {__version__}')
35
+ subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-commands')
36
+
37
+ # Train subcommand
38
+ train_parser = subparsers.add_parser('train', help='Train a VisionTransformer model')
39
+ train_parser.add_argument('args', nargs=argparse.REMAINDER)
40
+
41
+ # Finetune subcommand
42
+ finetune_parser = subparsers.add_parser('finetune', help='Finetune a VisionTransformer model with LoRA and LLM concepts')
43
+ finetune_parser.add_argument('args', nargs=argparse.REMAINDER)
44
+
45
+ # Evaluate subcommand
46
+ evaluate_parser = subparsers.add_parser('evaluate', help='Evaluate a trained model')
47
+ evaluate_parser.add_argument('args', nargs=argparse.REMAINDER)
48
+
49
+ # Export subcommand
50
+ export_parser = subparsers.add_parser('export', help='Export a model to various formats (ONNX, TorchScript)')
51
+ export_parser.add_argument('args', nargs=argparse.REMAINDER)
52
+
53
+ # Model Zoo subcommand
54
+ model_zoo_parser = subparsers.add_parser('model-zoo', help='Browse and download pre-trained models')
55
+ model_zoo_parser.add_argument('args', nargs=argparse.REMAINDER)
56
+
57
+ # Config subcommand
58
+ config_parser = subparsers.add_parser('config', help='Manage configuration files')
59
+ config_parser.add_argument('args', nargs=argparse.REMAINDER)
60
+
61
+ args = parser.parse_args()
62
+
63
+ if args.command == 'train':
64
+ sys.argv = [sys.argv[0]] + args.args
65
+ train_main()
66
+ elif args.command == 'finetune':
67
+ sys.argv = [sys.argv[0]] + args.args
68
+ finetune_main()
69
+ elif args.command == 'evaluate':
70
+ sys.argv = [sys.argv[0]] + args.args
71
+ evaluate_main()
72
+ elif args.command == 'export':
73
+ sys.argv = [sys.argv[0]] + args.args
74
+ export_main()
75
+ elif args.command == 'model-zoo':
76
+ sys.argv = [sys.argv[0]] + args.args
77
+ model_zoo_main()
78
+ elif args.command == 'config':
79
+ sys.argv = [sys.argv[0]] + args.args
80
+ config_main()
81
+ else:
82
+ parser.print_help()
83
+
84
+ if __name__ == '__main__':
85
+ main()
@@ -0,0 +1,319 @@
1
+ """
2
+ Complete CLI interface for Langvision framework.
3
+ """
4
+
5
+ import argparse
6
+ import sys
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import Optional, Dict, Any
10
+
11
+ from ..utils.setup import initialize_langvision, quick_setup
12
+ from ..models.lora import LoRAConfig
13
+ from ..training.advanced_trainer import AdvancedTrainer, TrainingConfig
14
+ from ..data.enhanced_datasets import create_enhanced_dataloaders, DatasetConfig
15
+ from ..models.vision_transformer import VisionTransformer
16
+ from ..models.resnet import resnet50
17
+ from ..models.multimodal import create_multimodal_model
18
+ from ..utils.device import get_device
19
+
20
+
21
+ def create_parser() -> argparse.ArgumentParser:
22
+ """Create comprehensive argument parser for Langvision CLI."""
23
+
24
+ parser = argparse.ArgumentParser(
25
+ prog="langvision",
26
+ description="Langvision: Advanced Vision-Language Models with Efficient LoRA Fine-Tuning",
27
+ formatter_class=argparse.RawDescriptionHelpFormatter,
28
+ epilog="""
29
+ Examples:
30
+ # Basic image classification training
31
+ langvision train --model vit --dataset cifar10 --data-dir ./data --epochs 50
32
+
33
+ # LoRA fine-tuning with custom parameters
34
+ langvision train --model resnet50 --lora-r 16 --lora-alpha 32 --freeze-backbone
35
+
36
+ # Multimodal training
37
+ langvision train-multimodal --vision-model vit --text-model bert-base-uncased --data-dir ./data
38
+
39
+ # Evaluation
40
+ langvision evaluate --model-path ./checkpoints/best_model.pt --data-dir ./test_data
41
+ """
42
+ )
43
+
44
+ parser.add_argument("--version", action="version", version="langvision 0.1.0")
45
+
46
+ # Create subparsers
47
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
48
+
49
+ # Training command
50
+ train_parser = subparsers.add_parser("train", help="Train a vision model")
51
+ add_training_args(train_parser)
52
+
53
+ # Multimodal training command
54
+ multimodal_parser = subparsers.add_parser("train-multimodal", help="Train a multimodal vision-language model")
55
+ add_multimodal_args(multimodal_parser)
56
+
57
+ # Evaluation command
58
+ eval_parser = subparsers.add_parser("evaluate", help="Evaluate a trained model")
59
+ add_evaluation_args(eval_parser)
60
+
61
+ # Setup command
62
+ setup_parser = subparsers.add_parser("setup", help="Setup and validate Langvision environment")
63
+ add_setup_args(setup_parser)
64
+
65
+ return parser
66
+
67
+
68
+ def add_training_args(parser: argparse.ArgumentParser) -> None:
69
+ """Add training-specific arguments."""
70
+
71
+ # Model arguments
72
+ model_group = parser.add_argument_group("Model Configuration")
73
+ model_group.add_argument("--model", choices=["vit", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"],
74
+ default="vit", help="Model architecture to use")
75
+ model_group.add_argument("--num-classes", type=int, default=10, help="Number of output classes")
76
+ model_group.add_argument("--img-size", type=int, default=224, help="Input image size")
77
+
78
+ # LoRA arguments
79
+ lora_group = parser.add_argument_group("LoRA Configuration")
80
+ lora_group.add_argument("--use-lora", action="store_true", help="Enable LoRA fine-tuning")
81
+ lora_group.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
82
+ lora_group.add_argument("--lora-alpha", type=float, default=32.0, help="LoRA alpha")
83
+ lora_group.add_argument("--lora-dropout", type=float, default=0.1, help="LoRA dropout")
84
+ lora_group.add_argument("--freeze-backbone", action="store_true", help="Freeze backbone parameters")
85
+
86
+ # Data arguments
87
+ data_group = parser.add_argument_group("Data Configuration")
88
+ data_group.add_argument("--dataset", choices=["cifar10", "cifar100", "imagefolder"],
89
+ default="cifar10", help="Dataset to use")
90
+ data_group.add_argument("--data-dir", type=str, required=True, help="Path to dataset")
91
+ data_group.add_argument("--batch-size", type=int, default=32, help="Batch size")
92
+ data_group.add_argument("--num-workers", type=int, default=4, help="Number of data loading workers")
93
+ data_group.add_argument("--augmentation", action="store_true", help="Enable data augmentation")
94
+ data_group.add_argument("--augmentation-strength", type=float, default=0.5, help="Augmentation strength")
95
+
96
+ # Training arguments
97
+ train_group = parser.add_argument_group("Training Configuration")
98
+ train_group.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
99
+ train_group.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
100
+ train_group.add_argument("--weight-decay", type=float, default=1e-4, help="Weight decay")
101
+ train_group.add_argument("--optimizer", choices=["adam", "adamw", "sgd"], default="adamw", help="Optimizer")
102
+ train_group.add_argument("--scheduler", choices=["cosine", "step", "plateau"], default="cosine", help="LR scheduler")
103
+ train_group.add_argument("--warmup-epochs", type=int, default=5, help="Warmup epochs")
104
+ train_group.add_argument("--use-amp", action="store_true", help="Use mixed precision training")
105
+ train_group.add_argument("--gradient-clip", type=float, default=1.0, help="Gradient clipping norm")
106
+
107
+ # Output arguments
108
+ output_group = parser.add_argument_group("Output Configuration")
109
+ output_group.add_argument("--output-dir", type=str, default="./outputs", help="Output directory")
110
+ output_group.add_argument("--experiment-name", type=str, default="langvision_experiment", help="Experiment name")
111
+ output_group.add_argument("--save-interval", type=int, default=5, help="Save checkpoint every N epochs")
112
+ output_group.add_argument("--log-interval", type=int, default=10, help="Log every N batches")
113
+
114
+ # System arguments
115
+ system_group = parser.add_argument_group("System Configuration")
116
+ system_group.add_argument("--seed", type=int, default=42, help="Random seed")
117
+ system_group.add_argument("--device", type=str, default="auto", help="Device to use (auto, cpu, cuda, mps)")
118
+ system_group.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"],
119
+ default="INFO", help="Logging level")
120
+
121
+
122
+ def add_multimodal_args(parser: argparse.ArgumentParser) -> None:
123
+ """Add multimodal training-specific arguments."""
124
+
125
+ # Inherit basic training args
126
+ add_training_args(parser)
127
+
128
+ # Multimodal-specific arguments
129
+ multimodal_group = parser.add_argument_group("Multimodal Configuration")
130
+ multimodal_group.add_argument("--vision-model", choices=["vit_base"], default="vit_base",
131
+ help="Vision model architecture")
132
+ multimodal_group.add_argument("--text-model", type=str, default="bert-base-uncased",
133
+ help="Text model from HuggingFace")
134
+ multimodal_group.add_argument("--vision-dim", type=int, default=768, help="Vision feature dimension")
135
+ multimodal_group.add_argument("--text-dim", type=int, default=768, help="Text feature dimension")
136
+ multimodal_group.add_argument("--hidden-dim", type=int, default=512, help="Hidden dimension for fusion")
137
+ multimodal_group.add_argument("--max-text-length", type=int, default=77, help="Maximum text sequence length")
138
+ multimodal_group.add_argument("--annotations-file", type=str, help="Path to text annotations file")
139
+ multimodal_group.add_argument("--contrastive-weight", type=float, default=1.0,
140
+ help="Weight for contrastive loss")
141
+ multimodal_group.add_argument("--classification-weight", type=float, default=0.5,
142
+ help="Weight for classification loss")
143
+
144
+
145
+ def add_evaluation_args(parser: argparse.ArgumentParser) -> None:
146
+ """Add evaluation-specific arguments."""
147
+
148
+ parser.add_argument("--model-path", type=str, required=True, help="Path to trained model checkpoint")
149
+ parser.add_argument("--data-dir", type=str, required=True, help="Path to evaluation dataset")
150
+ parser.add_argument("--dataset", choices=["cifar10", "cifar100", "imagefolder"],
151
+ default="cifar10", help="Dataset type")
152
+ parser.add_argument("--batch-size", type=int, default=32, help="Batch size for evaluation")
153
+ parser.add_argument("--num-workers", type=int, default=4, help="Number of data loading workers")
154
+ parser.add_argument("--output-dir", type=str, default="./eval_results", help="Output directory for results")
155
+ parser.add_argument("--save-predictions", action="store_true", help="Save model predictions")
156
+ parser.add_argument("--benchmark", action="store_true", help="Run inference speed benchmark")
157
+ parser.add_argument("--device", type=str, default="auto", help="Device to use")
158
+
159
+
160
+ def add_setup_args(parser: argparse.ArgumentParser) -> None:
161
+ """Add setup-specific arguments."""
162
+
163
+ parser.add_argument("--check-deps", action="store_true", help="Check all dependencies")
164
+ parser.add_argument("--validate-env", action="store_true", help="Validate environment")
165
+ parser.add_argument("--setup-cuda", action="store_true", help="Setup CUDA environment")
166
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for testing")
167
+
168
+
169
+ def handle_train_command(args: argparse.Namespace) -> None:
170
+ """Handle training command."""
171
+
172
+ logger = logging.getLogger("langvision.cli")
173
+ logger.info("Starting training...")
174
+
175
+ # Initialize framework
176
+ quick_setup(seed=args.seed, log_level=args.log_level)
177
+
178
+ # Create LoRA config if requested
179
+ lora_config = None
180
+ if args.use_lora:
181
+ lora_config = LoRAConfig(
182
+ r=args.lora_r,
183
+ alpha=args.lora_alpha,
184
+ dropout=args.lora_dropout
185
+ )
186
+ logger.info(f"Using LoRA with r={args.lora_r}, alpha={args.lora_alpha}")
187
+
188
+ # Create model
189
+ if args.model == "vit":
190
+ model = VisionTransformer(
191
+ img_size=args.img_size,
192
+ num_classes=args.num_classes,
193
+ lora_config=lora_config
194
+ )
195
+ elif args.model.startswith("resnet"):
196
+ model_fn = {
197
+ "resnet18": "resnet18", "resnet34": "resnet34", "resnet50": "resnet50",
198
+ "resnet101": "resnet101", "resnet152": "resnet152"
199
+ }[args.model]
200
+ from ..models import resnet
201
+ model = getattr(resnet, model_fn)(num_classes=args.num_classes, lora_config=lora_config)
202
+ else:
203
+ raise ValueError(f"Unsupported model: {args.model}")
204
+
205
+ logger.info(f"Created {args.model} model with {args.num_classes} classes")
206
+
207
+ # Create dataset config
208
+ dataset_config = DatasetConfig(
209
+ root_dir=args.data_dir,
210
+ image_size=(args.img_size, args.img_size),
211
+ batch_size=args.batch_size,
212
+ num_workers=args.num_workers,
213
+ use_augmentation=args.augmentation,
214
+ augmentation_strength=args.augmentation_strength
215
+ )
216
+
217
+ # Create dataloaders
218
+ dataloaders = create_enhanced_dataloaders(dataset_config)
219
+ logger.info(f"Created dataloaders for {args.dataset}")
220
+
221
+ # Create training config
222
+ training_config = TrainingConfig(
223
+ epochs=args.epochs,
224
+ batch_size=args.batch_size,
225
+ learning_rate=args.lr,
226
+ weight_decay=args.weight_decay,
227
+ optimizer=args.optimizer,
228
+ scheduler=args.scheduler,
229
+ warmup_epochs=args.warmup_epochs,
230
+ use_amp=args.use_amp,
231
+ gradient_clip_norm=args.gradient_clip,
232
+ output_dir=args.output_dir,
233
+ experiment_name=args.experiment_name,
234
+ save_interval=args.save_interval,
235
+ log_interval=args.log_interval,
236
+ lora_config=lora_config,
237
+ freeze_backbone=args.freeze_backbone
238
+ )
239
+
240
+ # Create trainer
241
+ trainer = AdvancedTrainer(
242
+ model=model,
243
+ train_loader=dataloaders['train'],
244
+ val_loader=dataloaders.get('val'),
245
+ config=training_config
246
+ )
247
+
248
+ # Start training
249
+ trainer.train()
250
+ logger.info("Training completed!")
251
+
252
+
253
+ def handle_setup_command(args: argparse.Namespace) -> None:
254
+ """Handle setup command."""
255
+
256
+ config = {
257
+ "log_level": "INFO",
258
+ "seed": args.seed
259
+ }
260
+
261
+ if args.setup_cuda:
262
+ config["setup_cuda"] = True
263
+
264
+ results = initialize_langvision(config)
265
+
266
+ print("\n=== Langvision Environment Setup ===")
267
+ print(f"Python: {results['validation_results']['python_version']}")
268
+ print(f"PyTorch: {results['validation_results']['pytorch_version']}")
269
+
270
+ if results['validation_results']['cuda_available']:
271
+ print(f"CUDA: {results['validation_results']['cuda_version']}")
272
+ print(f"GPUs: {results['validation_results']['gpu_count']}")
273
+ else:
274
+ print("CUDA: Not available")
275
+
276
+ if args.check_deps:
277
+ print("\n=== Dependencies ===")
278
+ for dep, available in results['dependencies'].items():
279
+ status = "✓" if available else "✗"
280
+ print(f"{status} {dep}")
281
+
282
+ print("\nSetup completed successfully!")
283
+
284
+
285
+ def main() -> None:
286
+ """Main CLI entry point."""
287
+
288
+ parser = create_parser()
289
+ args = parser.parse_args()
290
+
291
+ if not args.command:
292
+ parser.print_help()
293
+ sys.exit(1)
294
+
295
+ try:
296
+ if args.command == "train":
297
+ handle_train_command(args)
298
+ elif args.command == "train-multimodal":
299
+ # TODO: Implement multimodal training handler
300
+ print("Multimodal training not yet implemented in CLI")
301
+ sys.exit(1)
302
+ elif args.command == "evaluate":
303
+ # TODO: Implement evaluation handler
304
+ print("Evaluation not yet implemented in CLI")
305
+ sys.exit(1)
306
+ elif args.command == "setup":
307
+ handle_setup_command(args)
308
+ else:
309
+ parser.print_help()
310
+ sys.exit(1)
311
+
312
+ except Exception as e:
313
+ logger = logging.getLogger("langvision.cli")
314
+ logger.error(f"Command failed: {str(e)}")
315
+ sys.exit(1)
316
+
317
+
318
+ if __name__ == "__main__":
319
+ main()