langvision 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langvision might be problematic. Click here for more details.
- langvision/__init__.py +77 -2
- langvision/callbacks/base.py +166 -7
- langvision/cli/__init__.py +85 -0
- langvision/cli/complete_cli.py +319 -0
- langvision/cli/config.py +344 -0
- langvision/cli/evaluate.py +201 -0
- langvision/cli/export.py +177 -0
- langvision/cli/finetune.py +165 -48
- langvision/cli/model_zoo.py +162 -0
- langvision/cli/train.py +27 -13
- langvision/cli/utils.py +258 -0
- langvision/components/attention.py +4 -1
- langvision/concepts/__init__.py +9 -0
- langvision/concepts/ccot.py +30 -0
- langvision/concepts/cot.py +29 -0
- langvision/concepts/dpo.py +37 -0
- langvision/concepts/grpo.py +25 -0
- langvision/concepts/lime.py +37 -0
- langvision/concepts/ppo.py +47 -0
- langvision/concepts/rlhf.py +40 -0
- langvision/concepts/rlvr.py +25 -0
- langvision/concepts/shap.py +37 -0
- langvision/data/enhanced_datasets.py +582 -0
- langvision/model_zoo.py +169 -2
- langvision/models/lora.py +189 -17
- langvision/models/multimodal.py +297 -0
- langvision/models/resnet.py +303 -0
- langvision/training/advanced_trainer.py +478 -0
- langvision/training/trainer.py +30 -2
- langvision/utils/config.py +180 -9
- langvision/utils/metrics.py +448 -0
- langvision/utils/setup.py +266 -0
- langvision-0.1.0.dist-info/METADATA +50 -0
- langvision-0.1.0.dist-info/RECORD +61 -0
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/WHEEL +1 -1
- langvision-0.1.0.dist-info/entry_points.txt +2 -0
- langvision-0.0.1.dist-info/METADATA +0 -463
- langvision-0.0.1.dist-info/RECORD +0 -40
- langvision-0.0.1.dist-info/entry_points.txt +0 -2
- langvision-0.0.1.dist-info/licenses/LICENSE +0 -21
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/top_level.txt +0 -0
langvision/cli/config.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Configuration CLI for Langvision.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
import yaml
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
# Import from langvision modules
|
|
15
|
+
try:
|
|
16
|
+
from langvision.utils.config import default_config, load_config, save_config
|
|
17
|
+
except ImportError as e:
|
|
18
|
+
print(f"❌ Error importing langvision modules: {e}")
|
|
19
|
+
print("Please ensure langvision is properly installed:")
|
|
20
|
+
print(" pip install langvision")
|
|
21
|
+
sys.exit(1)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def parse_args():
|
|
25
|
+
"""Parse command-line arguments for configuration operations."""
|
|
26
|
+
parser = argparse.ArgumentParser(
|
|
27
|
+
description='Langvision Configuration Manager - Create, validate, and manage config files',
|
|
28
|
+
epilog='''\nExamples:\n langvision config create --output config.yaml\n langvision config validate config.yaml\n langvision config show --format json\n''',
|
|
29
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
subparsers = parser.add_subparsers(dest='command', required=True, help='Configuration commands')
|
|
33
|
+
|
|
34
|
+
# Create command
|
|
35
|
+
create_parser = subparsers.add_parser('create', help='Create a new configuration file')
|
|
36
|
+
create_parser.add_argument('--output', type=str, default='config.yaml', help='Output configuration file')
|
|
37
|
+
create_parser.add_argument('--template', type=str, choices=['basic', 'advanced', 'custom'], default='basic', help='Configuration template')
|
|
38
|
+
create_parser.add_argument('--dataset', type=str, default='cifar10', help='Default dataset')
|
|
39
|
+
create_parser.add_argument('--model', type=str, default='vit_base', help='Default model type')
|
|
40
|
+
|
|
41
|
+
# Validate command
|
|
42
|
+
validate_parser = subparsers.add_parser('validate', help='Validate a configuration file')
|
|
43
|
+
validate_parser.add_argument('config_file', type=str, help='Configuration file to validate')
|
|
44
|
+
validate_parser.add_argument('--strict', action='store_true', help='Use strict validation')
|
|
45
|
+
|
|
46
|
+
# Show command
|
|
47
|
+
show_parser = subparsers.add_parser('show', help='Show default configuration')
|
|
48
|
+
show_parser.add_argument('--format', type=str, default='yaml', choices=['yaml', 'json'], help='Output format')
|
|
49
|
+
show_parser.add_argument('--section', type=str, help='Show specific configuration section')
|
|
50
|
+
|
|
51
|
+
# Convert command
|
|
52
|
+
convert_parser = subparsers.add_parser('convert', help='Convert between configuration formats')
|
|
53
|
+
convert_parser.add_argument('input_file', type=str, help='Input configuration file')
|
|
54
|
+
convert_parser.add_argument('--output', type=str, help='Output configuration file')
|
|
55
|
+
convert_parser.add_argument('--format', type=str, choices=['yaml', 'json'], help='Output format')
|
|
56
|
+
|
|
57
|
+
# Diff command
|
|
58
|
+
diff_parser = subparsers.add_parser('diff', help='Compare two configuration files')
|
|
59
|
+
diff_parser.add_argument('config1', type=str, help='First configuration file')
|
|
60
|
+
diff_parser.add_argument('config2', type=str, help='Second configuration file')
|
|
61
|
+
diff_parser.add_argument('--format', type=str, default='table', choices=['table', 'json'], help='Output format')
|
|
62
|
+
|
|
63
|
+
# Misc
|
|
64
|
+
parser.add_argument('--log_level', type=str, default='info', help='Logging level')
|
|
65
|
+
|
|
66
|
+
return parser.parse_args()
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def setup_logging(log_level: str) -> None:
|
|
70
|
+
"""Set up logging with the specified log level."""
|
|
71
|
+
numeric_level = getattr(logging, log_level.upper(), None)
|
|
72
|
+
if not isinstance(numeric_level, int):
|
|
73
|
+
numeric_level = logging.INFO
|
|
74
|
+
logging.basicConfig(level=numeric_level, format='[%(levelname)s] %(message)s')
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def create_config_template(template_type, dataset, model):
|
|
78
|
+
"""Create a configuration template."""
|
|
79
|
+
if template_type == 'basic':
|
|
80
|
+
config = {
|
|
81
|
+
'model': {
|
|
82
|
+
'name': model,
|
|
83
|
+
'img_size': 224,
|
|
84
|
+
'patch_size': 16,
|
|
85
|
+
'num_classes': 10 if dataset == 'cifar10' else 100,
|
|
86
|
+
'embed_dim': 768,
|
|
87
|
+
'depth': 12,
|
|
88
|
+
'num_heads': 12,
|
|
89
|
+
'mlp_ratio': 4.0
|
|
90
|
+
},
|
|
91
|
+
'data': {
|
|
92
|
+
'dataset': dataset,
|
|
93
|
+
'data_dir': './data',
|
|
94
|
+
'batch_size': 64,
|
|
95
|
+
'num_workers': 2
|
|
96
|
+
},
|
|
97
|
+
'training': {
|
|
98
|
+
'epochs': 10,
|
|
99
|
+
'learning_rate': 1e-3,
|
|
100
|
+
'optimizer': 'adam',
|
|
101
|
+
'weight_decay': 0.01,
|
|
102
|
+
'scheduler': 'cosine'
|
|
103
|
+
},
|
|
104
|
+
'lora': {
|
|
105
|
+
'rank': 4,
|
|
106
|
+
'alpha': 1.0,
|
|
107
|
+
'dropout': 0.1
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
elif template_type == 'advanced':
|
|
111
|
+
config = {
|
|
112
|
+
'model': {
|
|
113
|
+
'name': model,
|
|
114
|
+
'img_size': 224,
|
|
115
|
+
'patch_size': 16,
|
|
116
|
+
'num_classes': 10 if dataset == 'cifar10' else 100,
|
|
117
|
+
'embed_dim': 768,
|
|
118
|
+
'depth': 12,
|
|
119
|
+
'num_heads': 12,
|
|
120
|
+
'mlp_ratio': 4.0,
|
|
121
|
+
'dropout': 0.1,
|
|
122
|
+
'attention_dropout': 0.1
|
|
123
|
+
},
|
|
124
|
+
'data': {
|
|
125
|
+
'dataset': dataset,
|
|
126
|
+
'data_dir': './data',
|
|
127
|
+
'batch_size': 64,
|
|
128
|
+
'num_workers': 4,
|
|
129
|
+
'pin_memory': True,
|
|
130
|
+
'persistent_workers': True
|
|
131
|
+
},
|
|
132
|
+
'training': {
|
|
133
|
+
'epochs': 50,
|
|
134
|
+
'learning_rate': 1e-4,
|
|
135
|
+
'optimizer': 'adamw',
|
|
136
|
+
'weight_decay': 0.05,
|
|
137
|
+
'scheduler': 'cosine',
|
|
138
|
+
'warmup_epochs': 5,
|
|
139
|
+
'min_lr': 1e-6,
|
|
140
|
+
'gradient_clip': 1.0
|
|
141
|
+
},
|
|
142
|
+
'lora': {
|
|
143
|
+
'rank': 16,
|
|
144
|
+
'alpha': 32,
|
|
145
|
+
'dropout': 0.1,
|
|
146
|
+
'target_modules': ['attention.qkv', 'attention.proj', 'mlp.fc1', 'mlp.fc2']
|
|
147
|
+
},
|
|
148
|
+
'callbacks': {
|
|
149
|
+
'early_stopping': {
|
|
150
|
+
'enabled': True,
|
|
151
|
+
'patience': 10,
|
|
152
|
+
'min_delta': 0.001
|
|
153
|
+
},
|
|
154
|
+
'checkpointing': {
|
|
155
|
+
'enabled': True,
|
|
156
|
+
'save_best': True,
|
|
157
|
+
'save_last': True
|
|
158
|
+
}
|
|
159
|
+
},
|
|
160
|
+
'logging': {
|
|
161
|
+
'level': 'info',
|
|
162
|
+
'log_interval': 100,
|
|
163
|
+
'save_interval': 5
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
else: # custom
|
|
167
|
+
config = default_config.copy()
|
|
168
|
+
config['data']['dataset'] = dataset
|
|
169
|
+
config['model']['name'] = model
|
|
170
|
+
|
|
171
|
+
return config
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def validate_config(config_file, strict=False):
|
|
175
|
+
"""Validate a configuration file."""
|
|
176
|
+
try:
|
|
177
|
+
config = load_config(config_file)
|
|
178
|
+
|
|
179
|
+
# Basic validation
|
|
180
|
+
required_sections = ['model', 'data', 'training']
|
|
181
|
+
for section in required_sections:
|
|
182
|
+
if section not in config:
|
|
183
|
+
return False, f"Missing required section: {section}"
|
|
184
|
+
|
|
185
|
+
# Model validation
|
|
186
|
+
model_config = config['model']
|
|
187
|
+
required_model_keys = ['img_size', 'patch_size', 'num_classes']
|
|
188
|
+
for key in required_model_keys:
|
|
189
|
+
if key not in model_config:
|
|
190
|
+
return False, f"Missing required model key: {key}"
|
|
191
|
+
|
|
192
|
+
# Data validation
|
|
193
|
+
data_config = config['data']
|
|
194
|
+
required_data_keys = ['dataset', 'batch_size']
|
|
195
|
+
for key in required_data_keys:
|
|
196
|
+
if key not in data_config:
|
|
197
|
+
return False, f"Missing required data key: {key}"
|
|
198
|
+
|
|
199
|
+
# Training validation
|
|
200
|
+
training_config = config['training']
|
|
201
|
+
required_training_keys = ['epochs', 'learning_rate']
|
|
202
|
+
for key in required_training_keys:
|
|
203
|
+
if key not in training_config:
|
|
204
|
+
return False, f"Missing required training key: {key}"
|
|
205
|
+
|
|
206
|
+
if strict:
|
|
207
|
+
# Additional strict validation
|
|
208
|
+
if training_config['learning_rate'] <= 0:
|
|
209
|
+
return False, "Learning rate must be positive"
|
|
210
|
+
if training_config['epochs'] <= 0:
|
|
211
|
+
return False, "Epochs must be positive"
|
|
212
|
+
if data_config['batch_size'] <= 0:
|
|
213
|
+
return False, "Batch size must be positive"
|
|
214
|
+
|
|
215
|
+
return True, "Configuration is valid"
|
|
216
|
+
|
|
217
|
+
except Exception as e:
|
|
218
|
+
return False, f"Validation error: {e}"
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def main():
|
|
222
|
+
"""Main function for configuration operations."""
|
|
223
|
+
args = parse_args()
|
|
224
|
+
setup_logging(args.log_level)
|
|
225
|
+
logger = logging.getLogger(__name__)
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
if args.command == 'create':
|
|
229
|
+
logger.info(f"📝 Creating {args.template} configuration template...")
|
|
230
|
+
config = create_config_template(args.template, args.dataset, args.model)
|
|
231
|
+
|
|
232
|
+
# Save configuration
|
|
233
|
+
if args.output.endswith('.json'):
|
|
234
|
+
with open(args.output, 'w') as f:
|
|
235
|
+
json.dump(config, f, indent=2)
|
|
236
|
+
else:
|
|
237
|
+
with open(args.output, 'w') as f:
|
|
238
|
+
yaml.dump(config, f, default_flow_style=False, indent=2)
|
|
239
|
+
|
|
240
|
+
logger.info(f"✅ Configuration saved to {args.output}")
|
|
241
|
+
|
|
242
|
+
elif args.command == 'validate':
|
|
243
|
+
logger.info(f"🔍 Validating configuration: {args.config_file}")
|
|
244
|
+
is_valid, message = validate_config(args.config_file, args.strict)
|
|
245
|
+
|
|
246
|
+
if is_valid:
|
|
247
|
+
logger.info(f"✅ {message}")
|
|
248
|
+
else:
|
|
249
|
+
logger.error(f"❌ {message}")
|
|
250
|
+
return 1
|
|
251
|
+
|
|
252
|
+
elif args.command == 'show':
|
|
253
|
+
logger.info("📋 Showing default configuration...")
|
|
254
|
+
config = default_config
|
|
255
|
+
|
|
256
|
+
if args.section:
|
|
257
|
+
if args.section in config:
|
|
258
|
+
config = {args.section: config[args.section]}
|
|
259
|
+
else:
|
|
260
|
+
logger.error(f"❌ Section '{args.section}' not found")
|
|
261
|
+
return 1
|
|
262
|
+
|
|
263
|
+
if args.format == 'json':
|
|
264
|
+
print(json.dumps(config, indent=2))
|
|
265
|
+
else:
|
|
266
|
+
print(yaml.dump(config, default_flow_style=False, indent=2))
|
|
267
|
+
|
|
268
|
+
elif args.command == 'convert':
|
|
269
|
+
logger.info(f"🔄 Converting configuration: {args.input_file}")
|
|
270
|
+
|
|
271
|
+
# Load input config
|
|
272
|
+
config = load_config(args.input_file)
|
|
273
|
+
|
|
274
|
+
# Determine output format
|
|
275
|
+
if args.format:
|
|
276
|
+
output_format = args.format
|
|
277
|
+
elif args.output:
|
|
278
|
+
output_format = 'json' if args.output.endswith('.json') else 'yaml'
|
|
279
|
+
else:
|
|
280
|
+
output_format = 'yaml'
|
|
281
|
+
|
|
282
|
+
# Determine output file
|
|
283
|
+
if args.output:
|
|
284
|
+
output_file = args.output
|
|
285
|
+
else:
|
|
286
|
+
base_name = os.path.splitext(args.input_file)[0]
|
|
287
|
+
output_file = f"{base_name}.{output_format}"
|
|
288
|
+
|
|
289
|
+
# Save in new format
|
|
290
|
+
if output_format == 'json':
|
|
291
|
+
with open(output_file, 'w') as f:
|
|
292
|
+
json.dump(config, f, indent=2)
|
|
293
|
+
else:
|
|
294
|
+
with open(output_file, 'w') as f:
|
|
295
|
+
yaml.dump(config, f, default_flow_style=False, indent=2)
|
|
296
|
+
|
|
297
|
+
logger.info(f"✅ Configuration converted to {output_file}")
|
|
298
|
+
|
|
299
|
+
elif args.command == 'diff':
|
|
300
|
+
logger.info(f"🔍 Comparing configurations: {args.config1} vs {args.config2}")
|
|
301
|
+
|
|
302
|
+
config1 = load_config(args.config1)
|
|
303
|
+
config2 = load_config(args.config2)
|
|
304
|
+
|
|
305
|
+
# Simple diff implementation
|
|
306
|
+
def get_differences(dict1, dict2, path=""):
|
|
307
|
+
differences = []
|
|
308
|
+
all_keys = set(dict1.keys()) | set(dict2.keys())
|
|
309
|
+
|
|
310
|
+
for key in all_keys:
|
|
311
|
+
current_path = f"{path}.{key}" if path else key
|
|
312
|
+
|
|
313
|
+
if key not in dict1:
|
|
314
|
+
differences.append(f"+ {current_path}: {dict2[key]}")
|
|
315
|
+
elif key not in dict2:
|
|
316
|
+
differences.append(f"- {current_path}: {dict1[key]}")
|
|
317
|
+
elif isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
|
|
318
|
+
differences.extend(get_differences(dict1[key], dict2[key], current_path))
|
|
319
|
+
elif dict1[key] != dict2[key]:
|
|
320
|
+
differences.append(f"~ {current_path}: {dict1[key]} -> {dict2[key]}")
|
|
321
|
+
|
|
322
|
+
return differences
|
|
323
|
+
|
|
324
|
+
differences = get_differences(config1, config2)
|
|
325
|
+
|
|
326
|
+
if args.format == 'json':
|
|
327
|
+
print(json.dumps(differences, indent=2))
|
|
328
|
+
else:
|
|
329
|
+
if differences:
|
|
330
|
+
print("Configuration differences:")
|
|
331
|
+
for diff in differences:
|
|
332
|
+
print(f" {diff}")
|
|
333
|
+
else:
|
|
334
|
+
print("No differences found")
|
|
335
|
+
|
|
336
|
+
except Exception as e:
|
|
337
|
+
logger.error(f"❌ Operation failed: {e}")
|
|
338
|
+
return 1
|
|
339
|
+
|
|
340
|
+
return 0
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
if __name__ == '__main__':
|
|
344
|
+
sys.exit(main())
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Evaluation CLI for Langvision models.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from torch.utils.data import DataLoader
|
|
14
|
+
|
|
15
|
+
# Import from langvision modules
|
|
16
|
+
try:
|
|
17
|
+
from langvision.models.vision_transformer import VisionTransformer
|
|
18
|
+
from langvision.utils.data import get_dataset
|
|
19
|
+
from langvision.utils.device import setup_cuda, set_seed
|
|
20
|
+
from langvision.training.trainer import Trainer
|
|
21
|
+
except ImportError as e:
|
|
22
|
+
print(f"❌ Error importing langvision modules: {e}")
|
|
23
|
+
print("Please ensure langvision is properly installed:")
|
|
24
|
+
print(" pip install langvision")
|
|
25
|
+
sys.exit(1)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def parse_args():
|
|
29
|
+
"""Parse command-line arguments for model evaluation."""
|
|
30
|
+
parser = argparse.ArgumentParser(
|
|
31
|
+
description='Evaluate a trained VisionTransformer model',
|
|
32
|
+
epilog='''\nExamples:\n langvision evaluate --checkpoint model.pth --dataset cifar10\n langvision evaluate --checkpoint model.pth --dataset cifar100 --batch_size 128\n''',
|
|
33
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# Model
|
|
37
|
+
model_group = parser.add_argument_group('Model')
|
|
38
|
+
model_group.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
|
|
39
|
+
model_group.add_argument('--img_size', type=int, default=224, help='Input image size')
|
|
40
|
+
model_group.add_argument('--patch_size', type=int, default=16, help='Patch size for ViT')
|
|
41
|
+
model_group.add_argument('--num_classes', type=int, default=10, help='Number of classes')
|
|
42
|
+
model_group.add_argument('--embed_dim', type=int, default=768, help='Embedding dimension')
|
|
43
|
+
model_group.add_argument('--depth', type=int, default=12, help='Number of transformer layers')
|
|
44
|
+
model_group.add_argument('--num_heads', type=int, default=12, help='Number of attention heads')
|
|
45
|
+
model_group.add_argument('--mlp_ratio', type=float, default=4.0, help='MLP hidden dim ratio')
|
|
46
|
+
|
|
47
|
+
# Data
|
|
48
|
+
data_group = parser.add_argument_group('Data')
|
|
49
|
+
data_group.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='Dataset to evaluate on')
|
|
50
|
+
data_group.add_argument('--data_dir', type=str, default='./data', help='Dataset directory')
|
|
51
|
+
data_group.add_argument('--batch_size', type=int, default=64, help='Batch size for evaluation')
|
|
52
|
+
data_group.add_argument('--num_workers', type=int, default=2, help='Number of data loader workers')
|
|
53
|
+
|
|
54
|
+
# Output
|
|
55
|
+
output_group = parser.add_argument_group('Output')
|
|
56
|
+
output_group.add_argument('--output_dir', type=str, default='./evaluation_results', help='Directory to save evaluation results')
|
|
57
|
+
output_group.add_argument('--save_predictions', action='store_true', help='Save model predictions to file')
|
|
58
|
+
output_group.add_argument('--save_confusion_matrix', action='store_true', help='Save confusion matrix plot')
|
|
59
|
+
|
|
60
|
+
# Device
|
|
61
|
+
device_group = parser.add_argument_group('Device')
|
|
62
|
+
device_group.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use')
|
|
63
|
+
|
|
64
|
+
# Misc
|
|
65
|
+
misc_group = parser.add_argument_group('Misc')
|
|
66
|
+
misc_group.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
|
|
67
|
+
misc_group.add_argument('--log_level', type=str, default='info', help='Logging level')
|
|
68
|
+
|
|
69
|
+
return parser.parse_args()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def setup_logging(log_level: str) -> None:
|
|
73
|
+
"""Set up logging with the specified log level."""
|
|
74
|
+
numeric_level = getattr(logging, log_level.upper(), None)
|
|
75
|
+
if not isinstance(numeric_level, int):
|
|
76
|
+
numeric_level = logging.INFO
|
|
77
|
+
logging.basicConfig(
|
|
78
|
+
level=numeric_level,
|
|
79
|
+
format='[%(levelname)s] %(message)s',
|
|
80
|
+
handlers=[
|
|
81
|
+
logging.StreamHandler(),
|
|
82
|
+
logging.FileHandler('evaluation.log')
|
|
83
|
+
]
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def main():
|
|
88
|
+
"""Main function for model evaluation."""
|
|
89
|
+
args = parse_args()
|
|
90
|
+
setup_logging(args.log_level)
|
|
91
|
+
logger = logging.getLogger(__name__)
|
|
92
|
+
|
|
93
|
+
logger.info("🚀 Starting model evaluation...")
|
|
94
|
+
|
|
95
|
+
# Setup
|
|
96
|
+
setup_cuda(seed=args.seed)
|
|
97
|
+
set_seed(args.seed)
|
|
98
|
+
|
|
99
|
+
# Create output directory
|
|
100
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
101
|
+
|
|
102
|
+
# Load dataset
|
|
103
|
+
logger.info(f"📊 Loading {args.dataset} dataset...")
|
|
104
|
+
try:
|
|
105
|
+
test_dataset = get_dataset(args.dataset, args.data_dir, train=False, img_size=args.img_size)
|
|
106
|
+
test_loader = DataLoader(
|
|
107
|
+
test_dataset,
|
|
108
|
+
batch_size=args.batch_size,
|
|
109
|
+
shuffle=False,
|
|
110
|
+
num_workers=args.num_workers
|
|
111
|
+
)
|
|
112
|
+
logger.info(f"✅ Dataset loaded: {len(test_dataset)} test samples")
|
|
113
|
+
except Exception as e:
|
|
114
|
+
logger.error(f"❌ Failed to load dataset: {e}")
|
|
115
|
+
return 1
|
|
116
|
+
|
|
117
|
+
# Load model
|
|
118
|
+
logger.info("🤖 Loading model...")
|
|
119
|
+
try:
|
|
120
|
+
model = VisionTransformer(
|
|
121
|
+
img_size=args.img_size,
|
|
122
|
+
patch_size=args.patch_size,
|
|
123
|
+
in_chans=3,
|
|
124
|
+
num_classes=args.num_classes,
|
|
125
|
+
embed_dim=args.embed_dim,
|
|
126
|
+
depth=args.depth,
|
|
127
|
+
num_heads=args.num_heads,
|
|
128
|
+
mlp_ratio=args.mlp_ratio,
|
|
129
|
+
).to(args.device)
|
|
130
|
+
|
|
131
|
+
# Load checkpoint
|
|
132
|
+
if not os.path.isfile(args.checkpoint):
|
|
133
|
+
logger.error(f"❌ Checkpoint file not found: {args.checkpoint}")
|
|
134
|
+
return 1
|
|
135
|
+
|
|
136
|
+
checkpoint = torch.load(args.checkpoint, map_location=args.device)
|
|
137
|
+
if 'model' in checkpoint:
|
|
138
|
+
model.load_state_dict(checkpoint['model'])
|
|
139
|
+
else:
|
|
140
|
+
model.load_state_dict(checkpoint)
|
|
141
|
+
|
|
142
|
+
model.eval()
|
|
143
|
+
logger.info("✅ Model loaded successfully")
|
|
144
|
+
|
|
145
|
+
except Exception as e:
|
|
146
|
+
logger.error(f"❌ Failed to load model: {e}")
|
|
147
|
+
return 1
|
|
148
|
+
|
|
149
|
+
# Evaluate
|
|
150
|
+
logger.info("📈 Running evaluation...")
|
|
151
|
+
try:
|
|
152
|
+
trainer = Trainer(model, device=args.device)
|
|
153
|
+
test_loss, test_acc = trainer.evaluate(test_loader)
|
|
154
|
+
|
|
155
|
+
logger.info(f"🎯 Evaluation Results:")
|
|
156
|
+
logger.info(f" Test Loss: {test_loss:.4f}")
|
|
157
|
+
logger.info(f" Test Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")
|
|
158
|
+
|
|
159
|
+
# Save results
|
|
160
|
+
results = {
|
|
161
|
+
'test_loss': float(test_loss),
|
|
162
|
+
'test_accuracy': float(test_acc),
|
|
163
|
+
'dataset': args.dataset,
|
|
164
|
+
'checkpoint': args.checkpoint,
|
|
165
|
+
'model_config': {
|
|
166
|
+
'img_size': args.img_size,
|
|
167
|
+
'patch_size': args.patch_size,
|
|
168
|
+
'num_classes': args.num_classes,
|
|
169
|
+
'embed_dim': args.embed_dim,
|
|
170
|
+
'depth': args.depth,
|
|
171
|
+
'num_heads': args.num_heads,
|
|
172
|
+
'mlp_ratio': args.mlp_ratio,
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
import json
|
|
177
|
+
results_file = os.path.join(args.output_dir, 'evaluation_results.json')
|
|
178
|
+
with open(results_file, 'w') as f:
|
|
179
|
+
json.dump(results, f, indent=2)
|
|
180
|
+
logger.info(f"💾 Results saved to {results_file}")
|
|
181
|
+
|
|
182
|
+
if args.save_predictions:
|
|
183
|
+
logger.info("💾 Saving predictions...")
|
|
184
|
+
# TODO: Implement prediction saving
|
|
185
|
+
pass
|
|
186
|
+
|
|
187
|
+
if args.save_confusion_matrix:
|
|
188
|
+
logger.info("📊 Saving confusion matrix...")
|
|
189
|
+
# TODO: Implement confusion matrix
|
|
190
|
+
pass
|
|
191
|
+
|
|
192
|
+
except Exception as e:
|
|
193
|
+
logger.error(f"❌ Evaluation failed: {e}")
|
|
194
|
+
return 1
|
|
195
|
+
|
|
196
|
+
logger.info("✅ Evaluation completed successfully!")
|
|
197
|
+
return 0
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
if __name__ == '__main__':
|
|
201
|
+
sys.exit(main())
|