langvision 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langvision might be problematic. Click here for more details.
- langvision/__init__.py +77 -2
- langvision/callbacks/base.py +166 -7
- langvision/cli/__init__.py +85 -0
- langvision/cli/complete_cli.py +319 -0
- langvision/cli/config.py +344 -0
- langvision/cli/evaluate.py +201 -0
- langvision/cli/export.py +177 -0
- langvision/cli/finetune.py +165 -48
- langvision/cli/model_zoo.py +162 -0
- langvision/cli/train.py +27 -13
- langvision/cli/utils.py +258 -0
- langvision/components/attention.py +4 -1
- langvision/concepts/__init__.py +9 -0
- langvision/concepts/ccot.py +30 -0
- langvision/concepts/cot.py +29 -0
- langvision/concepts/dpo.py +37 -0
- langvision/concepts/grpo.py +25 -0
- langvision/concepts/lime.py +37 -0
- langvision/concepts/ppo.py +47 -0
- langvision/concepts/rlhf.py +40 -0
- langvision/concepts/rlvr.py +25 -0
- langvision/concepts/shap.py +37 -0
- langvision/data/enhanced_datasets.py +582 -0
- langvision/model_zoo.py +169 -2
- langvision/models/lora.py +189 -17
- langvision/models/multimodal.py +297 -0
- langvision/models/resnet.py +303 -0
- langvision/training/advanced_trainer.py +478 -0
- langvision/training/trainer.py +30 -2
- langvision/utils/config.py +180 -9
- langvision/utils/metrics.py +448 -0
- langvision/utils/setup.py +266 -0
- langvision-0.1.0.dist-info/METADATA +50 -0
- langvision-0.1.0.dist-info/RECORD +61 -0
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/WHEEL +1 -1
- langvision-0.1.0.dist-info/entry_points.txt +2 -0
- langvision-0.0.1.dist-info/METADATA +0 -463
- langvision-0.0.1.dist-info/RECORD +0 -40
- langvision-0.0.1.dist-info/entry_points.txt +0 -2
- langvision-0.0.1.dist-info/licenses/LICENSE +0 -21
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/top_level.txt +0 -0
langvision/cli/train.py
CHANGED
|
@@ -9,19 +9,33 @@ import os
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def parse_args():
|
|
12
|
-
parser = argparse.ArgumentParser(
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
parser.
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
parser.
|
|
23
|
-
|
|
24
|
-
|
|
12
|
+
parser = argparse.ArgumentParser(
|
|
13
|
+
description='Train or evaluate VisionTransformer with LoRA',
|
|
14
|
+
epilog='''\nExamples:\n langvision train --dataset cifar10 --epochs 5\n langvision train --dataset cifar100 --lora_rank 8 --eval\n''',
|
|
15
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
16
|
+
)
|
|
17
|
+
# Data
|
|
18
|
+
data_group = parser.add_argument_group('Data')
|
|
19
|
+
data_group.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='Dataset to use')
|
|
20
|
+
data_group.add_argument('--data_dir', type=str, default='./data', help='Dataset directory')
|
|
21
|
+
# Training
|
|
22
|
+
train_group = parser.add_argument_group('Training')
|
|
23
|
+
train_group.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
|
|
24
|
+
train_group.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
|
|
25
|
+
train_group.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate')
|
|
26
|
+
# LoRA
|
|
27
|
+
lora_group = parser.add_argument_group('LoRA')
|
|
28
|
+
lora_group.add_argument('--lora_rank', type=int, default=4, help='LoRA rank (low-rank adaptation)')
|
|
29
|
+
lora_group.add_argument('--lora_alpha', type=float, default=1.0, help='LoRA alpha scaling')
|
|
30
|
+
lora_group.add_argument('--lora_dropout', type=float, default=0.1, help='LoRA dropout rate')
|
|
31
|
+
# Output
|
|
32
|
+
output_group = parser.add_argument_group('Output')
|
|
33
|
+
output_group.add_argument('--output_dir', type=str, default='./checkpoints', help='Directory to save checkpoints')
|
|
34
|
+
# Misc
|
|
35
|
+
misc_group = parser.add_argument_group('Misc')
|
|
36
|
+
misc_group.add_argument('--eval', action='store_true', help='Run evaluation only (no training)')
|
|
37
|
+
misc_group.add_argument('--export', action='store_true', help='Export model for inference')
|
|
38
|
+
misc_group.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use (cuda or cpu)')
|
|
25
39
|
return parser.parse_args()
|
|
26
40
|
|
|
27
41
|
|
langvision/cli/utils.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Utility functions for CLI operations.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Optional, Dict, Any
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def setup_logging_with_progress(log_level: str = 'info', log_file: Optional[str] = None) -> logging.Logger:
|
|
14
|
+
"""Set up logging with progress bar support."""
|
|
15
|
+
numeric_level = getattr(logging, log_level.upper(), None)
|
|
16
|
+
if not isinstance(numeric_level, int):
|
|
17
|
+
numeric_level = logging.INFO
|
|
18
|
+
|
|
19
|
+
# Create logger
|
|
20
|
+
logger = logging.getLogger('langvision')
|
|
21
|
+
logger.setLevel(numeric_level)
|
|
22
|
+
|
|
23
|
+
# Clear existing handlers
|
|
24
|
+
logger.handlers.clear()
|
|
25
|
+
|
|
26
|
+
# Create formatter
|
|
27
|
+
formatter = logging.Formatter('[%(levelname)s] %(message)s')
|
|
28
|
+
|
|
29
|
+
# Console handler
|
|
30
|
+
console_handler = logging.StreamHandler(sys.stdout)
|
|
31
|
+
console_handler.setLevel(numeric_level)
|
|
32
|
+
console_handler.setFormatter(formatter)
|
|
33
|
+
logger.addHandler(console_handler)
|
|
34
|
+
|
|
35
|
+
# File handler (if specified)
|
|
36
|
+
if log_file:
|
|
37
|
+
file_handler = logging.FileHandler(log_file)
|
|
38
|
+
file_handler.setLevel(numeric_level)
|
|
39
|
+
file_handler.setFormatter(formatter)
|
|
40
|
+
logger.addHandler(file_handler)
|
|
41
|
+
|
|
42
|
+
return logger
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def create_progress_bar(total: int, desc: str = "Processing") -> Any:
|
|
46
|
+
"""Create a progress bar using tqdm."""
|
|
47
|
+
try:
|
|
48
|
+
from tqdm import tqdm
|
|
49
|
+
return tqdm(total=total, desc=desc, unit="it", ncols=100,
|
|
50
|
+
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
|
|
51
|
+
except ImportError:
|
|
52
|
+
# Fallback to simple progress indicator
|
|
53
|
+
return SimpleProgressBar(total, desc)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class SimpleProgressBar:
|
|
57
|
+
"""Simple progress bar fallback when tqdm is not available."""
|
|
58
|
+
|
|
59
|
+
def __init__(self, total: int, desc: str = "Processing"):
|
|
60
|
+
self.total = total
|
|
61
|
+
self.desc = desc
|
|
62
|
+
self.current = 0
|
|
63
|
+
self.width = 50
|
|
64
|
+
|
|
65
|
+
def update(self, n: int = 1):
|
|
66
|
+
"""Update progress."""
|
|
67
|
+
self.current += n
|
|
68
|
+
self._display()
|
|
69
|
+
|
|
70
|
+
def _display(self):
|
|
71
|
+
"""Display progress."""
|
|
72
|
+
if self.total > 0:
|
|
73
|
+
percent = self.current / self.total
|
|
74
|
+
filled = int(self.width * percent)
|
|
75
|
+
bar = '█' * filled + '░' * (self.width - filled)
|
|
76
|
+
print(f'\r{self.desc}: |{bar}| {self.current}/{self.total} ({percent:.1%})', end='', flush=True)
|
|
77
|
+
|
|
78
|
+
def close(self):
|
|
79
|
+
"""Close progress bar."""
|
|
80
|
+
print() # New line
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def validate_file_path(file_path: str, must_exist: bool = True) -> bool:
|
|
84
|
+
"""Validate file path."""
|
|
85
|
+
path = Path(file_path)
|
|
86
|
+
|
|
87
|
+
if must_exist and not path.exists():
|
|
88
|
+
return False
|
|
89
|
+
|
|
90
|
+
if not must_exist:
|
|
91
|
+
# Check if parent directory exists
|
|
92
|
+
if not path.parent.exists():
|
|
93
|
+
try:
|
|
94
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
95
|
+
except Exception:
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
return True
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_file_size_mb(file_path: str) -> float:
|
|
102
|
+
"""Get file size in MB."""
|
|
103
|
+
try:
|
|
104
|
+
size_bytes = os.path.getsize(file_path)
|
|
105
|
+
return size_bytes / (1024 * 1024)
|
|
106
|
+
except OSError:
|
|
107
|
+
return 0.0
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def format_time(seconds: float) -> str:
|
|
111
|
+
"""Format time in human-readable format."""
|
|
112
|
+
if seconds < 60:
|
|
113
|
+
return f"{seconds:.1f}s"
|
|
114
|
+
elif seconds < 3600:
|
|
115
|
+
minutes = seconds / 60
|
|
116
|
+
return f"{minutes:.1f}m"
|
|
117
|
+
else:
|
|
118
|
+
hours = seconds / 3600
|
|
119
|
+
return f"{hours:.1f}h"
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def format_size(size_bytes: int) -> str:
|
|
123
|
+
"""Format size in human-readable format."""
|
|
124
|
+
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
|
125
|
+
if size_bytes < 1024.0:
|
|
126
|
+
return f"{size_bytes:.1f}{unit}"
|
|
127
|
+
size_bytes /= 1024.0
|
|
128
|
+
return f"{size_bytes:.1f}PB"
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def print_success(message: str):
|
|
132
|
+
"""Print success message with green color."""
|
|
133
|
+
print(f"\033[1;32m✅ {message}\033[0m")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def print_error(message: str):
|
|
137
|
+
"""Print error message with red color."""
|
|
138
|
+
print(f"\033[1;31m❌ {message}\033[0m")
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def print_warning(message: str):
|
|
142
|
+
"""Print warning message with yellow color."""
|
|
143
|
+
print(f"\033[1;33m⚠️ {message}\033[0m")
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def print_info(message: str):
|
|
147
|
+
"""Print info message with blue color."""
|
|
148
|
+
print(f"\033[1;34mℹ️ {message}\033[0m")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def print_step(step: int, total: int, message: str):
|
|
152
|
+
"""Print step message with progress indicator."""
|
|
153
|
+
print(f"\033[1;36m[STEP {step}/{total}]\033[0m {message}")
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def check_dependencies() -> Dict[str, bool]:
|
|
157
|
+
"""Check if required dependencies are available."""
|
|
158
|
+
dependencies = {
|
|
159
|
+
'torch': False,
|
|
160
|
+
'torchvision': False,
|
|
161
|
+
'numpy': False,
|
|
162
|
+
'tqdm': False,
|
|
163
|
+
'pyyaml': False,
|
|
164
|
+
'matplotlib': False,
|
|
165
|
+
'pillow': False,
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
for dep in dependencies:
|
|
169
|
+
try:
|
|
170
|
+
__import__(dep)
|
|
171
|
+
dependencies[dep] = True
|
|
172
|
+
except ImportError:
|
|
173
|
+
dependencies[dep] = False
|
|
174
|
+
|
|
175
|
+
return dependencies
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def print_dependency_status():
|
|
179
|
+
"""Print dependency status."""
|
|
180
|
+
deps = check_dependencies()
|
|
181
|
+
|
|
182
|
+
print("\n📦 Dependency Status:")
|
|
183
|
+
print("=" * 30)
|
|
184
|
+
|
|
185
|
+
for dep, available in deps.items():
|
|
186
|
+
status = "✅" if available else "❌"
|
|
187
|
+
print(f"{status} {dep}")
|
|
188
|
+
|
|
189
|
+
missing = [dep for dep, available in deps.items() if not available]
|
|
190
|
+
if missing:
|
|
191
|
+
print(f"\n⚠️ Missing dependencies: {', '.join(missing)}")
|
|
192
|
+
print("Install with: pip install " + " ".join(missing))
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def get_system_info() -> Dict[str, Any]:
|
|
196
|
+
"""Get system information."""
|
|
197
|
+
import platform
|
|
198
|
+
import psutil
|
|
199
|
+
|
|
200
|
+
info = {
|
|
201
|
+
'platform': platform.platform(),
|
|
202
|
+
'python_version': platform.python_version(),
|
|
203
|
+
'cpu_count': psutil.cpu_count(),
|
|
204
|
+
'memory_gb': round(psutil.virtual_memory().total / (1024**3), 1),
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
# Check for CUDA
|
|
208
|
+
try:
|
|
209
|
+
import torch
|
|
210
|
+
info['cuda_available'] = torch.cuda.is_available()
|
|
211
|
+
if info['cuda_available']:
|
|
212
|
+
info['cuda_version'] = torch.version.cuda
|
|
213
|
+
info['gpu_count'] = torch.cuda.device_count()
|
|
214
|
+
info['gpu_name'] = torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "Unknown"
|
|
215
|
+
except ImportError:
|
|
216
|
+
info['cuda_available'] = False
|
|
217
|
+
|
|
218
|
+
return info
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def print_system_info():
|
|
222
|
+
"""Print system information."""
|
|
223
|
+
info = get_system_info()
|
|
224
|
+
|
|
225
|
+
print("\n🖥️ System Information:")
|
|
226
|
+
print("=" * 30)
|
|
227
|
+
print(f"Platform: {info['platform']}")
|
|
228
|
+
print(f"Python: {info['python_version']}")
|
|
229
|
+
print(f"CPU Cores: {info['cpu_count']}")
|
|
230
|
+
print(f"Memory: {info['memory_gb']} GB")
|
|
231
|
+
|
|
232
|
+
if info['cuda_available']:
|
|
233
|
+
print(f"CUDA: {info['cuda_version']}")
|
|
234
|
+
print(f"GPU: {info['gpu_name']} ({info['gpu_count']} device(s))")
|
|
235
|
+
else:
|
|
236
|
+
print("CUDA: Not available")
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def create_output_directory(output_dir: str) -> bool:
|
|
240
|
+
"""Create output directory if it doesn't exist."""
|
|
241
|
+
try:
|
|
242
|
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
243
|
+
return True
|
|
244
|
+
except Exception as e:
|
|
245
|
+
print_error(f"Failed to create output directory '{output_dir}': {e}")
|
|
246
|
+
return False
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def save_results(results: Dict[str, Any], output_file: str) -> bool:
|
|
250
|
+
"""Save results to file."""
|
|
251
|
+
try:
|
|
252
|
+
import json
|
|
253
|
+
with open(output_file, 'w') as f:
|
|
254
|
+
json.dump(results, f, indent=2)
|
|
255
|
+
return True
|
|
256
|
+
except Exception as e:
|
|
257
|
+
print_error(f"Failed to save results to '{output_file}': {e}")
|
|
258
|
+
return False
|
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
|
-
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
import math
|
|
6
|
+
from ..models.lora import LoRALinear, LoRAConfig
|
|
4
7
|
|
|
5
8
|
class Attention(nn.Module):
|
|
6
9
|
def __init__(self, dim, num_heads=8, lora_config=None):
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CCoT (Contrastive Chain-of-Thought)
|
|
3
|
+
----------------------------------
|
|
4
|
+
Extends CoT by using contrastive learning to distinguish between correct and incorrect reasoning chains.
|
|
5
|
+
|
|
6
|
+
Example usage:
|
|
7
|
+
>>> from langvision.concepts.ccot import CCoT
|
|
8
|
+
>>> class MyCCoT(CCoT):
|
|
9
|
+
... def contrastive_train(self, positive_chains, negative_chains):
|
|
10
|
+
... return super().contrastive_train(positive_chains, negative_chains)
|
|
11
|
+
>>> ccot = MyCCoT()
|
|
12
|
+
>>> pos = [['Step 1: Think', 'Step 2: Solve']]
|
|
13
|
+
>>> neg = [['Step 1: Guess', 'Step 2: Wrong']]
|
|
14
|
+
>>> ccot.contrastive_train(pos, neg)
|
|
15
|
+
"""
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
17
|
+
from typing import Any, List
|
|
18
|
+
|
|
19
|
+
class CCoT(ABC):
|
|
20
|
+
"""
|
|
21
|
+
Abstract base class for Contrastive Chain-of-Thought (CCoT).
|
|
22
|
+
"""
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def contrastive_train(self, positive_chains: List[Any], negative_chains: List[Any]) -> None:
|
|
25
|
+
"""
|
|
26
|
+
Train using positive and negative reasoning chains.
|
|
27
|
+
"""
|
|
28
|
+
# Simple example: print contrastive pairs (toy logic)
|
|
29
|
+
for pos, neg in zip(positive_chains, negative_chains):
|
|
30
|
+
print(f"Positive: {pos} | Negative: {neg}")
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CoT (Chain-of-Thought)
|
|
3
|
+
----------------------
|
|
4
|
+
A prompting or training method that encourages models to reason step-by-step, improving performance on complex tasks.
|
|
5
|
+
|
|
6
|
+
Example usage:
|
|
7
|
+
>>> from langvision.concepts.cot import CoT
|
|
8
|
+
>>> class MyCoT(CoT):
|
|
9
|
+
... def generate_chain(self, prompt: str) -> list:
|
|
10
|
+
... return super().generate_chain(prompt)
|
|
11
|
+
>>> cot = MyCoT()
|
|
12
|
+
>>> chain = cot.generate_chain('What is 2 + 2?')
|
|
13
|
+
>>> print(chain)
|
|
14
|
+
"""
|
|
15
|
+
from abc import ABC, abstractmethod
|
|
16
|
+
from typing import Any, List
|
|
17
|
+
|
|
18
|
+
class CoT(ABC):
|
|
19
|
+
"""
|
|
20
|
+
Abstract base class for Chain-of-Thought (CoT).
|
|
21
|
+
"""
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def generate_chain(self, prompt: str) -> List[Any]:
|
|
24
|
+
"""
|
|
25
|
+
Generate a chain of thought for a given prompt.
|
|
26
|
+
"""
|
|
27
|
+
# Simple example: split prompt into steps (toy logic)
|
|
28
|
+
steps = [f"Step {i+1}: {word}" for i, word in enumerate(prompt.split())]
|
|
29
|
+
return steps
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DPO (Direct Preference Optimization)
|
|
3
|
+
-----------------------------------
|
|
4
|
+
An RL method that directly optimizes model outputs based on preference data, often used in LLM fine-tuning.
|
|
5
|
+
|
|
6
|
+
Example usage:
|
|
7
|
+
>>> from langvision.concepts.dpo import DPO
|
|
8
|
+
>>> import torch
|
|
9
|
+
>>> class MyDPO(DPO):
|
|
10
|
+
... def optimize_with_preferences(self, model, preferences, optimizer):
|
|
11
|
+
... super().optimize_with_preferences(model, preferences, optimizer)
|
|
12
|
+
>>> model = torch.nn.Linear(2, 1)
|
|
13
|
+
>>> preferences = [(torch.tensor([1.0, 2.0]), 1.0), (torch.tensor([2.0, 3.0]), -1.0)]
|
|
14
|
+
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
|
15
|
+
>>> dpo = MyDPO()
|
|
16
|
+
>>> dpo.optimize_with_preferences(model, preferences, optimizer)
|
|
17
|
+
"""
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
from typing import Any, List
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
class DPO(ABC):
|
|
23
|
+
"""
|
|
24
|
+
Abstract base class for Direct Preference Optimization (DPO).
|
|
25
|
+
"""
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def optimize_with_preferences(self, model: torch.nn.Module, preferences: List[Any], optimizer: torch.optim.Optimizer) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Optimize the model using preference data.
|
|
30
|
+
"""
|
|
31
|
+
for x, pref in preferences:
|
|
32
|
+
optimizer.zero_grad()
|
|
33
|
+
output = model(x)
|
|
34
|
+
# Simple preference loss: maximize output if preferred, minimize if not
|
|
35
|
+
loss = -pref * output.sum()
|
|
36
|
+
loss.backward()
|
|
37
|
+
optimizer.step()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GRPO (Generalized Reinforcement Policy Optimization)
|
|
3
|
+
---------------------------------------------------
|
|
4
|
+
A family of algorithms for optimizing policies in reinforcement learning, generalizing methods like PPO and DPO.
|
|
5
|
+
|
|
6
|
+
Example usage:
|
|
7
|
+
>>> from langvision.concepts.grpo import GRPO
|
|
8
|
+
>>> class MyGRPO(GRPO):
|
|
9
|
+
... def optimize(self, policy: Any, rewards: list) -> None:
|
|
10
|
+
... # Implement GRPO optimization
|
|
11
|
+
... pass
|
|
12
|
+
"""
|
|
13
|
+
from abc import ABC, abstractmethod
|
|
14
|
+
from typing import Any, List
|
|
15
|
+
|
|
16
|
+
class GRPO(ABC):
|
|
17
|
+
"""
|
|
18
|
+
Abstract base class for Generalized Reinforcement Policy Optimization (GRPO).
|
|
19
|
+
"""
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def optimize(self, policy: Any, rewards: List[Any]) -> None:
|
|
22
|
+
"""
|
|
23
|
+
Optimize the policy using provided rewards.
|
|
24
|
+
"""
|
|
25
|
+
pass
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LIME (Local Interpretable Model-agnostic Explanations)
|
|
3
|
+
-----------------------------------------------------
|
|
4
|
+
A technique for explaining model predictions by approximating them locally with interpretable models.
|
|
5
|
+
|
|
6
|
+
Example usage:
|
|
7
|
+
>>> from langvision.concepts.lime import LIME
|
|
8
|
+
>>> import torch
|
|
9
|
+
>>> class MyLIME(LIME):
|
|
10
|
+
... def explain(self, model, input_data):
|
|
11
|
+
... return super().explain(model, input_data)
|
|
12
|
+
>>> model = torch.nn.Linear(2, 1)
|
|
13
|
+
>>> lime = MyLIME()
|
|
14
|
+
>>> explanation = lime.explain(model, [[0.5, 1.0], [1.0, 2.0]])
|
|
15
|
+
>>> print(explanation)
|
|
16
|
+
"""
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
from typing import Any, Dict
|
|
19
|
+
|
|
20
|
+
class LIME(ABC):
|
|
21
|
+
"""
|
|
22
|
+
Abstract base class for Local Interpretable Model-agnostic Explanations (LIME).
|
|
23
|
+
"""
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def explain(self, model: Any, input_data: Any) -> Dict[str, Any]:
|
|
26
|
+
"""
|
|
27
|
+
Generate a local explanation for the model's prediction on input_data using lime if available.
|
|
28
|
+
"""
|
|
29
|
+
try:
|
|
30
|
+
from lime.lime_tabular import LimeTabularExplainer
|
|
31
|
+
except ImportError:
|
|
32
|
+
raise ImportError("Please install the 'lime' package to use LIME explanations.")
|
|
33
|
+
import numpy as np
|
|
34
|
+
X = np.array(input_data)
|
|
35
|
+
explainer = LimeTabularExplainer(X, mode="regression")
|
|
36
|
+
explanation = explainer.explain_instance(X[0], lambda x: model(torch.tensor(x, dtype=torch.float32)).detach().numpy())
|
|
37
|
+
return {"explanation": explanation.as_list()}
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PPO (Proximal Policy Optimization)
|
|
3
|
+
---------------------------------
|
|
4
|
+
A popular RL algorithm that balances exploration and exploitation by limiting policy updates to stay within a trust region.
|
|
5
|
+
|
|
6
|
+
Example usage:
|
|
7
|
+
>>> import torch
|
|
8
|
+
>>> from langvision.concepts.ppo import PPO
|
|
9
|
+
>>> class MyPPO(PPO):
|
|
10
|
+
... def step(self, policy, old_log_probs, states, actions, rewards, optimizer):
|
|
11
|
+
... super().step(policy, old_log_probs, states, actions, rewards, optimizer)
|
|
12
|
+
>>> policy = torch.nn.Linear(2, 2)
|
|
13
|
+
>>> old_log_probs = torch.tensor([0.5, 0.5])
|
|
14
|
+
>>> states = torch.randn(2, 2)
|
|
15
|
+
>>> actions = torch.tensor([0, 1])
|
|
16
|
+
>>> rewards = torch.tensor([1.0, 0.5])
|
|
17
|
+
>>> optimizer = torch.optim.Adam(policy.parameters(), lr=0.01)
|
|
18
|
+
>>> ppo = MyPPO()
|
|
19
|
+
>>> ppo.step(policy, old_log_probs, states, actions, rewards, optimizer)
|
|
20
|
+
"""
|
|
21
|
+
from abc import ABC, abstractmethod
|
|
22
|
+
from typing import Any
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
class PPO(ABC):
|
|
26
|
+
"""
|
|
27
|
+
Abstract base class for Proximal Policy Optimization (PPO).
|
|
28
|
+
"""
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def step(self, policy: torch.nn.Module, old_log_probs: torch.Tensor, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, optimizer: torch.optim.Optimizer) -> None:
|
|
31
|
+
"""
|
|
32
|
+
Perform a PPO update step.
|
|
33
|
+
"""
|
|
34
|
+
# Forward pass
|
|
35
|
+
logits = policy(states)
|
|
36
|
+
log_probs = torch.log_softmax(logits, dim=-1)
|
|
37
|
+
selected_log_probs = log_probs[range(len(actions)), actions]
|
|
38
|
+
# Calculate ratio
|
|
39
|
+
ratio = torch.exp(selected_log_probs - old_log_probs)
|
|
40
|
+
# Calculate surrogate loss
|
|
41
|
+
advantage = rewards - rewards.mean()
|
|
42
|
+
surr1 = ratio * advantage
|
|
43
|
+
surr2 = torch.clamp(ratio, 0.8, 1.2) * advantage
|
|
44
|
+
loss = -torch.min(surr1, surr2).mean()
|
|
45
|
+
optimizer.zero_grad()
|
|
46
|
+
loss.backward()
|
|
47
|
+
optimizer.step()
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RLHF (Reinforcement Learning from Human Feedback)
|
|
3
|
+
-------------------------------------------------
|
|
4
|
+
A technique where models are trained using feedback from humans to align outputs with human preferences.
|
|
5
|
+
|
|
6
|
+
Example usage:
|
|
7
|
+
>>> import torch
|
|
8
|
+
>>> from langvision.concepts.rlhf import RLHF
|
|
9
|
+
>>> class MyRLHF(RLHF):
|
|
10
|
+
... def train(self, model, data, feedback_fn, optimizer):
|
|
11
|
+
... super().train(model, data, feedback_fn, optimizer)
|
|
12
|
+
>>> model = torch.nn.Linear(2, 1)
|
|
13
|
+
>>> data = [torch.tensor([1.0, 2.0]), torch.tensor([2.0, 3.0])]
|
|
14
|
+
>>> def feedback_fn(output): return 1.0 if output.item() > 0 else -1.0
|
|
15
|
+
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
|
16
|
+
>>> rlhf = MyRLHF()
|
|
17
|
+
>>> rlhf.train(model, data, feedback_fn, optimizer)
|
|
18
|
+
"""
|
|
19
|
+
from abc import ABC, abstractmethod
|
|
20
|
+
from typing import Any, Callable, List
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
class RLHF(ABC):
|
|
24
|
+
"""
|
|
25
|
+
Abstract base class for Reinforcement Learning from Human Feedback (RLHF).
|
|
26
|
+
"""
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def train(self, model: torch.nn.Module, data: List[Any], feedback_fn: Callable[[Any], float], optimizer: torch.optim.Optimizer) -> None:
|
|
29
|
+
"""
|
|
30
|
+
Train the model using data and a feedback function that simulates human feedback.
|
|
31
|
+
"""
|
|
32
|
+
for x in data:
|
|
33
|
+
optimizer.zero_grad()
|
|
34
|
+
output = model(x)
|
|
35
|
+
# Synthetic feedback as reward
|
|
36
|
+
reward = feedback_fn(output)
|
|
37
|
+
# Simple loss: negative reward (maximize reward)
|
|
38
|
+
loss = -reward * output.sum()
|
|
39
|
+
loss.backward()
|
|
40
|
+
optimizer.step()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RLVR (Reinforcement Learning with Value Ranking)
|
|
3
|
+
-----------------------------------------------
|
|
4
|
+
A method that uses value-based ranking to guide policy optimization in RL settings.
|
|
5
|
+
|
|
6
|
+
Example usage:
|
|
7
|
+
>>> from langvision.concepts.rlvr import RLVR
|
|
8
|
+
>>> class MyRLVR(RLVR):
|
|
9
|
+
... def rank_and_update(self, values: list) -> None:
|
|
10
|
+
... # Implement value ranking and update
|
|
11
|
+
... pass
|
|
12
|
+
"""
|
|
13
|
+
from abc import ABC, abstractmethod
|
|
14
|
+
from typing import Any, List
|
|
15
|
+
|
|
16
|
+
class RLVR(ABC):
|
|
17
|
+
"""
|
|
18
|
+
Abstract base class for Reinforcement Learning with Value Ranking (RLVR).
|
|
19
|
+
"""
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def rank_and_update(self, values: List[Any]) -> None:
|
|
22
|
+
"""
|
|
23
|
+
Rank values and update the policy accordingly.
|
|
24
|
+
"""
|
|
25
|
+
pass
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SHAP (SHapley Additive exPlanations)
|
|
3
|
+
------------------------------------
|
|
4
|
+
A unified approach to interpreting model predictions using Shapley values from cooperative game theory.
|
|
5
|
+
|
|
6
|
+
Example usage:
|
|
7
|
+
>>> from langvision.concepts.shap import SHAP
|
|
8
|
+
>>> import torch
|
|
9
|
+
>>> class MySHAP(SHAP):
|
|
10
|
+
... def explain(self, model, input_data):
|
|
11
|
+
... return super().explain(model, input_data)
|
|
12
|
+
>>> model = torch.nn.Linear(2, 1)
|
|
13
|
+
>>> shap = MySHAP()
|
|
14
|
+
>>> explanation = shap.explain(model, [[0.5, 1.0], [1.0, 2.0]])
|
|
15
|
+
>>> print(explanation)
|
|
16
|
+
"""
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
from typing import Any, Dict
|
|
19
|
+
|
|
20
|
+
class SHAP(ABC):
|
|
21
|
+
"""
|
|
22
|
+
Abstract base class for SHapley Additive exPlanations (SHAP).
|
|
23
|
+
"""
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def explain(self, model: Any, input_data: Any) -> Dict[str, Any]:
|
|
26
|
+
"""
|
|
27
|
+
Generate SHAP values for the model's prediction on input_data using shap if available.
|
|
28
|
+
"""
|
|
29
|
+
try:
|
|
30
|
+
import shap
|
|
31
|
+
except ImportError:
|
|
32
|
+
raise ImportError("Please install the 'shap' package to use SHAP explanations.")
|
|
33
|
+
import numpy as np
|
|
34
|
+
X = np.array(input_data)
|
|
35
|
+
explainer = shap.Explainer(model)
|
|
36
|
+
shap_values = explainer(X)
|
|
37
|
+
return {"shap_values": shap_values.values.tolist()}
|