lightly-studio 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lightly-studio might be problematic. Click here for more details.

Files changed (219) hide show
  1. lightly_studio/__init__.py +11 -0
  2. lightly_studio/api/__init__.py +0 -0
  3. lightly_studio/api/app.py +110 -0
  4. lightly_studio/api/cache.py +77 -0
  5. lightly_studio/api/db.py +133 -0
  6. lightly_studio/api/db_tables.py +32 -0
  7. lightly_studio/api/features.py +7 -0
  8. lightly_studio/api/routes/api/annotation.py +233 -0
  9. lightly_studio/api/routes/api/annotation_label.py +90 -0
  10. lightly_studio/api/routes/api/annotation_task.py +38 -0
  11. lightly_studio/api/routes/api/classifier.py +387 -0
  12. lightly_studio/api/routes/api/dataset.py +182 -0
  13. lightly_studio/api/routes/api/dataset_tag.py +257 -0
  14. lightly_studio/api/routes/api/exceptions.py +96 -0
  15. lightly_studio/api/routes/api/features.py +17 -0
  16. lightly_studio/api/routes/api/metadata.py +37 -0
  17. lightly_studio/api/routes/api/metrics.py +80 -0
  18. lightly_studio/api/routes/api/sample.py +196 -0
  19. lightly_studio/api/routes/api/settings.py +45 -0
  20. lightly_studio/api/routes/api/status.py +19 -0
  21. lightly_studio/api/routes/api/text_embedding.py +48 -0
  22. lightly_studio/api/routes/api/validators.py +17 -0
  23. lightly_studio/api/routes/healthz.py +13 -0
  24. lightly_studio/api/routes/images.py +104 -0
  25. lightly_studio/api/routes/webapp.py +51 -0
  26. lightly_studio/api/server.py +82 -0
  27. lightly_studio/core/__init__.py +0 -0
  28. lightly_studio/core/dataset.py +523 -0
  29. lightly_studio/core/sample.py +77 -0
  30. lightly_studio/core/start_gui.py +15 -0
  31. lightly_studio/dataset/__init__.py +0 -0
  32. lightly_studio/dataset/edge_embedding_generator.py +144 -0
  33. lightly_studio/dataset/embedding_generator.py +91 -0
  34. lightly_studio/dataset/embedding_manager.py +163 -0
  35. lightly_studio/dataset/env.py +16 -0
  36. lightly_studio/dataset/file_utils.py +35 -0
  37. lightly_studio/dataset/loader.py +622 -0
  38. lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
  39. lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
  40. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
  41. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
  42. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
  43. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
  44. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
  45. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
  46. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
  47. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
  48. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
  49. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
  50. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
  51. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
  52. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
  53. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
  54. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
  55. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
  56. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
  57. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
  58. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
  59. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
  60. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
  61. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
  62. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
  63. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
  78. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
  79. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
  80. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
  81. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
  82. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
  83. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
  84. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
  85. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
  86. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
  87. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
  88. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
  89. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
  90. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
  91. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
  92. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
  93. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
  94. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
  95. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
  96. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
  97. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
  98. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
  99. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
  100. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
  101. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
  102. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
  103. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
  104. lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
  105. lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
  106. lightly_studio/examples/example.py +23 -0
  107. lightly_studio/examples/example_metadata.py +338 -0
  108. lightly_studio/examples/example_selection.py +39 -0
  109. lightly_studio/examples/example_split_work.py +67 -0
  110. lightly_studio/examples/example_v2.py +21 -0
  111. lightly_studio/export_schema.py +18 -0
  112. lightly_studio/few_shot_classifier/__init__.py +0 -0
  113. lightly_studio/few_shot_classifier/classifier.py +80 -0
  114. lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
  115. lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
  116. lightly_studio/metadata/complex_metadata.py +47 -0
  117. lightly_studio/metadata/gps_coordinate.py +41 -0
  118. lightly_studio/metadata/metadata_protocol.py +17 -0
  119. lightly_studio/metrics/__init__.py +0 -0
  120. lightly_studio/metrics/detection/__init__.py +0 -0
  121. lightly_studio/metrics/detection/map.py +268 -0
  122. lightly_studio/models/__init__.py +1 -0
  123. lightly_studio/models/annotation/__init__.py +0 -0
  124. lightly_studio/models/annotation/annotation_base.py +171 -0
  125. lightly_studio/models/annotation/instance_segmentation.py +56 -0
  126. lightly_studio/models/annotation/links.py +17 -0
  127. lightly_studio/models/annotation/object_detection.py +47 -0
  128. lightly_studio/models/annotation/semantic_segmentation.py +44 -0
  129. lightly_studio/models/annotation_label.py +47 -0
  130. lightly_studio/models/annotation_task.py +28 -0
  131. lightly_studio/models/classifier.py +20 -0
  132. lightly_studio/models/dataset.py +84 -0
  133. lightly_studio/models/embedding_model.py +30 -0
  134. lightly_studio/models/metadata.py +208 -0
  135. lightly_studio/models/sample.py +180 -0
  136. lightly_studio/models/sample_embedding.py +37 -0
  137. lightly_studio/models/settings.py +60 -0
  138. lightly_studio/models/tag.py +96 -0
  139. lightly_studio/py.typed +0 -0
  140. lightly_studio/resolvers/__init__.py +7 -0
  141. lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
  142. lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
  143. lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
  144. lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
  145. lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
  146. lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
  147. lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
  148. lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
  149. lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
  150. lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
  151. lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
  152. lightly_studio/resolvers/annotation_resolver/create.py +19 -0
  153. lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
  154. lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
  155. lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
  156. lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
  157. lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
  158. lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
  159. lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
  160. lightly_studio/resolvers/annotation_task_resolver.py +31 -0
  161. lightly_studio/resolvers/annotations/__init__.py +1 -0
  162. lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
  163. lightly_studio/resolvers/dataset_resolver.py +278 -0
  164. lightly_studio/resolvers/embedding_model_resolver.py +100 -0
  165. lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
  166. lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
  167. lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
  168. lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
  169. lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
  170. lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
  171. lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
  172. lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
  173. lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
  174. lightly_studio/resolvers/sample_resolver.py +249 -0
  175. lightly_studio/resolvers/samples_filter.py +81 -0
  176. lightly_studio/resolvers/settings_resolver.py +58 -0
  177. lightly_studio/resolvers/tag_resolver.py +276 -0
  178. lightly_studio/selection/README.md +6 -0
  179. lightly_studio/selection/mundig.py +105 -0
  180. lightly_studio/selection/select.py +96 -0
  181. lightly_studio/selection/select_via_db.py +93 -0
  182. lightly_studio/selection/selection_config.py +31 -0
  183. lightly_studio/services/annotations_service/__init__.py +21 -0
  184. lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
  185. lightly_studio/services/annotations_service/update_annotation.py +65 -0
  186. lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
  187. lightly_studio/services/annotations_service/update_annotations.py +29 -0
  188. lightly_studio/setup_logging.py +19 -0
  189. lightly_studio/type_definitions.py +19 -0
  190. lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
  191. lightly_studio/vendor/LICENSE +31 -0
  192. lightly_studio/vendor/LICENSE_weights_data +50 -0
  193. lightly_studio/vendor/README.md +5 -0
  194. lightly_studio/vendor/__init__.py +1 -0
  195. lightly_studio/vendor/mobileclip/__init__.py +96 -0
  196. lightly_studio/vendor/mobileclip/clip.py +77 -0
  197. lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
  198. lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
  199. lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
  200. lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
  201. lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
  202. lightly_studio/vendor/mobileclip/logger.py +154 -0
  203. lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
  204. lightly_studio/vendor/mobileclip/models/mci.py +933 -0
  205. lightly_studio/vendor/mobileclip/models/vit.py +433 -0
  206. lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
  207. lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
  208. lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
  209. lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
  210. lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
  211. lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
  212. lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
  213. lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
  214. lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
  215. lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
  216. lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
  217. lightly_studio-0.3.1.dist-info/METADATA +520 -0
  218. lightly_studio-0.3.1.dist-info/RECORD +219 -0
  219. lightly_studio-0.3.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,245 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ import math
6
+ from typing import Optional, Sequence
7
+
8
+ import torch
9
+ from torch import Tensor, nn
10
+
11
+ from .modules.common.transformer import (
12
+ PositionalEmbedding,
13
+ TransformerEncoder,
14
+ get_normalization_layer,
15
+ )
16
+ from .modules.text.repmixer import RepMixerBlock
17
+ from . import logger
18
+
19
+
20
+ class TextTransformer(nn.Module):
21
+ def __init__(self, cfg: dict, projection_dim: int, *args, **kwargs) -> None:
22
+ super().__init__()
23
+
24
+ model_dim = cfg["dim"]
25
+ no_scale_embedding = cfg.get("no_scale_embedding", False)
26
+ no_pos_embedding = cfg.get("no_pos_embedding", False)
27
+ embed_dropout = cfg.get("embed_dropout", 0.0)
28
+ norm_layer = cfg["norm_layer"]
29
+ variant = cfg["model_name"]
30
+ self.vocab_size = cfg["vocab_size"]
31
+ self.projection_dim = projection_dim
32
+
33
+ # Token embedding layer
34
+ self.embedding_layer = nn.Embedding(
35
+ embedding_dim=model_dim, num_embeddings=self.vocab_size
36
+ )
37
+ self.embed_scale = 1.0 if no_scale_embedding else model_dim**-0.5
38
+
39
+ # Context length
40
+ context_length = cfg["context_length"]
41
+ assert (
42
+ context_length is not None
43
+ ), "Context length can't be None. Please set value accordingly."
44
+
45
+ self.positional_embedding = (
46
+ None
47
+ if no_pos_embedding
48
+ else PositionalEmbedding(
49
+ num_embeddings=context_length, embedding_dim=model_dim
50
+ )
51
+ )
52
+
53
+ self.embedding_dropout = nn.Dropout(p=embed_dropout)
54
+
55
+ # Transformer layer
56
+ n_transformer_layers = cfg["n_transformer_layers"]
57
+
58
+ # FFN multipliers for transformer layer
59
+ ffn_multipliers = cfg["ffn_multiplier_per_layer"]
60
+ if isinstance(ffn_multipliers, (float, int)):
61
+ ffn_multipliers = [ffn_multipliers] * n_transformer_layers
62
+
63
+ if not isinstance(ffn_multipliers, Sequence):
64
+ logger.error(
65
+ "{} expects FFN multipliers as a list, whose length is the same as"
66
+ " number of transformer layers. Got: {}".format(
67
+ self.__class__.__name__, type(ffn_multipliers)
68
+ )
69
+ )
70
+ elif (
71
+ isinstance(ffn_multipliers, Sequence)
72
+ and len(ffn_multipliers) != n_transformer_layers
73
+ ):
74
+ logger.error(
75
+ "We need FFN multiplier for each transformer layer. Got {} ffn"
76
+ " multipliers while number of transformer layers = {}".format(
77
+ len(ffn_multipliers), n_transformer_layers
78
+ )
79
+ )
80
+ ffn_dims = [
81
+ int(math.ceil(model_dim * ffn_mult / 16.0) * 16.0)
82
+ for ffn_mult in ffn_multipliers
83
+ ]
84
+
85
+ # Heads for transformer layers
86
+ mha_heads = cfg["n_heads_per_layer"]
87
+ if isinstance(mha_heads, int):
88
+ mha_heads = [mha_heads] * n_transformer_layers
89
+
90
+ if not isinstance(mha_heads, Sequence):
91
+ logger.error(
92
+ "{} expects MHA heads as a list, whose length is the same as number of "
93
+ "transformer layers. Got: {}".format(
94
+ self.__class__.__name__, type(mha_heads)
95
+ )
96
+ )
97
+ elif isinstance(mha_heads, Sequence) and len(mha_heads) != n_transformer_layers:
98
+ logger.error(
99
+ "{} needs MHA heads for each transformer layer. Got {} mha heads while"
100
+ " number of transformer layers = {}".format(
101
+ self.__class__.__name__, len(mha_heads), n_transformer_layers
102
+ )
103
+ )
104
+
105
+ if variant == "base":
106
+ self.transformer = nn.ModuleList(
107
+ [
108
+ TransformerEncoder(
109
+ embed_dim=model_dim,
110
+ num_heads=mha_heads[layer_idx],
111
+ ffn_latent_dim=ffn_dims[layer_idx],
112
+ transformer_norm_layer=norm_layer,
113
+ )
114
+ for layer_idx in range(n_transformer_layers)
115
+ ]
116
+ )
117
+ elif variant == "mct":
118
+ self.transformer = nn.ModuleList([RepMixerBlock(dim=model_dim)])
119
+ self.transformer.extend(
120
+ [
121
+ TransformerEncoder(
122
+ embed_dim=model_dim,
123
+ num_heads=mha_heads[layer_idx],
124
+ ffn_latent_dim=ffn_dims[layer_idx],
125
+ transformer_norm_layer=norm_layer,
126
+ )
127
+ for layer_idx in range(n_transformer_layers)
128
+ ]
129
+ )
130
+ self.transformer.extend([RepMixerBlock(dim=model_dim)])
131
+ else:
132
+ raise ValueError("Unrecognized text encoder variant {}".format(variant))
133
+
134
+ self.final_layer_norm = get_normalization_layer(
135
+ num_features=model_dim, norm_type=norm_layer
136
+ )
137
+
138
+ self.projection_layer = nn.Parameter(
139
+ torch.empty(model_dim, self.projection_dim)
140
+ )
141
+ self.model_dim = model_dim
142
+ self.causal_masking = cfg["causal_masking"]
143
+
144
+ def forward_embedding(self, text_tokens: Tensor) -> Tensor:
145
+ """Return text embedding for all tokens.
146
+
147
+ Args:
148
+ text_tokens: a tensor of token indices. Shape: [batch_size, context_length]
149
+
150
+ Returns:
151
+ A tensor of [batch_size, context_length, hidden_dim].
152
+ """
153
+ # [batch_size, context_length] --> [batch_size, context_length, hidden_dim]
154
+ token_emb = self.embedding_layer(text_tokens)
155
+ seq_len = token_emb.shape[1]
156
+ if self.positional_embedding is not None:
157
+ token_emb = token_emb + self.positional_embedding(seq_len).to(
158
+ token_emb.dtype
159
+ )
160
+ token_emb = self.embedding_dropout(token_emb)
161
+ return token_emb
162
+
163
+ def build_attention_mask(self, context_length: int, batch_size: int) -> Tensor:
164
+ """Build causal attention mask [batch_size, context_length, context_length]."""
165
+ # Build mask with full attention between the tokens
166
+ # pytorch uses additive attention mask; fill with -inf
167
+ mask = torch.empty(context_length, context_length)
168
+ mask.fill_(float("-inf"))
169
+ mask.triu_(1) # zero out the lower diagonal
170
+ mask = mask.unsqueeze(0) # add dummy batch dimension
171
+ mask = mask.expand(batch_size, -1, -1)
172
+ return mask
173
+
174
+ def encode_text(
175
+ self,
176
+ text_tokens: Tensor,
177
+ key_padding_mask: Optional[Tensor] = None,
178
+ return_all_tokens: bool = False,
179
+ *args,
180
+ **kwargs
181
+ ) -> Tensor:
182
+ """Return text token embeddings.
183
+
184
+ Args:
185
+ text_tokens: a tensor of token indices. Shape: [batch_size, context_length]
186
+ key_padding_mask: a tensor of boolean values as the padding mask.
187
+ Shape: [batch_size, context_length]
188
+ return_all_tokens: a boolean flag to return all tokens, defaults to False
189
+ to return only EOT token embedding.
190
+ Returns:
191
+ A tensor of [batch_size, context_length, hidden_dim] if return_all_tokens is
192
+ True, otherwise a tensor of [batch_size, hidden_dim].
193
+ """
194
+ # Discrete tokens to continuous embeddings
195
+ # [batch_size, context_length] --> [batch_size, context_length, hidden_dim]
196
+ token_emb = self.forward_embedding(text_tokens)
197
+
198
+ # [1, context_length, context_length]
199
+ attn_mask = None
200
+ if self.causal_masking:
201
+ attn_mask = self.build_attention_mask(
202
+ context_length=text_tokens.shape[1], batch_size=text_tokens.shape[0]
203
+ )
204
+ attn_mask = attn_mask.to(device=token_emb.device, dtype=token_emb.dtype)
205
+ key_padding_mask = None
206
+
207
+ for layer in self.transformer:
208
+ token_emb = layer(
209
+ token_emb,
210
+ key_padding_mask=key_padding_mask,
211
+ attn_mask=attn_mask,
212
+ )
213
+
214
+ # Apply layer norm
215
+ token_emb = self.final_layer_norm(token_emb)
216
+
217
+ if return_all_tokens:
218
+ return token_emb
219
+
220
+ # Take features from the eot embedding (eot_token is the highest number in each sequence)
221
+ token_emb = token_emb[
222
+ torch.arange(text_tokens.shape[0]), text_tokens.argmax(dim=-1)
223
+ ]
224
+
225
+ token_emb = token_emb @ self.projection_layer
226
+ return token_emb
227
+
228
+ def forward(
229
+ self,
230
+ text_tokens: Tensor,
231
+ key_padding_mask: Optional[Tensor] = None,
232
+ return_all_tokens: bool = False,
233
+ *args,
234
+ **kwargs
235
+ ) -> Tensor:
236
+ # Image-text pair data with single caption
237
+ # [B, CL] --> [B, d]
238
+ text_tokens = self.encode_text(
239
+ text_tokens=text_tokens,
240
+ key_padding_mask=key_padding_mask,
241
+ return_all_tokens=return_all_tokens,
242
+ *args,
243
+ **kwargs
244
+ )
245
+ return text_tokens