lightly-studio 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lightly-studio might be problematic. Click here for more details.
- lightly_studio/__init__.py +11 -0
- lightly_studio/api/__init__.py +0 -0
- lightly_studio/api/app.py +110 -0
- lightly_studio/api/cache.py +77 -0
- lightly_studio/api/db.py +133 -0
- lightly_studio/api/db_tables.py +32 -0
- lightly_studio/api/features.py +7 -0
- lightly_studio/api/routes/api/annotation.py +233 -0
- lightly_studio/api/routes/api/annotation_label.py +90 -0
- lightly_studio/api/routes/api/annotation_task.py +38 -0
- lightly_studio/api/routes/api/classifier.py +387 -0
- lightly_studio/api/routes/api/dataset.py +182 -0
- lightly_studio/api/routes/api/dataset_tag.py +257 -0
- lightly_studio/api/routes/api/exceptions.py +96 -0
- lightly_studio/api/routes/api/features.py +17 -0
- lightly_studio/api/routes/api/metadata.py +37 -0
- lightly_studio/api/routes/api/metrics.py +80 -0
- lightly_studio/api/routes/api/sample.py +196 -0
- lightly_studio/api/routes/api/settings.py +45 -0
- lightly_studio/api/routes/api/status.py +19 -0
- lightly_studio/api/routes/api/text_embedding.py +48 -0
- lightly_studio/api/routes/api/validators.py +17 -0
- lightly_studio/api/routes/healthz.py +13 -0
- lightly_studio/api/routes/images.py +104 -0
- lightly_studio/api/routes/webapp.py +51 -0
- lightly_studio/api/server.py +82 -0
- lightly_studio/core/__init__.py +0 -0
- lightly_studio/core/dataset.py +523 -0
- lightly_studio/core/sample.py +77 -0
- lightly_studio/core/start_gui.py +15 -0
- lightly_studio/dataset/__init__.py +0 -0
- lightly_studio/dataset/edge_embedding_generator.py +144 -0
- lightly_studio/dataset/embedding_generator.py +91 -0
- lightly_studio/dataset/embedding_manager.py +163 -0
- lightly_studio/dataset/env.py +16 -0
- lightly_studio/dataset/file_utils.py +35 -0
- lightly_studio/dataset/loader.py +622 -0
- lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
- lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
- lightly_studio/examples/example.py +23 -0
- lightly_studio/examples/example_metadata.py +338 -0
- lightly_studio/examples/example_selection.py +39 -0
- lightly_studio/examples/example_split_work.py +67 -0
- lightly_studio/examples/example_v2.py +21 -0
- lightly_studio/export_schema.py +18 -0
- lightly_studio/few_shot_classifier/__init__.py +0 -0
- lightly_studio/few_shot_classifier/classifier.py +80 -0
- lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
- lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
- lightly_studio/metadata/complex_metadata.py +47 -0
- lightly_studio/metadata/gps_coordinate.py +41 -0
- lightly_studio/metadata/metadata_protocol.py +17 -0
- lightly_studio/metrics/__init__.py +0 -0
- lightly_studio/metrics/detection/__init__.py +0 -0
- lightly_studio/metrics/detection/map.py +268 -0
- lightly_studio/models/__init__.py +1 -0
- lightly_studio/models/annotation/__init__.py +0 -0
- lightly_studio/models/annotation/annotation_base.py +171 -0
- lightly_studio/models/annotation/instance_segmentation.py +56 -0
- lightly_studio/models/annotation/links.py +17 -0
- lightly_studio/models/annotation/object_detection.py +47 -0
- lightly_studio/models/annotation/semantic_segmentation.py +44 -0
- lightly_studio/models/annotation_label.py +47 -0
- lightly_studio/models/annotation_task.py +28 -0
- lightly_studio/models/classifier.py +20 -0
- lightly_studio/models/dataset.py +84 -0
- lightly_studio/models/embedding_model.py +30 -0
- lightly_studio/models/metadata.py +208 -0
- lightly_studio/models/sample.py +180 -0
- lightly_studio/models/sample_embedding.py +37 -0
- lightly_studio/models/settings.py +60 -0
- lightly_studio/models/tag.py +96 -0
- lightly_studio/py.typed +0 -0
- lightly_studio/resolvers/__init__.py +7 -0
- lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
- lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
- lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
- lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
- lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
- lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
- lightly_studio/resolvers/annotation_resolver/create.py +19 -0
- lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
- lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
- lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
- lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
- lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
- lightly_studio/resolvers/annotation_task_resolver.py +31 -0
- lightly_studio/resolvers/annotations/__init__.py +1 -0
- lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
- lightly_studio/resolvers/dataset_resolver.py +278 -0
- lightly_studio/resolvers/embedding_model_resolver.py +100 -0
- lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
- lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
- lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
- lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
- lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
- lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
- lightly_studio/resolvers/sample_resolver.py +249 -0
- lightly_studio/resolvers/samples_filter.py +81 -0
- lightly_studio/resolvers/settings_resolver.py +58 -0
- lightly_studio/resolvers/tag_resolver.py +276 -0
- lightly_studio/selection/README.md +6 -0
- lightly_studio/selection/mundig.py +105 -0
- lightly_studio/selection/select.py +96 -0
- lightly_studio/selection/select_via_db.py +93 -0
- lightly_studio/selection/selection_config.py +31 -0
- lightly_studio/services/annotations_service/__init__.py +21 -0
- lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
- lightly_studio/services/annotations_service/update_annotation.py +65 -0
- lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
- lightly_studio/services/annotations_service/update_annotations.py +29 -0
- lightly_studio/setup_logging.py +19 -0
- lightly_studio/type_definitions.py +19 -0
- lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
- lightly_studio/vendor/LICENSE +31 -0
- lightly_studio/vendor/LICENSE_weights_data +50 -0
- lightly_studio/vendor/README.md +5 -0
- lightly_studio/vendor/__init__.py +1 -0
- lightly_studio/vendor/mobileclip/__init__.py +96 -0
- lightly_studio/vendor/mobileclip/clip.py +77 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
- lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
- lightly_studio/vendor/mobileclip/logger.py +154 -0
- lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
- lightly_studio/vendor/mobileclip/models/mci.py +933 -0
- lightly_studio/vendor/mobileclip/models/vit.py +433 -0
- lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
- lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
- lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
- lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
- lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
- lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
- lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
- lightly_studio-0.3.1.dist-info/METADATA +520 -0
- lightly_studio-0.3.1.dist-info/RECORD +219 -0
- lightly_studio-0.3.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,433 @@
|
|
|
1
|
+
#
|
|
2
|
+
# For licensing see accompanying LICENSE file.
|
|
3
|
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
4
|
+
#
|
|
5
|
+
"""
|
|
6
|
+
Implementation of the following modules is borrowed from ml-cvnets repo:
|
|
7
|
+
https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/vit.py
|
|
8
|
+
|
|
9
|
+
Please see ACKNOWLEDGEMENTS for license details.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import Dict, Optional, Tuple, Union
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import torch
|
|
16
|
+
from torch import Tensor, nn
|
|
17
|
+
|
|
18
|
+
from timm.models import register_model
|
|
19
|
+
from ..modules.common.transformer import (
|
|
20
|
+
PositionalEmbedding,
|
|
21
|
+
TransformerEncoder,
|
|
22
|
+
get_normalization_layer,
|
|
23
|
+
)
|
|
24
|
+
from ..modules.image.image_projection import SimpleImageProjectionHead
|
|
25
|
+
from .. import logger
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ConvNormAct(nn.Module):
|
|
29
|
+
"""
|
|
30
|
+
Applies an N-dimensional convolution over an input.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
cfg: Model configuration.
|
|
34
|
+
in_channels: :math:`C_{out}` from an expected output of size
|
|
35
|
+
:math:`(bs, C_{in}, X_{1}, ..., X_{N})`.
|
|
36
|
+
out_channels: :math:`C_{out}` from an expected output of size
|
|
37
|
+
:math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`.
|
|
38
|
+
kernel_size: Kernel size for convolution. An integer, or tuple of length ``N``.
|
|
39
|
+
stride: Stride for convolution. An integer, or tuple of length ``N``. Default: 1.
|
|
40
|
+
dilation: Dilation rate for convolution. An integer, or tuple of length ``N``.
|
|
41
|
+
Default: ``1``.
|
|
42
|
+
padding: Padding for convolution. An integer, or tuple of length ``N``.
|
|
43
|
+
If not specified, padding is automatically computed based on kernel size and
|
|
44
|
+
dilation range. Default : ``None`` (equivalent to ``[
|
|
45
|
+
int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(N)]``).
|
|
46
|
+
groups: Number of groups in convolution. Default: ``1``.
|
|
47
|
+
bias: Use bias. Default: ``False``.
|
|
48
|
+
padding_mode: Padding mode ('zeros', 'reflect', 'replicate' or 'circular').
|
|
49
|
+
Default: ``zeros``.
|
|
50
|
+
use_norm: Use normalization layer after convolution. Default: ``True``.
|
|
51
|
+
use_act: Use activation layer after convolution (or convolution and normalization).
|
|
52
|
+
Default: ``True``.
|
|
53
|
+
norm_layer: If not None, the provided normalization layer object will be used.
|
|
54
|
+
Otherwise, a normalization object will be created based on config
|
|
55
|
+
``model.normalization.*`` opts.
|
|
56
|
+
act_layer: If not None, the provided activation function will be used.
|
|
57
|
+
Otherwise, an activation function will be created based on config
|
|
58
|
+
``model.activation.*`` opts.
|
|
59
|
+
|
|
60
|
+
Shape:
|
|
61
|
+
- Input: :math:`(bs, C_{in}, X_{1}, ..., X_{N})`.
|
|
62
|
+
- Output: :math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`.
|
|
63
|
+
|
|
64
|
+
.. note::
|
|
65
|
+
For depth-wise convolution, `groups=C_{in}=C_{out}`.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
cfg: Dict,
|
|
71
|
+
in_channels: int,
|
|
72
|
+
out_channels: int,
|
|
73
|
+
kernel_size: Union[int, Tuple[int, ...]],
|
|
74
|
+
stride: Union[int, Tuple[int, ...]] = 1,
|
|
75
|
+
dilation: Union[int, Tuple[int, ...]] = 1,
|
|
76
|
+
padding: Optional[Union[int, Tuple[int, ...]]] = None,
|
|
77
|
+
groups: int = 1,
|
|
78
|
+
bias: bool = False,
|
|
79
|
+
padding_mode: str = "zeros",
|
|
80
|
+
use_norm: bool = True,
|
|
81
|
+
use_act: bool = True,
|
|
82
|
+
norm_layer: Optional[nn.Module] = None,
|
|
83
|
+
act_layer: Optional[nn.Module] = None,
|
|
84
|
+
*args,
|
|
85
|
+
**kwargs,
|
|
86
|
+
) -> None:
|
|
87
|
+
super().__init__()
|
|
88
|
+
self.ndim = 2
|
|
89
|
+
|
|
90
|
+
if norm_layer is None and use_norm:
|
|
91
|
+
norm_type = cfg.get("normalization", "batch_norm")
|
|
92
|
+
if norm_type == "batch_norm":
|
|
93
|
+
norm_layer = nn.BatchNorm2d(
|
|
94
|
+
num_features=out_channels,
|
|
95
|
+
momentum=cfg.get("momentum", 0.1),
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
norm_layer = get_normalization_layer(
|
|
99
|
+
num_features=out_channels, norm_type=norm_type
|
|
100
|
+
)
|
|
101
|
+
elif norm_layer is not None and use_norm:
|
|
102
|
+
logger.error(
|
|
103
|
+
f"When use_norm is False, norm_layer should be None, but norm_layer={norm_layer} is provided."
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if act_layer is None and use_act:
|
|
107
|
+
act_layer = nn.GELU() # Default to GELU
|
|
108
|
+
elif act_layer is not None and use_act:
|
|
109
|
+
logger.error(
|
|
110
|
+
f"When use_act is False, act_layer should be None, but act_layer={act_layer} is provided."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if (
|
|
114
|
+
use_norm
|
|
115
|
+
and any(param[0] == "bias" for param in norm_layer.named_parameters())
|
|
116
|
+
and bias
|
|
117
|
+
):
|
|
118
|
+
assert (
|
|
119
|
+
not bias
|
|
120
|
+
), "Do not use bias when using normalization layers with bias."
|
|
121
|
+
|
|
122
|
+
if isinstance(kernel_size, int):
|
|
123
|
+
kernel_size = (kernel_size,) * self.ndim
|
|
124
|
+
|
|
125
|
+
if isinstance(stride, int):
|
|
126
|
+
stride = (stride,) * self.ndim
|
|
127
|
+
|
|
128
|
+
if isinstance(dilation, int):
|
|
129
|
+
dilation = (dilation,) * self.ndim
|
|
130
|
+
|
|
131
|
+
assert isinstance(kernel_size, Tuple)
|
|
132
|
+
assert isinstance(stride, Tuple)
|
|
133
|
+
assert isinstance(dilation, Tuple)
|
|
134
|
+
|
|
135
|
+
if padding is None:
|
|
136
|
+
padding = (
|
|
137
|
+
int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(self.ndim)
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if in_channels % groups != 0:
|
|
141
|
+
logger.error(
|
|
142
|
+
"Input channels are not divisible by groups. {}%{} != 0 ".format(
|
|
143
|
+
in_channels, groups
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
if out_channels % groups != 0:
|
|
147
|
+
logger.error(
|
|
148
|
+
"Output channels are not divisible by groups. {}%{} != 0 ".format(
|
|
149
|
+
out_channels, groups
|
|
150
|
+
)
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
block = nn.Sequential()
|
|
154
|
+
|
|
155
|
+
conv_layer = nn.Conv2d(
|
|
156
|
+
in_channels=in_channels,
|
|
157
|
+
out_channels=out_channels,
|
|
158
|
+
kernel_size=kernel_size, # type: ignore
|
|
159
|
+
stride=stride, # type: ignore
|
|
160
|
+
padding=padding,
|
|
161
|
+
dilation=dilation, # type: ignore
|
|
162
|
+
groups=groups,
|
|
163
|
+
bias=bias,
|
|
164
|
+
padding_mode=padding_mode,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
block.add_module(name="conv", module=conv_layer)
|
|
168
|
+
|
|
169
|
+
self.norm_name = None
|
|
170
|
+
if use_norm:
|
|
171
|
+
block.add_module(name="norm", module=norm_layer)
|
|
172
|
+
self.norm_name = norm_layer.__class__.__name__
|
|
173
|
+
|
|
174
|
+
self.act_name = None
|
|
175
|
+
if use_act:
|
|
176
|
+
block.add_module(name="act", module=act_layer)
|
|
177
|
+
self.act_name = act_layer.__class__.__name__
|
|
178
|
+
|
|
179
|
+
self.block = block
|
|
180
|
+
self.in_channels = in_channels
|
|
181
|
+
self.out_channels = out_channels
|
|
182
|
+
self.stride = stride
|
|
183
|
+
self.groups = groups
|
|
184
|
+
self.kernel_size = conv_layer.kernel_size
|
|
185
|
+
self.bias = bias
|
|
186
|
+
self.dilation = dilation
|
|
187
|
+
|
|
188
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
189
|
+
return self.block(x)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class VisionTransformer(nn.Module):
|
|
193
|
+
"""
|
|
194
|
+
This class defines the `Vision Transformer architecture <https://arxiv.org/abs/2010.11929>`_. Our model implementation
|
|
195
|
+
is inspired from `Early Convolutions Help Transformers See Better <https://arxiv.org/abs/2106.14881>`_
|
|
196
|
+
|
|
197
|
+
.. note::
|
|
198
|
+
Our implementation is different from the original implementation in two ways:
|
|
199
|
+
1. Kernel size is odd.
|
|
200
|
+
2. Our positional encoding implementation allows us to use ViT with any multiple input scales
|
|
201
|
+
3. We do not use StochasticDepth
|
|
202
|
+
4. We do not add positional encoding to class token (if enabled), as suggested in `DeiT-3 paper <https://arxiv.org/abs/2204.07118>`_
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
def __init__(self, cfg, *args, **kwargs) -> None:
|
|
206
|
+
super().__init__()
|
|
207
|
+
image_channels = 3
|
|
208
|
+
num_classes = cfg.get("n_classes", 1000)
|
|
209
|
+
|
|
210
|
+
self.projection_dim = None
|
|
211
|
+
if "projection_dim" in kwargs:
|
|
212
|
+
self.projection_dim = kwargs.get("projection_dim")
|
|
213
|
+
|
|
214
|
+
kernel_sizes_conv_stem = [4, 2, 2]
|
|
215
|
+
strides_conv_stem = [4, 2, 2]
|
|
216
|
+
|
|
217
|
+
# Typically, in the ImageNet dataset, we use 224x224 as a resolution.
|
|
218
|
+
# For out ViT implementation, patch size is 16 (16 = 4 * 2 * 2)
|
|
219
|
+
# Therefore, total number of embeddings along width and height are (224 / 16)^2
|
|
220
|
+
num_embeddings = (224 // 16) ** 2
|
|
221
|
+
|
|
222
|
+
embed_dim = cfg["embed_dim"]
|
|
223
|
+
ffn_dim = cfg["embed_dim"] * 4
|
|
224
|
+
pos_emb_drop_p = cfg.get("pos_emb_drop_p", 0.0)
|
|
225
|
+
n_transformer_layers = cfg["n_transformer_layers"]
|
|
226
|
+
num_heads = cfg["n_attn_heads"]
|
|
227
|
+
attn_dropout = cfg.get("attn_dropout", 0.0)
|
|
228
|
+
dropout = cfg.get("dropout", 0.0)
|
|
229
|
+
ffn_dropout = cfg.get("ffn_dropout", 0.0)
|
|
230
|
+
norm_layer = cfg.get("norm_layer", "layer_norm")
|
|
231
|
+
|
|
232
|
+
conv_stem_proj_dim = max(32, embed_dim // 4)
|
|
233
|
+
patch_emb = [
|
|
234
|
+
ConvNormAct(
|
|
235
|
+
cfg=cfg,
|
|
236
|
+
in_channels=image_channels,
|
|
237
|
+
out_channels=conv_stem_proj_dim,
|
|
238
|
+
kernel_size=kernel_sizes_conv_stem[0],
|
|
239
|
+
stride=strides_conv_stem[0],
|
|
240
|
+
bias=False,
|
|
241
|
+
use_norm=True,
|
|
242
|
+
use_act=True,
|
|
243
|
+
),
|
|
244
|
+
ConvNormAct(
|
|
245
|
+
cfg=cfg,
|
|
246
|
+
in_channels=conv_stem_proj_dim,
|
|
247
|
+
out_channels=conv_stem_proj_dim,
|
|
248
|
+
kernel_size=kernel_sizes_conv_stem[1],
|
|
249
|
+
stride=strides_conv_stem[1],
|
|
250
|
+
bias=False,
|
|
251
|
+
use_norm=True,
|
|
252
|
+
use_act=True,
|
|
253
|
+
),
|
|
254
|
+
ConvNormAct(
|
|
255
|
+
cfg=cfg,
|
|
256
|
+
in_channels=conv_stem_proj_dim,
|
|
257
|
+
out_channels=embed_dim,
|
|
258
|
+
kernel_size=kernel_sizes_conv_stem[2],
|
|
259
|
+
stride=strides_conv_stem[2],
|
|
260
|
+
bias=True,
|
|
261
|
+
use_norm=False,
|
|
262
|
+
use_act=False,
|
|
263
|
+
),
|
|
264
|
+
]
|
|
265
|
+
|
|
266
|
+
self.patch_emb = nn.Sequential(*patch_emb)
|
|
267
|
+
|
|
268
|
+
use_cls_token = not cfg.get("no_cls_token", False)
|
|
269
|
+
stochastic_dropout = cfg.get("stochastic_dropout", 0.0)
|
|
270
|
+
per_layer_stochastic_drop_rate = [
|
|
271
|
+
round(x, 3)
|
|
272
|
+
for x in np.linspace(0, stochastic_dropout, n_transformer_layers)
|
|
273
|
+
]
|
|
274
|
+
transformer_blocks = [
|
|
275
|
+
TransformerEncoder(
|
|
276
|
+
embed_dim=embed_dim,
|
|
277
|
+
ffn_latent_dim=ffn_dim,
|
|
278
|
+
num_heads=num_heads,
|
|
279
|
+
attn_dropout=attn_dropout,
|
|
280
|
+
dropout=dropout,
|
|
281
|
+
ffn_dropout=ffn_dropout,
|
|
282
|
+
transformer_norm_layer=norm_layer,
|
|
283
|
+
stochastic_dropout=per_layer_stochastic_drop_rate[layer_idx],
|
|
284
|
+
)
|
|
285
|
+
for layer_idx in range(n_transformer_layers)
|
|
286
|
+
]
|
|
287
|
+
|
|
288
|
+
self.post_transformer_norm = get_normalization_layer(
|
|
289
|
+
num_features=embed_dim, norm_type=norm_layer
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
self.transformer = nn.Sequential(*transformer_blocks)
|
|
293
|
+
|
|
294
|
+
if self.projection_dim is None:
|
|
295
|
+
self.classifier = nn.Linear(embed_dim, num_classes)
|
|
296
|
+
else:
|
|
297
|
+
self.classifier = SimpleImageProjectionHead(embed_dim, self.projection_dim)
|
|
298
|
+
|
|
299
|
+
if use_cls_token:
|
|
300
|
+
self.cls_token = nn.Parameter(torch.zeros(size=(1, 1, embed_dim)))
|
|
301
|
+
torch.nn.init.trunc_normal_(self.cls_token, std=0.02)
|
|
302
|
+
else:
|
|
303
|
+
self.cls_token = None
|
|
304
|
+
|
|
305
|
+
self.pos_embed = PositionalEmbedding(
|
|
306
|
+
num_embeddings=num_embeddings,
|
|
307
|
+
embedding_dim=embed_dim,
|
|
308
|
+
padding_idx=None,
|
|
309
|
+
interpolation_mode="bilinear",
|
|
310
|
+
)
|
|
311
|
+
self.emb_dropout = nn.Dropout(p=pos_emb_drop_p)
|
|
312
|
+
|
|
313
|
+
def extract_patch_embeddings(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]:
|
|
314
|
+
# input is of shape [Batch, in_channels, height, width]. in_channels is mostly 3 (for RGB images)
|
|
315
|
+
batch_size = x.shape[0]
|
|
316
|
+
|
|
317
|
+
# [Batch, in_channels, height, width] --> [Batch, emb_dim, num_patches_height, num_patches_width]
|
|
318
|
+
patch_emb = self.patch_emb(x)
|
|
319
|
+
n_h, n_w = patch_emb.shape[-2:]
|
|
320
|
+
|
|
321
|
+
# [Batch, emb_dim, num_patches_height, num_patches_width] --> [Batch, emb_dim, num_patches]
|
|
322
|
+
patch_emb = patch_emb.flatten(2)
|
|
323
|
+
# [Batch, emb_dim, num_patches] --> [Batch, num_patches, emb_dim]
|
|
324
|
+
patch_emb = patch_emb.transpose(1, 2).contiguous()
|
|
325
|
+
|
|
326
|
+
n_patches = patch_emb.shape[1]
|
|
327
|
+
# we resize the positional encodings dynamically.
|
|
328
|
+
pos_emb = self.pos_embed(n_patches).to(patch_emb.dtype)
|
|
329
|
+
|
|
330
|
+
# add positional encodings
|
|
331
|
+
patch_emb = pos_emb + patch_emb
|
|
332
|
+
|
|
333
|
+
# add classification token
|
|
334
|
+
if self.cls_token is not None:
|
|
335
|
+
# [1, 1, emb_dim] --> [Batch, 1, emb_dim]
|
|
336
|
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
|
337
|
+
# Concat([Batch, 1, emb_dim], [Batch, num_patches, emb_dim]) --> [Batch, num_patches + 1, emb_dim]
|
|
338
|
+
patch_emb = torch.cat((cls_tokens, patch_emb), dim=1)
|
|
339
|
+
|
|
340
|
+
# dropout
|
|
341
|
+
patch_emb = self.emb_dropout(patch_emb)
|
|
342
|
+
return patch_emb, (n_h, n_w)
|
|
343
|
+
|
|
344
|
+
def _features_from_transformer(
|
|
345
|
+
self, x: Tensor, *args, **kwargs
|
|
346
|
+
) -> Tuple[Tensor, Tuple[int, int]]:
|
|
347
|
+
# this function extract patch embeddings and then apply transformer module to learn
|
|
348
|
+
# inter-patch representations
|
|
349
|
+
|
|
350
|
+
# [B, N, C] --> [N, B, embed_dim], where B is batch size, N is number of tokens,
|
|
351
|
+
# and embed_dim is feature dim
|
|
352
|
+
x, (n_h, n_w) = self.extract_patch_embeddings(x)
|
|
353
|
+
|
|
354
|
+
for layer in self.transformer:
|
|
355
|
+
x = layer(x)
|
|
356
|
+
x = self.post_transformer_norm(x)
|
|
357
|
+
|
|
358
|
+
return x, (n_h, n_w)
|
|
359
|
+
|
|
360
|
+
def extract_features(
|
|
361
|
+
self, x: Tensor, *args, **kwargs
|
|
362
|
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
363
|
+
# The extract_features function for ViT returns two outputs: (1) embedding corresponding to CLS token
|
|
364
|
+
# and (2) image embeddings of the shape [B, C, h//o, w//o], where the value of o is typically 16.
|
|
365
|
+
return_image_embeddings = kwargs.get("return_image_embeddings", False)
|
|
366
|
+
|
|
367
|
+
# [B, C, H, W] --> [B, N + 1, embed_dim] or [B, N, embed_dim]
|
|
368
|
+
# here, B is batch size, C is input channels
|
|
369
|
+
# H and W are input height and width
|
|
370
|
+
# N is the number of pixels (or tokens) after processing input with conv stem and reshaping
|
|
371
|
+
# We add +1 for cls token (if applicable)
|
|
372
|
+
# embed_dim --> embedding dimension
|
|
373
|
+
x, (n_h, n_w) = self._features_from_transformer(x, *args, **kwargs)
|
|
374
|
+
|
|
375
|
+
if self.cls_token is not None:
|
|
376
|
+
# [B, N + 1, embed_dim] --> [B, embed_dim], [B, N, embed_dim]
|
|
377
|
+
cls_embedding, image_embedding = torch.split(
|
|
378
|
+
x, split_size_or_sections=[1, x.shape[1] - 1], dim=1
|
|
379
|
+
)
|
|
380
|
+
cls_embedding = cls_embedding.squeeze(1)
|
|
381
|
+
else:
|
|
382
|
+
# [B, N, embed_dim] -> [B, embed_dim]
|
|
383
|
+
cls_embedding = torch.mean(x, dim=1)
|
|
384
|
+
# [B, N, embed_dim]
|
|
385
|
+
image_embedding = x
|
|
386
|
+
|
|
387
|
+
if return_image_embeddings:
|
|
388
|
+
# reshape image embedding to 4-D tensor
|
|
389
|
+
# [B, N, C] --> [B, C, N]
|
|
390
|
+
image_embedding = image_embedding.transpose(1, 2).contiguous()
|
|
391
|
+
image_embedding = image_embedding.reshape(
|
|
392
|
+
image_embedding.shape[0], -1, n_h, n_w
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
return cls_embedding, image_embedding
|
|
396
|
+
else:
|
|
397
|
+
return cls_embedding, None
|
|
398
|
+
|
|
399
|
+
def forward_classifier(self, x: Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]:
|
|
400
|
+
cls_embedding, image_embedding = self.extract_features(x, *args, **kwargs)
|
|
401
|
+
# classify based on CLS token
|
|
402
|
+
cls_embedding = self.classifier(cls_embedding)
|
|
403
|
+
return cls_embedding, image_embedding
|
|
404
|
+
|
|
405
|
+
def forward(self, x: Tensor, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]:
|
|
406
|
+
# In ViT model, we can return either classifier embeddings (logits) or image embeddings or both.
|
|
407
|
+
# To return the image embeddings, we need to set keyword argument (return_image_embeddings) as True.
|
|
408
|
+
if kwargs.get("return_image_embeddings", False):
|
|
409
|
+
out_dict = dict()
|
|
410
|
+
prediction, image_embedding = self.forward_classifier(x, *args, **kwargs)
|
|
411
|
+
out_dict.update({"logits": prediction})
|
|
412
|
+
if image_embedding is not None:
|
|
413
|
+
out_dict.update({"image_embeddings": image_embedding})
|
|
414
|
+
return out_dict
|
|
415
|
+
else:
|
|
416
|
+
prediction, _ = self.forward_classifier(x, *args, **kwargs)
|
|
417
|
+
return prediction
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
@register_model
|
|
421
|
+
def vit_b16(pretrained=False, **kwargs):
|
|
422
|
+
# Vision transformer config
|
|
423
|
+
cfg = {
|
|
424
|
+
"norm_layer": "layer_norm_fp32",
|
|
425
|
+
"act_layer": "gelu",
|
|
426
|
+
"embed_dim": 768,
|
|
427
|
+
"n_transformer_layers": 12,
|
|
428
|
+
"n_attn_heads": 12,
|
|
429
|
+
}
|
|
430
|
+
model = VisionTransformer(cfg=cfg, **kwargs)
|
|
431
|
+
if pretrained:
|
|
432
|
+
raise ValueError("Functionality not implemented.")
|
|
433
|
+
return model
|