langvision 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langvision might be problematic. Click here for more details.

Files changed (41) hide show
  1. langvision/__init__.py +77 -2
  2. langvision/callbacks/base.py +166 -7
  3. langvision/cli/__init__.py +85 -0
  4. langvision/cli/complete_cli.py +319 -0
  5. langvision/cli/config.py +344 -0
  6. langvision/cli/evaluate.py +201 -0
  7. langvision/cli/export.py +177 -0
  8. langvision/cli/finetune.py +165 -48
  9. langvision/cli/model_zoo.py +162 -0
  10. langvision/cli/train.py +27 -13
  11. langvision/cli/utils.py +258 -0
  12. langvision/components/attention.py +4 -1
  13. langvision/concepts/__init__.py +9 -0
  14. langvision/concepts/ccot.py +30 -0
  15. langvision/concepts/cot.py +29 -0
  16. langvision/concepts/dpo.py +37 -0
  17. langvision/concepts/grpo.py +25 -0
  18. langvision/concepts/lime.py +37 -0
  19. langvision/concepts/ppo.py +47 -0
  20. langvision/concepts/rlhf.py +40 -0
  21. langvision/concepts/rlvr.py +25 -0
  22. langvision/concepts/shap.py +37 -0
  23. langvision/data/enhanced_datasets.py +582 -0
  24. langvision/model_zoo.py +169 -2
  25. langvision/models/lora.py +189 -17
  26. langvision/models/multimodal.py +297 -0
  27. langvision/models/resnet.py +303 -0
  28. langvision/training/advanced_trainer.py +478 -0
  29. langvision/training/trainer.py +30 -2
  30. langvision/utils/config.py +180 -9
  31. langvision/utils/metrics.py +448 -0
  32. langvision/utils/setup.py +266 -0
  33. langvision-0.1.0.dist-info/METADATA +50 -0
  34. langvision-0.1.0.dist-info/RECORD +61 -0
  35. {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/WHEEL +1 -1
  36. langvision-0.1.0.dist-info/entry_points.txt +2 -0
  37. langvision-0.0.1.dist-info/METADATA +0 -463
  38. langvision-0.0.1.dist-info/RECORD +0 -40
  39. langvision-0.0.1.dist-info/entry_points.txt +0 -2
  40. langvision-0.0.1.dist-info/licenses/LICENSE +0 -21
  41. {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,582 @@
1
+ """
2
+ Enhanced dataset classes with comprehensive data validation, augmentation, and multimodal support.
3
+ """
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+ import torchvision.transforms as transforms
8
+ from torchvision.datasets import ImageFolder, CIFAR10, CIFAR100
9
+ from PIL import Image
10
+ import pandas as pd
11
+ import json
12
+ import os
13
+ from pathlib import Path
14
+ from typing import Dict, List, Optional, Union, Tuple, Callable, Any
15
+ import numpy as np
16
+ import logging
17
+ from dataclasses import dataclass
18
+ import warnings
19
+ from collections import Counter
20
+ import cv2
21
+
22
+
23
+ @dataclass
24
+ class DatasetConfig:
25
+ """Configuration for dataset parameters."""
26
+ # Basic settings
27
+ root_dir: str
28
+ image_size: Tuple[int, int] = (224, 224)
29
+ batch_size: int = 32
30
+ num_workers: int = 4
31
+ pin_memory: bool = True
32
+
33
+ # Data splits
34
+ train_split: float = 0.8
35
+ val_split: float = 0.1
36
+ test_split: float = 0.1
37
+
38
+ # Augmentation settings
39
+ use_augmentation: bool = True
40
+ augmentation_strength: float = 0.5
41
+
42
+ # Multimodal settings
43
+ text_max_length: int = 77
44
+ text_tokenizer: Optional[str] = None
45
+
46
+ # Validation settings
47
+ validate_images: bool = True
48
+ min_image_size: Tuple[int, int] = (32, 32)
49
+ max_image_size: Tuple[int, int] = (4096, 4096)
50
+ allowed_formats: List[str] = None
51
+
52
+ def __post_init__(self):
53
+ if self.allowed_formats is None:
54
+ self.allowed_formats = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
55
+
56
+ # Validate splits sum to 1.0
57
+ total_split = self.train_split + self.val_split + self.test_split
58
+ if abs(total_split - 1.0) > 1e-6:
59
+ raise ValueError(f"Data splits must sum to 1.0, got {total_split}")
60
+
61
+
62
+ class ImageValidator:
63
+ """Comprehensive image validation utilities."""
64
+
65
+ def __init__(self,
66
+ min_size: Tuple[int, int] = (32, 32),
67
+ max_size: Tuple[int, int] = (4096, 4096),
68
+ allowed_formats: List[str] = None):
69
+ self.min_size = min_size
70
+ self.max_size = max_size
71
+ self.allowed_formats = allowed_formats or ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
72
+ self.logger = logging.getLogger(__name__)
73
+
74
+ def validate_image_file(self, image_path: Union[str, Path]) -> Dict[str, Any]:
75
+ """Validate a single image file."""
76
+ image_path = Path(image_path)
77
+ result = {
78
+ 'valid': True,
79
+ 'errors': [],
80
+ 'warnings': [],
81
+ 'metadata': {}
82
+ }
83
+
84
+ # Check if file exists
85
+ if not image_path.exists():
86
+ result['valid'] = False
87
+ result['errors'].append(f"File does not exist: {image_path}")
88
+ return result
89
+
90
+ # Check file extension
91
+ if image_path.suffix.lower() not in self.allowed_formats:
92
+ result['valid'] = False
93
+ result['errors'].append(f"Unsupported format: {image_path.suffix}")
94
+ return result
95
+
96
+ try:
97
+ # Try to open and validate image
98
+ with Image.open(image_path) as img:
99
+ width, height = img.size
100
+ result['metadata'].update({
101
+ 'width': width,
102
+ 'height': height,
103
+ 'mode': img.mode,
104
+ 'format': img.format,
105
+ 'file_size': image_path.stat().st_size
106
+ })
107
+
108
+ # Check image dimensions
109
+ if width < self.min_size[0] or height < self.min_size[1]:
110
+ result['warnings'].append(f"Image too small: {width}x{height} < {self.min_size}")
111
+
112
+ if width > self.max_size[0] or height > self.max_size[1]:
113
+ result['warnings'].append(f"Image very large: {width}x{height} > {self.max_size}")
114
+
115
+ # Check if image is corrupted by trying to load it
116
+ img.verify()
117
+
118
+ except Exception as e:
119
+ result['valid'] = False
120
+ result['errors'].append(f"Failed to open image: {str(e)}")
121
+
122
+ return result
123
+
124
+ def validate_dataset(self, image_paths: List[Union[str, Path]]) -> Dict[str, Any]:
125
+ """Validate an entire dataset."""
126
+ results = {
127
+ 'total_images': len(image_paths),
128
+ 'valid_images': 0,
129
+ 'invalid_images': 0,
130
+ 'warnings_count': 0,
131
+ 'errors': [],
132
+ 'warnings': [],
133
+ 'metadata_stats': {}
134
+ }
135
+
136
+ valid_images = []
137
+ widths, heights, file_sizes = [], [], []
138
+
139
+ for image_path in image_paths:
140
+ validation_result = self.validate_image_file(image_path)
141
+
142
+ if validation_result['valid']:
143
+ results['valid_images'] += 1
144
+ valid_images.append(image_path)
145
+
146
+ # Collect metadata for statistics
147
+ metadata = validation_result['metadata']
148
+ widths.append(metadata['width'])
149
+ heights.append(metadata['height'])
150
+ file_sizes.append(metadata['file_size'])
151
+ else:
152
+ results['invalid_images'] += 1
153
+ results['errors'].extend(validation_result['errors'])
154
+
155
+ results['warnings'].extend(validation_result['warnings'])
156
+ results['warnings_count'] += len(validation_result['warnings'])
157
+
158
+ # Compute metadata statistics
159
+ if valid_images:
160
+ results['metadata_stats'] = {
161
+ 'width_stats': {
162
+ 'mean': np.mean(widths),
163
+ 'std': np.std(widths),
164
+ 'min': np.min(widths),
165
+ 'max': np.max(widths)
166
+ },
167
+ 'height_stats': {
168
+ 'mean': np.mean(heights),
169
+ 'std': np.std(heights),
170
+ 'min': np.min(heights),
171
+ 'max': np.max(heights)
172
+ },
173
+ 'file_size_stats': {
174
+ 'mean_mb': np.mean(file_sizes) / (1024**2),
175
+ 'std_mb': np.std(file_sizes) / (1024**2),
176
+ 'min_mb': np.min(file_sizes) / (1024**2),
177
+ 'max_mb': np.max(file_sizes) / (1024**2)
178
+ }
179
+ }
180
+
181
+ return results
182
+
183
+
184
+ class SmartAugmentation:
185
+ """Intelligent data augmentation with adaptive strategies."""
186
+
187
+ def __init__(self,
188
+ image_size: Tuple[int, int] = (224, 224),
189
+ strength: float = 0.5,
190
+ preserve_aspect_ratio: bool = True):
191
+ self.image_size = image_size
192
+ self.strength = strength
193
+ self.preserve_aspect_ratio = preserve_aspect_ratio
194
+
195
+ # Base transforms
196
+ self.base_transforms = transforms.Compose([
197
+ transforms.Resize(image_size),
198
+ transforms.ToTensor(),
199
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
200
+ ])
201
+
202
+ # Augmentation transforms based on strength
203
+ self.augmentation_transforms = self._create_augmentation_transforms()
204
+
205
+ def _create_augmentation_transforms(self) -> transforms.Compose:
206
+ """Create augmentation transforms based on strength parameter."""
207
+ aug_list = []
208
+
209
+ if self.preserve_aspect_ratio:
210
+ aug_list.append(transforms.Resize(min(self.image_size)))
211
+ aug_list.append(transforms.CenterCrop(self.image_size))
212
+ else:
213
+ aug_list.append(transforms.Resize(self.image_size))
214
+
215
+ # Add augmentations based on strength
216
+ if self.strength > 0.1:
217
+ aug_list.extend([
218
+ transforms.RandomHorizontalFlip(p=0.5 * self.strength),
219
+ transforms.RandomRotation(degrees=10 * self.strength),
220
+ ])
221
+
222
+ if self.strength > 0.3:
223
+ aug_list.extend([
224
+ transforms.ColorJitter(
225
+ brightness=0.2 * self.strength,
226
+ contrast=0.2 * self.strength,
227
+ saturation=0.2 * self.strength,
228
+ hue=0.1 * self.strength
229
+ ),
230
+ transforms.RandomAffine(
231
+ degrees=5 * self.strength,
232
+ translate=(0.1 * self.strength, 0.1 * self.strength),
233
+ scale=(1 - 0.1 * self.strength, 1 + 0.1 * self.strength)
234
+ )
235
+ ])
236
+
237
+ if self.strength > 0.5:
238
+ aug_list.extend([
239
+ transforms.RandomPerspective(distortion_scale=0.2 * self.strength, p=0.3),
240
+ transforms.RandomErasing(p=0.2 * self.strength, scale=(0.02, 0.33), ratio=(0.3, 3.3))
241
+ ])
242
+
243
+ # Always add tensor conversion and normalization
244
+ aug_list.extend([
245
+ transforms.ToTensor(),
246
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
247
+ ])
248
+
249
+ return transforms.Compose(aug_list)
250
+
251
+ def get_train_transforms(self) -> transforms.Compose:
252
+ """Get training transforms with augmentation."""
253
+ return self.augmentation_transforms
254
+
255
+ def get_val_transforms(self) -> transforms.Compose:
256
+ """Get validation transforms without augmentation."""
257
+ return self.base_transforms
258
+
259
+
260
+ class EnhancedImageDataset(Dataset):
261
+ """Enhanced image dataset with validation, smart augmentation, and error handling."""
262
+
263
+ def __init__(self,
264
+ root_dir: Union[str, Path],
265
+ config: Optional[DatasetConfig] = None,
266
+ transform: Optional[Callable] = None,
267
+ validate_on_init: bool = True,
268
+ cache_images: bool = False):
269
+
270
+ self.root_dir = Path(root_dir)
271
+ self.config = config or DatasetConfig(root_dir=str(root_dir))
272
+ self.cache_images = cache_images
273
+ self.image_cache = {} if cache_images else None
274
+
275
+ # Setup logging
276
+ self.logger = logging.getLogger(__name__)
277
+
278
+ # Find all image files
279
+ self.image_paths = self._find_image_files()
280
+ self.logger.info(f"Found {len(self.image_paths)} image files")
281
+
282
+ # Setup transforms
283
+ if transform is None:
284
+ augmentation = SmartAugmentation(
285
+ image_size=self.config.image_size,
286
+ strength=self.config.augmentation_strength
287
+ )
288
+ self.transform = augmentation.get_train_transforms() if self.config.use_augmentation else augmentation.get_val_transforms()
289
+ else:
290
+ self.transform = transform
291
+
292
+ # Validate dataset if requested
293
+ if validate_on_init and self.config.validate_images:
294
+ self._validate_dataset()
295
+
296
+ # Create class mapping if this is a classification dataset
297
+ self.classes, self.class_to_idx = self._create_class_mapping()
298
+
299
+ # Create labels
300
+ self.labels = self._create_labels()
301
+
302
+ def _find_image_files(self) -> List[Path]:
303
+ """Find all valid image files in the directory."""
304
+ image_paths = []
305
+
306
+ for ext in self.config.allowed_formats:
307
+ pattern = f"**/*{ext}"
308
+ image_paths.extend(self.root_dir.glob(pattern))
309
+ # Also check uppercase extensions
310
+ pattern = f"**/*{ext.upper()}"
311
+ image_paths.extend(self.root_dir.glob(pattern))
312
+
313
+ return sorted(list(set(image_paths))) # Remove duplicates and sort
314
+
315
+ def _validate_dataset(self):
316
+ """Validate the entire dataset."""
317
+ validator = ImageValidator(
318
+ min_size=self.config.min_image_size,
319
+ max_size=self.config.max_image_size,
320
+ allowed_formats=self.config.allowed_formats
321
+ )
322
+
323
+ validation_results = validator.validate_dataset(self.image_paths)
324
+
325
+ # Log validation results
326
+ self.logger.info(f"Dataset validation completed:")
327
+ self.logger.info(f" Total images: {validation_results['total_images']}")
328
+ self.logger.info(f" Valid images: {validation_results['valid_images']}")
329
+ self.logger.info(f" Invalid images: {validation_results['invalid_images']}")
330
+ self.logger.info(f" Warnings: {validation_results['warnings_count']}")
331
+
332
+ # Remove invalid images
333
+ if validation_results['invalid_images'] > 0:
334
+ self.logger.warning(f"Removing {validation_results['invalid_images']} invalid images")
335
+ valid_paths = []
336
+ for image_path in self.image_paths:
337
+ result = validator.validate_image_file(image_path)
338
+ if result['valid']:
339
+ valid_paths.append(image_path)
340
+ self.image_paths = valid_paths
341
+
342
+ # Log metadata statistics
343
+ if 'metadata_stats' in validation_results:
344
+ stats = validation_results['metadata_stats']
345
+ self.logger.info(f"Image statistics:")
346
+ self.logger.info(f" Width: {stats['width_stats']['mean']:.1f}±{stats['width_stats']['std']:.1f}")
347
+ self.logger.info(f" Height: {stats['height_stats']['mean']:.1f}±{stats['height_stats']['std']:.1f}")
348
+ self.logger.info(f" File size: {stats['file_size_stats']['mean_mb']:.2f}±{stats['file_size_stats']['std_mb']:.2f} MB")
349
+
350
+ def _create_class_mapping(self) -> Tuple[List[str], Dict[str, int]]:
351
+ """Create class mapping from directory structure."""
352
+ # Assume directory structure: root_dir/class_name/image_files
353
+ classes = set()
354
+
355
+ for image_path in self.image_paths:
356
+ # Get parent directory name as class
357
+ class_name = image_path.parent.name
358
+ if class_name != self.root_dir.name: # Skip if image is directly in root
359
+ classes.add(class_name)
360
+
361
+ classes = sorted(list(classes))
362
+ class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
363
+
364
+ return classes, class_to_idx
365
+
366
+ def _create_labels(self) -> List[int]:
367
+ """Create labels for each image based on directory structure."""
368
+ labels = []
369
+
370
+ for image_path in self.image_paths:
371
+ class_name = image_path.parent.name
372
+ if class_name in self.class_to_idx:
373
+ labels.append(self.class_to_idx[class_name])
374
+ else:
375
+ labels.append(-1) # Unknown class
376
+
377
+ return labels
378
+
379
+ def __len__(self) -> int:
380
+ return len(self.image_paths)
381
+
382
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
383
+ """Get a single item from the dataset."""
384
+ image_path = self.image_paths[idx]
385
+
386
+ try:
387
+ # Load image from cache or disk
388
+ if self.cache_images and str(image_path) in self.image_cache:
389
+ image = self.image_cache[str(image_path)]
390
+ else:
391
+ image = Image.open(image_path).convert('RGB')
392
+ if self.cache_images:
393
+ self.image_cache[str(image_path)] = image
394
+
395
+ # Apply transforms
396
+ if self.transform:
397
+ image = self.transform(image)
398
+
399
+ # Create sample dictionary
400
+ sample = {
401
+ 'images': image,
402
+ 'labels': self.labels[idx] if idx < len(self.labels) else -1,
403
+ 'image_paths': str(image_path),
404
+ 'class_names': self.classes[self.labels[idx]] if self.labels[idx] >= 0 and self.labels[idx] < len(self.classes) else 'unknown'
405
+ }
406
+
407
+ return sample
408
+
409
+ except Exception as e:
410
+ self.logger.error(f"Error loading image {image_path}: {str(e)}")
411
+ # Return a dummy sample to avoid breaking the dataloader
412
+ dummy_image = torch.zeros(3, *self.config.image_size)
413
+ return {
414
+ 'images': dummy_image,
415
+ 'labels': -1,
416
+ 'image_paths': str(image_path),
417
+ 'class_names': 'error'
418
+ }
419
+
420
+ def get_class_distribution(self) -> Dict[str, int]:
421
+ """Get the distribution of classes in the dataset."""
422
+ if not self.classes:
423
+ return {}
424
+
425
+ label_counts = Counter(self.labels)
426
+ return {self.classes[label]: count for label, count in label_counts.items() if label >= 0}
427
+
428
+ def get_dataset_info(self) -> Dict[str, Any]:
429
+ """Get comprehensive dataset information."""
430
+ return {
431
+ 'num_samples': len(self),
432
+ 'num_classes': len(self.classes),
433
+ 'classes': self.classes,
434
+ 'class_distribution': self.get_class_distribution(),
435
+ 'image_size': self.config.image_size,
436
+ 'root_directory': str(self.root_dir),
437
+ 'cache_enabled': self.cache_images
438
+ }
439
+
440
+
441
+ class MultimodalDataset(EnhancedImageDataset):
442
+ """Multimodal dataset supporting both images and text."""
443
+
444
+ def __init__(self,
445
+ root_dir: Union[str, Path],
446
+ annotations_file: Optional[Union[str, Path]] = None,
447
+ config: Optional[DatasetConfig] = None,
448
+ **kwargs):
449
+
450
+ super().__init__(root_dir, config, **kwargs)
451
+
452
+ self.annotations_file = Path(annotations_file) if annotations_file else None
453
+ self.annotations = self._load_annotations()
454
+
455
+ # Setup text processing
456
+ if self.config.text_tokenizer:
457
+ from transformers import AutoTokenizer
458
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.text_tokenizer)
459
+ else:
460
+ self.tokenizer = None
461
+
462
+ def _load_annotations(self) -> Dict[str, Any]:
463
+ """Load text annotations from file."""
464
+ if not self.annotations_file or not self.annotations_file.exists():
465
+ self.logger.warning("No annotations file found, using image filenames as text")
466
+ return {}
467
+
468
+ try:
469
+ if self.annotations_file.suffix == '.json':
470
+ with open(self.annotations_file, 'r') as f:
471
+ return json.load(f)
472
+ elif self.annotations_file.suffix == '.csv':
473
+ df = pd.read_csv(self.annotations_file)
474
+ return df.to_dict('records')
475
+ else:
476
+ self.logger.error(f"Unsupported annotations format: {self.annotations_file.suffix}")
477
+ return {}
478
+ except Exception as e:
479
+ self.logger.error(f"Failed to load annotations: {str(e)}")
480
+ return {}
481
+
482
+ def _get_text_for_image(self, image_path: Path) -> str:
483
+ """Get text description for an image."""
484
+ image_name = image_path.name
485
+
486
+ # Try to find annotation by filename
487
+ if isinstance(self.annotations, dict):
488
+ return self.annotations.get(image_name, image_path.stem)
489
+ elif isinstance(self.annotations, list):
490
+ for annotation in self.annotations:
491
+ if annotation.get('filename') == image_name or annotation.get('image_path') == str(image_path):
492
+ return annotation.get('caption', annotation.get('text', image_path.stem))
493
+
494
+ # Fallback to filename
495
+ return image_path.stem.replace('_', ' ').replace('-', ' ')
496
+
497
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
498
+ """Get a multimodal sample."""
499
+ sample = super().__getitem__(idx)
500
+
501
+ # Add text information
502
+ image_path = Path(sample['image_paths'])
503
+ text = self._get_text_for_image(image_path)
504
+
505
+ sample['texts'] = text
506
+
507
+ # Tokenize text if tokenizer is available
508
+ if self.tokenizer:
509
+ tokens = self.tokenizer(
510
+ text,
511
+ padding='max_length',
512
+ truncation=True,
513
+ max_length=self.config.text_max_length,
514
+ return_tensors='pt'
515
+ )
516
+ sample['text_tokens'] = {k: v.squeeze(0) for k, v in tokens.items()}
517
+
518
+ return sample
519
+
520
+
521
+ def create_enhanced_dataloaders(config: DatasetConfig,
522
+ dataset_type: str = "image",
523
+ annotations_file: Optional[str] = None) -> Dict[str, DataLoader]:
524
+ """Factory function to create enhanced dataloaders."""
525
+
526
+ # Create dataset
527
+ if dataset_type == "multimodal":
528
+ dataset_class = MultimodalDataset
529
+ dataset_kwargs = {"annotations_file": annotations_file}
530
+ else:
531
+ dataset_class = EnhancedImageDataset
532
+ dataset_kwargs = {}
533
+
534
+ # Create full dataset
535
+ full_dataset = dataset_class(
536
+ root_dir=config.root_dir,
537
+ config=config,
538
+ **dataset_kwargs
539
+ )
540
+
541
+ # Split dataset
542
+ total_size = len(full_dataset)
543
+ train_size = int(config.train_split * total_size)
544
+ val_size = int(config.val_split * total_size)
545
+ test_size = total_size - train_size - val_size
546
+
547
+ train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
548
+ full_dataset, [train_size, val_size, test_size]
549
+ )
550
+
551
+ # Create dataloaders
552
+ dataloaders = {}
553
+
554
+ if train_size > 0:
555
+ dataloaders['train'] = DataLoader(
556
+ train_dataset,
557
+ batch_size=config.batch_size,
558
+ shuffle=True,
559
+ num_workers=config.num_workers,
560
+ pin_memory=config.pin_memory,
561
+ drop_last=True
562
+ )
563
+
564
+ if val_size > 0:
565
+ dataloaders['val'] = DataLoader(
566
+ val_dataset,
567
+ batch_size=config.batch_size,
568
+ shuffle=False,
569
+ num_workers=config.num_workers,
570
+ pin_memory=config.pin_memory
571
+ )
572
+
573
+ if test_size > 0:
574
+ dataloaders['test'] = DataLoader(
575
+ test_dataset,
576
+ batch_size=config.batch_size,
577
+ shuffle=False,
578
+ num_workers=config.num_workers,
579
+ pin_memory=config.pin_memory
580
+ )
581
+
582
+ return dataloaders