brainscore-vision 2.2.2__py3-none-any.whl → 2.2.4__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (213) hide show
  1. brainscore_vision/models/alexnet_less_variation_1/__init__.py +6 -0
  2. brainscore_vision/models/alexnet_less_variation_1/model.py +200 -0
  3. brainscore_vision/models/alexnet_less_variation_1/region_layer_map/alexnet_less_variation_iteration=1.json +6 -0
  4. brainscore_vision/models/alexnet_less_variation_1/setup.py +29 -0
  5. brainscore_vision/models/alexnet_less_variation_1/test.py +3 -0
  6. brainscore_vision/models/alexnet_less_variation_2/__init__.py +6 -0
  7. brainscore_vision/models/alexnet_less_variation_2/model.py +200 -0
  8. brainscore_vision/models/alexnet_less_variation_2/region_layer_map/alexnet_less_variation_iteration=2.json +6 -0
  9. brainscore_vision/models/alexnet_less_variation_2/setup.py +29 -0
  10. brainscore_vision/models/alexnet_less_variation_2/test.py +3 -0
  11. brainscore_vision/models/alexnet_less_variation_4/__init__.py +6 -0
  12. brainscore_vision/models/alexnet_less_variation_4/model.py +200 -0
  13. brainscore_vision/models/alexnet_less_variation_4/region_layer_map/alexnet_less_variation_iteration=4.json +6 -0
  14. brainscore_vision/models/alexnet_less_variation_4/setup.py +29 -0
  15. brainscore_vision/models/alexnet_less_variation_4/test.py +3 -0
  16. brainscore_vision/models/alexnet_no_specular_2/__init__.py +6 -0
  17. brainscore_vision/models/alexnet_no_specular_2/model.py +200 -0
  18. brainscore_vision/models/alexnet_no_specular_2/region_layer_map/alexnet_no_specular_iteration=2.json +6 -0
  19. brainscore_vision/models/alexnet_no_specular_2/setup.py +29 -0
  20. brainscore_vision/models/alexnet_no_specular_2/test.py +3 -0
  21. brainscore_vision/models/alexnet_no_specular_4/__init__.py +6 -0
  22. brainscore_vision/models/alexnet_no_specular_4/model.py +200 -0
  23. brainscore_vision/models/alexnet_no_specular_4/region_layer_map/alexnet_no_specular_iteration=4.json +6 -0
  24. brainscore_vision/models/alexnet_no_specular_4/setup.py +29 -0
  25. brainscore_vision/models/alexnet_no_specular_4/test.py +3 -0
  26. brainscore_vision/models/alexnet_no_variation_4/__init__.py +6 -0
  27. brainscore_vision/models/alexnet_no_variation_4/model.py +200 -0
  28. brainscore_vision/models/alexnet_no_variation_4/region_layer_map/alexnet_no_variation_iteration=4.json +6 -0
  29. brainscore_vision/models/alexnet_no_variation_4/setup.py +29 -0
  30. brainscore_vision/models/alexnet_no_variation_4/test.py +3 -0
  31. brainscore_vision/models/alexnet_original_3/__init__.py +6 -0
  32. brainscore_vision/models/alexnet_original_3/model.py +200 -0
  33. brainscore_vision/models/alexnet_original_3/region_layer_map/alexnet_original_iteration=3.json +6 -0
  34. brainscore_vision/models/alexnet_original_3/setup.py +29 -0
  35. brainscore_vision/models/alexnet_original_3/test.py +3 -0
  36. brainscore_vision/models/alexnet_wo_shading_4/__init__.py +6 -0
  37. brainscore_vision/models/alexnet_wo_shading_4/model.py +200 -0
  38. brainscore_vision/models/alexnet_wo_shading_4/region_layer_map/alexnet_wo_shading_iteration=4.json +6 -0
  39. brainscore_vision/models/alexnet_wo_shading_4/setup.py +29 -0
  40. brainscore_vision/models/alexnet_wo_shading_4/test.py +3 -0
  41. brainscore_vision/models/alexnet_wo_shadows_5/__init__.py +6 -0
  42. brainscore_vision/models/alexnet_wo_shadows_5/model.py +200 -0
  43. brainscore_vision/models/alexnet_wo_shadows_5/region_layer_map/alexnet_wo_shadows_iteration=5.json +6 -0
  44. brainscore_vision/models/alexnet_wo_shadows_5/setup.py +29 -0
  45. brainscore_vision/models/alexnet_wo_shadows_5/test.py +3 -0
  46. brainscore_vision/models/alexnet_z_axis_1/__init__.py +6 -0
  47. brainscore_vision/models/alexnet_z_axis_1/model.py +200 -0
  48. brainscore_vision/models/alexnet_z_axis_1/region_layer_map/alexnet_z_axis_iteration=1.json +6 -0
  49. brainscore_vision/models/alexnet_z_axis_1/setup.py +29 -0
  50. brainscore_vision/models/alexnet_z_axis_1/test.py +3 -0
  51. brainscore_vision/models/alexnet_z_axis_2/__init__.py +6 -0
  52. brainscore_vision/models/alexnet_z_axis_2/model.py +200 -0
  53. brainscore_vision/models/alexnet_z_axis_2/region_layer_map/alexnet_z_axis_iteration=2.json +6 -0
  54. brainscore_vision/models/alexnet_z_axis_2/setup.py +29 -0
  55. brainscore_vision/models/alexnet_z_axis_2/test.py +3 -0
  56. brainscore_vision/models/alexnet_z_axis_3/__init__.py +6 -0
  57. brainscore_vision/models/alexnet_z_axis_3/model.py +200 -0
  58. brainscore_vision/models/alexnet_z_axis_3/region_layer_map/alexnet_z_axis_iteration=3.json +6 -0
  59. brainscore_vision/models/alexnet_z_axis_3/setup.py +29 -0
  60. brainscore_vision/models/alexnet_z_axis_3/test.py +3 -0
  61. brainscore_vision/models/alexnet_z_axis_4/__init__.py +6 -0
  62. brainscore_vision/models/alexnet_z_axis_4/model.py +200 -0
  63. brainscore_vision/models/alexnet_z_axis_4/region_layer_map/alexnet_z_axis_iteration=4.json +6 -0
  64. brainscore_vision/models/alexnet_z_axis_4/setup.py +29 -0
  65. brainscore_vision/models/alexnet_z_axis_4/test.py +3 -0
  66. brainscore_vision/models/artResNet18_1/__init__.py +5 -0
  67. brainscore_vision/models/artResNet18_1/model.py +66 -0
  68. brainscore_vision/models/artResNet18_1/requirements.txt +4 -0
  69. brainscore_vision/models/artResNet18_1/test.py +12 -0
  70. brainscore_vision/models/barlow_twins_custom/__init__.py +5 -0
  71. brainscore_vision/models/barlow_twins_custom/model.py +58 -0
  72. brainscore_vision/models/barlow_twins_custom/requirements.txt +4 -0
  73. brainscore_vision/models/barlow_twins_custom/test.py +12 -0
  74. brainscore_vision/models/blt-vs/__init__.py +15 -0
  75. brainscore_vision/models/blt-vs/model.py +962 -0
  76. brainscore_vision/models/blt-vs/pretrained.py +219 -0
  77. brainscore_vision/models/blt-vs/region_layer_map/blt_vs.json +6 -0
  78. brainscore_vision/models/blt-vs/setup.py +22 -0
  79. brainscore_vision/models/blt-vs/test.py +0 -0
  80. brainscore_vision/models/cifar_resnet18_1/__init__.py +5 -0
  81. brainscore_vision/models/cifar_resnet18_1/model.py +68 -0
  82. brainscore_vision/models/cifar_resnet18_1/requirements.txt +4 -0
  83. brainscore_vision/models/cifar_resnet18_1/test.py +10 -0
  84. brainscore_vision/models/resnet18_random/__init__.py +5 -0
  85. brainscore_vision/models/resnet18_random/archive_name.zip +0 -0
  86. brainscore_vision/models/resnet18_random/model.py +42 -0
  87. brainscore_vision/models/resnet18_random/requirements.txt +2 -0
  88. brainscore_vision/models/resnet18_random/test.py +12 -0
  89. brainscore_vision/models/resnet50_less_variation_1/__init__.py +6 -0
  90. brainscore_vision/models/resnet50_less_variation_1/model.py +200 -0
  91. brainscore_vision/models/resnet50_less_variation_1/region_layer_map/resnet50_less_variation_iteration=1.json +6 -0
  92. brainscore_vision/models/resnet50_less_variation_1/setup.py +29 -0
  93. brainscore_vision/models/resnet50_less_variation_1/test.py +3 -0
  94. brainscore_vision/models/resnet50_less_variation_2/__init__.py +6 -0
  95. brainscore_vision/models/resnet50_less_variation_2/model.py +200 -0
  96. brainscore_vision/models/resnet50_less_variation_2/region_layer_map/resnet50_less_variation_iteration=2.json +6 -0
  97. brainscore_vision/models/resnet50_less_variation_2/setup.py +29 -0
  98. brainscore_vision/models/resnet50_less_variation_2/test.py +3 -0
  99. brainscore_vision/models/resnet50_less_variation_3/__init__.py +6 -0
  100. brainscore_vision/models/resnet50_less_variation_3/model.py +200 -0
  101. brainscore_vision/models/resnet50_less_variation_3/region_layer_map/resnet50_less_variation_iteration=3.json +6 -0
  102. brainscore_vision/models/resnet50_less_variation_3/setup.py +29 -0
  103. brainscore_vision/models/resnet50_less_variation_3/test.py +3 -0
  104. brainscore_vision/models/resnet50_less_variation_4/__init__.py +6 -0
  105. brainscore_vision/models/resnet50_less_variation_4/model.py +200 -0
  106. brainscore_vision/models/resnet50_less_variation_4/region_layer_map/resnet50_less_variation_iteration=4.json +6 -0
  107. brainscore_vision/models/resnet50_less_variation_4/setup.py +29 -0
  108. brainscore_vision/models/resnet50_less_variation_4/test.py +3 -0
  109. brainscore_vision/models/resnet50_less_variation_5/__init__.py +6 -0
  110. brainscore_vision/models/resnet50_less_variation_5/model.py +200 -0
  111. brainscore_vision/models/resnet50_less_variation_5/region_layer_map/resnet50_less_variation_iteration=5.json +6 -0
  112. brainscore_vision/models/resnet50_less_variation_5/setup.py +29 -0
  113. brainscore_vision/models/resnet50_less_variation_5/test.py +3 -0
  114. brainscore_vision/models/resnet50_no_variation_1/__init__.py +6 -0
  115. brainscore_vision/models/resnet50_no_variation_1/model.py +200 -0
  116. brainscore_vision/models/resnet50_no_variation_1/region_layer_map/resnet50_no_variation_iteration=1.json +6 -0
  117. brainscore_vision/models/resnet50_no_variation_1/setup.py +29 -0
  118. brainscore_vision/models/resnet50_no_variation_1/test.py +3 -0
  119. brainscore_vision/models/resnet50_no_variation_2/__init__.py +6 -0
  120. brainscore_vision/models/resnet50_no_variation_2/model.py +200 -0
  121. brainscore_vision/models/resnet50_no_variation_2/region_layer_map/resnet50_no_variation_iteration=2.json +6 -0
  122. brainscore_vision/models/resnet50_no_variation_2/setup.py +29 -0
  123. brainscore_vision/models/resnet50_no_variation_2/test.py +3 -0
  124. brainscore_vision/models/resnet50_no_variation_5/__init__.py +6 -0
  125. brainscore_vision/models/resnet50_no_variation_5/model.py +200 -0
  126. brainscore_vision/models/resnet50_no_variation_5/region_layer_map/resnet50_no_variation_iteration=5.json +6 -0
  127. brainscore_vision/models/resnet50_no_variation_5/setup.py +29 -0
  128. brainscore_vision/models/resnet50_no_variation_5/test.py +3 -0
  129. brainscore_vision/models/resnet50_original_1/__init__.py +6 -0
  130. brainscore_vision/models/resnet50_original_1/model.py +200 -0
  131. brainscore_vision/models/resnet50_original_1/region_layer_map/resnet50_original_iteration=1.json +6 -0
  132. brainscore_vision/models/resnet50_original_1/setup.py +29 -0
  133. brainscore_vision/models/resnet50_original_1/test.py +3 -0
  134. brainscore_vision/models/resnet50_original_2/__init__.py +6 -0
  135. brainscore_vision/models/resnet50_original_2/model.py +200 -0
  136. brainscore_vision/models/resnet50_original_2/region_layer_map/resnet50_original_iteration=2.json +6 -0
  137. brainscore_vision/models/resnet50_original_2/setup.py +29 -0
  138. brainscore_vision/models/resnet50_original_2/test.py +3 -0
  139. brainscore_vision/models/resnet50_original_5/__init__.py +6 -0
  140. brainscore_vision/models/resnet50_original_5/model.py +200 -0
  141. brainscore_vision/models/resnet50_original_5/region_layer_map/resnet50_original_iteration=5.json +6 -0
  142. brainscore_vision/models/resnet50_original_5/setup.py +29 -0
  143. brainscore_vision/models/resnet50_original_5/test.py +3 -0
  144. brainscore_vision/models/resnet50_textures_1/__init__.py +6 -0
  145. brainscore_vision/models/resnet50_textures_1/model.py +200 -0
  146. brainscore_vision/models/resnet50_textures_1/region_layer_map/resnet50_textures_iteration=1.json +6 -0
  147. brainscore_vision/models/resnet50_textures_1/setup.py +29 -0
  148. brainscore_vision/models/resnet50_textures_1/test.py +3 -0
  149. brainscore_vision/models/resnet50_textures_2/__init__.py +6 -0
  150. brainscore_vision/models/resnet50_textures_2/model.py +200 -0
  151. brainscore_vision/models/resnet50_textures_2/region_layer_map/resnet50_textures_iteration=2.json +6 -0
  152. brainscore_vision/models/resnet50_textures_2/setup.py +29 -0
  153. brainscore_vision/models/resnet50_textures_2/test.py +3 -0
  154. brainscore_vision/models/resnet50_textures_3/__init__.py +6 -0
  155. brainscore_vision/models/resnet50_textures_3/model.py +200 -0
  156. brainscore_vision/models/resnet50_textures_3/region_layer_map/resnet50_textures_iteration=3.json +6 -0
  157. brainscore_vision/models/resnet50_textures_3/setup.py +29 -0
  158. brainscore_vision/models/resnet50_textures_3/test.py +3 -0
  159. brainscore_vision/models/resnet50_textures_4/__init__.py +6 -0
  160. brainscore_vision/models/resnet50_textures_4/model.py +200 -0
  161. brainscore_vision/models/resnet50_textures_4/region_layer_map/resnet50_textures_iteration=4.json +6 -0
  162. brainscore_vision/models/resnet50_textures_4/setup.py +29 -0
  163. brainscore_vision/models/resnet50_textures_4/test.py +3 -0
  164. brainscore_vision/models/resnet50_textures_5/__init__.py +6 -0
  165. brainscore_vision/models/resnet50_textures_5/model.py +200 -0
  166. brainscore_vision/models/resnet50_textures_5/region_layer_map/resnet50_textures_iteration=5.json +6 -0
  167. brainscore_vision/models/resnet50_textures_5/setup.py +29 -0
  168. brainscore_vision/models/resnet50_textures_5/test.py +3 -0
  169. brainscore_vision/models/resnet50_wo_shading_1/__init__.py +6 -0
  170. brainscore_vision/models/resnet50_wo_shading_1/model.py +200 -0
  171. brainscore_vision/models/resnet50_wo_shading_1/region_layer_map/resnet50_wo_shading_iteration=1.json +6 -0
  172. brainscore_vision/models/resnet50_wo_shading_1/setup.py +29 -0
  173. brainscore_vision/models/resnet50_wo_shading_1/test.py +3 -0
  174. brainscore_vision/models/resnet50_wo_shading_3/__init__.py +6 -0
  175. brainscore_vision/models/resnet50_wo_shading_3/model.py +200 -0
  176. brainscore_vision/models/resnet50_wo_shading_3/region_layer_map/resnet50_wo_shading_iteration=3.json +6 -0
  177. brainscore_vision/models/resnet50_wo_shading_3/setup.py +29 -0
  178. brainscore_vision/models/resnet50_wo_shading_3/test.py +3 -0
  179. brainscore_vision/models/resnet50_wo_shading_4/__init__.py +6 -0
  180. brainscore_vision/models/resnet50_wo_shading_4/model.py +200 -0
  181. brainscore_vision/models/resnet50_wo_shading_4/region_layer_map/resnet50_wo_shading_iteration=4.json +6 -0
  182. brainscore_vision/models/resnet50_wo_shading_4/setup.py +29 -0
  183. brainscore_vision/models/resnet50_wo_shading_4/test.py +3 -0
  184. brainscore_vision/models/resnet50_wo_shadows_4/__init__.py +6 -0
  185. brainscore_vision/models/resnet50_wo_shadows_4/model.py +200 -0
  186. brainscore_vision/models/resnet50_wo_shadows_4/region_layer_map/resnet50_wo_shadows_iteration=4.json +6 -0
  187. brainscore_vision/models/resnet50_wo_shadows_4/setup.py +29 -0
  188. brainscore_vision/models/resnet50_wo_shadows_4/test.py +3 -0
  189. brainscore_vision/models/resnet50_z_axis_1/__init__.py +6 -0
  190. brainscore_vision/models/resnet50_z_axis_1/model.py +200 -0
  191. brainscore_vision/models/resnet50_z_axis_1/region_layer_map/resnet50_z_axis_iteration=1.json +6 -0
  192. brainscore_vision/models/resnet50_z_axis_1/setup.py +29 -0
  193. brainscore_vision/models/resnet50_z_axis_1/test.py +3 -0
  194. brainscore_vision/models/resnet50_z_axis_2/__init__.py +6 -0
  195. brainscore_vision/models/resnet50_z_axis_2/model.py +200 -0
  196. brainscore_vision/models/resnet50_z_axis_2/region_layer_map/resnet50_z_axis_iteration=2.json +6 -0
  197. brainscore_vision/models/resnet50_z_axis_2/setup.py +29 -0
  198. brainscore_vision/models/resnet50_z_axis_2/test.py +3 -0
  199. brainscore_vision/models/resnet50_z_axis_3/__init__.py +6 -0
  200. brainscore_vision/models/resnet50_z_axis_3/model.py +200 -0
  201. brainscore_vision/models/resnet50_z_axis_3/region_layer_map/resnet50_z_axis_iteration=3.json +6 -0
  202. brainscore_vision/models/resnet50_z_axis_3/setup.py +29 -0
  203. brainscore_vision/models/resnet50_z_axis_3/test.py +3 -0
  204. brainscore_vision/models/resnet50_z_axis_5/__init__.py +6 -0
  205. brainscore_vision/models/resnet50_z_axis_5/model.py +200 -0
  206. brainscore_vision/models/resnet50_z_axis_5/region_layer_map/resnet50_z_axis_iteration=5.json +6 -0
  207. brainscore_vision/models/resnet50_z_axis_5/setup.py +29 -0
  208. brainscore_vision/models/resnet50_z_axis_5/test.py +3 -0
  209. {brainscore_vision-2.2.2.dist-info → brainscore_vision-2.2.4.dist-info}/METADATA +1 -1
  210. {brainscore_vision-2.2.2.dist-info → brainscore_vision-2.2.4.dist-info}/RECORD +213 -5
  211. {brainscore_vision-2.2.2.dist-info → brainscore_vision-2.2.4.dist-info}/LICENSE +0 -0
  212. {brainscore_vision-2.2.2.dist-info → brainscore_vision-2.2.4.dist-info}/WHEEL +0 -0
  213. {brainscore_vision-2.2.2.dist-info → brainscore_vision-2.2.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,962 @@
1
+ import os
2
+ import sys
3
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import functools
9
+ from torchvision import transforms
10
+ from brainscore_vision.model_helpers.check_submission import check_models
11
+ from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
12
+ from pretrained import get_model_instance, clear_models_and_aliases, register_model, register_aliases
13
+ from PIL import Image
14
+
15
+ SUBMODULE_SEPARATOR = '.'
16
+
17
+ LAYERS = ['Retina_5', 'LGN_5', 'V1_5', 'V2_5', 'V3_5', 'V4_5', 'LOC_5', 'logits']
18
+
19
+
20
+ def get_model(model_name='blt_vs', key_or_alias='blt_vs', image_size=224):
21
+ """
22
+ Get a model instance with preprocessing wrapped in a PytorchWrapper.
23
+
24
+ Args:
25
+ model_name (str): Identifier for the model.
26
+ key_or_alias (str): Key or alias for the registered model.
27
+ image_size (int): Input image size for preprocessing.
28
+
29
+ Returns:
30
+ PytorchWrapper: A wrapper around the model with preprocessing.
31
+ """
32
+
33
+ clear_models_and_aliases(BLT_VS)
34
+
35
+ register_model(
36
+ BLT_VS,
37
+ 'blt_vs',
38
+ 'https://zenodo.org/records/14223659/files/blt_vs.zip',
39
+ '36d74a367a261e788028c6c9caa7a5675fee48e938a6b86a6c62655b23afaf53'
40
+ )
41
+
42
+ register_aliases(BLT_VS, 'blt_vs', 'blt_vs')
43
+
44
+
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+
47
+ preprocessing = functools.partial(load_preprocess_images_sush, image_size=image_size)
48
+
49
+ blt_model = get_model_instance(BLT_VS, key_or_alias)
50
+ blt_model.to(device)
51
+ wrapper = PytorchWrapper(identifier=model_name, model=blt_model, preprocessing=preprocessing)
52
+
53
+ return wrapper
54
+
55
+
56
+ def load_preprocess_images_sush(image_filepaths, image_size, **kwargs):
57
+ images = load_images(image_filepaths)
58
+ images = preprocess_images_sush(images, image_size=image_size, **kwargs)
59
+ return images
60
+
61
+
62
+ def load_images(image_filepaths):
63
+ return [load_image(image_filepath) for image_filepath in image_filepaths]
64
+
65
+ def preprocess_images_sush(images, image_size, **kwargs):
66
+ preprocess = torchvision_preprocess_input_sush(image_size, **kwargs)
67
+ images = [preprocess(image) for image in images]
68
+ images = np.concatenate(images)
69
+ return images
70
+
71
+ def torchvision_preprocess_sush(normalize_mean=(0.5, 0.5, 0.5), normalize_std=(0.5, 0.5, 0.5)):
72
+
73
+
74
+
75
+ return transforms.Compose([
76
+ transforms.ToTensor(),
77
+ transforms.Normalize(mean=normalize_mean, std=normalize_std),
78
+ lambda img: img.unsqueeze(0)
79
+ ])
80
+
81
+ def torchvision_preprocess_input_sush(image_size, **kwargs):
82
+ from torchvision import transforms
83
+ return transforms.Compose([
84
+ transforms.Resize((image_size, image_size)),
85
+ transforms.CenterCrop(image_size),
86
+ torchvision_preprocess_sush(**kwargs),
87
+ ])
88
+
89
+ def load_image(image_filepath):
90
+ with Image.open(image_filepath) as pil_image:
91
+ if 'L' not in pil_image.mode.upper() and 'A' not in pil_image.mode.upper() \
92
+ and 'P' not in pil_image.mode.upper(): # not binary and not alpha and not palletized
93
+ # work around to https://github.com/python-pillow/Pillow/issues/1144,
94
+ # see https://stackoverflow.com/a/30376272/2225200
95
+ return pil_image.copy()
96
+ else: # make sure potential binary images are in RGB
97
+ rgb_image = Image.new("RGB", pil_image.size)
98
+ rgb_image.paste(pil_image)
99
+ return rgb_image
100
+
101
+
102
+
103
+ class BLT_VS(nn.Module):
104
+ """
105
+ BLT_VS model simulates the ventral stream of the visual cortex. See BLT_VS_info.txt for more details on motivation and design.
106
+
107
+ Parameters:
108
+ -----------
109
+ timesteps : int
110
+ Number of time steps for the recurrent computation.
111
+ num_classes : int
112
+ Number of output classes for classification.
113
+ add_feats : int
114
+ Additional features to maintain orientation, color, etc.
115
+ lateral_connections : bool
116
+ Whether to include lateral connections.
117
+ topdown_connections : bool
118
+ Whether to include top-down connections.
119
+ skip_connections : bool
120
+ Whether to include skip connections.
121
+ bio_unroll : bool
122
+ Whether to use biological unrolling.
123
+ image_size : int
124
+ Size of the input image (height and width).
125
+ hook_type : str
126
+ What kind of area/timestep hooks to register. Options are 'concat' (concat BU/TD), 'separate', 'None'.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ timesteps=12,
132
+ num_classes=565,
133
+ add_feats=100,
134
+ lateral_connections=True,
135
+ topdown_connections=True,
136
+ skip_connections=True,
137
+ bio_unroll=True,
138
+ image_size=224,
139
+ hook_type='None',
140
+ ):
141
+ super(BLT_VS, self).__init__()
142
+
143
+ self.timesteps = timesteps
144
+ self.num_classes = num_classes
145
+ self.add_feats = add_feats
146
+ self.lateral_connections = lateral_connections
147
+ self.topdown_connections = topdown_connections
148
+ self.skip_connections = skip_connections
149
+ self.bio_unroll = bio_unroll
150
+ self.image_size = image_size
151
+ self.hook_type = hook_type
152
+
153
+ # Define network areas and configurations
154
+ self.areas = ["Retina", "LGN", "V1", "V2", "V3", "V4", "LOC", "Readout"]
155
+
156
+ if image_size == 224:
157
+ self.kernel_sizes = [7, 7, 5, 1, 5, 3, 3, 5]
158
+ self.kernel_sizes_lateral = [0, 0, 5, 5, 5, 5, 5, 0]
159
+ else:
160
+ self.kernel_sizes = [5, 3, 3, 1, 3, 3, 3, 3]
161
+ self.kernel_sizes_lateral = [0, 0, 3, 3, 3, 3, 3, 0]
162
+
163
+ self.strides = [2, 2, 2, 1, 1, 1, 2, 2]
164
+ self.paddings = (np.array(self.kernel_sizes) - 1) // 2 # For 'same' padding
165
+ self.channel_sizes = [
166
+ 32,
167
+ 32,
168
+ 576,
169
+ 480,
170
+ 352,
171
+ 256,
172
+ 352,
173
+ int(num_classes + add_feats),
174
+ ]
175
+
176
+ # Top-down connections configuration
177
+ self.topdown_connections_layers = [
178
+ False,
179
+ True,
180
+ True,
181
+ True,
182
+ True,
183
+ True,
184
+ True,
185
+ False,
186
+ ]
187
+
188
+ # Initialize network layers
189
+ self.connections = nn.ModuleDict()
190
+ for idx in range(len(self.areas) - 1):
191
+ area = self.areas[idx]
192
+ self.connections[area] = BLT_VS_Layer(
193
+ layer_n=idx,
194
+ channel_sizes=self.channel_sizes,
195
+ strides=self.strides,
196
+ kernel_sizes=self.kernel_sizes,
197
+ kernel_sizes_lateral=self.kernel_sizes_lateral,
198
+ paddings=self.paddings,
199
+ lateral_connections=self.lateral_connections
200
+ and (self.kernel_sizes_lateral[idx] > 0),
201
+ topdown_connections=self.topdown_connections
202
+ and self.topdown_connections_layers[idx],
203
+ skip_connections_bu=self.skip_connections and (idx == 5),
204
+ skip_connections_td=self.skip_connections and (idx == 2),
205
+ image_size=image_size,
206
+ )
207
+ self.connections["Readout"] = BLT_VS_Readout(
208
+ layer_n=7,
209
+ channel_sizes=self.channel_sizes,
210
+ kernel_sizes=self.kernel_sizes,
211
+ strides=self.strides,
212
+ num_classes=num_classes,
213
+ )
214
+
215
+ # Create nn.identity for each area for each timesteps such that hooks can be registered to acquire bu and td for any area/timestep
216
+ if self.hook_type != 'None':
217
+ for area in self.areas:
218
+ for t in range(timesteps):
219
+ if self.hook_type == 'concat' and area != 'Readout': # we can't concat for readout
220
+ setattr(self, f"{area}_{t}", nn.Identity())
221
+ else:
222
+ setattr(self, f"{area}_{t}_BU", nn.Identity())
223
+ setattr(self, f"{area}_{t}_TD", nn.Identity())
224
+ setattr(self, "logits", nn.Identity())
225
+
226
+ # Precompute output shapes
227
+ self.output_shapes = self.compute_output_shapes(image_size)
228
+
229
+ def compute_output_shapes(self, image_size):
230
+ """
231
+ Compute the output shapes for each area based on the image size.
232
+
233
+ Parameters:
234
+ -----------
235
+ image_size : int
236
+ The input image size.
237
+
238
+ Returns:
239
+ --------
240
+ output_shapes : list of tuples
241
+ The output height and width for each area.
242
+ """
243
+ output_shapes = []
244
+ height = width = image_size
245
+ for idx in range(len(self.areas)):
246
+ kernel_size = self.kernel_sizes[idx]
247
+ stride = self.strides[idx]
248
+ padding = self.paddings[idx]
249
+ height = (height + 2 * padding - kernel_size) // stride + 1
250
+ width = (width + 2 * padding - kernel_size) // stride + 1
251
+ output_shapes.append((int(height), int(width)))
252
+ return output_shapes
253
+
254
+ def forward(
255
+ self,
256
+ img_input,
257
+ extract_actvs=False,
258
+ areas=None,
259
+ timesteps=None,
260
+ bu=True,
261
+ td=True,
262
+ concat=False,
263
+ ):
264
+ """
265
+ Forward pass for the BLT_VS model.
266
+
267
+ Parameters:
268
+ -----------
269
+ img_input : torch.Tensor
270
+ Input image tensor.
271
+ extract_actvs : bool
272
+ Whether to extract activations.
273
+ areas : list of str
274
+ List of area names to retrieve activations from.
275
+ timesteps : list of int
276
+ List of timesteps to retrieve activations at.
277
+ bu : bool
278
+ Whether to retrieve bottom-up activations.
279
+ td : bool
280
+ Whether to retrieve top-down activations.
281
+ concat : bool
282
+ Whether to concatenate BU and TD activations.
283
+
284
+ Returns:
285
+ --------
286
+ If extract_actvs is False:
287
+ readout_output : list of torch.Tensor
288
+ The readout outputs at each timestep.
289
+ If extract_actvs is True:
290
+ (readout_output, activations) : tuple
291
+ readout_output is as above.
292
+ activations is a dict with structure activations[area][timestep] = activation
293
+ """
294
+ # check if input has 4 dims, else add batch dim
295
+ if len(img_input.shape) == 3:
296
+ img_input = img_input.unsqueeze(0)
297
+
298
+ if extract_actvs:
299
+ if areas is None or timesteps is None:
300
+ raise ValueError(
301
+ "When extract_actvs is True, areas and timesteps must be specified."
302
+ )
303
+ activations = {area: {} for area in areas}
304
+ else:
305
+ activations = None
306
+
307
+ readout_output = []
308
+ bu_activations = [None for _ in self.areas]
309
+ td_activations = [None for _ in self.areas]
310
+ batch_size = img_input.size(0)
311
+
312
+ if self.bio_unroll:
313
+ # Implement the bio_unroll forward pass
314
+ bu_activations_old = [None for _ in self.areas]
315
+ td_activations_old = [None for _ in self.areas]
316
+
317
+ # Initial activation for Retina
318
+ bu_activations_old[0], _ = self.connections["Retina"](bu_input=img_input)
319
+ bu_activations[0] = bu_activations_old[0]
320
+
321
+ # Timestep 0 (if 0 is in timesteps)
322
+ t = 0
323
+ activations = self.activation_shenanigans(
324
+ extract_actvs, areas, timesteps, bu, td, concat, batch_size, bu_activations, td_activations, activations, t
325
+ )
326
+
327
+ for t in range(1, self.timesteps):
328
+ # For each timestep, update the outputs of the areas
329
+ for idx, area in enumerate(self.areas[1:-1]):
330
+ # Update only if necessary
331
+ should_update = any(
332
+ [
333
+ bu_activations_old[idx] is not None, # bottom-up connection
334
+ (bu_activations_old[2] is not None and (idx + 1) == 5), # skip connection bu
335
+ td_activations_old[idx + 2] is not None, # top-down connection
336
+ (td_activations_old[5] is not None and (idx + 1) == 2), # skip connection td
337
+ ]
338
+ )
339
+ if should_update:
340
+ bu_act, td_act = self.connections[area](
341
+ bu_input=bu_activations_old[idx],
342
+ bu_l_input=bu_activations_old[idx + 1],
343
+ td_input=td_activations_old[idx + 2],
344
+ td_l_input=td_activations_old[idx + 1],
345
+ bu_skip_input=bu_activations_old[2]
346
+ if (idx + 1) == 5
347
+ else None,
348
+ td_skip_input=td_activations_old[5]
349
+ if (idx + 1) == 2
350
+ else None,
351
+ )
352
+ bu_activations[idx + 1] = bu_act
353
+ td_activations[idx + 1] = td_act
354
+
355
+ bu_activations_old = bu_activations[:]
356
+ td_activations_old = td_activations[:]
357
+
358
+ # Activate readout when LOC output is ready
359
+ if bu_activations_old[-2] is not None:
360
+ bu_act, td_act = self.connections["Readout"](
361
+ bu_input=bu_activations_old[-2]
362
+ )
363
+ bu_activations_old[-1] = bu_act
364
+ td_activations_old[-1] = td_act
365
+ readout_output.append(bu_act)
366
+ bu_activations[-1] = bu_act
367
+ td_activations[-1] = td_act
368
+
369
+ activations = self.activation_shenanigans(
370
+ extract_actvs, areas, timesteps, bu, td, concat, batch_size, bu_activations, td_activations, activations, t
371
+ )
372
+
373
+ else:
374
+ # Implement the standard forward pass
375
+ bu_activations[0], _ = self.connections["Retina"](bu_input=img_input)
376
+ for idx, area in enumerate(self.areas[1:-1]):
377
+ bu_act, _ = self.connections[area](
378
+ bu_input=bu_activations[idx],
379
+ bu_skip_input=bu_activations[2] if idx + 1 == 5 else None,
380
+ )
381
+ bu_activations[idx + 1] = bu_act
382
+
383
+ bu_act, td_act = self.connections["Readout"](bu_input=bu_activations[-2])
384
+ bu_activations[-1] = bu_act
385
+ td_activations[-1] = td_act
386
+ readout_output.append(bu_act)
387
+
388
+ for idx,area in enumerate(reversed(self.areas[1:-1])):
389
+ _, td_act = self.connections[area](
390
+ bu_input=bu_activations[-(idx + 2) - 1],
391
+ td_input=td_activations[-(idx + 2) + 1],
392
+ td_skip_input=td_activations[5] if idx + 1 == 2 else None,
393
+ )
394
+ td_activations[-(idx + 2)] = td_act
395
+ _, td_act = self.connections["Retina"](
396
+ bu_input=img_input,
397
+ td_input=td_activations[1],
398
+ )
399
+ td_activations[0] = td_act
400
+
401
+ t = 0
402
+ activations = self.activation_shenanigans(
403
+ extract_actvs, areas, timesteps, bu, td, concat, batch_size, bu_activations, td_activations, activations, t
404
+ )
405
+
406
+ for t in range(1, self.timesteps):
407
+ # For each timestep, compute the activations
408
+ for idx, area in enumerate(self.areas[1:-1]):
409
+ bu_act, _ = self.connections[area](
410
+ bu_input=bu_activations[idx],
411
+ bu_l_input=bu_activations[idx + 1],
412
+ td_input=td_activations[idx + 2],
413
+ bu_skip_input=bu_activations[2] if idx + 1 == 5 else None,
414
+ )
415
+ bu_activations[idx + 1] = bu_act
416
+
417
+ bu_act, td_act = self.connections["Readout"](bu_input=bu_activations[-2])
418
+ bu_activations[-1] = bu_act
419
+ td_activations[-1] = td_act
420
+ readout_output.append(bu_act)
421
+
422
+ for idx,area in enumerate(reversed(self.areas[1:-1])):
423
+ _, td_act = self.connections[area](
424
+ bu_input=bu_activations[-(idx + 2) - 1],
425
+ td_input=td_activations[-(idx + 2) + 1],
426
+ td_l_input=td_activations[-(idx + 2)],
427
+ td_skip_input=td_activations[5] if idx + 1 == 2 else None,
428
+ )
429
+ td_activations[-(idx + 2)] = td_act
430
+ _, td_act = self.connections["Retina"](
431
+ bu_input=img_input,
432
+ td_input=td_activations[1],
433
+ td_l_input=td_activations[0],
434
+ )
435
+ td_activations[0] = td_act
436
+
437
+ activations = self.activation_shenanigans(
438
+ extract_actvs, areas, timesteps, bu, td, concat, batch_size, bu_activations, td_activations, activations, t
439
+ )
440
+
441
+ if self.hook_type != 'None':
442
+ _ = self.logits(readout_output[-1])
443
+
444
+ if extract_actvs:
445
+ return readout_output, activations
446
+ else:
447
+ return readout_output
448
+
449
+
450
+ def activation_shenanigans(
451
+ self, extract_actvs, areas, timesteps, bu, td, concat, batch_size, bu_activations, td_activations, activations, t
452
+ ):
453
+ """
454
+ Helper function to implement activation collection and compute relevant for hook registration.
455
+
456
+ Parameters:
457
+ -----------
458
+ extract_actvs : bool
459
+ Whether to extract activations.
460
+ areas : list of str
461
+ List of area names to retrieve activations from.
462
+ timesteps : list of int
463
+ List of timesteps to retrieve activations at.
464
+ bu : bool
465
+ Whether to retrieve bottom-up activations.
466
+ td : bool
467
+ Whether to retrieve top-down activations.
468
+ concat : bool
469
+ Whether to concatenate BU and TD activations.
470
+ batch_size : int
471
+ Batch size of the input data.
472
+ bu_activations : list of torch.Tensor
473
+ List of bottom-up activations.
474
+ td_activations : list of torch.Tensor
475
+ List of top-down activations.
476
+ activations : dict
477
+ Dictionary to store activations.
478
+ t : int
479
+ Current timestep.
480
+
481
+ Returns:
482
+ --------
483
+ activations : dict
484
+ Updated activations dictionary.
485
+ """
486
+ if extract_actvs and t in timesteps:
487
+ for idx, area in enumerate(self.areas):
488
+ if area in areas:
489
+ # If concat is True and area is 'Readout', skip
490
+ if concat and area == 'Readout':
491
+ continue
492
+ activation = self.collect_activation(
493
+ bu_activations[idx],
494
+ td_activations[idx],
495
+ bu,
496
+ td,
497
+ concat,
498
+ idx,
499
+ batch_size,
500
+ )
501
+ activations[area][t] = activation
502
+
503
+ if self.hook_type != 'None':
504
+ for idx, area in enumerate(self.areas):
505
+ if self.hook_type == 'concat' and area != 'Readout':
506
+ _ = getattr(self, f"{area}_{t}")(concat_or_not(bu_activations[idx], td_activations[idx], dim=1))
507
+ elif self.hook_type == 'separate':
508
+ _ = getattr(self, f"{area}_{t}_BU")(bu_activations[idx])
509
+ _ = getattr(self, f"{area}_{t}_TD")(td_activations[idx])
510
+
511
+ return activations
512
+
513
+
514
+ def collect_activation(
515
+ self, bu_activation, td_activation, bu_flag, td_flag, concat, area_idx, batch_size
516
+ ):
517
+ """
518
+ Helper function to collect activations, handling None values and concatenation.
519
+
520
+ Parameters:
521
+ -----------
522
+ bu_activation : torch.Tensor or None
523
+ Bottom-up activation.
524
+ td_activation : torch.Tensor or None
525
+ Top-down activation.
526
+ bu_flag : bool
527
+ Whether to collect BU activations.
528
+ td_flag : bool
529
+ Whether to collect TD activations.
530
+ concat : bool
531
+ Whether to concatenate BU and TD activations.
532
+ area_idx : int
533
+ Index of the area in self.areas.
534
+ batch_size : int
535
+ Batch size of the input data.
536
+
537
+ Returns:
538
+ --------
539
+ activation : torch.Tensor or dict
540
+ The collected activation. If concat is True, returns a single tensor.
541
+ If concat is False, returns a dict with keys 'bu' and/or 'td'.
542
+ """
543
+ device = next(self.parameters()).device # Get the device of the model
544
+
545
+ if concat:
546
+ # Handle None activations
547
+ if bu_activation is None and td_activation is None:
548
+ # Get output shape and channels
549
+ channels = self.channel_sizes[area_idx] * 2 # BU and TD activations concatenated
550
+ height, width = self.output_shapes[area_idx]
551
+ zeros = torch.zeros((batch_size, channels, height, width), device=device)
552
+ return zeros
553
+ if bu_activation is None:
554
+ bu_activation = torch.zeros_like(td_activation)
555
+ if td_activation is None:
556
+ td_activation = torch.zeros_like(bu_activation)
557
+ activation = torch.cat([bu_activation, td_activation], dim=1)
558
+ return activation
559
+ else:
560
+ activation = {}
561
+ if bu_flag:
562
+ if bu_activation is not None:
563
+ activation['bu'] = bu_activation
564
+ elif td_activation is not None:
565
+ activation['bu'] = torch.zeros_like(td_activation)
566
+ else:
567
+ # Create zeros of appropriate shape
568
+ channels = self.channel_sizes[area_idx]
569
+ height, width = self.output_shapes[area_idx]
570
+ activation['bu'] = torch.zeros(
571
+ (batch_size, channels, height, width), device=device
572
+ )
573
+ if td_flag:
574
+ if td_activation is not None:
575
+ activation['td'] = td_activation
576
+ elif bu_activation is not None:
577
+ activation['td'] = torch.zeros_like(bu_activation)
578
+ else:
579
+ channels = self.channel_sizes[area_idx]
580
+ height, width = self.output_shapes[area_idx]
581
+ activation['td'] = torch.zeros(
582
+ (batch_size, channels, height, width), device=device
583
+ )
584
+ return activation
585
+
586
+
587
+ class BLT_VS_Layer(nn.Module):
588
+ """
589
+ A single layer in the BLT_VS model, representing a cortical area.
590
+
591
+ Parameters:
592
+ -----------
593
+ layer_n : int
594
+ Layer index.
595
+ channel_sizes : list
596
+ List of channel sizes for each layer.
597
+ strides : list
598
+ List of strides for each layer.
599
+ kernel_sizes : list
600
+ List of kernel sizes for each layer.
601
+ kernel_sizes_lateral : list
602
+ List of lateral kernel sizes for each layer.
603
+ paddings : list
604
+ List of paddings for each layer.
605
+ lateral_connections : bool
606
+ Whether to include lateral connections.
607
+ topdown_connections : bool
608
+ Whether to include top-down connections.
609
+ skip_connections_bu : bool
610
+ Whether to include bottom-up skip connections.
611
+ skip_connections_td : bool
612
+ Whether to include top-down skip connections.
613
+ image_size : int
614
+ Size of the input image (height and width).
615
+ """
616
+
617
+ def __init__(
618
+ self,
619
+ layer_n,
620
+ channel_sizes,
621
+ strides,
622
+ kernel_sizes,
623
+ kernel_sizes_lateral,
624
+ paddings,
625
+ lateral_connections=True,
626
+ topdown_connections=True,
627
+ skip_connections_bu=False,
628
+ skip_connections_td=False,
629
+ image_size=224,
630
+ ):
631
+ super(BLT_VS_Layer, self).__init__()
632
+
633
+ in_channels = 3 if layer_n == 0 else channel_sizes[layer_n - 1]
634
+ out_channels = channel_sizes[layer_n]
635
+
636
+ # Bottom-up convolution
637
+ self.bu_conv = nn.Conv2d(
638
+ in_channels=in_channels,
639
+ out_channels=out_channels,
640
+ kernel_size=kernel_sizes[layer_n],
641
+ stride=strides[layer_n],
642
+ padding=paddings[layer_n],
643
+ )
644
+
645
+ # Lateral connections
646
+ if lateral_connections:
647
+ kernel_size_lateral = kernel_sizes_lateral[layer_n]
648
+ self.bu_l_conv_depthwise = nn.Conv2d(
649
+ in_channels=out_channels,
650
+ out_channels=out_channels,
651
+ kernel_size=kernel_size_lateral,
652
+ stride=1,
653
+ padding='same',
654
+ groups=out_channels,
655
+ )
656
+ self.bu_l_conv_pointwise = nn.Conv2d(
657
+ in_channels=out_channels,
658
+ out_channels=out_channels,
659
+ kernel_size=1,
660
+ stride=1,
661
+ padding=0,
662
+ )
663
+ else:
664
+ self.bu_l_conv_depthwise = NoOpModule()
665
+ self.bu_l_conv_pointwise = NoOpModule()
666
+
667
+ # Top-down connections
668
+ if topdown_connections:
669
+ self.td_conv = nn.ConvTranspose2d(
670
+ in_channels=channel_sizes[layer_n + 1],
671
+ out_channels=out_channels,
672
+ kernel_size=kernel_sizes[layer_n + 1],
673
+ stride=strides[layer_n + 1],
674
+ padding=(kernel_sizes[layer_n + 1] - 1) // 2
675
+ )
676
+ if lateral_connections:
677
+ self.td_l_conv_depthwise = nn.Conv2d(
678
+ in_channels=out_channels,
679
+ out_channels=out_channels,
680
+ kernel_size=kernel_sizes_lateral[layer_n],
681
+ stride=1,
682
+ padding='same',
683
+ groups=out_channels,
684
+ )
685
+ self.td_l_conv_pointwise = nn.Conv2d(
686
+ in_channels=out_channels,
687
+ out_channels=out_channels,
688
+ kernel_size=1,
689
+ stride=1,
690
+ padding=0,
691
+ )
692
+ else:
693
+ self.td_l_conv_depthwise = NoOpModule()
694
+ self.td_l_conv_pointwise = NoOpModule()
695
+ else:
696
+ self.td_conv = NoOpModule()
697
+ self.td_l_conv_depthwise = NoOpModule()
698
+ self.td_l_conv_pointwise = NoOpModule()
699
+
700
+ # Skip connections
701
+ if skip_connections_bu:
702
+ self.skip_bu_depthwise = nn.Conv2d(
703
+ in_channels=channel_sizes[2], # From V1
704
+ out_channels=out_channels,
705
+ kernel_size=7 if image_size == 224 else 5,
706
+ stride=1,
707
+ padding='same',
708
+ groups=np.gcd(channel_sizes[2], out_channels),
709
+ )
710
+ self.skip_bu_pointwise = nn.Conv2d(
711
+ in_channels=out_channels,
712
+ out_channels=out_channels,
713
+ kernel_size=1,
714
+ stride=1,
715
+ padding=0,
716
+ )
717
+ else:
718
+ self.skip_bu_depthwise = NoOpModule()
719
+ self.skip_bu_pointwise = NoOpModule()
720
+
721
+ if skip_connections_td:
722
+ self.skip_td_depthwise = nn.Conv2d(
723
+ in_channels=channel_sizes[5], # From V4
724
+ out_channels=out_channels,
725
+ kernel_size=3, # V4 to V1 skip connection
726
+ stride=1,
727
+ padding='same',
728
+ groups=np.gcd(channel_sizes[5], out_channels),
729
+ )
730
+ self.skip_td_pointwise = nn.Conv2d(
731
+ in_channels=out_channels,
732
+ out_channels=out_channels,
733
+ kernel_size=1,
734
+ stride=1,
735
+ padding=0,
736
+ )
737
+ else:
738
+ self.skip_td_depthwise = NoOpModule()
739
+ self.skip_td_pointwise = NoOpModule()
740
+
741
+ self.layer_norm_bu = nn.GroupNorm(num_groups=1, num_channels=out_channels)
742
+ self.layer_norm_td = nn.GroupNorm(num_groups=1, num_channels=out_channels)
743
+
744
+ def forward(
745
+ self,
746
+ bu_input,
747
+ bu_l_input=None,
748
+ td_input=None,
749
+ td_l_input=None,
750
+ bu_skip_input=None,
751
+ td_skip_input=None,
752
+ ):
753
+ """
754
+ Forward pass for a single BLT_VS layer.
755
+
756
+ Parameters:
757
+ -----------
758
+ bu_input : torch.Tensor or None
759
+ Bottom-up input tensor.
760
+ bu_l_input : torch.Tensor or None
761
+ Bottom-up lateral input tensor.
762
+ td_input : torch.Tensor or None
763
+ Top-down input tensor.
764
+ td_l_input : torch.Tensor or None
765
+ Top-down lateral input tensor.
766
+ bu_skip_input : torch.Tensor or None
767
+ Bottom-up skip connection input.
768
+ td_skip_input : torch.Tensor or None
769
+ Top-down skip connection input.
770
+
771
+ Returns:
772
+ --------
773
+ bu_output : torch.Tensor
774
+ Bottom-up output tensor.
775
+ td_output : torch.Tensor
776
+ Top-down output tensor.
777
+ """
778
+ # Process bottom-up input
779
+ bu_processed = self.bu_conv(bu_input) if bu_input is not None else 0
780
+
781
+ # Process top-down input
782
+ td_processed = (
783
+ self.td_conv(td_input, output_size=bu_processed.size())
784
+ if td_input is not None
785
+ else 0
786
+ )
787
+
788
+ # Process bottom-up lateral input
789
+ bu_l_processed = (
790
+ self.bu_l_conv_pointwise(self.bu_l_conv_depthwise(bu_l_input))
791
+ if bu_l_input is not None
792
+ else 0
793
+ )
794
+
795
+ # Process top-down lateral input
796
+ td_l_processed = (
797
+ self.td_l_conv_pointwise(self.td_l_conv_depthwise(td_l_input))
798
+ if td_l_input is not None
799
+ else 0
800
+ )
801
+
802
+ # Process skip connections
803
+ skip_bu_processed = (
804
+ self.skip_bu_pointwise(self.skip_bu_depthwise(bu_skip_input))
805
+ if bu_skip_input is not None
806
+ else 0
807
+ )
808
+ skip_td_processed = (
809
+ self.skip_td_pointwise(self.skip_td_depthwise(td_skip_input))
810
+ if td_skip_input is not None
811
+ else 0
812
+ )
813
+
814
+ # Compute sums
815
+ bu_drive = bu_processed + bu_l_processed + skip_bu_processed
816
+ bu_mod = bu_processed + skip_bu_processed
817
+ td_drive = td_processed + td_l_processed + skip_td_processed
818
+ td_mod = td_processed + skip_td_processed
819
+
820
+ # Compute bottom-up output
821
+ if isinstance(td_mod, torch.Tensor):
822
+ if isinstance(bu_drive, torch.Tensor):
823
+ bu_output = F.relu(bu_drive) * 2 * torch.sigmoid(td_mod)
824
+ else:
825
+ bu_output = torch.zeros_like(td_mod)
826
+ else:
827
+ bu_output = F.relu(bu_drive)
828
+
829
+ # Compute top-down output
830
+ if isinstance(bu_mod, torch.Tensor):
831
+ if isinstance(td_drive, torch.Tensor):
832
+ td_output = F.relu(td_drive) * 2 * torch.sigmoid(bu_mod)
833
+ else:
834
+ td_output = torch.zeros_like(bu_mod)
835
+ else:
836
+ td_output = F.relu(td_drive)
837
+
838
+ bu_output = self.layer_norm_bu(bu_output)
839
+ td_output = self.layer_norm_td(td_output)
840
+
841
+ return bu_output, td_output
842
+
843
+
844
+ class BLT_VS_Readout(nn.Module):
845
+ """
846
+ Readout layer for the BLT_VS model.
847
+
848
+ Parameters:
849
+ -----------
850
+ layer_n : int
851
+ Layer index.
852
+ channel_sizes : list
853
+ List of channel sizes for each layer.
854
+ kernel_sizes : list
855
+ List of kernel sizes for each layer.
856
+ strides : list
857
+ List of strides for each layer.
858
+ num_classes : int
859
+ Number of output classes for classification.
860
+ """
861
+
862
+ def __init__(self, layer_n, channel_sizes, kernel_sizes, strides, num_classes):
863
+ super(BLT_VS_Readout, self).__init__()
864
+
865
+ self.num_classes = num_classes
866
+ in_channels = channel_sizes[layer_n - 1]
867
+ out_channels = channel_sizes[layer_n]
868
+
869
+ self.readout_conv = nn.Conv2d(
870
+ in_channels=in_channels,
871
+ out_channels=out_channels,
872
+ kernel_size=kernel_sizes[layer_n],
873
+ stride=strides[layer_n],
874
+ padding=(kernel_sizes[layer_n] - 1) // 2,
875
+ )
876
+
877
+ self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
878
+ self.layer_norm_td = nn.GroupNorm(num_groups=1, num_channels=out_channels)
879
+
880
+ def forward(self, bu_input):
881
+ """
882
+ Forward pass for the Readout layer.
883
+
884
+ Parameters:
885
+ -----------
886
+ bu_input : torch.Tensor
887
+ Bottom-up input tensor.
888
+
889
+ Returns:
890
+ --------
891
+ output : torch.Tensor
892
+ Class scores for classification.
893
+ td_output : torch.Tensor
894
+ Top-down output tensor.
895
+ """
896
+ output_intermediate = self.readout_conv(bu_input)
897
+ output_pooled = self.global_avg_pool(output_intermediate).view(
898
+ output_intermediate.size(0), -1
899
+ )
900
+ output = output_pooled[
901
+ :, : self.num_classes
902
+ ] # Only pass classes to softmax and loss
903
+ td_output = self.layer_norm_td(F.relu(output_intermediate))
904
+
905
+ return output, td_output
906
+
907
+
908
+ class NoOpModule(nn.Module):
909
+ """
910
+ A no-operation module that returns zero regardless of the input.
911
+
912
+ This is used in places where an operation is conditionally skipped.
913
+ """
914
+
915
+ def __init__(self):
916
+ super(NoOpModule, self).__init__()
917
+
918
+ def forward(self, *args, **kwargs):
919
+ """
920
+ Forward pass that returns zero.
921
+
922
+ Returns:
923
+ --------
924
+ Zero tensor or zero value as appropriate.
925
+ """
926
+ return 0
927
+
928
+ def concat_or_not(bu_activation, td_activation, dim=1):
929
+ # If both are None, return None
930
+ if bu_activation is None and td_activation is None:
931
+ return None
932
+
933
+ # If bu_activation is None, create a tensor of zeros like td_activation
934
+ if bu_activation is None:
935
+ bu_activation = torch.zeros_like(td_activation)
936
+
937
+ # If td_activation is None, create a tensor of zeros like bu_activation
938
+ if td_activation is None:
939
+ td_activation = torch.zeros_like(bu_activation)
940
+
941
+ # Concatenate along the specified dimension
942
+ return torch.cat([bu_activation, td_activation], dim=dim)
943
+
944
+ def get_layers(model_name):
945
+
946
+ brainscore_layers = LAYERS
947
+
948
+ return brainscore_layers
949
+
950
+ def get_bibtex(model_identifier):
951
+ """
952
+ A method returning the bibtex reference of the requested model as a string.
953
+ """
954
+
955
+ return ''
956
+
957
+ if __name__ == '__main__':
958
+ # Use this method to ensure the correctness of the BaseModel implementations.
959
+ # It executes a mock run of brain-score benchmarks.
960
+ check_models.check_base_models(__name__)
961
+
962
+