langvision 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langvision might be problematic. Click here for more details.
- langvision/__init__.py +77 -2
- langvision/callbacks/base.py +166 -7
- langvision/cli/__init__.py +85 -0
- langvision/cli/complete_cli.py +319 -0
- langvision/cli/config.py +344 -0
- langvision/cli/evaluate.py +201 -0
- langvision/cli/export.py +177 -0
- langvision/cli/finetune.py +165 -48
- langvision/cli/model_zoo.py +162 -0
- langvision/cli/train.py +27 -13
- langvision/cli/utils.py +258 -0
- langvision/components/attention.py +4 -1
- langvision/concepts/__init__.py +9 -0
- langvision/concepts/ccot.py +30 -0
- langvision/concepts/cot.py +29 -0
- langvision/concepts/dpo.py +37 -0
- langvision/concepts/grpo.py +25 -0
- langvision/concepts/lime.py +37 -0
- langvision/concepts/ppo.py +47 -0
- langvision/concepts/rlhf.py +40 -0
- langvision/concepts/rlvr.py +25 -0
- langvision/concepts/shap.py +37 -0
- langvision/data/enhanced_datasets.py +582 -0
- langvision/model_zoo.py +169 -2
- langvision/models/lora.py +189 -17
- langvision/models/multimodal.py +297 -0
- langvision/models/resnet.py +303 -0
- langvision/training/advanced_trainer.py +478 -0
- langvision/training/trainer.py +30 -2
- langvision/utils/config.py +180 -9
- langvision/utils/metrics.py +448 -0
- langvision/utils/setup.py +266 -0
- langvision-0.1.0.dist-info/METADATA +50 -0
- langvision-0.1.0.dist-info/RECORD +61 -0
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/WHEEL +1 -1
- langvision-0.1.0.dist-info/entry_points.txt +2 -0
- langvision-0.0.1.dist-info/METADATA +0 -463
- langvision-0.0.1.dist-info/RECORD +0 -40
- langvision-0.0.1.dist-info/entry_points.txt +0 -2
- langvision-0.0.1.dist-info/licenses/LICENSE +0 -21
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/top_level.txt +0 -0
langvision/model_zoo.py
CHANGED
|
@@ -1,2 +1,169 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
"""
|
|
2
|
+
Model Zoo for Langvision - Pre-trained Vision Transformer models.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Dict, List, Optional, Any
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
# Default model configurations
|
|
12
|
+
DEFAULT_MODELS = {
|
|
13
|
+
"vit_tiny_patch16_224": {
|
|
14
|
+
"name": "vit_tiny_patch16_224",
|
|
15
|
+
"type": "vision_transformer",
|
|
16
|
+
"size": "5.4M",
|
|
17
|
+
"description": "Vision Transformer Tiny (16x16 patches, 224x224 input)",
|
|
18
|
+
"config": {
|
|
19
|
+
"img_size": 224,
|
|
20
|
+
"patch_size": 16,
|
|
21
|
+
"embed_dim": 192,
|
|
22
|
+
"depth": 12,
|
|
23
|
+
"num_heads": 3,
|
|
24
|
+
"mlp_ratio": 4.0,
|
|
25
|
+
"num_classes": 1000
|
|
26
|
+
}
|
|
27
|
+
},
|
|
28
|
+
"vit_small_patch16_224": {
|
|
29
|
+
"name": "vit_small_patch16_224",
|
|
30
|
+
"type": "vision_transformer",
|
|
31
|
+
"size": "22.1M",
|
|
32
|
+
"description": "Vision Transformer Small (16x16 patches, 224x224 input)",
|
|
33
|
+
"config": {
|
|
34
|
+
"img_size": 224,
|
|
35
|
+
"patch_size": 16,
|
|
36
|
+
"embed_dim": 384,
|
|
37
|
+
"depth": 12,
|
|
38
|
+
"num_heads": 6,
|
|
39
|
+
"mlp_ratio": 4.0,
|
|
40
|
+
"num_classes": 1000
|
|
41
|
+
}
|
|
42
|
+
},
|
|
43
|
+
"vit_base_patch16_224": {
|
|
44
|
+
"name": "vit_base_patch16_224",
|
|
45
|
+
"type": "vision_transformer",
|
|
46
|
+
"size": "86.4M",
|
|
47
|
+
"description": "Vision Transformer Base (16x16 patches, 224x224 input)",
|
|
48
|
+
"config": {
|
|
49
|
+
"img_size": 224,
|
|
50
|
+
"patch_size": 16,
|
|
51
|
+
"embed_dim": 768,
|
|
52
|
+
"depth": 12,
|
|
53
|
+
"num_heads": 12,
|
|
54
|
+
"mlp_ratio": 4.0,
|
|
55
|
+
"num_classes": 1000
|
|
56
|
+
}
|
|
57
|
+
},
|
|
58
|
+
"vit_large_patch16_224": {
|
|
59
|
+
"name": "vit_large_patch16_224",
|
|
60
|
+
"type": "vision_transformer",
|
|
61
|
+
"size": "304.3M",
|
|
62
|
+
"description": "Vision Transformer Large (16x16 patches, 224x224 input)",
|
|
63
|
+
"config": {
|
|
64
|
+
"img_size": 224,
|
|
65
|
+
"patch_size": 16,
|
|
66
|
+
"embed_dim": 1024,
|
|
67
|
+
"depth": 24,
|
|
68
|
+
"num_heads": 16,
|
|
69
|
+
"mlp_ratio": 4.0,
|
|
70
|
+
"num_classes": 1000
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_available_models() -> List[Dict[str, Any]]:
|
|
77
|
+
"""Get list of all available models."""
|
|
78
|
+
return list(DEFAULT_MODELS.values())
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_model_info(model_name: str) -> Dict[str, Any]:
|
|
82
|
+
"""Get detailed information about a specific model."""
|
|
83
|
+
if model_name not in DEFAULT_MODELS:
|
|
84
|
+
raise ValueError(f"Model '{model_name}' not found. Available models: {list(DEFAULT_MODELS.keys())}")
|
|
85
|
+
|
|
86
|
+
return DEFAULT_MODELS[model_name]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def download_model(model_name: str, output_dir: str = "./models", force: bool = False) -> str:
|
|
90
|
+
"""Download a pre-trained model (placeholder implementation)."""
|
|
91
|
+
if model_name not in DEFAULT_MODELS:
|
|
92
|
+
raise ValueError(f"Model '{model_name}' not found. Available models: {list(DEFAULT_MODELS.keys())}")
|
|
93
|
+
|
|
94
|
+
# Create output directory
|
|
95
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
96
|
+
|
|
97
|
+
# For now, just save the model configuration
|
|
98
|
+
# In a real implementation, this would download actual model weights
|
|
99
|
+
model_info = DEFAULT_MODELS[model_name]
|
|
100
|
+
output_path = os.path.join(output_dir, f"{model_name}.json")
|
|
101
|
+
|
|
102
|
+
if os.path.exists(output_path) and not force:
|
|
103
|
+
raise FileExistsError(f"Model already exists at {output_path}. Use --force to overwrite.")
|
|
104
|
+
|
|
105
|
+
with open(output_path, 'w') as f:
|
|
106
|
+
json.dump(model_info, f, indent=2)
|
|
107
|
+
|
|
108
|
+
return output_path
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def list_models_by_type(model_type: str = None) -> List[Dict[str, Any]]:
|
|
112
|
+
"""List models filtered by type."""
|
|
113
|
+
models = get_available_models()
|
|
114
|
+
if model_type:
|
|
115
|
+
models = [m for m in models if m.get('type') == model_type]
|
|
116
|
+
return models
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def search_models(query: str) -> List[Dict[str, Any]]:
|
|
120
|
+
"""Search models by name or description."""
|
|
121
|
+
models = get_available_models()
|
|
122
|
+
query_lower = query.lower()
|
|
123
|
+
|
|
124
|
+
results = []
|
|
125
|
+
for model in models:
|
|
126
|
+
if (query_lower in model.get('name', '').lower() or
|
|
127
|
+
query_lower in model.get('description', '').lower()):
|
|
128
|
+
results.append(model)
|
|
129
|
+
|
|
130
|
+
return results
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class ModelZoo:
|
|
134
|
+
"""Model Zoo manager for Langvision."""
|
|
135
|
+
|
|
136
|
+
def __init__(self, cache_dir: str = "./models"):
|
|
137
|
+
self.cache_dir = Path(cache_dir)
|
|
138
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
139
|
+
self.logger = logging.getLogger(__name__)
|
|
140
|
+
|
|
141
|
+
def list_models(self, model_type: str = None) -> List[Dict[str, Any]]:
|
|
142
|
+
"""List available models."""
|
|
143
|
+
return list_models_by_type(model_type)
|
|
144
|
+
|
|
145
|
+
def get_model(self, model_name: str) -> Dict[str, Any]:
|
|
146
|
+
"""Get model information."""
|
|
147
|
+
return get_model_info(model_name)
|
|
148
|
+
|
|
149
|
+
def download(self, model_name: str, force: bool = False) -> str:
|
|
150
|
+
"""Download a model."""
|
|
151
|
+
return download_model(model_name, str(self.cache_dir), force)
|
|
152
|
+
|
|
153
|
+
def search(self, query: str) -> List[Dict[str, Any]]:
|
|
154
|
+
"""Search for models."""
|
|
155
|
+
return search_models(query)
|
|
156
|
+
|
|
157
|
+
def is_downloaded(self, model_name: str) -> bool:
|
|
158
|
+
"""Check if model is downloaded."""
|
|
159
|
+
model_path = self.cache_dir / f"{model_name}.json"
|
|
160
|
+
return model_path.exists()
|
|
161
|
+
|
|
162
|
+
def get_downloaded_models(self) -> List[str]:
|
|
163
|
+
"""Get list of downloaded models."""
|
|
164
|
+
downloaded = []
|
|
165
|
+
for model_file in self.cache_dir.glob("*.json"):
|
|
166
|
+
model_name = model_file.stem
|
|
167
|
+
if model_name in DEFAULT_MODELS:
|
|
168
|
+
downloaded.append(model_name)
|
|
169
|
+
return downloaded
|
langvision/models/lora.py
CHANGED
|
@@ -1,30 +1,202 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from typing import Optional, Dict, Any
|
|
5
|
+
import math
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class LoRAConfig:
|
|
10
|
+
"""Configuration class for LoRA parameters."""
|
|
11
|
+
r: int = 4
|
|
12
|
+
alpha: float = 1.0
|
|
13
|
+
dropout: float = 0.0
|
|
14
|
+
target_modules: Optional[list] = None
|
|
15
|
+
bias: str = "none" # "none", "all", "lora_only"
|
|
16
|
+
task_type: str = "FEATURE_EXTRACTION"
|
|
17
|
+
inference_mode: bool = False
|
|
18
|
+
|
|
19
|
+
def __post_init__(self):
|
|
20
|
+
if self.target_modules is None:
|
|
21
|
+
self.target_modules = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
|
3
22
|
|
|
4
23
|
class LoRALinear(nn.Module):
|
|
5
|
-
|
|
24
|
+
"""Enhanced LoRA Linear layer with better initialization and features."""
|
|
25
|
+
|
|
26
|
+
def __init__(self,
|
|
27
|
+
in_features: int,
|
|
28
|
+
out_features: int,
|
|
29
|
+
r: int = 4,
|
|
30
|
+
alpha: float = 1.0,
|
|
31
|
+
dropout: float = 0.0,
|
|
32
|
+
bias: bool = True,
|
|
33
|
+
fan_in_fan_out: bool = False):
|
|
6
34
|
super().__init__()
|
|
7
35
|
self.r = r
|
|
8
36
|
self.alpha = alpha
|
|
9
|
-
self.
|
|
37
|
+
self.fan_in_fan_out = fan_in_fan_out
|
|
38
|
+
self.in_features = in_features
|
|
39
|
+
self.out_features = out_features
|
|
40
|
+
|
|
41
|
+
# Original linear layer (frozen)
|
|
42
|
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
43
|
+
|
|
44
|
+
# LoRA parameters
|
|
10
45
|
if r > 0:
|
|
11
|
-
self.lora_A = nn.Parameter(torch.
|
|
12
|
-
self.lora_B = nn.Parameter(torch.
|
|
46
|
+
self.lora_A = nn.Parameter(torch.zeros(r, in_features))
|
|
47
|
+
self.lora_B = nn.Parameter(torch.zeros(out_features, r))
|
|
48
|
+
self.scaling = alpha / r
|
|
49
|
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
|
50
|
+
self.reset_parameters()
|
|
13
51
|
else:
|
|
14
52
|
self.lora_A = None
|
|
15
53
|
self.lora_B = None
|
|
16
|
-
|
|
54
|
+
self.scaling = 1.0
|
|
55
|
+
self.dropout = nn.Identity()
|
|
56
|
+
|
|
57
|
+
def reset_parameters(self):
|
|
58
|
+
"""Initialize LoRA parameters using Kaiming uniform initialization."""
|
|
59
|
+
if hasattr(self, 'lora_A') and self.lora_A is not None:
|
|
60
|
+
# Initialize A with random values and B with zeros
|
|
61
|
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
62
|
+
nn.init.zeros_(self.lora_B)
|
|
63
|
+
|
|
64
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
65
|
+
"""Forward pass with LoRA adaptation."""
|
|
66
|
+
# Original linear transformation
|
|
67
|
+
result = self.linear(x)
|
|
68
|
+
|
|
69
|
+
# Add LoRA adaptation if enabled
|
|
70
|
+
if self.r > 0 and self.lora_A is not None:
|
|
71
|
+
# Apply dropout to input
|
|
72
|
+
x_dropped = self.dropout(x)
|
|
73
|
+
|
|
74
|
+
# LoRA computation: x @ A^T @ B^T
|
|
75
|
+
if self.fan_in_fan_out:
|
|
76
|
+
lora_result = F.linear(x_dropped, self.lora_A.T) @ self.lora_B.T
|
|
77
|
+
else:
|
|
78
|
+
lora_result = F.linear(F.linear(x_dropped, self.lora_A.T), self.lora_B.T)
|
|
79
|
+
|
|
80
|
+
result = result + lora_result * self.scaling
|
|
81
|
+
|
|
82
|
+
return result
|
|
83
|
+
|
|
84
|
+
def merge_weights(self):
|
|
85
|
+
"""Merge LoRA weights into the original linear layer."""
|
|
86
|
+
if self.r > 0 and self.lora_A is not None:
|
|
87
|
+
# Compute LoRA weight update
|
|
88
|
+
delta_w = self.lora_B @ self.lora_A * self.scaling
|
|
89
|
+
|
|
90
|
+
# Merge with original weights
|
|
91
|
+
if self.fan_in_fan_out:
|
|
92
|
+
self.linear.weight.data += delta_w.T
|
|
93
|
+
else:
|
|
94
|
+
self.linear.weight.data += delta_w
|
|
95
|
+
|
|
96
|
+
# Reset LoRA parameters
|
|
97
|
+
self.lora_A.data.zero_()
|
|
98
|
+
self.lora_B.data.zero_()
|
|
99
|
+
|
|
100
|
+
def unmerge_weights(self):
|
|
101
|
+
"""Unmerge LoRA weights from the original linear layer."""
|
|
102
|
+
if self.r > 0 and self.lora_A is not None:
|
|
103
|
+
# Compute LoRA weight update
|
|
104
|
+
delta_w = self.lora_B @ self.lora_A * self.scaling
|
|
105
|
+
|
|
106
|
+
# Remove from original weights
|
|
107
|
+
if self.fan_in_fan_out:
|
|
108
|
+
self.linear.weight.data -= delta_w.T
|
|
109
|
+
else:
|
|
110
|
+
self.linear.weight.data -= delta_w
|
|
111
|
+
|
|
112
|
+
class AdaLoRALinear(LoRALinear):
|
|
113
|
+
"""Adaptive LoRA with dynamic rank adjustment."""
|
|
114
|
+
|
|
115
|
+
def __init__(self, *args, **kwargs):
|
|
116
|
+
self.target_rank = kwargs.pop('target_rank', None)
|
|
117
|
+
self.rank_pattern = kwargs.pop('rank_pattern', None)
|
|
118
|
+
super().__init__(*args, **kwargs)
|
|
119
|
+
|
|
120
|
+
if self.target_rank is not None:
|
|
121
|
+
self.rank_scheduler = self._create_rank_scheduler()
|
|
122
|
+
|
|
123
|
+
def _create_rank_scheduler(self):
|
|
124
|
+
"""Create a rank scheduler for adaptive rank adjustment."""
|
|
125
|
+
# Placeholder for rank scheduling logic
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
def update_rank(self, new_rank: int):
|
|
129
|
+
"""Dynamically update the LoRA rank."""
|
|
130
|
+
if new_rank != self.r and new_rank > 0:
|
|
131
|
+
old_r = self.r
|
|
132
|
+
self.r = new_rank
|
|
133
|
+
self.scaling = self.alpha / new_rank
|
|
134
|
+
|
|
135
|
+
# Resize parameters
|
|
136
|
+
if old_r > 0:
|
|
137
|
+
# Preserve existing weights up to min(old_r, new_rank)
|
|
138
|
+
min_r = min(old_r, new_rank)
|
|
139
|
+
old_A = self.lora_A.data[:min_r, :]
|
|
140
|
+
old_B = self.lora_B.data[:, :min_r]
|
|
141
|
+
|
|
142
|
+
# Create new parameters
|
|
143
|
+
self.lora_A = nn.Parameter(torch.zeros(new_rank, self.in_features))
|
|
144
|
+
self.lora_B = nn.Parameter(torch.zeros(self.out_features, new_rank))
|
|
145
|
+
|
|
146
|
+
if old_r > 0:
|
|
147
|
+
# Copy preserved weights
|
|
148
|
+
self.lora_A.data[:min_r, :] = old_A
|
|
149
|
+
self.lora_B.data[:, :min_r] = old_B
|
|
150
|
+
|
|
151
|
+
# Initialize new parameters
|
|
152
|
+
if new_rank > old_r:
|
|
153
|
+
nn.init.kaiming_uniform_(self.lora_A.data[old_r:, :], a=math.sqrt(5))
|
|
154
|
+
nn.init.zeros_(self.lora_B.data[:, old_r:])
|
|
17
155
|
|
|
18
|
-
|
|
156
|
+
class QLoRALinear(nn.Module):
|
|
157
|
+
"""Quantized LoRA implementation for memory efficiency."""
|
|
158
|
+
|
|
159
|
+
def __init__(self,
|
|
160
|
+
in_features: int,
|
|
161
|
+
out_features: int,
|
|
162
|
+
r: int = 4,
|
|
163
|
+
alpha: float = 1.0,
|
|
164
|
+
dropout: float = 0.0,
|
|
165
|
+
compute_dtype: torch.dtype = torch.float16,
|
|
166
|
+
quant_type: str = "nf4"):
|
|
167
|
+
super().__init__()
|
|
168
|
+
self.r = r
|
|
169
|
+
self.alpha = alpha
|
|
170
|
+
self.scaling = alpha / r if r > 0 else 1.0
|
|
171
|
+
self.compute_dtype = compute_dtype
|
|
172
|
+
self.quant_type = quant_type
|
|
173
|
+
|
|
174
|
+
# Quantized base layer (placeholder - would need actual quantization library)
|
|
175
|
+
self.base_layer = nn.Linear(in_features, out_features, bias=False)
|
|
176
|
+
|
|
177
|
+
# LoRA adapters in full precision
|
|
178
|
+
if r > 0:
|
|
179
|
+
self.lora_A = nn.Parameter(torch.zeros(r, in_features, dtype=compute_dtype))
|
|
180
|
+
self.lora_B = nn.Parameter(torch.zeros(out_features, r, dtype=compute_dtype))
|
|
181
|
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
|
182
|
+
self.reset_parameters()
|
|
183
|
+
|
|
184
|
+
def reset_parameters(self):
|
|
185
|
+
if hasattr(self, 'lora_A') and self.lora_A is not None:
|
|
186
|
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
187
|
+
nn.init.zeros_(self.lora_B)
|
|
188
|
+
|
|
189
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
190
|
+
# Convert to compute dtype
|
|
191
|
+
x = x.to(self.compute_dtype)
|
|
192
|
+
|
|
193
|
+
# Base layer forward (quantized)
|
|
194
|
+
result = self.base_layer(x)
|
|
195
|
+
|
|
196
|
+
# LoRA adaptation
|
|
19
197
|
if self.r > 0:
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
# Reshape back to original except last dim is out_features
|
|
26
|
-
out_shape = list(orig_shape[:-1]) + [self.lora_B.shape[1]]
|
|
27
|
-
lora_out = lora_out.view(*out_shape)
|
|
28
|
-
return lora_out
|
|
29
|
-
else:
|
|
30
|
-
return 0.0
|
|
198
|
+
x_dropped = self.dropout(x)
|
|
199
|
+
lora_result = F.linear(F.linear(x_dropped, self.lora_A.T), self.lora_B.T)
|
|
200
|
+
result = result + lora_result * self.scaling
|
|
201
|
+
|
|
202
|
+
return result
|
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Multimodal Vision-Language Models with LoRA fine-tuning support.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from typing import Optional, Dict, Any, Tuple, List
|
|
9
|
+
from transformers import AutoTokenizer, AutoModel
|
|
10
|
+
from .vision_transformer import VisionTransformer
|
|
11
|
+
from .lora import LoRALinear, LoRAConfig
|
|
12
|
+
import math
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CrossAttention(nn.Module):
|
|
16
|
+
"""Cross-attention mechanism for vision-language fusion."""
|
|
17
|
+
|
|
18
|
+
def __init__(self,
|
|
19
|
+
vision_dim: int,
|
|
20
|
+
text_dim: int,
|
|
21
|
+
hidden_dim: int = 512,
|
|
22
|
+
num_heads: int = 8,
|
|
23
|
+
dropout: float = 0.1,
|
|
24
|
+
lora_config: Optional[LoRAConfig] = None):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.hidden_dim = hidden_dim
|
|
27
|
+
self.num_heads = num_heads
|
|
28
|
+
self.head_dim = hidden_dim // num_heads
|
|
29
|
+
self.scale = self.head_dim ** -0.5
|
|
30
|
+
|
|
31
|
+
# Query, Key, Value projections
|
|
32
|
+
if lora_config and lora_config.r > 0:
|
|
33
|
+
self.q_proj = LoRALinear(vision_dim, hidden_dim,
|
|
34
|
+
r=lora_config.r, alpha=lora_config.alpha, dropout=lora_config.dropout)
|
|
35
|
+
self.k_proj = LoRALinear(text_dim, hidden_dim,
|
|
36
|
+
r=lora_config.r, alpha=lora_config.alpha, dropout=lora_config.dropout)
|
|
37
|
+
self.v_proj = LoRALinear(text_dim, hidden_dim,
|
|
38
|
+
r=lora_config.r, alpha=lora_config.alpha, dropout=lora_config.dropout)
|
|
39
|
+
self.out_proj = LoRALinear(hidden_dim, vision_dim,
|
|
40
|
+
r=lora_config.r, alpha=lora_config.alpha, dropout=lora_config.dropout)
|
|
41
|
+
else:
|
|
42
|
+
self.q_proj = nn.Linear(vision_dim, hidden_dim)
|
|
43
|
+
self.k_proj = nn.Linear(text_dim, hidden_dim)
|
|
44
|
+
self.v_proj = nn.Linear(text_dim, hidden_dim)
|
|
45
|
+
self.out_proj = nn.Linear(hidden_dim, vision_dim)
|
|
46
|
+
|
|
47
|
+
self.dropout = nn.Dropout(dropout)
|
|
48
|
+
self.layer_norm = nn.LayerNorm(vision_dim)
|
|
49
|
+
|
|
50
|
+
def forward(self,
|
|
51
|
+
vision_features: torch.Tensor,
|
|
52
|
+
text_features: torch.Tensor,
|
|
53
|
+
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
54
|
+
"""
|
|
55
|
+
Args:
|
|
56
|
+
vision_features: (B, N_v, D_v) vision tokens
|
|
57
|
+
text_features: (B, N_t, D_t) text tokens
|
|
58
|
+
attention_mask: (B, N_t) mask for text tokens
|
|
59
|
+
"""
|
|
60
|
+
B, N_v, D_v = vision_features.shape
|
|
61
|
+
B, N_t, D_t = text_features.shape
|
|
62
|
+
|
|
63
|
+
# Project to query, key, value
|
|
64
|
+
Q = self.q_proj(vision_features) # (B, N_v, hidden_dim)
|
|
65
|
+
K = self.k_proj(text_features) # (B, N_t, hidden_dim)
|
|
66
|
+
V = self.v_proj(text_features) # (B, N_t, hidden_dim)
|
|
67
|
+
|
|
68
|
+
# Reshape for multi-head attention
|
|
69
|
+
Q = Q.view(B, N_v, self.num_heads, self.head_dim).transpose(1, 2) # (B, num_heads, N_v, head_dim)
|
|
70
|
+
K = K.view(B, N_t, self.num_heads, self.head_dim).transpose(1, 2) # (B, num_heads, N_t, head_dim)
|
|
71
|
+
V = V.view(B, N_t, self.num_heads, self.head_dim).transpose(1, 2) # (B, num_heads, N_t, head_dim)
|
|
72
|
+
|
|
73
|
+
# Compute attention scores
|
|
74
|
+
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # (B, num_heads, N_v, N_t)
|
|
75
|
+
|
|
76
|
+
# Apply attention mask if provided
|
|
77
|
+
if attention_mask is not None:
|
|
78
|
+
attn_mask = attention_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, N_t)
|
|
79
|
+
attn_scores = attn_scores.masked_fill(attn_mask == 0, float('-inf'))
|
|
80
|
+
|
|
81
|
+
# Apply softmax
|
|
82
|
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
83
|
+
attn_weights = self.dropout(attn_weights)
|
|
84
|
+
|
|
85
|
+
# Apply attention to values
|
|
86
|
+
attn_output = torch.matmul(attn_weights, V) # (B, num_heads, N_v, head_dim)
|
|
87
|
+
|
|
88
|
+
# Reshape and project
|
|
89
|
+
attn_output = attn_output.transpose(1, 2).contiguous().view(B, N_v, self.hidden_dim)
|
|
90
|
+
output = self.out_proj(attn_output)
|
|
91
|
+
|
|
92
|
+
# Residual connection and layer norm
|
|
93
|
+
output = self.layer_norm(vision_features + output)
|
|
94
|
+
|
|
95
|
+
return output
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class VisionLanguageModel(nn.Module):
|
|
99
|
+
"""Vision-Language Model with cross-modal attention and LoRA fine-tuning."""
|
|
100
|
+
|
|
101
|
+
def __init__(self,
|
|
102
|
+
vision_model: str = "vit_base",
|
|
103
|
+
text_model: str = "bert-base-uncased",
|
|
104
|
+
vision_dim: int = 768,
|
|
105
|
+
text_dim: int = 768,
|
|
106
|
+
hidden_dim: int = 512,
|
|
107
|
+
num_classes: int = 1000,
|
|
108
|
+
max_text_length: int = 77,
|
|
109
|
+
lora_config: Optional[LoRAConfig] = None):
|
|
110
|
+
super().__init__()
|
|
111
|
+
|
|
112
|
+
self.vision_dim = vision_dim
|
|
113
|
+
self.text_dim = text_dim
|
|
114
|
+
self.hidden_dim = hidden_dim
|
|
115
|
+
self.max_text_length = max_text_length
|
|
116
|
+
|
|
117
|
+
# Vision encoder
|
|
118
|
+
if vision_model == "vit_base":
|
|
119
|
+
self.vision_encoder = VisionTransformer(
|
|
120
|
+
embed_dim=vision_dim,
|
|
121
|
+
lora_config=lora_config
|
|
122
|
+
)
|
|
123
|
+
else:
|
|
124
|
+
raise ValueError(f"Unsupported vision model: {vision_model}")
|
|
125
|
+
|
|
126
|
+
# Text encoder
|
|
127
|
+
self.tokenizer = AutoTokenizer.from_pretrained(text_model)
|
|
128
|
+
self.text_encoder = AutoModel.from_pretrained(text_model)
|
|
129
|
+
|
|
130
|
+
# Freeze text encoder if using LoRA
|
|
131
|
+
if lora_config and lora_config.r > 0:
|
|
132
|
+
for param in self.text_encoder.parameters():
|
|
133
|
+
param.requires_grad = False
|
|
134
|
+
|
|
135
|
+
# Cross-modal fusion
|
|
136
|
+
self.cross_attention = CrossAttention(
|
|
137
|
+
vision_dim=vision_dim,
|
|
138
|
+
text_dim=text_dim,
|
|
139
|
+
hidden_dim=hidden_dim,
|
|
140
|
+
lora_config=lora_config
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Classification head
|
|
144
|
+
if lora_config and lora_config.r > 0:
|
|
145
|
+
self.classifier = LoRALinear(vision_dim, num_classes,
|
|
146
|
+
r=lora_config.r, alpha=lora_config.alpha, dropout=lora_config.dropout)
|
|
147
|
+
else:
|
|
148
|
+
self.classifier = nn.Linear(vision_dim, num_classes)
|
|
149
|
+
|
|
150
|
+
# Text projection for contrastive learning
|
|
151
|
+
if lora_config and lora_config.r > 0:
|
|
152
|
+
self.text_projection = LoRALinear(text_dim, hidden_dim,
|
|
153
|
+
r=lora_config.r, alpha=lora_config.alpha, dropout=lora_config.dropout)
|
|
154
|
+
self.vision_projection = LoRALinear(vision_dim, hidden_dim,
|
|
155
|
+
r=lora_config.r, alpha=lora_config.alpha, dropout=lora_config.dropout)
|
|
156
|
+
else:
|
|
157
|
+
self.text_projection = nn.Linear(text_dim, hidden_dim)
|
|
158
|
+
self.vision_projection = nn.Linear(vision_dim, hidden_dim)
|
|
159
|
+
|
|
160
|
+
self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
|
|
161
|
+
|
|
162
|
+
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
|
|
163
|
+
"""Encode images to feature representations."""
|
|
164
|
+
return self.vision_encoder(images)
|
|
165
|
+
|
|
166
|
+
def encode_text(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
167
|
+
"""Encode texts to feature representations."""
|
|
168
|
+
# Tokenize texts
|
|
169
|
+
tokens = self.tokenizer(
|
|
170
|
+
texts,
|
|
171
|
+
padding=True,
|
|
172
|
+
truncation=True,
|
|
173
|
+
max_length=self.max_text_length,
|
|
174
|
+
return_tensors="pt"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Get text features
|
|
178
|
+
with torch.no_grad() if hasattr(self, 'text_encoder') else torch.enable_grad():
|
|
179
|
+
text_outputs = self.text_encoder(**tokens)
|
|
180
|
+
text_features = text_outputs.last_hidden_state # (B, seq_len, text_dim)
|
|
181
|
+
|
|
182
|
+
return text_features, tokens['attention_mask']
|
|
183
|
+
|
|
184
|
+
def forward(self,
|
|
185
|
+
images: torch.Tensor,
|
|
186
|
+
texts: Optional[List[str]] = None,
|
|
187
|
+
return_features: bool = False) -> Dict[str, torch.Tensor]:
|
|
188
|
+
"""
|
|
189
|
+
Forward pass for vision-language model.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
images: (B, C, H, W) input images
|
|
193
|
+
texts: List of text descriptions
|
|
194
|
+
return_features: Whether to return intermediate features
|
|
195
|
+
"""
|
|
196
|
+
# Encode images
|
|
197
|
+
vision_features = self.encode_image(images) # (B, N_patches, vision_dim)
|
|
198
|
+
|
|
199
|
+
outputs = {"vision_features": vision_features}
|
|
200
|
+
|
|
201
|
+
if texts is not None:
|
|
202
|
+
# Encode texts
|
|
203
|
+
text_features, attention_mask = self.encode_text(texts) # (B, seq_len, text_dim)
|
|
204
|
+
outputs["text_features"] = text_features
|
|
205
|
+
|
|
206
|
+
# Cross-modal fusion
|
|
207
|
+
fused_features = self.cross_attention(
|
|
208
|
+
vision_features, text_features, attention_mask
|
|
209
|
+
) # (B, N_patches, vision_dim)
|
|
210
|
+
outputs["fused_features"] = fused_features
|
|
211
|
+
|
|
212
|
+
# Global pooling for classification
|
|
213
|
+
pooled_vision = fused_features.mean(dim=1) # (B, vision_dim)
|
|
214
|
+
pooled_text = text_features.mean(dim=1) # (B, text_dim)
|
|
215
|
+
|
|
216
|
+
# Classification
|
|
217
|
+
logits = self.classifier(pooled_vision)
|
|
218
|
+
outputs["logits"] = logits
|
|
219
|
+
|
|
220
|
+
# Contrastive learning projections
|
|
221
|
+
vision_proj = F.normalize(self.vision_projection(pooled_vision), dim=-1)
|
|
222
|
+
text_proj = F.normalize(self.text_projection(pooled_text), dim=-1)
|
|
223
|
+
|
|
224
|
+
# Compute contrastive logits
|
|
225
|
+
logit_scale = self.logit_scale.exp()
|
|
226
|
+
contrastive_logits = logit_scale * vision_proj @ text_proj.T
|
|
227
|
+
|
|
228
|
+
outputs.update({
|
|
229
|
+
"vision_proj": vision_proj,
|
|
230
|
+
"text_proj": text_proj,
|
|
231
|
+
"contrastive_logits": contrastive_logits,
|
|
232
|
+
"logit_scale": logit_scale
|
|
233
|
+
})
|
|
234
|
+
else:
|
|
235
|
+
# Image-only classification
|
|
236
|
+
pooled_vision = vision_features.mean(dim=1)
|
|
237
|
+
logits = self.classifier(pooled_vision)
|
|
238
|
+
outputs["logits"] = logits
|
|
239
|
+
|
|
240
|
+
if not return_features:
|
|
241
|
+
# Return only essential outputs for training
|
|
242
|
+
essential_keys = ["logits", "contrastive_logits"] if texts else ["logits"]
|
|
243
|
+
outputs = {k: v for k, v in outputs.items() if k in essential_keys}
|
|
244
|
+
|
|
245
|
+
return outputs
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class CLIPLoss(nn.Module):
|
|
249
|
+
"""CLIP-style contrastive loss for vision-language learning."""
|
|
250
|
+
|
|
251
|
+
def __init__(self, temperature: float = 0.07):
|
|
252
|
+
super().__init__()
|
|
253
|
+
self.temperature = temperature
|
|
254
|
+
self.cross_entropy = nn.CrossEntropyLoss()
|
|
255
|
+
|
|
256
|
+
def forward(self, vision_proj: torch.Tensor, text_proj: torch.Tensor) -> torch.Tensor:
|
|
257
|
+
"""
|
|
258
|
+
Compute contrastive loss between vision and text projections.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
vision_proj: (B, D) normalized vision projections
|
|
262
|
+
text_proj: (B, D) normalized text projections
|
|
263
|
+
"""
|
|
264
|
+
batch_size = vision_proj.shape[0]
|
|
265
|
+
|
|
266
|
+
# Compute similarity matrix
|
|
267
|
+
logits = vision_proj @ text_proj.T / self.temperature
|
|
268
|
+
|
|
269
|
+
# Labels are diagonal (each image matches its corresponding text)
|
|
270
|
+
labels = torch.arange(batch_size, device=logits.device)
|
|
271
|
+
|
|
272
|
+
# Symmetric loss (image-to-text and text-to-image)
|
|
273
|
+
loss_i2t = self.cross_entropy(logits, labels)
|
|
274
|
+
loss_t2i = self.cross_entropy(logits.T, labels)
|
|
275
|
+
|
|
276
|
+
return (loss_i2t + loss_t2i) / 2
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def create_multimodal_model(model_config: Dict[str, Any]) -> VisionLanguageModel:
|
|
280
|
+
"""Factory function to create multimodal models with different configurations."""
|
|
281
|
+
|
|
282
|
+
# Default configuration
|
|
283
|
+
default_config = {
|
|
284
|
+
"vision_model": "vit_base",
|
|
285
|
+
"text_model": "bert-base-uncased",
|
|
286
|
+
"vision_dim": 768,
|
|
287
|
+
"text_dim": 768,
|
|
288
|
+
"hidden_dim": 512,
|
|
289
|
+
"num_classes": 1000,
|
|
290
|
+
"max_text_length": 77,
|
|
291
|
+
"lora_config": None
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
# Update with user config
|
|
295
|
+
config = {**default_config, **model_config}
|
|
296
|
+
|
|
297
|
+
return VisionLanguageModel(**config)
|