lightly-studio 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lightly-studio might be problematic. Click here for more details.
- lightly_studio/__init__.py +11 -0
- lightly_studio/api/__init__.py +0 -0
- lightly_studio/api/app.py +110 -0
- lightly_studio/api/cache.py +77 -0
- lightly_studio/api/db.py +133 -0
- lightly_studio/api/db_tables.py +32 -0
- lightly_studio/api/features.py +7 -0
- lightly_studio/api/routes/api/annotation.py +233 -0
- lightly_studio/api/routes/api/annotation_label.py +90 -0
- lightly_studio/api/routes/api/annotation_task.py +38 -0
- lightly_studio/api/routes/api/classifier.py +387 -0
- lightly_studio/api/routes/api/dataset.py +182 -0
- lightly_studio/api/routes/api/dataset_tag.py +257 -0
- lightly_studio/api/routes/api/exceptions.py +96 -0
- lightly_studio/api/routes/api/features.py +17 -0
- lightly_studio/api/routes/api/metadata.py +37 -0
- lightly_studio/api/routes/api/metrics.py +80 -0
- lightly_studio/api/routes/api/sample.py +196 -0
- lightly_studio/api/routes/api/settings.py +45 -0
- lightly_studio/api/routes/api/status.py +19 -0
- lightly_studio/api/routes/api/text_embedding.py +48 -0
- lightly_studio/api/routes/api/validators.py +17 -0
- lightly_studio/api/routes/healthz.py +13 -0
- lightly_studio/api/routes/images.py +104 -0
- lightly_studio/api/routes/webapp.py +51 -0
- lightly_studio/api/server.py +82 -0
- lightly_studio/core/__init__.py +0 -0
- lightly_studio/core/dataset.py +523 -0
- lightly_studio/core/sample.py +77 -0
- lightly_studio/core/start_gui.py +15 -0
- lightly_studio/dataset/__init__.py +0 -0
- lightly_studio/dataset/edge_embedding_generator.py +144 -0
- lightly_studio/dataset/embedding_generator.py +91 -0
- lightly_studio/dataset/embedding_manager.py +163 -0
- lightly_studio/dataset/env.py +16 -0
- lightly_studio/dataset/file_utils.py +35 -0
- lightly_studio/dataset/loader.py +622 -0
- lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
- lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
- lightly_studio/examples/example.py +23 -0
- lightly_studio/examples/example_metadata.py +338 -0
- lightly_studio/examples/example_selection.py +39 -0
- lightly_studio/examples/example_split_work.py +67 -0
- lightly_studio/examples/example_v2.py +21 -0
- lightly_studio/export_schema.py +18 -0
- lightly_studio/few_shot_classifier/__init__.py +0 -0
- lightly_studio/few_shot_classifier/classifier.py +80 -0
- lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
- lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
- lightly_studio/metadata/complex_metadata.py +47 -0
- lightly_studio/metadata/gps_coordinate.py +41 -0
- lightly_studio/metadata/metadata_protocol.py +17 -0
- lightly_studio/metrics/__init__.py +0 -0
- lightly_studio/metrics/detection/__init__.py +0 -0
- lightly_studio/metrics/detection/map.py +268 -0
- lightly_studio/models/__init__.py +1 -0
- lightly_studio/models/annotation/__init__.py +0 -0
- lightly_studio/models/annotation/annotation_base.py +171 -0
- lightly_studio/models/annotation/instance_segmentation.py +56 -0
- lightly_studio/models/annotation/links.py +17 -0
- lightly_studio/models/annotation/object_detection.py +47 -0
- lightly_studio/models/annotation/semantic_segmentation.py +44 -0
- lightly_studio/models/annotation_label.py +47 -0
- lightly_studio/models/annotation_task.py +28 -0
- lightly_studio/models/classifier.py +20 -0
- lightly_studio/models/dataset.py +84 -0
- lightly_studio/models/embedding_model.py +30 -0
- lightly_studio/models/metadata.py +208 -0
- lightly_studio/models/sample.py +180 -0
- lightly_studio/models/sample_embedding.py +37 -0
- lightly_studio/models/settings.py +60 -0
- lightly_studio/models/tag.py +96 -0
- lightly_studio/py.typed +0 -0
- lightly_studio/resolvers/__init__.py +7 -0
- lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
- lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
- lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
- lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
- lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
- lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
- lightly_studio/resolvers/annotation_resolver/create.py +19 -0
- lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
- lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
- lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
- lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
- lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
- lightly_studio/resolvers/annotation_task_resolver.py +31 -0
- lightly_studio/resolvers/annotations/__init__.py +1 -0
- lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
- lightly_studio/resolvers/dataset_resolver.py +278 -0
- lightly_studio/resolvers/embedding_model_resolver.py +100 -0
- lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
- lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
- lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
- lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
- lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
- lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
- lightly_studio/resolvers/sample_resolver.py +249 -0
- lightly_studio/resolvers/samples_filter.py +81 -0
- lightly_studio/resolvers/settings_resolver.py +58 -0
- lightly_studio/resolvers/tag_resolver.py +276 -0
- lightly_studio/selection/README.md +6 -0
- lightly_studio/selection/mundig.py +105 -0
- lightly_studio/selection/select.py +96 -0
- lightly_studio/selection/select_via_db.py +93 -0
- lightly_studio/selection/selection_config.py +31 -0
- lightly_studio/services/annotations_service/__init__.py +21 -0
- lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
- lightly_studio/services/annotations_service/update_annotation.py +65 -0
- lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
- lightly_studio/services/annotations_service/update_annotations.py +29 -0
- lightly_studio/setup_logging.py +19 -0
- lightly_studio/type_definitions.py +19 -0
- lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
- lightly_studio/vendor/LICENSE +31 -0
- lightly_studio/vendor/LICENSE_weights_data +50 -0
- lightly_studio/vendor/README.md +5 -0
- lightly_studio/vendor/__init__.py +1 -0
- lightly_studio/vendor/mobileclip/__init__.py +96 -0
- lightly_studio/vendor/mobileclip/clip.py +77 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
- lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
- lightly_studio/vendor/mobileclip/logger.py +154 -0
- lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
- lightly_studio/vendor/mobileclip/models/mci.py +933 -0
- lightly_studio/vendor/mobileclip/models/vit.py +433 -0
- lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
- lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
- lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
- lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
- lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
- lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
- lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
- lightly_studio-0.3.1.dist-info/METADATA +520 -0
- lightly_studio-0.3.1.dist-info/RECORD +219 -0
- lightly_studio-0.3.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
#
|
|
2
|
+
# For acknowledgement see accompanying ACKNOWLEDGEMENTS file.
|
|
3
|
+
# Copyright (C) 2024 Apple Inc. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
from typing import Tuple
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
|
|
10
|
+
from timm.models.layers import SqueezeExcite
|
|
11
|
+
|
|
12
|
+
__all__ = ["ReparamLargeKernelConv"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ReparamLargeKernelConv(nn.Module):
|
|
16
|
+
"""Building Block of RepLKNet
|
|
17
|
+
|
|
18
|
+
This class defines overparameterized large kernel conv block
|
|
19
|
+
introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_
|
|
20
|
+
|
|
21
|
+
Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
in_channels: int,
|
|
27
|
+
out_channels: int,
|
|
28
|
+
kernel_size: int,
|
|
29
|
+
stride: int,
|
|
30
|
+
groups: int,
|
|
31
|
+
small_kernel: int,
|
|
32
|
+
inference_mode: bool = False,
|
|
33
|
+
use_se: bool = False,
|
|
34
|
+
activation: nn.Module = nn.GELU(),
|
|
35
|
+
) -> None:
|
|
36
|
+
"""Construct a ReparamLargeKernelConv module.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
in_channels: Number of input channels.
|
|
40
|
+
out_channels: Number of output channels.
|
|
41
|
+
kernel_size: Kernel size of the large kernel conv branch.
|
|
42
|
+
stride: Stride size. Default: 1
|
|
43
|
+
groups: Group number. Default: 1
|
|
44
|
+
small_kernel: Kernel size of small kernel conv branch.
|
|
45
|
+
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
|
46
|
+
activation: Activation module. Default: ``nn.GELU``
|
|
47
|
+
"""
|
|
48
|
+
super(ReparamLargeKernelConv, self).__init__()
|
|
49
|
+
|
|
50
|
+
self.stride = stride
|
|
51
|
+
self.groups = groups
|
|
52
|
+
self.in_channels = in_channels
|
|
53
|
+
self.out_channels = out_channels
|
|
54
|
+
self.activation = activation
|
|
55
|
+
|
|
56
|
+
self.kernel_size = kernel_size
|
|
57
|
+
self.small_kernel = small_kernel
|
|
58
|
+
self.padding = kernel_size // 2
|
|
59
|
+
|
|
60
|
+
# Check if SE is requested
|
|
61
|
+
if use_se:
|
|
62
|
+
self.se = SqueezeExcite(out_channels, rd_ratio=0.25)
|
|
63
|
+
else:
|
|
64
|
+
self.se = nn.Identity()
|
|
65
|
+
|
|
66
|
+
if inference_mode:
|
|
67
|
+
self.lkb_reparam = nn.Conv2d(
|
|
68
|
+
in_channels=in_channels,
|
|
69
|
+
out_channels=out_channels,
|
|
70
|
+
kernel_size=kernel_size,
|
|
71
|
+
stride=stride,
|
|
72
|
+
padding=self.padding,
|
|
73
|
+
dilation=1,
|
|
74
|
+
groups=groups,
|
|
75
|
+
bias=True,
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
self.lkb_origin = self._conv_bn(
|
|
79
|
+
kernel_size=kernel_size, padding=self.padding
|
|
80
|
+
)
|
|
81
|
+
if small_kernel is not None:
|
|
82
|
+
assert (
|
|
83
|
+
small_kernel <= kernel_size
|
|
84
|
+
), "The kernel size for re-param cannot be larger than the large kernel!"
|
|
85
|
+
self.small_conv = self._conv_bn(
|
|
86
|
+
kernel_size=small_kernel, padding=small_kernel // 2
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
90
|
+
"""Apply forward pass."""
|
|
91
|
+
if hasattr(self, "lkb_reparam"):
|
|
92
|
+
out = self.lkb_reparam(x)
|
|
93
|
+
else:
|
|
94
|
+
out = self.lkb_origin(x)
|
|
95
|
+
if hasattr(self, "small_conv"):
|
|
96
|
+
out += self.small_conv(x)
|
|
97
|
+
|
|
98
|
+
return self.activation(self.se(out))
|
|
99
|
+
|
|
100
|
+
def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
101
|
+
"""Method to obtain re-parameterized kernel and bias.
|
|
102
|
+
Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Tuple of (kernel, bias) after fusing branches.
|
|
106
|
+
"""
|
|
107
|
+
eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
|
|
108
|
+
if hasattr(self, "small_conv"):
|
|
109
|
+
small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
|
|
110
|
+
eq_b += small_b
|
|
111
|
+
eq_k += nn.functional.pad(
|
|
112
|
+
small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
|
|
113
|
+
)
|
|
114
|
+
return eq_k, eq_b
|
|
115
|
+
|
|
116
|
+
def reparameterize(self) -> None:
|
|
117
|
+
"""
|
|
118
|
+
Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
|
|
119
|
+
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
|
|
120
|
+
architecture used at training time to obtain a plain CNN-like structure
|
|
121
|
+
for inference.
|
|
122
|
+
"""
|
|
123
|
+
eq_k, eq_b = self.get_kernel_bias()
|
|
124
|
+
self.lkb_reparam = nn.Conv2d(
|
|
125
|
+
in_channels=self.in_channels,
|
|
126
|
+
out_channels=self.out_channels,
|
|
127
|
+
kernel_size=self.kernel_size,
|
|
128
|
+
stride=self.stride,
|
|
129
|
+
padding=self.padding,
|
|
130
|
+
dilation=self.lkb_origin.conv.dilation,
|
|
131
|
+
groups=self.groups,
|
|
132
|
+
bias=True,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
self.lkb_reparam.weight.data = eq_k
|
|
136
|
+
self.lkb_reparam.bias.data = eq_b
|
|
137
|
+
self.__delattr__("lkb_origin")
|
|
138
|
+
if hasattr(self, "small_conv"):
|
|
139
|
+
self.__delattr__("small_conv")
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
def _fuse_bn(
|
|
143
|
+
conv: torch.Tensor, bn: nn.BatchNorm2d
|
|
144
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
145
|
+
"""Method to fuse batchnorm layer with conv layer.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
conv: Convolutional kernel weights.
|
|
149
|
+
bn: Batchnorm 2d layer.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Tuple of (kernel, bias) after fusing batchnorm.
|
|
153
|
+
"""
|
|
154
|
+
kernel = conv.weight
|
|
155
|
+
running_mean = bn.running_mean
|
|
156
|
+
running_var = bn.running_var
|
|
157
|
+
gamma = bn.weight
|
|
158
|
+
beta = bn.bias
|
|
159
|
+
eps = bn.eps
|
|
160
|
+
std = (running_var + eps).sqrt()
|
|
161
|
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
|
162
|
+
return kernel * t, beta - running_mean * gamma / std
|
|
163
|
+
|
|
164
|
+
def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
|
|
165
|
+
"""Helper method to construct conv-batchnorm layers.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
kernel_size: Size of the convolution kernel.
|
|
169
|
+
padding: Zero-padding size.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
A nn.Sequential Conv-BN module.
|
|
173
|
+
"""
|
|
174
|
+
mod_list = nn.Sequential()
|
|
175
|
+
mod_list.add_module(
|
|
176
|
+
"conv",
|
|
177
|
+
nn.Conv2d(
|
|
178
|
+
in_channels=self.in_channels,
|
|
179
|
+
out_channels=self.out_channels,
|
|
180
|
+
kernel_size=kernel_size,
|
|
181
|
+
stride=self.stride,
|
|
182
|
+
padding=padding,
|
|
183
|
+
groups=self.groups,
|
|
184
|
+
bias=False,
|
|
185
|
+
),
|
|
186
|
+
)
|
|
187
|
+
mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
|
|
188
|
+
return mod_list
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
#
|
|
2
|
+
# For licensing see accompanying LICENSE file.
|
|
3
|
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
4
|
+
#
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
|
|
10
|
+
from timm.models.layers import DropPath, trunc_normal_
|
|
11
|
+
from ..common.mobileone import MobileOneBlock
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ConvFFN(nn.Module):
|
|
15
|
+
"""Convolutional FFN Module."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
in_channels: int,
|
|
20
|
+
context_size: int,
|
|
21
|
+
hidden_channels: Optional[int] = None,
|
|
22
|
+
out_channels: Optional[int] = None,
|
|
23
|
+
act_layer: nn.Module = nn.GELU,
|
|
24
|
+
drop: float = 0.0,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Build convolutional FFN module.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
in_channels: Number of input channels.
|
|
30
|
+
context_size: Context size for 1D signals.
|
|
31
|
+
hidden_channels: Number of channels after expansion. Default: None
|
|
32
|
+
out_channels: Number of output channels. Default: None
|
|
33
|
+
act_layer: Activation layer. Default: ``GELU``
|
|
34
|
+
drop: Dropout rate. Default: ``0.0``.
|
|
35
|
+
"""
|
|
36
|
+
super().__init__()
|
|
37
|
+
out_channels = out_channels or in_channels
|
|
38
|
+
hidden_channels = hidden_channels or in_channels
|
|
39
|
+
self.conv = nn.Sequential()
|
|
40
|
+
self.conv.add_module(
|
|
41
|
+
"conv",
|
|
42
|
+
nn.Conv2d(
|
|
43
|
+
in_channels=in_channels,
|
|
44
|
+
out_channels=out_channels,
|
|
45
|
+
kernel_size=(1, int(context_size)),
|
|
46
|
+
padding=(0, int(context_size // 2)),
|
|
47
|
+
groups=in_channels,
|
|
48
|
+
bias=False,
|
|
49
|
+
),
|
|
50
|
+
)
|
|
51
|
+
self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
|
|
52
|
+
self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
|
|
53
|
+
self.act = act_layer()
|
|
54
|
+
self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
|
|
55
|
+
self.drop = nn.Dropout(drop)
|
|
56
|
+
self.apply(self._init_weights)
|
|
57
|
+
|
|
58
|
+
def _init_weights(self, m: nn.Module) -> None:
|
|
59
|
+
if isinstance(m, nn.Conv2d):
|
|
60
|
+
trunc_normal_(m.weight, std=0.02)
|
|
61
|
+
if m.bias is not None:
|
|
62
|
+
nn.init.constant_(m.bias, 0)
|
|
63
|
+
|
|
64
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
65
|
+
x = self.conv(x)
|
|
66
|
+
x = self.fc1(x)
|
|
67
|
+
x = self.act(x)
|
|
68
|
+
x = self.drop(x)
|
|
69
|
+
x = self.fc2(x)
|
|
70
|
+
x = self.drop(x)
|
|
71
|
+
return x
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class RepMixer(nn.Module):
|
|
75
|
+
"""Reparameterizable token mixer.
|
|
76
|
+
|
|
77
|
+
For more details, please refer to our paper:
|
|
78
|
+
`FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
dim,
|
|
84
|
+
kernel_size=3,
|
|
85
|
+
use_layer_scale=True,
|
|
86
|
+
layer_scale_init_value=1e-5,
|
|
87
|
+
inference_mode: bool = False,
|
|
88
|
+
):
|
|
89
|
+
"""Build RepMixer Module.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
|
|
93
|
+
kernel_size: Kernel size for spatial mixing. Default: 3
|
|
94
|
+
use_layer_scale: If True, learnable layer scale is used. Default: ``True``
|
|
95
|
+
layer_scale_init_value: Initial value for layer scale. Default: 1e-5
|
|
96
|
+
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
|
97
|
+
"""
|
|
98
|
+
super().__init__()
|
|
99
|
+
self.dim = dim
|
|
100
|
+
self.kernel_size = kernel_size
|
|
101
|
+
self.inference_mode = inference_mode
|
|
102
|
+
|
|
103
|
+
if inference_mode:
|
|
104
|
+
self.reparam_conv = nn.Conv2d(
|
|
105
|
+
in_channels=self.dim,
|
|
106
|
+
out_channels=self.dim,
|
|
107
|
+
kernel_size=(1, self.kernel_size),
|
|
108
|
+
stride=1,
|
|
109
|
+
padding=(0, self.kernel_size // 2),
|
|
110
|
+
groups=self.dim,
|
|
111
|
+
bias=True,
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
self.norm = MobileOneBlock(
|
|
115
|
+
dim,
|
|
116
|
+
dim,
|
|
117
|
+
(1, kernel_size),
|
|
118
|
+
padding=(0, kernel_size // 2),
|
|
119
|
+
groups=dim,
|
|
120
|
+
use_act=False,
|
|
121
|
+
use_scale_branch=False,
|
|
122
|
+
num_conv_branches=0,
|
|
123
|
+
)
|
|
124
|
+
self.mixer = MobileOneBlock(
|
|
125
|
+
dim,
|
|
126
|
+
dim,
|
|
127
|
+
(1, kernel_size),
|
|
128
|
+
padding=(0, kernel_size // 2),
|
|
129
|
+
groups=dim,
|
|
130
|
+
use_act=False,
|
|
131
|
+
)
|
|
132
|
+
self.use_layer_scale = use_layer_scale
|
|
133
|
+
if use_layer_scale:
|
|
134
|
+
self.layer_scale = nn.Parameter(
|
|
135
|
+
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
139
|
+
if hasattr(self, "reparam_conv"):
|
|
140
|
+
x = self.reparam_conv(x)
|
|
141
|
+
return x
|
|
142
|
+
else:
|
|
143
|
+
if self.use_layer_scale:
|
|
144
|
+
x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
|
|
145
|
+
else:
|
|
146
|
+
x = x + self.mixer(x) - self.norm(x)
|
|
147
|
+
return x
|
|
148
|
+
|
|
149
|
+
def reparameterize(self) -> None:
|
|
150
|
+
"""Reparameterize mixer and norm into a single
|
|
151
|
+
convolutional layer for efficient inference.
|
|
152
|
+
"""
|
|
153
|
+
if self.inference_mode:
|
|
154
|
+
return
|
|
155
|
+
|
|
156
|
+
self.mixer.reparameterize()
|
|
157
|
+
self.norm.reparameterize()
|
|
158
|
+
|
|
159
|
+
if self.use_layer_scale:
|
|
160
|
+
w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
|
|
161
|
+
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
|
162
|
+
)
|
|
163
|
+
b = torch.squeeze(self.layer_scale) * (
|
|
164
|
+
self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
w = (
|
|
168
|
+
self.mixer.id_tensor
|
|
169
|
+
+ self.mixer.reparam_conv.weight
|
|
170
|
+
- self.norm.reparam_conv.weight
|
|
171
|
+
)
|
|
172
|
+
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
|
173
|
+
|
|
174
|
+
self.reparam_conv = nn.Conv2d(
|
|
175
|
+
in_channels=self.dim,
|
|
176
|
+
out_channels=self.dim,
|
|
177
|
+
kernel_size=(1, self.kernel_size),
|
|
178
|
+
stride=1,
|
|
179
|
+
padding=(0, self.kernel_size // 2),
|
|
180
|
+
groups=self.dim,
|
|
181
|
+
bias=True,
|
|
182
|
+
)
|
|
183
|
+
self.reparam_conv.weight.data = w
|
|
184
|
+
self.reparam_conv.bias.data = b
|
|
185
|
+
|
|
186
|
+
for para in self.parameters():
|
|
187
|
+
para.detach_()
|
|
188
|
+
self.__delattr__("mixer")
|
|
189
|
+
self.__delattr__("norm")
|
|
190
|
+
if self.use_layer_scale:
|
|
191
|
+
self.__delattr__("layer_scale")
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class RepMixerBlock(nn.Module):
|
|
195
|
+
"""Implementation of Metaformer block with RepMixer as token mixer.
|
|
196
|
+
|
|
197
|
+
For more details on Metaformer structure, please refer to:
|
|
198
|
+
`MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
def __init__(
|
|
202
|
+
self,
|
|
203
|
+
dim: int,
|
|
204
|
+
kernel_size: int = 11,
|
|
205
|
+
mlp_ratio: float = 4.0,
|
|
206
|
+
act_layer: nn.Module = nn.GELU,
|
|
207
|
+
drop: float = 0.0,
|
|
208
|
+
drop_path: float = 0.0,
|
|
209
|
+
use_layer_scale: bool = True,
|
|
210
|
+
layer_scale_init_value: float = 1e-5,
|
|
211
|
+
inference_mode: bool = False,
|
|
212
|
+
*args,
|
|
213
|
+
**kwargs,
|
|
214
|
+
):
|
|
215
|
+
"""Build RepMixer Block.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
dim: Number of embedding dimensions.
|
|
219
|
+
kernel_size: Kernel size for repmixer. Default: 3
|
|
220
|
+
mlp_ratio: MLP expansion ratio. Default: 4.0
|
|
221
|
+
act_layer: Activation layer. Default: ``nn.GELU``
|
|
222
|
+
drop: Dropout rate. Default: 0.0
|
|
223
|
+
drop_path: Drop path rate. Default: 0.0
|
|
224
|
+
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
|
225
|
+
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
|
226
|
+
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
super().__init__()
|
|
230
|
+
|
|
231
|
+
self.token_mixer = RepMixer(
|
|
232
|
+
dim,
|
|
233
|
+
kernel_size=kernel_size,
|
|
234
|
+
use_layer_scale=use_layer_scale,
|
|
235
|
+
layer_scale_init_value=layer_scale_init_value,
|
|
236
|
+
inference_mode=inference_mode,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
|
|
240
|
+
mlp_ratio
|
|
241
|
+
)
|
|
242
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
243
|
+
self.convffn = ConvFFN(
|
|
244
|
+
in_channels=dim,
|
|
245
|
+
context_size=kernel_size,
|
|
246
|
+
hidden_channels=mlp_hidden_dim,
|
|
247
|
+
act_layer=act_layer,
|
|
248
|
+
drop=drop,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Drop Path
|
|
252
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
253
|
+
|
|
254
|
+
# Layer Scale
|
|
255
|
+
self.use_layer_scale = use_layer_scale
|
|
256
|
+
if use_layer_scale:
|
|
257
|
+
self.layer_scale = nn.Parameter(
|
|
258
|
+
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def forward(self, x, *args, **kwargs):
|
|
262
|
+
if x.dim() == 3:
|
|
263
|
+
# B, C, D --- where C is the context length
|
|
264
|
+
# Convert to B, D, C --- to match RepMixer impl.
|
|
265
|
+
x = x.permute(0, 2, 1)
|
|
266
|
+
x = torch.unsqueeze(x, dim=2)
|
|
267
|
+
else:
|
|
268
|
+
raise ValueError(
|
|
269
|
+
f"Expected tensor of dim=3, obtained tensor of dim={x.dim()}"
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
if self.use_layer_scale:
|
|
273
|
+
x = self.token_mixer(x)
|
|
274
|
+
x = x + self.drop_path(self.layer_scale * self.convffn(x))
|
|
275
|
+
else:
|
|
276
|
+
x = self.token_mixer(x)
|
|
277
|
+
x = x + self.drop_path(self.convffn(x))
|
|
278
|
+
|
|
279
|
+
# Convert tensors back
|
|
280
|
+
x = x.squeeze(dim=2).permute(0, 2, 1)
|
|
281
|
+
return x
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
#
|
|
2
|
+
# For licensing see accompanying LICENSE file.
|
|
3
|
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
4
|
+
#
|
|
5
|
+
from typing import Dict
|
|
6
|
+
|
|
7
|
+
import open_clip
|
|
8
|
+
from torch import Tensor, nn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ClipTokenizer(nn.Module):
|
|
12
|
+
def __init__(self, cfg, *args, **kwargs):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.context_length = cfg["text_cfg"]["context_length"]
|
|
15
|
+
model_name = getattr(cfg["text_cfg"], "open_clip_tokenizer", "ViT-B-16")
|
|
16
|
+
self.tokenizer = open_clip.get_tokenizer(model_name)
|
|
17
|
+
|
|
18
|
+
def get_vocab_size(self) -> int:
|
|
19
|
+
return len(self.tokenizer.encoder)
|
|
20
|
+
|
|
21
|
+
def get_encodings(self) -> Dict[str, int]:
|
|
22
|
+
return self.tokenizer.encoder
|
|
23
|
+
|
|
24
|
+
def get_eot_token(self) -> int:
|
|
25
|
+
# Tokenizing an empty string returns a list [sot_id, eot_id]
|
|
26
|
+
return self.tokenizer("")[1]
|
|
27
|
+
|
|
28
|
+
def get_sot_token(self) -> int:
|
|
29
|
+
# Tokenizing an empty string returns a list [sot_id, eot_id]
|
|
30
|
+
return self.tokenizer("")[0]
|
|
31
|
+
|
|
32
|
+
def forward(self, input_sentence: str, *args, **kwargs) -> Tensor:
|
|
33
|
+
# tokenizer returns indices as a string
|
|
34
|
+
tokenized_sentence = self.tokenizer(input_sentence, self.context_length)
|
|
35
|
+
assert (
|
|
36
|
+
tokenized_sentence.shape[-1] == self.context_length
|
|
37
|
+
), "Tokenized tensor should be exactly `context_length` long."
|
|
38
|
+
return tokenized_sentence
|