langvision 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langvision might be problematic. Click here for more details.
- langvision/__init__.py +77 -2
- langvision/callbacks/base.py +166 -7
- langvision/cli/__init__.py +85 -0
- langvision/cli/complete_cli.py +319 -0
- langvision/cli/config.py +344 -0
- langvision/cli/evaluate.py +201 -0
- langvision/cli/export.py +177 -0
- langvision/cli/finetune.py +165 -48
- langvision/cli/model_zoo.py +162 -0
- langvision/cli/train.py +27 -13
- langvision/cli/utils.py +258 -0
- langvision/components/attention.py +4 -1
- langvision/concepts/__init__.py +9 -0
- langvision/concepts/ccot.py +30 -0
- langvision/concepts/cot.py +29 -0
- langvision/concepts/dpo.py +37 -0
- langvision/concepts/grpo.py +25 -0
- langvision/concepts/lime.py +37 -0
- langvision/concepts/ppo.py +47 -0
- langvision/concepts/rlhf.py +40 -0
- langvision/concepts/rlvr.py +25 -0
- langvision/concepts/shap.py +37 -0
- langvision/data/enhanced_datasets.py +582 -0
- langvision/model_zoo.py +169 -2
- langvision/models/lora.py +189 -17
- langvision/models/multimodal.py +297 -0
- langvision/models/resnet.py +303 -0
- langvision/training/advanced_trainer.py +478 -0
- langvision/training/trainer.py +30 -2
- langvision/utils/config.py +180 -9
- langvision/utils/metrics.py +448 -0
- langvision/utils/setup.py +266 -0
- langvision-0.1.0.dist-info/METADATA +50 -0
- langvision-0.1.0.dist-info/RECORD +61 -0
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/WHEEL +1 -1
- langvision-0.1.0.dist-info/entry_points.txt +2 -0
- langvision-0.0.1.dist-info/METADATA +0 -463
- langvision-0.0.1.dist-info/RECORD +0 -40
- langvision-0.0.1.dist-info/entry_points.txt +0 -2
- langvision-0.0.1.dist-info/licenses/LICENSE +0 -21
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/top_level.txt +0 -0
langvision/utils/config.py
CHANGED
|
@@ -1,15 +1,186 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration utilities for Langvision.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import json
|
|
7
|
+
import yaml
|
|
8
|
+
from typing import Dict, Any, Optional
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
|
|
1
12
|
default_config = {
|
|
2
|
-
'
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
13
|
+
'model': {
|
|
14
|
+
'name': 'vit_base',
|
|
15
|
+
'img_size': 224,
|
|
16
|
+
'patch_size': 16,
|
|
17
|
+
'in_chans': 3,
|
|
18
|
+
'num_classes': 1000,
|
|
19
|
+
'embed_dim': 768,
|
|
20
|
+
'depth': 12,
|
|
21
|
+
'num_heads': 12,
|
|
22
|
+
'mlp_ratio': 4.0,
|
|
23
|
+
'dropout': 0.1,
|
|
24
|
+
'attention_dropout': 0.1,
|
|
25
|
+
},
|
|
26
|
+
'data': {
|
|
27
|
+
'dataset': 'cifar10',
|
|
28
|
+
'data_dir': './data',
|
|
29
|
+
'batch_size': 64,
|
|
30
|
+
'num_workers': 2,
|
|
31
|
+
'pin_memory': True,
|
|
32
|
+
'persistent_workers': False,
|
|
33
|
+
},
|
|
34
|
+
'training': {
|
|
35
|
+
'epochs': 10,
|
|
36
|
+
'learning_rate': 1e-3,
|
|
37
|
+
'optimizer': 'adam',
|
|
38
|
+
'weight_decay': 0.01,
|
|
39
|
+
'scheduler': 'cosine',
|
|
40
|
+
'warmup_epochs': 0,
|
|
41
|
+
'min_lr': 1e-6,
|
|
42
|
+
'gradient_clip': None,
|
|
43
|
+
},
|
|
10
44
|
'lora': {
|
|
11
45
|
'r': 4,
|
|
12
46
|
'alpha': 1.0,
|
|
13
47
|
'dropout': 0.1,
|
|
48
|
+
'target_modules': ['attention.qkv', 'attention.proj', 'mlp.fc1', 'mlp.fc2'],
|
|
49
|
+
},
|
|
50
|
+
'callbacks': {
|
|
51
|
+
'early_stopping': {
|
|
52
|
+
'enabled': False,
|
|
53
|
+
'patience': 5,
|
|
54
|
+
'min_delta': 0.001,
|
|
55
|
+
},
|
|
56
|
+
'checkpointing': {
|
|
57
|
+
'enabled': True,
|
|
58
|
+
'save_best': True,
|
|
59
|
+
'save_last': True,
|
|
60
|
+
},
|
|
61
|
+
},
|
|
62
|
+
'logging': {
|
|
63
|
+
'level': 'info',
|
|
64
|
+
'log_interval': 100,
|
|
65
|
+
'save_interval': 5,
|
|
66
|
+
},
|
|
67
|
+
'output': {
|
|
68
|
+
'output_dir': './outputs',
|
|
69
|
+
'save_name': 'vit_lora_best.pth',
|
|
70
|
+
},
|
|
71
|
+
'device': {
|
|
72
|
+
'device': 'cuda',
|
|
73
|
+
'cuda_deterministic': False,
|
|
74
|
+
'cuda_benchmark': True,
|
|
75
|
+
},
|
|
76
|
+
'misc': {
|
|
77
|
+
'seed': 42,
|
|
78
|
+
'log_level': 'info',
|
|
14
79
|
},
|
|
15
|
-
}
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def load_config(config_path: str) -> Dict[str, Any]:
|
|
84
|
+
"""Load configuration from YAML or JSON file."""
|
|
85
|
+
config_path = Path(config_path)
|
|
86
|
+
|
|
87
|
+
if not config_path.exists():
|
|
88
|
+
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
|
89
|
+
|
|
90
|
+
with open(config_path, 'r') as f:
|
|
91
|
+
if config_path.suffix.lower() in ['.yaml', '.yml']:
|
|
92
|
+
config = yaml.safe_load(f)
|
|
93
|
+
elif config_path.suffix.lower() == '.json':
|
|
94
|
+
config = json.load(f)
|
|
95
|
+
else:
|
|
96
|
+
raise ValueError(f"Unsupported config file format: {config_path.suffix}")
|
|
97
|
+
|
|
98
|
+
return config
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def save_config(config: Dict[str, Any], config_path: str) -> None:
|
|
102
|
+
"""Save configuration to YAML or JSON file."""
|
|
103
|
+
config_path = Path(config_path)
|
|
104
|
+
|
|
105
|
+
# Create parent directory if it doesn't exist
|
|
106
|
+
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
107
|
+
|
|
108
|
+
with open(config_path, 'w') as f:
|
|
109
|
+
if config_path.suffix.lower() in ['.yaml', '.yml']:
|
|
110
|
+
yaml.dump(config, f, default_flow_style=False, indent=2)
|
|
111
|
+
elif config_path.suffix.lower() == '.json':
|
|
112
|
+
json.dump(config, f, indent=2)
|
|
113
|
+
else:
|
|
114
|
+
raise ValueError(f"Unsupported config file format: {config_path.suffix}")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def merge_configs(base_config: Dict[str, Any], override_config: Dict[str, Any]) -> Dict[str, Any]:
|
|
118
|
+
"""Merge two configurations, with override_config taking precedence."""
|
|
119
|
+
merged = base_config.copy()
|
|
120
|
+
|
|
121
|
+
for key, value in override_config.items():
|
|
122
|
+
if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
|
|
123
|
+
merged[key] = merge_configs(merged[key], value)
|
|
124
|
+
else:
|
|
125
|
+
merged[key] = value
|
|
126
|
+
|
|
127
|
+
return merged
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def validate_config(config: Dict[str, Any]) -> bool:
|
|
131
|
+
"""Validate configuration structure."""
|
|
132
|
+
required_sections = ['model', 'data', 'training']
|
|
133
|
+
|
|
134
|
+
for section in required_sections:
|
|
135
|
+
if section not in config:
|
|
136
|
+
raise ValueError(f"Missing required configuration section: {section}")
|
|
137
|
+
|
|
138
|
+
# Validate model section
|
|
139
|
+
model_config = config['model']
|
|
140
|
+
required_model_keys = ['img_size', 'patch_size', 'num_classes']
|
|
141
|
+
for key in required_model_keys:
|
|
142
|
+
if key not in model_config:
|
|
143
|
+
raise ValueError(f"Missing required model configuration key: {key}")
|
|
144
|
+
|
|
145
|
+
# Validate data section
|
|
146
|
+
data_config = config['data']
|
|
147
|
+
required_data_keys = ['dataset', 'batch_size']
|
|
148
|
+
for key in required_data_keys:
|
|
149
|
+
if key not in data_config:
|
|
150
|
+
raise ValueError(f"Missing required data configuration key: {key}")
|
|
151
|
+
|
|
152
|
+
# Validate training section
|
|
153
|
+
training_config = config['training']
|
|
154
|
+
required_training_keys = ['epochs', 'learning_rate']
|
|
155
|
+
for key in required_training_keys:
|
|
156
|
+
if key not in training_config:
|
|
157
|
+
raise ValueError(f"Missing required training configuration key: {key}")
|
|
158
|
+
|
|
159
|
+
return True
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def get_config_value(config: Dict[str, Any], key_path: str, default: Any = None) -> Any:
|
|
163
|
+
"""Get configuration value using dot notation (e.g., 'model.img_size')."""
|
|
164
|
+
keys = key_path.split('.')
|
|
165
|
+
value = config
|
|
166
|
+
|
|
167
|
+
for key in keys:
|
|
168
|
+
if isinstance(value, dict) and key in value:
|
|
169
|
+
value = value[key]
|
|
170
|
+
else:
|
|
171
|
+
return default
|
|
172
|
+
|
|
173
|
+
return value
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def set_config_value(config: Dict[str, Any], key_path: str, value: Any) -> None:
|
|
177
|
+
"""Set configuration value using dot notation (e.g., 'model.img_size')."""
|
|
178
|
+
keys = key_path.split('.')
|
|
179
|
+
current = config
|
|
180
|
+
|
|
181
|
+
for key in keys[:-1]:
|
|
182
|
+
if key not in current:
|
|
183
|
+
current[key] = {}
|
|
184
|
+
current = current[key]
|
|
185
|
+
|
|
186
|
+
current[keys[-1]] = value
|
|
@@ -0,0 +1,448 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Comprehensive metrics tracking and evaluation utilities.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
import numpy as np
|
|
8
|
+
from typing import Dict, List, Optional, Union, Any
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
import sklearn.metrics as sk_metrics
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
import warnings
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class MetricResult:
|
|
17
|
+
"""Container for metric computation results."""
|
|
18
|
+
value: float
|
|
19
|
+
count: int
|
|
20
|
+
sum: float
|
|
21
|
+
|
|
22
|
+
def __post_init__(self):
|
|
23
|
+
if self.count == 0:
|
|
24
|
+
self.value = 0.0
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MetricsTracker:
|
|
28
|
+
"""Advanced metrics tracking with support for various metric types."""
|
|
29
|
+
|
|
30
|
+
def __init__(self):
|
|
31
|
+
self.metrics = defaultdict(list)
|
|
32
|
+
self.counts = defaultdict(int)
|
|
33
|
+
self.sums = defaultdict(float)
|
|
34
|
+
|
|
35
|
+
def update(self, name: str, value: float, count: int = 1):
|
|
36
|
+
"""Update a metric with a new value."""
|
|
37
|
+
self.metrics[name].append(value)
|
|
38
|
+
self.counts[name] += count
|
|
39
|
+
self.sums[name] += value * count
|
|
40
|
+
|
|
41
|
+
def get_average(self, name: str) -> float:
|
|
42
|
+
"""Get the average value of a metric."""
|
|
43
|
+
if self.counts[name] == 0:
|
|
44
|
+
return 0.0
|
|
45
|
+
return self.sums[name] / self.counts[name]
|
|
46
|
+
|
|
47
|
+
def get_averages(self, names: Optional[List[str]] = None) -> Dict[str, float]:
|
|
48
|
+
"""Get average values for multiple metrics."""
|
|
49
|
+
if names is None:
|
|
50
|
+
names = list(self.metrics.keys())
|
|
51
|
+
|
|
52
|
+
return {name: self.get_average(name) for name in names}
|
|
53
|
+
|
|
54
|
+
def get_latest(self, name: str) -> float:
|
|
55
|
+
"""Get the latest value of a metric."""
|
|
56
|
+
if not self.metrics[name]:
|
|
57
|
+
return 0.0
|
|
58
|
+
return self.metrics[name][-1]
|
|
59
|
+
|
|
60
|
+
def reset(self, names: Optional[List[str]] = None):
|
|
61
|
+
"""Reset metrics."""
|
|
62
|
+
if names is None:
|
|
63
|
+
self.metrics.clear()
|
|
64
|
+
self.counts.clear()
|
|
65
|
+
self.sums.clear()
|
|
66
|
+
else:
|
|
67
|
+
for name in names:
|
|
68
|
+
if name in self.metrics:
|
|
69
|
+
del self.metrics[name]
|
|
70
|
+
if name in self.counts:
|
|
71
|
+
del self.counts[name]
|
|
72
|
+
if name in self.sums:
|
|
73
|
+
del self.sums[name]
|
|
74
|
+
|
|
75
|
+
def get_summary(self) -> Dict[str, Dict[str, float]]:
|
|
76
|
+
"""Get comprehensive summary of all metrics."""
|
|
77
|
+
summary = {}
|
|
78
|
+
for name in self.metrics.keys():
|
|
79
|
+
values = self.metrics[name]
|
|
80
|
+
if values:
|
|
81
|
+
summary[name] = {
|
|
82
|
+
'mean': np.mean(values),
|
|
83
|
+
'std': np.std(values),
|
|
84
|
+
'min': np.min(values),
|
|
85
|
+
'max': np.max(values),
|
|
86
|
+
'count': len(values),
|
|
87
|
+
'latest': values[-1]
|
|
88
|
+
}
|
|
89
|
+
return summary
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ClassificationMetrics:
|
|
93
|
+
"""Comprehensive classification metrics computation."""
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def accuracy(predictions: torch.Tensor, targets: torch.Tensor) -> float:
|
|
97
|
+
"""Compute accuracy."""
|
|
98
|
+
correct = (predictions.argmax(dim=1) == targets).float()
|
|
99
|
+
return correct.mean().item()
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def top_k_accuracy(predictions: torch.Tensor, targets: torch.Tensor, k: int = 5) -> float:
|
|
103
|
+
"""Compute top-k accuracy."""
|
|
104
|
+
_, top_k_preds = predictions.topk(k, dim=1)
|
|
105
|
+
correct = top_k_preds.eq(targets.view(-1, 1).expand_as(top_k_preds))
|
|
106
|
+
return correct.any(dim=1).float().mean().item()
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def precision_recall_f1(predictions: torch.Tensor,
|
|
110
|
+
targets: torch.Tensor,
|
|
111
|
+
average: str = 'weighted') -> Dict[str, float]:
|
|
112
|
+
"""Compute precision, recall, and F1 score."""
|
|
113
|
+
preds = predictions.argmax(dim=1).cpu().numpy()
|
|
114
|
+
targets = targets.cpu().numpy()
|
|
115
|
+
|
|
116
|
+
precision = sk_metrics.precision_score(targets, preds, average=average, zero_division=0)
|
|
117
|
+
recall = sk_metrics.recall_score(targets, preds, average=average, zero_division=0)
|
|
118
|
+
f1 = sk_metrics.f1_score(targets, preds, average=average, zero_division=0)
|
|
119
|
+
|
|
120
|
+
return {
|
|
121
|
+
'precision': precision,
|
|
122
|
+
'recall': recall,
|
|
123
|
+
'f1': f1
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
@staticmethod
|
|
127
|
+
def confusion_matrix(predictions: torch.Tensor, targets: torch.Tensor) -> np.ndarray:
|
|
128
|
+
"""Compute confusion matrix."""
|
|
129
|
+
preds = predictions.argmax(dim=1).cpu().numpy()
|
|
130
|
+
targets = targets.cpu().numpy()
|
|
131
|
+
return sk_metrics.confusion_matrix(targets, preds)
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def classification_report(predictions: torch.Tensor,
|
|
135
|
+
targets: torch.Tensor,
|
|
136
|
+
class_names: Optional[List[str]] = None) -> str:
|
|
137
|
+
"""Generate classification report."""
|
|
138
|
+
preds = predictions.argmax(dim=1).cpu().numpy()
|
|
139
|
+
targets = targets.cpu().numpy()
|
|
140
|
+
return sk_metrics.classification_report(targets, preds, target_names=class_names)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class ContrastiveMetrics:
|
|
144
|
+
"""Metrics for contrastive learning and multimodal models."""
|
|
145
|
+
|
|
146
|
+
@staticmethod
|
|
147
|
+
def contrastive_accuracy(image_features: torch.Tensor,
|
|
148
|
+
text_features: torch.Tensor,
|
|
149
|
+
temperature: float = 0.07) -> Dict[str, float]:
|
|
150
|
+
"""Compute contrastive accuracy (image-to-text and text-to-image)."""
|
|
151
|
+
# Normalize features
|
|
152
|
+
image_features = F.normalize(image_features, dim=-1)
|
|
153
|
+
text_features = F.normalize(text_features, dim=-1)
|
|
154
|
+
|
|
155
|
+
# Compute similarity matrix
|
|
156
|
+
logits = torch.matmul(image_features, text_features.T) / temperature
|
|
157
|
+
|
|
158
|
+
# Ground truth labels (diagonal)
|
|
159
|
+
batch_size = image_features.shape[0]
|
|
160
|
+
labels = torch.arange(batch_size, device=logits.device)
|
|
161
|
+
|
|
162
|
+
# Image-to-text accuracy
|
|
163
|
+
i2t_acc = (logits.argmax(dim=1) == labels).float().mean().item()
|
|
164
|
+
|
|
165
|
+
# Text-to-image accuracy
|
|
166
|
+
t2i_acc = (logits.T.argmax(dim=1) == labels).float().mean().item()
|
|
167
|
+
|
|
168
|
+
return {
|
|
169
|
+
'i2t_accuracy': i2t_acc,
|
|
170
|
+
't2i_accuracy': t2i_acc,
|
|
171
|
+
'mean_accuracy': (i2t_acc + t2i_acc) / 2
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
@staticmethod
|
|
175
|
+
def retrieval_metrics(image_features: torch.Tensor,
|
|
176
|
+
text_features: torch.Tensor,
|
|
177
|
+
k_values: List[int] = [1, 5, 10]) -> Dict[str, float]:
|
|
178
|
+
"""Compute retrieval metrics (Recall@K)."""
|
|
179
|
+
# Normalize features
|
|
180
|
+
image_features = F.normalize(image_features, dim=-1)
|
|
181
|
+
text_features = F.normalize(text_features, dim=-1)
|
|
182
|
+
|
|
183
|
+
# Compute similarity matrix
|
|
184
|
+
similarities = torch.matmul(image_features, text_features.T)
|
|
185
|
+
|
|
186
|
+
batch_size = similarities.shape[0]
|
|
187
|
+
metrics = {}
|
|
188
|
+
|
|
189
|
+
# Image-to-text retrieval
|
|
190
|
+
for k in k_values:
|
|
191
|
+
if k <= batch_size:
|
|
192
|
+
# Get top-k indices for each image
|
|
193
|
+
_, top_k_indices = similarities.topk(k, dim=1)
|
|
194
|
+
|
|
195
|
+
# Check if correct text is in top-k
|
|
196
|
+
correct_indices = torch.arange(batch_size, device=similarities.device).unsqueeze(1)
|
|
197
|
+
hits = (top_k_indices == correct_indices).any(dim=1)
|
|
198
|
+
recall_at_k = hits.float().mean().item()
|
|
199
|
+
|
|
200
|
+
metrics[f'i2t_recall@{k}'] = recall_at_k
|
|
201
|
+
|
|
202
|
+
# Text-to-image retrieval
|
|
203
|
+
similarities_t2i = similarities.T
|
|
204
|
+
for k in k_values:
|
|
205
|
+
if k <= batch_size:
|
|
206
|
+
_, top_k_indices = similarities_t2i.topk(k, dim=1)
|
|
207
|
+
correct_indices = torch.arange(batch_size, device=similarities.device).unsqueeze(1)
|
|
208
|
+
hits = (top_k_indices == correct_indices).any(dim=1)
|
|
209
|
+
recall_at_k = hits.float().mean().item()
|
|
210
|
+
|
|
211
|
+
metrics[f't2i_recall@{k}'] = recall_at_k
|
|
212
|
+
|
|
213
|
+
return metrics
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class PerformanceMetrics:
|
|
217
|
+
"""Performance and efficiency metrics."""
|
|
218
|
+
|
|
219
|
+
@staticmethod
|
|
220
|
+
def model_size(model: torch.nn.Module) -> Dict[str, int]:
|
|
221
|
+
"""Calculate model size metrics."""
|
|
222
|
+
total_params = sum(p.numel() for p in model.parameters())
|
|
223
|
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
224
|
+
|
|
225
|
+
return {
|
|
226
|
+
'total_parameters': total_params,
|
|
227
|
+
'trainable_parameters': trainable_params,
|
|
228
|
+
'frozen_parameters': total_params - trainable_params,
|
|
229
|
+
'trainable_ratio': trainable_params / total_params if total_params > 0 else 0
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
@staticmethod
|
|
233
|
+
def memory_usage() -> Dict[str, float]:
|
|
234
|
+
"""Get GPU memory usage statistics."""
|
|
235
|
+
if torch.cuda.is_available():
|
|
236
|
+
return {
|
|
237
|
+
'allocated_mb': torch.cuda.memory_allocated() / 1024**2,
|
|
238
|
+
'cached_mb': torch.cuda.memory_reserved() / 1024**2,
|
|
239
|
+
'max_allocated_mb': torch.cuda.max_memory_allocated() / 1024**2,
|
|
240
|
+
'max_cached_mb': torch.cuda.max_memory_reserved() / 1024**2
|
|
241
|
+
}
|
|
242
|
+
else:
|
|
243
|
+
return {
|
|
244
|
+
'allocated_mb': 0.0,
|
|
245
|
+
'cached_mb': 0.0,
|
|
246
|
+
'max_allocated_mb': 0.0,
|
|
247
|
+
'max_cached_mb': 0.0
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
@staticmethod
|
|
251
|
+
def throughput_metrics(batch_size: int,
|
|
252
|
+
processing_time: float,
|
|
253
|
+
num_samples: Optional[int] = None) -> Dict[str, float]:
|
|
254
|
+
"""Calculate throughput metrics."""
|
|
255
|
+
if num_samples is None:
|
|
256
|
+
num_samples = batch_size
|
|
257
|
+
|
|
258
|
+
return {
|
|
259
|
+
'samples_per_second': num_samples / processing_time if processing_time > 0 else 0,
|
|
260
|
+
'batches_per_second': 1 / processing_time if processing_time > 0 else 0,
|
|
261
|
+
'ms_per_sample': (processing_time * 1000) / num_samples if num_samples > 0 else 0,
|
|
262
|
+
'ms_per_batch': processing_time * 1000
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class EvaluationSuite:
|
|
267
|
+
"""Comprehensive evaluation suite for vision models."""
|
|
268
|
+
|
|
269
|
+
def __init__(self,
|
|
270
|
+
model: torch.nn.Module,
|
|
271
|
+
device: torch.device,
|
|
272
|
+
class_names: Optional[List[str]] = None):
|
|
273
|
+
self.model = model
|
|
274
|
+
self.device = device
|
|
275
|
+
self.class_names = class_names
|
|
276
|
+
self.metrics_tracker = MetricsTracker()
|
|
277
|
+
self.classification_metrics = ClassificationMetrics()
|
|
278
|
+
self.contrastive_metrics = ContrastiveMetrics()
|
|
279
|
+
self.performance_metrics = PerformanceMetrics()
|
|
280
|
+
|
|
281
|
+
def evaluate_classification(self,
|
|
282
|
+
dataloader: torch.utils.data.DataLoader,
|
|
283
|
+
return_predictions: bool = False) -> Dict[str, Any]:
|
|
284
|
+
"""Comprehensive classification evaluation."""
|
|
285
|
+
self.model.eval()
|
|
286
|
+
all_predictions = []
|
|
287
|
+
all_targets = []
|
|
288
|
+
|
|
289
|
+
with torch.no_grad():
|
|
290
|
+
for batch in dataloader:
|
|
291
|
+
images = batch['images'].to(self.device)
|
|
292
|
+
targets = batch['labels'].to(self.device)
|
|
293
|
+
|
|
294
|
+
# Forward pass
|
|
295
|
+
start_time = torch.cuda.Event(enable_timing=True)
|
|
296
|
+
end_time = torch.cuda.Event(enable_timing=True)
|
|
297
|
+
|
|
298
|
+
start_time.record()
|
|
299
|
+
outputs = self.model(images)
|
|
300
|
+
end_time.record()
|
|
301
|
+
|
|
302
|
+
torch.cuda.synchronize()
|
|
303
|
+
batch_time = start_time.elapsed_time(end_time) / 1000.0 # Convert to seconds
|
|
304
|
+
|
|
305
|
+
# Collect predictions and targets
|
|
306
|
+
all_predictions.append(outputs.cpu())
|
|
307
|
+
all_targets.append(targets.cpu())
|
|
308
|
+
|
|
309
|
+
# Update metrics
|
|
310
|
+
batch_size = images.size(0)
|
|
311
|
+
acc = self.classification_metrics.accuracy(outputs, targets)
|
|
312
|
+
top5_acc = self.classification_metrics.top_k_accuracy(outputs, targets, k=5)
|
|
313
|
+
|
|
314
|
+
self.metrics_tracker.update('accuracy', acc, batch_size)
|
|
315
|
+
self.metrics_tracker.update('top5_accuracy', top5_acc, batch_size)
|
|
316
|
+
|
|
317
|
+
# Performance metrics
|
|
318
|
+
throughput = self.performance_metrics.throughput_metrics(batch_size, batch_time)
|
|
319
|
+
for metric_name, metric_value in throughput.items():
|
|
320
|
+
self.metrics_tracker.update(metric_name, metric_value, 1)
|
|
321
|
+
|
|
322
|
+
# Combine all predictions and targets
|
|
323
|
+
all_predictions = torch.cat(all_predictions, dim=0)
|
|
324
|
+
all_targets = torch.cat(all_targets, dim=0)
|
|
325
|
+
|
|
326
|
+
# Compute comprehensive metrics
|
|
327
|
+
results = self.metrics_tracker.get_averages()
|
|
328
|
+
|
|
329
|
+
# Add detailed classification metrics
|
|
330
|
+
detailed_metrics = self.classification_metrics.precision_recall_f1(all_predictions, all_targets)
|
|
331
|
+
results.update(detailed_metrics)
|
|
332
|
+
|
|
333
|
+
# Add confusion matrix
|
|
334
|
+
results['confusion_matrix'] = self.classification_metrics.confusion_matrix(all_predictions, all_targets)
|
|
335
|
+
|
|
336
|
+
# Add classification report
|
|
337
|
+
if self.class_names:
|
|
338
|
+
results['classification_report'] = self.classification_metrics.classification_report(
|
|
339
|
+
all_predictions, all_targets, self.class_names
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Add model size metrics
|
|
343
|
+
results.update(self.performance_metrics.model_size(self.model))
|
|
344
|
+
|
|
345
|
+
# Add memory usage
|
|
346
|
+
results.update(self.performance_metrics.memory_usage())
|
|
347
|
+
|
|
348
|
+
if return_predictions:
|
|
349
|
+
results['predictions'] = all_predictions
|
|
350
|
+
results['targets'] = all_targets
|
|
351
|
+
|
|
352
|
+
return results
|
|
353
|
+
|
|
354
|
+
def evaluate_contrastive(self,
|
|
355
|
+
dataloader: torch.utils.data.DataLoader) -> Dict[str, Any]:
|
|
356
|
+
"""Evaluate contrastive/multimodal model."""
|
|
357
|
+
self.model.eval()
|
|
358
|
+
all_image_features = []
|
|
359
|
+
all_text_features = []
|
|
360
|
+
|
|
361
|
+
with torch.no_grad():
|
|
362
|
+
for batch in dataloader:
|
|
363
|
+
images = batch['images'].to(self.device)
|
|
364
|
+
texts = batch.get('texts', None)
|
|
365
|
+
|
|
366
|
+
if texts is None:
|
|
367
|
+
warnings.warn("No text data found for contrastive evaluation")
|
|
368
|
+
continue
|
|
369
|
+
|
|
370
|
+
# Forward pass
|
|
371
|
+
outputs = self.model(images, texts, return_features=True)
|
|
372
|
+
|
|
373
|
+
# Collect features
|
|
374
|
+
if 'vision_proj' in outputs and 'text_proj' in outputs:
|
|
375
|
+
all_image_features.append(outputs['vision_proj'].cpu())
|
|
376
|
+
all_text_features.append(outputs['text_proj'].cpu())
|
|
377
|
+
|
|
378
|
+
if not all_image_features:
|
|
379
|
+
return {'error': 'No valid batches for contrastive evaluation'}
|
|
380
|
+
|
|
381
|
+
# Combine all features
|
|
382
|
+
all_image_features = torch.cat(all_image_features, dim=0)
|
|
383
|
+
all_text_features = torch.cat(all_text_features, dim=0)
|
|
384
|
+
|
|
385
|
+
# Compute contrastive metrics
|
|
386
|
+
results = {}
|
|
387
|
+
|
|
388
|
+
# Contrastive accuracy
|
|
389
|
+
contrastive_acc = self.contrastive_metrics.contrastive_accuracy(
|
|
390
|
+
all_image_features, all_text_features
|
|
391
|
+
)
|
|
392
|
+
results.update(contrastive_acc)
|
|
393
|
+
|
|
394
|
+
# Retrieval metrics
|
|
395
|
+
retrieval_metrics = self.contrastive_metrics.retrieval_metrics(
|
|
396
|
+
all_image_features, all_text_features
|
|
397
|
+
)
|
|
398
|
+
results.update(retrieval_metrics)
|
|
399
|
+
|
|
400
|
+
return results
|
|
401
|
+
|
|
402
|
+
def benchmark_inference(self,
|
|
403
|
+
dataloader: torch.utils.data.DataLoader,
|
|
404
|
+
num_warmup: int = 10,
|
|
405
|
+
num_benchmark: int = 100) -> Dict[str, float]:
|
|
406
|
+
"""Benchmark model inference performance."""
|
|
407
|
+
self.model.eval()
|
|
408
|
+
|
|
409
|
+
# Warmup
|
|
410
|
+
with torch.no_grad():
|
|
411
|
+
for i, batch in enumerate(dataloader):
|
|
412
|
+
if i >= num_warmup:
|
|
413
|
+
break
|
|
414
|
+
|
|
415
|
+
images = batch['images'].to(self.device)
|
|
416
|
+
_ = self.model(images)
|
|
417
|
+
|
|
418
|
+
# Benchmark
|
|
419
|
+
times = []
|
|
420
|
+
with torch.no_grad():
|
|
421
|
+
for i, batch in enumerate(dataloader):
|
|
422
|
+
if i >= num_benchmark:
|
|
423
|
+
break
|
|
424
|
+
|
|
425
|
+
images = batch['images'].to(self.device)
|
|
426
|
+
batch_size = images.size(0)
|
|
427
|
+
|
|
428
|
+
start_time = torch.cuda.Event(enable_timing=True)
|
|
429
|
+
end_time = torch.cuda.Event(enable_timing=True)
|
|
430
|
+
|
|
431
|
+
start_time.record()
|
|
432
|
+
_ = self.model(images)
|
|
433
|
+
end_time.record()
|
|
434
|
+
|
|
435
|
+
torch.cuda.synchronize()
|
|
436
|
+
batch_time = start_time.elapsed_time(end_time) / 1000.0
|
|
437
|
+
times.append(batch_time / batch_size) # Time per sample
|
|
438
|
+
|
|
439
|
+
# Compute statistics
|
|
440
|
+
times = np.array(times)
|
|
441
|
+
return {
|
|
442
|
+
'mean_inference_time_ms': np.mean(times) * 1000,
|
|
443
|
+
'std_inference_time_ms': np.std(times) * 1000,
|
|
444
|
+
'min_inference_time_ms': np.min(times) * 1000,
|
|
445
|
+
'max_inference_time_ms': np.max(times) * 1000,
|
|
446
|
+
'median_inference_time_ms': np.median(times) * 1000,
|
|
447
|
+
'throughput_fps': 1.0 / np.mean(times)
|
|
448
|
+
}
|