lightly-studio 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lightly-studio might be problematic. Click here for more details.
- lightly_studio/__init__.py +11 -0
- lightly_studio/api/__init__.py +0 -0
- lightly_studio/api/app.py +110 -0
- lightly_studio/api/cache.py +77 -0
- lightly_studio/api/db.py +133 -0
- lightly_studio/api/db_tables.py +32 -0
- lightly_studio/api/features.py +7 -0
- lightly_studio/api/routes/api/annotation.py +233 -0
- lightly_studio/api/routes/api/annotation_label.py +90 -0
- lightly_studio/api/routes/api/annotation_task.py +38 -0
- lightly_studio/api/routes/api/classifier.py +387 -0
- lightly_studio/api/routes/api/dataset.py +182 -0
- lightly_studio/api/routes/api/dataset_tag.py +257 -0
- lightly_studio/api/routes/api/exceptions.py +96 -0
- lightly_studio/api/routes/api/features.py +17 -0
- lightly_studio/api/routes/api/metadata.py +37 -0
- lightly_studio/api/routes/api/metrics.py +80 -0
- lightly_studio/api/routes/api/sample.py +196 -0
- lightly_studio/api/routes/api/settings.py +45 -0
- lightly_studio/api/routes/api/status.py +19 -0
- lightly_studio/api/routes/api/text_embedding.py +48 -0
- lightly_studio/api/routes/api/validators.py +17 -0
- lightly_studio/api/routes/healthz.py +13 -0
- lightly_studio/api/routes/images.py +104 -0
- lightly_studio/api/routes/webapp.py +51 -0
- lightly_studio/api/server.py +82 -0
- lightly_studio/core/__init__.py +0 -0
- lightly_studio/core/dataset.py +523 -0
- lightly_studio/core/sample.py +77 -0
- lightly_studio/core/start_gui.py +15 -0
- lightly_studio/dataset/__init__.py +0 -0
- lightly_studio/dataset/edge_embedding_generator.py +144 -0
- lightly_studio/dataset/embedding_generator.py +91 -0
- lightly_studio/dataset/embedding_manager.py +163 -0
- lightly_studio/dataset/env.py +16 -0
- lightly_studio/dataset/file_utils.py +35 -0
- lightly_studio/dataset/loader.py +622 -0
- lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
- lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
- lightly_studio/examples/example.py +23 -0
- lightly_studio/examples/example_metadata.py +338 -0
- lightly_studio/examples/example_selection.py +39 -0
- lightly_studio/examples/example_split_work.py +67 -0
- lightly_studio/examples/example_v2.py +21 -0
- lightly_studio/export_schema.py +18 -0
- lightly_studio/few_shot_classifier/__init__.py +0 -0
- lightly_studio/few_shot_classifier/classifier.py +80 -0
- lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
- lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
- lightly_studio/metadata/complex_metadata.py +47 -0
- lightly_studio/metadata/gps_coordinate.py +41 -0
- lightly_studio/metadata/metadata_protocol.py +17 -0
- lightly_studio/metrics/__init__.py +0 -0
- lightly_studio/metrics/detection/__init__.py +0 -0
- lightly_studio/metrics/detection/map.py +268 -0
- lightly_studio/models/__init__.py +1 -0
- lightly_studio/models/annotation/__init__.py +0 -0
- lightly_studio/models/annotation/annotation_base.py +171 -0
- lightly_studio/models/annotation/instance_segmentation.py +56 -0
- lightly_studio/models/annotation/links.py +17 -0
- lightly_studio/models/annotation/object_detection.py +47 -0
- lightly_studio/models/annotation/semantic_segmentation.py +44 -0
- lightly_studio/models/annotation_label.py +47 -0
- lightly_studio/models/annotation_task.py +28 -0
- lightly_studio/models/classifier.py +20 -0
- lightly_studio/models/dataset.py +84 -0
- lightly_studio/models/embedding_model.py +30 -0
- lightly_studio/models/metadata.py +208 -0
- lightly_studio/models/sample.py +180 -0
- lightly_studio/models/sample_embedding.py +37 -0
- lightly_studio/models/settings.py +60 -0
- lightly_studio/models/tag.py +96 -0
- lightly_studio/py.typed +0 -0
- lightly_studio/resolvers/__init__.py +7 -0
- lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
- lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
- lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
- lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
- lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
- lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
- lightly_studio/resolvers/annotation_resolver/create.py +19 -0
- lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
- lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
- lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
- lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
- lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
- lightly_studio/resolvers/annotation_task_resolver.py +31 -0
- lightly_studio/resolvers/annotations/__init__.py +1 -0
- lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
- lightly_studio/resolvers/dataset_resolver.py +278 -0
- lightly_studio/resolvers/embedding_model_resolver.py +100 -0
- lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
- lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
- lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
- lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
- lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
- lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
- lightly_studio/resolvers/sample_resolver.py +249 -0
- lightly_studio/resolvers/samples_filter.py +81 -0
- lightly_studio/resolvers/settings_resolver.py +58 -0
- lightly_studio/resolvers/tag_resolver.py +276 -0
- lightly_studio/selection/README.md +6 -0
- lightly_studio/selection/mundig.py +105 -0
- lightly_studio/selection/select.py +96 -0
- lightly_studio/selection/select_via_db.py +93 -0
- lightly_studio/selection/selection_config.py +31 -0
- lightly_studio/services/annotations_service/__init__.py +21 -0
- lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
- lightly_studio/services/annotations_service/update_annotation.py +65 -0
- lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
- lightly_studio/services/annotations_service/update_annotations.py +29 -0
- lightly_studio/setup_logging.py +19 -0
- lightly_studio/type_definitions.py +19 -0
- lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
- lightly_studio/vendor/LICENSE +31 -0
- lightly_studio/vendor/LICENSE_weights_data +50 -0
- lightly_studio/vendor/README.md +5 -0
- lightly_studio/vendor/__init__.py +1 -0
- lightly_studio/vendor/mobileclip/__init__.py +96 -0
- lightly_studio/vendor/mobileclip/clip.py +77 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
- lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
- lightly_studio/vendor/mobileclip/logger.py +154 -0
- lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
- lightly_studio/vendor/mobileclip/models/mci.py +933 -0
- lightly_studio/vendor/mobileclip/models/vit.py +433 -0
- lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
- lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
- lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
- lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
- lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
- lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
- lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
- lightly_studio-0.3.1.dist-info/METADATA +520 -0
- lightly_studio-0.3.1.dist-info/RECORD +219 -0
- lightly_studio-0.3.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
#
|
|
2
|
+
# For licensing see accompanying LICENSE file.
|
|
3
|
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
4
|
+
#
|
|
5
|
+
from typing import Union, Tuple
|
|
6
|
+
|
|
7
|
+
import copy
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
|
|
12
|
+
__all__ = ["MobileOneBlock", "reparameterize_model"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SEBlock(nn.Module):
|
|
16
|
+
"""Squeeze and Excite module.
|
|
17
|
+
|
|
18
|
+
Pytorch implementation of `Squeeze-and-Excitation Networks` -
|
|
19
|
+
https://arxiv.org/pdf/1709.01507.pdf
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
|
|
23
|
+
"""Construct a Squeeze and Excite Module.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
in_channels: Number of input channels.
|
|
27
|
+
rd_ratio: Input channel reduction ratio.
|
|
28
|
+
"""
|
|
29
|
+
super(SEBlock, self).__init__()
|
|
30
|
+
self.reduce = nn.Conv2d(
|
|
31
|
+
in_channels=in_channels,
|
|
32
|
+
out_channels=int(in_channels * rd_ratio),
|
|
33
|
+
kernel_size=1,
|
|
34
|
+
stride=1,
|
|
35
|
+
bias=True,
|
|
36
|
+
)
|
|
37
|
+
self.expand = nn.Conv2d(
|
|
38
|
+
in_channels=int(in_channels * rd_ratio),
|
|
39
|
+
out_channels=in_channels,
|
|
40
|
+
kernel_size=1,
|
|
41
|
+
stride=1,
|
|
42
|
+
bias=True,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
46
|
+
"""Apply forward pass."""
|
|
47
|
+
b, c, h, w = inputs.size()
|
|
48
|
+
x = F.avg_pool2d(inputs, kernel_size=[h, w])
|
|
49
|
+
x = self.reduce(x)
|
|
50
|
+
x = F.relu(x)
|
|
51
|
+
x = self.expand(x)
|
|
52
|
+
x = torch.sigmoid(x)
|
|
53
|
+
x = x.view(-1, c, 1, 1)
|
|
54
|
+
return inputs * x
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class MobileOneBlock(nn.Module):
|
|
58
|
+
"""MobileOne building block.
|
|
59
|
+
|
|
60
|
+
This block has a multi-branched architecture at train-time
|
|
61
|
+
and plain-CNN style architecture at inference time
|
|
62
|
+
For more details, please refer to our paper:
|
|
63
|
+
`An Improved One millisecond Mobile Backbone` -
|
|
64
|
+
https://arxiv.org/pdf/2206.04040.pdf
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
in_channels: int,
|
|
70
|
+
out_channels: int,
|
|
71
|
+
kernel_size: int,
|
|
72
|
+
stride: int = 1,
|
|
73
|
+
padding: int = 0,
|
|
74
|
+
dilation: int = 1,
|
|
75
|
+
groups: int = 1,
|
|
76
|
+
inference_mode: bool = False,
|
|
77
|
+
use_se: bool = False,
|
|
78
|
+
use_act: bool = True,
|
|
79
|
+
use_scale_branch: bool = True,
|
|
80
|
+
num_conv_branches: int = 1,
|
|
81
|
+
activation: nn.Module = nn.GELU(),
|
|
82
|
+
) -> None:
|
|
83
|
+
"""Construct a MobileOneBlock module.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
in_channels: Number of channels in the input.
|
|
87
|
+
out_channels: Number of channels produced by the block.
|
|
88
|
+
kernel_size: Size of the convolution kernel.
|
|
89
|
+
stride: Stride size.
|
|
90
|
+
padding: Zero-padding size.
|
|
91
|
+
dilation: Kernel dilation factor.
|
|
92
|
+
groups: Group number.
|
|
93
|
+
inference_mode: If True, instantiates model in inference mode.
|
|
94
|
+
use_se: Whether to use SE-ReLU activations.
|
|
95
|
+
use_act: Whether to use activation. Default: ``True``
|
|
96
|
+
use_scale_branch: Whether to use scale branch. Default: ``True``
|
|
97
|
+
num_conv_branches: Number of linear conv branches.
|
|
98
|
+
"""
|
|
99
|
+
super(MobileOneBlock, self).__init__()
|
|
100
|
+
self.inference_mode = inference_mode
|
|
101
|
+
self.groups = groups
|
|
102
|
+
self.stride = stride
|
|
103
|
+
self.padding = padding
|
|
104
|
+
self.dilation = dilation
|
|
105
|
+
self.kernel_size = kernel_size
|
|
106
|
+
self.in_channels = in_channels
|
|
107
|
+
self.out_channels = out_channels
|
|
108
|
+
self.num_conv_branches = num_conv_branches
|
|
109
|
+
|
|
110
|
+
# Check if SE-ReLU is requested
|
|
111
|
+
if use_se:
|
|
112
|
+
self.se = SEBlock(out_channels)
|
|
113
|
+
else:
|
|
114
|
+
self.se = nn.Identity()
|
|
115
|
+
|
|
116
|
+
if use_act:
|
|
117
|
+
self.activation = activation
|
|
118
|
+
else:
|
|
119
|
+
self.activation = nn.Identity()
|
|
120
|
+
|
|
121
|
+
if inference_mode:
|
|
122
|
+
self.reparam_conv = nn.Conv2d(
|
|
123
|
+
in_channels=in_channels,
|
|
124
|
+
out_channels=out_channels,
|
|
125
|
+
kernel_size=kernel_size,
|
|
126
|
+
stride=stride,
|
|
127
|
+
padding=padding,
|
|
128
|
+
dilation=dilation,
|
|
129
|
+
groups=groups,
|
|
130
|
+
bias=True,
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
# Re-parameterizable skip connection
|
|
134
|
+
self.rbr_skip = (
|
|
135
|
+
nn.BatchNorm2d(num_features=in_channels)
|
|
136
|
+
if out_channels == in_channels and stride == 1
|
|
137
|
+
else None
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Re-parameterizable conv branches
|
|
141
|
+
if num_conv_branches > 0:
|
|
142
|
+
rbr_conv = list()
|
|
143
|
+
for _ in range(self.num_conv_branches):
|
|
144
|
+
rbr_conv.append(
|
|
145
|
+
self._conv_bn(kernel_size=kernel_size, padding=padding)
|
|
146
|
+
)
|
|
147
|
+
self.rbr_conv = nn.ModuleList(rbr_conv)
|
|
148
|
+
else:
|
|
149
|
+
self.rbr_conv = None
|
|
150
|
+
|
|
151
|
+
# Re-parameterizable scale branch
|
|
152
|
+
self.rbr_scale = None
|
|
153
|
+
if not isinstance(kernel_size, int):
|
|
154
|
+
kernel_size = kernel_size[0]
|
|
155
|
+
if (kernel_size > 1) and use_scale_branch:
|
|
156
|
+
self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
|
|
157
|
+
|
|
158
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
159
|
+
"""Apply forward pass."""
|
|
160
|
+
# Inference mode forward pass.
|
|
161
|
+
if self.inference_mode:
|
|
162
|
+
return self.activation(self.se(self.reparam_conv(x)))
|
|
163
|
+
|
|
164
|
+
# Multi-branched train-time forward pass.
|
|
165
|
+
# Skip branch output
|
|
166
|
+
identity_out = 0
|
|
167
|
+
if self.rbr_skip is not None:
|
|
168
|
+
identity_out = self.rbr_skip(x)
|
|
169
|
+
|
|
170
|
+
# Scale branch output
|
|
171
|
+
scale_out = 0
|
|
172
|
+
if self.rbr_scale is not None:
|
|
173
|
+
scale_out = self.rbr_scale(x)
|
|
174
|
+
|
|
175
|
+
# Other branches
|
|
176
|
+
out = scale_out + identity_out
|
|
177
|
+
if self.rbr_conv is not None:
|
|
178
|
+
for ix in range(self.num_conv_branches):
|
|
179
|
+
out += self.rbr_conv[ix](x)
|
|
180
|
+
|
|
181
|
+
return self.activation(self.se(out))
|
|
182
|
+
|
|
183
|
+
def reparameterize(self):
|
|
184
|
+
"""Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
|
|
185
|
+
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
|
|
186
|
+
architecture used at training time to obtain a plain CNN-like structure
|
|
187
|
+
for inference.
|
|
188
|
+
"""
|
|
189
|
+
if self.inference_mode:
|
|
190
|
+
return
|
|
191
|
+
kernel, bias = self._get_kernel_bias()
|
|
192
|
+
self.reparam_conv = nn.Conv2d(
|
|
193
|
+
in_channels=self.in_channels,
|
|
194
|
+
out_channels=self.out_channels,
|
|
195
|
+
kernel_size=self.kernel_size,
|
|
196
|
+
stride=self.stride,
|
|
197
|
+
padding=self.padding,
|
|
198
|
+
dilation=self.dilation,
|
|
199
|
+
groups=self.groups,
|
|
200
|
+
bias=True,
|
|
201
|
+
)
|
|
202
|
+
self.reparam_conv.weight.data = kernel
|
|
203
|
+
self.reparam_conv.bias.data = bias
|
|
204
|
+
|
|
205
|
+
# Delete un-used branches
|
|
206
|
+
for para in self.parameters():
|
|
207
|
+
para.detach_()
|
|
208
|
+
self.__delattr__("rbr_conv")
|
|
209
|
+
self.__delattr__("rbr_scale")
|
|
210
|
+
if hasattr(self, "rbr_skip"):
|
|
211
|
+
self.__delattr__("rbr_skip")
|
|
212
|
+
|
|
213
|
+
self.inference_mode = True
|
|
214
|
+
|
|
215
|
+
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
216
|
+
"""Method to obtain re-parameterized kernel and bias.
|
|
217
|
+
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Tuple of (kernel, bias) after fusing branches.
|
|
221
|
+
"""
|
|
222
|
+
# get weights and bias of scale branch
|
|
223
|
+
kernel_scale = 0
|
|
224
|
+
bias_scale = 0
|
|
225
|
+
if self.rbr_scale is not None:
|
|
226
|
+
kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
|
|
227
|
+
# Pad scale branch kernel to match conv branch kernel size.
|
|
228
|
+
pad = self.kernel_size // 2
|
|
229
|
+
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
|
|
230
|
+
|
|
231
|
+
# get weights and bias of skip branch
|
|
232
|
+
kernel_identity = 0
|
|
233
|
+
bias_identity = 0
|
|
234
|
+
if self.rbr_skip is not None:
|
|
235
|
+
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
|
|
236
|
+
|
|
237
|
+
# get weights and bias of conv branches
|
|
238
|
+
kernel_conv = 0
|
|
239
|
+
bias_conv = 0
|
|
240
|
+
if self.rbr_conv is not None:
|
|
241
|
+
for ix in range(self.num_conv_branches):
|
|
242
|
+
_kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
|
|
243
|
+
kernel_conv += _kernel
|
|
244
|
+
bias_conv += _bias
|
|
245
|
+
|
|
246
|
+
kernel_final = kernel_conv + kernel_scale + kernel_identity
|
|
247
|
+
bias_final = bias_conv + bias_scale + bias_identity
|
|
248
|
+
return kernel_final, bias_final
|
|
249
|
+
|
|
250
|
+
def _fuse_bn_tensor(
|
|
251
|
+
self, branch: Union[nn.Sequential, nn.BatchNorm2d]
|
|
252
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
253
|
+
"""Method to fuse batchnorm layer with preceeding conv layer.
|
|
254
|
+
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
branch: Sequence of ops to be fused.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Tuple of (kernel, bias) after fusing batchnorm.
|
|
261
|
+
"""
|
|
262
|
+
if isinstance(branch, nn.Sequential):
|
|
263
|
+
kernel = branch.conv.weight
|
|
264
|
+
running_mean = branch.bn.running_mean
|
|
265
|
+
running_var = branch.bn.running_var
|
|
266
|
+
gamma = branch.bn.weight
|
|
267
|
+
beta = branch.bn.bias
|
|
268
|
+
eps = branch.bn.eps
|
|
269
|
+
else:
|
|
270
|
+
assert isinstance(branch, nn.BatchNorm2d)
|
|
271
|
+
if not hasattr(self, "id_tensor"):
|
|
272
|
+
input_dim = self.in_channels // self.groups
|
|
273
|
+
|
|
274
|
+
kernel_size = self.kernel_size
|
|
275
|
+
if isinstance(self.kernel_size, int):
|
|
276
|
+
kernel_size = (self.kernel_size, self.kernel_size)
|
|
277
|
+
|
|
278
|
+
kernel_value = torch.zeros(
|
|
279
|
+
(self.in_channels, input_dim, kernel_size[0], kernel_size[1]),
|
|
280
|
+
dtype=branch.weight.dtype,
|
|
281
|
+
device=branch.weight.device,
|
|
282
|
+
)
|
|
283
|
+
for i in range(self.in_channels):
|
|
284
|
+
kernel_value[
|
|
285
|
+
i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2
|
|
286
|
+
] = 1
|
|
287
|
+
self.id_tensor = kernel_value
|
|
288
|
+
kernel = self.id_tensor
|
|
289
|
+
running_mean = branch.running_mean
|
|
290
|
+
running_var = branch.running_var
|
|
291
|
+
gamma = branch.weight
|
|
292
|
+
beta = branch.bias
|
|
293
|
+
eps = branch.eps
|
|
294
|
+
std = (running_var + eps).sqrt()
|
|
295
|
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
|
296
|
+
return kernel * t, beta - running_mean * gamma / std
|
|
297
|
+
|
|
298
|
+
def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
|
|
299
|
+
"""Helper method to construct conv-batchnorm layers.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
kernel_size: Size of the convolution kernel.
|
|
303
|
+
padding: Zero-padding size.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
Conv-BN module.
|
|
307
|
+
"""
|
|
308
|
+
mod_list = nn.Sequential()
|
|
309
|
+
mod_list.add_module(
|
|
310
|
+
"conv",
|
|
311
|
+
nn.Conv2d(
|
|
312
|
+
in_channels=self.in_channels,
|
|
313
|
+
out_channels=self.out_channels,
|
|
314
|
+
kernel_size=kernel_size,
|
|
315
|
+
stride=self.stride,
|
|
316
|
+
padding=padding,
|
|
317
|
+
groups=self.groups,
|
|
318
|
+
bias=False,
|
|
319
|
+
),
|
|
320
|
+
)
|
|
321
|
+
mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
|
|
322
|
+
return mod_list
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def reparameterize_model(model: torch.nn.Module) -> nn.Module:
|
|
326
|
+
"""Method returns a model where a multi-branched structure
|
|
327
|
+
used in training is re-parameterized into a single branch
|
|
328
|
+
for inference.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
model: MobileOne model in train mode.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
MobileOne model in inference mode.
|
|
335
|
+
"""
|
|
336
|
+
# Avoid editing original graph
|
|
337
|
+
model = copy.deepcopy(model)
|
|
338
|
+
for module in model.modules():
|
|
339
|
+
if hasattr(module, "reparameterize"):
|
|
340
|
+
module.reparameterize()
|
|
341
|
+
return model
|