brainscore-vision 2.1__py3-none-any.whl → 2.2.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- brainscore_vision/benchmarks/coggan2024_behavior/__init__.py +2 -1
- brainscore_vision/benchmarks/coggan2024_behavior/test.py +2 -2
- brainscore_vision/benchmarks/coggan2024_fMRI/__init__.py +4 -4
- brainscore_vision/benchmarks/coggan2024_fMRI/test.py +2 -2
- brainscore_vision/benchmarks/imagenet/imagenet2012.csv +50000 -50000
- brainscore_vision/benchmarks/imagenet_c/benchmark.py +1 -1
- brainscore_vision/benchmarks/lonnqvist2024/__init__.py +8 -0
- brainscore_vision/benchmarks/lonnqvist2024/benchmark.py +125 -0
- brainscore_vision/benchmarks/lonnqvist2024/test.py +61 -0
- brainscore_vision/benchmarks/malania2007/benchmark.py +3 -0
- brainscore_vision/benchmarks/maniquet2024/benchmark.py +1 -1
- brainscore_vision/data/lonnqvist2024/__init__.py +47 -0
- brainscore_vision/data/lonnqvist2024/data_packaging/lonnqvist_data_assembly.py +53 -0
- brainscore_vision/data/lonnqvist2024/data_packaging/lonnqvist_stimulus_set.py +61 -0
- brainscore_vision/data/lonnqvist2024/test.py +127 -0
- brainscore_vision/model_helpers/brain_transformation/__init__.py +33 -0
- brainscore_vision/models/alexnet/region_layer_map/alexnet.json +1 -0
- brainscore_vision/models/alexnet_7be5be79/setup.py +4 -4
- brainscore_vision/models/alexnet_random/__init__.py +7 -0
- brainscore_vision/models/alexnet_random/model.py +46 -0
- brainscore_vision/models/alexnet_random/setup.py +26 -0
- brainscore_vision/models/alexnet_random/test.py +1 -0
- brainscore_vision/models/cvt_cvt_13_224_in1k_4/__init__.py +9 -0
- brainscore_vision/models/cvt_cvt_13_224_in1k_4/model.py +142 -0
- brainscore_vision/models/cvt_cvt_13_224_in1k_4/region_layer_map/cvt_cvt-13-224-in1k_4.json +6 -0
- brainscore_vision/models/cvt_cvt_13_224_in1k_4/region_layer_map/cvt_cvt-13-224-in1k_4_LucyV4.json +6 -0
- brainscore_vision/models/cvt_cvt_13_224_in1k_4/requirements.txt +4 -0
- brainscore_vision/models/cvt_cvt_13_224_in1k_4/test.py +8 -0
- brainscore_vision/models/cvt_cvt_13_384_in1k_4/__init__.py +9 -0
- brainscore_vision/models/cvt_cvt_13_384_in1k_4/model.py +142 -0
- brainscore_vision/models/cvt_cvt_13_384_in1k_4/region_layer_map/cvt_cvt-13-384-in1k_4_LucyV4.json +6 -0
- brainscore_vision/models/cvt_cvt_13_384_in1k_4/requirements.txt +4 -0
- brainscore_vision/models/cvt_cvt_13_384_in1k_4/test.py +8 -0
- brainscore_vision/models/cvt_cvt_13_384_in22k_finetuned_in1k_4/__init__.py +9 -0
- brainscore_vision/models/cvt_cvt_13_384_in22k_finetuned_in1k_4/model.py +142 -0
- brainscore_vision/models/cvt_cvt_13_384_in22k_finetuned_in1k_4/region_layer_map/cvt_cvt-13-384-in22k_finetuned-in1k_4_LucyV4.json +6 -0
- brainscore_vision/models/cvt_cvt_13_384_in22k_finetuned_in1k_4/requirements.txt +4 -0
- brainscore_vision/models/cvt_cvt_13_384_in22k_finetuned_in1k_4/test.py +8 -0
- brainscore_vision/models/cvt_cvt_21_224_in1k_4/__init__.py +9 -0
- brainscore_vision/models/cvt_cvt_21_224_in1k_4/model.py +142 -0
- brainscore_vision/models/cvt_cvt_21_224_in1k_4/region_layer_map/cvt_cvt-21-224-in1k_4_LucyV4.json +6 -0
- brainscore_vision/models/cvt_cvt_21_224_in1k_4/requirements.txt +4 -0
- brainscore_vision/models/cvt_cvt_21_224_in1k_4/test.py +8 -0
- brainscore_vision/models/cvt_cvt_21_384_in1k_4/__init__.py +9 -0
- brainscore_vision/models/cvt_cvt_21_384_in1k_4/model.py +142 -0
- brainscore_vision/models/cvt_cvt_21_384_in1k_4/region_layer_map/cvt_cvt-21-384-in1k_4_LucyV4.json +6 -0
- brainscore_vision/models/cvt_cvt_21_384_in1k_4/requirements.txt +4 -0
- brainscore_vision/models/cvt_cvt_21_384_in1k_4/test.py +8 -0
- brainscore_vision/models/cvt_cvt_21_384_in22k_finetuned_in1k_4/__init__.py +9 -0
- brainscore_vision/models/cvt_cvt_21_384_in22k_finetuned_in1k_4/model.py +142 -0
- brainscore_vision/models/cvt_cvt_21_384_in22k_finetuned_in1k_4/region_layer_map/cvt_cvt-21-384-in22k_finetuned-in1k_4_LucyV4.json +6 -0
- brainscore_vision/models/cvt_cvt_21_384_in22k_finetuned_in1k_4/requirements.txt +4 -0
- brainscore_vision/models/cvt_cvt_21_384_in22k_finetuned_in1k_4/test.py +8 -0
- brainscore_vision/models/fixres_resnext101_32x48d_wsl/__init__.py +7 -0
- brainscore_vision/models/fixres_resnext101_32x48d_wsl/model.py +57 -0
- brainscore_vision/models/fixres_resnext101_32x48d_wsl/requirements.txt +5 -0
- brainscore_vision/models/fixres_resnext101_32x48d_wsl/test.py +7 -0
- brainscore_vision/models/inception_v4_pytorch/__init__.py +7 -0
- brainscore_vision/models/inception_v4_pytorch/model.py +64 -0
- brainscore_vision/models/inception_v4_pytorch/requirements.txt +3 -0
- brainscore_vision/models/inception_v4_pytorch/test.py +8 -0
- brainscore_vision/models/mvimgnet_ms_05/__init__.py +9 -0
- brainscore_vision/models/mvimgnet_ms_05/model.py +64 -0
- brainscore_vision/models/mvimgnet_ms_05/setup.py +25 -0
- brainscore_vision/models/mvimgnet_ms_05/test.py +1 -0
- brainscore_vision/models/mvimgnet_rf/__init__.py +9 -0
- brainscore_vision/models/mvimgnet_rf/model.py +64 -0
- brainscore_vision/models/mvimgnet_rf/setup.py +25 -0
- brainscore_vision/models/mvimgnet_rf/test.py +1 -0
- brainscore_vision/models/mvimgnet_ss_00/__init__.py +9 -0
- brainscore_vision/models/mvimgnet_ss_00/model.py +64 -0
- brainscore_vision/models/mvimgnet_ss_00/setup.py +25 -0
- brainscore_vision/models/mvimgnet_ss_00/test.py +1 -0
- brainscore_vision/models/mvimgnet_ss_02/__init__.py +9 -0
- brainscore_vision/models/mvimgnet_ss_02/model.py +64 -0
- brainscore_vision/models/mvimgnet_ss_02/setup.py +25 -0
- brainscore_vision/models/mvimgnet_ss_02/test.py +1 -0
- brainscore_vision/models/mvimgnet_ss_03/__init__.py +9 -0
- brainscore_vision/models/mvimgnet_ss_03/model.py +64 -0
- brainscore_vision/models/mvimgnet_ss_03/setup.py +25 -0
- brainscore_vision/models/mvimgnet_ss_03/test.py +1 -0
- brainscore_vision/models/mvimgnet_ss_04/__init__.py +9 -0
- brainscore_vision/models/mvimgnet_ss_04/model.py +64 -0
- brainscore_vision/models/mvimgnet_ss_04/setup.py +25 -0
- brainscore_vision/models/mvimgnet_ss_04/test.py +1 -0
- brainscore_vision/models/mvimgnet_ss_05/__init__.py +9 -0
- brainscore_vision/models/mvimgnet_ss_05/model.py +64 -0
- brainscore_vision/models/mvimgnet_ss_05/setup.py +25 -0
- brainscore_vision/models/mvimgnet_ss_05/test.py +1 -0
- brainscore_vision/models/resnet50_tutorial/region_layer_map/resnet50_tutorial.json +1 -0
- brainscore_vision/models/sam_test_resnet/__init__.py +5 -0
- brainscore_vision/models/sam_test_resnet/model.py +26 -0
- brainscore_vision/models/sam_test_resnet/requirements.txt +2 -0
- brainscore_vision/models/sam_test_resnet/test.py +8 -0
- brainscore_vision/models/sam_test_resnet_4/__init__.py +5 -0
- brainscore_vision/models/sam_test_resnet_4/model.py +26 -0
- brainscore_vision/models/sam_test_resnet_4/requirements.txt +2 -0
- brainscore_vision/models/sam_test_resnet_4/test.py +8 -0
- brainscore_vision/models/scaling_models/__init__.py +265 -0
- brainscore_vision/models/scaling_models/model.py +148 -0
- brainscore_vision/models/scaling_models/model_configs.json +869 -0
- brainscore_vision/models/scaling_models/region_layer_map/convnext_base_imagenet_full_seed-0.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/convnext_large_imagenet_full_seed-0.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/convnext_small_imagenet_100_seed-0.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/convnext_small_imagenet_10_seed-0.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/convnext_small_imagenet_1_seed-0.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/convnext_small_imagenet_full_seed-0.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/deit_base_imagenet_full_seed-0.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/deit_large_imagenet_full_seed-0.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/deit_small_imagenet_100_seed-0.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/deit_small_imagenet_10_seed-0.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/deit_small_imagenet_1_seed-0.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/deit_small_imagenet_full_seed-0.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/efficientnet_b0_imagenet_full.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/efficientnet_b1_imagenet_full.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/efficientnet_b2_imagenet_full.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/resnet101_ecoset_full.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/resnet101_imagenet_full.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/resnet152_ecoset_full.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/resnet18_ecoset_full.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/resnet18_imagenet_full.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/resnet34_ecoset_full.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/resnet34_imagenet_full.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/resnet50_ecoset_full.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/resnet50_imagenet_100_seed-0.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/resnet50_imagenet_10_seed-0.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/resnet50_imagenet_1_seed-0.json +6 -0
- brainscore_vision/models/scaling_models/region_layer_map/resnet50_imagenet_full.json +6 -0
- brainscore_vision/models/scaling_models/requirements.txt +4 -0
- brainscore_vision/models/scaling_models/test.py +0 -0
- brainscore_vision/models/vitb14_dinov2_imagenet1k/__init__.py +5 -0
- brainscore_vision/models/vitb14_dinov2_imagenet1k/model.py +852 -0
- brainscore_vision/models/vitb14_dinov2_imagenet1k/setup.py +25 -0
- brainscore_vision/models/vitb14_dinov2_imagenet1k/test.py +0 -0
- brainscore_vision/models/voneresnet_50_non_stochastic/region_layer_map/voneresnet-50-non_stochastic.json +1 -0
- brainscore_vision/submission/actions_helpers.py +2 -2
- brainscore_vision/submission/endpoints.py +3 -4
- {brainscore_vision-2.1.dist-info → brainscore_vision-2.2.1.dist-info}/METADATA +2 -2
- {brainscore_vision-2.1.dist-info → brainscore_vision-2.2.1.dist-info}/RECORD +143 -18
- {brainscore_vision-2.1.dist-info → brainscore_vision-2.2.1.dist-info}/WHEEL +1 -1
- tests/test_model_helpers/temporal/activations/test_inferencer.py +2 -2
- {brainscore_vision-2.1.dist-info → brainscore_vision-2.2.1.dist-info}/LICENSE +0 -0
- {brainscore_vision-2.1.dist-info → brainscore_vision-2.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,64 @@
|
|
1
|
+
from brainscore_vision.model_helpers.check_submission import check_models
|
2
|
+
import functools
|
3
|
+
import os
|
4
|
+
from urllib.request import urlretrieve
|
5
|
+
import torchvision.models
|
6
|
+
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
|
7
|
+
from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images
|
8
|
+
from pathlib import Path
|
9
|
+
from brainscore_vision.model_helpers import download_weights
|
10
|
+
import torch
|
11
|
+
from collections import OrderedDict
|
12
|
+
|
13
|
+
# This is an example implementation for submitting resnet-50 as a pytorch model
|
14
|
+
|
15
|
+
# Attention: It is important, that the wrapper identifier is unique per model!
|
16
|
+
# The results will otherwise be the same due to brain-scores internal result caching mechanism.
|
17
|
+
# Please load your pytorch model for usage in CPU. There won't be GPUs available for scoring your model.
|
18
|
+
# If the model requires a GPU, contact the brain-score team directly.
|
19
|
+
from brainscore_vision.model_helpers.check_submission import check_models
|
20
|
+
|
21
|
+
|
22
|
+
def get_model_list():
|
23
|
+
return ["mvimgnet_ss_05"]
|
24
|
+
|
25
|
+
|
26
|
+
def get_model(name):
|
27
|
+
assert name == "mvimgnet_ss_05"
|
28
|
+
url = "https://users.flatironinstitute.org/~tyerxa/slow_steady/training_checkpoints/slow_steady/r2/LARS/lmda_0.5/latest-rank0.pt"
|
29
|
+
fh = urlretrieve(url)
|
30
|
+
state_dict = torch.load(fh[0], map_location=torch.device("cpu"))["state"]["model"]
|
31
|
+
model = load_composer_classifier(state_dict)
|
32
|
+
preprocessing = functools.partial(load_preprocess_images, image_size=224)
|
33
|
+
wrapper = PytorchWrapper(identifier=name, model=model, preprocessing=preprocessing)
|
34
|
+
wrapper.image_size = 224
|
35
|
+
return wrapper
|
36
|
+
|
37
|
+
def load_composer_classifier(sd):
|
38
|
+
model = torchvision.models.resnet.resnet50()
|
39
|
+
new_sd = OrderedDict()
|
40
|
+
for k, v in sd.items():
|
41
|
+
if 'lin_cls' in k:
|
42
|
+
new_sd['fc.' + k.split('.')[-1]] = v
|
43
|
+
if ".f." not in k:
|
44
|
+
continue
|
45
|
+
parts = k.split(".")
|
46
|
+
idx = parts.index("f")
|
47
|
+
new_k = ".".join(parts[idx + 1 :])
|
48
|
+
new_sd[new_k] = v
|
49
|
+
model.load_state_dict(new_sd, strict=True)
|
50
|
+
return model
|
51
|
+
|
52
|
+
def get_layers(name):
|
53
|
+
assert name == "mvimgnet_ss_05"
|
54
|
+
|
55
|
+
outs = ["layer1", "layer2", "layer3", "layer4"]
|
56
|
+
return outs
|
57
|
+
|
58
|
+
|
59
|
+
def get_bibtex(model_identifier):
|
60
|
+
return """xx"""
|
61
|
+
|
62
|
+
|
63
|
+
if __name__ == "__main__":
|
64
|
+
check_models.check_base_models(__name__)
|
@@ -0,0 +1,25 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
from setuptools import setup, find_packages
|
5
|
+
|
6
|
+
requirements = [ "torchvision",
|
7
|
+
"torch"
|
8
|
+
]
|
9
|
+
|
10
|
+
setup(
|
11
|
+
packages=find_packages(exclude=['tests']),
|
12
|
+
include_package_data=True,
|
13
|
+
install_requires=requirements,
|
14
|
+
license="MIT license",
|
15
|
+
zip_safe=False,
|
16
|
+
keywords='brain-score template',
|
17
|
+
classifiers=[
|
18
|
+
'Development Status :: 2 - Pre-Alpha',
|
19
|
+
'Intended Audience :: Developers',
|
20
|
+
'License :: OSI Approved :: MIT License',
|
21
|
+
'Natural Language :: English',
|
22
|
+
'Programming Language :: Python :: 3.7',
|
23
|
+
],
|
24
|
+
test_suite='tests',
|
25
|
+
)
|
@@ -0,0 +1 @@
|
|
1
|
+
# Left empty as part of 2023 models migration
|
@@ -0,0 +1 @@
|
|
1
|
+
{"IT": "layer3", "V2": "layer2", "V1": "layer2", "V4": "layer2"}
|
@@ -0,0 +1,5 @@
|
|
1
|
+
from brainscore_vision import model_registry
|
2
|
+
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
|
3
|
+
from .model import get_model, get_layers
|
4
|
+
|
5
|
+
model_registry['sam_test_resnet'] = lambda: ModelCommitment(identifier='sam_test_resnet', activations_model=get_model('sam_test_resnet'), layers=get_layers('sam_test_resnet'))
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from brainscore_vision.model_helpers.check_submission import check_models
|
2
|
+
import functools
|
3
|
+
import torchvision.models
|
4
|
+
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
|
5
|
+
from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images
|
6
|
+
|
7
|
+
def get_model(name):
|
8
|
+
assert name == 'sam_test_resnet'
|
9
|
+
model = torchvision.models.resnet50(pretrained=True)
|
10
|
+
preprocessing = functools.partial(load_preprocess_images, image_size=224)
|
11
|
+
wrapper = PytorchWrapper(identifier='sam_test_resnet', model=model, preprocessing=preprocessing)
|
12
|
+
wrapper.image_size = 224
|
13
|
+
return wrapper
|
14
|
+
|
15
|
+
|
16
|
+
def get_layers(name):
|
17
|
+
assert name == 'sam_test_resnet'
|
18
|
+
return ['conv1','layer1', 'layer2', 'layer3', 'layer4', 'fc']
|
19
|
+
|
20
|
+
|
21
|
+
def get_bibtex(model_identifier):
|
22
|
+
return """"""
|
23
|
+
|
24
|
+
|
25
|
+
if __name__ == '__main__':
|
26
|
+
check_models.check_base_models(__name__)
|
@@ -0,0 +1,5 @@
|
|
1
|
+
from brainscore_vision import model_registry
|
2
|
+
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
|
3
|
+
from .model import get_model, get_layers
|
4
|
+
|
5
|
+
model_registry['sam_test_resnet_4'] = lambda: ModelCommitment(identifier='sam_test_resnet_4', activations_model=get_model('sam_test_resnet_4'), layers=get_layers('sam_test_resnet_4'))
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from brainscore_vision.model_helpers.check_submission import check_models
|
2
|
+
import functools
|
3
|
+
import torchvision.models
|
4
|
+
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
|
5
|
+
from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images
|
6
|
+
|
7
|
+
def get_model(name):
|
8
|
+
assert name == 'sam_test_resnet_4'
|
9
|
+
model = torchvision.models.resnet50(pretrained=True)
|
10
|
+
preprocessing = functools.partial(load_preprocess_images, image_size=224)
|
11
|
+
wrapper = PytorchWrapper(identifier='sam_test_resnet', model=model, preprocessing=preprocessing)
|
12
|
+
wrapper.image_size = 224
|
13
|
+
return wrapper
|
14
|
+
|
15
|
+
|
16
|
+
def get_layers(name):
|
17
|
+
assert name == 'sam_test_resnet_4'
|
18
|
+
return ['conv1','layer1', 'layer2', 'layer3', 'layer4', 'fc']
|
19
|
+
|
20
|
+
|
21
|
+
def get_bibtex(model_identifier):
|
22
|
+
return """"""
|
23
|
+
|
24
|
+
|
25
|
+
if __name__ == '__main__':
|
26
|
+
check_models.check_base_models(__name__)
|
@@ -0,0 +1,265 @@
|
|
1
|
+
from brainscore_vision import model_registry
|
2
|
+
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
|
3
|
+
from .model import get_model, MODEL_CONFIGS
|
4
|
+
|
5
|
+
model_registry["resnet18_imagenet_full"] = lambda: ModelCommitment(
|
6
|
+
identifier="resnet18_imagenet_full",
|
7
|
+
activations_model=get_model("resnet18_imagenet_full"),
|
8
|
+
layers=MODEL_CONFIGS["resnet18_imagenet_full"]["model_commitment"]["layers"],
|
9
|
+
behavioral_readout_layer=MODEL_CONFIGS["resnet18_imagenet_full"]["model_commitment"]["behavioral_readout_layer"],
|
10
|
+
region_layer_map=MODEL_CONFIGS["resnet18_imagenet_full"]["model_commitment"]["region_layer_map"]
|
11
|
+
)
|
12
|
+
|
13
|
+
|
14
|
+
model_registry["resnet34_imagenet_full"] = lambda: ModelCommitment(
|
15
|
+
identifier="resnet34_imagenet_full",
|
16
|
+
activations_model=get_model("resnet34_imagenet_full"),
|
17
|
+
layers=MODEL_CONFIGS["resnet34_imagenet_full"]["model_commitment"]["layers"],
|
18
|
+
behavioral_readout_layer=MODEL_CONFIGS["resnet34_imagenet_full"]["model_commitment"]["behavioral_readout_layer"],
|
19
|
+
region_layer_map=MODEL_CONFIGS["resnet34_imagenet_full"]["model_commitment"]["region_layer_map"]
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
model_registry["resnet50_imagenet_full"] = lambda: ModelCommitment(
|
24
|
+
identifier="resnet50_imagenet_full",
|
25
|
+
activations_model=get_model("resnet50_imagenet_full"),
|
26
|
+
layers=MODEL_CONFIGS["resnet50_imagenet_full"]["model_commitment"]["layers"],
|
27
|
+
behavioral_readout_layer=MODEL_CONFIGS["resnet50_imagenet_full"]["model_commitment"]["behavioral_readout_layer"],
|
28
|
+
region_layer_map=MODEL_CONFIGS["resnet50_imagenet_full"]["model_commitment"]["region_layer_map"]
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
model_registry["resnet101_imagenet_full"] = lambda: ModelCommitment(
|
33
|
+
identifier="resnet101_imagenet_full",
|
34
|
+
activations_model=get_model("resnet101_imagenet_full"),
|
35
|
+
layers=MODEL_CONFIGS["resnet101_imagenet_full"]["model_commitment"]["layers"],
|
36
|
+
behavioral_readout_layer=MODEL_CONFIGS["resnet101_imagenet_full"]["model_commitment"]["behavioral_readout_layer"],
|
37
|
+
region_layer_map=MODEL_CONFIGS["resnet101_imagenet_full"]["model_commitment"]["region_layer_map"]
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
model_registry["resnet152_imagenet_full"] = lambda: ModelCommitment(
|
42
|
+
identifier="resnet152_imagenet_full",
|
43
|
+
activations_model=get_model("resnet152_imagenet_full"),
|
44
|
+
layers=MODEL_CONFIGS["resnet152_imagenet_full"]["model_commitment"]["layers"],
|
45
|
+
behavioral_readout_layer=MODEL_CONFIGS["resnet152_imagenet_full"]["model_commitment"]["behavioral_readout_layer"],
|
46
|
+
region_layer_map=MODEL_CONFIGS["resnet152_imagenet_full"]["model_commitment"]["region_layer_map"]
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
model_registry["resnet18_ecoset_full"] = lambda: ModelCommitment(
|
51
|
+
identifier="resnet18_ecoset_full",
|
52
|
+
activations_model=get_model("resnet18_ecoset_full"),
|
53
|
+
layers=MODEL_CONFIGS["resnet18_ecoset_full"]["model_commitment"]["layers"],
|
54
|
+
behavioral_readout_layer=MODEL_CONFIGS["resnet18_ecoset_full"]["model_commitment"]["behavioral_readout_layer"],
|
55
|
+
region_layer_map=MODEL_CONFIGS["resnet18_ecoset_full"]["model_commitment"]["region_layer_map"]
|
56
|
+
)
|
57
|
+
|
58
|
+
|
59
|
+
model_registry["resnet34_ecoset_full"] = lambda: ModelCommitment(
|
60
|
+
identifier="resnet34_ecoset_full",
|
61
|
+
activations_model=get_model("resnet34_ecoset_full"),
|
62
|
+
layers=MODEL_CONFIGS["resnet34_ecoset_full"]["model_commitment"]["layers"],
|
63
|
+
behavioral_readout_layer=MODEL_CONFIGS["resnet34_ecoset_full"]["model_commitment"]["behavioral_readout_layer"],
|
64
|
+
region_layer_map=MODEL_CONFIGS["resnet34_ecoset_full"]["model_commitment"]["region_layer_map"]
|
65
|
+
)
|
66
|
+
|
67
|
+
|
68
|
+
model_registry["resnet50_ecoset_full"] = lambda: ModelCommitment(
|
69
|
+
identifier="resnet50_ecoset_full",
|
70
|
+
activations_model=get_model("resnet50_ecoset_full"),
|
71
|
+
layers=MODEL_CONFIGS["resnet50_ecoset_full"]["model_commitment"]["layers"],
|
72
|
+
behavioral_readout_layer=MODEL_CONFIGS["resnet50_ecoset_full"]["model_commitment"]["behavioral_readout_layer"],
|
73
|
+
region_layer_map=MODEL_CONFIGS["resnet50_ecoset_full"]["model_commitment"]["region_layer_map"]
|
74
|
+
)
|
75
|
+
|
76
|
+
|
77
|
+
model_registry["resnet101_ecoset_full"] = lambda: ModelCommitment(
|
78
|
+
identifier="resnet101_ecoset_full",
|
79
|
+
activations_model=get_model("resnet101_ecoset_full"),
|
80
|
+
layers=MODEL_CONFIGS["resnet101_ecoset_full"]["model_commitment"]["layers"],
|
81
|
+
behavioral_readout_layer=MODEL_CONFIGS["resnet101_ecoset_full"]["model_commitment"]["behavioral_readout_layer"],
|
82
|
+
region_layer_map=MODEL_CONFIGS["resnet101_ecoset_full"]["model_commitment"]["region_layer_map"]
|
83
|
+
)
|
84
|
+
|
85
|
+
|
86
|
+
model_registry["resnet152_ecoset_full"] = lambda: ModelCommitment(
|
87
|
+
identifier="resnet152_ecoset_full",
|
88
|
+
activations_model=get_model("resnet152_ecoset_full"),
|
89
|
+
layers=MODEL_CONFIGS["resnet152_ecoset_full"]["model_commitment"]["layers"],
|
90
|
+
behavioral_readout_layer=MODEL_CONFIGS["resnet152_ecoset_full"]["model_commitment"]["behavioral_readout_layer"],
|
91
|
+
region_layer_map=MODEL_CONFIGS["resnet152_ecoset_full"]["model_commitment"]["region_layer_map"]
|
92
|
+
)
|
93
|
+
|
94
|
+
|
95
|
+
model_registry["resnet50_imagenet_1_seed-0"] = lambda: ModelCommitment(
|
96
|
+
identifier="resnet50_imagenet_1_seed-0",
|
97
|
+
activations_model=get_model("resnet50_imagenet_1_seed-0"),
|
98
|
+
layers=MODEL_CONFIGS["resnet50_imagenet_1_seed-0"]["model_commitment"]["layers"],
|
99
|
+
behavioral_readout_layer=MODEL_CONFIGS["resnet50_imagenet_1_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
100
|
+
region_layer_map=MODEL_CONFIGS["resnet50_imagenet_1_seed-0"]["model_commitment"]["region_layer_map"]
|
101
|
+
)
|
102
|
+
|
103
|
+
|
104
|
+
model_registry["resnet50_imagenet_10_seed-0"] = lambda: ModelCommitment(
|
105
|
+
identifier="resnet50_imagenet_10_seed-0",
|
106
|
+
activations_model=get_model("resnet50_imagenet_10_seed-0"),
|
107
|
+
layers=MODEL_CONFIGS["resnet50_imagenet_10_seed-0"]["model_commitment"]["layers"],
|
108
|
+
behavioral_readout_layer=MODEL_CONFIGS["resnet50_imagenet_10_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
109
|
+
region_layer_map=MODEL_CONFIGS["resnet50_imagenet_10_seed-0"]["model_commitment"]["region_layer_map"]
|
110
|
+
)
|
111
|
+
|
112
|
+
|
113
|
+
model_registry["resnet50_imagenet_100_seed-0"] = lambda: ModelCommitment(
|
114
|
+
identifier="resnet50_imagenet_100_seed-0",
|
115
|
+
activations_model=get_model("resnet50_imagenet_100_seed-0"),
|
116
|
+
layers=MODEL_CONFIGS["resnet50_imagenet_100_seed-0"]["model_commitment"]["layers"],
|
117
|
+
behavioral_readout_layer=MODEL_CONFIGS["resnet50_imagenet_100_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
118
|
+
region_layer_map=MODEL_CONFIGS["resnet50_imagenet_100_seed-0"]["model_commitment"]["region_layer_map"]
|
119
|
+
)
|
120
|
+
|
121
|
+
|
122
|
+
model_registry["efficientnet_b0_imagenet_full"] = lambda: ModelCommitment(
|
123
|
+
identifier="efficientnet_b0_imagenet_full",
|
124
|
+
activations_model=get_model("efficientnet_b0_imagenet_full"),
|
125
|
+
layers=MODEL_CONFIGS["efficientnet_b0_imagenet_full"]["model_commitment"]["layers"],
|
126
|
+
behavioral_readout_layer=MODEL_CONFIGS["efficientnet_b0_imagenet_full"]["model_commitment"]["behavioral_readout_layer"],
|
127
|
+
region_layer_map=MODEL_CONFIGS["efficientnet_b0_imagenet_full"]["model_commitment"]["region_layer_map"]
|
128
|
+
)
|
129
|
+
|
130
|
+
|
131
|
+
model_registry["efficientnet_b1_imagenet_full"] = lambda: ModelCommitment(
|
132
|
+
identifier="efficientnet_b1_imagenet_full",
|
133
|
+
activations_model=get_model("efficientnet_b1_imagenet_full"),
|
134
|
+
layers=MODEL_CONFIGS["efficientnet_b1_imagenet_full"]["model_commitment"]["layers"],
|
135
|
+
behavioral_readout_layer=MODEL_CONFIGS["efficientnet_b1_imagenet_full"]["model_commitment"]["behavioral_readout_layer"],
|
136
|
+
region_layer_map=MODEL_CONFIGS["efficientnet_b1_imagenet_full"]["model_commitment"]["region_layer_map"]
|
137
|
+
)
|
138
|
+
|
139
|
+
|
140
|
+
model_registry["efficientnet_b2_imagenet_full"] = lambda: ModelCommitment(
|
141
|
+
identifier="efficientnet_b2_imagenet_full",
|
142
|
+
activations_model=get_model("efficientnet_b2_imagenet_full"),
|
143
|
+
layers=MODEL_CONFIGS["efficientnet_b2_imagenet_full"]["model_commitment"]["layers"],
|
144
|
+
behavioral_readout_layer=MODEL_CONFIGS["efficientnet_b2_imagenet_full"]["model_commitment"]["behavioral_readout_layer"],
|
145
|
+
region_layer_map=MODEL_CONFIGS["efficientnet_b2_imagenet_full"]["model_commitment"]["region_layer_map"]
|
146
|
+
)
|
147
|
+
|
148
|
+
|
149
|
+
model_registry["deit_small_imagenet_full_seed-0"] = lambda: ModelCommitment(
|
150
|
+
identifier="deit_small_imagenet_full_seed-0",
|
151
|
+
activations_model=get_model("deit_small_imagenet_full_seed-0"),
|
152
|
+
layers=MODEL_CONFIGS["deit_small_imagenet_full_seed-0"]["model_commitment"]["layers"],
|
153
|
+
behavioral_readout_layer=MODEL_CONFIGS["deit_small_imagenet_full_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
154
|
+
region_layer_map=MODEL_CONFIGS["deit_small_imagenet_full_seed-0"]["model_commitment"]["region_layer_map"]
|
155
|
+
)
|
156
|
+
|
157
|
+
|
158
|
+
model_registry["deit_base_imagenet_full_seed-0"] = lambda: ModelCommitment(
|
159
|
+
identifier="deit_base_imagenet_full_seed-0",
|
160
|
+
activations_model=get_model("deit_base_imagenet_full_seed-0"),
|
161
|
+
layers=MODEL_CONFIGS["deit_base_imagenet_full_seed-0"]["model_commitment"]["layers"],
|
162
|
+
behavioral_readout_layer=MODEL_CONFIGS["deit_base_imagenet_full_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
163
|
+
region_layer_map=MODEL_CONFIGS["deit_base_imagenet_full_seed-0"]["model_commitment"]["region_layer_map"]
|
164
|
+
)
|
165
|
+
|
166
|
+
|
167
|
+
model_registry["deit_large_imagenet_full_seed-0"] = lambda: ModelCommitment(
|
168
|
+
identifier="deit_large_imagenet_full_seed-0",
|
169
|
+
activations_model=get_model("deit_large_imagenet_full_seed-0"),
|
170
|
+
layers=MODEL_CONFIGS["deit_large_imagenet_full_seed-0"]["model_commitment"]["layers"],
|
171
|
+
behavioral_readout_layer=MODEL_CONFIGS["deit_large_imagenet_full_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
172
|
+
region_layer_map=MODEL_CONFIGS["deit_large_imagenet_full_seed-0"]["model_commitment"]["region_layer_map"]
|
173
|
+
)
|
174
|
+
|
175
|
+
|
176
|
+
model_registry["deit_small_imagenet_1_seed-0"] = lambda: ModelCommitment(
|
177
|
+
identifier="deit_small_imagenet_1_seed-0",
|
178
|
+
activations_model=get_model("deit_small_imagenet_1_seed-0"),
|
179
|
+
layers=MODEL_CONFIGS["deit_small_imagenet_1_seed-0"]["model_commitment"]["layers"],
|
180
|
+
behavioral_readout_layer=MODEL_CONFIGS["deit_small_imagenet_1_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
181
|
+
region_layer_map=MODEL_CONFIGS["deit_small_imagenet_1_seed-0"]["model_commitment"]["region_layer_map"]
|
182
|
+
)
|
183
|
+
|
184
|
+
|
185
|
+
model_registry["deit_small_imagenet_10_seed-0"] = lambda: ModelCommitment(
|
186
|
+
identifier="deit_small_imagenet_10_seed-0",
|
187
|
+
activations_model=get_model("deit_small_imagenet_10_seed-0"),
|
188
|
+
layers=MODEL_CONFIGS["deit_small_imagenet_10_seed-0"]["model_commitment"]["layers"],
|
189
|
+
behavioral_readout_layer=MODEL_CONFIGS["deit_small_imagenet_10_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
190
|
+
region_layer_map=MODEL_CONFIGS["deit_small_imagenet_10_seed-0"]["model_commitment"]["region_layer_map"]
|
191
|
+
)
|
192
|
+
|
193
|
+
|
194
|
+
model_registry["deit_small_imagenet_100_seed-0"] = lambda: ModelCommitment(
|
195
|
+
identifier="deit_small_imagenet_100_seed-0",
|
196
|
+
activations_model=get_model("deit_small_imagenet_100_seed-0"),
|
197
|
+
layers=MODEL_CONFIGS["deit_small_imagenet_100_seed-0"]["model_commitment"]["layers"],
|
198
|
+
behavioral_readout_layer=MODEL_CONFIGS["deit_small_imagenet_100_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
199
|
+
region_layer_map=MODEL_CONFIGS["deit_small_imagenet_100_seed-0"]["model_commitment"]["region_layer_map"]
|
200
|
+
)
|
201
|
+
|
202
|
+
|
203
|
+
model_registry["convnext_tiny_imagenet_full_seed-0"] = lambda: ModelCommitment(
|
204
|
+
identifier="convnext_tiny_imagenet_full_seed-0",
|
205
|
+
activations_model=get_model("convnext_tiny_imagenet_full_seed-0"),
|
206
|
+
layers=MODEL_CONFIGS["convnext_tiny_imagenet_full_seed-0"]["model_commitment"]["layers"],
|
207
|
+
behavioral_readout_layer=MODEL_CONFIGS["convnext_tiny_imagenet_full_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
208
|
+
region_layer_map=MODEL_CONFIGS["convnext_tiny_imagenet_full_seed-0"]["model_commitment"]["region_layer_map"]
|
209
|
+
)
|
210
|
+
|
211
|
+
|
212
|
+
model_registry["convnext_small_imagenet_full_seed-0"] = lambda: ModelCommitment(
|
213
|
+
identifier="convnext_small_imagenet_full_seed-0",
|
214
|
+
activations_model=get_model("convnext_small_imagenet_full_seed-0"),
|
215
|
+
layers=MODEL_CONFIGS["convnext_small_imagenet_full_seed-0"]["model_commitment"]["layers"],
|
216
|
+
behavioral_readout_layer=MODEL_CONFIGS["convnext_small_imagenet_full_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
217
|
+
region_layer_map=MODEL_CONFIGS["convnext_small_imagenet_full_seed-0"]["model_commitment"]["region_layer_map"]
|
218
|
+
)
|
219
|
+
|
220
|
+
|
221
|
+
model_registry["convnext_base_imagenet_full_seed-0"] = lambda: ModelCommitment(
|
222
|
+
identifier="convnext_base_imagenet_full_seed-0",
|
223
|
+
activations_model=get_model("convnext_base_imagenet_full_seed-0"),
|
224
|
+
layers=MODEL_CONFIGS["convnext_base_imagenet_full_seed-0"]["model_commitment"]["layers"],
|
225
|
+
behavioral_readout_layer=MODEL_CONFIGS["convnext_base_imagenet_full_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
226
|
+
region_layer_map=MODEL_CONFIGS["convnext_base_imagenet_full_seed-0"]["model_commitment"]["region_layer_map"]
|
227
|
+
)
|
228
|
+
|
229
|
+
|
230
|
+
model_registry["convnext_large_imagenet_full_seed-0"] = lambda: ModelCommitment(
|
231
|
+
identifier="convnext_large_imagenet_full_seed-0",
|
232
|
+
activations_model=get_model("convnext_large_imagenet_full_seed-0"),
|
233
|
+
layers=MODEL_CONFIGS["convnext_large_imagenet_full_seed-0"]["model_commitment"]["layers"],
|
234
|
+
behavioral_readout_layer=MODEL_CONFIGS["convnext_large_imagenet_full_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
235
|
+
region_layer_map=MODEL_CONFIGS["convnext_large_imagenet_full_seed-0"]["model_commitment"]["region_layer_map"]
|
236
|
+
)
|
237
|
+
|
238
|
+
|
239
|
+
model_registry["convnext_small_imagenet_1_seed-0"] = lambda: ModelCommitment(
|
240
|
+
identifier="convnext_small_imagenet_1_seed-0",
|
241
|
+
activations_model=get_model("convnext_small_imagenet_1_seed-0"),
|
242
|
+
layers=MODEL_CONFIGS["convnext_small_imagenet_1_seed-0"]["model_commitment"]["layers"],
|
243
|
+
behavioral_readout_layer=MODEL_CONFIGS["convnext_small_imagenet_1_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
244
|
+
region_layer_map=MODEL_CONFIGS["convnext_small_imagenet_1_seed-0"]["model_commitment"]["region_layer_map"]
|
245
|
+
)
|
246
|
+
|
247
|
+
|
248
|
+
model_registry["convnext_small_imagenet_10_seed-0"] = lambda: ModelCommitment(
|
249
|
+
identifier="convnext_small_imagenet_10_seed-0",
|
250
|
+
activations_model=get_model("convnext_small_imagenet_10_seed-0"),
|
251
|
+
layers=MODEL_CONFIGS["convnext_small_imagenet_10_seed-0"]["model_commitment"]["layers"],
|
252
|
+
behavioral_readout_layer=MODEL_CONFIGS["convnext_small_imagenet_10_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
253
|
+
region_layer_map=MODEL_CONFIGS["convnext_small_imagenet_10_seed-0"]["model_commitment"]["region_layer_map"]
|
254
|
+
)
|
255
|
+
|
256
|
+
|
257
|
+
model_registry["convnext_small_imagenet_100_seed-0"] = lambda: ModelCommitment(
|
258
|
+
identifier="convnext_small_imagenet_100_seed-0",
|
259
|
+
activations_model=get_model("convnext_small_imagenet_100_seed-0"),
|
260
|
+
layers=MODEL_CONFIGS["convnext_small_imagenet_100_seed-0"]["model_commitment"]["layers"],
|
261
|
+
behavioral_readout_layer=MODEL_CONFIGS["convnext_small_imagenet_100_seed-0"]["model_commitment"]["behavioral_readout_layer"],
|
262
|
+
region_layer_map=MODEL_CONFIGS["convnext_small_imagenet_100_seed-0"]["model_commitment"]["region_layer_map"]
|
263
|
+
)
|
264
|
+
|
265
|
+
|
@@ -0,0 +1,148 @@
|
|
1
|
+
import os
|
2
|
+
import functools
|
3
|
+
import json
|
4
|
+
from pathlib import Path
|
5
|
+
import ssl
|
6
|
+
|
7
|
+
import torchvision.models
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
|
11
|
+
from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images
|
12
|
+
|
13
|
+
import timm
|
14
|
+
import numpy as np
|
15
|
+
import torchvision.transforms as T
|
16
|
+
from PIL import Image
|
17
|
+
|
18
|
+
import albumentations as A
|
19
|
+
from albumentations.pytorch import ToTensorV2
|
20
|
+
|
21
|
+
# Disable SSL verification
|
22
|
+
ssl._create_default_https_context = ssl._create_unverified_context
|
23
|
+
|
24
|
+
BIBTEX = """"""
|
25
|
+
|
26
|
+
|
27
|
+
with open(Path(__file__).parent / "model_configs.json", "r") as f:
|
28
|
+
MODEL_CONFIGS = json.load(f)
|
29
|
+
|
30
|
+
|
31
|
+
def load_image(image_filepath):
|
32
|
+
return Image.open(image_filepath).convert("RGB")
|
33
|
+
|
34
|
+
|
35
|
+
def get_interpolation_mode(interpolation: str) -> int:
|
36
|
+
"""Returns the interpolation mode for albumentations"""
|
37
|
+
if "linear" or "bilinear" in interpolation:
|
38
|
+
return 1
|
39
|
+
elif "cubic" or "bicubic" in interpolation:
|
40
|
+
return 2
|
41
|
+
else:
|
42
|
+
raise NotImplementedError(f"Interpolation mode {interpolation} not implemented")
|
43
|
+
|
44
|
+
|
45
|
+
def custom_image_preprocess(
|
46
|
+
images,
|
47
|
+
resize_size: int,
|
48
|
+
crop_size: int,
|
49
|
+
interpolation: str,
|
50
|
+
transforms=None,
|
51
|
+
):
|
52
|
+
if transforms is None:
|
53
|
+
interpolation = get_interpolation_mode(interpolation)
|
54
|
+
transforms = A.Compose(
|
55
|
+
[
|
56
|
+
A.Resize(resize_size, resize_size, p=1.0, interpolation=interpolation),
|
57
|
+
A.CenterCrop(crop_size, crop_size, p=1.0),
|
58
|
+
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
59
|
+
ToTensorV2(),
|
60
|
+
]
|
61
|
+
)
|
62
|
+
if isinstance(transforms, T.Compose):
|
63
|
+
images = [transforms(image) for image in images]
|
64
|
+
images = [np.array(image) for image in images]
|
65
|
+
images = np.stack(images)
|
66
|
+
elif isinstance(transforms, A.Compose):
|
67
|
+
images = [transforms(image=np.array(image))["image"] for image in images]
|
68
|
+
images = np.stack(images)
|
69
|
+
else:
|
70
|
+
raise NotImplementedError(
|
71
|
+
f"Transform of type {type(transforms)} is not implemented"
|
72
|
+
)
|
73
|
+
|
74
|
+
return images
|
75
|
+
|
76
|
+
|
77
|
+
def load_preprocess_images_custom(
|
78
|
+
image_filepaths, preprocess_images=custom_image_preprocess, **kwargs
|
79
|
+
):
|
80
|
+
images = [load_image(image_filepath) for image_filepath in image_filepaths]
|
81
|
+
images = preprocess_images(images, **kwargs)
|
82
|
+
return images
|
83
|
+
|
84
|
+
|
85
|
+
def get_model(model_id:str):
|
86
|
+
|
87
|
+
# Unpack model config
|
88
|
+
config = MODEL_CONFIGS[model_id]
|
89
|
+
model_name = config["model_name"]
|
90
|
+
model_id = config["model_id"]
|
91
|
+
resize_size = config["resize_size"]
|
92
|
+
crop_size = config["crop_size"]
|
93
|
+
interpolation = config["interpolation"]
|
94
|
+
num_classes = config["num_classes"]
|
95
|
+
ckpt_url = config["checkpoint_url"]
|
96
|
+
use_timm = config["use_timm"]
|
97
|
+
timm_model_name = config["timm_model_name"]
|
98
|
+
epoch = config["epoch"]
|
99
|
+
load_model_ema = config["load_model_ema"]
|
100
|
+
output_head = config["output_head"]
|
101
|
+
is_vit = config["is_vit"]
|
102
|
+
|
103
|
+
# Temporary fix for vit models
|
104
|
+
# See https://github.com/brain-score/vision/pull/1232
|
105
|
+
if is_vit:
|
106
|
+
os.environ['RESULTCACHING_DISABLE'] = 'brainscore_vision.model_helpers.activations.core.ActivationsExtractorHelper._from_paths_stored'
|
107
|
+
|
108
|
+
|
109
|
+
# Initialize model
|
110
|
+
if use_timm:
|
111
|
+
model = timm.create_model(timm_model_name, pretrained=False, num_classes=num_classes)
|
112
|
+
else:
|
113
|
+
model = eval(f"torchvision.models.{model_name}(weights=None)")
|
114
|
+
if num_classes != 1000:
|
115
|
+
exec(f'''{output_head} = torch.nn.Linear(
|
116
|
+
in_features={output_head}.in_features,
|
117
|
+
out_features=num_classes,
|
118
|
+
bias={output_head}.bias is not None,
|
119
|
+
)'''
|
120
|
+
)
|
121
|
+
|
122
|
+
# Load model weights
|
123
|
+
state_dict = torch.hub.load_state_dict_from_url(
|
124
|
+
ckpt_url,
|
125
|
+
check_hash=True,
|
126
|
+
file_name=f"{model_id}_ep{epoch}.pt",
|
127
|
+
map_location="cpu",
|
128
|
+
)
|
129
|
+
if load_model_ema:
|
130
|
+
state_dict = state_dict["state"]["model_ema_state_dict"]
|
131
|
+
else:
|
132
|
+
state_dict = state_dict["state"]["model"]
|
133
|
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
134
|
+
model.load_state_dict(state_dict, strict=True)
|
135
|
+
print(f"Model loaded from {ckpt_url}")
|
136
|
+
|
137
|
+
# Wrap model
|
138
|
+
preprocessing = functools.partial(
|
139
|
+
load_preprocess_images_custom,
|
140
|
+
resize_size=resize_size,
|
141
|
+
crop_size=crop_size,
|
142
|
+
interpolation=interpolation,
|
143
|
+
transforms=None
|
144
|
+
)
|
145
|
+
wrapper = PytorchWrapper(
|
146
|
+
identifier=model_id, model=model, preprocessing=preprocessing
|
147
|
+
)
|
148
|
+
return wrapper
|