lightly-studio 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lightly-studio might be problematic. Click here for more details.

Files changed (219) hide show
  1. lightly_studio/__init__.py +11 -0
  2. lightly_studio/api/__init__.py +0 -0
  3. lightly_studio/api/app.py +110 -0
  4. lightly_studio/api/cache.py +77 -0
  5. lightly_studio/api/db.py +133 -0
  6. lightly_studio/api/db_tables.py +32 -0
  7. lightly_studio/api/features.py +7 -0
  8. lightly_studio/api/routes/api/annotation.py +233 -0
  9. lightly_studio/api/routes/api/annotation_label.py +90 -0
  10. lightly_studio/api/routes/api/annotation_task.py +38 -0
  11. lightly_studio/api/routes/api/classifier.py +387 -0
  12. lightly_studio/api/routes/api/dataset.py +182 -0
  13. lightly_studio/api/routes/api/dataset_tag.py +257 -0
  14. lightly_studio/api/routes/api/exceptions.py +96 -0
  15. lightly_studio/api/routes/api/features.py +17 -0
  16. lightly_studio/api/routes/api/metadata.py +37 -0
  17. lightly_studio/api/routes/api/metrics.py +80 -0
  18. lightly_studio/api/routes/api/sample.py +196 -0
  19. lightly_studio/api/routes/api/settings.py +45 -0
  20. lightly_studio/api/routes/api/status.py +19 -0
  21. lightly_studio/api/routes/api/text_embedding.py +48 -0
  22. lightly_studio/api/routes/api/validators.py +17 -0
  23. lightly_studio/api/routes/healthz.py +13 -0
  24. lightly_studio/api/routes/images.py +104 -0
  25. lightly_studio/api/routes/webapp.py +51 -0
  26. lightly_studio/api/server.py +82 -0
  27. lightly_studio/core/__init__.py +0 -0
  28. lightly_studio/core/dataset.py +523 -0
  29. lightly_studio/core/sample.py +77 -0
  30. lightly_studio/core/start_gui.py +15 -0
  31. lightly_studio/dataset/__init__.py +0 -0
  32. lightly_studio/dataset/edge_embedding_generator.py +144 -0
  33. lightly_studio/dataset/embedding_generator.py +91 -0
  34. lightly_studio/dataset/embedding_manager.py +163 -0
  35. lightly_studio/dataset/env.py +16 -0
  36. lightly_studio/dataset/file_utils.py +35 -0
  37. lightly_studio/dataset/loader.py +622 -0
  38. lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
  39. lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
  40. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
  41. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
  42. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
  43. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
  44. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
  45. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
  46. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
  47. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
  48. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
  49. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
  50. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
  51. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
  52. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
  53. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
  54. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
  55. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
  56. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
  57. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
  58. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
  59. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
  60. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
  61. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
  62. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
  63. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
  78. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
  79. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
  80. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
  81. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
  82. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
  83. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
  84. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
  85. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
  86. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
  87. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
  88. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
  89. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
  90. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
  91. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
  92. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
  93. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
  94. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
  95. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
  96. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
  97. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
  98. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
  99. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
  100. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
  101. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
  102. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
  103. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
  104. lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
  105. lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
  106. lightly_studio/examples/example.py +23 -0
  107. lightly_studio/examples/example_metadata.py +338 -0
  108. lightly_studio/examples/example_selection.py +39 -0
  109. lightly_studio/examples/example_split_work.py +67 -0
  110. lightly_studio/examples/example_v2.py +21 -0
  111. lightly_studio/export_schema.py +18 -0
  112. lightly_studio/few_shot_classifier/__init__.py +0 -0
  113. lightly_studio/few_shot_classifier/classifier.py +80 -0
  114. lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
  115. lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
  116. lightly_studio/metadata/complex_metadata.py +47 -0
  117. lightly_studio/metadata/gps_coordinate.py +41 -0
  118. lightly_studio/metadata/metadata_protocol.py +17 -0
  119. lightly_studio/metrics/__init__.py +0 -0
  120. lightly_studio/metrics/detection/__init__.py +0 -0
  121. lightly_studio/metrics/detection/map.py +268 -0
  122. lightly_studio/models/__init__.py +1 -0
  123. lightly_studio/models/annotation/__init__.py +0 -0
  124. lightly_studio/models/annotation/annotation_base.py +171 -0
  125. lightly_studio/models/annotation/instance_segmentation.py +56 -0
  126. lightly_studio/models/annotation/links.py +17 -0
  127. lightly_studio/models/annotation/object_detection.py +47 -0
  128. lightly_studio/models/annotation/semantic_segmentation.py +44 -0
  129. lightly_studio/models/annotation_label.py +47 -0
  130. lightly_studio/models/annotation_task.py +28 -0
  131. lightly_studio/models/classifier.py +20 -0
  132. lightly_studio/models/dataset.py +84 -0
  133. lightly_studio/models/embedding_model.py +30 -0
  134. lightly_studio/models/metadata.py +208 -0
  135. lightly_studio/models/sample.py +180 -0
  136. lightly_studio/models/sample_embedding.py +37 -0
  137. lightly_studio/models/settings.py +60 -0
  138. lightly_studio/models/tag.py +96 -0
  139. lightly_studio/py.typed +0 -0
  140. lightly_studio/resolvers/__init__.py +7 -0
  141. lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
  142. lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
  143. lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
  144. lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
  145. lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
  146. lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
  147. lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
  148. lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
  149. lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
  150. lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
  151. lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
  152. lightly_studio/resolvers/annotation_resolver/create.py +19 -0
  153. lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
  154. lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
  155. lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
  156. lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
  157. lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
  158. lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
  159. lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
  160. lightly_studio/resolvers/annotation_task_resolver.py +31 -0
  161. lightly_studio/resolvers/annotations/__init__.py +1 -0
  162. lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
  163. lightly_studio/resolvers/dataset_resolver.py +278 -0
  164. lightly_studio/resolvers/embedding_model_resolver.py +100 -0
  165. lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
  166. lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
  167. lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
  168. lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
  169. lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
  170. lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
  171. lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
  172. lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
  173. lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
  174. lightly_studio/resolvers/sample_resolver.py +249 -0
  175. lightly_studio/resolvers/samples_filter.py +81 -0
  176. lightly_studio/resolvers/settings_resolver.py +58 -0
  177. lightly_studio/resolvers/tag_resolver.py +276 -0
  178. lightly_studio/selection/README.md +6 -0
  179. lightly_studio/selection/mundig.py +105 -0
  180. lightly_studio/selection/select.py +96 -0
  181. lightly_studio/selection/select_via_db.py +93 -0
  182. lightly_studio/selection/selection_config.py +31 -0
  183. lightly_studio/services/annotations_service/__init__.py +21 -0
  184. lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
  185. lightly_studio/services/annotations_service/update_annotation.py +65 -0
  186. lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
  187. lightly_studio/services/annotations_service/update_annotations.py +29 -0
  188. lightly_studio/setup_logging.py +19 -0
  189. lightly_studio/type_definitions.py +19 -0
  190. lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
  191. lightly_studio/vendor/LICENSE +31 -0
  192. lightly_studio/vendor/LICENSE_weights_data +50 -0
  193. lightly_studio/vendor/README.md +5 -0
  194. lightly_studio/vendor/__init__.py +1 -0
  195. lightly_studio/vendor/mobileclip/__init__.py +96 -0
  196. lightly_studio/vendor/mobileclip/clip.py +77 -0
  197. lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
  198. lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
  199. lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
  200. lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
  201. lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
  202. lightly_studio/vendor/mobileclip/logger.py +154 -0
  203. lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
  204. lightly_studio/vendor/mobileclip/models/mci.py +933 -0
  205. lightly_studio/vendor/mobileclip/models/vit.py +433 -0
  206. lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
  207. lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
  208. lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
  209. lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
  210. lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
  211. lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
  212. lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
  213. lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
  214. lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
  215. lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
  216. lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
  217. lightly_studio-0.3.1.dist-info/METADATA +520 -0
  218. lightly_studio-0.3.1.dist-info/RECORD +219 -0
  219. lightly_studio-0.3.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,188 @@
1
+ #
2
+ # For acknowledgement see accompanying ACKNOWLEDGEMENTS file.
3
+ # Copyright (C) 2024 Apple Inc. All rights reserved.
4
+ #
5
+ from typing import Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from timm.models.layers import SqueezeExcite
11
+
12
+ __all__ = ["ReparamLargeKernelConv"]
13
+
14
+
15
+ class ReparamLargeKernelConv(nn.Module):
16
+ """Building Block of RepLKNet
17
+
18
+ This class defines overparameterized large kernel conv block
19
+ introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_
20
+
21
+ Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ in_channels: int,
27
+ out_channels: int,
28
+ kernel_size: int,
29
+ stride: int,
30
+ groups: int,
31
+ small_kernel: int,
32
+ inference_mode: bool = False,
33
+ use_se: bool = False,
34
+ activation: nn.Module = nn.GELU(),
35
+ ) -> None:
36
+ """Construct a ReparamLargeKernelConv module.
37
+
38
+ Args:
39
+ in_channels: Number of input channels.
40
+ out_channels: Number of output channels.
41
+ kernel_size: Kernel size of the large kernel conv branch.
42
+ stride: Stride size. Default: 1
43
+ groups: Group number. Default: 1
44
+ small_kernel: Kernel size of small kernel conv branch.
45
+ inference_mode: If True, instantiates model in inference mode. Default: ``False``
46
+ activation: Activation module. Default: ``nn.GELU``
47
+ """
48
+ super(ReparamLargeKernelConv, self).__init__()
49
+
50
+ self.stride = stride
51
+ self.groups = groups
52
+ self.in_channels = in_channels
53
+ self.out_channels = out_channels
54
+ self.activation = activation
55
+
56
+ self.kernel_size = kernel_size
57
+ self.small_kernel = small_kernel
58
+ self.padding = kernel_size // 2
59
+
60
+ # Check if SE is requested
61
+ if use_se:
62
+ self.se = SqueezeExcite(out_channels, rd_ratio=0.25)
63
+ else:
64
+ self.se = nn.Identity()
65
+
66
+ if inference_mode:
67
+ self.lkb_reparam = nn.Conv2d(
68
+ in_channels=in_channels,
69
+ out_channels=out_channels,
70
+ kernel_size=kernel_size,
71
+ stride=stride,
72
+ padding=self.padding,
73
+ dilation=1,
74
+ groups=groups,
75
+ bias=True,
76
+ )
77
+ else:
78
+ self.lkb_origin = self._conv_bn(
79
+ kernel_size=kernel_size, padding=self.padding
80
+ )
81
+ if small_kernel is not None:
82
+ assert (
83
+ small_kernel <= kernel_size
84
+ ), "The kernel size for re-param cannot be larger than the large kernel!"
85
+ self.small_conv = self._conv_bn(
86
+ kernel_size=small_kernel, padding=small_kernel // 2
87
+ )
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ """Apply forward pass."""
91
+ if hasattr(self, "lkb_reparam"):
92
+ out = self.lkb_reparam(x)
93
+ else:
94
+ out = self.lkb_origin(x)
95
+ if hasattr(self, "small_conv"):
96
+ out += self.small_conv(x)
97
+
98
+ return self.activation(self.se(out))
99
+
100
+ def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
101
+ """Method to obtain re-parameterized kernel and bias.
102
+ Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
103
+
104
+ Returns:
105
+ Tuple of (kernel, bias) after fusing branches.
106
+ """
107
+ eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
108
+ if hasattr(self, "small_conv"):
109
+ small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
110
+ eq_b += small_b
111
+ eq_k += nn.functional.pad(
112
+ small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
113
+ )
114
+ return eq_k, eq_b
115
+
116
+ def reparameterize(self) -> None:
117
+ """
118
+ Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
119
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
120
+ architecture used at training time to obtain a plain CNN-like structure
121
+ for inference.
122
+ """
123
+ eq_k, eq_b = self.get_kernel_bias()
124
+ self.lkb_reparam = nn.Conv2d(
125
+ in_channels=self.in_channels,
126
+ out_channels=self.out_channels,
127
+ kernel_size=self.kernel_size,
128
+ stride=self.stride,
129
+ padding=self.padding,
130
+ dilation=self.lkb_origin.conv.dilation,
131
+ groups=self.groups,
132
+ bias=True,
133
+ )
134
+
135
+ self.lkb_reparam.weight.data = eq_k
136
+ self.lkb_reparam.bias.data = eq_b
137
+ self.__delattr__("lkb_origin")
138
+ if hasattr(self, "small_conv"):
139
+ self.__delattr__("small_conv")
140
+
141
+ @staticmethod
142
+ def _fuse_bn(
143
+ conv: torch.Tensor, bn: nn.BatchNorm2d
144
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
145
+ """Method to fuse batchnorm layer with conv layer.
146
+
147
+ Args:
148
+ conv: Convolutional kernel weights.
149
+ bn: Batchnorm 2d layer.
150
+
151
+ Returns:
152
+ Tuple of (kernel, bias) after fusing batchnorm.
153
+ """
154
+ kernel = conv.weight
155
+ running_mean = bn.running_mean
156
+ running_var = bn.running_var
157
+ gamma = bn.weight
158
+ beta = bn.bias
159
+ eps = bn.eps
160
+ std = (running_var + eps).sqrt()
161
+ t = (gamma / std).reshape(-1, 1, 1, 1)
162
+ return kernel * t, beta - running_mean * gamma / std
163
+
164
+ def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
165
+ """Helper method to construct conv-batchnorm layers.
166
+
167
+ Args:
168
+ kernel_size: Size of the convolution kernel.
169
+ padding: Zero-padding size.
170
+
171
+ Returns:
172
+ A nn.Sequential Conv-BN module.
173
+ """
174
+ mod_list = nn.Sequential()
175
+ mod_list.add_module(
176
+ "conv",
177
+ nn.Conv2d(
178
+ in_channels=self.in_channels,
179
+ out_channels=self.out_channels,
180
+ kernel_size=kernel_size,
181
+ stride=self.stride,
182
+ padding=padding,
183
+ groups=self.groups,
184
+ bias=False,
185
+ ),
186
+ )
187
+ mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
188
+ return mod_list
@@ -0,0 +1,4 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All rights reserved.
4
+ #
@@ -0,0 +1,281 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from timm.models.layers import DropPath, trunc_normal_
11
+ from ..common.mobileone import MobileOneBlock
12
+
13
+
14
+ class ConvFFN(nn.Module):
15
+ """Convolutional FFN Module."""
16
+
17
+ def __init__(
18
+ self,
19
+ in_channels: int,
20
+ context_size: int,
21
+ hidden_channels: Optional[int] = None,
22
+ out_channels: Optional[int] = None,
23
+ act_layer: nn.Module = nn.GELU,
24
+ drop: float = 0.0,
25
+ ) -> None:
26
+ """Build convolutional FFN module.
27
+
28
+ Args:
29
+ in_channels: Number of input channels.
30
+ context_size: Context size for 1D signals.
31
+ hidden_channels: Number of channels after expansion. Default: None
32
+ out_channels: Number of output channels. Default: None
33
+ act_layer: Activation layer. Default: ``GELU``
34
+ drop: Dropout rate. Default: ``0.0``.
35
+ """
36
+ super().__init__()
37
+ out_channels = out_channels or in_channels
38
+ hidden_channels = hidden_channels or in_channels
39
+ self.conv = nn.Sequential()
40
+ self.conv.add_module(
41
+ "conv",
42
+ nn.Conv2d(
43
+ in_channels=in_channels,
44
+ out_channels=out_channels,
45
+ kernel_size=(1, int(context_size)),
46
+ padding=(0, int(context_size // 2)),
47
+ groups=in_channels,
48
+ bias=False,
49
+ ),
50
+ )
51
+ self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
52
+ self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
53
+ self.act = act_layer()
54
+ self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
55
+ self.drop = nn.Dropout(drop)
56
+ self.apply(self._init_weights)
57
+
58
+ def _init_weights(self, m: nn.Module) -> None:
59
+ if isinstance(m, nn.Conv2d):
60
+ trunc_normal_(m.weight, std=0.02)
61
+ if m.bias is not None:
62
+ nn.init.constant_(m.bias, 0)
63
+
64
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
65
+ x = self.conv(x)
66
+ x = self.fc1(x)
67
+ x = self.act(x)
68
+ x = self.drop(x)
69
+ x = self.fc2(x)
70
+ x = self.drop(x)
71
+ return x
72
+
73
+
74
+ class RepMixer(nn.Module):
75
+ """Reparameterizable token mixer.
76
+
77
+ For more details, please refer to our paper:
78
+ `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ dim,
84
+ kernel_size=3,
85
+ use_layer_scale=True,
86
+ layer_scale_init_value=1e-5,
87
+ inference_mode: bool = False,
88
+ ):
89
+ """Build RepMixer Module.
90
+
91
+ Args:
92
+ dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
93
+ kernel_size: Kernel size for spatial mixing. Default: 3
94
+ use_layer_scale: If True, learnable layer scale is used. Default: ``True``
95
+ layer_scale_init_value: Initial value for layer scale. Default: 1e-5
96
+ inference_mode: If True, instantiates model in inference mode. Default: ``False``
97
+ """
98
+ super().__init__()
99
+ self.dim = dim
100
+ self.kernel_size = kernel_size
101
+ self.inference_mode = inference_mode
102
+
103
+ if inference_mode:
104
+ self.reparam_conv = nn.Conv2d(
105
+ in_channels=self.dim,
106
+ out_channels=self.dim,
107
+ kernel_size=(1, self.kernel_size),
108
+ stride=1,
109
+ padding=(0, self.kernel_size // 2),
110
+ groups=self.dim,
111
+ bias=True,
112
+ )
113
+ else:
114
+ self.norm = MobileOneBlock(
115
+ dim,
116
+ dim,
117
+ (1, kernel_size),
118
+ padding=(0, kernel_size // 2),
119
+ groups=dim,
120
+ use_act=False,
121
+ use_scale_branch=False,
122
+ num_conv_branches=0,
123
+ )
124
+ self.mixer = MobileOneBlock(
125
+ dim,
126
+ dim,
127
+ (1, kernel_size),
128
+ padding=(0, kernel_size // 2),
129
+ groups=dim,
130
+ use_act=False,
131
+ )
132
+ self.use_layer_scale = use_layer_scale
133
+ if use_layer_scale:
134
+ self.layer_scale = nn.Parameter(
135
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
136
+ )
137
+
138
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
139
+ if hasattr(self, "reparam_conv"):
140
+ x = self.reparam_conv(x)
141
+ return x
142
+ else:
143
+ if self.use_layer_scale:
144
+ x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
145
+ else:
146
+ x = x + self.mixer(x) - self.norm(x)
147
+ return x
148
+
149
+ def reparameterize(self) -> None:
150
+ """Reparameterize mixer and norm into a single
151
+ convolutional layer for efficient inference.
152
+ """
153
+ if self.inference_mode:
154
+ return
155
+
156
+ self.mixer.reparameterize()
157
+ self.norm.reparameterize()
158
+
159
+ if self.use_layer_scale:
160
+ w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
161
+ self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
162
+ )
163
+ b = torch.squeeze(self.layer_scale) * (
164
+ self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
165
+ )
166
+ else:
167
+ w = (
168
+ self.mixer.id_tensor
169
+ + self.mixer.reparam_conv.weight
170
+ - self.norm.reparam_conv.weight
171
+ )
172
+ b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
173
+
174
+ self.reparam_conv = nn.Conv2d(
175
+ in_channels=self.dim,
176
+ out_channels=self.dim,
177
+ kernel_size=(1, self.kernel_size),
178
+ stride=1,
179
+ padding=(0, self.kernel_size // 2),
180
+ groups=self.dim,
181
+ bias=True,
182
+ )
183
+ self.reparam_conv.weight.data = w
184
+ self.reparam_conv.bias.data = b
185
+
186
+ for para in self.parameters():
187
+ para.detach_()
188
+ self.__delattr__("mixer")
189
+ self.__delattr__("norm")
190
+ if self.use_layer_scale:
191
+ self.__delattr__("layer_scale")
192
+
193
+
194
+ class RepMixerBlock(nn.Module):
195
+ """Implementation of Metaformer block with RepMixer as token mixer.
196
+
197
+ For more details on Metaformer structure, please refer to:
198
+ `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ dim: int,
204
+ kernel_size: int = 11,
205
+ mlp_ratio: float = 4.0,
206
+ act_layer: nn.Module = nn.GELU,
207
+ drop: float = 0.0,
208
+ drop_path: float = 0.0,
209
+ use_layer_scale: bool = True,
210
+ layer_scale_init_value: float = 1e-5,
211
+ inference_mode: bool = False,
212
+ *args,
213
+ **kwargs,
214
+ ):
215
+ """Build RepMixer Block.
216
+
217
+ Args:
218
+ dim: Number of embedding dimensions.
219
+ kernel_size: Kernel size for repmixer. Default: 3
220
+ mlp_ratio: MLP expansion ratio. Default: 4.0
221
+ act_layer: Activation layer. Default: ``nn.GELU``
222
+ drop: Dropout rate. Default: 0.0
223
+ drop_path: Drop path rate. Default: 0.0
224
+ use_layer_scale: Flag to turn on layer scale. Default: ``True``
225
+ layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
226
+ inference_mode: Flag to instantiate block in inference mode. Default: ``False``
227
+ """
228
+
229
+ super().__init__()
230
+
231
+ self.token_mixer = RepMixer(
232
+ dim,
233
+ kernel_size=kernel_size,
234
+ use_layer_scale=use_layer_scale,
235
+ layer_scale_init_value=layer_scale_init_value,
236
+ inference_mode=inference_mode,
237
+ )
238
+
239
+ assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
240
+ mlp_ratio
241
+ )
242
+ mlp_hidden_dim = int(dim * mlp_ratio)
243
+ self.convffn = ConvFFN(
244
+ in_channels=dim,
245
+ context_size=kernel_size,
246
+ hidden_channels=mlp_hidden_dim,
247
+ act_layer=act_layer,
248
+ drop=drop,
249
+ )
250
+
251
+ # Drop Path
252
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
253
+
254
+ # Layer Scale
255
+ self.use_layer_scale = use_layer_scale
256
+ if use_layer_scale:
257
+ self.layer_scale = nn.Parameter(
258
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
259
+ )
260
+
261
+ def forward(self, x, *args, **kwargs):
262
+ if x.dim() == 3:
263
+ # B, C, D --- where C is the context length
264
+ # Convert to B, D, C --- to match RepMixer impl.
265
+ x = x.permute(0, 2, 1)
266
+ x = torch.unsqueeze(x, dim=2)
267
+ else:
268
+ raise ValueError(
269
+ f"Expected tensor of dim=3, obtained tensor of dim={x.dim()}"
270
+ )
271
+
272
+ if self.use_layer_scale:
273
+ x = self.token_mixer(x)
274
+ x = x + self.drop_path(self.layer_scale * self.convffn(x))
275
+ else:
276
+ x = self.token_mixer(x)
277
+ x = x + self.drop_path(self.convffn(x))
278
+
279
+ # Convert tensors back
280
+ x = x.squeeze(dim=2).permute(0, 2, 1)
281
+ return x
@@ -0,0 +1,38 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ from typing import Dict
6
+
7
+ import open_clip
8
+ from torch import Tensor, nn
9
+
10
+
11
+ class ClipTokenizer(nn.Module):
12
+ def __init__(self, cfg, *args, **kwargs):
13
+ super().__init__()
14
+ self.context_length = cfg["text_cfg"]["context_length"]
15
+ model_name = getattr(cfg["text_cfg"], "open_clip_tokenizer", "ViT-B-16")
16
+ self.tokenizer = open_clip.get_tokenizer(model_name)
17
+
18
+ def get_vocab_size(self) -> int:
19
+ return len(self.tokenizer.encoder)
20
+
21
+ def get_encodings(self) -> Dict[str, int]:
22
+ return self.tokenizer.encoder
23
+
24
+ def get_eot_token(self) -> int:
25
+ # Tokenizing an empty string returns a list [sot_id, eot_id]
26
+ return self.tokenizer("")[1]
27
+
28
+ def get_sot_token(self) -> int:
29
+ # Tokenizing an empty string returns a list [sot_id, eot_id]
30
+ return self.tokenizer("")[0]
31
+
32
+ def forward(self, input_sentence: str, *args, **kwargs) -> Tensor:
33
+ # tokenizer returns indices as a string
34
+ tokenized_sentence = self.tokenizer(input_sentence, self.context_length)
35
+ assert (
36
+ tokenized_sentence.shape[-1] == self.context_length
37
+ ), "Tokenized tensor should be exactly `context_length` long."
38
+ return tokenized_sentence