lightly-studio 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lightly-studio might be problematic. Click here for more details.

Files changed (219) hide show
  1. lightly_studio/__init__.py +11 -0
  2. lightly_studio/api/__init__.py +0 -0
  3. lightly_studio/api/app.py +110 -0
  4. lightly_studio/api/cache.py +77 -0
  5. lightly_studio/api/db.py +133 -0
  6. lightly_studio/api/db_tables.py +32 -0
  7. lightly_studio/api/features.py +7 -0
  8. lightly_studio/api/routes/api/annotation.py +233 -0
  9. lightly_studio/api/routes/api/annotation_label.py +90 -0
  10. lightly_studio/api/routes/api/annotation_task.py +38 -0
  11. lightly_studio/api/routes/api/classifier.py +387 -0
  12. lightly_studio/api/routes/api/dataset.py +182 -0
  13. lightly_studio/api/routes/api/dataset_tag.py +257 -0
  14. lightly_studio/api/routes/api/exceptions.py +96 -0
  15. lightly_studio/api/routes/api/features.py +17 -0
  16. lightly_studio/api/routes/api/metadata.py +37 -0
  17. lightly_studio/api/routes/api/metrics.py +80 -0
  18. lightly_studio/api/routes/api/sample.py +196 -0
  19. lightly_studio/api/routes/api/settings.py +45 -0
  20. lightly_studio/api/routes/api/status.py +19 -0
  21. lightly_studio/api/routes/api/text_embedding.py +48 -0
  22. lightly_studio/api/routes/api/validators.py +17 -0
  23. lightly_studio/api/routes/healthz.py +13 -0
  24. lightly_studio/api/routes/images.py +104 -0
  25. lightly_studio/api/routes/webapp.py +51 -0
  26. lightly_studio/api/server.py +82 -0
  27. lightly_studio/core/__init__.py +0 -0
  28. lightly_studio/core/dataset.py +523 -0
  29. lightly_studio/core/sample.py +77 -0
  30. lightly_studio/core/start_gui.py +15 -0
  31. lightly_studio/dataset/__init__.py +0 -0
  32. lightly_studio/dataset/edge_embedding_generator.py +144 -0
  33. lightly_studio/dataset/embedding_generator.py +91 -0
  34. lightly_studio/dataset/embedding_manager.py +163 -0
  35. lightly_studio/dataset/env.py +16 -0
  36. lightly_studio/dataset/file_utils.py +35 -0
  37. lightly_studio/dataset/loader.py +622 -0
  38. lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
  39. lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
  40. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
  41. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
  42. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
  43. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
  44. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
  45. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
  46. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
  47. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
  48. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
  49. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
  50. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
  51. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
  52. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
  53. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
  54. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
  55. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
  56. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
  57. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
  58. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
  59. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
  60. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
  61. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
  62. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
  63. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
  78. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
  79. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
  80. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
  81. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
  82. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
  83. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
  84. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
  85. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
  86. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
  87. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
  88. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
  89. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
  90. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
  91. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
  92. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
  93. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
  94. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
  95. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
  96. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
  97. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
  98. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
  99. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
  100. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
  101. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
  102. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
  103. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
  104. lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
  105. lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
  106. lightly_studio/examples/example.py +23 -0
  107. lightly_studio/examples/example_metadata.py +338 -0
  108. lightly_studio/examples/example_selection.py +39 -0
  109. lightly_studio/examples/example_split_work.py +67 -0
  110. lightly_studio/examples/example_v2.py +21 -0
  111. lightly_studio/export_schema.py +18 -0
  112. lightly_studio/few_shot_classifier/__init__.py +0 -0
  113. lightly_studio/few_shot_classifier/classifier.py +80 -0
  114. lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
  115. lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
  116. lightly_studio/metadata/complex_metadata.py +47 -0
  117. lightly_studio/metadata/gps_coordinate.py +41 -0
  118. lightly_studio/metadata/metadata_protocol.py +17 -0
  119. lightly_studio/metrics/__init__.py +0 -0
  120. lightly_studio/metrics/detection/__init__.py +0 -0
  121. lightly_studio/metrics/detection/map.py +268 -0
  122. lightly_studio/models/__init__.py +1 -0
  123. lightly_studio/models/annotation/__init__.py +0 -0
  124. lightly_studio/models/annotation/annotation_base.py +171 -0
  125. lightly_studio/models/annotation/instance_segmentation.py +56 -0
  126. lightly_studio/models/annotation/links.py +17 -0
  127. lightly_studio/models/annotation/object_detection.py +47 -0
  128. lightly_studio/models/annotation/semantic_segmentation.py +44 -0
  129. lightly_studio/models/annotation_label.py +47 -0
  130. lightly_studio/models/annotation_task.py +28 -0
  131. lightly_studio/models/classifier.py +20 -0
  132. lightly_studio/models/dataset.py +84 -0
  133. lightly_studio/models/embedding_model.py +30 -0
  134. lightly_studio/models/metadata.py +208 -0
  135. lightly_studio/models/sample.py +180 -0
  136. lightly_studio/models/sample_embedding.py +37 -0
  137. lightly_studio/models/settings.py +60 -0
  138. lightly_studio/models/tag.py +96 -0
  139. lightly_studio/py.typed +0 -0
  140. lightly_studio/resolvers/__init__.py +7 -0
  141. lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
  142. lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
  143. lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
  144. lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
  145. lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
  146. lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
  147. lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
  148. lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
  149. lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
  150. lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
  151. lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
  152. lightly_studio/resolvers/annotation_resolver/create.py +19 -0
  153. lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
  154. lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
  155. lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
  156. lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
  157. lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
  158. lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
  159. lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
  160. lightly_studio/resolvers/annotation_task_resolver.py +31 -0
  161. lightly_studio/resolvers/annotations/__init__.py +1 -0
  162. lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
  163. lightly_studio/resolvers/dataset_resolver.py +278 -0
  164. lightly_studio/resolvers/embedding_model_resolver.py +100 -0
  165. lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
  166. lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
  167. lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
  168. lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
  169. lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
  170. lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
  171. lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
  172. lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
  173. lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
  174. lightly_studio/resolvers/sample_resolver.py +249 -0
  175. lightly_studio/resolvers/samples_filter.py +81 -0
  176. lightly_studio/resolvers/settings_resolver.py +58 -0
  177. lightly_studio/resolvers/tag_resolver.py +276 -0
  178. lightly_studio/selection/README.md +6 -0
  179. lightly_studio/selection/mundig.py +105 -0
  180. lightly_studio/selection/select.py +96 -0
  181. lightly_studio/selection/select_via_db.py +93 -0
  182. lightly_studio/selection/selection_config.py +31 -0
  183. lightly_studio/services/annotations_service/__init__.py +21 -0
  184. lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
  185. lightly_studio/services/annotations_service/update_annotation.py +65 -0
  186. lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
  187. lightly_studio/services/annotations_service/update_annotations.py +29 -0
  188. lightly_studio/setup_logging.py +19 -0
  189. lightly_studio/type_definitions.py +19 -0
  190. lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
  191. lightly_studio/vendor/LICENSE +31 -0
  192. lightly_studio/vendor/LICENSE_weights_data +50 -0
  193. lightly_studio/vendor/README.md +5 -0
  194. lightly_studio/vendor/__init__.py +1 -0
  195. lightly_studio/vendor/mobileclip/__init__.py +96 -0
  196. lightly_studio/vendor/mobileclip/clip.py +77 -0
  197. lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
  198. lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
  199. lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
  200. lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
  201. lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
  202. lightly_studio/vendor/mobileclip/logger.py +154 -0
  203. lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
  204. lightly_studio/vendor/mobileclip/models/mci.py +933 -0
  205. lightly_studio/vendor/mobileclip/models/vit.py +433 -0
  206. lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
  207. lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
  208. lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
  209. lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
  210. lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
  211. lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
  212. lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
  213. lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
  214. lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
  215. lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
  216. lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
  217. lightly_studio-0.3.1.dist-info/METADATA +520 -0
  218. lightly_studio-0.3.1.dist-info/RECORD +219 -0
  219. lightly_studio-0.3.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,933 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ import copy
6
+ from functools import partial
7
+ from typing import List, Tuple, Optional, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13
+ from timm.models.layers import DropPath, trunc_normal_
14
+ from timm.models import register_model
15
+
16
+ from ..modules.common.mobileone import MobileOneBlock
17
+ from ..modules.image.replknet import ReparamLargeKernelConv
18
+
19
+
20
+ def _cfg(url="", **kwargs):
21
+ return {
22
+ "url": url,
23
+ "num_classes": 1000,
24
+ "input_size": (3, 256, 256),
25
+ "pool_size": None,
26
+ "crop_pct": 0.95,
27
+ "interpolation": "bicubic",
28
+ "mean": IMAGENET_DEFAULT_MEAN,
29
+ "std": IMAGENET_DEFAULT_STD,
30
+ "classifier": "head",
31
+ **kwargs,
32
+ }
33
+
34
+
35
+ default_cfgs = {
36
+ "fastvit_t": _cfg(crop_pct=0.9),
37
+ "fastvit_s": _cfg(crop_pct=0.9),
38
+ "fastvit_m": _cfg(crop_pct=0.95),
39
+ }
40
+
41
+
42
+ def convolutional_stem(
43
+ in_channels: int, out_channels: int, inference_mode: bool = False
44
+ ) -> nn.Sequential:
45
+ """Build convolutional stem with MobileOne blocks.
46
+
47
+ Args:
48
+ in_channels: Number of input channels.
49
+ out_channels: Number of output channels.
50
+ inference_mode: Flag to instantiate model in inference mode. Default: ``False``
51
+
52
+ Returns:
53
+ nn.Sequential object with stem elements.
54
+ """
55
+ return nn.Sequential(
56
+ MobileOneBlock(
57
+ in_channels=in_channels,
58
+ out_channels=out_channels,
59
+ kernel_size=3,
60
+ stride=2,
61
+ padding=1,
62
+ groups=1,
63
+ inference_mode=inference_mode,
64
+ use_se=False,
65
+ num_conv_branches=1,
66
+ ),
67
+ MobileOneBlock(
68
+ in_channels=out_channels,
69
+ out_channels=out_channels,
70
+ kernel_size=3,
71
+ stride=2,
72
+ padding=1,
73
+ groups=out_channels,
74
+ inference_mode=inference_mode,
75
+ use_se=False,
76
+ num_conv_branches=1,
77
+ ),
78
+ MobileOneBlock(
79
+ in_channels=out_channels,
80
+ out_channels=out_channels,
81
+ kernel_size=1,
82
+ stride=1,
83
+ padding=0,
84
+ groups=1,
85
+ inference_mode=inference_mode,
86
+ use_se=False,
87
+ num_conv_branches=1,
88
+ ),
89
+ )
90
+
91
+
92
+ class MHSA(nn.Module):
93
+ """Multi-headed Self Attention module.
94
+
95
+ Source modified from:
96
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ dim: int,
102
+ head_dim: int = 32,
103
+ qkv_bias: bool = False,
104
+ attn_drop: float = 0.0,
105
+ proj_drop: float = 0.0,
106
+ ) -> None:
107
+ """Build MHSA module that can handle 3D or 4D input tensors.
108
+
109
+ Args:
110
+ dim: Number of embedding dimensions.
111
+ head_dim: Number of hidden dimensions per head. Default: ``32``
112
+ qkv_bias: Use bias or not. Default: ``False``
113
+ attn_drop: Dropout rate for attention tensor.
114
+ proj_drop: Dropout rate for projection tensor.
115
+ """
116
+ super().__init__()
117
+ assert dim % head_dim == 0, "dim should be divisible by head_dim"
118
+ self.head_dim = head_dim
119
+ self.num_heads = dim // head_dim
120
+ self.scale = head_dim**-0.5
121
+
122
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
123
+ self.attn_drop = nn.Dropout(attn_drop)
124
+ self.proj = nn.Linear(dim, dim)
125
+ self.proj_drop = nn.Dropout(proj_drop)
126
+
127
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
128
+ shape = x.shape
129
+ B, C, H, W = shape
130
+ N = H * W
131
+ if len(shape) == 4:
132
+ x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
133
+ qkv = (
134
+ self.qkv(x)
135
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
136
+ .permute(2, 0, 3, 1, 4)
137
+ )
138
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
139
+
140
+ # trick here to make q@k.t more stable
141
+ attn = (q * self.scale) @ k.transpose(-2, -1)
142
+ attn = attn.softmax(dim=-1)
143
+ attn = self.attn_drop(attn)
144
+
145
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
146
+ x = self.proj(x)
147
+ x = self.proj_drop(x)
148
+ if len(shape) == 4:
149
+ x = x.transpose(-2, -1).reshape(B, C, H, W)
150
+
151
+ return x
152
+
153
+
154
+ class PatchEmbed(nn.Module):
155
+ """Convolutional patch embedding layer."""
156
+
157
+ def __init__(
158
+ self,
159
+ patch_size: int,
160
+ stride: int,
161
+ in_channels: int,
162
+ embed_dim: int,
163
+ inference_mode: bool = False,
164
+ use_se: bool = False,
165
+ ) -> None:
166
+ """Build patch embedding layer.
167
+
168
+ Args:
169
+ patch_size: Patch size for embedding computation.
170
+ stride: Stride for convolutional embedding layer.
171
+ in_channels: Number of channels of input tensor.
172
+ embed_dim: Number of embedding dimensions.
173
+ inference_mode: Flag to instantiate model in inference mode. Default: ``False``
174
+ use_se: If ``True`` SE block will be used.
175
+ """
176
+ super().__init__()
177
+ block = list()
178
+ block.append(
179
+ ReparamLargeKernelConv(
180
+ in_channels=in_channels,
181
+ out_channels=embed_dim,
182
+ kernel_size=patch_size,
183
+ stride=stride,
184
+ groups=in_channels,
185
+ small_kernel=3,
186
+ inference_mode=inference_mode,
187
+ use_se=use_se,
188
+ )
189
+ )
190
+ block.append(
191
+ MobileOneBlock(
192
+ in_channels=embed_dim,
193
+ out_channels=embed_dim,
194
+ kernel_size=1,
195
+ stride=1,
196
+ padding=0,
197
+ groups=1,
198
+ inference_mode=inference_mode,
199
+ use_se=False,
200
+ num_conv_branches=1,
201
+ )
202
+ )
203
+ self.proj = nn.Sequential(*block)
204
+
205
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
206
+ x = self.proj(x)
207
+ return x
208
+
209
+
210
+ class RepMixer(nn.Module):
211
+ """Reparameterizable token mixer.
212
+
213
+ For more details, please refer to our paper:
214
+ `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
215
+ """
216
+
217
+ def __init__(
218
+ self,
219
+ dim,
220
+ kernel_size=3,
221
+ use_layer_scale=True,
222
+ layer_scale_init_value=1e-5,
223
+ inference_mode: bool = False,
224
+ ):
225
+ """Build RepMixer Module.
226
+
227
+ Args:
228
+ dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
229
+ kernel_size: Kernel size for spatial mixing. Default: 3
230
+ use_layer_scale: If True, learnable layer scale is used. Default: ``True``
231
+ layer_scale_init_value: Initial value for layer scale. Default: 1e-5
232
+ inference_mode: If True, instantiates model in inference mode. Default: ``False``
233
+ """
234
+ super().__init__()
235
+ self.dim = dim
236
+ self.kernel_size = kernel_size
237
+ self.inference_mode = inference_mode
238
+
239
+ if inference_mode:
240
+ self.reparam_conv = nn.Conv2d(
241
+ in_channels=self.dim,
242
+ out_channels=self.dim,
243
+ kernel_size=self.kernel_size,
244
+ stride=1,
245
+ padding=self.kernel_size // 2,
246
+ groups=self.dim,
247
+ bias=True,
248
+ )
249
+ else:
250
+ self.norm = MobileOneBlock(
251
+ dim,
252
+ dim,
253
+ kernel_size,
254
+ padding=kernel_size // 2,
255
+ groups=dim,
256
+ use_act=False,
257
+ use_scale_branch=False,
258
+ num_conv_branches=0,
259
+ )
260
+ self.mixer = MobileOneBlock(
261
+ dim,
262
+ dim,
263
+ kernel_size,
264
+ padding=kernel_size // 2,
265
+ groups=dim,
266
+ use_act=False,
267
+ )
268
+ self.use_layer_scale = use_layer_scale
269
+ if use_layer_scale:
270
+ self.layer_scale = nn.Parameter(
271
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
272
+ )
273
+
274
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
275
+ if hasattr(self, "reparam_conv"):
276
+ x = self.reparam_conv(x)
277
+ return x
278
+ else:
279
+ if self.use_layer_scale:
280
+ x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
281
+ else:
282
+ x = x + self.mixer(x) - self.norm(x)
283
+ return x
284
+
285
+ def reparameterize(self) -> None:
286
+ """Reparameterize mixer and norm into a single
287
+ convolutional layer for efficient inference.
288
+ """
289
+ if self.inference_mode:
290
+ return
291
+
292
+ self.mixer.reparameterize()
293
+ self.norm.reparameterize()
294
+
295
+ if self.use_layer_scale:
296
+ w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
297
+ self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
298
+ )
299
+ b = torch.squeeze(self.layer_scale) * (
300
+ self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
301
+ )
302
+ else:
303
+ w = (
304
+ self.mixer.id_tensor
305
+ + self.mixer.reparam_conv.weight
306
+ - self.norm.reparam_conv.weight
307
+ )
308
+ b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
309
+
310
+ self.reparam_conv = nn.Conv2d(
311
+ in_channels=self.dim,
312
+ out_channels=self.dim,
313
+ kernel_size=self.kernel_size,
314
+ stride=1,
315
+ padding=self.kernel_size // 2,
316
+ groups=self.dim,
317
+ bias=True,
318
+ )
319
+ self.reparam_conv.weight.data = w
320
+ self.reparam_conv.bias.data = b
321
+
322
+ for para in self.parameters():
323
+ para.detach_()
324
+ self.__delattr__("mixer")
325
+ self.__delattr__("norm")
326
+ if self.use_layer_scale:
327
+ self.__delattr__("layer_scale")
328
+
329
+
330
+ class ConvFFN(nn.Module):
331
+ """Convolutional FFN Module."""
332
+
333
+ def __init__(
334
+ self,
335
+ in_channels: int,
336
+ hidden_channels: Optional[int] = None,
337
+ out_channels: Optional[int] = None,
338
+ act_layer: nn.Module = nn.GELU,
339
+ drop: float = 0.0,
340
+ ) -> None:
341
+ """Build convolutional FFN module.
342
+
343
+ Args:
344
+ in_channels: Number of input channels.
345
+ hidden_channels: Number of channels after expansion. Default: None
346
+ out_channels: Number of output channels. Default: None
347
+ act_layer: Activation layer. Default: ``GELU``
348
+ drop: Dropout rate. Default: ``0.0``.
349
+ """
350
+ super().__init__()
351
+ out_channels = out_channels or in_channels
352
+ hidden_channels = hidden_channels or in_channels
353
+ self.conv = nn.Sequential()
354
+ self.conv.add_module(
355
+ "conv",
356
+ nn.Conv2d(
357
+ in_channels=in_channels,
358
+ out_channels=out_channels,
359
+ kernel_size=7,
360
+ padding=3,
361
+ groups=in_channels,
362
+ bias=False,
363
+ ),
364
+ )
365
+ self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
366
+ self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
367
+ self.act = act_layer()
368
+ self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
369
+ self.drop = nn.Dropout(drop)
370
+ self.apply(self._init_weights)
371
+
372
+ def _init_weights(self, m: nn.Module) -> None:
373
+ if isinstance(m, nn.Conv2d):
374
+ trunc_normal_(m.weight, std=0.02)
375
+ if m.bias is not None:
376
+ nn.init.constant_(m.bias, 0)
377
+
378
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
379
+ x = self.conv(x)
380
+ x = self.fc1(x)
381
+ x = self.act(x)
382
+ x = self.drop(x)
383
+ x = self.fc2(x)
384
+ x = self.drop(x)
385
+ return x
386
+
387
+
388
+ class RepCPE(nn.Module):
389
+ """Implementation of conditional positional encoding.
390
+
391
+ For more details refer to paper:
392
+ `Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_
393
+
394
+ In our implementation, we can reparameterize this module to eliminate a skip connection.
395
+ """
396
+
397
+ def __init__(
398
+ self,
399
+ in_channels: int,
400
+ embed_dim: int = 768,
401
+ spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
402
+ inference_mode=False,
403
+ ) -> None:
404
+ """Build reparameterizable conditional positional encoding
405
+
406
+ Args:
407
+ in_channels: Number of input channels.
408
+ embed_dim: Number of embedding dimensions. Default: 768
409
+ spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
410
+ inference_mode: Flag to instantiate block in inference mode. Default: ``False``
411
+ """
412
+ super(RepCPE, self).__init__()
413
+ if isinstance(spatial_shape, int):
414
+ spatial_shape = tuple([spatial_shape] * 2)
415
+ assert isinstance(spatial_shape, Tuple), (
416
+ f'"spatial_shape" must by a sequence or int, '
417
+ f"get {type(spatial_shape)} instead."
418
+ )
419
+ assert len(spatial_shape) == 2, (
420
+ f'Length of "spatial_shape" should be 2, '
421
+ f"got {len(spatial_shape)} instead."
422
+ )
423
+
424
+ self.spatial_shape = spatial_shape
425
+ self.embed_dim = embed_dim
426
+ self.in_channels = in_channels
427
+ self.groups = embed_dim
428
+
429
+ if inference_mode:
430
+ self.reparam_conv = nn.Conv2d(
431
+ in_channels=self.in_channels,
432
+ out_channels=self.embed_dim,
433
+ kernel_size=self.spatial_shape,
434
+ stride=1,
435
+ padding=int(self.spatial_shape[0] // 2),
436
+ groups=self.embed_dim,
437
+ bias=True,
438
+ )
439
+ else:
440
+ self.pe = nn.Conv2d(
441
+ in_channels,
442
+ embed_dim,
443
+ spatial_shape,
444
+ 1,
445
+ int(spatial_shape[0] // 2),
446
+ bias=True,
447
+ groups=embed_dim,
448
+ )
449
+
450
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
451
+ if hasattr(self, "reparam_conv"):
452
+ x = self.reparam_conv(x)
453
+ return x
454
+ else:
455
+ x = self.pe(x) + x
456
+ return x
457
+
458
+ def reparameterize(self) -> None:
459
+ # Build equivalent Id tensor
460
+ input_dim = self.in_channels // self.groups
461
+ kernel_value = torch.zeros(
462
+ (
463
+ self.in_channels,
464
+ input_dim,
465
+ self.spatial_shape[0],
466
+ self.spatial_shape[1],
467
+ ),
468
+ dtype=self.pe.weight.dtype,
469
+ device=self.pe.weight.device,
470
+ )
471
+ for i in range(self.in_channels):
472
+ kernel_value[
473
+ i,
474
+ i % input_dim,
475
+ self.spatial_shape[0] // 2,
476
+ self.spatial_shape[1] // 2,
477
+ ] = 1
478
+ id_tensor = kernel_value
479
+
480
+ # Reparameterize Id tensor and conv
481
+ w_final = id_tensor + self.pe.weight
482
+ b_final = self.pe.bias
483
+
484
+ # Introduce reparam conv
485
+ self.reparam_conv = nn.Conv2d(
486
+ in_channels=self.in_channels,
487
+ out_channels=self.embed_dim,
488
+ kernel_size=self.spatial_shape,
489
+ stride=1,
490
+ padding=int(self.spatial_shape[0] // 2),
491
+ groups=self.embed_dim,
492
+ bias=True,
493
+ )
494
+ self.reparam_conv.weight.data = w_final
495
+ self.reparam_conv.bias.data = b_final
496
+
497
+ for para in self.parameters():
498
+ para.detach_()
499
+ self.__delattr__("pe")
500
+
501
+
502
+ class RepMixerBlock(nn.Module):
503
+ """Implementation of Metaformer block with RepMixer as token mixer.
504
+
505
+ For more details on Metaformer structure, please refer to:
506
+ `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
507
+ """
508
+
509
+ def __init__(
510
+ self,
511
+ dim: int,
512
+ kernel_size: int = 3,
513
+ mlp_ratio: float = 4.0,
514
+ act_layer: nn.Module = nn.GELU,
515
+ drop: float = 0.0,
516
+ drop_path: float = 0.0,
517
+ use_layer_scale: bool = True,
518
+ layer_scale_init_value: float = 1e-5,
519
+ inference_mode: bool = False,
520
+ ):
521
+ """Build RepMixer Block.
522
+
523
+ Args:
524
+ dim: Number of embedding dimensions.
525
+ kernel_size: Kernel size for repmixer. Default: 3
526
+ mlp_ratio: MLP expansion ratio. Default: 4.0
527
+ act_layer: Activation layer. Default: ``nn.GELU``
528
+ drop: Dropout rate. Default: 0.0
529
+ drop_path: Drop path rate. Default: 0.0
530
+ use_layer_scale: Flag to turn on layer scale. Default: ``True``
531
+ layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
532
+ inference_mode: Flag to instantiate block in inference mode. Default: ``False``
533
+ """
534
+
535
+ super().__init__()
536
+
537
+ self.token_mixer = RepMixer(
538
+ dim,
539
+ kernel_size=kernel_size,
540
+ use_layer_scale=use_layer_scale,
541
+ layer_scale_init_value=layer_scale_init_value,
542
+ inference_mode=inference_mode,
543
+ )
544
+
545
+ assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
546
+ mlp_ratio
547
+ )
548
+ mlp_hidden_dim = int(dim * mlp_ratio)
549
+ self.convffn = ConvFFN(
550
+ in_channels=dim,
551
+ hidden_channels=mlp_hidden_dim,
552
+ act_layer=act_layer,
553
+ drop=drop,
554
+ )
555
+
556
+ # Drop Path
557
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
558
+
559
+ # Layer Scale
560
+ self.use_layer_scale = use_layer_scale
561
+ if use_layer_scale:
562
+ self.layer_scale = nn.Parameter(
563
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
564
+ )
565
+
566
+ def forward(self, x):
567
+ if self.use_layer_scale:
568
+ x = self.token_mixer(x)
569
+ x = x + self.drop_path(self.layer_scale * self.convffn(x))
570
+ else:
571
+ x = self.token_mixer(x)
572
+ x = x + self.drop_path(self.convffn(x))
573
+ return x
574
+
575
+
576
+ class AttentionBlock(nn.Module):
577
+ """Implementation of metaformer block with MHSA as token mixer.
578
+
579
+ For more details on Metaformer structure, please refer to:
580
+ `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
581
+ """
582
+
583
+ def __init__(
584
+ self,
585
+ dim: int,
586
+ mlp_ratio: float = 4.0,
587
+ act_layer: nn.Module = nn.GELU,
588
+ norm_layer: nn.Module = nn.BatchNorm2d,
589
+ drop: float = 0.0,
590
+ drop_path: float = 0.0,
591
+ use_layer_scale: bool = True,
592
+ layer_scale_init_value: float = 1e-5,
593
+ ):
594
+ """Build Attention Block.
595
+
596
+ Args:
597
+ dim: Number of embedding dimensions.
598
+ mlp_ratio: MLP expansion ratio. Default: 4.0
599
+ act_layer: Activation layer. Default: ``nn.GELU``
600
+ norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
601
+ drop: Dropout rate. Default: 0.0
602
+ drop_path: Drop path rate. Default: 0.0
603
+ use_layer_scale: Flag to turn on layer scale. Default: ``True``
604
+ layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
605
+ """
606
+
607
+ super().__init__()
608
+
609
+ self.norm = norm_layer(dim)
610
+ self.token_mixer = MHSA(dim=dim)
611
+
612
+ assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
613
+ mlp_ratio
614
+ )
615
+ mlp_hidden_dim = int(dim * mlp_ratio)
616
+ self.convffn = ConvFFN(
617
+ in_channels=dim,
618
+ hidden_channels=mlp_hidden_dim,
619
+ act_layer=act_layer,
620
+ drop=drop,
621
+ )
622
+
623
+ # Drop path
624
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
625
+
626
+ # Layer Scale
627
+ self.use_layer_scale = use_layer_scale
628
+ if use_layer_scale:
629
+ self.layer_scale_1 = nn.Parameter(
630
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
631
+ )
632
+ self.layer_scale_2 = nn.Parameter(
633
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
634
+ )
635
+
636
+ def forward(self, x):
637
+ if self.use_layer_scale:
638
+ x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
639
+ x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
640
+ else:
641
+ x = x + self.drop_path(self.token_mixer(self.norm(x)))
642
+ x = x + self.drop_path(self.convffn(x))
643
+ return x
644
+
645
+
646
+ def basic_blocks(
647
+ dim: int,
648
+ block_index: int,
649
+ num_blocks: List[int],
650
+ token_mixer_type: str,
651
+ kernel_size: int = 3,
652
+ mlp_ratio: float = 4.0,
653
+ act_layer: nn.Module = nn.GELU,
654
+ norm_layer: nn.Module = nn.BatchNorm2d,
655
+ drop_rate: float = 0.0,
656
+ drop_path_rate: float = 0.0,
657
+ use_layer_scale: bool = True,
658
+ layer_scale_init_value: float = 1e-5,
659
+ inference_mode=False,
660
+ ) -> nn.Sequential:
661
+ """Build FastViT blocks within a stage.
662
+
663
+ Args:
664
+ dim: Number of embedding dimensions.
665
+ block_index: block index.
666
+ num_blocks: List containing number of blocks per stage.
667
+ token_mixer_type: Token mixer type.
668
+ kernel_size: Kernel size for repmixer.
669
+ mlp_ratio: MLP expansion ratio.
670
+ act_layer: Activation layer.
671
+ norm_layer: Normalization layer.
672
+ drop_rate: Dropout rate.
673
+ drop_path_rate: Drop path rate.
674
+ use_layer_scale: Flag to turn on layer scale regularization.
675
+ layer_scale_init_value: Layer scale value at initialization.
676
+ inference_mode: Flag to instantiate block in inference mode.
677
+
678
+ Returns:
679
+ nn.Sequential object of all the blocks within the stage.
680
+ """
681
+ blocks = []
682
+ for block_idx in range(num_blocks[block_index]):
683
+ block_dpr = (
684
+ drop_path_rate
685
+ * (block_idx + sum(num_blocks[:block_index]))
686
+ / (sum(num_blocks) - 1)
687
+ )
688
+ if token_mixer_type == "repmixer":
689
+ blocks.append(
690
+ RepMixerBlock(
691
+ dim,
692
+ kernel_size=kernel_size,
693
+ mlp_ratio=mlp_ratio,
694
+ act_layer=act_layer,
695
+ drop=drop_rate,
696
+ drop_path=block_dpr,
697
+ use_layer_scale=use_layer_scale,
698
+ layer_scale_init_value=layer_scale_init_value,
699
+ inference_mode=inference_mode,
700
+ )
701
+ )
702
+ elif token_mixer_type == "attention":
703
+ blocks.append(
704
+ AttentionBlock(
705
+ dim,
706
+ mlp_ratio=mlp_ratio,
707
+ act_layer=act_layer,
708
+ norm_layer=norm_layer,
709
+ drop=drop_rate,
710
+ drop_path=block_dpr,
711
+ use_layer_scale=use_layer_scale,
712
+ layer_scale_init_value=layer_scale_init_value,
713
+ )
714
+ )
715
+ else:
716
+ raise ValueError(
717
+ "Token mixer type: {} not supported".format(token_mixer_type)
718
+ )
719
+ blocks = nn.Sequential(*blocks)
720
+
721
+ return blocks
722
+
723
+
724
+ class FastViT(nn.Module):
725
+ """
726
+ This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_
727
+ """
728
+
729
+ def __init__(
730
+ self,
731
+ layers,
732
+ token_mixers: Tuple[str, ...],
733
+ embed_dims=None,
734
+ mlp_ratios=None,
735
+ downsamples=None,
736
+ se_downsamples=None,
737
+ repmixer_kernel_size=3,
738
+ norm_layer: nn.Module = nn.BatchNorm2d,
739
+ act_layer: nn.Module = nn.GELU,
740
+ num_classes=1000,
741
+ pos_embs=None,
742
+ down_patch_size=7,
743
+ down_stride=2,
744
+ drop_rate=0.0,
745
+ drop_path_rate=0.0,
746
+ use_layer_scale=True,
747
+ layer_scale_init_value=1e-5,
748
+ init_cfg=None,
749
+ pretrained=None,
750
+ cls_ratio=2.0,
751
+ inference_mode=False,
752
+ **kwargs,
753
+ ) -> None:
754
+
755
+ super().__init__()
756
+
757
+ self.num_classes = num_classes
758
+ if pos_embs is None:
759
+ pos_embs = [None] * len(layers)
760
+
761
+ if se_downsamples is None:
762
+ se_downsamples = [False] * len(layers)
763
+
764
+ # Convolutional stem
765
+ self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode)
766
+
767
+ # Build the main stages of the network architecture
768
+ network = []
769
+ for i in range(len(layers)):
770
+ # Add position embeddings if requested
771
+ if pos_embs[i] is not None:
772
+ network.append(
773
+ pos_embs[i](
774
+ embed_dims[i], embed_dims[i], inference_mode=inference_mode
775
+ )
776
+ )
777
+ stage = basic_blocks(
778
+ embed_dims[i],
779
+ i,
780
+ layers,
781
+ token_mixer_type=token_mixers[i],
782
+ kernel_size=repmixer_kernel_size,
783
+ mlp_ratio=mlp_ratios[i],
784
+ act_layer=act_layer,
785
+ norm_layer=norm_layer,
786
+ drop_rate=drop_rate,
787
+ drop_path_rate=drop_path_rate,
788
+ use_layer_scale=use_layer_scale,
789
+ layer_scale_init_value=layer_scale_init_value,
790
+ inference_mode=inference_mode,
791
+ )
792
+ network.append(stage)
793
+ if i >= len(layers) - 1:
794
+ break
795
+
796
+ # Patch merging/downsampling between stages.
797
+ if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
798
+ network.append(
799
+ PatchEmbed(
800
+ patch_size=down_patch_size,
801
+ stride=down_stride,
802
+ in_channels=embed_dims[i],
803
+ embed_dim=embed_dims[i + 1],
804
+ inference_mode=inference_mode,
805
+ use_se=se_downsamples[i + 1],
806
+ )
807
+ )
808
+ self.network = nn.ModuleList(network)
809
+
810
+ # Classifier head
811
+ self.conv_exp = MobileOneBlock(
812
+ in_channels=embed_dims[-1],
813
+ out_channels=int(embed_dims[-1] * cls_ratio),
814
+ kernel_size=3,
815
+ stride=1,
816
+ padding=1,
817
+ groups=embed_dims[-1],
818
+ inference_mode=inference_mode,
819
+ use_se=True,
820
+ num_conv_branches=1,
821
+ )
822
+ self.head = (
823
+ nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes)
824
+ if num_classes > 0
825
+ else nn.Identity()
826
+ )
827
+ self.apply(self.cls_init_weights)
828
+ self.init_cfg = copy.deepcopy(init_cfg)
829
+
830
+ def cls_init_weights(self, m: nn.Module) -> None:
831
+ """Init. for classification"""
832
+ if isinstance(m, nn.Linear):
833
+ trunc_normal_(m.weight, std=0.02)
834
+ if isinstance(m, nn.Linear) and m.bias is not None:
835
+ nn.init.constant_(m.bias, 0)
836
+
837
+ def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor:
838
+ x = self.patch_embed(x)
839
+ return x
840
+
841
+ def forward_tokens(self, x: torch.Tensor) -> torch.Tensor:
842
+ for idx, block in enumerate(self.network):
843
+ x = block(x)
844
+ # output only the features of last layer for image classification
845
+ return x
846
+
847
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
848
+ # input embedding
849
+ x = self.forward_embeddings(x)
850
+ # through backbone
851
+ x = self.forward_tokens(x)
852
+ # for image classification
853
+ x = self.conv_exp(x)
854
+ cls_out = self.head(x)
855
+ return cls_out
856
+
857
+
858
+ @register_model
859
+ def mci0(pretrained=False, **kwargs):
860
+ """Instantiate MCi0 model variant."""
861
+ layers = [2, 6, 10, 2]
862
+ embed_dims = [64, 128, 256, 512]
863
+ mlp_ratios = [3, 3, 3, 3]
864
+ downsamples = [True, True, True, True]
865
+ se_downsamples = [False, False, True, True]
866
+ pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
867
+ token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
868
+ model = FastViT(
869
+ layers,
870
+ token_mixers=token_mixers,
871
+ embed_dims=embed_dims,
872
+ pos_embs=pos_embs,
873
+ mlp_ratios=mlp_ratios,
874
+ downsamples=downsamples,
875
+ se_downsamples=se_downsamples,
876
+ **kwargs,
877
+ )
878
+ model.default_cfg = default_cfgs["fastvit_s"]
879
+ if pretrained:
880
+ raise ValueError("Functionality not implemented.")
881
+ return model
882
+
883
+
884
+ @register_model
885
+ def mci1(pretrained=False, **kwargs):
886
+ """Instantiate MCi1 model variant."""
887
+ layers = [4, 12, 20, 4]
888
+ embed_dims = [64, 128, 256, 512]
889
+ mlp_ratios = [3, 3, 3, 3]
890
+ downsamples = [True, True, True, True]
891
+ se_downsamples = [False, False, True, True]
892
+ pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
893
+ token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
894
+ model = FastViT(
895
+ layers,
896
+ token_mixers=token_mixers,
897
+ embed_dims=embed_dims,
898
+ pos_embs=pos_embs,
899
+ mlp_ratios=mlp_ratios,
900
+ downsamples=downsamples,
901
+ se_downsamples=se_downsamples,
902
+ **kwargs,
903
+ )
904
+ model.default_cfg = default_cfgs["fastvit_s"]
905
+ if pretrained:
906
+ raise ValueError("Functionality not implemented.")
907
+ return model
908
+
909
+
910
+ @register_model
911
+ def mci2(pretrained=False, **kwargs):
912
+ """Instantiate MCi2 model variant."""
913
+ layers = [4, 12, 24, 4]
914
+ embed_dims = [80, 160, 320, 640]
915
+ mlp_ratios = [3, 3, 3, 3]
916
+ downsamples = [True, True, True, True]
917
+ se_downsamples = [False, False, True, True]
918
+ pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
919
+ token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
920
+ model = FastViT(
921
+ layers,
922
+ token_mixers=token_mixers,
923
+ embed_dims=embed_dims,
924
+ pos_embs=pos_embs,
925
+ mlp_ratios=mlp_ratios,
926
+ downsamples=downsamples,
927
+ se_downsamples=se_downsamples,
928
+ **kwargs,
929
+ )
930
+ model.default_cfg = default_cfgs["fastvit_m"]
931
+ if pretrained:
932
+ raise ValueError("Functionality not implemented.")
933
+ return model