langvision 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langvision might be problematic. Click here for more details.

Files changed (41) hide show
  1. langvision/__init__.py +77 -2
  2. langvision/callbacks/base.py +166 -7
  3. langvision/cli/__init__.py +85 -0
  4. langvision/cli/complete_cli.py +319 -0
  5. langvision/cli/config.py +344 -0
  6. langvision/cli/evaluate.py +201 -0
  7. langvision/cli/export.py +177 -0
  8. langvision/cli/finetune.py +165 -48
  9. langvision/cli/model_zoo.py +162 -0
  10. langvision/cli/train.py +27 -13
  11. langvision/cli/utils.py +258 -0
  12. langvision/components/attention.py +4 -1
  13. langvision/concepts/__init__.py +9 -0
  14. langvision/concepts/ccot.py +30 -0
  15. langvision/concepts/cot.py +29 -0
  16. langvision/concepts/dpo.py +37 -0
  17. langvision/concepts/grpo.py +25 -0
  18. langvision/concepts/lime.py +37 -0
  19. langvision/concepts/ppo.py +47 -0
  20. langvision/concepts/rlhf.py +40 -0
  21. langvision/concepts/rlvr.py +25 -0
  22. langvision/concepts/shap.py +37 -0
  23. langvision/data/enhanced_datasets.py +582 -0
  24. langvision/model_zoo.py +169 -2
  25. langvision/models/lora.py +189 -17
  26. langvision/models/multimodal.py +297 -0
  27. langvision/models/resnet.py +303 -0
  28. langvision/training/advanced_trainer.py +478 -0
  29. langvision/training/trainer.py +30 -2
  30. langvision/utils/config.py +180 -9
  31. langvision/utils/metrics.py +448 -0
  32. langvision/utils/setup.py +266 -0
  33. langvision-0.1.0.dist-info/METADATA +50 -0
  34. langvision-0.1.0.dist-info/RECORD +61 -0
  35. {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/WHEEL +1 -1
  36. langvision-0.1.0.dist-info/entry_points.txt +2 -0
  37. langvision-0.0.1.dist-info/METADATA +0 -463
  38. langvision-0.0.1.dist-info/RECORD +0 -40
  39. langvision-0.0.1.dist-info/entry_points.txt +0 -2
  40. langvision-0.0.1.dist-info/licenses/LICENSE +0 -21
  41. {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/top_level.txt +0 -0
langvision/cli/train.py CHANGED
@@ -9,19 +9,33 @@ import os
9
9
 
10
10
 
11
11
  def parse_args():
12
- parser = argparse.ArgumentParser(description='Train or evaluate VisionTransformer with LoRA')
13
- parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='Dataset to use')
14
- parser.add_argument('--data_dir', type=str, default='./data', help='Dataset directory')
15
- parser.add_argument('--epochs', type=int, default=10)
16
- parser.add_argument('--batch_size', type=int, default=64)
17
- parser.add_argument('--learning_rate', type=float, default=1e-3)
18
- parser.add_argument('--lora_rank', type=int, default=4)
19
- parser.add_argument('--lora_alpha', type=float, default=1.0)
20
- parser.add_argument('--lora_dropout', type=float, default=0.1)
21
- parser.add_argument('--output_dir', type=str, default='./checkpoints')
22
- parser.add_argument('--eval', action='store_true', help='Run evaluation only')
23
- parser.add_argument('--export', action='store_true', help='Export model for inference')
24
- parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
12
+ parser = argparse.ArgumentParser(
13
+ description='Train or evaluate VisionTransformer with LoRA',
14
+ epilog='''\nExamples:\n langvision train --dataset cifar10 --epochs 5\n langvision train --dataset cifar100 --lora_rank 8 --eval\n''',
15
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
16
+ )
17
+ # Data
18
+ data_group = parser.add_argument_group('Data')
19
+ data_group.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='Dataset to use')
20
+ data_group.add_argument('--data_dir', type=str, default='./data', help='Dataset directory')
21
+ # Training
22
+ train_group = parser.add_argument_group('Training')
23
+ train_group.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
24
+ train_group.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
25
+ train_group.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate')
26
+ # LoRA
27
+ lora_group = parser.add_argument_group('LoRA')
28
+ lora_group.add_argument('--lora_rank', type=int, default=4, help='LoRA rank (low-rank adaptation)')
29
+ lora_group.add_argument('--lora_alpha', type=float, default=1.0, help='LoRA alpha scaling')
30
+ lora_group.add_argument('--lora_dropout', type=float, default=0.1, help='LoRA dropout rate')
31
+ # Output
32
+ output_group = parser.add_argument_group('Output')
33
+ output_group.add_argument('--output_dir', type=str, default='./checkpoints', help='Directory to save checkpoints')
34
+ # Misc
35
+ misc_group = parser.add_argument_group('Misc')
36
+ misc_group.add_argument('--eval', action='store_true', help='Run evaluation only (no training)')
37
+ misc_group.add_argument('--export', action='store_true', help='Export model for inference')
38
+ misc_group.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use (cuda or cpu)')
25
39
  return parser.parse_args()
26
40
 
27
41
 
@@ -0,0 +1,258 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Utility functions for CLI operations.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import logging
9
+ from typing import Optional, Dict, Any
10
+ from pathlib import Path
11
+
12
+
13
+ def setup_logging_with_progress(log_level: str = 'info', log_file: Optional[str] = None) -> logging.Logger:
14
+ """Set up logging with progress bar support."""
15
+ numeric_level = getattr(logging, log_level.upper(), None)
16
+ if not isinstance(numeric_level, int):
17
+ numeric_level = logging.INFO
18
+
19
+ # Create logger
20
+ logger = logging.getLogger('langvision')
21
+ logger.setLevel(numeric_level)
22
+
23
+ # Clear existing handlers
24
+ logger.handlers.clear()
25
+
26
+ # Create formatter
27
+ formatter = logging.Formatter('[%(levelname)s] %(message)s')
28
+
29
+ # Console handler
30
+ console_handler = logging.StreamHandler(sys.stdout)
31
+ console_handler.setLevel(numeric_level)
32
+ console_handler.setFormatter(formatter)
33
+ logger.addHandler(console_handler)
34
+
35
+ # File handler (if specified)
36
+ if log_file:
37
+ file_handler = logging.FileHandler(log_file)
38
+ file_handler.setLevel(numeric_level)
39
+ file_handler.setFormatter(formatter)
40
+ logger.addHandler(file_handler)
41
+
42
+ return logger
43
+
44
+
45
+ def create_progress_bar(total: int, desc: str = "Processing") -> Any:
46
+ """Create a progress bar using tqdm."""
47
+ try:
48
+ from tqdm import tqdm
49
+ return tqdm(total=total, desc=desc, unit="it", ncols=100,
50
+ bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
51
+ except ImportError:
52
+ # Fallback to simple progress indicator
53
+ return SimpleProgressBar(total, desc)
54
+
55
+
56
+ class SimpleProgressBar:
57
+ """Simple progress bar fallback when tqdm is not available."""
58
+
59
+ def __init__(self, total: int, desc: str = "Processing"):
60
+ self.total = total
61
+ self.desc = desc
62
+ self.current = 0
63
+ self.width = 50
64
+
65
+ def update(self, n: int = 1):
66
+ """Update progress."""
67
+ self.current += n
68
+ self._display()
69
+
70
+ def _display(self):
71
+ """Display progress."""
72
+ if self.total > 0:
73
+ percent = self.current / self.total
74
+ filled = int(self.width * percent)
75
+ bar = '█' * filled + '░' * (self.width - filled)
76
+ print(f'\r{self.desc}: |{bar}| {self.current}/{self.total} ({percent:.1%})', end='', flush=True)
77
+
78
+ def close(self):
79
+ """Close progress bar."""
80
+ print() # New line
81
+
82
+
83
+ def validate_file_path(file_path: str, must_exist: bool = True) -> bool:
84
+ """Validate file path."""
85
+ path = Path(file_path)
86
+
87
+ if must_exist and not path.exists():
88
+ return False
89
+
90
+ if not must_exist:
91
+ # Check if parent directory exists
92
+ if not path.parent.exists():
93
+ try:
94
+ path.parent.mkdir(parents=True, exist_ok=True)
95
+ except Exception:
96
+ return False
97
+
98
+ return True
99
+
100
+
101
+ def get_file_size_mb(file_path: str) -> float:
102
+ """Get file size in MB."""
103
+ try:
104
+ size_bytes = os.path.getsize(file_path)
105
+ return size_bytes / (1024 * 1024)
106
+ except OSError:
107
+ return 0.0
108
+
109
+
110
+ def format_time(seconds: float) -> str:
111
+ """Format time in human-readable format."""
112
+ if seconds < 60:
113
+ return f"{seconds:.1f}s"
114
+ elif seconds < 3600:
115
+ minutes = seconds / 60
116
+ return f"{minutes:.1f}m"
117
+ else:
118
+ hours = seconds / 3600
119
+ return f"{hours:.1f}h"
120
+
121
+
122
+ def format_size(size_bytes: int) -> str:
123
+ """Format size in human-readable format."""
124
+ for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
125
+ if size_bytes < 1024.0:
126
+ return f"{size_bytes:.1f}{unit}"
127
+ size_bytes /= 1024.0
128
+ return f"{size_bytes:.1f}PB"
129
+
130
+
131
+ def print_success(message: str):
132
+ """Print success message with green color."""
133
+ print(f"\033[1;32m✅ {message}\033[0m")
134
+
135
+
136
+ def print_error(message: str):
137
+ """Print error message with red color."""
138
+ print(f"\033[1;31m❌ {message}\033[0m")
139
+
140
+
141
+ def print_warning(message: str):
142
+ """Print warning message with yellow color."""
143
+ print(f"\033[1;33m⚠️ {message}\033[0m")
144
+
145
+
146
+ def print_info(message: str):
147
+ """Print info message with blue color."""
148
+ print(f"\033[1;34mℹ️ {message}\033[0m")
149
+
150
+
151
+ def print_step(step: int, total: int, message: str):
152
+ """Print step message with progress indicator."""
153
+ print(f"\033[1;36m[STEP {step}/{total}]\033[0m {message}")
154
+
155
+
156
+ def check_dependencies() -> Dict[str, bool]:
157
+ """Check if required dependencies are available."""
158
+ dependencies = {
159
+ 'torch': False,
160
+ 'torchvision': False,
161
+ 'numpy': False,
162
+ 'tqdm': False,
163
+ 'pyyaml': False,
164
+ 'matplotlib': False,
165
+ 'pillow': False,
166
+ }
167
+
168
+ for dep in dependencies:
169
+ try:
170
+ __import__(dep)
171
+ dependencies[dep] = True
172
+ except ImportError:
173
+ dependencies[dep] = False
174
+
175
+ return dependencies
176
+
177
+
178
+ def print_dependency_status():
179
+ """Print dependency status."""
180
+ deps = check_dependencies()
181
+
182
+ print("\n📦 Dependency Status:")
183
+ print("=" * 30)
184
+
185
+ for dep, available in deps.items():
186
+ status = "✅" if available else "❌"
187
+ print(f"{status} {dep}")
188
+
189
+ missing = [dep for dep, available in deps.items() if not available]
190
+ if missing:
191
+ print(f"\n⚠️ Missing dependencies: {', '.join(missing)}")
192
+ print("Install with: pip install " + " ".join(missing))
193
+
194
+
195
+ def get_system_info() -> Dict[str, Any]:
196
+ """Get system information."""
197
+ import platform
198
+ import psutil
199
+
200
+ info = {
201
+ 'platform': platform.platform(),
202
+ 'python_version': platform.python_version(),
203
+ 'cpu_count': psutil.cpu_count(),
204
+ 'memory_gb': round(psutil.virtual_memory().total / (1024**3), 1),
205
+ }
206
+
207
+ # Check for CUDA
208
+ try:
209
+ import torch
210
+ info['cuda_available'] = torch.cuda.is_available()
211
+ if info['cuda_available']:
212
+ info['cuda_version'] = torch.version.cuda
213
+ info['gpu_count'] = torch.cuda.device_count()
214
+ info['gpu_name'] = torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "Unknown"
215
+ except ImportError:
216
+ info['cuda_available'] = False
217
+
218
+ return info
219
+
220
+
221
+ def print_system_info():
222
+ """Print system information."""
223
+ info = get_system_info()
224
+
225
+ print("\n🖥️ System Information:")
226
+ print("=" * 30)
227
+ print(f"Platform: {info['platform']}")
228
+ print(f"Python: {info['python_version']}")
229
+ print(f"CPU Cores: {info['cpu_count']}")
230
+ print(f"Memory: {info['memory_gb']} GB")
231
+
232
+ if info['cuda_available']:
233
+ print(f"CUDA: {info['cuda_version']}")
234
+ print(f"GPU: {info['gpu_name']} ({info['gpu_count']} device(s))")
235
+ else:
236
+ print("CUDA: Not available")
237
+
238
+
239
+ def create_output_directory(output_dir: str) -> bool:
240
+ """Create output directory if it doesn't exist."""
241
+ try:
242
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
243
+ return True
244
+ except Exception as e:
245
+ print_error(f"Failed to create output directory '{output_dir}': {e}")
246
+ return False
247
+
248
+
249
+ def save_results(results: Dict[str, Any], output_file: str) -> bool:
250
+ """Save results to file."""
251
+ try:
252
+ import json
253
+ with open(output_file, 'w') as f:
254
+ json.dump(results, f, indent=2)
255
+ return True
256
+ except Exception as e:
257
+ print_error(f"Failed to save results to '{output_file}': {e}")
258
+ return False
@@ -1,6 +1,9 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
- from langchain.models.lora import LoRALinear
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Tuple
5
+ import math
6
+ from ..models.lora import LoRALinear, LoRAConfig
4
7
 
5
8
  class Attention(nn.Module):
6
9
  def __init__(self, dim, num_heads=8, lora_config=None):
@@ -0,0 +1,9 @@
1
+ from .rlhf import RLHF
2
+ from .cot import CoT
3
+ from .ccot import CCoT
4
+ from .grpo import GRPO
5
+ from .rlvr import RLVR
6
+ from .dpo import DPO
7
+ from .ppo import PPO
8
+ from .lime import LIME
9
+ from .shap import SHAP
@@ -0,0 +1,30 @@
1
+ """
2
+ CCoT (Contrastive Chain-of-Thought)
3
+ ----------------------------------
4
+ Extends CoT by using contrastive learning to distinguish between correct and incorrect reasoning chains.
5
+
6
+ Example usage:
7
+ >>> from langvision.concepts.ccot import CCoT
8
+ >>> class MyCCoT(CCoT):
9
+ ... def contrastive_train(self, positive_chains, negative_chains):
10
+ ... return super().contrastive_train(positive_chains, negative_chains)
11
+ >>> ccot = MyCCoT()
12
+ >>> pos = [['Step 1: Think', 'Step 2: Solve']]
13
+ >>> neg = [['Step 1: Guess', 'Step 2: Wrong']]
14
+ >>> ccot.contrastive_train(pos, neg)
15
+ """
16
+ from abc import ABC, abstractmethod
17
+ from typing import Any, List
18
+
19
+ class CCoT(ABC):
20
+ """
21
+ Abstract base class for Contrastive Chain-of-Thought (CCoT).
22
+ """
23
+ @abstractmethod
24
+ def contrastive_train(self, positive_chains: List[Any], negative_chains: List[Any]) -> None:
25
+ """
26
+ Train using positive and negative reasoning chains.
27
+ """
28
+ # Simple example: print contrastive pairs (toy logic)
29
+ for pos, neg in zip(positive_chains, negative_chains):
30
+ print(f"Positive: {pos} | Negative: {neg}")
@@ -0,0 +1,29 @@
1
+ """
2
+ CoT (Chain-of-Thought)
3
+ ----------------------
4
+ A prompting or training method that encourages models to reason step-by-step, improving performance on complex tasks.
5
+
6
+ Example usage:
7
+ >>> from langvision.concepts.cot import CoT
8
+ >>> class MyCoT(CoT):
9
+ ... def generate_chain(self, prompt: str) -> list:
10
+ ... return super().generate_chain(prompt)
11
+ >>> cot = MyCoT()
12
+ >>> chain = cot.generate_chain('What is 2 + 2?')
13
+ >>> print(chain)
14
+ """
15
+ from abc import ABC, abstractmethod
16
+ from typing import Any, List
17
+
18
+ class CoT(ABC):
19
+ """
20
+ Abstract base class for Chain-of-Thought (CoT).
21
+ """
22
+ @abstractmethod
23
+ def generate_chain(self, prompt: str) -> List[Any]:
24
+ """
25
+ Generate a chain of thought for a given prompt.
26
+ """
27
+ # Simple example: split prompt into steps (toy logic)
28
+ steps = [f"Step {i+1}: {word}" for i, word in enumerate(prompt.split())]
29
+ return steps
@@ -0,0 +1,37 @@
1
+ """
2
+ DPO (Direct Preference Optimization)
3
+ -----------------------------------
4
+ An RL method that directly optimizes model outputs based on preference data, often used in LLM fine-tuning.
5
+
6
+ Example usage:
7
+ >>> from langvision.concepts.dpo import DPO
8
+ >>> import torch
9
+ >>> class MyDPO(DPO):
10
+ ... def optimize_with_preferences(self, model, preferences, optimizer):
11
+ ... super().optimize_with_preferences(model, preferences, optimizer)
12
+ >>> model = torch.nn.Linear(2, 1)
13
+ >>> preferences = [(torch.tensor([1.0, 2.0]), 1.0), (torch.tensor([2.0, 3.0]), -1.0)]
14
+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
15
+ >>> dpo = MyDPO()
16
+ >>> dpo.optimize_with_preferences(model, preferences, optimizer)
17
+ """
18
+ from abc import ABC, abstractmethod
19
+ from typing import Any, List
20
+ import torch
21
+
22
+ class DPO(ABC):
23
+ """
24
+ Abstract base class for Direct Preference Optimization (DPO).
25
+ """
26
+ @abstractmethod
27
+ def optimize_with_preferences(self, model: torch.nn.Module, preferences: List[Any], optimizer: torch.optim.Optimizer) -> None:
28
+ """
29
+ Optimize the model using preference data.
30
+ """
31
+ for x, pref in preferences:
32
+ optimizer.zero_grad()
33
+ output = model(x)
34
+ # Simple preference loss: maximize output if preferred, minimize if not
35
+ loss = -pref * output.sum()
36
+ loss.backward()
37
+ optimizer.step()
@@ -0,0 +1,25 @@
1
+ """
2
+ GRPO (Generalized Reinforcement Policy Optimization)
3
+ ---------------------------------------------------
4
+ A family of algorithms for optimizing policies in reinforcement learning, generalizing methods like PPO and DPO.
5
+
6
+ Example usage:
7
+ >>> from langvision.concepts.grpo import GRPO
8
+ >>> class MyGRPO(GRPO):
9
+ ... def optimize(self, policy: Any, rewards: list) -> None:
10
+ ... # Implement GRPO optimization
11
+ ... pass
12
+ """
13
+ from abc import ABC, abstractmethod
14
+ from typing import Any, List
15
+
16
+ class GRPO(ABC):
17
+ """
18
+ Abstract base class for Generalized Reinforcement Policy Optimization (GRPO).
19
+ """
20
+ @abstractmethod
21
+ def optimize(self, policy: Any, rewards: List[Any]) -> None:
22
+ """
23
+ Optimize the policy using provided rewards.
24
+ """
25
+ pass
@@ -0,0 +1,37 @@
1
+ """
2
+ LIME (Local Interpretable Model-agnostic Explanations)
3
+ -----------------------------------------------------
4
+ A technique for explaining model predictions by approximating them locally with interpretable models.
5
+
6
+ Example usage:
7
+ >>> from langvision.concepts.lime import LIME
8
+ >>> import torch
9
+ >>> class MyLIME(LIME):
10
+ ... def explain(self, model, input_data):
11
+ ... return super().explain(model, input_data)
12
+ >>> model = torch.nn.Linear(2, 1)
13
+ >>> lime = MyLIME()
14
+ >>> explanation = lime.explain(model, [[0.5, 1.0], [1.0, 2.0]])
15
+ >>> print(explanation)
16
+ """
17
+ from abc import ABC, abstractmethod
18
+ from typing import Any, Dict
19
+
20
+ class LIME(ABC):
21
+ """
22
+ Abstract base class for Local Interpretable Model-agnostic Explanations (LIME).
23
+ """
24
+ @abstractmethod
25
+ def explain(self, model: Any, input_data: Any) -> Dict[str, Any]:
26
+ """
27
+ Generate a local explanation for the model's prediction on input_data using lime if available.
28
+ """
29
+ try:
30
+ from lime.lime_tabular import LimeTabularExplainer
31
+ except ImportError:
32
+ raise ImportError("Please install the 'lime' package to use LIME explanations.")
33
+ import numpy as np
34
+ X = np.array(input_data)
35
+ explainer = LimeTabularExplainer(X, mode="regression")
36
+ explanation = explainer.explain_instance(X[0], lambda x: model(torch.tensor(x, dtype=torch.float32)).detach().numpy())
37
+ return {"explanation": explanation.as_list()}
@@ -0,0 +1,47 @@
1
+ """
2
+ PPO (Proximal Policy Optimization)
3
+ ---------------------------------
4
+ A popular RL algorithm that balances exploration and exploitation by limiting policy updates to stay within a trust region.
5
+
6
+ Example usage:
7
+ >>> import torch
8
+ >>> from langvision.concepts.ppo import PPO
9
+ >>> class MyPPO(PPO):
10
+ ... def step(self, policy, old_log_probs, states, actions, rewards, optimizer):
11
+ ... super().step(policy, old_log_probs, states, actions, rewards, optimizer)
12
+ >>> policy = torch.nn.Linear(2, 2)
13
+ >>> old_log_probs = torch.tensor([0.5, 0.5])
14
+ >>> states = torch.randn(2, 2)
15
+ >>> actions = torch.tensor([0, 1])
16
+ >>> rewards = torch.tensor([1.0, 0.5])
17
+ >>> optimizer = torch.optim.Adam(policy.parameters(), lr=0.01)
18
+ >>> ppo = MyPPO()
19
+ >>> ppo.step(policy, old_log_probs, states, actions, rewards, optimizer)
20
+ """
21
+ from abc import ABC, abstractmethod
22
+ from typing import Any
23
+ import torch
24
+
25
+ class PPO(ABC):
26
+ """
27
+ Abstract base class for Proximal Policy Optimization (PPO).
28
+ """
29
+ @abstractmethod
30
+ def step(self, policy: torch.nn.Module, old_log_probs: torch.Tensor, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, optimizer: torch.optim.Optimizer) -> None:
31
+ """
32
+ Perform a PPO update step.
33
+ """
34
+ # Forward pass
35
+ logits = policy(states)
36
+ log_probs = torch.log_softmax(logits, dim=-1)
37
+ selected_log_probs = log_probs[range(len(actions)), actions]
38
+ # Calculate ratio
39
+ ratio = torch.exp(selected_log_probs - old_log_probs)
40
+ # Calculate surrogate loss
41
+ advantage = rewards - rewards.mean()
42
+ surr1 = ratio * advantage
43
+ surr2 = torch.clamp(ratio, 0.8, 1.2) * advantage
44
+ loss = -torch.min(surr1, surr2).mean()
45
+ optimizer.zero_grad()
46
+ loss.backward()
47
+ optimizer.step()
@@ -0,0 +1,40 @@
1
+ """
2
+ RLHF (Reinforcement Learning from Human Feedback)
3
+ -------------------------------------------------
4
+ A technique where models are trained using feedback from humans to align outputs with human preferences.
5
+
6
+ Example usage:
7
+ >>> import torch
8
+ >>> from langvision.concepts.rlhf import RLHF
9
+ >>> class MyRLHF(RLHF):
10
+ ... def train(self, model, data, feedback_fn, optimizer):
11
+ ... super().train(model, data, feedback_fn, optimizer)
12
+ >>> model = torch.nn.Linear(2, 1)
13
+ >>> data = [torch.tensor([1.0, 2.0]), torch.tensor([2.0, 3.0])]
14
+ >>> def feedback_fn(output): return 1.0 if output.item() > 0 else -1.0
15
+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
16
+ >>> rlhf = MyRLHF()
17
+ >>> rlhf.train(model, data, feedback_fn, optimizer)
18
+ """
19
+ from abc import ABC, abstractmethod
20
+ from typing import Any, Callable, List
21
+ import torch
22
+
23
+ class RLHF(ABC):
24
+ """
25
+ Abstract base class for Reinforcement Learning from Human Feedback (RLHF).
26
+ """
27
+ @abstractmethod
28
+ def train(self, model: torch.nn.Module, data: List[Any], feedback_fn: Callable[[Any], float], optimizer: torch.optim.Optimizer) -> None:
29
+ """
30
+ Train the model using data and a feedback function that simulates human feedback.
31
+ """
32
+ for x in data:
33
+ optimizer.zero_grad()
34
+ output = model(x)
35
+ # Synthetic feedback as reward
36
+ reward = feedback_fn(output)
37
+ # Simple loss: negative reward (maximize reward)
38
+ loss = -reward * output.sum()
39
+ loss.backward()
40
+ optimizer.step()
@@ -0,0 +1,25 @@
1
+ """
2
+ RLVR (Reinforcement Learning with Value Ranking)
3
+ -----------------------------------------------
4
+ A method that uses value-based ranking to guide policy optimization in RL settings.
5
+
6
+ Example usage:
7
+ >>> from langvision.concepts.rlvr import RLVR
8
+ >>> class MyRLVR(RLVR):
9
+ ... def rank_and_update(self, values: list) -> None:
10
+ ... # Implement value ranking and update
11
+ ... pass
12
+ """
13
+ from abc import ABC, abstractmethod
14
+ from typing import Any, List
15
+
16
+ class RLVR(ABC):
17
+ """
18
+ Abstract base class for Reinforcement Learning with Value Ranking (RLVR).
19
+ """
20
+ @abstractmethod
21
+ def rank_and_update(self, values: List[Any]) -> None:
22
+ """
23
+ Rank values and update the policy accordingly.
24
+ """
25
+ pass
@@ -0,0 +1,37 @@
1
+ """
2
+ SHAP (SHapley Additive exPlanations)
3
+ ------------------------------------
4
+ A unified approach to interpreting model predictions using Shapley values from cooperative game theory.
5
+
6
+ Example usage:
7
+ >>> from langvision.concepts.shap import SHAP
8
+ >>> import torch
9
+ >>> class MySHAP(SHAP):
10
+ ... def explain(self, model, input_data):
11
+ ... return super().explain(model, input_data)
12
+ >>> model = torch.nn.Linear(2, 1)
13
+ >>> shap = MySHAP()
14
+ >>> explanation = shap.explain(model, [[0.5, 1.0], [1.0, 2.0]])
15
+ >>> print(explanation)
16
+ """
17
+ from abc import ABC, abstractmethod
18
+ from typing import Any, Dict
19
+
20
+ class SHAP(ABC):
21
+ """
22
+ Abstract base class for SHapley Additive exPlanations (SHAP).
23
+ """
24
+ @abstractmethod
25
+ def explain(self, model: Any, input_data: Any) -> Dict[str, Any]:
26
+ """
27
+ Generate SHAP values for the model's prediction on input_data using shap if available.
28
+ """
29
+ try:
30
+ import shap
31
+ except ImportError:
32
+ raise ImportError("Please install the 'shap' package to use SHAP explanations.")
33
+ import numpy as np
34
+ X = np.array(input_data)
35
+ explainer = shap.Explainer(model)
36
+ shap_values = explainer(X)
37
+ return {"shap_values": shap_values.values.tolist()}