langvision 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langvision might be problematic. Click here for more details.

Files changed (41) hide show
  1. langvision/__init__.py +77 -2
  2. langvision/callbacks/base.py +166 -7
  3. langvision/cli/__init__.py +85 -0
  4. langvision/cli/complete_cli.py +319 -0
  5. langvision/cli/config.py +344 -0
  6. langvision/cli/evaluate.py +201 -0
  7. langvision/cli/export.py +177 -0
  8. langvision/cli/finetune.py +165 -48
  9. langvision/cli/model_zoo.py +162 -0
  10. langvision/cli/train.py +27 -13
  11. langvision/cli/utils.py +258 -0
  12. langvision/components/attention.py +4 -1
  13. langvision/concepts/__init__.py +9 -0
  14. langvision/concepts/ccot.py +30 -0
  15. langvision/concepts/cot.py +29 -0
  16. langvision/concepts/dpo.py +37 -0
  17. langvision/concepts/grpo.py +25 -0
  18. langvision/concepts/lime.py +37 -0
  19. langvision/concepts/ppo.py +47 -0
  20. langvision/concepts/rlhf.py +40 -0
  21. langvision/concepts/rlvr.py +25 -0
  22. langvision/concepts/shap.py +37 -0
  23. langvision/data/enhanced_datasets.py +582 -0
  24. langvision/model_zoo.py +169 -2
  25. langvision/models/lora.py +189 -17
  26. langvision/models/multimodal.py +297 -0
  27. langvision/models/resnet.py +303 -0
  28. langvision/training/advanced_trainer.py +478 -0
  29. langvision/training/trainer.py +30 -2
  30. langvision/utils/config.py +180 -9
  31. langvision/utils/metrics.py +448 -0
  32. langvision/utils/setup.py +266 -0
  33. langvision-0.1.0.dist-info/METADATA +50 -0
  34. langvision-0.1.0.dist-info/RECORD +61 -0
  35. {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/WHEEL +1 -1
  36. langvision-0.1.0.dist-info/entry_points.txt +2 -0
  37. langvision-0.0.1.dist-info/METADATA +0 -463
  38. langvision-0.0.1.dist-info/RECORD +0 -40
  39. langvision-0.0.1.dist-info/entry_points.txt +0 -2
  40. langvision-0.0.1.dist-info/licenses/LICENSE +0 -21
  41. {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/top_level.txt +0 -0
langvision/model_zoo.py CHANGED
@@ -1,2 +1,169 @@
1
- class VisualModelLoader:
2
- pass
1
+ """
2
+ Model Zoo for Langvision - Pre-trained Vision Transformer models.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import logging
8
+ from typing import Dict, List, Optional, Any
9
+ from pathlib import Path
10
+
11
+ # Default model configurations
12
+ DEFAULT_MODELS = {
13
+ "vit_tiny_patch16_224": {
14
+ "name": "vit_tiny_patch16_224",
15
+ "type": "vision_transformer",
16
+ "size": "5.4M",
17
+ "description": "Vision Transformer Tiny (16x16 patches, 224x224 input)",
18
+ "config": {
19
+ "img_size": 224,
20
+ "patch_size": 16,
21
+ "embed_dim": 192,
22
+ "depth": 12,
23
+ "num_heads": 3,
24
+ "mlp_ratio": 4.0,
25
+ "num_classes": 1000
26
+ }
27
+ },
28
+ "vit_small_patch16_224": {
29
+ "name": "vit_small_patch16_224",
30
+ "type": "vision_transformer",
31
+ "size": "22.1M",
32
+ "description": "Vision Transformer Small (16x16 patches, 224x224 input)",
33
+ "config": {
34
+ "img_size": 224,
35
+ "patch_size": 16,
36
+ "embed_dim": 384,
37
+ "depth": 12,
38
+ "num_heads": 6,
39
+ "mlp_ratio": 4.0,
40
+ "num_classes": 1000
41
+ }
42
+ },
43
+ "vit_base_patch16_224": {
44
+ "name": "vit_base_patch16_224",
45
+ "type": "vision_transformer",
46
+ "size": "86.4M",
47
+ "description": "Vision Transformer Base (16x16 patches, 224x224 input)",
48
+ "config": {
49
+ "img_size": 224,
50
+ "patch_size": 16,
51
+ "embed_dim": 768,
52
+ "depth": 12,
53
+ "num_heads": 12,
54
+ "mlp_ratio": 4.0,
55
+ "num_classes": 1000
56
+ }
57
+ },
58
+ "vit_large_patch16_224": {
59
+ "name": "vit_large_patch16_224",
60
+ "type": "vision_transformer",
61
+ "size": "304.3M",
62
+ "description": "Vision Transformer Large (16x16 patches, 224x224 input)",
63
+ "config": {
64
+ "img_size": 224,
65
+ "patch_size": 16,
66
+ "embed_dim": 1024,
67
+ "depth": 24,
68
+ "num_heads": 16,
69
+ "mlp_ratio": 4.0,
70
+ "num_classes": 1000
71
+ }
72
+ }
73
+ }
74
+
75
+
76
+ def get_available_models() -> List[Dict[str, Any]]:
77
+ """Get list of all available models."""
78
+ return list(DEFAULT_MODELS.values())
79
+
80
+
81
+ def get_model_info(model_name: str) -> Dict[str, Any]:
82
+ """Get detailed information about a specific model."""
83
+ if model_name not in DEFAULT_MODELS:
84
+ raise ValueError(f"Model '{model_name}' not found. Available models: {list(DEFAULT_MODELS.keys())}")
85
+
86
+ return DEFAULT_MODELS[model_name]
87
+
88
+
89
+ def download_model(model_name: str, output_dir: str = "./models", force: bool = False) -> str:
90
+ """Download a pre-trained model (placeholder implementation)."""
91
+ if model_name not in DEFAULT_MODELS:
92
+ raise ValueError(f"Model '{model_name}' not found. Available models: {list(DEFAULT_MODELS.keys())}")
93
+
94
+ # Create output directory
95
+ os.makedirs(output_dir, exist_ok=True)
96
+
97
+ # For now, just save the model configuration
98
+ # In a real implementation, this would download actual model weights
99
+ model_info = DEFAULT_MODELS[model_name]
100
+ output_path = os.path.join(output_dir, f"{model_name}.json")
101
+
102
+ if os.path.exists(output_path) and not force:
103
+ raise FileExistsError(f"Model already exists at {output_path}. Use --force to overwrite.")
104
+
105
+ with open(output_path, 'w') as f:
106
+ json.dump(model_info, f, indent=2)
107
+
108
+ return output_path
109
+
110
+
111
+ def list_models_by_type(model_type: str = None) -> List[Dict[str, Any]]:
112
+ """List models filtered by type."""
113
+ models = get_available_models()
114
+ if model_type:
115
+ models = [m for m in models if m.get('type') == model_type]
116
+ return models
117
+
118
+
119
+ def search_models(query: str) -> List[Dict[str, Any]]:
120
+ """Search models by name or description."""
121
+ models = get_available_models()
122
+ query_lower = query.lower()
123
+
124
+ results = []
125
+ for model in models:
126
+ if (query_lower in model.get('name', '').lower() or
127
+ query_lower in model.get('description', '').lower()):
128
+ results.append(model)
129
+
130
+ return results
131
+
132
+
133
+ class ModelZoo:
134
+ """Model Zoo manager for Langvision."""
135
+
136
+ def __init__(self, cache_dir: str = "./models"):
137
+ self.cache_dir = Path(cache_dir)
138
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
139
+ self.logger = logging.getLogger(__name__)
140
+
141
+ def list_models(self, model_type: str = None) -> List[Dict[str, Any]]:
142
+ """List available models."""
143
+ return list_models_by_type(model_type)
144
+
145
+ def get_model(self, model_name: str) -> Dict[str, Any]:
146
+ """Get model information."""
147
+ return get_model_info(model_name)
148
+
149
+ def download(self, model_name: str, force: bool = False) -> str:
150
+ """Download a model."""
151
+ return download_model(model_name, str(self.cache_dir), force)
152
+
153
+ def search(self, query: str) -> List[Dict[str, Any]]:
154
+ """Search for models."""
155
+ return search_models(query)
156
+
157
+ def is_downloaded(self, model_name: str) -> bool:
158
+ """Check if model is downloaded."""
159
+ model_path = self.cache_dir / f"{model_name}.json"
160
+ return model_path.exists()
161
+
162
+ def get_downloaded_models(self) -> List[str]:
163
+ """Get list of downloaded models."""
164
+ downloaded = []
165
+ for model_file in self.cache_dir.glob("*.json"):
166
+ model_name = model_file.stem
167
+ if model_name in DEFAULT_MODELS:
168
+ downloaded.append(model_name)
169
+ return downloaded
langvision/models/lora.py CHANGED
@@ -1,30 +1,202 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Dict, Any
5
+ import math
6
+ from dataclasses import dataclass
7
+
8
+ @dataclass
9
+ class LoRAConfig:
10
+ """Configuration class for LoRA parameters."""
11
+ r: int = 4
12
+ alpha: float = 1.0
13
+ dropout: float = 0.0
14
+ target_modules: Optional[list] = None
15
+ bias: str = "none" # "none", "all", "lora_only"
16
+ task_type: str = "FEATURE_EXTRACTION"
17
+ inference_mode: bool = False
18
+
19
+ def __post_init__(self):
20
+ if self.target_modules is None:
21
+ self.target_modules = ["q_proj", "v_proj", "k_proj", "o_proj"]
3
22
 
4
23
  class LoRALinear(nn.Module):
5
- def __init__(self, in_features, out_features, r=4, alpha=1.0, dropout=0.0):
24
+ """Enhanced LoRA Linear layer with better initialization and features."""
25
+
26
+ def __init__(self,
27
+ in_features: int,
28
+ out_features: int,
29
+ r: int = 4,
30
+ alpha: float = 1.0,
31
+ dropout: float = 0.0,
32
+ bias: bool = True,
33
+ fan_in_fan_out: bool = False):
6
34
  super().__init__()
7
35
  self.r = r
8
36
  self.alpha = alpha
9
- self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
37
+ self.fan_in_fan_out = fan_in_fan_out
38
+ self.in_features = in_features
39
+ self.out_features = out_features
40
+
41
+ # Original linear layer (frozen)
42
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
43
+
44
+ # LoRA parameters
10
45
  if r > 0:
11
- self.lora_A = nn.Parameter(torch.randn(in_features, r) * 0.01)
12
- self.lora_B = nn.Parameter(torch.randn(r, out_features) * 0.01)
46
+ self.lora_A = nn.Parameter(torch.zeros(r, in_features))
47
+ self.lora_B = nn.Parameter(torch.zeros(out_features, r))
48
+ self.scaling = alpha / r
49
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
50
+ self.reset_parameters()
13
51
  else:
14
52
  self.lora_A = None
15
53
  self.lora_B = None
16
- self.scale = alpha / r if r > 0 else 1.0
54
+ self.scaling = 1.0
55
+ self.dropout = nn.Identity()
56
+
57
+ def reset_parameters(self):
58
+ """Initialize LoRA parameters using Kaiming uniform initialization."""
59
+ if hasattr(self, 'lora_A') and self.lora_A is not None:
60
+ # Initialize A with random values and B with zeros
61
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
62
+ nn.init.zeros_(self.lora_B)
63
+
64
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
65
+ """Forward pass with LoRA adaptation."""
66
+ # Original linear transformation
67
+ result = self.linear(x)
68
+
69
+ # Add LoRA adaptation if enabled
70
+ if self.r > 0 and self.lora_A is not None:
71
+ # Apply dropout to input
72
+ x_dropped = self.dropout(x)
73
+
74
+ # LoRA computation: x @ A^T @ B^T
75
+ if self.fan_in_fan_out:
76
+ lora_result = F.linear(x_dropped, self.lora_A.T) @ self.lora_B.T
77
+ else:
78
+ lora_result = F.linear(F.linear(x_dropped, self.lora_A.T), self.lora_B.T)
79
+
80
+ result = result + lora_result * self.scaling
81
+
82
+ return result
83
+
84
+ def merge_weights(self):
85
+ """Merge LoRA weights into the original linear layer."""
86
+ if self.r > 0 and self.lora_A is not None:
87
+ # Compute LoRA weight update
88
+ delta_w = self.lora_B @ self.lora_A * self.scaling
89
+
90
+ # Merge with original weights
91
+ if self.fan_in_fan_out:
92
+ self.linear.weight.data += delta_w.T
93
+ else:
94
+ self.linear.weight.data += delta_w
95
+
96
+ # Reset LoRA parameters
97
+ self.lora_A.data.zero_()
98
+ self.lora_B.data.zero_()
99
+
100
+ def unmerge_weights(self):
101
+ """Unmerge LoRA weights from the original linear layer."""
102
+ if self.r > 0 and self.lora_A is not None:
103
+ # Compute LoRA weight update
104
+ delta_w = self.lora_B @ self.lora_A * self.scaling
105
+
106
+ # Remove from original weights
107
+ if self.fan_in_fan_out:
108
+ self.linear.weight.data -= delta_w.T
109
+ else:
110
+ self.linear.weight.data -= delta_w
111
+
112
+ class AdaLoRALinear(LoRALinear):
113
+ """Adaptive LoRA with dynamic rank adjustment."""
114
+
115
+ def __init__(self, *args, **kwargs):
116
+ self.target_rank = kwargs.pop('target_rank', None)
117
+ self.rank_pattern = kwargs.pop('rank_pattern', None)
118
+ super().__init__(*args, **kwargs)
119
+
120
+ if self.target_rank is not None:
121
+ self.rank_scheduler = self._create_rank_scheduler()
122
+
123
+ def _create_rank_scheduler(self):
124
+ """Create a rank scheduler for adaptive rank adjustment."""
125
+ # Placeholder for rank scheduling logic
126
+ return None
127
+
128
+ def update_rank(self, new_rank: int):
129
+ """Dynamically update the LoRA rank."""
130
+ if new_rank != self.r and new_rank > 0:
131
+ old_r = self.r
132
+ self.r = new_rank
133
+ self.scaling = self.alpha / new_rank
134
+
135
+ # Resize parameters
136
+ if old_r > 0:
137
+ # Preserve existing weights up to min(old_r, new_rank)
138
+ min_r = min(old_r, new_rank)
139
+ old_A = self.lora_A.data[:min_r, :]
140
+ old_B = self.lora_B.data[:, :min_r]
141
+
142
+ # Create new parameters
143
+ self.lora_A = nn.Parameter(torch.zeros(new_rank, self.in_features))
144
+ self.lora_B = nn.Parameter(torch.zeros(self.out_features, new_rank))
145
+
146
+ if old_r > 0:
147
+ # Copy preserved weights
148
+ self.lora_A.data[:min_r, :] = old_A
149
+ self.lora_B.data[:, :min_r] = old_B
150
+
151
+ # Initialize new parameters
152
+ if new_rank > old_r:
153
+ nn.init.kaiming_uniform_(self.lora_A.data[old_r:, :], a=math.sqrt(5))
154
+ nn.init.zeros_(self.lora_B.data[:, old_r:])
17
155
 
18
- def forward(self, x):
156
+ class QLoRALinear(nn.Module):
157
+ """Quantized LoRA implementation for memory efficiency."""
158
+
159
+ def __init__(self,
160
+ in_features: int,
161
+ out_features: int,
162
+ r: int = 4,
163
+ alpha: float = 1.0,
164
+ dropout: float = 0.0,
165
+ compute_dtype: torch.dtype = torch.float16,
166
+ quant_type: str = "nf4"):
167
+ super().__init__()
168
+ self.r = r
169
+ self.alpha = alpha
170
+ self.scaling = alpha / r if r > 0 else 1.0
171
+ self.compute_dtype = compute_dtype
172
+ self.quant_type = quant_type
173
+
174
+ # Quantized base layer (placeholder - would need actual quantization library)
175
+ self.base_layer = nn.Linear(in_features, out_features, bias=False)
176
+
177
+ # LoRA adapters in full precision
178
+ if r > 0:
179
+ self.lora_A = nn.Parameter(torch.zeros(r, in_features, dtype=compute_dtype))
180
+ self.lora_B = nn.Parameter(torch.zeros(out_features, r, dtype=compute_dtype))
181
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
182
+ self.reset_parameters()
183
+
184
+ def reset_parameters(self):
185
+ if hasattr(self, 'lora_A') and self.lora_A is not None:
186
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
187
+ nn.init.zeros_(self.lora_B)
188
+
189
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
190
+ # Convert to compute dtype
191
+ x = x.to(self.compute_dtype)
192
+
193
+ # Base layer forward (quantized)
194
+ result = self.base_layer(x)
195
+
196
+ # LoRA adaptation
19
197
  if self.r > 0:
20
- orig_shape = x.shape
21
- # Flatten all but last dim
22
- x_2d = x.reshape(-1, x.shape[-1])
23
- # Apply LoRA: (N, in_features) @ (in_features, r) @ (r, out_features) = (N, out_features)
24
- lora_out = self.dropout(x_2d) @ self.lora_A @ self.lora_B * self.scale
25
- # Reshape back to original except last dim is out_features
26
- out_shape = list(orig_shape[:-1]) + [self.lora_B.shape[1]]
27
- lora_out = lora_out.view(*out_shape)
28
- return lora_out
29
- else:
30
- return 0.0
198
+ x_dropped = self.dropout(x)
199
+ lora_result = F.linear(F.linear(x_dropped, self.lora_A.T), self.lora_B.T)
200
+ result = result + lora_result * self.scaling
201
+
202
+ return result
@@ -0,0 +1,297 @@
1
+ """
2
+ Multimodal Vision-Language Models with LoRA fine-tuning support.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Optional, Dict, Any, Tuple, List
9
+ from transformers import AutoTokenizer, AutoModel
10
+ from .vision_transformer import VisionTransformer
11
+ from .lora import LoRALinear, LoRAConfig
12
+ import math
13
+
14
+
15
+ class CrossAttention(nn.Module):
16
+ """Cross-attention mechanism for vision-language fusion."""
17
+
18
+ def __init__(self,
19
+ vision_dim: int,
20
+ text_dim: int,
21
+ hidden_dim: int = 512,
22
+ num_heads: int = 8,
23
+ dropout: float = 0.1,
24
+ lora_config: Optional[LoRAConfig] = None):
25
+ super().__init__()
26
+ self.hidden_dim = hidden_dim
27
+ self.num_heads = num_heads
28
+ self.head_dim = hidden_dim // num_heads
29
+ self.scale = self.head_dim ** -0.5
30
+
31
+ # Query, Key, Value projections
32
+ if lora_config and lora_config.r > 0:
33
+ self.q_proj = LoRALinear(vision_dim, hidden_dim,
34
+ r=lora_config.r, alpha=lora_config.alpha, dropout=lora_config.dropout)
35
+ self.k_proj = LoRALinear(text_dim, hidden_dim,
36
+ r=lora_config.r, alpha=lora_config.alpha, dropout=lora_config.dropout)
37
+ self.v_proj = LoRALinear(text_dim, hidden_dim,
38
+ r=lora_config.r, alpha=lora_config.alpha, dropout=lora_config.dropout)
39
+ self.out_proj = LoRALinear(hidden_dim, vision_dim,
40
+ r=lora_config.r, alpha=lora_config.alpha, dropout=lora_config.dropout)
41
+ else:
42
+ self.q_proj = nn.Linear(vision_dim, hidden_dim)
43
+ self.k_proj = nn.Linear(text_dim, hidden_dim)
44
+ self.v_proj = nn.Linear(text_dim, hidden_dim)
45
+ self.out_proj = nn.Linear(hidden_dim, vision_dim)
46
+
47
+ self.dropout = nn.Dropout(dropout)
48
+ self.layer_norm = nn.LayerNorm(vision_dim)
49
+
50
+ def forward(self,
51
+ vision_features: torch.Tensor,
52
+ text_features: torch.Tensor,
53
+ attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
54
+ """
55
+ Args:
56
+ vision_features: (B, N_v, D_v) vision tokens
57
+ text_features: (B, N_t, D_t) text tokens
58
+ attention_mask: (B, N_t) mask for text tokens
59
+ """
60
+ B, N_v, D_v = vision_features.shape
61
+ B, N_t, D_t = text_features.shape
62
+
63
+ # Project to query, key, value
64
+ Q = self.q_proj(vision_features) # (B, N_v, hidden_dim)
65
+ K = self.k_proj(text_features) # (B, N_t, hidden_dim)
66
+ V = self.v_proj(text_features) # (B, N_t, hidden_dim)
67
+
68
+ # Reshape for multi-head attention
69
+ Q = Q.view(B, N_v, self.num_heads, self.head_dim).transpose(1, 2) # (B, num_heads, N_v, head_dim)
70
+ K = K.view(B, N_t, self.num_heads, self.head_dim).transpose(1, 2) # (B, num_heads, N_t, head_dim)
71
+ V = V.view(B, N_t, self.num_heads, self.head_dim).transpose(1, 2) # (B, num_heads, N_t, head_dim)
72
+
73
+ # Compute attention scores
74
+ attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # (B, num_heads, N_v, N_t)
75
+
76
+ # Apply attention mask if provided
77
+ if attention_mask is not None:
78
+ attn_mask = attention_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, N_t)
79
+ attn_scores = attn_scores.masked_fill(attn_mask == 0, float('-inf'))
80
+
81
+ # Apply softmax
82
+ attn_weights = F.softmax(attn_scores, dim=-1)
83
+ attn_weights = self.dropout(attn_weights)
84
+
85
+ # Apply attention to values
86
+ attn_output = torch.matmul(attn_weights, V) # (B, num_heads, N_v, head_dim)
87
+
88
+ # Reshape and project
89
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, N_v, self.hidden_dim)
90
+ output = self.out_proj(attn_output)
91
+
92
+ # Residual connection and layer norm
93
+ output = self.layer_norm(vision_features + output)
94
+
95
+ return output
96
+
97
+
98
+ class VisionLanguageModel(nn.Module):
99
+ """Vision-Language Model with cross-modal attention and LoRA fine-tuning."""
100
+
101
+ def __init__(self,
102
+ vision_model: str = "vit_base",
103
+ text_model: str = "bert-base-uncased",
104
+ vision_dim: int = 768,
105
+ text_dim: int = 768,
106
+ hidden_dim: int = 512,
107
+ num_classes: int = 1000,
108
+ max_text_length: int = 77,
109
+ lora_config: Optional[LoRAConfig] = None):
110
+ super().__init__()
111
+
112
+ self.vision_dim = vision_dim
113
+ self.text_dim = text_dim
114
+ self.hidden_dim = hidden_dim
115
+ self.max_text_length = max_text_length
116
+
117
+ # Vision encoder
118
+ if vision_model == "vit_base":
119
+ self.vision_encoder = VisionTransformer(
120
+ embed_dim=vision_dim,
121
+ lora_config=lora_config
122
+ )
123
+ else:
124
+ raise ValueError(f"Unsupported vision model: {vision_model}")
125
+
126
+ # Text encoder
127
+ self.tokenizer = AutoTokenizer.from_pretrained(text_model)
128
+ self.text_encoder = AutoModel.from_pretrained(text_model)
129
+
130
+ # Freeze text encoder if using LoRA
131
+ if lora_config and lora_config.r > 0:
132
+ for param in self.text_encoder.parameters():
133
+ param.requires_grad = False
134
+
135
+ # Cross-modal fusion
136
+ self.cross_attention = CrossAttention(
137
+ vision_dim=vision_dim,
138
+ text_dim=text_dim,
139
+ hidden_dim=hidden_dim,
140
+ lora_config=lora_config
141
+ )
142
+
143
+ # Classification head
144
+ if lora_config and lora_config.r > 0:
145
+ self.classifier = LoRALinear(vision_dim, num_classes,
146
+ r=lora_config.r, alpha=lora_config.alpha, dropout=lora_config.dropout)
147
+ else:
148
+ self.classifier = nn.Linear(vision_dim, num_classes)
149
+
150
+ # Text projection for contrastive learning
151
+ if lora_config and lora_config.r > 0:
152
+ self.text_projection = LoRALinear(text_dim, hidden_dim,
153
+ r=lora_config.r, alpha=lora_config.alpha, dropout=lora_config.dropout)
154
+ self.vision_projection = LoRALinear(vision_dim, hidden_dim,
155
+ r=lora_config.r, alpha=lora_config.alpha, dropout=lora_config.dropout)
156
+ else:
157
+ self.text_projection = nn.Linear(text_dim, hidden_dim)
158
+ self.vision_projection = nn.Linear(vision_dim, hidden_dim)
159
+
160
+ self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
161
+
162
+ def encode_image(self, images: torch.Tensor) -> torch.Tensor:
163
+ """Encode images to feature representations."""
164
+ return self.vision_encoder(images)
165
+
166
+ def encode_text(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
167
+ """Encode texts to feature representations."""
168
+ # Tokenize texts
169
+ tokens = self.tokenizer(
170
+ texts,
171
+ padding=True,
172
+ truncation=True,
173
+ max_length=self.max_text_length,
174
+ return_tensors="pt"
175
+ )
176
+
177
+ # Get text features
178
+ with torch.no_grad() if hasattr(self, 'text_encoder') else torch.enable_grad():
179
+ text_outputs = self.text_encoder(**tokens)
180
+ text_features = text_outputs.last_hidden_state # (B, seq_len, text_dim)
181
+
182
+ return text_features, tokens['attention_mask']
183
+
184
+ def forward(self,
185
+ images: torch.Tensor,
186
+ texts: Optional[List[str]] = None,
187
+ return_features: bool = False) -> Dict[str, torch.Tensor]:
188
+ """
189
+ Forward pass for vision-language model.
190
+
191
+ Args:
192
+ images: (B, C, H, W) input images
193
+ texts: List of text descriptions
194
+ return_features: Whether to return intermediate features
195
+ """
196
+ # Encode images
197
+ vision_features = self.encode_image(images) # (B, N_patches, vision_dim)
198
+
199
+ outputs = {"vision_features": vision_features}
200
+
201
+ if texts is not None:
202
+ # Encode texts
203
+ text_features, attention_mask = self.encode_text(texts) # (B, seq_len, text_dim)
204
+ outputs["text_features"] = text_features
205
+
206
+ # Cross-modal fusion
207
+ fused_features = self.cross_attention(
208
+ vision_features, text_features, attention_mask
209
+ ) # (B, N_patches, vision_dim)
210
+ outputs["fused_features"] = fused_features
211
+
212
+ # Global pooling for classification
213
+ pooled_vision = fused_features.mean(dim=1) # (B, vision_dim)
214
+ pooled_text = text_features.mean(dim=1) # (B, text_dim)
215
+
216
+ # Classification
217
+ logits = self.classifier(pooled_vision)
218
+ outputs["logits"] = logits
219
+
220
+ # Contrastive learning projections
221
+ vision_proj = F.normalize(self.vision_projection(pooled_vision), dim=-1)
222
+ text_proj = F.normalize(self.text_projection(pooled_text), dim=-1)
223
+
224
+ # Compute contrastive logits
225
+ logit_scale = self.logit_scale.exp()
226
+ contrastive_logits = logit_scale * vision_proj @ text_proj.T
227
+
228
+ outputs.update({
229
+ "vision_proj": vision_proj,
230
+ "text_proj": text_proj,
231
+ "contrastive_logits": contrastive_logits,
232
+ "logit_scale": logit_scale
233
+ })
234
+ else:
235
+ # Image-only classification
236
+ pooled_vision = vision_features.mean(dim=1)
237
+ logits = self.classifier(pooled_vision)
238
+ outputs["logits"] = logits
239
+
240
+ if not return_features:
241
+ # Return only essential outputs for training
242
+ essential_keys = ["logits", "contrastive_logits"] if texts else ["logits"]
243
+ outputs = {k: v for k, v in outputs.items() if k in essential_keys}
244
+
245
+ return outputs
246
+
247
+
248
+ class CLIPLoss(nn.Module):
249
+ """CLIP-style contrastive loss for vision-language learning."""
250
+
251
+ def __init__(self, temperature: float = 0.07):
252
+ super().__init__()
253
+ self.temperature = temperature
254
+ self.cross_entropy = nn.CrossEntropyLoss()
255
+
256
+ def forward(self, vision_proj: torch.Tensor, text_proj: torch.Tensor) -> torch.Tensor:
257
+ """
258
+ Compute contrastive loss between vision and text projections.
259
+
260
+ Args:
261
+ vision_proj: (B, D) normalized vision projections
262
+ text_proj: (B, D) normalized text projections
263
+ """
264
+ batch_size = vision_proj.shape[0]
265
+
266
+ # Compute similarity matrix
267
+ logits = vision_proj @ text_proj.T / self.temperature
268
+
269
+ # Labels are diagonal (each image matches its corresponding text)
270
+ labels = torch.arange(batch_size, device=logits.device)
271
+
272
+ # Symmetric loss (image-to-text and text-to-image)
273
+ loss_i2t = self.cross_entropy(logits, labels)
274
+ loss_t2i = self.cross_entropy(logits.T, labels)
275
+
276
+ return (loss_i2t + loss_t2i) / 2
277
+
278
+
279
+ def create_multimodal_model(model_config: Dict[str, Any]) -> VisionLanguageModel:
280
+ """Factory function to create multimodal models with different configurations."""
281
+
282
+ # Default configuration
283
+ default_config = {
284
+ "vision_model": "vit_base",
285
+ "text_model": "bert-base-uncased",
286
+ "vision_dim": 768,
287
+ "text_dim": 768,
288
+ "hidden_dim": 512,
289
+ "num_classes": 1000,
290
+ "max_text_length": 77,
291
+ "lora_config": None
292
+ }
293
+
294
+ # Update with user config
295
+ config = {**default_config, **model_config}
296
+
297
+ return VisionLanguageModel(**config)