brainscore-vision 2.2.2__py3-none-any.whl → 2.2.4__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- brainscore_vision/models/alexnet_less_variation_1/__init__.py +6 -0
- brainscore_vision/models/alexnet_less_variation_1/model.py +200 -0
- brainscore_vision/models/alexnet_less_variation_1/region_layer_map/alexnet_less_variation_iteration=1.json +6 -0
- brainscore_vision/models/alexnet_less_variation_1/setup.py +29 -0
- brainscore_vision/models/alexnet_less_variation_1/test.py +3 -0
- brainscore_vision/models/alexnet_less_variation_2/__init__.py +6 -0
- brainscore_vision/models/alexnet_less_variation_2/model.py +200 -0
- brainscore_vision/models/alexnet_less_variation_2/region_layer_map/alexnet_less_variation_iteration=2.json +6 -0
- brainscore_vision/models/alexnet_less_variation_2/setup.py +29 -0
- brainscore_vision/models/alexnet_less_variation_2/test.py +3 -0
- brainscore_vision/models/alexnet_less_variation_4/__init__.py +6 -0
- brainscore_vision/models/alexnet_less_variation_4/model.py +200 -0
- brainscore_vision/models/alexnet_less_variation_4/region_layer_map/alexnet_less_variation_iteration=4.json +6 -0
- brainscore_vision/models/alexnet_less_variation_4/setup.py +29 -0
- brainscore_vision/models/alexnet_less_variation_4/test.py +3 -0
- brainscore_vision/models/alexnet_no_specular_2/__init__.py +6 -0
- brainscore_vision/models/alexnet_no_specular_2/model.py +200 -0
- brainscore_vision/models/alexnet_no_specular_2/region_layer_map/alexnet_no_specular_iteration=2.json +6 -0
- brainscore_vision/models/alexnet_no_specular_2/setup.py +29 -0
- brainscore_vision/models/alexnet_no_specular_2/test.py +3 -0
- brainscore_vision/models/alexnet_no_specular_4/__init__.py +6 -0
- brainscore_vision/models/alexnet_no_specular_4/model.py +200 -0
- brainscore_vision/models/alexnet_no_specular_4/region_layer_map/alexnet_no_specular_iteration=4.json +6 -0
- brainscore_vision/models/alexnet_no_specular_4/setup.py +29 -0
- brainscore_vision/models/alexnet_no_specular_4/test.py +3 -0
- brainscore_vision/models/alexnet_no_variation_4/__init__.py +6 -0
- brainscore_vision/models/alexnet_no_variation_4/model.py +200 -0
- brainscore_vision/models/alexnet_no_variation_4/region_layer_map/alexnet_no_variation_iteration=4.json +6 -0
- brainscore_vision/models/alexnet_no_variation_4/setup.py +29 -0
- brainscore_vision/models/alexnet_no_variation_4/test.py +3 -0
- brainscore_vision/models/alexnet_original_3/__init__.py +6 -0
- brainscore_vision/models/alexnet_original_3/model.py +200 -0
- brainscore_vision/models/alexnet_original_3/region_layer_map/alexnet_original_iteration=3.json +6 -0
- brainscore_vision/models/alexnet_original_3/setup.py +29 -0
- brainscore_vision/models/alexnet_original_3/test.py +3 -0
- brainscore_vision/models/alexnet_wo_shading_4/__init__.py +6 -0
- brainscore_vision/models/alexnet_wo_shading_4/model.py +200 -0
- brainscore_vision/models/alexnet_wo_shading_4/region_layer_map/alexnet_wo_shading_iteration=4.json +6 -0
- brainscore_vision/models/alexnet_wo_shading_4/setup.py +29 -0
- brainscore_vision/models/alexnet_wo_shading_4/test.py +3 -0
- brainscore_vision/models/alexnet_wo_shadows_5/__init__.py +6 -0
- brainscore_vision/models/alexnet_wo_shadows_5/model.py +200 -0
- brainscore_vision/models/alexnet_wo_shadows_5/region_layer_map/alexnet_wo_shadows_iteration=5.json +6 -0
- brainscore_vision/models/alexnet_wo_shadows_5/setup.py +29 -0
- brainscore_vision/models/alexnet_wo_shadows_5/test.py +3 -0
- brainscore_vision/models/alexnet_z_axis_1/__init__.py +6 -0
- brainscore_vision/models/alexnet_z_axis_1/model.py +200 -0
- brainscore_vision/models/alexnet_z_axis_1/region_layer_map/alexnet_z_axis_iteration=1.json +6 -0
- brainscore_vision/models/alexnet_z_axis_1/setup.py +29 -0
- brainscore_vision/models/alexnet_z_axis_1/test.py +3 -0
- brainscore_vision/models/alexnet_z_axis_2/__init__.py +6 -0
- brainscore_vision/models/alexnet_z_axis_2/model.py +200 -0
- brainscore_vision/models/alexnet_z_axis_2/region_layer_map/alexnet_z_axis_iteration=2.json +6 -0
- brainscore_vision/models/alexnet_z_axis_2/setup.py +29 -0
- brainscore_vision/models/alexnet_z_axis_2/test.py +3 -0
- brainscore_vision/models/alexnet_z_axis_3/__init__.py +6 -0
- brainscore_vision/models/alexnet_z_axis_3/model.py +200 -0
- brainscore_vision/models/alexnet_z_axis_3/region_layer_map/alexnet_z_axis_iteration=3.json +6 -0
- brainscore_vision/models/alexnet_z_axis_3/setup.py +29 -0
- brainscore_vision/models/alexnet_z_axis_3/test.py +3 -0
- brainscore_vision/models/alexnet_z_axis_4/__init__.py +6 -0
- brainscore_vision/models/alexnet_z_axis_4/model.py +200 -0
- brainscore_vision/models/alexnet_z_axis_4/region_layer_map/alexnet_z_axis_iteration=4.json +6 -0
- brainscore_vision/models/alexnet_z_axis_4/setup.py +29 -0
- brainscore_vision/models/alexnet_z_axis_4/test.py +3 -0
- brainscore_vision/models/artResNet18_1/__init__.py +5 -0
- brainscore_vision/models/artResNet18_1/model.py +66 -0
- brainscore_vision/models/artResNet18_1/requirements.txt +4 -0
- brainscore_vision/models/artResNet18_1/test.py +12 -0
- brainscore_vision/models/barlow_twins_custom/__init__.py +5 -0
- brainscore_vision/models/barlow_twins_custom/model.py +58 -0
- brainscore_vision/models/barlow_twins_custom/requirements.txt +4 -0
- brainscore_vision/models/barlow_twins_custom/test.py +12 -0
- brainscore_vision/models/blt-vs/__init__.py +15 -0
- brainscore_vision/models/blt-vs/model.py +962 -0
- brainscore_vision/models/blt-vs/pretrained.py +219 -0
- brainscore_vision/models/blt-vs/region_layer_map/blt_vs.json +6 -0
- brainscore_vision/models/blt-vs/setup.py +22 -0
- brainscore_vision/models/blt-vs/test.py +0 -0
- brainscore_vision/models/cifar_resnet18_1/__init__.py +5 -0
- brainscore_vision/models/cifar_resnet18_1/model.py +68 -0
- brainscore_vision/models/cifar_resnet18_1/requirements.txt +4 -0
- brainscore_vision/models/cifar_resnet18_1/test.py +10 -0
- brainscore_vision/models/resnet18_random/__init__.py +5 -0
- brainscore_vision/models/resnet18_random/archive_name.zip +0 -0
- brainscore_vision/models/resnet18_random/model.py +42 -0
- brainscore_vision/models/resnet18_random/requirements.txt +2 -0
- brainscore_vision/models/resnet18_random/test.py +12 -0
- brainscore_vision/models/resnet50_less_variation_1/__init__.py +6 -0
- brainscore_vision/models/resnet50_less_variation_1/model.py +200 -0
- brainscore_vision/models/resnet50_less_variation_1/region_layer_map/resnet50_less_variation_iteration=1.json +6 -0
- brainscore_vision/models/resnet50_less_variation_1/setup.py +29 -0
- brainscore_vision/models/resnet50_less_variation_1/test.py +3 -0
- brainscore_vision/models/resnet50_less_variation_2/__init__.py +6 -0
- brainscore_vision/models/resnet50_less_variation_2/model.py +200 -0
- brainscore_vision/models/resnet50_less_variation_2/region_layer_map/resnet50_less_variation_iteration=2.json +6 -0
- brainscore_vision/models/resnet50_less_variation_2/setup.py +29 -0
- brainscore_vision/models/resnet50_less_variation_2/test.py +3 -0
- brainscore_vision/models/resnet50_less_variation_3/__init__.py +6 -0
- brainscore_vision/models/resnet50_less_variation_3/model.py +200 -0
- brainscore_vision/models/resnet50_less_variation_3/region_layer_map/resnet50_less_variation_iteration=3.json +6 -0
- brainscore_vision/models/resnet50_less_variation_3/setup.py +29 -0
- brainscore_vision/models/resnet50_less_variation_3/test.py +3 -0
- brainscore_vision/models/resnet50_less_variation_4/__init__.py +6 -0
- brainscore_vision/models/resnet50_less_variation_4/model.py +200 -0
- brainscore_vision/models/resnet50_less_variation_4/region_layer_map/resnet50_less_variation_iteration=4.json +6 -0
- brainscore_vision/models/resnet50_less_variation_4/setup.py +29 -0
- brainscore_vision/models/resnet50_less_variation_4/test.py +3 -0
- brainscore_vision/models/resnet50_less_variation_5/__init__.py +6 -0
- brainscore_vision/models/resnet50_less_variation_5/model.py +200 -0
- brainscore_vision/models/resnet50_less_variation_5/region_layer_map/resnet50_less_variation_iteration=5.json +6 -0
- brainscore_vision/models/resnet50_less_variation_5/setup.py +29 -0
- brainscore_vision/models/resnet50_less_variation_5/test.py +3 -0
- brainscore_vision/models/resnet50_no_variation_1/__init__.py +6 -0
- brainscore_vision/models/resnet50_no_variation_1/model.py +200 -0
- brainscore_vision/models/resnet50_no_variation_1/region_layer_map/resnet50_no_variation_iteration=1.json +6 -0
- brainscore_vision/models/resnet50_no_variation_1/setup.py +29 -0
- brainscore_vision/models/resnet50_no_variation_1/test.py +3 -0
- brainscore_vision/models/resnet50_no_variation_2/__init__.py +6 -0
- brainscore_vision/models/resnet50_no_variation_2/model.py +200 -0
- brainscore_vision/models/resnet50_no_variation_2/region_layer_map/resnet50_no_variation_iteration=2.json +6 -0
- brainscore_vision/models/resnet50_no_variation_2/setup.py +29 -0
- brainscore_vision/models/resnet50_no_variation_2/test.py +3 -0
- brainscore_vision/models/resnet50_no_variation_5/__init__.py +6 -0
- brainscore_vision/models/resnet50_no_variation_5/model.py +200 -0
- brainscore_vision/models/resnet50_no_variation_5/region_layer_map/resnet50_no_variation_iteration=5.json +6 -0
- brainscore_vision/models/resnet50_no_variation_5/setup.py +29 -0
- brainscore_vision/models/resnet50_no_variation_5/test.py +3 -0
- brainscore_vision/models/resnet50_original_1/__init__.py +6 -0
- brainscore_vision/models/resnet50_original_1/model.py +200 -0
- brainscore_vision/models/resnet50_original_1/region_layer_map/resnet50_original_iteration=1.json +6 -0
- brainscore_vision/models/resnet50_original_1/setup.py +29 -0
- brainscore_vision/models/resnet50_original_1/test.py +3 -0
- brainscore_vision/models/resnet50_original_2/__init__.py +6 -0
- brainscore_vision/models/resnet50_original_2/model.py +200 -0
- brainscore_vision/models/resnet50_original_2/region_layer_map/resnet50_original_iteration=2.json +6 -0
- brainscore_vision/models/resnet50_original_2/setup.py +29 -0
- brainscore_vision/models/resnet50_original_2/test.py +3 -0
- brainscore_vision/models/resnet50_original_5/__init__.py +6 -0
- brainscore_vision/models/resnet50_original_5/model.py +200 -0
- brainscore_vision/models/resnet50_original_5/region_layer_map/resnet50_original_iteration=5.json +6 -0
- brainscore_vision/models/resnet50_original_5/setup.py +29 -0
- brainscore_vision/models/resnet50_original_5/test.py +3 -0
- brainscore_vision/models/resnet50_textures_1/__init__.py +6 -0
- brainscore_vision/models/resnet50_textures_1/model.py +200 -0
- brainscore_vision/models/resnet50_textures_1/region_layer_map/resnet50_textures_iteration=1.json +6 -0
- brainscore_vision/models/resnet50_textures_1/setup.py +29 -0
- brainscore_vision/models/resnet50_textures_1/test.py +3 -0
- brainscore_vision/models/resnet50_textures_2/__init__.py +6 -0
- brainscore_vision/models/resnet50_textures_2/model.py +200 -0
- brainscore_vision/models/resnet50_textures_2/region_layer_map/resnet50_textures_iteration=2.json +6 -0
- brainscore_vision/models/resnet50_textures_2/setup.py +29 -0
- brainscore_vision/models/resnet50_textures_2/test.py +3 -0
- brainscore_vision/models/resnet50_textures_3/__init__.py +6 -0
- brainscore_vision/models/resnet50_textures_3/model.py +200 -0
- brainscore_vision/models/resnet50_textures_3/region_layer_map/resnet50_textures_iteration=3.json +6 -0
- brainscore_vision/models/resnet50_textures_3/setup.py +29 -0
- brainscore_vision/models/resnet50_textures_3/test.py +3 -0
- brainscore_vision/models/resnet50_textures_4/__init__.py +6 -0
- brainscore_vision/models/resnet50_textures_4/model.py +200 -0
- brainscore_vision/models/resnet50_textures_4/region_layer_map/resnet50_textures_iteration=4.json +6 -0
- brainscore_vision/models/resnet50_textures_4/setup.py +29 -0
- brainscore_vision/models/resnet50_textures_4/test.py +3 -0
- brainscore_vision/models/resnet50_textures_5/__init__.py +6 -0
- brainscore_vision/models/resnet50_textures_5/model.py +200 -0
- brainscore_vision/models/resnet50_textures_5/region_layer_map/resnet50_textures_iteration=5.json +6 -0
- brainscore_vision/models/resnet50_textures_5/setup.py +29 -0
- brainscore_vision/models/resnet50_textures_5/test.py +3 -0
- brainscore_vision/models/resnet50_wo_shading_1/__init__.py +6 -0
- brainscore_vision/models/resnet50_wo_shading_1/model.py +200 -0
- brainscore_vision/models/resnet50_wo_shading_1/region_layer_map/resnet50_wo_shading_iteration=1.json +6 -0
- brainscore_vision/models/resnet50_wo_shading_1/setup.py +29 -0
- brainscore_vision/models/resnet50_wo_shading_1/test.py +3 -0
- brainscore_vision/models/resnet50_wo_shading_3/__init__.py +6 -0
- brainscore_vision/models/resnet50_wo_shading_3/model.py +200 -0
- brainscore_vision/models/resnet50_wo_shading_3/region_layer_map/resnet50_wo_shading_iteration=3.json +6 -0
- brainscore_vision/models/resnet50_wo_shading_3/setup.py +29 -0
- brainscore_vision/models/resnet50_wo_shading_3/test.py +3 -0
- brainscore_vision/models/resnet50_wo_shading_4/__init__.py +6 -0
- brainscore_vision/models/resnet50_wo_shading_4/model.py +200 -0
- brainscore_vision/models/resnet50_wo_shading_4/region_layer_map/resnet50_wo_shading_iteration=4.json +6 -0
- brainscore_vision/models/resnet50_wo_shading_4/setup.py +29 -0
- brainscore_vision/models/resnet50_wo_shading_4/test.py +3 -0
- brainscore_vision/models/resnet50_wo_shadows_4/__init__.py +6 -0
- brainscore_vision/models/resnet50_wo_shadows_4/model.py +200 -0
- brainscore_vision/models/resnet50_wo_shadows_4/region_layer_map/resnet50_wo_shadows_iteration=4.json +6 -0
- brainscore_vision/models/resnet50_wo_shadows_4/setup.py +29 -0
- brainscore_vision/models/resnet50_wo_shadows_4/test.py +3 -0
- brainscore_vision/models/resnet50_z_axis_1/__init__.py +6 -0
- brainscore_vision/models/resnet50_z_axis_1/model.py +200 -0
- brainscore_vision/models/resnet50_z_axis_1/region_layer_map/resnet50_z_axis_iteration=1.json +6 -0
- brainscore_vision/models/resnet50_z_axis_1/setup.py +29 -0
- brainscore_vision/models/resnet50_z_axis_1/test.py +3 -0
- brainscore_vision/models/resnet50_z_axis_2/__init__.py +6 -0
- brainscore_vision/models/resnet50_z_axis_2/model.py +200 -0
- brainscore_vision/models/resnet50_z_axis_2/region_layer_map/resnet50_z_axis_iteration=2.json +6 -0
- brainscore_vision/models/resnet50_z_axis_2/setup.py +29 -0
- brainscore_vision/models/resnet50_z_axis_2/test.py +3 -0
- brainscore_vision/models/resnet50_z_axis_3/__init__.py +6 -0
- brainscore_vision/models/resnet50_z_axis_3/model.py +200 -0
- brainscore_vision/models/resnet50_z_axis_3/region_layer_map/resnet50_z_axis_iteration=3.json +6 -0
- brainscore_vision/models/resnet50_z_axis_3/setup.py +29 -0
- brainscore_vision/models/resnet50_z_axis_3/test.py +3 -0
- brainscore_vision/models/resnet50_z_axis_5/__init__.py +6 -0
- brainscore_vision/models/resnet50_z_axis_5/model.py +200 -0
- brainscore_vision/models/resnet50_z_axis_5/region_layer_map/resnet50_z_axis_iteration=5.json +6 -0
- brainscore_vision/models/resnet50_z_axis_5/setup.py +29 -0
- brainscore_vision/models/resnet50_z_axis_5/test.py +3 -0
- {brainscore_vision-2.2.2.dist-info → brainscore_vision-2.2.4.dist-info}/METADATA +1 -1
- {brainscore_vision-2.2.2.dist-info → brainscore_vision-2.2.4.dist-info}/RECORD +213 -5
- {brainscore_vision-2.2.2.dist-info → brainscore_vision-2.2.4.dist-info}/LICENSE +0 -0
- {brainscore_vision-2.2.2.dist-info → brainscore_vision-2.2.4.dist-info}/WHEEL +0 -0
- {brainscore_vision-2.2.2.dist-info → brainscore_vision-2.2.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,962 @@
|
|
1
|
+
import os
|
2
|
+
import sys
|
3
|
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
import torch.nn.functional as F
|
7
|
+
import numpy as np
|
8
|
+
import functools
|
9
|
+
from torchvision import transforms
|
10
|
+
from brainscore_vision.model_helpers.check_submission import check_models
|
11
|
+
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
|
12
|
+
from pretrained import get_model_instance, clear_models_and_aliases, register_model, register_aliases
|
13
|
+
from PIL import Image
|
14
|
+
|
15
|
+
SUBMODULE_SEPARATOR = '.'
|
16
|
+
|
17
|
+
LAYERS = ['Retina_5', 'LGN_5', 'V1_5', 'V2_5', 'V3_5', 'V4_5', 'LOC_5', 'logits']
|
18
|
+
|
19
|
+
|
20
|
+
def get_model(model_name='blt_vs', key_or_alias='blt_vs', image_size=224):
|
21
|
+
"""
|
22
|
+
Get a model instance with preprocessing wrapped in a PytorchWrapper.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
model_name (str): Identifier for the model.
|
26
|
+
key_or_alias (str): Key or alias for the registered model.
|
27
|
+
image_size (int): Input image size for preprocessing.
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
PytorchWrapper: A wrapper around the model with preprocessing.
|
31
|
+
"""
|
32
|
+
|
33
|
+
clear_models_and_aliases(BLT_VS)
|
34
|
+
|
35
|
+
register_model(
|
36
|
+
BLT_VS,
|
37
|
+
'blt_vs',
|
38
|
+
'https://zenodo.org/records/14223659/files/blt_vs.zip',
|
39
|
+
'36d74a367a261e788028c6c9caa7a5675fee48e938a6b86a6c62655b23afaf53'
|
40
|
+
)
|
41
|
+
|
42
|
+
register_aliases(BLT_VS, 'blt_vs', 'blt_vs')
|
43
|
+
|
44
|
+
|
45
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
46
|
+
|
47
|
+
preprocessing = functools.partial(load_preprocess_images_sush, image_size=image_size)
|
48
|
+
|
49
|
+
blt_model = get_model_instance(BLT_VS, key_or_alias)
|
50
|
+
blt_model.to(device)
|
51
|
+
wrapper = PytorchWrapper(identifier=model_name, model=blt_model, preprocessing=preprocessing)
|
52
|
+
|
53
|
+
return wrapper
|
54
|
+
|
55
|
+
|
56
|
+
def load_preprocess_images_sush(image_filepaths, image_size, **kwargs):
|
57
|
+
images = load_images(image_filepaths)
|
58
|
+
images = preprocess_images_sush(images, image_size=image_size, **kwargs)
|
59
|
+
return images
|
60
|
+
|
61
|
+
|
62
|
+
def load_images(image_filepaths):
|
63
|
+
return [load_image(image_filepath) for image_filepath in image_filepaths]
|
64
|
+
|
65
|
+
def preprocess_images_sush(images, image_size, **kwargs):
|
66
|
+
preprocess = torchvision_preprocess_input_sush(image_size, **kwargs)
|
67
|
+
images = [preprocess(image) for image in images]
|
68
|
+
images = np.concatenate(images)
|
69
|
+
return images
|
70
|
+
|
71
|
+
def torchvision_preprocess_sush(normalize_mean=(0.5, 0.5, 0.5), normalize_std=(0.5, 0.5, 0.5)):
|
72
|
+
|
73
|
+
|
74
|
+
|
75
|
+
return transforms.Compose([
|
76
|
+
transforms.ToTensor(),
|
77
|
+
transforms.Normalize(mean=normalize_mean, std=normalize_std),
|
78
|
+
lambda img: img.unsqueeze(0)
|
79
|
+
])
|
80
|
+
|
81
|
+
def torchvision_preprocess_input_sush(image_size, **kwargs):
|
82
|
+
from torchvision import transforms
|
83
|
+
return transforms.Compose([
|
84
|
+
transforms.Resize((image_size, image_size)),
|
85
|
+
transforms.CenterCrop(image_size),
|
86
|
+
torchvision_preprocess_sush(**kwargs),
|
87
|
+
])
|
88
|
+
|
89
|
+
def load_image(image_filepath):
|
90
|
+
with Image.open(image_filepath) as pil_image:
|
91
|
+
if 'L' not in pil_image.mode.upper() and 'A' not in pil_image.mode.upper() \
|
92
|
+
and 'P' not in pil_image.mode.upper(): # not binary and not alpha and not palletized
|
93
|
+
# work around to https://github.com/python-pillow/Pillow/issues/1144,
|
94
|
+
# see https://stackoverflow.com/a/30376272/2225200
|
95
|
+
return pil_image.copy()
|
96
|
+
else: # make sure potential binary images are in RGB
|
97
|
+
rgb_image = Image.new("RGB", pil_image.size)
|
98
|
+
rgb_image.paste(pil_image)
|
99
|
+
return rgb_image
|
100
|
+
|
101
|
+
|
102
|
+
|
103
|
+
class BLT_VS(nn.Module):
|
104
|
+
"""
|
105
|
+
BLT_VS model simulates the ventral stream of the visual cortex. See BLT_VS_info.txt for more details on motivation and design.
|
106
|
+
|
107
|
+
Parameters:
|
108
|
+
-----------
|
109
|
+
timesteps : int
|
110
|
+
Number of time steps for the recurrent computation.
|
111
|
+
num_classes : int
|
112
|
+
Number of output classes for classification.
|
113
|
+
add_feats : int
|
114
|
+
Additional features to maintain orientation, color, etc.
|
115
|
+
lateral_connections : bool
|
116
|
+
Whether to include lateral connections.
|
117
|
+
topdown_connections : bool
|
118
|
+
Whether to include top-down connections.
|
119
|
+
skip_connections : bool
|
120
|
+
Whether to include skip connections.
|
121
|
+
bio_unroll : bool
|
122
|
+
Whether to use biological unrolling.
|
123
|
+
image_size : int
|
124
|
+
Size of the input image (height and width).
|
125
|
+
hook_type : str
|
126
|
+
What kind of area/timestep hooks to register. Options are 'concat' (concat BU/TD), 'separate', 'None'.
|
127
|
+
"""
|
128
|
+
|
129
|
+
def __init__(
|
130
|
+
self,
|
131
|
+
timesteps=12,
|
132
|
+
num_classes=565,
|
133
|
+
add_feats=100,
|
134
|
+
lateral_connections=True,
|
135
|
+
topdown_connections=True,
|
136
|
+
skip_connections=True,
|
137
|
+
bio_unroll=True,
|
138
|
+
image_size=224,
|
139
|
+
hook_type='None',
|
140
|
+
):
|
141
|
+
super(BLT_VS, self).__init__()
|
142
|
+
|
143
|
+
self.timesteps = timesteps
|
144
|
+
self.num_classes = num_classes
|
145
|
+
self.add_feats = add_feats
|
146
|
+
self.lateral_connections = lateral_connections
|
147
|
+
self.topdown_connections = topdown_connections
|
148
|
+
self.skip_connections = skip_connections
|
149
|
+
self.bio_unroll = bio_unroll
|
150
|
+
self.image_size = image_size
|
151
|
+
self.hook_type = hook_type
|
152
|
+
|
153
|
+
# Define network areas and configurations
|
154
|
+
self.areas = ["Retina", "LGN", "V1", "V2", "V3", "V4", "LOC", "Readout"]
|
155
|
+
|
156
|
+
if image_size == 224:
|
157
|
+
self.kernel_sizes = [7, 7, 5, 1, 5, 3, 3, 5]
|
158
|
+
self.kernel_sizes_lateral = [0, 0, 5, 5, 5, 5, 5, 0]
|
159
|
+
else:
|
160
|
+
self.kernel_sizes = [5, 3, 3, 1, 3, 3, 3, 3]
|
161
|
+
self.kernel_sizes_lateral = [0, 0, 3, 3, 3, 3, 3, 0]
|
162
|
+
|
163
|
+
self.strides = [2, 2, 2, 1, 1, 1, 2, 2]
|
164
|
+
self.paddings = (np.array(self.kernel_sizes) - 1) // 2 # For 'same' padding
|
165
|
+
self.channel_sizes = [
|
166
|
+
32,
|
167
|
+
32,
|
168
|
+
576,
|
169
|
+
480,
|
170
|
+
352,
|
171
|
+
256,
|
172
|
+
352,
|
173
|
+
int(num_classes + add_feats),
|
174
|
+
]
|
175
|
+
|
176
|
+
# Top-down connections configuration
|
177
|
+
self.topdown_connections_layers = [
|
178
|
+
False,
|
179
|
+
True,
|
180
|
+
True,
|
181
|
+
True,
|
182
|
+
True,
|
183
|
+
True,
|
184
|
+
True,
|
185
|
+
False,
|
186
|
+
]
|
187
|
+
|
188
|
+
# Initialize network layers
|
189
|
+
self.connections = nn.ModuleDict()
|
190
|
+
for idx in range(len(self.areas) - 1):
|
191
|
+
area = self.areas[idx]
|
192
|
+
self.connections[area] = BLT_VS_Layer(
|
193
|
+
layer_n=idx,
|
194
|
+
channel_sizes=self.channel_sizes,
|
195
|
+
strides=self.strides,
|
196
|
+
kernel_sizes=self.kernel_sizes,
|
197
|
+
kernel_sizes_lateral=self.kernel_sizes_lateral,
|
198
|
+
paddings=self.paddings,
|
199
|
+
lateral_connections=self.lateral_connections
|
200
|
+
and (self.kernel_sizes_lateral[idx] > 0),
|
201
|
+
topdown_connections=self.topdown_connections
|
202
|
+
and self.topdown_connections_layers[idx],
|
203
|
+
skip_connections_bu=self.skip_connections and (idx == 5),
|
204
|
+
skip_connections_td=self.skip_connections and (idx == 2),
|
205
|
+
image_size=image_size,
|
206
|
+
)
|
207
|
+
self.connections["Readout"] = BLT_VS_Readout(
|
208
|
+
layer_n=7,
|
209
|
+
channel_sizes=self.channel_sizes,
|
210
|
+
kernel_sizes=self.kernel_sizes,
|
211
|
+
strides=self.strides,
|
212
|
+
num_classes=num_classes,
|
213
|
+
)
|
214
|
+
|
215
|
+
# Create nn.identity for each area for each timesteps such that hooks can be registered to acquire bu and td for any area/timestep
|
216
|
+
if self.hook_type != 'None':
|
217
|
+
for area in self.areas:
|
218
|
+
for t in range(timesteps):
|
219
|
+
if self.hook_type == 'concat' and area != 'Readout': # we can't concat for readout
|
220
|
+
setattr(self, f"{area}_{t}", nn.Identity())
|
221
|
+
else:
|
222
|
+
setattr(self, f"{area}_{t}_BU", nn.Identity())
|
223
|
+
setattr(self, f"{area}_{t}_TD", nn.Identity())
|
224
|
+
setattr(self, "logits", nn.Identity())
|
225
|
+
|
226
|
+
# Precompute output shapes
|
227
|
+
self.output_shapes = self.compute_output_shapes(image_size)
|
228
|
+
|
229
|
+
def compute_output_shapes(self, image_size):
|
230
|
+
"""
|
231
|
+
Compute the output shapes for each area based on the image size.
|
232
|
+
|
233
|
+
Parameters:
|
234
|
+
-----------
|
235
|
+
image_size : int
|
236
|
+
The input image size.
|
237
|
+
|
238
|
+
Returns:
|
239
|
+
--------
|
240
|
+
output_shapes : list of tuples
|
241
|
+
The output height and width for each area.
|
242
|
+
"""
|
243
|
+
output_shapes = []
|
244
|
+
height = width = image_size
|
245
|
+
for idx in range(len(self.areas)):
|
246
|
+
kernel_size = self.kernel_sizes[idx]
|
247
|
+
stride = self.strides[idx]
|
248
|
+
padding = self.paddings[idx]
|
249
|
+
height = (height + 2 * padding - kernel_size) // stride + 1
|
250
|
+
width = (width + 2 * padding - kernel_size) // stride + 1
|
251
|
+
output_shapes.append((int(height), int(width)))
|
252
|
+
return output_shapes
|
253
|
+
|
254
|
+
def forward(
|
255
|
+
self,
|
256
|
+
img_input,
|
257
|
+
extract_actvs=False,
|
258
|
+
areas=None,
|
259
|
+
timesteps=None,
|
260
|
+
bu=True,
|
261
|
+
td=True,
|
262
|
+
concat=False,
|
263
|
+
):
|
264
|
+
"""
|
265
|
+
Forward pass for the BLT_VS model.
|
266
|
+
|
267
|
+
Parameters:
|
268
|
+
-----------
|
269
|
+
img_input : torch.Tensor
|
270
|
+
Input image tensor.
|
271
|
+
extract_actvs : bool
|
272
|
+
Whether to extract activations.
|
273
|
+
areas : list of str
|
274
|
+
List of area names to retrieve activations from.
|
275
|
+
timesteps : list of int
|
276
|
+
List of timesteps to retrieve activations at.
|
277
|
+
bu : bool
|
278
|
+
Whether to retrieve bottom-up activations.
|
279
|
+
td : bool
|
280
|
+
Whether to retrieve top-down activations.
|
281
|
+
concat : bool
|
282
|
+
Whether to concatenate BU and TD activations.
|
283
|
+
|
284
|
+
Returns:
|
285
|
+
--------
|
286
|
+
If extract_actvs is False:
|
287
|
+
readout_output : list of torch.Tensor
|
288
|
+
The readout outputs at each timestep.
|
289
|
+
If extract_actvs is True:
|
290
|
+
(readout_output, activations) : tuple
|
291
|
+
readout_output is as above.
|
292
|
+
activations is a dict with structure activations[area][timestep] = activation
|
293
|
+
"""
|
294
|
+
# check if input has 4 dims, else add batch dim
|
295
|
+
if len(img_input.shape) == 3:
|
296
|
+
img_input = img_input.unsqueeze(0)
|
297
|
+
|
298
|
+
if extract_actvs:
|
299
|
+
if areas is None or timesteps is None:
|
300
|
+
raise ValueError(
|
301
|
+
"When extract_actvs is True, areas and timesteps must be specified."
|
302
|
+
)
|
303
|
+
activations = {area: {} for area in areas}
|
304
|
+
else:
|
305
|
+
activations = None
|
306
|
+
|
307
|
+
readout_output = []
|
308
|
+
bu_activations = [None for _ in self.areas]
|
309
|
+
td_activations = [None for _ in self.areas]
|
310
|
+
batch_size = img_input.size(0)
|
311
|
+
|
312
|
+
if self.bio_unroll:
|
313
|
+
# Implement the bio_unroll forward pass
|
314
|
+
bu_activations_old = [None for _ in self.areas]
|
315
|
+
td_activations_old = [None for _ in self.areas]
|
316
|
+
|
317
|
+
# Initial activation for Retina
|
318
|
+
bu_activations_old[0], _ = self.connections["Retina"](bu_input=img_input)
|
319
|
+
bu_activations[0] = bu_activations_old[0]
|
320
|
+
|
321
|
+
# Timestep 0 (if 0 is in timesteps)
|
322
|
+
t = 0
|
323
|
+
activations = self.activation_shenanigans(
|
324
|
+
extract_actvs, areas, timesteps, bu, td, concat, batch_size, bu_activations, td_activations, activations, t
|
325
|
+
)
|
326
|
+
|
327
|
+
for t in range(1, self.timesteps):
|
328
|
+
# For each timestep, update the outputs of the areas
|
329
|
+
for idx, area in enumerate(self.areas[1:-1]):
|
330
|
+
# Update only if necessary
|
331
|
+
should_update = any(
|
332
|
+
[
|
333
|
+
bu_activations_old[idx] is not None, # bottom-up connection
|
334
|
+
(bu_activations_old[2] is not None and (idx + 1) == 5), # skip connection bu
|
335
|
+
td_activations_old[idx + 2] is not None, # top-down connection
|
336
|
+
(td_activations_old[5] is not None and (idx + 1) == 2), # skip connection td
|
337
|
+
]
|
338
|
+
)
|
339
|
+
if should_update:
|
340
|
+
bu_act, td_act = self.connections[area](
|
341
|
+
bu_input=bu_activations_old[idx],
|
342
|
+
bu_l_input=bu_activations_old[idx + 1],
|
343
|
+
td_input=td_activations_old[idx + 2],
|
344
|
+
td_l_input=td_activations_old[idx + 1],
|
345
|
+
bu_skip_input=bu_activations_old[2]
|
346
|
+
if (idx + 1) == 5
|
347
|
+
else None,
|
348
|
+
td_skip_input=td_activations_old[5]
|
349
|
+
if (idx + 1) == 2
|
350
|
+
else None,
|
351
|
+
)
|
352
|
+
bu_activations[idx + 1] = bu_act
|
353
|
+
td_activations[idx + 1] = td_act
|
354
|
+
|
355
|
+
bu_activations_old = bu_activations[:]
|
356
|
+
td_activations_old = td_activations[:]
|
357
|
+
|
358
|
+
# Activate readout when LOC output is ready
|
359
|
+
if bu_activations_old[-2] is not None:
|
360
|
+
bu_act, td_act = self.connections["Readout"](
|
361
|
+
bu_input=bu_activations_old[-2]
|
362
|
+
)
|
363
|
+
bu_activations_old[-1] = bu_act
|
364
|
+
td_activations_old[-1] = td_act
|
365
|
+
readout_output.append(bu_act)
|
366
|
+
bu_activations[-1] = bu_act
|
367
|
+
td_activations[-1] = td_act
|
368
|
+
|
369
|
+
activations = self.activation_shenanigans(
|
370
|
+
extract_actvs, areas, timesteps, bu, td, concat, batch_size, bu_activations, td_activations, activations, t
|
371
|
+
)
|
372
|
+
|
373
|
+
else:
|
374
|
+
# Implement the standard forward pass
|
375
|
+
bu_activations[0], _ = self.connections["Retina"](bu_input=img_input)
|
376
|
+
for idx, area in enumerate(self.areas[1:-1]):
|
377
|
+
bu_act, _ = self.connections[area](
|
378
|
+
bu_input=bu_activations[idx],
|
379
|
+
bu_skip_input=bu_activations[2] if idx + 1 == 5 else None,
|
380
|
+
)
|
381
|
+
bu_activations[idx + 1] = bu_act
|
382
|
+
|
383
|
+
bu_act, td_act = self.connections["Readout"](bu_input=bu_activations[-2])
|
384
|
+
bu_activations[-1] = bu_act
|
385
|
+
td_activations[-1] = td_act
|
386
|
+
readout_output.append(bu_act)
|
387
|
+
|
388
|
+
for idx,area in enumerate(reversed(self.areas[1:-1])):
|
389
|
+
_, td_act = self.connections[area](
|
390
|
+
bu_input=bu_activations[-(idx + 2) - 1],
|
391
|
+
td_input=td_activations[-(idx + 2) + 1],
|
392
|
+
td_skip_input=td_activations[5] if idx + 1 == 2 else None,
|
393
|
+
)
|
394
|
+
td_activations[-(idx + 2)] = td_act
|
395
|
+
_, td_act = self.connections["Retina"](
|
396
|
+
bu_input=img_input,
|
397
|
+
td_input=td_activations[1],
|
398
|
+
)
|
399
|
+
td_activations[0] = td_act
|
400
|
+
|
401
|
+
t = 0
|
402
|
+
activations = self.activation_shenanigans(
|
403
|
+
extract_actvs, areas, timesteps, bu, td, concat, batch_size, bu_activations, td_activations, activations, t
|
404
|
+
)
|
405
|
+
|
406
|
+
for t in range(1, self.timesteps):
|
407
|
+
# For each timestep, compute the activations
|
408
|
+
for idx, area in enumerate(self.areas[1:-1]):
|
409
|
+
bu_act, _ = self.connections[area](
|
410
|
+
bu_input=bu_activations[idx],
|
411
|
+
bu_l_input=bu_activations[idx + 1],
|
412
|
+
td_input=td_activations[idx + 2],
|
413
|
+
bu_skip_input=bu_activations[2] if idx + 1 == 5 else None,
|
414
|
+
)
|
415
|
+
bu_activations[idx + 1] = bu_act
|
416
|
+
|
417
|
+
bu_act, td_act = self.connections["Readout"](bu_input=bu_activations[-2])
|
418
|
+
bu_activations[-1] = bu_act
|
419
|
+
td_activations[-1] = td_act
|
420
|
+
readout_output.append(bu_act)
|
421
|
+
|
422
|
+
for idx,area in enumerate(reversed(self.areas[1:-1])):
|
423
|
+
_, td_act = self.connections[area](
|
424
|
+
bu_input=bu_activations[-(idx + 2) - 1],
|
425
|
+
td_input=td_activations[-(idx + 2) + 1],
|
426
|
+
td_l_input=td_activations[-(idx + 2)],
|
427
|
+
td_skip_input=td_activations[5] if idx + 1 == 2 else None,
|
428
|
+
)
|
429
|
+
td_activations[-(idx + 2)] = td_act
|
430
|
+
_, td_act = self.connections["Retina"](
|
431
|
+
bu_input=img_input,
|
432
|
+
td_input=td_activations[1],
|
433
|
+
td_l_input=td_activations[0],
|
434
|
+
)
|
435
|
+
td_activations[0] = td_act
|
436
|
+
|
437
|
+
activations = self.activation_shenanigans(
|
438
|
+
extract_actvs, areas, timesteps, bu, td, concat, batch_size, bu_activations, td_activations, activations, t
|
439
|
+
)
|
440
|
+
|
441
|
+
if self.hook_type != 'None':
|
442
|
+
_ = self.logits(readout_output[-1])
|
443
|
+
|
444
|
+
if extract_actvs:
|
445
|
+
return readout_output, activations
|
446
|
+
else:
|
447
|
+
return readout_output
|
448
|
+
|
449
|
+
|
450
|
+
def activation_shenanigans(
|
451
|
+
self, extract_actvs, areas, timesteps, bu, td, concat, batch_size, bu_activations, td_activations, activations, t
|
452
|
+
):
|
453
|
+
"""
|
454
|
+
Helper function to implement activation collection and compute relevant for hook registration.
|
455
|
+
|
456
|
+
Parameters:
|
457
|
+
-----------
|
458
|
+
extract_actvs : bool
|
459
|
+
Whether to extract activations.
|
460
|
+
areas : list of str
|
461
|
+
List of area names to retrieve activations from.
|
462
|
+
timesteps : list of int
|
463
|
+
List of timesteps to retrieve activations at.
|
464
|
+
bu : bool
|
465
|
+
Whether to retrieve bottom-up activations.
|
466
|
+
td : bool
|
467
|
+
Whether to retrieve top-down activations.
|
468
|
+
concat : bool
|
469
|
+
Whether to concatenate BU and TD activations.
|
470
|
+
batch_size : int
|
471
|
+
Batch size of the input data.
|
472
|
+
bu_activations : list of torch.Tensor
|
473
|
+
List of bottom-up activations.
|
474
|
+
td_activations : list of torch.Tensor
|
475
|
+
List of top-down activations.
|
476
|
+
activations : dict
|
477
|
+
Dictionary to store activations.
|
478
|
+
t : int
|
479
|
+
Current timestep.
|
480
|
+
|
481
|
+
Returns:
|
482
|
+
--------
|
483
|
+
activations : dict
|
484
|
+
Updated activations dictionary.
|
485
|
+
"""
|
486
|
+
if extract_actvs and t in timesteps:
|
487
|
+
for idx, area in enumerate(self.areas):
|
488
|
+
if area in areas:
|
489
|
+
# If concat is True and area is 'Readout', skip
|
490
|
+
if concat and area == 'Readout':
|
491
|
+
continue
|
492
|
+
activation = self.collect_activation(
|
493
|
+
bu_activations[idx],
|
494
|
+
td_activations[idx],
|
495
|
+
bu,
|
496
|
+
td,
|
497
|
+
concat,
|
498
|
+
idx,
|
499
|
+
batch_size,
|
500
|
+
)
|
501
|
+
activations[area][t] = activation
|
502
|
+
|
503
|
+
if self.hook_type != 'None':
|
504
|
+
for idx, area in enumerate(self.areas):
|
505
|
+
if self.hook_type == 'concat' and area != 'Readout':
|
506
|
+
_ = getattr(self, f"{area}_{t}")(concat_or_not(bu_activations[idx], td_activations[idx], dim=1))
|
507
|
+
elif self.hook_type == 'separate':
|
508
|
+
_ = getattr(self, f"{area}_{t}_BU")(bu_activations[idx])
|
509
|
+
_ = getattr(self, f"{area}_{t}_TD")(td_activations[idx])
|
510
|
+
|
511
|
+
return activations
|
512
|
+
|
513
|
+
|
514
|
+
def collect_activation(
|
515
|
+
self, bu_activation, td_activation, bu_flag, td_flag, concat, area_idx, batch_size
|
516
|
+
):
|
517
|
+
"""
|
518
|
+
Helper function to collect activations, handling None values and concatenation.
|
519
|
+
|
520
|
+
Parameters:
|
521
|
+
-----------
|
522
|
+
bu_activation : torch.Tensor or None
|
523
|
+
Bottom-up activation.
|
524
|
+
td_activation : torch.Tensor or None
|
525
|
+
Top-down activation.
|
526
|
+
bu_flag : bool
|
527
|
+
Whether to collect BU activations.
|
528
|
+
td_flag : bool
|
529
|
+
Whether to collect TD activations.
|
530
|
+
concat : bool
|
531
|
+
Whether to concatenate BU and TD activations.
|
532
|
+
area_idx : int
|
533
|
+
Index of the area in self.areas.
|
534
|
+
batch_size : int
|
535
|
+
Batch size of the input data.
|
536
|
+
|
537
|
+
Returns:
|
538
|
+
--------
|
539
|
+
activation : torch.Tensor or dict
|
540
|
+
The collected activation. If concat is True, returns a single tensor.
|
541
|
+
If concat is False, returns a dict with keys 'bu' and/or 'td'.
|
542
|
+
"""
|
543
|
+
device = next(self.parameters()).device # Get the device of the model
|
544
|
+
|
545
|
+
if concat:
|
546
|
+
# Handle None activations
|
547
|
+
if bu_activation is None and td_activation is None:
|
548
|
+
# Get output shape and channels
|
549
|
+
channels = self.channel_sizes[area_idx] * 2 # BU and TD activations concatenated
|
550
|
+
height, width = self.output_shapes[area_idx]
|
551
|
+
zeros = torch.zeros((batch_size, channels, height, width), device=device)
|
552
|
+
return zeros
|
553
|
+
if bu_activation is None:
|
554
|
+
bu_activation = torch.zeros_like(td_activation)
|
555
|
+
if td_activation is None:
|
556
|
+
td_activation = torch.zeros_like(bu_activation)
|
557
|
+
activation = torch.cat([bu_activation, td_activation], dim=1)
|
558
|
+
return activation
|
559
|
+
else:
|
560
|
+
activation = {}
|
561
|
+
if bu_flag:
|
562
|
+
if bu_activation is not None:
|
563
|
+
activation['bu'] = bu_activation
|
564
|
+
elif td_activation is not None:
|
565
|
+
activation['bu'] = torch.zeros_like(td_activation)
|
566
|
+
else:
|
567
|
+
# Create zeros of appropriate shape
|
568
|
+
channels = self.channel_sizes[area_idx]
|
569
|
+
height, width = self.output_shapes[area_idx]
|
570
|
+
activation['bu'] = torch.zeros(
|
571
|
+
(batch_size, channels, height, width), device=device
|
572
|
+
)
|
573
|
+
if td_flag:
|
574
|
+
if td_activation is not None:
|
575
|
+
activation['td'] = td_activation
|
576
|
+
elif bu_activation is not None:
|
577
|
+
activation['td'] = torch.zeros_like(bu_activation)
|
578
|
+
else:
|
579
|
+
channels = self.channel_sizes[area_idx]
|
580
|
+
height, width = self.output_shapes[area_idx]
|
581
|
+
activation['td'] = torch.zeros(
|
582
|
+
(batch_size, channels, height, width), device=device
|
583
|
+
)
|
584
|
+
return activation
|
585
|
+
|
586
|
+
|
587
|
+
class BLT_VS_Layer(nn.Module):
|
588
|
+
"""
|
589
|
+
A single layer in the BLT_VS model, representing a cortical area.
|
590
|
+
|
591
|
+
Parameters:
|
592
|
+
-----------
|
593
|
+
layer_n : int
|
594
|
+
Layer index.
|
595
|
+
channel_sizes : list
|
596
|
+
List of channel sizes for each layer.
|
597
|
+
strides : list
|
598
|
+
List of strides for each layer.
|
599
|
+
kernel_sizes : list
|
600
|
+
List of kernel sizes for each layer.
|
601
|
+
kernel_sizes_lateral : list
|
602
|
+
List of lateral kernel sizes for each layer.
|
603
|
+
paddings : list
|
604
|
+
List of paddings for each layer.
|
605
|
+
lateral_connections : bool
|
606
|
+
Whether to include lateral connections.
|
607
|
+
topdown_connections : bool
|
608
|
+
Whether to include top-down connections.
|
609
|
+
skip_connections_bu : bool
|
610
|
+
Whether to include bottom-up skip connections.
|
611
|
+
skip_connections_td : bool
|
612
|
+
Whether to include top-down skip connections.
|
613
|
+
image_size : int
|
614
|
+
Size of the input image (height and width).
|
615
|
+
"""
|
616
|
+
|
617
|
+
def __init__(
|
618
|
+
self,
|
619
|
+
layer_n,
|
620
|
+
channel_sizes,
|
621
|
+
strides,
|
622
|
+
kernel_sizes,
|
623
|
+
kernel_sizes_lateral,
|
624
|
+
paddings,
|
625
|
+
lateral_connections=True,
|
626
|
+
topdown_connections=True,
|
627
|
+
skip_connections_bu=False,
|
628
|
+
skip_connections_td=False,
|
629
|
+
image_size=224,
|
630
|
+
):
|
631
|
+
super(BLT_VS_Layer, self).__init__()
|
632
|
+
|
633
|
+
in_channels = 3 if layer_n == 0 else channel_sizes[layer_n - 1]
|
634
|
+
out_channels = channel_sizes[layer_n]
|
635
|
+
|
636
|
+
# Bottom-up convolution
|
637
|
+
self.bu_conv = nn.Conv2d(
|
638
|
+
in_channels=in_channels,
|
639
|
+
out_channels=out_channels,
|
640
|
+
kernel_size=kernel_sizes[layer_n],
|
641
|
+
stride=strides[layer_n],
|
642
|
+
padding=paddings[layer_n],
|
643
|
+
)
|
644
|
+
|
645
|
+
# Lateral connections
|
646
|
+
if lateral_connections:
|
647
|
+
kernel_size_lateral = kernel_sizes_lateral[layer_n]
|
648
|
+
self.bu_l_conv_depthwise = nn.Conv2d(
|
649
|
+
in_channels=out_channels,
|
650
|
+
out_channels=out_channels,
|
651
|
+
kernel_size=kernel_size_lateral,
|
652
|
+
stride=1,
|
653
|
+
padding='same',
|
654
|
+
groups=out_channels,
|
655
|
+
)
|
656
|
+
self.bu_l_conv_pointwise = nn.Conv2d(
|
657
|
+
in_channels=out_channels,
|
658
|
+
out_channels=out_channels,
|
659
|
+
kernel_size=1,
|
660
|
+
stride=1,
|
661
|
+
padding=0,
|
662
|
+
)
|
663
|
+
else:
|
664
|
+
self.bu_l_conv_depthwise = NoOpModule()
|
665
|
+
self.bu_l_conv_pointwise = NoOpModule()
|
666
|
+
|
667
|
+
# Top-down connections
|
668
|
+
if topdown_connections:
|
669
|
+
self.td_conv = nn.ConvTranspose2d(
|
670
|
+
in_channels=channel_sizes[layer_n + 1],
|
671
|
+
out_channels=out_channels,
|
672
|
+
kernel_size=kernel_sizes[layer_n + 1],
|
673
|
+
stride=strides[layer_n + 1],
|
674
|
+
padding=(kernel_sizes[layer_n + 1] - 1) // 2
|
675
|
+
)
|
676
|
+
if lateral_connections:
|
677
|
+
self.td_l_conv_depthwise = nn.Conv2d(
|
678
|
+
in_channels=out_channels,
|
679
|
+
out_channels=out_channels,
|
680
|
+
kernel_size=kernel_sizes_lateral[layer_n],
|
681
|
+
stride=1,
|
682
|
+
padding='same',
|
683
|
+
groups=out_channels,
|
684
|
+
)
|
685
|
+
self.td_l_conv_pointwise = nn.Conv2d(
|
686
|
+
in_channels=out_channels,
|
687
|
+
out_channels=out_channels,
|
688
|
+
kernel_size=1,
|
689
|
+
stride=1,
|
690
|
+
padding=0,
|
691
|
+
)
|
692
|
+
else:
|
693
|
+
self.td_l_conv_depthwise = NoOpModule()
|
694
|
+
self.td_l_conv_pointwise = NoOpModule()
|
695
|
+
else:
|
696
|
+
self.td_conv = NoOpModule()
|
697
|
+
self.td_l_conv_depthwise = NoOpModule()
|
698
|
+
self.td_l_conv_pointwise = NoOpModule()
|
699
|
+
|
700
|
+
# Skip connections
|
701
|
+
if skip_connections_bu:
|
702
|
+
self.skip_bu_depthwise = nn.Conv2d(
|
703
|
+
in_channels=channel_sizes[2], # From V1
|
704
|
+
out_channels=out_channels,
|
705
|
+
kernel_size=7 if image_size == 224 else 5,
|
706
|
+
stride=1,
|
707
|
+
padding='same',
|
708
|
+
groups=np.gcd(channel_sizes[2], out_channels),
|
709
|
+
)
|
710
|
+
self.skip_bu_pointwise = nn.Conv2d(
|
711
|
+
in_channels=out_channels,
|
712
|
+
out_channels=out_channels,
|
713
|
+
kernel_size=1,
|
714
|
+
stride=1,
|
715
|
+
padding=0,
|
716
|
+
)
|
717
|
+
else:
|
718
|
+
self.skip_bu_depthwise = NoOpModule()
|
719
|
+
self.skip_bu_pointwise = NoOpModule()
|
720
|
+
|
721
|
+
if skip_connections_td:
|
722
|
+
self.skip_td_depthwise = nn.Conv2d(
|
723
|
+
in_channels=channel_sizes[5], # From V4
|
724
|
+
out_channels=out_channels,
|
725
|
+
kernel_size=3, # V4 to V1 skip connection
|
726
|
+
stride=1,
|
727
|
+
padding='same',
|
728
|
+
groups=np.gcd(channel_sizes[5], out_channels),
|
729
|
+
)
|
730
|
+
self.skip_td_pointwise = nn.Conv2d(
|
731
|
+
in_channels=out_channels,
|
732
|
+
out_channels=out_channels,
|
733
|
+
kernel_size=1,
|
734
|
+
stride=1,
|
735
|
+
padding=0,
|
736
|
+
)
|
737
|
+
else:
|
738
|
+
self.skip_td_depthwise = NoOpModule()
|
739
|
+
self.skip_td_pointwise = NoOpModule()
|
740
|
+
|
741
|
+
self.layer_norm_bu = nn.GroupNorm(num_groups=1, num_channels=out_channels)
|
742
|
+
self.layer_norm_td = nn.GroupNorm(num_groups=1, num_channels=out_channels)
|
743
|
+
|
744
|
+
def forward(
|
745
|
+
self,
|
746
|
+
bu_input,
|
747
|
+
bu_l_input=None,
|
748
|
+
td_input=None,
|
749
|
+
td_l_input=None,
|
750
|
+
bu_skip_input=None,
|
751
|
+
td_skip_input=None,
|
752
|
+
):
|
753
|
+
"""
|
754
|
+
Forward pass for a single BLT_VS layer.
|
755
|
+
|
756
|
+
Parameters:
|
757
|
+
-----------
|
758
|
+
bu_input : torch.Tensor or None
|
759
|
+
Bottom-up input tensor.
|
760
|
+
bu_l_input : torch.Tensor or None
|
761
|
+
Bottom-up lateral input tensor.
|
762
|
+
td_input : torch.Tensor or None
|
763
|
+
Top-down input tensor.
|
764
|
+
td_l_input : torch.Tensor or None
|
765
|
+
Top-down lateral input tensor.
|
766
|
+
bu_skip_input : torch.Tensor or None
|
767
|
+
Bottom-up skip connection input.
|
768
|
+
td_skip_input : torch.Tensor or None
|
769
|
+
Top-down skip connection input.
|
770
|
+
|
771
|
+
Returns:
|
772
|
+
--------
|
773
|
+
bu_output : torch.Tensor
|
774
|
+
Bottom-up output tensor.
|
775
|
+
td_output : torch.Tensor
|
776
|
+
Top-down output tensor.
|
777
|
+
"""
|
778
|
+
# Process bottom-up input
|
779
|
+
bu_processed = self.bu_conv(bu_input) if bu_input is not None else 0
|
780
|
+
|
781
|
+
# Process top-down input
|
782
|
+
td_processed = (
|
783
|
+
self.td_conv(td_input, output_size=bu_processed.size())
|
784
|
+
if td_input is not None
|
785
|
+
else 0
|
786
|
+
)
|
787
|
+
|
788
|
+
# Process bottom-up lateral input
|
789
|
+
bu_l_processed = (
|
790
|
+
self.bu_l_conv_pointwise(self.bu_l_conv_depthwise(bu_l_input))
|
791
|
+
if bu_l_input is not None
|
792
|
+
else 0
|
793
|
+
)
|
794
|
+
|
795
|
+
# Process top-down lateral input
|
796
|
+
td_l_processed = (
|
797
|
+
self.td_l_conv_pointwise(self.td_l_conv_depthwise(td_l_input))
|
798
|
+
if td_l_input is not None
|
799
|
+
else 0
|
800
|
+
)
|
801
|
+
|
802
|
+
# Process skip connections
|
803
|
+
skip_bu_processed = (
|
804
|
+
self.skip_bu_pointwise(self.skip_bu_depthwise(bu_skip_input))
|
805
|
+
if bu_skip_input is not None
|
806
|
+
else 0
|
807
|
+
)
|
808
|
+
skip_td_processed = (
|
809
|
+
self.skip_td_pointwise(self.skip_td_depthwise(td_skip_input))
|
810
|
+
if td_skip_input is not None
|
811
|
+
else 0
|
812
|
+
)
|
813
|
+
|
814
|
+
# Compute sums
|
815
|
+
bu_drive = bu_processed + bu_l_processed + skip_bu_processed
|
816
|
+
bu_mod = bu_processed + skip_bu_processed
|
817
|
+
td_drive = td_processed + td_l_processed + skip_td_processed
|
818
|
+
td_mod = td_processed + skip_td_processed
|
819
|
+
|
820
|
+
# Compute bottom-up output
|
821
|
+
if isinstance(td_mod, torch.Tensor):
|
822
|
+
if isinstance(bu_drive, torch.Tensor):
|
823
|
+
bu_output = F.relu(bu_drive) * 2 * torch.sigmoid(td_mod)
|
824
|
+
else:
|
825
|
+
bu_output = torch.zeros_like(td_mod)
|
826
|
+
else:
|
827
|
+
bu_output = F.relu(bu_drive)
|
828
|
+
|
829
|
+
# Compute top-down output
|
830
|
+
if isinstance(bu_mod, torch.Tensor):
|
831
|
+
if isinstance(td_drive, torch.Tensor):
|
832
|
+
td_output = F.relu(td_drive) * 2 * torch.sigmoid(bu_mod)
|
833
|
+
else:
|
834
|
+
td_output = torch.zeros_like(bu_mod)
|
835
|
+
else:
|
836
|
+
td_output = F.relu(td_drive)
|
837
|
+
|
838
|
+
bu_output = self.layer_norm_bu(bu_output)
|
839
|
+
td_output = self.layer_norm_td(td_output)
|
840
|
+
|
841
|
+
return bu_output, td_output
|
842
|
+
|
843
|
+
|
844
|
+
class BLT_VS_Readout(nn.Module):
|
845
|
+
"""
|
846
|
+
Readout layer for the BLT_VS model.
|
847
|
+
|
848
|
+
Parameters:
|
849
|
+
-----------
|
850
|
+
layer_n : int
|
851
|
+
Layer index.
|
852
|
+
channel_sizes : list
|
853
|
+
List of channel sizes for each layer.
|
854
|
+
kernel_sizes : list
|
855
|
+
List of kernel sizes for each layer.
|
856
|
+
strides : list
|
857
|
+
List of strides for each layer.
|
858
|
+
num_classes : int
|
859
|
+
Number of output classes for classification.
|
860
|
+
"""
|
861
|
+
|
862
|
+
def __init__(self, layer_n, channel_sizes, kernel_sizes, strides, num_classes):
|
863
|
+
super(BLT_VS_Readout, self).__init__()
|
864
|
+
|
865
|
+
self.num_classes = num_classes
|
866
|
+
in_channels = channel_sizes[layer_n - 1]
|
867
|
+
out_channels = channel_sizes[layer_n]
|
868
|
+
|
869
|
+
self.readout_conv = nn.Conv2d(
|
870
|
+
in_channels=in_channels,
|
871
|
+
out_channels=out_channels,
|
872
|
+
kernel_size=kernel_sizes[layer_n],
|
873
|
+
stride=strides[layer_n],
|
874
|
+
padding=(kernel_sizes[layer_n] - 1) // 2,
|
875
|
+
)
|
876
|
+
|
877
|
+
self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
878
|
+
self.layer_norm_td = nn.GroupNorm(num_groups=1, num_channels=out_channels)
|
879
|
+
|
880
|
+
def forward(self, bu_input):
|
881
|
+
"""
|
882
|
+
Forward pass for the Readout layer.
|
883
|
+
|
884
|
+
Parameters:
|
885
|
+
-----------
|
886
|
+
bu_input : torch.Tensor
|
887
|
+
Bottom-up input tensor.
|
888
|
+
|
889
|
+
Returns:
|
890
|
+
--------
|
891
|
+
output : torch.Tensor
|
892
|
+
Class scores for classification.
|
893
|
+
td_output : torch.Tensor
|
894
|
+
Top-down output tensor.
|
895
|
+
"""
|
896
|
+
output_intermediate = self.readout_conv(bu_input)
|
897
|
+
output_pooled = self.global_avg_pool(output_intermediate).view(
|
898
|
+
output_intermediate.size(0), -1
|
899
|
+
)
|
900
|
+
output = output_pooled[
|
901
|
+
:, : self.num_classes
|
902
|
+
] # Only pass classes to softmax and loss
|
903
|
+
td_output = self.layer_norm_td(F.relu(output_intermediate))
|
904
|
+
|
905
|
+
return output, td_output
|
906
|
+
|
907
|
+
|
908
|
+
class NoOpModule(nn.Module):
|
909
|
+
"""
|
910
|
+
A no-operation module that returns zero regardless of the input.
|
911
|
+
|
912
|
+
This is used in places where an operation is conditionally skipped.
|
913
|
+
"""
|
914
|
+
|
915
|
+
def __init__(self):
|
916
|
+
super(NoOpModule, self).__init__()
|
917
|
+
|
918
|
+
def forward(self, *args, **kwargs):
|
919
|
+
"""
|
920
|
+
Forward pass that returns zero.
|
921
|
+
|
922
|
+
Returns:
|
923
|
+
--------
|
924
|
+
Zero tensor or zero value as appropriate.
|
925
|
+
"""
|
926
|
+
return 0
|
927
|
+
|
928
|
+
def concat_or_not(bu_activation, td_activation, dim=1):
|
929
|
+
# If both are None, return None
|
930
|
+
if bu_activation is None and td_activation is None:
|
931
|
+
return None
|
932
|
+
|
933
|
+
# If bu_activation is None, create a tensor of zeros like td_activation
|
934
|
+
if bu_activation is None:
|
935
|
+
bu_activation = torch.zeros_like(td_activation)
|
936
|
+
|
937
|
+
# If td_activation is None, create a tensor of zeros like bu_activation
|
938
|
+
if td_activation is None:
|
939
|
+
td_activation = torch.zeros_like(bu_activation)
|
940
|
+
|
941
|
+
# Concatenate along the specified dimension
|
942
|
+
return torch.cat([bu_activation, td_activation], dim=dim)
|
943
|
+
|
944
|
+
def get_layers(model_name):
|
945
|
+
|
946
|
+
brainscore_layers = LAYERS
|
947
|
+
|
948
|
+
return brainscore_layers
|
949
|
+
|
950
|
+
def get_bibtex(model_identifier):
|
951
|
+
"""
|
952
|
+
A method returning the bibtex reference of the requested model as a string.
|
953
|
+
"""
|
954
|
+
|
955
|
+
return ''
|
956
|
+
|
957
|
+
if __name__ == '__main__':
|
958
|
+
# Use this method to ensure the correctness of the BaseModel implementations.
|
959
|
+
# It executes a mock run of brain-score benchmarks.
|
960
|
+
check_models.check_base_models(__name__)
|
961
|
+
|
962
|
+
|