langvision 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langvision might be problematic. Click here for more details.

Files changed (41) hide show
  1. langvision/__init__.py +77 -2
  2. langvision/callbacks/base.py +166 -7
  3. langvision/cli/__init__.py +85 -0
  4. langvision/cli/complete_cli.py +319 -0
  5. langvision/cli/config.py +344 -0
  6. langvision/cli/evaluate.py +201 -0
  7. langvision/cli/export.py +177 -0
  8. langvision/cli/finetune.py +165 -48
  9. langvision/cli/model_zoo.py +162 -0
  10. langvision/cli/train.py +27 -13
  11. langvision/cli/utils.py +258 -0
  12. langvision/components/attention.py +4 -1
  13. langvision/concepts/__init__.py +9 -0
  14. langvision/concepts/ccot.py +30 -0
  15. langvision/concepts/cot.py +29 -0
  16. langvision/concepts/dpo.py +37 -0
  17. langvision/concepts/grpo.py +25 -0
  18. langvision/concepts/lime.py +37 -0
  19. langvision/concepts/ppo.py +47 -0
  20. langvision/concepts/rlhf.py +40 -0
  21. langvision/concepts/rlvr.py +25 -0
  22. langvision/concepts/shap.py +37 -0
  23. langvision/data/enhanced_datasets.py +582 -0
  24. langvision/model_zoo.py +169 -2
  25. langvision/models/lora.py +189 -17
  26. langvision/models/multimodal.py +297 -0
  27. langvision/models/resnet.py +303 -0
  28. langvision/training/advanced_trainer.py +478 -0
  29. langvision/training/trainer.py +30 -2
  30. langvision/utils/config.py +180 -9
  31. langvision/utils/metrics.py +448 -0
  32. langvision/utils/setup.py +266 -0
  33. langvision-0.1.0.dist-info/METADATA +50 -0
  34. langvision-0.1.0.dist-info/RECORD +61 -0
  35. {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/WHEEL +1 -1
  36. langvision-0.1.0.dist-info/entry_points.txt +2 -0
  37. langvision-0.0.1.dist-info/METADATA +0 -463
  38. langvision-0.0.1.dist-info/RECORD +0 -40
  39. langvision-0.0.1.dist-info/entry_points.txt +0 -2
  40. langvision-0.0.1.dist-info/licenses/LICENSE +0 -21
  41. {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,186 @@
1
+ """
2
+ Configuration utilities for Langvision.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import yaml
8
+ from typing import Dict, Any, Optional
9
+ from pathlib import Path
10
+
11
+
1
12
  default_config = {
2
- 'img_size': 224,
3
- 'patch_size': 16,
4
- 'in_chans': 3,
5
- 'num_classes': 1000,
6
- 'embed_dim': 768,
7
- 'depth': 12,
8
- 'num_heads': 12,
9
- 'mlp_ratio': 4.0,
13
+ 'model': {
14
+ 'name': 'vit_base',
15
+ 'img_size': 224,
16
+ 'patch_size': 16,
17
+ 'in_chans': 3,
18
+ 'num_classes': 1000,
19
+ 'embed_dim': 768,
20
+ 'depth': 12,
21
+ 'num_heads': 12,
22
+ 'mlp_ratio': 4.0,
23
+ 'dropout': 0.1,
24
+ 'attention_dropout': 0.1,
25
+ },
26
+ 'data': {
27
+ 'dataset': 'cifar10',
28
+ 'data_dir': './data',
29
+ 'batch_size': 64,
30
+ 'num_workers': 2,
31
+ 'pin_memory': True,
32
+ 'persistent_workers': False,
33
+ },
34
+ 'training': {
35
+ 'epochs': 10,
36
+ 'learning_rate': 1e-3,
37
+ 'optimizer': 'adam',
38
+ 'weight_decay': 0.01,
39
+ 'scheduler': 'cosine',
40
+ 'warmup_epochs': 0,
41
+ 'min_lr': 1e-6,
42
+ 'gradient_clip': None,
43
+ },
10
44
  'lora': {
11
45
  'r': 4,
12
46
  'alpha': 1.0,
13
47
  'dropout': 0.1,
48
+ 'target_modules': ['attention.qkv', 'attention.proj', 'mlp.fc1', 'mlp.fc2'],
49
+ },
50
+ 'callbacks': {
51
+ 'early_stopping': {
52
+ 'enabled': False,
53
+ 'patience': 5,
54
+ 'min_delta': 0.001,
55
+ },
56
+ 'checkpointing': {
57
+ 'enabled': True,
58
+ 'save_best': True,
59
+ 'save_last': True,
60
+ },
61
+ },
62
+ 'logging': {
63
+ 'level': 'info',
64
+ 'log_interval': 100,
65
+ 'save_interval': 5,
66
+ },
67
+ 'output': {
68
+ 'output_dir': './outputs',
69
+ 'save_name': 'vit_lora_best.pth',
70
+ },
71
+ 'device': {
72
+ 'device': 'cuda',
73
+ 'cuda_deterministic': False,
74
+ 'cuda_benchmark': True,
75
+ },
76
+ 'misc': {
77
+ 'seed': 42,
78
+ 'log_level': 'info',
14
79
  },
15
- }
80
+ }
81
+
82
+
83
+ def load_config(config_path: str) -> Dict[str, Any]:
84
+ """Load configuration from YAML or JSON file."""
85
+ config_path = Path(config_path)
86
+
87
+ if not config_path.exists():
88
+ raise FileNotFoundError(f"Configuration file not found: {config_path}")
89
+
90
+ with open(config_path, 'r') as f:
91
+ if config_path.suffix.lower() in ['.yaml', '.yml']:
92
+ config = yaml.safe_load(f)
93
+ elif config_path.suffix.lower() == '.json':
94
+ config = json.load(f)
95
+ else:
96
+ raise ValueError(f"Unsupported config file format: {config_path.suffix}")
97
+
98
+ return config
99
+
100
+
101
+ def save_config(config: Dict[str, Any], config_path: str) -> None:
102
+ """Save configuration to YAML or JSON file."""
103
+ config_path = Path(config_path)
104
+
105
+ # Create parent directory if it doesn't exist
106
+ config_path.parent.mkdir(parents=True, exist_ok=True)
107
+
108
+ with open(config_path, 'w') as f:
109
+ if config_path.suffix.lower() in ['.yaml', '.yml']:
110
+ yaml.dump(config, f, default_flow_style=False, indent=2)
111
+ elif config_path.suffix.lower() == '.json':
112
+ json.dump(config, f, indent=2)
113
+ else:
114
+ raise ValueError(f"Unsupported config file format: {config_path.suffix}")
115
+
116
+
117
+ def merge_configs(base_config: Dict[str, Any], override_config: Dict[str, Any]) -> Dict[str, Any]:
118
+ """Merge two configurations, with override_config taking precedence."""
119
+ merged = base_config.copy()
120
+
121
+ for key, value in override_config.items():
122
+ if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
123
+ merged[key] = merge_configs(merged[key], value)
124
+ else:
125
+ merged[key] = value
126
+
127
+ return merged
128
+
129
+
130
+ def validate_config(config: Dict[str, Any]) -> bool:
131
+ """Validate configuration structure."""
132
+ required_sections = ['model', 'data', 'training']
133
+
134
+ for section in required_sections:
135
+ if section not in config:
136
+ raise ValueError(f"Missing required configuration section: {section}")
137
+
138
+ # Validate model section
139
+ model_config = config['model']
140
+ required_model_keys = ['img_size', 'patch_size', 'num_classes']
141
+ for key in required_model_keys:
142
+ if key not in model_config:
143
+ raise ValueError(f"Missing required model configuration key: {key}")
144
+
145
+ # Validate data section
146
+ data_config = config['data']
147
+ required_data_keys = ['dataset', 'batch_size']
148
+ for key in required_data_keys:
149
+ if key not in data_config:
150
+ raise ValueError(f"Missing required data configuration key: {key}")
151
+
152
+ # Validate training section
153
+ training_config = config['training']
154
+ required_training_keys = ['epochs', 'learning_rate']
155
+ for key in required_training_keys:
156
+ if key not in training_config:
157
+ raise ValueError(f"Missing required training configuration key: {key}")
158
+
159
+ return True
160
+
161
+
162
+ def get_config_value(config: Dict[str, Any], key_path: str, default: Any = None) -> Any:
163
+ """Get configuration value using dot notation (e.g., 'model.img_size')."""
164
+ keys = key_path.split('.')
165
+ value = config
166
+
167
+ for key in keys:
168
+ if isinstance(value, dict) and key in value:
169
+ value = value[key]
170
+ else:
171
+ return default
172
+
173
+ return value
174
+
175
+
176
+ def set_config_value(config: Dict[str, Any], key_path: str, value: Any) -> None:
177
+ """Set configuration value using dot notation (e.g., 'model.img_size')."""
178
+ keys = key_path.split('.')
179
+ current = config
180
+
181
+ for key in keys[:-1]:
182
+ if key not in current:
183
+ current[key] = {}
184
+ current = current[key]
185
+
186
+ current[keys[-1]] = value
@@ -0,0 +1,448 @@
1
+ """
2
+ Comprehensive metrics tracking and evaluation utilities.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from typing import Dict, List, Optional, Union, Any
9
+ from collections import defaultdict
10
+ import sklearn.metrics as sk_metrics
11
+ from dataclasses import dataclass
12
+ import warnings
13
+
14
+
15
+ @dataclass
16
+ class MetricResult:
17
+ """Container for metric computation results."""
18
+ value: float
19
+ count: int
20
+ sum: float
21
+
22
+ def __post_init__(self):
23
+ if self.count == 0:
24
+ self.value = 0.0
25
+
26
+
27
+ class MetricsTracker:
28
+ """Advanced metrics tracking with support for various metric types."""
29
+
30
+ def __init__(self):
31
+ self.metrics = defaultdict(list)
32
+ self.counts = defaultdict(int)
33
+ self.sums = defaultdict(float)
34
+
35
+ def update(self, name: str, value: float, count: int = 1):
36
+ """Update a metric with a new value."""
37
+ self.metrics[name].append(value)
38
+ self.counts[name] += count
39
+ self.sums[name] += value * count
40
+
41
+ def get_average(self, name: str) -> float:
42
+ """Get the average value of a metric."""
43
+ if self.counts[name] == 0:
44
+ return 0.0
45
+ return self.sums[name] / self.counts[name]
46
+
47
+ def get_averages(self, names: Optional[List[str]] = None) -> Dict[str, float]:
48
+ """Get average values for multiple metrics."""
49
+ if names is None:
50
+ names = list(self.metrics.keys())
51
+
52
+ return {name: self.get_average(name) for name in names}
53
+
54
+ def get_latest(self, name: str) -> float:
55
+ """Get the latest value of a metric."""
56
+ if not self.metrics[name]:
57
+ return 0.0
58
+ return self.metrics[name][-1]
59
+
60
+ def reset(self, names: Optional[List[str]] = None):
61
+ """Reset metrics."""
62
+ if names is None:
63
+ self.metrics.clear()
64
+ self.counts.clear()
65
+ self.sums.clear()
66
+ else:
67
+ for name in names:
68
+ if name in self.metrics:
69
+ del self.metrics[name]
70
+ if name in self.counts:
71
+ del self.counts[name]
72
+ if name in self.sums:
73
+ del self.sums[name]
74
+
75
+ def get_summary(self) -> Dict[str, Dict[str, float]]:
76
+ """Get comprehensive summary of all metrics."""
77
+ summary = {}
78
+ for name in self.metrics.keys():
79
+ values = self.metrics[name]
80
+ if values:
81
+ summary[name] = {
82
+ 'mean': np.mean(values),
83
+ 'std': np.std(values),
84
+ 'min': np.min(values),
85
+ 'max': np.max(values),
86
+ 'count': len(values),
87
+ 'latest': values[-1]
88
+ }
89
+ return summary
90
+
91
+
92
+ class ClassificationMetrics:
93
+ """Comprehensive classification metrics computation."""
94
+
95
+ @staticmethod
96
+ def accuracy(predictions: torch.Tensor, targets: torch.Tensor) -> float:
97
+ """Compute accuracy."""
98
+ correct = (predictions.argmax(dim=1) == targets).float()
99
+ return correct.mean().item()
100
+
101
+ @staticmethod
102
+ def top_k_accuracy(predictions: torch.Tensor, targets: torch.Tensor, k: int = 5) -> float:
103
+ """Compute top-k accuracy."""
104
+ _, top_k_preds = predictions.topk(k, dim=1)
105
+ correct = top_k_preds.eq(targets.view(-1, 1).expand_as(top_k_preds))
106
+ return correct.any(dim=1).float().mean().item()
107
+
108
+ @staticmethod
109
+ def precision_recall_f1(predictions: torch.Tensor,
110
+ targets: torch.Tensor,
111
+ average: str = 'weighted') -> Dict[str, float]:
112
+ """Compute precision, recall, and F1 score."""
113
+ preds = predictions.argmax(dim=1).cpu().numpy()
114
+ targets = targets.cpu().numpy()
115
+
116
+ precision = sk_metrics.precision_score(targets, preds, average=average, zero_division=0)
117
+ recall = sk_metrics.recall_score(targets, preds, average=average, zero_division=0)
118
+ f1 = sk_metrics.f1_score(targets, preds, average=average, zero_division=0)
119
+
120
+ return {
121
+ 'precision': precision,
122
+ 'recall': recall,
123
+ 'f1': f1
124
+ }
125
+
126
+ @staticmethod
127
+ def confusion_matrix(predictions: torch.Tensor, targets: torch.Tensor) -> np.ndarray:
128
+ """Compute confusion matrix."""
129
+ preds = predictions.argmax(dim=1).cpu().numpy()
130
+ targets = targets.cpu().numpy()
131
+ return sk_metrics.confusion_matrix(targets, preds)
132
+
133
+ @staticmethod
134
+ def classification_report(predictions: torch.Tensor,
135
+ targets: torch.Tensor,
136
+ class_names: Optional[List[str]] = None) -> str:
137
+ """Generate classification report."""
138
+ preds = predictions.argmax(dim=1).cpu().numpy()
139
+ targets = targets.cpu().numpy()
140
+ return sk_metrics.classification_report(targets, preds, target_names=class_names)
141
+
142
+
143
+ class ContrastiveMetrics:
144
+ """Metrics for contrastive learning and multimodal models."""
145
+
146
+ @staticmethod
147
+ def contrastive_accuracy(image_features: torch.Tensor,
148
+ text_features: torch.Tensor,
149
+ temperature: float = 0.07) -> Dict[str, float]:
150
+ """Compute contrastive accuracy (image-to-text and text-to-image)."""
151
+ # Normalize features
152
+ image_features = F.normalize(image_features, dim=-1)
153
+ text_features = F.normalize(text_features, dim=-1)
154
+
155
+ # Compute similarity matrix
156
+ logits = torch.matmul(image_features, text_features.T) / temperature
157
+
158
+ # Ground truth labels (diagonal)
159
+ batch_size = image_features.shape[0]
160
+ labels = torch.arange(batch_size, device=logits.device)
161
+
162
+ # Image-to-text accuracy
163
+ i2t_acc = (logits.argmax(dim=1) == labels).float().mean().item()
164
+
165
+ # Text-to-image accuracy
166
+ t2i_acc = (logits.T.argmax(dim=1) == labels).float().mean().item()
167
+
168
+ return {
169
+ 'i2t_accuracy': i2t_acc,
170
+ 't2i_accuracy': t2i_acc,
171
+ 'mean_accuracy': (i2t_acc + t2i_acc) / 2
172
+ }
173
+
174
+ @staticmethod
175
+ def retrieval_metrics(image_features: torch.Tensor,
176
+ text_features: torch.Tensor,
177
+ k_values: List[int] = [1, 5, 10]) -> Dict[str, float]:
178
+ """Compute retrieval metrics (Recall@K)."""
179
+ # Normalize features
180
+ image_features = F.normalize(image_features, dim=-1)
181
+ text_features = F.normalize(text_features, dim=-1)
182
+
183
+ # Compute similarity matrix
184
+ similarities = torch.matmul(image_features, text_features.T)
185
+
186
+ batch_size = similarities.shape[0]
187
+ metrics = {}
188
+
189
+ # Image-to-text retrieval
190
+ for k in k_values:
191
+ if k <= batch_size:
192
+ # Get top-k indices for each image
193
+ _, top_k_indices = similarities.topk(k, dim=1)
194
+
195
+ # Check if correct text is in top-k
196
+ correct_indices = torch.arange(batch_size, device=similarities.device).unsqueeze(1)
197
+ hits = (top_k_indices == correct_indices).any(dim=1)
198
+ recall_at_k = hits.float().mean().item()
199
+
200
+ metrics[f'i2t_recall@{k}'] = recall_at_k
201
+
202
+ # Text-to-image retrieval
203
+ similarities_t2i = similarities.T
204
+ for k in k_values:
205
+ if k <= batch_size:
206
+ _, top_k_indices = similarities_t2i.topk(k, dim=1)
207
+ correct_indices = torch.arange(batch_size, device=similarities.device).unsqueeze(1)
208
+ hits = (top_k_indices == correct_indices).any(dim=1)
209
+ recall_at_k = hits.float().mean().item()
210
+
211
+ metrics[f't2i_recall@{k}'] = recall_at_k
212
+
213
+ return metrics
214
+
215
+
216
+ class PerformanceMetrics:
217
+ """Performance and efficiency metrics."""
218
+
219
+ @staticmethod
220
+ def model_size(model: torch.nn.Module) -> Dict[str, int]:
221
+ """Calculate model size metrics."""
222
+ total_params = sum(p.numel() for p in model.parameters())
223
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
224
+
225
+ return {
226
+ 'total_parameters': total_params,
227
+ 'trainable_parameters': trainable_params,
228
+ 'frozen_parameters': total_params - trainable_params,
229
+ 'trainable_ratio': trainable_params / total_params if total_params > 0 else 0
230
+ }
231
+
232
+ @staticmethod
233
+ def memory_usage() -> Dict[str, float]:
234
+ """Get GPU memory usage statistics."""
235
+ if torch.cuda.is_available():
236
+ return {
237
+ 'allocated_mb': torch.cuda.memory_allocated() / 1024**2,
238
+ 'cached_mb': torch.cuda.memory_reserved() / 1024**2,
239
+ 'max_allocated_mb': torch.cuda.max_memory_allocated() / 1024**2,
240
+ 'max_cached_mb': torch.cuda.max_memory_reserved() / 1024**2
241
+ }
242
+ else:
243
+ return {
244
+ 'allocated_mb': 0.0,
245
+ 'cached_mb': 0.0,
246
+ 'max_allocated_mb': 0.0,
247
+ 'max_cached_mb': 0.0
248
+ }
249
+
250
+ @staticmethod
251
+ def throughput_metrics(batch_size: int,
252
+ processing_time: float,
253
+ num_samples: Optional[int] = None) -> Dict[str, float]:
254
+ """Calculate throughput metrics."""
255
+ if num_samples is None:
256
+ num_samples = batch_size
257
+
258
+ return {
259
+ 'samples_per_second': num_samples / processing_time if processing_time > 0 else 0,
260
+ 'batches_per_second': 1 / processing_time if processing_time > 0 else 0,
261
+ 'ms_per_sample': (processing_time * 1000) / num_samples if num_samples > 0 else 0,
262
+ 'ms_per_batch': processing_time * 1000
263
+ }
264
+
265
+
266
+ class EvaluationSuite:
267
+ """Comprehensive evaluation suite for vision models."""
268
+
269
+ def __init__(self,
270
+ model: torch.nn.Module,
271
+ device: torch.device,
272
+ class_names: Optional[List[str]] = None):
273
+ self.model = model
274
+ self.device = device
275
+ self.class_names = class_names
276
+ self.metrics_tracker = MetricsTracker()
277
+ self.classification_metrics = ClassificationMetrics()
278
+ self.contrastive_metrics = ContrastiveMetrics()
279
+ self.performance_metrics = PerformanceMetrics()
280
+
281
+ def evaluate_classification(self,
282
+ dataloader: torch.utils.data.DataLoader,
283
+ return_predictions: bool = False) -> Dict[str, Any]:
284
+ """Comprehensive classification evaluation."""
285
+ self.model.eval()
286
+ all_predictions = []
287
+ all_targets = []
288
+
289
+ with torch.no_grad():
290
+ for batch in dataloader:
291
+ images = batch['images'].to(self.device)
292
+ targets = batch['labels'].to(self.device)
293
+
294
+ # Forward pass
295
+ start_time = torch.cuda.Event(enable_timing=True)
296
+ end_time = torch.cuda.Event(enable_timing=True)
297
+
298
+ start_time.record()
299
+ outputs = self.model(images)
300
+ end_time.record()
301
+
302
+ torch.cuda.synchronize()
303
+ batch_time = start_time.elapsed_time(end_time) / 1000.0 # Convert to seconds
304
+
305
+ # Collect predictions and targets
306
+ all_predictions.append(outputs.cpu())
307
+ all_targets.append(targets.cpu())
308
+
309
+ # Update metrics
310
+ batch_size = images.size(0)
311
+ acc = self.classification_metrics.accuracy(outputs, targets)
312
+ top5_acc = self.classification_metrics.top_k_accuracy(outputs, targets, k=5)
313
+
314
+ self.metrics_tracker.update('accuracy', acc, batch_size)
315
+ self.metrics_tracker.update('top5_accuracy', top5_acc, batch_size)
316
+
317
+ # Performance metrics
318
+ throughput = self.performance_metrics.throughput_metrics(batch_size, batch_time)
319
+ for metric_name, metric_value in throughput.items():
320
+ self.metrics_tracker.update(metric_name, metric_value, 1)
321
+
322
+ # Combine all predictions and targets
323
+ all_predictions = torch.cat(all_predictions, dim=0)
324
+ all_targets = torch.cat(all_targets, dim=0)
325
+
326
+ # Compute comprehensive metrics
327
+ results = self.metrics_tracker.get_averages()
328
+
329
+ # Add detailed classification metrics
330
+ detailed_metrics = self.classification_metrics.precision_recall_f1(all_predictions, all_targets)
331
+ results.update(detailed_metrics)
332
+
333
+ # Add confusion matrix
334
+ results['confusion_matrix'] = self.classification_metrics.confusion_matrix(all_predictions, all_targets)
335
+
336
+ # Add classification report
337
+ if self.class_names:
338
+ results['classification_report'] = self.classification_metrics.classification_report(
339
+ all_predictions, all_targets, self.class_names
340
+ )
341
+
342
+ # Add model size metrics
343
+ results.update(self.performance_metrics.model_size(self.model))
344
+
345
+ # Add memory usage
346
+ results.update(self.performance_metrics.memory_usage())
347
+
348
+ if return_predictions:
349
+ results['predictions'] = all_predictions
350
+ results['targets'] = all_targets
351
+
352
+ return results
353
+
354
+ def evaluate_contrastive(self,
355
+ dataloader: torch.utils.data.DataLoader) -> Dict[str, Any]:
356
+ """Evaluate contrastive/multimodal model."""
357
+ self.model.eval()
358
+ all_image_features = []
359
+ all_text_features = []
360
+
361
+ with torch.no_grad():
362
+ for batch in dataloader:
363
+ images = batch['images'].to(self.device)
364
+ texts = batch.get('texts', None)
365
+
366
+ if texts is None:
367
+ warnings.warn("No text data found for contrastive evaluation")
368
+ continue
369
+
370
+ # Forward pass
371
+ outputs = self.model(images, texts, return_features=True)
372
+
373
+ # Collect features
374
+ if 'vision_proj' in outputs and 'text_proj' in outputs:
375
+ all_image_features.append(outputs['vision_proj'].cpu())
376
+ all_text_features.append(outputs['text_proj'].cpu())
377
+
378
+ if not all_image_features:
379
+ return {'error': 'No valid batches for contrastive evaluation'}
380
+
381
+ # Combine all features
382
+ all_image_features = torch.cat(all_image_features, dim=0)
383
+ all_text_features = torch.cat(all_text_features, dim=0)
384
+
385
+ # Compute contrastive metrics
386
+ results = {}
387
+
388
+ # Contrastive accuracy
389
+ contrastive_acc = self.contrastive_metrics.contrastive_accuracy(
390
+ all_image_features, all_text_features
391
+ )
392
+ results.update(contrastive_acc)
393
+
394
+ # Retrieval metrics
395
+ retrieval_metrics = self.contrastive_metrics.retrieval_metrics(
396
+ all_image_features, all_text_features
397
+ )
398
+ results.update(retrieval_metrics)
399
+
400
+ return results
401
+
402
+ def benchmark_inference(self,
403
+ dataloader: torch.utils.data.DataLoader,
404
+ num_warmup: int = 10,
405
+ num_benchmark: int = 100) -> Dict[str, float]:
406
+ """Benchmark model inference performance."""
407
+ self.model.eval()
408
+
409
+ # Warmup
410
+ with torch.no_grad():
411
+ for i, batch in enumerate(dataloader):
412
+ if i >= num_warmup:
413
+ break
414
+
415
+ images = batch['images'].to(self.device)
416
+ _ = self.model(images)
417
+
418
+ # Benchmark
419
+ times = []
420
+ with torch.no_grad():
421
+ for i, batch in enumerate(dataloader):
422
+ if i >= num_benchmark:
423
+ break
424
+
425
+ images = batch['images'].to(self.device)
426
+ batch_size = images.size(0)
427
+
428
+ start_time = torch.cuda.Event(enable_timing=True)
429
+ end_time = torch.cuda.Event(enable_timing=True)
430
+
431
+ start_time.record()
432
+ _ = self.model(images)
433
+ end_time.record()
434
+
435
+ torch.cuda.synchronize()
436
+ batch_time = start_time.elapsed_time(end_time) / 1000.0
437
+ times.append(batch_time / batch_size) # Time per sample
438
+
439
+ # Compute statistics
440
+ times = np.array(times)
441
+ return {
442
+ 'mean_inference_time_ms': np.mean(times) * 1000,
443
+ 'std_inference_time_ms': np.std(times) * 1000,
444
+ 'min_inference_time_ms': np.min(times) * 1000,
445
+ 'max_inference_time_ms': np.max(times) * 1000,
446
+ 'median_inference_time_ms': np.median(times) * 1000,
447
+ 'throughput_fps': 1.0 / np.mean(times)
448
+ }