lightly-studio 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lightly-studio might be problematic. Click here for more details.

Files changed (219) hide show
  1. lightly_studio/__init__.py +11 -0
  2. lightly_studio/api/__init__.py +0 -0
  3. lightly_studio/api/app.py +110 -0
  4. lightly_studio/api/cache.py +77 -0
  5. lightly_studio/api/db.py +133 -0
  6. lightly_studio/api/db_tables.py +32 -0
  7. lightly_studio/api/features.py +7 -0
  8. lightly_studio/api/routes/api/annotation.py +233 -0
  9. lightly_studio/api/routes/api/annotation_label.py +90 -0
  10. lightly_studio/api/routes/api/annotation_task.py +38 -0
  11. lightly_studio/api/routes/api/classifier.py +387 -0
  12. lightly_studio/api/routes/api/dataset.py +182 -0
  13. lightly_studio/api/routes/api/dataset_tag.py +257 -0
  14. lightly_studio/api/routes/api/exceptions.py +96 -0
  15. lightly_studio/api/routes/api/features.py +17 -0
  16. lightly_studio/api/routes/api/metadata.py +37 -0
  17. lightly_studio/api/routes/api/metrics.py +80 -0
  18. lightly_studio/api/routes/api/sample.py +196 -0
  19. lightly_studio/api/routes/api/settings.py +45 -0
  20. lightly_studio/api/routes/api/status.py +19 -0
  21. lightly_studio/api/routes/api/text_embedding.py +48 -0
  22. lightly_studio/api/routes/api/validators.py +17 -0
  23. lightly_studio/api/routes/healthz.py +13 -0
  24. lightly_studio/api/routes/images.py +104 -0
  25. lightly_studio/api/routes/webapp.py +51 -0
  26. lightly_studio/api/server.py +82 -0
  27. lightly_studio/core/__init__.py +0 -0
  28. lightly_studio/core/dataset.py +523 -0
  29. lightly_studio/core/sample.py +77 -0
  30. lightly_studio/core/start_gui.py +15 -0
  31. lightly_studio/dataset/__init__.py +0 -0
  32. lightly_studio/dataset/edge_embedding_generator.py +144 -0
  33. lightly_studio/dataset/embedding_generator.py +91 -0
  34. lightly_studio/dataset/embedding_manager.py +163 -0
  35. lightly_studio/dataset/env.py +16 -0
  36. lightly_studio/dataset/file_utils.py +35 -0
  37. lightly_studio/dataset/loader.py +622 -0
  38. lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
  39. lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
  40. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
  41. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
  42. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
  43. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
  44. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
  45. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
  46. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
  47. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
  48. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
  49. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
  50. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
  51. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
  52. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
  53. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
  54. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
  55. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
  56. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
  57. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
  58. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
  59. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
  60. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
  61. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
  62. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
  63. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
  78. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
  79. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
  80. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
  81. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
  82. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
  83. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
  84. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
  85. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
  86. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
  87. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
  88. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
  89. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
  90. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
  91. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
  92. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
  93. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
  94. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
  95. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
  96. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
  97. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
  98. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
  99. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
  100. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
  101. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
  102. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
  103. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
  104. lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
  105. lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
  106. lightly_studio/examples/example.py +23 -0
  107. lightly_studio/examples/example_metadata.py +338 -0
  108. lightly_studio/examples/example_selection.py +39 -0
  109. lightly_studio/examples/example_split_work.py +67 -0
  110. lightly_studio/examples/example_v2.py +21 -0
  111. lightly_studio/export_schema.py +18 -0
  112. lightly_studio/few_shot_classifier/__init__.py +0 -0
  113. lightly_studio/few_shot_classifier/classifier.py +80 -0
  114. lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
  115. lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
  116. lightly_studio/metadata/complex_metadata.py +47 -0
  117. lightly_studio/metadata/gps_coordinate.py +41 -0
  118. lightly_studio/metadata/metadata_protocol.py +17 -0
  119. lightly_studio/metrics/__init__.py +0 -0
  120. lightly_studio/metrics/detection/__init__.py +0 -0
  121. lightly_studio/metrics/detection/map.py +268 -0
  122. lightly_studio/models/__init__.py +1 -0
  123. lightly_studio/models/annotation/__init__.py +0 -0
  124. lightly_studio/models/annotation/annotation_base.py +171 -0
  125. lightly_studio/models/annotation/instance_segmentation.py +56 -0
  126. lightly_studio/models/annotation/links.py +17 -0
  127. lightly_studio/models/annotation/object_detection.py +47 -0
  128. lightly_studio/models/annotation/semantic_segmentation.py +44 -0
  129. lightly_studio/models/annotation_label.py +47 -0
  130. lightly_studio/models/annotation_task.py +28 -0
  131. lightly_studio/models/classifier.py +20 -0
  132. lightly_studio/models/dataset.py +84 -0
  133. lightly_studio/models/embedding_model.py +30 -0
  134. lightly_studio/models/metadata.py +208 -0
  135. lightly_studio/models/sample.py +180 -0
  136. lightly_studio/models/sample_embedding.py +37 -0
  137. lightly_studio/models/settings.py +60 -0
  138. lightly_studio/models/tag.py +96 -0
  139. lightly_studio/py.typed +0 -0
  140. lightly_studio/resolvers/__init__.py +7 -0
  141. lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
  142. lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
  143. lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
  144. lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
  145. lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
  146. lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
  147. lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
  148. lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
  149. lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
  150. lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
  151. lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
  152. lightly_studio/resolvers/annotation_resolver/create.py +19 -0
  153. lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
  154. lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
  155. lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
  156. lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
  157. lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
  158. lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
  159. lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
  160. lightly_studio/resolvers/annotation_task_resolver.py +31 -0
  161. lightly_studio/resolvers/annotations/__init__.py +1 -0
  162. lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
  163. lightly_studio/resolvers/dataset_resolver.py +278 -0
  164. lightly_studio/resolvers/embedding_model_resolver.py +100 -0
  165. lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
  166. lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
  167. lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
  168. lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
  169. lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
  170. lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
  171. lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
  172. lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
  173. lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
  174. lightly_studio/resolvers/sample_resolver.py +249 -0
  175. lightly_studio/resolvers/samples_filter.py +81 -0
  176. lightly_studio/resolvers/settings_resolver.py +58 -0
  177. lightly_studio/resolvers/tag_resolver.py +276 -0
  178. lightly_studio/selection/README.md +6 -0
  179. lightly_studio/selection/mundig.py +105 -0
  180. lightly_studio/selection/select.py +96 -0
  181. lightly_studio/selection/select_via_db.py +93 -0
  182. lightly_studio/selection/selection_config.py +31 -0
  183. lightly_studio/services/annotations_service/__init__.py +21 -0
  184. lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
  185. lightly_studio/services/annotations_service/update_annotation.py +65 -0
  186. lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
  187. lightly_studio/services/annotations_service/update_annotations.py +29 -0
  188. lightly_studio/setup_logging.py +19 -0
  189. lightly_studio/type_definitions.py +19 -0
  190. lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
  191. lightly_studio/vendor/LICENSE +31 -0
  192. lightly_studio/vendor/LICENSE_weights_data +50 -0
  193. lightly_studio/vendor/README.md +5 -0
  194. lightly_studio/vendor/__init__.py +1 -0
  195. lightly_studio/vendor/mobileclip/__init__.py +96 -0
  196. lightly_studio/vendor/mobileclip/clip.py +77 -0
  197. lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
  198. lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
  199. lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
  200. lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
  201. lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
  202. lightly_studio/vendor/mobileclip/logger.py +154 -0
  203. lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
  204. lightly_studio/vendor/mobileclip/models/mci.py +933 -0
  205. lightly_studio/vendor/mobileclip/models/vit.py +433 -0
  206. lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
  207. lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
  208. lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
  209. lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
  210. lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
  211. lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
  212. lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
  213. lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
  214. lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
  215. lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
  216. lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
  217. lightly_studio-0.3.1.dist-info/METADATA +520 -0
  218. lightly_studio-0.3.1.dist-info/RECORD +219 -0
  219. lightly_studio-0.3.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,451 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ """
6
+ Implementation of the following modules is borrowed from ml-cvnets repo:
7
+ https://github.com/apple/ml-cvnets/blob/main/cvnets/layers/multi_head_attention.py
8
+ https://github.com/apple/ml-cvnets/blob/main/cvnets/text_encoders/transformer.py
9
+
10
+ Please see ACKNOWLEDGEMENTS for license details.
11
+ """
12
+
13
+ from typing import List, Optional, Union
14
+
15
+ import torch
16
+ from torch import Size, Tensor, nn
17
+ from torch.nn import functional as F
18
+ from torchvision.ops import StochasticDepth
19
+
20
+ from ... import logger
21
+
22
+
23
+ class LayerNormFP32(nn.LayerNorm):
24
+ """
25
+ Applies `Layer Normalization <https://arxiv.org/abs/1607.06450>`_ over a input tensor with FP32 precision
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ normalized_shape: Union[int, List[int], Size],
31
+ eps: Optional[float] = 1e-5,
32
+ elementwise_affine: Optional[bool] = True,
33
+ *args,
34
+ **kwargs,
35
+ ):
36
+ super().__init__(
37
+ normalized_shape=normalized_shape,
38
+ eps=eps,
39
+ elementwise_affine=elementwise_affine,
40
+ *args,
41
+ **kwargs,
42
+ )
43
+
44
+ def forward(self, x: Tensor) -> Tensor:
45
+ # Convert input from dtype X to FP32 and perform normalization operation.
46
+ # This may help with underflow/overflow issues that we typically see with normalization layers
47
+ inp_dtype = x.dtype
48
+ return super().forward(x.to(torch.float32)).to(inp_dtype)
49
+
50
+
51
+ def get_normalization_layer(norm_type, num_features):
52
+ if norm_type == "layer_norm":
53
+ return nn.LayerNorm(num_features)
54
+ elif norm_type == "layer_norm_fp32":
55
+ return LayerNormFP32(num_features)
56
+ else:
57
+ raise NotImplementedError(f"Option: {norm_type} not supported.")
58
+
59
+
60
+ class PositionalEmbedding(nn.Module):
61
+ def __init__(
62
+ self,
63
+ num_embeddings: int,
64
+ embedding_dim: int,
65
+ padding_idx: Optional[int] = None,
66
+ is_learnable: Optional[bool] = False,
67
+ interpolation_mode: Optional[str] = "bilinear",
68
+ *args,
69
+ **kwargs,
70
+ ):
71
+ super().__init__()
72
+ # Add other pos embedding here and logic to choose between them
73
+ module = LearnablePositionalEmbedding
74
+
75
+ self.pos_embed = module(
76
+ num_embeddings=num_embeddings,
77
+ embedding_dim=embedding_dim,
78
+ padding_idx=padding_idx,
79
+ interpolation_mode=interpolation_mode,
80
+ *args,
81
+ **kwargs,
82
+ )
83
+
84
+ def forward(self, seq_len: int, *args, **kwargs) -> Tensor:
85
+ return self.pos_embed(seq_len, *args, **kwargs)
86
+
87
+ def __repr__(self):
88
+ return self.pos_embed.__repr__()
89
+
90
+
91
+ class LearnablePositionalEmbedding(nn.Module):
92
+ """Learnable Positional embedding"""
93
+
94
+ def __init__(
95
+ self,
96
+ num_embeddings: int,
97
+ embedding_dim: int,
98
+ padding_idx: Optional[int] = None,
99
+ interpolation_mode: Optional[str] = "bilinear",
100
+ *args,
101
+ **kwargs,
102
+ ):
103
+ super().__init__()
104
+ self.pos_embed = nn.Parameter(torch.empty(1, 1, num_embeddings, embedding_dim))
105
+ self.embedding_dim = embedding_dim
106
+ self.num_embeddings = num_embeddings
107
+ self.padding_idx = padding_idx
108
+ self.interpolation_mode = interpolation_mode
109
+
110
+ self.reset_parameters()
111
+
112
+ def reset_parameters(self) -> None:
113
+ nn.init.trunc_normal_(self.pos_embed, mean=0, std=self.embedding_dim**-0.5)
114
+ if self.padding_idx is not None:
115
+ with torch.no_grad():
116
+ self.pos_embed[:, :, self.padding_idx, ...] = 0.0
117
+
118
+ def forward(self, seq_len: int, *args, **kwargs) -> Tensor:
119
+ # scale pos embedding
120
+ pos_embed = self.pos_embed
121
+ if self.padding_idx is not None:
122
+ with torch.no_grad():
123
+ pos_embed[:, :, self.padding_idx, ...] = 0.0
124
+
125
+ if seq_len != self.num_embeddings:
126
+ pos_embed = F.interpolate(
127
+ pos_embed,
128
+ size=(seq_len, self.embedding_dim),
129
+ mode=self.interpolation_mode,
130
+ )
131
+
132
+ # Input is of the form [Batch, Seq_len, Embedding_dim]
133
+ return pos_embed.reshape(1, seq_len, self.embedding_dim)
134
+
135
+ def __repr__(self):
136
+ return "{}(num_embeddings={}, embedding_dim={}, padding_idx={})".format(
137
+ self.__class__.__name__,
138
+ self.num_embeddings,
139
+ self.embedding_dim,
140
+ self.padding_idx,
141
+ )
142
+
143
+
144
+ class MultiHeadAttention(nn.Module):
145
+ """
146
+ This layer applies a multi-head self- or cross-attention as described in
147
+ `Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper
148
+
149
+ Args:
150
+ embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, S, C_{in})`
151
+ num_heads (int): Number of heads in multi-head attention
152
+ attn_dropout (Optional[float]): Attention dropout. Default: 0.0
153
+ bias (Optional[bool]): Use bias or not. Default: ``True``
154
+
155
+ Shape:
156
+ - Input:
157
+ - Query tensor (x_q) :math:`(N, S, C_{in})` where :math:`N` is batch size, :math:`S` is number of source tokens,
158
+ and :math:`C_{in}` is input embedding dim
159
+ - Optional Key-Value tensor (x_kv) :math:`(N, T, C_{in})` where :math:`T` is number of target tokens
160
+ - Output: same shape as the input
161
+
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ embed_dim: int,
167
+ num_heads: int,
168
+ attn_dropout: Optional[float] = 0.0,
169
+ bias: Optional[bool] = True,
170
+ output_dim: Optional[int] = None,
171
+ *args,
172
+ **kwargs,
173
+ ) -> None:
174
+ if output_dim is None:
175
+ output_dim = embed_dim
176
+ super().__init__()
177
+ if embed_dim % num_heads != 0:
178
+ logger.error(
179
+ "Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
180
+ self.__class__.__name__, embed_dim, num_heads
181
+ )
182
+ )
183
+
184
+ self.qkv_proj = nn.Linear(
185
+ in_features=embed_dim, out_features=3 * embed_dim, bias=bias
186
+ )
187
+
188
+ self.attn_dropout = nn.Dropout(p=attn_dropout)
189
+ self.out_proj = nn.Linear(
190
+ in_features=embed_dim, out_features=output_dim, bias=bias
191
+ )
192
+
193
+ self.head_dim = embed_dim // num_heads
194
+ self.scaling = self.head_dim**-0.5
195
+ self.softmax = nn.Softmax(dim=-1)
196
+ self.num_heads = num_heads
197
+ self.embed_dim = embed_dim
198
+ self.use_separate_proj_weight = embed_dim != output_dim
199
+
200
+ def __repr__(self):
201
+ return "{}(head_dim={}, num_heads={}, attn_dropout={})".format(
202
+ self.__class__.__name__, self.head_dim, self.num_heads, self.attn_dropout.p
203
+ )
204
+
205
+ def _forward_impl(
206
+ self,
207
+ x_q: Tensor,
208
+ x_kv: Optional[Tensor] = None,
209
+ key_padding_mask: Optional[Tensor] = None,
210
+ attn_mask: Optional[Tensor] = None,
211
+ ) -> Tensor:
212
+ # [N, S, C]
213
+ b_sz, S_len, in_channels = x_q.shape
214
+
215
+ if x_kv is None:
216
+ # self-attention
217
+ # [N, S, C] --> [N, S, 3C] --> [N, S, 3, h, c] where C = hc
218
+ qkv = self.qkv_proj(x_q).reshape(b_sz, S_len, 3, self.num_heads, -1)
219
+ # [N, S, 3, h, c] --> [N, h, 3, S, C]
220
+ qkv = qkv.transpose(1, 3).contiguous()
221
+
222
+ # [N, h, 3, S, C] --> [N, h, S, C] x 3
223
+ query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
224
+ else:
225
+ T_len = x_kv.shape[1]
226
+
227
+ # cross-attention
228
+ # [N, S, C]
229
+ query = F.linear(
230
+ x_q,
231
+ weight=self.qkv_proj.weight[: self.embed_dim, ...],
232
+ bias=self.qkv_proj.bias[: self.embed_dim]
233
+ if self.qkv_proj.bias is not None
234
+ else None,
235
+ )
236
+ # [N, S, C] --> [N, S, h, c] --> [N, h, S, c]
237
+ query = (
238
+ query.reshape(b_sz, S_len, self.num_heads, self.head_dim)
239
+ .transpose(1, 2)
240
+ .contiguous()
241
+ )
242
+
243
+ # [N, T, C] --> [N, T, 2C]
244
+ kv = F.linear(
245
+ x_kv,
246
+ weight=self.qkv_proj.weight[self.embed_dim :, ...],
247
+ bias=self.qkv_proj.bias[self.embed_dim :]
248
+ if self.qkv_proj.bias is not None
249
+ else None,
250
+ )
251
+ # [N, T, 2C] --> [N, T, 2, h, c]
252
+ kv = kv.reshape(b_sz, T_len, 2, self.num_heads, self.head_dim)
253
+ # [N, T, 2, h, c] --> [N, h, 2, T, c]
254
+ kv = kv.transpose(1, 3).contiguous()
255
+ key, value = kv[:, :, 0], kv[:, :, 1]
256
+
257
+ query = query * self.scaling
258
+
259
+ # [N h, T, c] --> [N, h, c, T]
260
+ key = key.transpose(-1, -2)
261
+
262
+ # QK^T
263
+ # [N, h, S, c] x [N, h, c, T] --> [N, h, S, T]
264
+ attn = torch.matmul(query, key)
265
+
266
+ batch_size, num_heads, num_src_tokens, num_tgt_tokens = attn.shape
267
+ if attn_mask is not None:
268
+ # attn_mask shape should be the same as attn
269
+ assert list(attn_mask.shape) == [
270
+ batch_size,
271
+ num_src_tokens,
272
+ num_tgt_tokens,
273
+ ], "Shape of attention mask should be [{}, {}, {}]. Got: {}".format(
274
+ batch_size, num_src_tokens, num_tgt_tokens, attn_mask.shape
275
+ )
276
+ # [N, S, T] --> [N, 1, S, T]
277
+ attn_mask = attn_mask.unsqueeze(1)
278
+ attn = attn + attn_mask
279
+
280
+ if key_padding_mask is not None:
281
+ # Do not attend to padding positions
282
+ # key padding mask size is [N, T]
283
+ assert key_padding_mask.dim() == 2 and list(key_padding_mask.shape) == [
284
+ batch_size,
285
+ num_tgt_tokens,
286
+ ], "Key_padding_mask should be 2-dimension with shape [{}, {}]. Got: {}".format(
287
+ batch_size, num_tgt_tokens, key_padding_mask.shape
288
+ )
289
+ attn = attn.masked_fill(
290
+ key_padding_mask.unsqueeze(1)
291
+ .unsqueeze(2)
292
+ .to(torch.bool), # [N, T] --> [N, 1, 1, T]
293
+ float("-inf"),
294
+ )
295
+
296
+ attn_dtype = attn.dtype
297
+ attn_as_float = self.softmax(attn.float())
298
+ attn = attn_as_float.to(attn_dtype)
299
+ attn = self.attn_dropout(attn)
300
+
301
+ # weighted sum
302
+ # [N, h, S, T] x [N, h, T, c] --> [N, h, S, c]
303
+ out = torch.matmul(attn, value)
304
+
305
+ # [N, h, S, c] --> [N, S, h, c] --> [N, S, C]
306
+ out = out.transpose(1, 2).reshape(b_sz, S_len, -1)
307
+ out = self.out_proj(out)
308
+
309
+ return out
310
+
311
+ def forward(
312
+ self,
313
+ x_q: Tensor,
314
+ x_kv: Optional[Tensor] = None,
315
+ key_padding_mask: Optional[Tensor] = None,
316
+ attn_mask: Optional[Tensor] = None,
317
+ *args,
318
+ **kwargs,
319
+ ) -> Tensor:
320
+ # [Batch , Sequence, Hidden_dim]
321
+ return self._forward_impl(
322
+ x_q=x_q,
323
+ x_kv=x_kv,
324
+ key_padding_mask=key_padding_mask,
325
+ attn_mask=attn_mask,
326
+ )
327
+
328
+
329
+ class TransformerEncoder(nn.Module):
330
+ """
331
+ This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_
332
+ Args:
333
+ embed_dim: :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`.
334
+ ffn_latent_dim: Inner dimension of the FFN.
335
+ num_heads: Number of heads in multi-head attention. Default: 8.
336
+ attn_dropout: Dropout rate for attention in multi-head attention. Default: 0.0
337
+ dropout: Dropout rate. Default: 0.0.
338
+ ffn_dropout: Dropout between FFN layers. Default: 0.0.
339
+ transformer_norm_layer: Normalization layer. Default: layer_norm.
340
+ stochastic_dropout: Stochastic dropout setting. Default: 0.0.
341
+
342
+ Shape:
343
+ - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
344
+ and :math:`C_{in}` is input embedding dim
345
+ - Output: same shape as the input
346
+ """
347
+
348
+ def __init__(
349
+ self,
350
+ embed_dim: int,
351
+ ffn_latent_dim: int,
352
+ num_heads: Optional[int] = 8,
353
+ attn_dropout: Optional[float] = 0.0,
354
+ dropout: Optional[float] = 0.0,
355
+ ffn_dropout: Optional[float] = 0.0,
356
+ transformer_norm_layer: Optional[str] = "layer_norm",
357
+ stochastic_dropout: Optional[float] = 0.0,
358
+ *args,
359
+ **kwargs,
360
+ ) -> None:
361
+
362
+ super().__init__()
363
+
364
+ # Build attention layer
365
+ attn_unit = MultiHeadAttention(
366
+ embed_dim,
367
+ num_heads,
368
+ attn_dropout=attn_dropout,
369
+ bias=True,
370
+ )
371
+
372
+ self.pre_norm_mha = nn.Sequential(
373
+ get_normalization_layer(
374
+ norm_type=transformer_norm_layer, num_features=embed_dim
375
+ ),
376
+ attn_unit,
377
+ nn.Dropout(p=dropout),
378
+ )
379
+
380
+ act_name = nn.GELU()
381
+ self.pre_norm_ffn = nn.Sequential(
382
+ get_normalization_layer(
383
+ norm_type=transformer_norm_layer, num_features=embed_dim
384
+ ),
385
+ nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
386
+ act_name,
387
+ nn.Dropout(p=ffn_dropout),
388
+ nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
389
+ nn.Dropout(p=dropout),
390
+ )
391
+
392
+ self.drop_path = nn.Identity()
393
+ if stochastic_dropout > 0.0:
394
+ if dropout > 0.0:
395
+ logger.error(
396
+ "Stochastic dropout and dropout are mutually exclusive. "
397
+ "Use either of them, but not both."
398
+ "Got: {} and {}".format(stochastic_dropout, dropout)
399
+ )
400
+ self.drop_path = StochasticDepth(p=stochastic_dropout, mode="row")
401
+
402
+ self.embed_dim = embed_dim
403
+ self.ffn_dim = ffn_latent_dim
404
+ self.ffn_dropout = ffn_dropout
405
+ self.stochastic_dropout = stochastic_dropout
406
+ self.std_dropout = dropout
407
+ self.attn_fn_name = attn_unit.__class__.__name__
408
+ self.act_fn_name = act_name.__class__.__name__
409
+ self.norm_type = transformer_norm_layer
410
+
411
+ def __repr__(self) -> str:
412
+ return "{}(embed_dim={}, ffn_dim={}, dropout={}, ffn_dropout={}, stochastic_dropout={}, attn_fn={}, act_fn={}, norm_fn={})".format(
413
+ self.__class__.__name__,
414
+ self.embed_dim,
415
+ self.ffn_dim,
416
+ self.std_dropout,
417
+ self.ffn_dropout,
418
+ self.stochastic_dropout,
419
+ self.attn_fn_name,
420
+ self.act_fn_name,
421
+ self.norm_type,
422
+ )
423
+
424
+ def forward(
425
+ self,
426
+ x: Tensor,
427
+ x_prev: Optional[Tensor] = None,
428
+ key_padding_mask: Optional[Tensor] = None,
429
+ attn_mask: Optional[Tensor] = None,
430
+ *args,
431
+ **kwargs,
432
+ ) -> Tensor:
433
+
434
+ # Multi-head attention
435
+ res = x
436
+ x = self.pre_norm_mha[0](x) # norm
437
+ x = self.pre_norm_mha[1](
438
+ x_q=x,
439
+ x_kv=x_prev,
440
+ key_padding_mask=key_padding_mask,
441
+ attn_mask=attn_mask,
442
+ *args,
443
+ **kwargs,
444
+ ) # mha
445
+
446
+ x = self.drop_path(self.pre_norm_mha[2](x)) # applying stochastic depth
447
+ x = x + res
448
+
449
+ # Feed forward network
450
+ x = x + self.drop_path(self.pre_norm_ffn(x))
451
+ return x
@@ -0,0 +1,4 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All rights reserved.
4
+ #
@@ -0,0 +1,113 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ from typing import List, Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import Tensor
10
+
11
+ from ... import logger
12
+
13
+
14
+ class GlobalPool(nn.Module):
15
+ """
16
+ This layers applies global pooling over a 4D or 5D input tensor
17
+
18
+ Args:
19
+ pool_type (Optional[str]): Pooling type. It can be mean, rms, or abs. Default: `mean`
20
+ keep_dim (Optional[bool]): Do not squeeze the dimensions of a tensor. Default: `False`
21
+
22
+ Shape:
23
+ - Input: :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)`
24
+ - Output: :math:`(N, C, 1, 1)` or :math:`(N, C, 1, 1, 1)` if keep_dim else :math:`(N, C)`
25
+ """
26
+
27
+ pool_types = ["mean", "rms", "abs"]
28
+
29
+ def __init__(
30
+ self,
31
+ pool_type: Optional[str] = "mean",
32
+ keep_dim: Optional[bool] = False,
33
+ *args,
34
+ **kwargs
35
+ ) -> None:
36
+ super().__init__()
37
+ if pool_type not in self.pool_types:
38
+ logger.error(
39
+ "Supported pool types are: {}. Got {}".format(
40
+ self.pool_types, pool_type
41
+ )
42
+ )
43
+ self.pool_type = pool_type
44
+ self.keep_dim = keep_dim
45
+
46
+ def _global_pool(self, x: Tensor, dims: List):
47
+ if self.pool_type == "rms": # root mean square
48
+ x = x**2
49
+ x = torch.mean(x, dim=dims, keepdim=self.keep_dim)
50
+ x = x**-0.5
51
+ elif self.pool_type == "abs": # absolute
52
+ x = torch.mean(torch.abs(x), dim=dims, keepdim=self.keep_dim)
53
+ else:
54
+ # default is mean
55
+ # same as AdaptiveAvgPool
56
+ x = torch.mean(x, dim=dims, keepdim=self.keep_dim)
57
+ return x
58
+
59
+ def forward(self, x: Tensor) -> Tensor:
60
+ if x.dim() == 4:
61
+ dims = [-2, -1]
62
+ elif x.dim() == 5:
63
+ dims = [-3, -2, -1]
64
+ else:
65
+ raise NotImplementedError("Currently 2D and 3D global pooling supported")
66
+ return self._global_pool(x, dims=dims)
67
+
68
+
69
+ class GlobalPool2D(nn.Module):
70
+ """This class implements global pooling with linear projection."""
71
+
72
+ def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None:
73
+ super().__init__()
74
+ scale = in_dim**-0.5
75
+ self.pool = GlobalPool(pool_type="mean", keep_dim=False)
76
+ self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
77
+ self.in_dim = in_dim
78
+ self.out_dim = out_dim
79
+
80
+ def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
81
+ # x is of shape [batch, in_dim]
82
+ assert (
83
+ x.dim() == 4
84
+ ), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format(
85
+ x.shape
86
+ )
87
+
88
+ # [batch, in_dim, in_height, in_width] --> [batch, in_dim]
89
+ x = self.pool(x)
90
+ # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
91
+ x = x @ self.proj
92
+ return x
93
+
94
+
95
+ class SimpleImageProjectionHead(nn.Module):
96
+ """This class implements linear projection head."""
97
+
98
+ def __init__(self, in_dim: int, out_dim: int) -> None:
99
+ super().__init__()
100
+ scale = in_dim**-0.5
101
+ self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
102
+ self.in_dim = in_dim
103
+ self.out_dim = out_dim
104
+
105
+ def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
106
+ # x is of shape [batch, in_dim]
107
+ assert (
108
+ x.dim() == 2
109
+ ), "Input should be 2-dimensional (Batch x in_dim). Got: {}".format(x.shape)
110
+
111
+ # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
112
+ x = x @ self.proj
113
+ return x