lightly-studio 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lightly-studio might be problematic. Click here for more details.

Files changed (219) hide show
  1. lightly_studio/__init__.py +11 -0
  2. lightly_studio/api/__init__.py +0 -0
  3. lightly_studio/api/app.py +110 -0
  4. lightly_studio/api/cache.py +77 -0
  5. lightly_studio/api/db.py +133 -0
  6. lightly_studio/api/db_tables.py +32 -0
  7. lightly_studio/api/features.py +7 -0
  8. lightly_studio/api/routes/api/annotation.py +233 -0
  9. lightly_studio/api/routes/api/annotation_label.py +90 -0
  10. lightly_studio/api/routes/api/annotation_task.py +38 -0
  11. lightly_studio/api/routes/api/classifier.py +387 -0
  12. lightly_studio/api/routes/api/dataset.py +182 -0
  13. lightly_studio/api/routes/api/dataset_tag.py +257 -0
  14. lightly_studio/api/routes/api/exceptions.py +96 -0
  15. lightly_studio/api/routes/api/features.py +17 -0
  16. lightly_studio/api/routes/api/metadata.py +37 -0
  17. lightly_studio/api/routes/api/metrics.py +80 -0
  18. lightly_studio/api/routes/api/sample.py +196 -0
  19. lightly_studio/api/routes/api/settings.py +45 -0
  20. lightly_studio/api/routes/api/status.py +19 -0
  21. lightly_studio/api/routes/api/text_embedding.py +48 -0
  22. lightly_studio/api/routes/api/validators.py +17 -0
  23. lightly_studio/api/routes/healthz.py +13 -0
  24. lightly_studio/api/routes/images.py +104 -0
  25. lightly_studio/api/routes/webapp.py +51 -0
  26. lightly_studio/api/server.py +82 -0
  27. lightly_studio/core/__init__.py +0 -0
  28. lightly_studio/core/dataset.py +523 -0
  29. lightly_studio/core/sample.py +77 -0
  30. lightly_studio/core/start_gui.py +15 -0
  31. lightly_studio/dataset/__init__.py +0 -0
  32. lightly_studio/dataset/edge_embedding_generator.py +144 -0
  33. lightly_studio/dataset/embedding_generator.py +91 -0
  34. lightly_studio/dataset/embedding_manager.py +163 -0
  35. lightly_studio/dataset/env.py +16 -0
  36. lightly_studio/dataset/file_utils.py +35 -0
  37. lightly_studio/dataset/loader.py +622 -0
  38. lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
  39. lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
  40. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
  41. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
  42. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
  43. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
  44. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
  45. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
  46. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
  47. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
  48. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
  49. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
  50. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
  51. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
  52. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
  53. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
  54. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
  55. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
  56. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
  57. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
  58. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
  59. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
  60. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
  61. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
  62. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
  63. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
  78. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
  79. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
  80. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
  81. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
  82. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
  83. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
  84. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
  85. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
  86. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
  87. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
  88. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
  89. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
  90. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
  91. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
  92. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
  93. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
  94. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
  95. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
  96. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
  97. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
  98. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
  99. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
  100. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
  101. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
  102. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
  103. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
  104. lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
  105. lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
  106. lightly_studio/examples/example.py +23 -0
  107. lightly_studio/examples/example_metadata.py +338 -0
  108. lightly_studio/examples/example_selection.py +39 -0
  109. lightly_studio/examples/example_split_work.py +67 -0
  110. lightly_studio/examples/example_v2.py +21 -0
  111. lightly_studio/export_schema.py +18 -0
  112. lightly_studio/few_shot_classifier/__init__.py +0 -0
  113. lightly_studio/few_shot_classifier/classifier.py +80 -0
  114. lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
  115. lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
  116. lightly_studio/metadata/complex_metadata.py +47 -0
  117. lightly_studio/metadata/gps_coordinate.py +41 -0
  118. lightly_studio/metadata/metadata_protocol.py +17 -0
  119. lightly_studio/metrics/__init__.py +0 -0
  120. lightly_studio/metrics/detection/__init__.py +0 -0
  121. lightly_studio/metrics/detection/map.py +268 -0
  122. lightly_studio/models/__init__.py +1 -0
  123. lightly_studio/models/annotation/__init__.py +0 -0
  124. lightly_studio/models/annotation/annotation_base.py +171 -0
  125. lightly_studio/models/annotation/instance_segmentation.py +56 -0
  126. lightly_studio/models/annotation/links.py +17 -0
  127. lightly_studio/models/annotation/object_detection.py +47 -0
  128. lightly_studio/models/annotation/semantic_segmentation.py +44 -0
  129. lightly_studio/models/annotation_label.py +47 -0
  130. lightly_studio/models/annotation_task.py +28 -0
  131. lightly_studio/models/classifier.py +20 -0
  132. lightly_studio/models/dataset.py +84 -0
  133. lightly_studio/models/embedding_model.py +30 -0
  134. lightly_studio/models/metadata.py +208 -0
  135. lightly_studio/models/sample.py +180 -0
  136. lightly_studio/models/sample_embedding.py +37 -0
  137. lightly_studio/models/settings.py +60 -0
  138. lightly_studio/models/tag.py +96 -0
  139. lightly_studio/py.typed +0 -0
  140. lightly_studio/resolvers/__init__.py +7 -0
  141. lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
  142. lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
  143. lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
  144. lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
  145. lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
  146. lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
  147. lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
  148. lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
  149. lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
  150. lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
  151. lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
  152. lightly_studio/resolvers/annotation_resolver/create.py +19 -0
  153. lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
  154. lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
  155. lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
  156. lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
  157. lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
  158. lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
  159. lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
  160. lightly_studio/resolvers/annotation_task_resolver.py +31 -0
  161. lightly_studio/resolvers/annotations/__init__.py +1 -0
  162. lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
  163. lightly_studio/resolvers/dataset_resolver.py +278 -0
  164. lightly_studio/resolvers/embedding_model_resolver.py +100 -0
  165. lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
  166. lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
  167. lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
  168. lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
  169. lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
  170. lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
  171. lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
  172. lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
  173. lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
  174. lightly_studio/resolvers/sample_resolver.py +249 -0
  175. lightly_studio/resolvers/samples_filter.py +81 -0
  176. lightly_studio/resolvers/settings_resolver.py +58 -0
  177. lightly_studio/resolvers/tag_resolver.py +276 -0
  178. lightly_studio/selection/README.md +6 -0
  179. lightly_studio/selection/mundig.py +105 -0
  180. lightly_studio/selection/select.py +96 -0
  181. lightly_studio/selection/select_via_db.py +93 -0
  182. lightly_studio/selection/selection_config.py +31 -0
  183. lightly_studio/services/annotations_service/__init__.py +21 -0
  184. lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
  185. lightly_studio/services/annotations_service/update_annotation.py +65 -0
  186. lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
  187. lightly_studio/services/annotations_service/update_annotations.py +29 -0
  188. lightly_studio/setup_logging.py +19 -0
  189. lightly_studio/type_definitions.py +19 -0
  190. lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
  191. lightly_studio/vendor/LICENSE +31 -0
  192. lightly_studio/vendor/LICENSE_weights_data +50 -0
  193. lightly_studio/vendor/README.md +5 -0
  194. lightly_studio/vendor/__init__.py +1 -0
  195. lightly_studio/vendor/mobileclip/__init__.py +96 -0
  196. lightly_studio/vendor/mobileclip/clip.py +77 -0
  197. lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
  198. lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
  199. lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
  200. lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
  201. lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
  202. lightly_studio/vendor/mobileclip/logger.py +154 -0
  203. lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
  204. lightly_studio/vendor/mobileclip/models/mci.py +933 -0
  205. lightly_studio/vendor/mobileclip/models/vit.py +433 -0
  206. lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
  207. lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
  208. lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
  209. lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
  210. lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
  211. lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
  212. lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
  213. lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
  214. lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
  215. lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
  216. lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
  217. lightly_studio-0.3.1.dist-info/METADATA +520 -0
  218. lightly_studio-0.3.1.dist-info/RECORD +219 -0
  219. lightly_studio-0.3.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,341 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ from typing import Union, Tuple
6
+
7
+ import copy
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ __all__ = ["MobileOneBlock", "reparameterize_model"]
13
+
14
+
15
+ class SEBlock(nn.Module):
16
+ """Squeeze and Excite module.
17
+
18
+ Pytorch implementation of `Squeeze-and-Excitation Networks` -
19
+ https://arxiv.org/pdf/1709.01507.pdf
20
+ """
21
+
22
+ def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
23
+ """Construct a Squeeze and Excite Module.
24
+
25
+ Args:
26
+ in_channels: Number of input channels.
27
+ rd_ratio: Input channel reduction ratio.
28
+ """
29
+ super(SEBlock, self).__init__()
30
+ self.reduce = nn.Conv2d(
31
+ in_channels=in_channels,
32
+ out_channels=int(in_channels * rd_ratio),
33
+ kernel_size=1,
34
+ stride=1,
35
+ bias=True,
36
+ )
37
+ self.expand = nn.Conv2d(
38
+ in_channels=int(in_channels * rd_ratio),
39
+ out_channels=in_channels,
40
+ kernel_size=1,
41
+ stride=1,
42
+ bias=True,
43
+ )
44
+
45
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
46
+ """Apply forward pass."""
47
+ b, c, h, w = inputs.size()
48
+ x = F.avg_pool2d(inputs, kernel_size=[h, w])
49
+ x = self.reduce(x)
50
+ x = F.relu(x)
51
+ x = self.expand(x)
52
+ x = torch.sigmoid(x)
53
+ x = x.view(-1, c, 1, 1)
54
+ return inputs * x
55
+
56
+
57
+ class MobileOneBlock(nn.Module):
58
+ """MobileOne building block.
59
+
60
+ This block has a multi-branched architecture at train-time
61
+ and plain-CNN style architecture at inference time
62
+ For more details, please refer to our paper:
63
+ `An Improved One millisecond Mobile Backbone` -
64
+ https://arxiv.org/pdf/2206.04040.pdf
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ in_channels: int,
70
+ out_channels: int,
71
+ kernel_size: int,
72
+ stride: int = 1,
73
+ padding: int = 0,
74
+ dilation: int = 1,
75
+ groups: int = 1,
76
+ inference_mode: bool = False,
77
+ use_se: bool = False,
78
+ use_act: bool = True,
79
+ use_scale_branch: bool = True,
80
+ num_conv_branches: int = 1,
81
+ activation: nn.Module = nn.GELU(),
82
+ ) -> None:
83
+ """Construct a MobileOneBlock module.
84
+
85
+ Args:
86
+ in_channels: Number of channels in the input.
87
+ out_channels: Number of channels produced by the block.
88
+ kernel_size: Size of the convolution kernel.
89
+ stride: Stride size.
90
+ padding: Zero-padding size.
91
+ dilation: Kernel dilation factor.
92
+ groups: Group number.
93
+ inference_mode: If True, instantiates model in inference mode.
94
+ use_se: Whether to use SE-ReLU activations.
95
+ use_act: Whether to use activation. Default: ``True``
96
+ use_scale_branch: Whether to use scale branch. Default: ``True``
97
+ num_conv_branches: Number of linear conv branches.
98
+ """
99
+ super(MobileOneBlock, self).__init__()
100
+ self.inference_mode = inference_mode
101
+ self.groups = groups
102
+ self.stride = stride
103
+ self.padding = padding
104
+ self.dilation = dilation
105
+ self.kernel_size = kernel_size
106
+ self.in_channels = in_channels
107
+ self.out_channels = out_channels
108
+ self.num_conv_branches = num_conv_branches
109
+
110
+ # Check if SE-ReLU is requested
111
+ if use_se:
112
+ self.se = SEBlock(out_channels)
113
+ else:
114
+ self.se = nn.Identity()
115
+
116
+ if use_act:
117
+ self.activation = activation
118
+ else:
119
+ self.activation = nn.Identity()
120
+
121
+ if inference_mode:
122
+ self.reparam_conv = nn.Conv2d(
123
+ in_channels=in_channels,
124
+ out_channels=out_channels,
125
+ kernel_size=kernel_size,
126
+ stride=stride,
127
+ padding=padding,
128
+ dilation=dilation,
129
+ groups=groups,
130
+ bias=True,
131
+ )
132
+ else:
133
+ # Re-parameterizable skip connection
134
+ self.rbr_skip = (
135
+ nn.BatchNorm2d(num_features=in_channels)
136
+ if out_channels == in_channels and stride == 1
137
+ else None
138
+ )
139
+
140
+ # Re-parameterizable conv branches
141
+ if num_conv_branches > 0:
142
+ rbr_conv = list()
143
+ for _ in range(self.num_conv_branches):
144
+ rbr_conv.append(
145
+ self._conv_bn(kernel_size=kernel_size, padding=padding)
146
+ )
147
+ self.rbr_conv = nn.ModuleList(rbr_conv)
148
+ else:
149
+ self.rbr_conv = None
150
+
151
+ # Re-parameterizable scale branch
152
+ self.rbr_scale = None
153
+ if not isinstance(kernel_size, int):
154
+ kernel_size = kernel_size[0]
155
+ if (kernel_size > 1) and use_scale_branch:
156
+ self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
157
+
158
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
159
+ """Apply forward pass."""
160
+ # Inference mode forward pass.
161
+ if self.inference_mode:
162
+ return self.activation(self.se(self.reparam_conv(x)))
163
+
164
+ # Multi-branched train-time forward pass.
165
+ # Skip branch output
166
+ identity_out = 0
167
+ if self.rbr_skip is not None:
168
+ identity_out = self.rbr_skip(x)
169
+
170
+ # Scale branch output
171
+ scale_out = 0
172
+ if self.rbr_scale is not None:
173
+ scale_out = self.rbr_scale(x)
174
+
175
+ # Other branches
176
+ out = scale_out + identity_out
177
+ if self.rbr_conv is not None:
178
+ for ix in range(self.num_conv_branches):
179
+ out += self.rbr_conv[ix](x)
180
+
181
+ return self.activation(self.se(out))
182
+
183
+ def reparameterize(self):
184
+ """Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
185
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
186
+ architecture used at training time to obtain a plain CNN-like structure
187
+ for inference.
188
+ """
189
+ if self.inference_mode:
190
+ return
191
+ kernel, bias = self._get_kernel_bias()
192
+ self.reparam_conv = nn.Conv2d(
193
+ in_channels=self.in_channels,
194
+ out_channels=self.out_channels,
195
+ kernel_size=self.kernel_size,
196
+ stride=self.stride,
197
+ padding=self.padding,
198
+ dilation=self.dilation,
199
+ groups=self.groups,
200
+ bias=True,
201
+ )
202
+ self.reparam_conv.weight.data = kernel
203
+ self.reparam_conv.bias.data = bias
204
+
205
+ # Delete un-used branches
206
+ for para in self.parameters():
207
+ para.detach_()
208
+ self.__delattr__("rbr_conv")
209
+ self.__delattr__("rbr_scale")
210
+ if hasattr(self, "rbr_skip"):
211
+ self.__delattr__("rbr_skip")
212
+
213
+ self.inference_mode = True
214
+
215
+ def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
216
+ """Method to obtain re-parameterized kernel and bias.
217
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
218
+
219
+ Returns:
220
+ Tuple of (kernel, bias) after fusing branches.
221
+ """
222
+ # get weights and bias of scale branch
223
+ kernel_scale = 0
224
+ bias_scale = 0
225
+ if self.rbr_scale is not None:
226
+ kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
227
+ # Pad scale branch kernel to match conv branch kernel size.
228
+ pad = self.kernel_size // 2
229
+ kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
230
+
231
+ # get weights and bias of skip branch
232
+ kernel_identity = 0
233
+ bias_identity = 0
234
+ if self.rbr_skip is not None:
235
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
236
+
237
+ # get weights and bias of conv branches
238
+ kernel_conv = 0
239
+ bias_conv = 0
240
+ if self.rbr_conv is not None:
241
+ for ix in range(self.num_conv_branches):
242
+ _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
243
+ kernel_conv += _kernel
244
+ bias_conv += _bias
245
+
246
+ kernel_final = kernel_conv + kernel_scale + kernel_identity
247
+ bias_final = bias_conv + bias_scale + bias_identity
248
+ return kernel_final, bias_final
249
+
250
+ def _fuse_bn_tensor(
251
+ self, branch: Union[nn.Sequential, nn.BatchNorm2d]
252
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
253
+ """Method to fuse batchnorm layer with preceeding conv layer.
254
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
255
+
256
+ Args:
257
+ branch: Sequence of ops to be fused.
258
+
259
+ Returns:
260
+ Tuple of (kernel, bias) after fusing batchnorm.
261
+ """
262
+ if isinstance(branch, nn.Sequential):
263
+ kernel = branch.conv.weight
264
+ running_mean = branch.bn.running_mean
265
+ running_var = branch.bn.running_var
266
+ gamma = branch.bn.weight
267
+ beta = branch.bn.bias
268
+ eps = branch.bn.eps
269
+ else:
270
+ assert isinstance(branch, nn.BatchNorm2d)
271
+ if not hasattr(self, "id_tensor"):
272
+ input_dim = self.in_channels // self.groups
273
+
274
+ kernel_size = self.kernel_size
275
+ if isinstance(self.kernel_size, int):
276
+ kernel_size = (self.kernel_size, self.kernel_size)
277
+
278
+ kernel_value = torch.zeros(
279
+ (self.in_channels, input_dim, kernel_size[0], kernel_size[1]),
280
+ dtype=branch.weight.dtype,
281
+ device=branch.weight.device,
282
+ )
283
+ for i in range(self.in_channels):
284
+ kernel_value[
285
+ i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2
286
+ ] = 1
287
+ self.id_tensor = kernel_value
288
+ kernel = self.id_tensor
289
+ running_mean = branch.running_mean
290
+ running_var = branch.running_var
291
+ gamma = branch.weight
292
+ beta = branch.bias
293
+ eps = branch.eps
294
+ std = (running_var + eps).sqrt()
295
+ t = (gamma / std).reshape(-1, 1, 1, 1)
296
+ return kernel * t, beta - running_mean * gamma / std
297
+
298
+ def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
299
+ """Helper method to construct conv-batchnorm layers.
300
+
301
+ Args:
302
+ kernel_size: Size of the convolution kernel.
303
+ padding: Zero-padding size.
304
+
305
+ Returns:
306
+ Conv-BN module.
307
+ """
308
+ mod_list = nn.Sequential()
309
+ mod_list.add_module(
310
+ "conv",
311
+ nn.Conv2d(
312
+ in_channels=self.in_channels,
313
+ out_channels=self.out_channels,
314
+ kernel_size=kernel_size,
315
+ stride=self.stride,
316
+ padding=padding,
317
+ groups=self.groups,
318
+ bias=False,
319
+ ),
320
+ )
321
+ mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
322
+ return mod_list
323
+
324
+
325
+ def reparameterize_model(model: torch.nn.Module) -> nn.Module:
326
+ """Method returns a model where a multi-branched structure
327
+ used in training is re-parameterized into a single branch
328
+ for inference.
329
+
330
+ Args:
331
+ model: MobileOne model in train mode.
332
+
333
+ Returns:
334
+ MobileOne model in inference mode.
335
+ """
336
+ # Avoid editing original graph
337
+ model = copy.deepcopy(model)
338
+ for module in model.modules():
339
+ if hasattr(module, "reparameterize"):
340
+ module.reparameterize()
341
+ return model