lightly-studio 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lightly-studio might be problematic. Click here for more details.
- lightly_studio/__init__.py +11 -0
- lightly_studio/api/__init__.py +0 -0
- lightly_studio/api/app.py +110 -0
- lightly_studio/api/cache.py +77 -0
- lightly_studio/api/db.py +133 -0
- lightly_studio/api/db_tables.py +32 -0
- lightly_studio/api/features.py +7 -0
- lightly_studio/api/routes/api/annotation.py +233 -0
- lightly_studio/api/routes/api/annotation_label.py +90 -0
- lightly_studio/api/routes/api/annotation_task.py +38 -0
- lightly_studio/api/routes/api/classifier.py +387 -0
- lightly_studio/api/routes/api/dataset.py +182 -0
- lightly_studio/api/routes/api/dataset_tag.py +257 -0
- lightly_studio/api/routes/api/exceptions.py +96 -0
- lightly_studio/api/routes/api/features.py +17 -0
- lightly_studio/api/routes/api/metadata.py +37 -0
- lightly_studio/api/routes/api/metrics.py +80 -0
- lightly_studio/api/routes/api/sample.py +196 -0
- lightly_studio/api/routes/api/settings.py +45 -0
- lightly_studio/api/routes/api/status.py +19 -0
- lightly_studio/api/routes/api/text_embedding.py +48 -0
- lightly_studio/api/routes/api/validators.py +17 -0
- lightly_studio/api/routes/healthz.py +13 -0
- lightly_studio/api/routes/images.py +104 -0
- lightly_studio/api/routes/webapp.py +51 -0
- lightly_studio/api/server.py +82 -0
- lightly_studio/core/__init__.py +0 -0
- lightly_studio/core/dataset.py +523 -0
- lightly_studio/core/sample.py +77 -0
- lightly_studio/core/start_gui.py +15 -0
- lightly_studio/dataset/__init__.py +0 -0
- lightly_studio/dataset/edge_embedding_generator.py +144 -0
- lightly_studio/dataset/embedding_generator.py +91 -0
- lightly_studio/dataset/embedding_manager.py +163 -0
- lightly_studio/dataset/env.py +16 -0
- lightly_studio/dataset/file_utils.py +35 -0
- lightly_studio/dataset/loader.py +622 -0
- lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
- lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
- lightly_studio/examples/example.py +23 -0
- lightly_studio/examples/example_metadata.py +338 -0
- lightly_studio/examples/example_selection.py +39 -0
- lightly_studio/examples/example_split_work.py +67 -0
- lightly_studio/examples/example_v2.py +21 -0
- lightly_studio/export_schema.py +18 -0
- lightly_studio/few_shot_classifier/__init__.py +0 -0
- lightly_studio/few_shot_classifier/classifier.py +80 -0
- lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
- lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
- lightly_studio/metadata/complex_metadata.py +47 -0
- lightly_studio/metadata/gps_coordinate.py +41 -0
- lightly_studio/metadata/metadata_protocol.py +17 -0
- lightly_studio/metrics/__init__.py +0 -0
- lightly_studio/metrics/detection/__init__.py +0 -0
- lightly_studio/metrics/detection/map.py +268 -0
- lightly_studio/models/__init__.py +1 -0
- lightly_studio/models/annotation/__init__.py +0 -0
- lightly_studio/models/annotation/annotation_base.py +171 -0
- lightly_studio/models/annotation/instance_segmentation.py +56 -0
- lightly_studio/models/annotation/links.py +17 -0
- lightly_studio/models/annotation/object_detection.py +47 -0
- lightly_studio/models/annotation/semantic_segmentation.py +44 -0
- lightly_studio/models/annotation_label.py +47 -0
- lightly_studio/models/annotation_task.py +28 -0
- lightly_studio/models/classifier.py +20 -0
- lightly_studio/models/dataset.py +84 -0
- lightly_studio/models/embedding_model.py +30 -0
- lightly_studio/models/metadata.py +208 -0
- lightly_studio/models/sample.py +180 -0
- lightly_studio/models/sample_embedding.py +37 -0
- lightly_studio/models/settings.py +60 -0
- lightly_studio/models/tag.py +96 -0
- lightly_studio/py.typed +0 -0
- lightly_studio/resolvers/__init__.py +7 -0
- lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
- lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
- lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
- lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
- lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
- lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
- lightly_studio/resolvers/annotation_resolver/create.py +19 -0
- lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
- lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
- lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
- lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
- lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
- lightly_studio/resolvers/annotation_task_resolver.py +31 -0
- lightly_studio/resolvers/annotations/__init__.py +1 -0
- lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
- lightly_studio/resolvers/dataset_resolver.py +278 -0
- lightly_studio/resolvers/embedding_model_resolver.py +100 -0
- lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
- lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
- lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
- lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
- lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
- lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
- lightly_studio/resolvers/sample_resolver.py +249 -0
- lightly_studio/resolvers/samples_filter.py +81 -0
- lightly_studio/resolvers/settings_resolver.py +58 -0
- lightly_studio/resolvers/tag_resolver.py +276 -0
- lightly_studio/selection/README.md +6 -0
- lightly_studio/selection/mundig.py +105 -0
- lightly_studio/selection/select.py +96 -0
- lightly_studio/selection/select_via_db.py +93 -0
- lightly_studio/selection/selection_config.py +31 -0
- lightly_studio/services/annotations_service/__init__.py +21 -0
- lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
- lightly_studio/services/annotations_service/update_annotation.py +65 -0
- lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
- lightly_studio/services/annotations_service/update_annotations.py +29 -0
- lightly_studio/setup_logging.py +19 -0
- lightly_studio/type_definitions.py +19 -0
- lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
- lightly_studio/vendor/LICENSE +31 -0
- lightly_studio/vendor/LICENSE_weights_data +50 -0
- lightly_studio/vendor/README.md +5 -0
- lightly_studio/vendor/__init__.py +1 -0
- lightly_studio/vendor/mobileclip/__init__.py +96 -0
- lightly_studio/vendor/mobileclip/clip.py +77 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
- lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
- lightly_studio/vendor/mobileclip/logger.py +154 -0
- lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
- lightly_studio/vendor/mobileclip/models/mci.py +933 -0
- lightly_studio/vendor/mobileclip/models/vit.py +433 -0
- lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
- lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
- lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
- lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
- lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
- lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
- lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
- lightly_studio-0.3.1.dist-info/METADATA +520 -0
- lightly_studio-0.3.1.dist-info/RECORD +219 -0
- lightly_studio-0.3.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,933 @@
|
|
|
1
|
+
#
|
|
2
|
+
# For licensing see accompanying LICENSE file.
|
|
3
|
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
4
|
+
#
|
|
5
|
+
import copy
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import List, Tuple, Optional, Union
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
|
|
12
|
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
13
|
+
from timm.models.layers import DropPath, trunc_normal_
|
|
14
|
+
from timm.models import register_model
|
|
15
|
+
|
|
16
|
+
from ..modules.common.mobileone import MobileOneBlock
|
|
17
|
+
from ..modules.image.replknet import ReparamLargeKernelConv
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _cfg(url="", **kwargs):
|
|
21
|
+
return {
|
|
22
|
+
"url": url,
|
|
23
|
+
"num_classes": 1000,
|
|
24
|
+
"input_size": (3, 256, 256),
|
|
25
|
+
"pool_size": None,
|
|
26
|
+
"crop_pct": 0.95,
|
|
27
|
+
"interpolation": "bicubic",
|
|
28
|
+
"mean": IMAGENET_DEFAULT_MEAN,
|
|
29
|
+
"std": IMAGENET_DEFAULT_STD,
|
|
30
|
+
"classifier": "head",
|
|
31
|
+
**kwargs,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
default_cfgs = {
|
|
36
|
+
"fastvit_t": _cfg(crop_pct=0.9),
|
|
37
|
+
"fastvit_s": _cfg(crop_pct=0.9),
|
|
38
|
+
"fastvit_m": _cfg(crop_pct=0.95),
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def convolutional_stem(
|
|
43
|
+
in_channels: int, out_channels: int, inference_mode: bool = False
|
|
44
|
+
) -> nn.Sequential:
|
|
45
|
+
"""Build convolutional stem with MobileOne blocks.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
in_channels: Number of input channels.
|
|
49
|
+
out_channels: Number of output channels.
|
|
50
|
+
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
nn.Sequential object with stem elements.
|
|
54
|
+
"""
|
|
55
|
+
return nn.Sequential(
|
|
56
|
+
MobileOneBlock(
|
|
57
|
+
in_channels=in_channels,
|
|
58
|
+
out_channels=out_channels,
|
|
59
|
+
kernel_size=3,
|
|
60
|
+
stride=2,
|
|
61
|
+
padding=1,
|
|
62
|
+
groups=1,
|
|
63
|
+
inference_mode=inference_mode,
|
|
64
|
+
use_se=False,
|
|
65
|
+
num_conv_branches=1,
|
|
66
|
+
),
|
|
67
|
+
MobileOneBlock(
|
|
68
|
+
in_channels=out_channels,
|
|
69
|
+
out_channels=out_channels,
|
|
70
|
+
kernel_size=3,
|
|
71
|
+
stride=2,
|
|
72
|
+
padding=1,
|
|
73
|
+
groups=out_channels,
|
|
74
|
+
inference_mode=inference_mode,
|
|
75
|
+
use_se=False,
|
|
76
|
+
num_conv_branches=1,
|
|
77
|
+
),
|
|
78
|
+
MobileOneBlock(
|
|
79
|
+
in_channels=out_channels,
|
|
80
|
+
out_channels=out_channels,
|
|
81
|
+
kernel_size=1,
|
|
82
|
+
stride=1,
|
|
83
|
+
padding=0,
|
|
84
|
+
groups=1,
|
|
85
|
+
inference_mode=inference_mode,
|
|
86
|
+
use_se=False,
|
|
87
|
+
num_conv_branches=1,
|
|
88
|
+
),
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class MHSA(nn.Module):
|
|
93
|
+
"""Multi-headed Self Attention module.
|
|
94
|
+
|
|
95
|
+
Source modified from:
|
|
96
|
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
dim: int,
|
|
102
|
+
head_dim: int = 32,
|
|
103
|
+
qkv_bias: bool = False,
|
|
104
|
+
attn_drop: float = 0.0,
|
|
105
|
+
proj_drop: float = 0.0,
|
|
106
|
+
) -> None:
|
|
107
|
+
"""Build MHSA module that can handle 3D or 4D input tensors.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
dim: Number of embedding dimensions.
|
|
111
|
+
head_dim: Number of hidden dimensions per head. Default: ``32``
|
|
112
|
+
qkv_bias: Use bias or not. Default: ``False``
|
|
113
|
+
attn_drop: Dropout rate for attention tensor.
|
|
114
|
+
proj_drop: Dropout rate for projection tensor.
|
|
115
|
+
"""
|
|
116
|
+
super().__init__()
|
|
117
|
+
assert dim % head_dim == 0, "dim should be divisible by head_dim"
|
|
118
|
+
self.head_dim = head_dim
|
|
119
|
+
self.num_heads = dim // head_dim
|
|
120
|
+
self.scale = head_dim**-0.5
|
|
121
|
+
|
|
122
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
123
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
|
124
|
+
self.proj = nn.Linear(dim, dim)
|
|
125
|
+
self.proj_drop = nn.Dropout(proj_drop)
|
|
126
|
+
|
|
127
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
128
|
+
shape = x.shape
|
|
129
|
+
B, C, H, W = shape
|
|
130
|
+
N = H * W
|
|
131
|
+
if len(shape) == 4:
|
|
132
|
+
x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
|
|
133
|
+
qkv = (
|
|
134
|
+
self.qkv(x)
|
|
135
|
+
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
|
136
|
+
.permute(2, 0, 3, 1, 4)
|
|
137
|
+
)
|
|
138
|
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
|
139
|
+
|
|
140
|
+
# trick here to make q@k.t more stable
|
|
141
|
+
attn = (q * self.scale) @ k.transpose(-2, -1)
|
|
142
|
+
attn = attn.softmax(dim=-1)
|
|
143
|
+
attn = self.attn_drop(attn)
|
|
144
|
+
|
|
145
|
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
146
|
+
x = self.proj(x)
|
|
147
|
+
x = self.proj_drop(x)
|
|
148
|
+
if len(shape) == 4:
|
|
149
|
+
x = x.transpose(-2, -1).reshape(B, C, H, W)
|
|
150
|
+
|
|
151
|
+
return x
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class PatchEmbed(nn.Module):
|
|
155
|
+
"""Convolutional patch embedding layer."""
|
|
156
|
+
|
|
157
|
+
def __init__(
|
|
158
|
+
self,
|
|
159
|
+
patch_size: int,
|
|
160
|
+
stride: int,
|
|
161
|
+
in_channels: int,
|
|
162
|
+
embed_dim: int,
|
|
163
|
+
inference_mode: bool = False,
|
|
164
|
+
use_se: bool = False,
|
|
165
|
+
) -> None:
|
|
166
|
+
"""Build patch embedding layer.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
patch_size: Patch size for embedding computation.
|
|
170
|
+
stride: Stride for convolutional embedding layer.
|
|
171
|
+
in_channels: Number of channels of input tensor.
|
|
172
|
+
embed_dim: Number of embedding dimensions.
|
|
173
|
+
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
|
|
174
|
+
use_se: If ``True`` SE block will be used.
|
|
175
|
+
"""
|
|
176
|
+
super().__init__()
|
|
177
|
+
block = list()
|
|
178
|
+
block.append(
|
|
179
|
+
ReparamLargeKernelConv(
|
|
180
|
+
in_channels=in_channels,
|
|
181
|
+
out_channels=embed_dim,
|
|
182
|
+
kernel_size=patch_size,
|
|
183
|
+
stride=stride,
|
|
184
|
+
groups=in_channels,
|
|
185
|
+
small_kernel=3,
|
|
186
|
+
inference_mode=inference_mode,
|
|
187
|
+
use_se=use_se,
|
|
188
|
+
)
|
|
189
|
+
)
|
|
190
|
+
block.append(
|
|
191
|
+
MobileOneBlock(
|
|
192
|
+
in_channels=embed_dim,
|
|
193
|
+
out_channels=embed_dim,
|
|
194
|
+
kernel_size=1,
|
|
195
|
+
stride=1,
|
|
196
|
+
padding=0,
|
|
197
|
+
groups=1,
|
|
198
|
+
inference_mode=inference_mode,
|
|
199
|
+
use_se=False,
|
|
200
|
+
num_conv_branches=1,
|
|
201
|
+
)
|
|
202
|
+
)
|
|
203
|
+
self.proj = nn.Sequential(*block)
|
|
204
|
+
|
|
205
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
206
|
+
x = self.proj(x)
|
|
207
|
+
return x
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class RepMixer(nn.Module):
|
|
211
|
+
"""Reparameterizable token mixer.
|
|
212
|
+
|
|
213
|
+
For more details, please refer to our paper:
|
|
214
|
+
`FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
def __init__(
|
|
218
|
+
self,
|
|
219
|
+
dim,
|
|
220
|
+
kernel_size=3,
|
|
221
|
+
use_layer_scale=True,
|
|
222
|
+
layer_scale_init_value=1e-5,
|
|
223
|
+
inference_mode: bool = False,
|
|
224
|
+
):
|
|
225
|
+
"""Build RepMixer Module.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
|
|
229
|
+
kernel_size: Kernel size for spatial mixing. Default: 3
|
|
230
|
+
use_layer_scale: If True, learnable layer scale is used. Default: ``True``
|
|
231
|
+
layer_scale_init_value: Initial value for layer scale. Default: 1e-5
|
|
232
|
+
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
|
233
|
+
"""
|
|
234
|
+
super().__init__()
|
|
235
|
+
self.dim = dim
|
|
236
|
+
self.kernel_size = kernel_size
|
|
237
|
+
self.inference_mode = inference_mode
|
|
238
|
+
|
|
239
|
+
if inference_mode:
|
|
240
|
+
self.reparam_conv = nn.Conv2d(
|
|
241
|
+
in_channels=self.dim,
|
|
242
|
+
out_channels=self.dim,
|
|
243
|
+
kernel_size=self.kernel_size,
|
|
244
|
+
stride=1,
|
|
245
|
+
padding=self.kernel_size // 2,
|
|
246
|
+
groups=self.dim,
|
|
247
|
+
bias=True,
|
|
248
|
+
)
|
|
249
|
+
else:
|
|
250
|
+
self.norm = MobileOneBlock(
|
|
251
|
+
dim,
|
|
252
|
+
dim,
|
|
253
|
+
kernel_size,
|
|
254
|
+
padding=kernel_size // 2,
|
|
255
|
+
groups=dim,
|
|
256
|
+
use_act=False,
|
|
257
|
+
use_scale_branch=False,
|
|
258
|
+
num_conv_branches=0,
|
|
259
|
+
)
|
|
260
|
+
self.mixer = MobileOneBlock(
|
|
261
|
+
dim,
|
|
262
|
+
dim,
|
|
263
|
+
kernel_size,
|
|
264
|
+
padding=kernel_size // 2,
|
|
265
|
+
groups=dim,
|
|
266
|
+
use_act=False,
|
|
267
|
+
)
|
|
268
|
+
self.use_layer_scale = use_layer_scale
|
|
269
|
+
if use_layer_scale:
|
|
270
|
+
self.layer_scale = nn.Parameter(
|
|
271
|
+
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
275
|
+
if hasattr(self, "reparam_conv"):
|
|
276
|
+
x = self.reparam_conv(x)
|
|
277
|
+
return x
|
|
278
|
+
else:
|
|
279
|
+
if self.use_layer_scale:
|
|
280
|
+
x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
|
|
281
|
+
else:
|
|
282
|
+
x = x + self.mixer(x) - self.norm(x)
|
|
283
|
+
return x
|
|
284
|
+
|
|
285
|
+
def reparameterize(self) -> None:
|
|
286
|
+
"""Reparameterize mixer and norm into a single
|
|
287
|
+
convolutional layer for efficient inference.
|
|
288
|
+
"""
|
|
289
|
+
if self.inference_mode:
|
|
290
|
+
return
|
|
291
|
+
|
|
292
|
+
self.mixer.reparameterize()
|
|
293
|
+
self.norm.reparameterize()
|
|
294
|
+
|
|
295
|
+
if self.use_layer_scale:
|
|
296
|
+
w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
|
|
297
|
+
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
|
298
|
+
)
|
|
299
|
+
b = torch.squeeze(self.layer_scale) * (
|
|
300
|
+
self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
|
301
|
+
)
|
|
302
|
+
else:
|
|
303
|
+
w = (
|
|
304
|
+
self.mixer.id_tensor
|
|
305
|
+
+ self.mixer.reparam_conv.weight
|
|
306
|
+
- self.norm.reparam_conv.weight
|
|
307
|
+
)
|
|
308
|
+
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
|
309
|
+
|
|
310
|
+
self.reparam_conv = nn.Conv2d(
|
|
311
|
+
in_channels=self.dim,
|
|
312
|
+
out_channels=self.dim,
|
|
313
|
+
kernel_size=self.kernel_size,
|
|
314
|
+
stride=1,
|
|
315
|
+
padding=self.kernel_size // 2,
|
|
316
|
+
groups=self.dim,
|
|
317
|
+
bias=True,
|
|
318
|
+
)
|
|
319
|
+
self.reparam_conv.weight.data = w
|
|
320
|
+
self.reparam_conv.bias.data = b
|
|
321
|
+
|
|
322
|
+
for para in self.parameters():
|
|
323
|
+
para.detach_()
|
|
324
|
+
self.__delattr__("mixer")
|
|
325
|
+
self.__delattr__("norm")
|
|
326
|
+
if self.use_layer_scale:
|
|
327
|
+
self.__delattr__("layer_scale")
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
class ConvFFN(nn.Module):
|
|
331
|
+
"""Convolutional FFN Module."""
|
|
332
|
+
|
|
333
|
+
def __init__(
|
|
334
|
+
self,
|
|
335
|
+
in_channels: int,
|
|
336
|
+
hidden_channels: Optional[int] = None,
|
|
337
|
+
out_channels: Optional[int] = None,
|
|
338
|
+
act_layer: nn.Module = nn.GELU,
|
|
339
|
+
drop: float = 0.0,
|
|
340
|
+
) -> None:
|
|
341
|
+
"""Build convolutional FFN module.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
in_channels: Number of input channels.
|
|
345
|
+
hidden_channels: Number of channels after expansion. Default: None
|
|
346
|
+
out_channels: Number of output channels. Default: None
|
|
347
|
+
act_layer: Activation layer. Default: ``GELU``
|
|
348
|
+
drop: Dropout rate. Default: ``0.0``.
|
|
349
|
+
"""
|
|
350
|
+
super().__init__()
|
|
351
|
+
out_channels = out_channels or in_channels
|
|
352
|
+
hidden_channels = hidden_channels or in_channels
|
|
353
|
+
self.conv = nn.Sequential()
|
|
354
|
+
self.conv.add_module(
|
|
355
|
+
"conv",
|
|
356
|
+
nn.Conv2d(
|
|
357
|
+
in_channels=in_channels,
|
|
358
|
+
out_channels=out_channels,
|
|
359
|
+
kernel_size=7,
|
|
360
|
+
padding=3,
|
|
361
|
+
groups=in_channels,
|
|
362
|
+
bias=False,
|
|
363
|
+
),
|
|
364
|
+
)
|
|
365
|
+
self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
|
|
366
|
+
self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
|
|
367
|
+
self.act = act_layer()
|
|
368
|
+
self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
|
|
369
|
+
self.drop = nn.Dropout(drop)
|
|
370
|
+
self.apply(self._init_weights)
|
|
371
|
+
|
|
372
|
+
def _init_weights(self, m: nn.Module) -> None:
|
|
373
|
+
if isinstance(m, nn.Conv2d):
|
|
374
|
+
trunc_normal_(m.weight, std=0.02)
|
|
375
|
+
if m.bias is not None:
|
|
376
|
+
nn.init.constant_(m.bias, 0)
|
|
377
|
+
|
|
378
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
379
|
+
x = self.conv(x)
|
|
380
|
+
x = self.fc1(x)
|
|
381
|
+
x = self.act(x)
|
|
382
|
+
x = self.drop(x)
|
|
383
|
+
x = self.fc2(x)
|
|
384
|
+
x = self.drop(x)
|
|
385
|
+
return x
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class RepCPE(nn.Module):
|
|
389
|
+
"""Implementation of conditional positional encoding.
|
|
390
|
+
|
|
391
|
+
For more details refer to paper:
|
|
392
|
+
`Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_
|
|
393
|
+
|
|
394
|
+
In our implementation, we can reparameterize this module to eliminate a skip connection.
|
|
395
|
+
"""
|
|
396
|
+
|
|
397
|
+
def __init__(
|
|
398
|
+
self,
|
|
399
|
+
in_channels: int,
|
|
400
|
+
embed_dim: int = 768,
|
|
401
|
+
spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
|
|
402
|
+
inference_mode=False,
|
|
403
|
+
) -> None:
|
|
404
|
+
"""Build reparameterizable conditional positional encoding
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
in_channels: Number of input channels.
|
|
408
|
+
embed_dim: Number of embedding dimensions. Default: 768
|
|
409
|
+
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
|
|
410
|
+
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
|
411
|
+
"""
|
|
412
|
+
super(RepCPE, self).__init__()
|
|
413
|
+
if isinstance(spatial_shape, int):
|
|
414
|
+
spatial_shape = tuple([spatial_shape] * 2)
|
|
415
|
+
assert isinstance(spatial_shape, Tuple), (
|
|
416
|
+
f'"spatial_shape" must by a sequence or int, '
|
|
417
|
+
f"get {type(spatial_shape)} instead."
|
|
418
|
+
)
|
|
419
|
+
assert len(spatial_shape) == 2, (
|
|
420
|
+
f'Length of "spatial_shape" should be 2, '
|
|
421
|
+
f"got {len(spatial_shape)} instead."
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
self.spatial_shape = spatial_shape
|
|
425
|
+
self.embed_dim = embed_dim
|
|
426
|
+
self.in_channels = in_channels
|
|
427
|
+
self.groups = embed_dim
|
|
428
|
+
|
|
429
|
+
if inference_mode:
|
|
430
|
+
self.reparam_conv = nn.Conv2d(
|
|
431
|
+
in_channels=self.in_channels,
|
|
432
|
+
out_channels=self.embed_dim,
|
|
433
|
+
kernel_size=self.spatial_shape,
|
|
434
|
+
stride=1,
|
|
435
|
+
padding=int(self.spatial_shape[0] // 2),
|
|
436
|
+
groups=self.embed_dim,
|
|
437
|
+
bias=True,
|
|
438
|
+
)
|
|
439
|
+
else:
|
|
440
|
+
self.pe = nn.Conv2d(
|
|
441
|
+
in_channels,
|
|
442
|
+
embed_dim,
|
|
443
|
+
spatial_shape,
|
|
444
|
+
1,
|
|
445
|
+
int(spatial_shape[0] // 2),
|
|
446
|
+
bias=True,
|
|
447
|
+
groups=embed_dim,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
451
|
+
if hasattr(self, "reparam_conv"):
|
|
452
|
+
x = self.reparam_conv(x)
|
|
453
|
+
return x
|
|
454
|
+
else:
|
|
455
|
+
x = self.pe(x) + x
|
|
456
|
+
return x
|
|
457
|
+
|
|
458
|
+
def reparameterize(self) -> None:
|
|
459
|
+
# Build equivalent Id tensor
|
|
460
|
+
input_dim = self.in_channels // self.groups
|
|
461
|
+
kernel_value = torch.zeros(
|
|
462
|
+
(
|
|
463
|
+
self.in_channels,
|
|
464
|
+
input_dim,
|
|
465
|
+
self.spatial_shape[0],
|
|
466
|
+
self.spatial_shape[1],
|
|
467
|
+
),
|
|
468
|
+
dtype=self.pe.weight.dtype,
|
|
469
|
+
device=self.pe.weight.device,
|
|
470
|
+
)
|
|
471
|
+
for i in range(self.in_channels):
|
|
472
|
+
kernel_value[
|
|
473
|
+
i,
|
|
474
|
+
i % input_dim,
|
|
475
|
+
self.spatial_shape[0] // 2,
|
|
476
|
+
self.spatial_shape[1] // 2,
|
|
477
|
+
] = 1
|
|
478
|
+
id_tensor = kernel_value
|
|
479
|
+
|
|
480
|
+
# Reparameterize Id tensor and conv
|
|
481
|
+
w_final = id_tensor + self.pe.weight
|
|
482
|
+
b_final = self.pe.bias
|
|
483
|
+
|
|
484
|
+
# Introduce reparam conv
|
|
485
|
+
self.reparam_conv = nn.Conv2d(
|
|
486
|
+
in_channels=self.in_channels,
|
|
487
|
+
out_channels=self.embed_dim,
|
|
488
|
+
kernel_size=self.spatial_shape,
|
|
489
|
+
stride=1,
|
|
490
|
+
padding=int(self.spatial_shape[0] // 2),
|
|
491
|
+
groups=self.embed_dim,
|
|
492
|
+
bias=True,
|
|
493
|
+
)
|
|
494
|
+
self.reparam_conv.weight.data = w_final
|
|
495
|
+
self.reparam_conv.bias.data = b_final
|
|
496
|
+
|
|
497
|
+
for para in self.parameters():
|
|
498
|
+
para.detach_()
|
|
499
|
+
self.__delattr__("pe")
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
class RepMixerBlock(nn.Module):
|
|
503
|
+
"""Implementation of Metaformer block with RepMixer as token mixer.
|
|
504
|
+
|
|
505
|
+
For more details on Metaformer structure, please refer to:
|
|
506
|
+
`MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
|
|
507
|
+
"""
|
|
508
|
+
|
|
509
|
+
def __init__(
|
|
510
|
+
self,
|
|
511
|
+
dim: int,
|
|
512
|
+
kernel_size: int = 3,
|
|
513
|
+
mlp_ratio: float = 4.0,
|
|
514
|
+
act_layer: nn.Module = nn.GELU,
|
|
515
|
+
drop: float = 0.0,
|
|
516
|
+
drop_path: float = 0.0,
|
|
517
|
+
use_layer_scale: bool = True,
|
|
518
|
+
layer_scale_init_value: float = 1e-5,
|
|
519
|
+
inference_mode: bool = False,
|
|
520
|
+
):
|
|
521
|
+
"""Build RepMixer Block.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
dim: Number of embedding dimensions.
|
|
525
|
+
kernel_size: Kernel size for repmixer. Default: 3
|
|
526
|
+
mlp_ratio: MLP expansion ratio. Default: 4.0
|
|
527
|
+
act_layer: Activation layer. Default: ``nn.GELU``
|
|
528
|
+
drop: Dropout rate. Default: 0.0
|
|
529
|
+
drop_path: Drop path rate. Default: 0.0
|
|
530
|
+
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
|
531
|
+
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
|
532
|
+
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
|
533
|
+
"""
|
|
534
|
+
|
|
535
|
+
super().__init__()
|
|
536
|
+
|
|
537
|
+
self.token_mixer = RepMixer(
|
|
538
|
+
dim,
|
|
539
|
+
kernel_size=kernel_size,
|
|
540
|
+
use_layer_scale=use_layer_scale,
|
|
541
|
+
layer_scale_init_value=layer_scale_init_value,
|
|
542
|
+
inference_mode=inference_mode,
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
|
|
546
|
+
mlp_ratio
|
|
547
|
+
)
|
|
548
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
549
|
+
self.convffn = ConvFFN(
|
|
550
|
+
in_channels=dim,
|
|
551
|
+
hidden_channels=mlp_hidden_dim,
|
|
552
|
+
act_layer=act_layer,
|
|
553
|
+
drop=drop,
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
# Drop Path
|
|
557
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
558
|
+
|
|
559
|
+
# Layer Scale
|
|
560
|
+
self.use_layer_scale = use_layer_scale
|
|
561
|
+
if use_layer_scale:
|
|
562
|
+
self.layer_scale = nn.Parameter(
|
|
563
|
+
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
def forward(self, x):
|
|
567
|
+
if self.use_layer_scale:
|
|
568
|
+
x = self.token_mixer(x)
|
|
569
|
+
x = x + self.drop_path(self.layer_scale * self.convffn(x))
|
|
570
|
+
else:
|
|
571
|
+
x = self.token_mixer(x)
|
|
572
|
+
x = x + self.drop_path(self.convffn(x))
|
|
573
|
+
return x
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
class AttentionBlock(nn.Module):
|
|
577
|
+
"""Implementation of metaformer block with MHSA as token mixer.
|
|
578
|
+
|
|
579
|
+
For more details on Metaformer structure, please refer to:
|
|
580
|
+
`MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
|
|
581
|
+
"""
|
|
582
|
+
|
|
583
|
+
def __init__(
|
|
584
|
+
self,
|
|
585
|
+
dim: int,
|
|
586
|
+
mlp_ratio: float = 4.0,
|
|
587
|
+
act_layer: nn.Module = nn.GELU,
|
|
588
|
+
norm_layer: nn.Module = nn.BatchNorm2d,
|
|
589
|
+
drop: float = 0.0,
|
|
590
|
+
drop_path: float = 0.0,
|
|
591
|
+
use_layer_scale: bool = True,
|
|
592
|
+
layer_scale_init_value: float = 1e-5,
|
|
593
|
+
):
|
|
594
|
+
"""Build Attention Block.
|
|
595
|
+
|
|
596
|
+
Args:
|
|
597
|
+
dim: Number of embedding dimensions.
|
|
598
|
+
mlp_ratio: MLP expansion ratio. Default: 4.0
|
|
599
|
+
act_layer: Activation layer. Default: ``nn.GELU``
|
|
600
|
+
norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
|
|
601
|
+
drop: Dropout rate. Default: 0.0
|
|
602
|
+
drop_path: Drop path rate. Default: 0.0
|
|
603
|
+
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
|
604
|
+
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
|
605
|
+
"""
|
|
606
|
+
|
|
607
|
+
super().__init__()
|
|
608
|
+
|
|
609
|
+
self.norm = norm_layer(dim)
|
|
610
|
+
self.token_mixer = MHSA(dim=dim)
|
|
611
|
+
|
|
612
|
+
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
|
|
613
|
+
mlp_ratio
|
|
614
|
+
)
|
|
615
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
616
|
+
self.convffn = ConvFFN(
|
|
617
|
+
in_channels=dim,
|
|
618
|
+
hidden_channels=mlp_hidden_dim,
|
|
619
|
+
act_layer=act_layer,
|
|
620
|
+
drop=drop,
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
# Drop path
|
|
624
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
625
|
+
|
|
626
|
+
# Layer Scale
|
|
627
|
+
self.use_layer_scale = use_layer_scale
|
|
628
|
+
if use_layer_scale:
|
|
629
|
+
self.layer_scale_1 = nn.Parameter(
|
|
630
|
+
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
|
631
|
+
)
|
|
632
|
+
self.layer_scale_2 = nn.Parameter(
|
|
633
|
+
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
def forward(self, x):
|
|
637
|
+
if self.use_layer_scale:
|
|
638
|
+
x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
|
|
639
|
+
x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
|
|
640
|
+
else:
|
|
641
|
+
x = x + self.drop_path(self.token_mixer(self.norm(x)))
|
|
642
|
+
x = x + self.drop_path(self.convffn(x))
|
|
643
|
+
return x
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
def basic_blocks(
|
|
647
|
+
dim: int,
|
|
648
|
+
block_index: int,
|
|
649
|
+
num_blocks: List[int],
|
|
650
|
+
token_mixer_type: str,
|
|
651
|
+
kernel_size: int = 3,
|
|
652
|
+
mlp_ratio: float = 4.0,
|
|
653
|
+
act_layer: nn.Module = nn.GELU,
|
|
654
|
+
norm_layer: nn.Module = nn.BatchNorm2d,
|
|
655
|
+
drop_rate: float = 0.0,
|
|
656
|
+
drop_path_rate: float = 0.0,
|
|
657
|
+
use_layer_scale: bool = True,
|
|
658
|
+
layer_scale_init_value: float = 1e-5,
|
|
659
|
+
inference_mode=False,
|
|
660
|
+
) -> nn.Sequential:
|
|
661
|
+
"""Build FastViT blocks within a stage.
|
|
662
|
+
|
|
663
|
+
Args:
|
|
664
|
+
dim: Number of embedding dimensions.
|
|
665
|
+
block_index: block index.
|
|
666
|
+
num_blocks: List containing number of blocks per stage.
|
|
667
|
+
token_mixer_type: Token mixer type.
|
|
668
|
+
kernel_size: Kernel size for repmixer.
|
|
669
|
+
mlp_ratio: MLP expansion ratio.
|
|
670
|
+
act_layer: Activation layer.
|
|
671
|
+
norm_layer: Normalization layer.
|
|
672
|
+
drop_rate: Dropout rate.
|
|
673
|
+
drop_path_rate: Drop path rate.
|
|
674
|
+
use_layer_scale: Flag to turn on layer scale regularization.
|
|
675
|
+
layer_scale_init_value: Layer scale value at initialization.
|
|
676
|
+
inference_mode: Flag to instantiate block in inference mode.
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
nn.Sequential object of all the blocks within the stage.
|
|
680
|
+
"""
|
|
681
|
+
blocks = []
|
|
682
|
+
for block_idx in range(num_blocks[block_index]):
|
|
683
|
+
block_dpr = (
|
|
684
|
+
drop_path_rate
|
|
685
|
+
* (block_idx + sum(num_blocks[:block_index]))
|
|
686
|
+
/ (sum(num_blocks) - 1)
|
|
687
|
+
)
|
|
688
|
+
if token_mixer_type == "repmixer":
|
|
689
|
+
blocks.append(
|
|
690
|
+
RepMixerBlock(
|
|
691
|
+
dim,
|
|
692
|
+
kernel_size=kernel_size,
|
|
693
|
+
mlp_ratio=mlp_ratio,
|
|
694
|
+
act_layer=act_layer,
|
|
695
|
+
drop=drop_rate,
|
|
696
|
+
drop_path=block_dpr,
|
|
697
|
+
use_layer_scale=use_layer_scale,
|
|
698
|
+
layer_scale_init_value=layer_scale_init_value,
|
|
699
|
+
inference_mode=inference_mode,
|
|
700
|
+
)
|
|
701
|
+
)
|
|
702
|
+
elif token_mixer_type == "attention":
|
|
703
|
+
blocks.append(
|
|
704
|
+
AttentionBlock(
|
|
705
|
+
dim,
|
|
706
|
+
mlp_ratio=mlp_ratio,
|
|
707
|
+
act_layer=act_layer,
|
|
708
|
+
norm_layer=norm_layer,
|
|
709
|
+
drop=drop_rate,
|
|
710
|
+
drop_path=block_dpr,
|
|
711
|
+
use_layer_scale=use_layer_scale,
|
|
712
|
+
layer_scale_init_value=layer_scale_init_value,
|
|
713
|
+
)
|
|
714
|
+
)
|
|
715
|
+
else:
|
|
716
|
+
raise ValueError(
|
|
717
|
+
"Token mixer type: {} not supported".format(token_mixer_type)
|
|
718
|
+
)
|
|
719
|
+
blocks = nn.Sequential(*blocks)
|
|
720
|
+
|
|
721
|
+
return blocks
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
class FastViT(nn.Module):
|
|
725
|
+
"""
|
|
726
|
+
This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_
|
|
727
|
+
"""
|
|
728
|
+
|
|
729
|
+
def __init__(
|
|
730
|
+
self,
|
|
731
|
+
layers,
|
|
732
|
+
token_mixers: Tuple[str, ...],
|
|
733
|
+
embed_dims=None,
|
|
734
|
+
mlp_ratios=None,
|
|
735
|
+
downsamples=None,
|
|
736
|
+
se_downsamples=None,
|
|
737
|
+
repmixer_kernel_size=3,
|
|
738
|
+
norm_layer: nn.Module = nn.BatchNorm2d,
|
|
739
|
+
act_layer: nn.Module = nn.GELU,
|
|
740
|
+
num_classes=1000,
|
|
741
|
+
pos_embs=None,
|
|
742
|
+
down_patch_size=7,
|
|
743
|
+
down_stride=2,
|
|
744
|
+
drop_rate=0.0,
|
|
745
|
+
drop_path_rate=0.0,
|
|
746
|
+
use_layer_scale=True,
|
|
747
|
+
layer_scale_init_value=1e-5,
|
|
748
|
+
init_cfg=None,
|
|
749
|
+
pretrained=None,
|
|
750
|
+
cls_ratio=2.0,
|
|
751
|
+
inference_mode=False,
|
|
752
|
+
**kwargs,
|
|
753
|
+
) -> None:
|
|
754
|
+
|
|
755
|
+
super().__init__()
|
|
756
|
+
|
|
757
|
+
self.num_classes = num_classes
|
|
758
|
+
if pos_embs is None:
|
|
759
|
+
pos_embs = [None] * len(layers)
|
|
760
|
+
|
|
761
|
+
if se_downsamples is None:
|
|
762
|
+
se_downsamples = [False] * len(layers)
|
|
763
|
+
|
|
764
|
+
# Convolutional stem
|
|
765
|
+
self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode)
|
|
766
|
+
|
|
767
|
+
# Build the main stages of the network architecture
|
|
768
|
+
network = []
|
|
769
|
+
for i in range(len(layers)):
|
|
770
|
+
# Add position embeddings if requested
|
|
771
|
+
if pos_embs[i] is not None:
|
|
772
|
+
network.append(
|
|
773
|
+
pos_embs[i](
|
|
774
|
+
embed_dims[i], embed_dims[i], inference_mode=inference_mode
|
|
775
|
+
)
|
|
776
|
+
)
|
|
777
|
+
stage = basic_blocks(
|
|
778
|
+
embed_dims[i],
|
|
779
|
+
i,
|
|
780
|
+
layers,
|
|
781
|
+
token_mixer_type=token_mixers[i],
|
|
782
|
+
kernel_size=repmixer_kernel_size,
|
|
783
|
+
mlp_ratio=mlp_ratios[i],
|
|
784
|
+
act_layer=act_layer,
|
|
785
|
+
norm_layer=norm_layer,
|
|
786
|
+
drop_rate=drop_rate,
|
|
787
|
+
drop_path_rate=drop_path_rate,
|
|
788
|
+
use_layer_scale=use_layer_scale,
|
|
789
|
+
layer_scale_init_value=layer_scale_init_value,
|
|
790
|
+
inference_mode=inference_mode,
|
|
791
|
+
)
|
|
792
|
+
network.append(stage)
|
|
793
|
+
if i >= len(layers) - 1:
|
|
794
|
+
break
|
|
795
|
+
|
|
796
|
+
# Patch merging/downsampling between stages.
|
|
797
|
+
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
|
|
798
|
+
network.append(
|
|
799
|
+
PatchEmbed(
|
|
800
|
+
patch_size=down_patch_size,
|
|
801
|
+
stride=down_stride,
|
|
802
|
+
in_channels=embed_dims[i],
|
|
803
|
+
embed_dim=embed_dims[i + 1],
|
|
804
|
+
inference_mode=inference_mode,
|
|
805
|
+
use_se=se_downsamples[i + 1],
|
|
806
|
+
)
|
|
807
|
+
)
|
|
808
|
+
self.network = nn.ModuleList(network)
|
|
809
|
+
|
|
810
|
+
# Classifier head
|
|
811
|
+
self.conv_exp = MobileOneBlock(
|
|
812
|
+
in_channels=embed_dims[-1],
|
|
813
|
+
out_channels=int(embed_dims[-1] * cls_ratio),
|
|
814
|
+
kernel_size=3,
|
|
815
|
+
stride=1,
|
|
816
|
+
padding=1,
|
|
817
|
+
groups=embed_dims[-1],
|
|
818
|
+
inference_mode=inference_mode,
|
|
819
|
+
use_se=True,
|
|
820
|
+
num_conv_branches=1,
|
|
821
|
+
)
|
|
822
|
+
self.head = (
|
|
823
|
+
nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes)
|
|
824
|
+
if num_classes > 0
|
|
825
|
+
else nn.Identity()
|
|
826
|
+
)
|
|
827
|
+
self.apply(self.cls_init_weights)
|
|
828
|
+
self.init_cfg = copy.deepcopy(init_cfg)
|
|
829
|
+
|
|
830
|
+
def cls_init_weights(self, m: nn.Module) -> None:
|
|
831
|
+
"""Init. for classification"""
|
|
832
|
+
if isinstance(m, nn.Linear):
|
|
833
|
+
trunc_normal_(m.weight, std=0.02)
|
|
834
|
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
835
|
+
nn.init.constant_(m.bias, 0)
|
|
836
|
+
|
|
837
|
+
def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor:
|
|
838
|
+
x = self.patch_embed(x)
|
|
839
|
+
return x
|
|
840
|
+
|
|
841
|
+
def forward_tokens(self, x: torch.Tensor) -> torch.Tensor:
|
|
842
|
+
for idx, block in enumerate(self.network):
|
|
843
|
+
x = block(x)
|
|
844
|
+
# output only the features of last layer for image classification
|
|
845
|
+
return x
|
|
846
|
+
|
|
847
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
848
|
+
# input embedding
|
|
849
|
+
x = self.forward_embeddings(x)
|
|
850
|
+
# through backbone
|
|
851
|
+
x = self.forward_tokens(x)
|
|
852
|
+
# for image classification
|
|
853
|
+
x = self.conv_exp(x)
|
|
854
|
+
cls_out = self.head(x)
|
|
855
|
+
return cls_out
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
@register_model
|
|
859
|
+
def mci0(pretrained=False, **kwargs):
|
|
860
|
+
"""Instantiate MCi0 model variant."""
|
|
861
|
+
layers = [2, 6, 10, 2]
|
|
862
|
+
embed_dims = [64, 128, 256, 512]
|
|
863
|
+
mlp_ratios = [3, 3, 3, 3]
|
|
864
|
+
downsamples = [True, True, True, True]
|
|
865
|
+
se_downsamples = [False, False, True, True]
|
|
866
|
+
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
|
867
|
+
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
|
868
|
+
model = FastViT(
|
|
869
|
+
layers,
|
|
870
|
+
token_mixers=token_mixers,
|
|
871
|
+
embed_dims=embed_dims,
|
|
872
|
+
pos_embs=pos_embs,
|
|
873
|
+
mlp_ratios=mlp_ratios,
|
|
874
|
+
downsamples=downsamples,
|
|
875
|
+
se_downsamples=se_downsamples,
|
|
876
|
+
**kwargs,
|
|
877
|
+
)
|
|
878
|
+
model.default_cfg = default_cfgs["fastvit_s"]
|
|
879
|
+
if pretrained:
|
|
880
|
+
raise ValueError("Functionality not implemented.")
|
|
881
|
+
return model
|
|
882
|
+
|
|
883
|
+
|
|
884
|
+
@register_model
|
|
885
|
+
def mci1(pretrained=False, **kwargs):
|
|
886
|
+
"""Instantiate MCi1 model variant."""
|
|
887
|
+
layers = [4, 12, 20, 4]
|
|
888
|
+
embed_dims = [64, 128, 256, 512]
|
|
889
|
+
mlp_ratios = [3, 3, 3, 3]
|
|
890
|
+
downsamples = [True, True, True, True]
|
|
891
|
+
se_downsamples = [False, False, True, True]
|
|
892
|
+
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
|
893
|
+
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
|
894
|
+
model = FastViT(
|
|
895
|
+
layers,
|
|
896
|
+
token_mixers=token_mixers,
|
|
897
|
+
embed_dims=embed_dims,
|
|
898
|
+
pos_embs=pos_embs,
|
|
899
|
+
mlp_ratios=mlp_ratios,
|
|
900
|
+
downsamples=downsamples,
|
|
901
|
+
se_downsamples=se_downsamples,
|
|
902
|
+
**kwargs,
|
|
903
|
+
)
|
|
904
|
+
model.default_cfg = default_cfgs["fastvit_s"]
|
|
905
|
+
if pretrained:
|
|
906
|
+
raise ValueError("Functionality not implemented.")
|
|
907
|
+
return model
|
|
908
|
+
|
|
909
|
+
|
|
910
|
+
@register_model
|
|
911
|
+
def mci2(pretrained=False, **kwargs):
|
|
912
|
+
"""Instantiate MCi2 model variant."""
|
|
913
|
+
layers = [4, 12, 24, 4]
|
|
914
|
+
embed_dims = [80, 160, 320, 640]
|
|
915
|
+
mlp_ratios = [3, 3, 3, 3]
|
|
916
|
+
downsamples = [True, True, True, True]
|
|
917
|
+
se_downsamples = [False, False, True, True]
|
|
918
|
+
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
|
919
|
+
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
|
920
|
+
model = FastViT(
|
|
921
|
+
layers,
|
|
922
|
+
token_mixers=token_mixers,
|
|
923
|
+
embed_dims=embed_dims,
|
|
924
|
+
pos_embs=pos_embs,
|
|
925
|
+
mlp_ratios=mlp_ratios,
|
|
926
|
+
downsamples=downsamples,
|
|
927
|
+
se_downsamples=se_downsamples,
|
|
928
|
+
**kwargs,
|
|
929
|
+
)
|
|
930
|
+
model.default_cfg = default_cfgs["fastvit_m"]
|
|
931
|
+
if pretrained:
|
|
932
|
+
raise ValueError("Functionality not implemented.")
|
|
933
|
+
return model
|