langvision 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langvision might be problematic. Click here for more details.
- langvision/__init__.py +77 -2
- langvision/callbacks/base.py +166 -7
- langvision/cli/__init__.py +85 -0
- langvision/cli/complete_cli.py +319 -0
- langvision/cli/config.py +344 -0
- langvision/cli/evaluate.py +201 -0
- langvision/cli/export.py +177 -0
- langvision/cli/finetune.py +165 -48
- langvision/cli/model_zoo.py +162 -0
- langvision/cli/train.py +27 -13
- langvision/cli/utils.py +258 -0
- langvision/components/attention.py +4 -1
- langvision/concepts/__init__.py +9 -0
- langvision/concepts/ccot.py +30 -0
- langvision/concepts/cot.py +29 -0
- langvision/concepts/dpo.py +37 -0
- langvision/concepts/grpo.py +25 -0
- langvision/concepts/lime.py +37 -0
- langvision/concepts/ppo.py +47 -0
- langvision/concepts/rlhf.py +40 -0
- langvision/concepts/rlvr.py +25 -0
- langvision/concepts/shap.py +37 -0
- langvision/data/enhanced_datasets.py +582 -0
- langvision/model_zoo.py +169 -2
- langvision/models/lora.py +189 -17
- langvision/models/multimodal.py +297 -0
- langvision/models/resnet.py +303 -0
- langvision/training/advanced_trainer.py +478 -0
- langvision/training/trainer.py +30 -2
- langvision/utils/config.py +180 -9
- langvision/utils/metrics.py +448 -0
- langvision/utils/setup.py +266 -0
- langvision-0.1.0.dist-info/METADATA +50 -0
- langvision-0.1.0.dist-info/RECORD +61 -0
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/WHEEL +1 -1
- langvision-0.1.0.dist-info/entry_points.txt +2 -0
- langvision-0.0.1.dist-info/METADATA +0 -463
- langvision-0.0.1.dist-info/RECORD +0 -40
- langvision-0.0.1.dist-info/entry_points.txt +0 -2
- langvision-0.0.1.dist-info/licenses/LICENSE +0 -21
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,582 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Enhanced dataset classes with comprehensive data validation, augmentation, and multimodal support.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch.utils.data import Dataset, DataLoader
|
|
7
|
+
import torchvision.transforms as transforms
|
|
8
|
+
from torchvision.datasets import ImageFolder, CIFAR10, CIFAR100
|
|
9
|
+
from PIL import Image
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Dict, List, Optional, Union, Tuple, Callable, Any
|
|
15
|
+
import numpy as np
|
|
16
|
+
import logging
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
import warnings
|
|
19
|
+
from collections import Counter
|
|
20
|
+
import cv2
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class DatasetConfig:
|
|
25
|
+
"""Configuration for dataset parameters."""
|
|
26
|
+
# Basic settings
|
|
27
|
+
root_dir: str
|
|
28
|
+
image_size: Tuple[int, int] = (224, 224)
|
|
29
|
+
batch_size: int = 32
|
|
30
|
+
num_workers: int = 4
|
|
31
|
+
pin_memory: bool = True
|
|
32
|
+
|
|
33
|
+
# Data splits
|
|
34
|
+
train_split: float = 0.8
|
|
35
|
+
val_split: float = 0.1
|
|
36
|
+
test_split: float = 0.1
|
|
37
|
+
|
|
38
|
+
# Augmentation settings
|
|
39
|
+
use_augmentation: bool = True
|
|
40
|
+
augmentation_strength: float = 0.5
|
|
41
|
+
|
|
42
|
+
# Multimodal settings
|
|
43
|
+
text_max_length: int = 77
|
|
44
|
+
text_tokenizer: Optional[str] = None
|
|
45
|
+
|
|
46
|
+
# Validation settings
|
|
47
|
+
validate_images: bool = True
|
|
48
|
+
min_image_size: Tuple[int, int] = (32, 32)
|
|
49
|
+
max_image_size: Tuple[int, int] = (4096, 4096)
|
|
50
|
+
allowed_formats: List[str] = None
|
|
51
|
+
|
|
52
|
+
def __post_init__(self):
|
|
53
|
+
if self.allowed_formats is None:
|
|
54
|
+
self.allowed_formats = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
|
|
55
|
+
|
|
56
|
+
# Validate splits sum to 1.0
|
|
57
|
+
total_split = self.train_split + self.val_split + self.test_split
|
|
58
|
+
if abs(total_split - 1.0) > 1e-6:
|
|
59
|
+
raise ValueError(f"Data splits must sum to 1.0, got {total_split}")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ImageValidator:
|
|
63
|
+
"""Comprehensive image validation utilities."""
|
|
64
|
+
|
|
65
|
+
def __init__(self,
|
|
66
|
+
min_size: Tuple[int, int] = (32, 32),
|
|
67
|
+
max_size: Tuple[int, int] = (4096, 4096),
|
|
68
|
+
allowed_formats: List[str] = None):
|
|
69
|
+
self.min_size = min_size
|
|
70
|
+
self.max_size = max_size
|
|
71
|
+
self.allowed_formats = allowed_formats or ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
|
|
72
|
+
self.logger = logging.getLogger(__name__)
|
|
73
|
+
|
|
74
|
+
def validate_image_file(self, image_path: Union[str, Path]) -> Dict[str, Any]:
|
|
75
|
+
"""Validate a single image file."""
|
|
76
|
+
image_path = Path(image_path)
|
|
77
|
+
result = {
|
|
78
|
+
'valid': True,
|
|
79
|
+
'errors': [],
|
|
80
|
+
'warnings': [],
|
|
81
|
+
'metadata': {}
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
# Check if file exists
|
|
85
|
+
if not image_path.exists():
|
|
86
|
+
result['valid'] = False
|
|
87
|
+
result['errors'].append(f"File does not exist: {image_path}")
|
|
88
|
+
return result
|
|
89
|
+
|
|
90
|
+
# Check file extension
|
|
91
|
+
if image_path.suffix.lower() not in self.allowed_formats:
|
|
92
|
+
result['valid'] = False
|
|
93
|
+
result['errors'].append(f"Unsupported format: {image_path.suffix}")
|
|
94
|
+
return result
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
# Try to open and validate image
|
|
98
|
+
with Image.open(image_path) as img:
|
|
99
|
+
width, height = img.size
|
|
100
|
+
result['metadata'].update({
|
|
101
|
+
'width': width,
|
|
102
|
+
'height': height,
|
|
103
|
+
'mode': img.mode,
|
|
104
|
+
'format': img.format,
|
|
105
|
+
'file_size': image_path.stat().st_size
|
|
106
|
+
})
|
|
107
|
+
|
|
108
|
+
# Check image dimensions
|
|
109
|
+
if width < self.min_size[0] or height < self.min_size[1]:
|
|
110
|
+
result['warnings'].append(f"Image too small: {width}x{height} < {self.min_size}")
|
|
111
|
+
|
|
112
|
+
if width > self.max_size[0] or height > self.max_size[1]:
|
|
113
|
+
result['warnings'].append(f"Image very large: {width}x{height} > {self.max_size}")
|
|
114
|
+
|
|
115
|
+
# Check if image is corrupted by trying to load it
|
|
116
|
+
img.verify()
|
|
117
|
+
|
|
118
|
+
except Exception as e:
|
|
119
|
+
result['valid'] = False
|
|
120
|
+
result['errors'].append(f"Failed to open image: {str(e)}")
|
|
121
|
+
|
|
122
|
+
return result
|
|
123
|
+
|
|
124
|
+
def validate_dataset(self, image_paths: List[Union[str, Path]]) -> Dict[str, Any]:
|
|
125
|
+
"""Validate an entire dataset."""
|
|
126
|
+
results = {
|
|
127
|
+
'total_images': len(image_paths),
|
|
128
|
+
'valid_images': 0,
|
|
129
|
+
'invalid_images': 0,
|
|
130
|
+
'warnings_count': 0,
|
|
131
|
+
'errors': [],
|
|
132
|
+
'warnings': [],
|
|
133
|
+
'metadata_stats': {}
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
valid_images = []
|
|
137
|
+
widths, heights, file_sizes = [], [], []
|
|
138
|
+
|
|
139
|
+
for image_path in image_paths:
|
|
140
|
+
validation_result = self.validate_image_file(image_path)
|
|
141
|
+
|
|
142
|
+
if validation_result['valid']:
|
|
143
|
+
results['valid_images'] += 1
|
|
144
|
+
valid_images.append(image_path)
|
|
145
|
+
|
|
146
|
+
# Collect metadata for statistics
|
|
147
|
+
metadata = validation_result['metadata']
|
|
148
|
+
widths.append(metadata['width'])
|
|
149
|
+
heights.append(metadata['height'])
|
|
150
|
+
file_sizes.append(metadata['file_size'])
|
|
151
|
+
else:
|
|
152
|
+
results['invalid_images'] += 1
|
|
153
|
+
results['errors'].extend(validation_result['errors'])
|
|
154
|
+
|
|
155
|
+
results['warnings'].extend(validation_result['warnings'])
|
|
156
|
+
results['warnings_count'] += len(validation_result['warnings'])
|
|
157
|
+
|
|
158
|
+
# Compute metadata statistics
|
|
159
|
+
if valid_images:
|
|
160
|
+
results['metadata_stats'] = {
|
|
161
|
+
'width_stats': {
|
|
162
|
+
'mean': np.mean(widths),
|
|
163
|
+
'std': np.std(widths),
|
|
164
|
+
'min': np.min(widths),
|
|
165
|
+
'max': np.max(widths)
|
|
166
|
+
},
|
|
167
|
+
'height_stats': {
|
|
168
|
+
'mean': np.mean(heights),
|
|
169
|
+
'std': np.std(heights),
|
|
170
|
+
'min': np.min(heights),
|
|
171
|
+
'max': np.max(heights)
|
|
172
|
+
},
|
|
173
|
+
'file_size_stats': {
|
|
174
|
+
'mean_mb': np.mean(file_sizes) / (1024**2),
|
|
175
|
+
'std_mb': np.std(file_sizes) / (1024**2),
|
|
176
|
+
'min_mb': np.min(file_sizes) / (1024**2),
|
|
177
|
+
'max_mb': np.max(file_sizes) / (1024**2)
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
return results
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class SmartAugmentation:
|
|
185
|
+
"""Intelligent data augmentation with adaptive strategies."""
|
|
186
|
+
|
|
187
|
+
def __init__(self,
|
|
188
|
+
image_size: Tuple[int, int] = (224, 224),
|
|
189
|
+
strength: float = 0.5,
|
|
190
|
+
preserve_aspect_ratio: bool = True):
|
|
191
|
+
self.image_size = image_size
|
|
192
|
+
self.strength = strength
|
|
193
|
+
self.preserve_aspect_ratio = preserve_aspect_ratio
|
|
194
|
+
|
|
195
|
+
# Base transforms
|
|
196
|
+
self.base_transforms = transforms.Compose([
|
|
197
|
+
transforms.Resize(image_size),
|
|
198
|
+
transforms.ToTensor(),
|
|
199
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
200
|
+
])
|
|
201
|
+
|
|
202
|
+
# Augmentation transforms based on strength
|
|
203
|
+
self.augmentation_transforms = self._create_augmentation_transforms()
|
|
204
|
+
|
|
205
|
+
def _create_augmentation_transforms(self) -> transforms.Compose:
|
|
206
|
+
"""Create augmentation transforms based on strength parameter."""
|
|
207
|
+
aug_list = []
|
|
208
|
+
|
|
209
|
+
if self.preserve_aspect_ratio:
|
|
210
|
+
aug_list.append(transforms.Resize(min(self.image_size)))
|
|
211
|
+
aug_list.append(transforms.CenterCrop(self.image_size))
|
|
212
|
+
else:
|
|
213
|
+
aug_list.append(transforms.Resize(self.image_size))
|
|
214
|
+
|
|
215
|
+
# Add augmentations based on strength
|
|
216
|
+
if self.strength > 0.1:
|
|
217
|
+
aug_list.extend([
|
|
218
|
+
transforms.RandomHorizontalFlip(p=0.5 * self.strength),
|
|
219
|
+
transforms.RandomRotation(degrees=10 * self.strength),
|
|
220
|
+
])
|
|
221
|
+
|
|
222
|
+
if self.strength > 0.3:
|
|
223
|
+
aug_list.extend([
|
|
224
|
+
transforms.ColorJitter(
|
|
225
|
+
brightness=0.2 * self.strength,
|
|
226
|
+
contrast=0.2 * self.strength,
|
|
227
|
+
saturation=0.2 * self.strength,
|
|
228
|
+
hue=0.1 * self.strength
|
|
229
|
+
),
|
|
230
|
+
transforms.RandomAffine(
|
|
231
|
+
degrees=5 * self.strength,
|
|
232
|
+
translate=(0.1 * self.strength, 0.1 * self.strength),
|
|
233
|
+
scale=(1 - 0.1 * self.strength, 1 + 0.1 * self.strength)
|
|
234
|
+
)
|
|
235
|
+
])
|
|
236
|
+
|
|
237
|
+
if self.strength > 0.5:
|
|
238
|
+
aug_list.extend([
|
|
239
|
+
transforms.RandomPerspective(distortion_scale=0.2 * self.strength, p=0.3),
|
|
240
|
+
transforms.RandomErasing(p=0.2 * self.strength, scale=(0.02, 0.33), ratio=(0.3, 3.3))
|
|
241
|
+
])
|
|
242
|
+
|
|
243
|
+
# Always add tensor conversion and normalization
|
|
244
|
+
aug_list.extend([
|
|
245
|
+
transforms.ToTensor(),
|
|
246
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
247
|
+
])
|
|
248
|
+
|
|
249
|
+
return transforms.Compose(aug_list)
|
|
250
|
+
|
|
251
|
+
def get_train_transforms(self) -> transforms.Compose:
|
|
252
|
+
"""Get training transforms with augmentation."""
|
|
253
|
+
return self.augmentation_transforms
|
|
254
|
+
|
|
255
|
+
def get_val_transforms(self) -> transforms.Compose:
|
|
256
|
+
"""Get validation transforms without augmentation."""
|
|
257
|
+
return self.base_transforms
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class EnhancedImageDataset(Dataset):
|
|
261
|
+
"""Enhanced image dataset with validation, smart augmentation, and error handling."""
|
|
262
|
+
|
|
263
|
+
def __init__(self,
|
|
264
|
+
root_dir: Union[str, Path],
|
|
265
|
+
config: Optional[DatasetConfig] = None,
|
|
266
|
+
transform: Optional[Callable] = None,
|
|
267
|
+
validate_on_init: bool = True,
|
|
268
|
+
cache_images: bool = False):
|
|
269
|
+
|
|
270
|
+
self.root_dir = Path(root_dir)
|
|
271
|
+
self.config = config or DatasetConfig(root_dir=str(root_dir))
|
|
272
|
+
self.cache_images = cache_images
|
|
273
|
+
self.image_cache = {} if cache_images else None
|
|
274
|
+
|
|
275
|
+
# Setup logging
|
|
276
|
+
self.logger = logging.getLogger(__name__)
|
|
277
|
+
|
|
278
|
+
# Find all image files
|
|
279
|
+
self.image_paths = self._find_image_files()
|
|
280
|
+
self.logger.info(f"Found {len(self.image_paths)} image files")
|
|
281
|
+
|
|
282
|
+
# Setup transforms
|
|
283
|
+
if transform is None:
|
|
284
|
+
augmentation = SmartAugmentation(
|
|
285
|
+
image_size=self.config.image_size,
|
|
286
|
+
strength=self.config.augmentation_strength
|
|
287
|
+
)
|
|
288
|
+
self.transform = augmentation.get_train_transforms() if self.config.use_augmentation else augmentation.get_val_transforms()
|
|
289
|
+
else:
|
|
290
|
+
self.transform = transform
|
|
291
|
+
|
|
292
|
+
# Validate dataset if requested
|
|
293
|
+
if validate_on_init and self.config.validate_images:
|
|
294
|
+
self._validate_dataset()
|
|
295
|
+
|
|
296
|
+
# Create class mapping if this is a classification dataset
|
|
297
|
+
self.classes, self.class_to_idx = self._create_class_mapping()
|
|
298
|
+
|
|
299
|
+
# Create labels
|
|
300
|
+
self.labels = self._create_labels()
|
|
301
|
+
|
|
302
|
+
def _find_image_files(self) -> List[Path]:
|
|
303
|
+
"""Find all valid image files in the directory."""
|
|
304
|
+
image_paths = []
|
|
305
|
+
|
|
306
|
+
for ext in self.config.allowed_formats:
|
|
307
|
+
pattern = f"**/*{ext}"
|
|
308
|
+
image_paths.extend(self.root_dir.glob(pattern))
|
|
309
|
+
# Also check uppercase extensions
|
|
310
|
+
pattern = f"**/*{ext.upper()}"
|
|
311
|
+
image_paths.extend(self.root_dir.glob(pattern))
|
|
312
|
+
|
|
313
|
+
return sorted(list(set(image_paths))) # Remove duplicates and sort
|
|
314
|
+
|
|
315
|
+
def _validate_dataset(self):
|
|
316
|
+
"""Validate the entire dataset."""
|
|
317
|
+
validator = ImageValidator(
|
|
318
|
+
min_size=self.config.min_image_size,
|
|
319
|
+
max_size=self.config.max_image_size,
|
|
320
|
+
allowed_formats=self.config.allowed_formats
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
validation_results = validator.validate_dataset(self.image_paths)
|
|
324
|
+
|
|
325
|
+
# Log validation results
|
|
326
|
+
self.logger.info(f"Dataset validation completed:")
|
|
327
|
+
self.logger.info(f" Total images: {validation_results['total_images']}")
|
|
328
|
+
self.logger.info(f" Valid images: {validation_results['valid_images']}")
|
|
329
|
+
self.logger.info(f" Invalid images: {validation_results['invalid_images']}")
|
|
330
|
+
self.logger.info(f" Warnings: {validation_results['warnings_count']}")
|
|
331
|
+
|
|
332
|
+
# Remove invalid images
|
|
333
|
+
if validation_results['invalid_images'] > 0:
|
|
334
|
+
self.logger.warning(f"Removing {validation_results['invalid_images']} invalid images")
|
|
335
|
+
valid_paths = []
|
|
336
|
+
for image_path in self.image_paths:
|
|
337
|
+
result = validator.validate_image_file(image_path)
|
|
338
|
+
if result['valid']:
|
|
339
|
+
valid_paths.append(image_path)
|
|
340
|
+
self.image_paths = valid_paths
|
|
341
|
+
|
|
342
|
+
# Log metadata statistics
|
|
343
|
+
if 'metadata_stats' in validation_results:
|
|
344
|
+
stats = validation_results['metadata_stats']
|
|
345
|
+
self.logger.info(f"Image statistics:")
|
|
346
|
+
self.logger.info(f" Width: {stats['width_stats']['mean']:.1f}±{stats['width_stats']['std']:.1f}")
|
|
347
|
+
self.logger.info(f" Height: {stats['height_stats']['mean']:.1f}±{stats['height_stats']['std']:.1f}")
|
|
348
|
+
self.logger.info(f" File size: {stats['file_size_stats']['mean_mb']:.2f}±{stats['file_size_stats']['std_mb']:.2f} MB")
|
|
349
|
+
|
|
350
|
+
def _create_class_mapping(self) -> Tuple[List[str], Dict[str, int]]:
|
|
351
|
+
"""Create class mapping from directory structure."""
|
|
352
|
+
# Assume directory structure: root_dir/class_name/image_files
|
|
353
|
+
classes = set()
|
|
354
|
+
|
|
355
|
+
for image_path in self.image_paths:
|
|
356
|
+
# Get parent directory name as class
|
|
357
|
+
class_name = image_path.parent.name
|
|
358
|
+
if class_name != self.root_dir.name: # Skip if image is directly in root
|
|
359
|
+
classes.add(class_name)
|
|
360
|
+
|
|
361
|
+
classes = sorted(list(classes))
|
|
362
|
+
class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
|
|
363
|
+
|
|
364
|
+
return classes, class_to_idx
|
|
365
|
+
|
|
366
|
+
def _create_labels(self) -> List[int]:
|
|
367
|
+
"""Create labels for each image based on directory structure."""
|
|
368
|
+
labels = []
|
|
369
|
+
|
|
370
|
+
for image_path in self.image_paths:
|
|
371
|
+
class_name = image_path.parent.name
|
|
372
|
+
if class_name in self.class_to_idx:
|
|
373
|
+
labels.append(self.class_to_idx[class_name])
|
|
374
|
+
else:
|
|
375
|
+
labels.append(-1) # Unknown class
|
|
376
|
+
|
|
377
|
+
return labels
|
|
378
|
+
|
|
379
|
+
def __len__(self) -> int:
|
|
380
|
+
return len(self.image_paths)
|
|
381
|
+
|
|
382
|
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
|
383
|
+
"""Get a single item from the dataset."""
|
|
384
|
+
image_path = self.image_paths[idx]
|
|
385
|
+
|
|
386
|
+
try:
|
|
387
|
+
# Load image from cache or disk
|
|
388
|
+
if self.cache_images and str(image_path) in self.image_cache:
|
|
389
|
+
image = self.image_cache[str(image_path)]
|
|
390
|
+
else:
|
|
391
|
+
image = Image.open(image_path).convert('RGB')
|
|
392
|
+
if self.cache_images:
|
|
393
|
+
self.image_cache[str(image_path)] = image
|
|
394
|
+
|
|
395
|
+
# Apply transforms
|
|
396
|
+
if self.transform:
|
|
397
|
+
image = self.transform(image)
|
|
398
|
+
|
|
399
|
+
# Create sample dictionary
|
|
400
|
+
sample = {
|
|
401
|
+
'images': image,
|
|
402
|
+
'labels': self.labels[idx] if idx < len(self.labels) else -1,
|
|
403
|
+
'image_paths': str(image_path),
|
|
404
|
+
'class_names': self.classes[self.labels[idx]] if self.labels[idx] >= 0 and self.labels[idx] < len(self.classes) else 'unknown'
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
return sample
|
|
408
|
+
|
|
409
|
+
except Exception as e:
|
|
410
|
+
self.logger.error(f"Error loading image {image_path}: {str(e)}")
|
|
411
|
+
# Return a dummy sample to avoid breaking the dataloader
|
|
412
|
+
dummy_image = torch.zeros(3, *self.config.image_size)
|
|
413
|
+
return {
|
|
414
|
+
'images': dummy_image,
|
|
415
|
+
'labels': -1,
|
|
416
|
+
'image_paths': str(image_path),
|
|
417
|
+
'class_names': 'error'
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
def get_class_distribution(self) -> Dict[str, int]:
|
|
421
|
+
"""Get the distribution of classes in the dataset."""
|
|
422
|
+
if not self.classes:
|
|
423
|
+
return {}
|
|
424
|
+
|
|
425
|
+
label_counts = Counter(self.labels)
|
|
426
|
+
return {self.classes[label]: count for label, count in label_counts.items() if label >= 0}
|
|
427
|
+
|
|
428
|
+
def get_dataset_info(self) -> Dict[str, Any]:
|
|
429
|
+
"""Get comprehensive dataset information."""
|
|
430
|
+
return {
|
|
431
|
+
'num_samples': len(self),
|
|
432
|
+
'num_classes': len(self.classes),
|
|
433
|
+
'classes': self.classes,
|
|
434
|
+
'class_distribution': self.get_class_distribution(),
|
|
435
|
+
'image_size': self.config.image_size,
|
|
436
|
+
'root_directory': str(self.root_dir),
|
|
437
|
+
'cache_enabled': self.cache_images
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
class MultimodalDataset(EnhancedImageDataset):
|
|
442
|
+
"""Multimodal dataset supporting both images and text."""
|
|
443
|
+
|
|
444
|
+
def __init__(self,
|
|
445
|
+
root_dir: Union[str, Path],
|
|
446
|
+
annotations_file: Optional[Union[str, Path]] = None,
|
|
447
|
+
config: Optional[DatasetConfig] = None,
|
|
448
|
+
**kwargs):
|
|
449
|
+
|
|
450
|
+
super().__init__(root_dir, config, **kwargs)
|
|
451
|
+
|
|
452
|
+
self.annotations_file = Path(annotations_file) if annotations_file else None
|
|
453
|
+
self.annotations = self._load_annotations()
|
|
454
|
+
|
|
455
|
+
# Setup text processing
|
|
456
|
+
if self.config.text_tokenizer:
|
|
457
|
+
from transformers import AutoTokenizer
|
|
458
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.text_tokenizer)
|
|
459
|
+
else:
|
|
460
|
+
self.tokenizer = None
|
|
461
|
+
|
|
462
|
+
def _load_annotations(self) -> Dict[str, Any]:
|
|
463
|
+
"""Load text annotations from file."""
|
|
464
|
+
if not self.annotations_file or not self.annotations_file.exists():
|
|
465
|
+
self.logger.warning("No annotations file found, using image filenames as text")
|
|
466
|
+
return {}
|
|
467
|
+
|
|
468
|
+
try:
|
|
469
|
+
if self.annotations_file.suffix == '.json':
|
|
470
|
+
with open(self.annotations_file, 'r') as f:
|
|
471
|
+
return json.load(f)
|
|
472
|
+
elif self.annotations_file.suffix == '.csv':
|
|
473
|
+
df = pd.read_csv(self.annotations_file)
|
|
474
|
+
return df.to_dict('records')
|
|
475
|
+
else:
|
|
476
|
+
self.logger.error(f"Unsupported annotations format: {self.annotations_file.suffix}")
|
|
477
|
+
return {}
|
|
478
|
+
except Exception as e:
|
|
479
|
+
self.logger.error(f"Failed to load annotations: {str(e)}")
|
|
480
|
+
return {}
|
|
481
|
+
|
|
482
|
+
def _get_text_for_image(self, image_path: Path) -> str:
|
|
483
|
+
"""Get text description for an image."""
|
|
484
|
+
image_name = image_path.name
|
|
485
|
+
|
|
486
|
+
# Try to find annotation by filename
|
|
487
|
+
if isinstance(self.annotations, dict):
|
|
488
|
+
return self.annotations.get(image_name, image_path.stem)
|
|
489
|
+
elif isinstance(self.annotations, list):
|
|
490
|
+
for annotation in self.annotations:
|
|
491
|
+
if annotation.get('filename') == image_name or annotation.get('image_path') == str(image_path):
|
|
492
|
+
return annotation.get('caption', annotation.get('text', image_path.stem))
|
|
493
|
+
|
|
494
|
+
# Fallback to filename
|
|
495
|
+
return image_path.stem.replace('_', ' ').replace('-', ' ')
|
|
496
|
+
|
|
497
|
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
|
498
|
+
"""Get a multimodal sample."""
|
|
499
|
+
sample = super().__getitem__(idx)
|
|
500
|
+
|
|
501
|
+
# Add text information
|
|
502
|
+
image_path = Path(sample['image_paths'])
|
|
503
|
+
text = self._get_text_for_image(image_path)
|
|
504
|
+
|
|
505
|
+
sample['texts'] = text
|
|
506
|
+
|
|
507
|
+
# Tokenize text if tokenizer is available
|
|
508
|
+
if self.tokenizer:
|
|
509
|
+
tokens = self.tokenizer(
|
|
510
|
+
text,
|
|
511
|
+
padding='max_length',
|
|
512
|
+
truncation=True,
|
|
513
|
+
max_length=self.config.text_max_length,
|
|
514
|
+
return_tensors='pt'
|
|
515
|
+
)
|
|
516
|
+
sample['text_tokens'] = {k: v.squeeze(0) for k, v in tokens.items()}
|
|
517
|
+
|
|
518
|
+
return sample
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def create_enhanced_dataloaders(config: DatasetConfig,
|
|
522
|
+
dataset_type: str = "image",
|
|
523
|
+
annotations_file: Optional[str] = None) -> Dict[str, DataLoader]:
|
|
524
|
+
"""Factory function to create enhanced dataloaders."""
|
|
525
|
+
|
|
526
|
+
# Create dataset
|
|
527
|
+
if dataset_type == "multimodal":
|
|
528
|
+
dataset_class = MultimodalDataset
|
|
529
|
+
dataset_kwargs = {"annotations_file": annotations_file}
|
|
530
|
+
else:
|
|
531
|
+
dataset_class = EnhancedImageDataset
|
|
532
|
+
dataset_kwargs = {}
|
|
533
|
+
|
|
534
|
+
# Create full dataset
|
|
535
|
+
full_dataset = dataset_class(
|
|
536
|
+
root_dir=config.root_dir,
|
|
537
|
+
config=config,
|
|
538
|
+
**dataset_kwargs
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
# Split dataset
|
|
542
|
+
total_size = len(full_dataset)
|
|
543
|
+
train_size = int(config.train_split * total_size)
|
|
544
|
+
val_size = int(config.val_split * total_size)
|
|
545
|
+
test_size = total_size - train_size - val_size
|
|
546
|
+
|
|
547
|
+
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
|
|
548
|
+
full_dataset, [train_size, val_size, test_size]
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
# Create dataloaders
|
|
552
|
+
dataloaders = {}
|
|
553
|
+
|
|
554
|
+
if train_size > 0:
|
|
555
|
+
dataloaders['train'] = DataLoader(
|
|
556
|
+
train_dataset,
|
|
557
|
+
batch_size=config.batch_size,
|
|
558
|
+
shuffle=True,
|
|
559
|
+
num_workers=config.num_workers,
|
|
560
|
+
pin_memory=config.pin_memory,
|
|
561
|
+
drop_last=True
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
if val_size > 0:
|
|
565
|
+
dataloaders['val'] = DataLoader(
|
|
566
|
+
val_dataset,
|
|
567
|
+
batch_size=config.batch_size,
|
|
568
|
+
shuffle=False,
|
|
569
|
+
num_workers=config.num_workers,
|
|
570
|
+
pin_memory=config.pin_memory
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
if test_size > 0:
|
|
574
|
+
dataloaders['test'] = DataLoader(
|
|
575
|
+
test_dataset,
|
|
576
|
+
batch_size=config.batch_size,
|
|
577
|
+
shuffle=False,
|
|
578
|
+
num_workers=config.num_workers,
|
|
579
|
+
pin_memory=config.pin_memory
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
return dataloaders
|