lightly-studio 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lightly-studio might be problematic. Click here for more details.

Files changed (219) hide show
  1. lightly_studio/__init__.py +11 -0
  2. lightly_studio/api/__init__.py +0 -0
  3. lightly_studio/api/app.py +110 -0
  4. lightly_studio/api/cache.py +77 -0
  5. lightly_studio/api/db.py +133 -0
  6. lightly_studio/api/db_tables.py +32 -0
  7. lightly_studio/api/features.py +7 -0
  8. lightly_studio/api/routes/api/annotation.py +233 -0
  9. lightly_studio/api/routes/api/annotation_label.py +90 -0
  10. lightly_studio/api/routes/api/annotation_task.py +38 -0
  11. lightly_studio/api/routes/api/classifier.py +387 -0
  12. lightly_studio/api/routes/api/dataset.py +182 -0
  13. lightly_studio/api/routes/api/dataset_tag.py +257 -0
  14. lightly_studio/api/routes/api/exceptions.py +96 -0
  15. lightly_studio/api/routes/api/features.py +17 -0
  16. lightly_studio/api/routes/api/metadata.py +37 -0
  17. lightly_studio/api/routes/api/metrics.py +80 -0
  18. lightly_studio/api/routes/api/sample.py +196 -0
  19. lightly_studio/api/routes/api/settings.py +45 -0
  20. lightly_studio/api/routes/api/status.py +19 -0
  21. lightly_studio/api/routes/api/text_embedding.py +48 -0
  22. lightly_studio/api/routes/api/validators.py +17 -0
  23. lightly_studio/api/routes/healthz.py +13 -0
  24. lightly_studio/api/routes/images.py +104 -0
  25. lightly_studio/api/routes/webapp.py +51 -0
  26. lightly_studio/api/server.py +82 -0
  27. lightly_studio/core/__init__.py +0 -0
  28. lightly_studio/core/dataset.py +523 -0
  29. lightly_studio/core/sample.py +77 -0
  30. lightly_studio/core/start_gui.py +15 -0
  31. lightly_studio/dataset/__init__.py +0 -0
  32. lightly_studio/dataset/edge_embedding_generator.py +144 -0
  33. lightly_studio/dataset/embedding_generator.py +91 -0
  34. lightly_studio/dataset/embedding_manager.py +163 -0
  35. lightly_studio/dataset/env.py +16 -0
  36. lightly_studio/dataset/file_utils.py +35 -0
  37. lightly_studio/dataset/loader.py +622 -0
  38. lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
  39. lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
  40. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
  41. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
  42. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
  43. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
  44. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
  45. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
  46. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
  47. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
  48. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
  49. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
  50. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
  51. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
  52. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
  53. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
  54. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
  55. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
  56. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
  57. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
  58. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
  59. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
  60. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
  61. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
  62. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
  63. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
  78. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
  79. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
  80. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
  81. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
  82. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
  83. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
  84. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
  85. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
  86. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
  87. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
  88. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
  89. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
  90. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
  91. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
  92. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
  93. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
  94. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
  95. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
  96. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
  97. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
  98. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
  99. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
  100. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
  101. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
  102. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
  103. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
  104. lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
  105. lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
  106. lightly_studio/examples/example.py +23 -0
  107. lightly_studio/examples/example_metadata.py +338 -0
  108. lightly_studio/examples/example_selection.py +39 -0
  109. lightly_studio/examples/example_split_work.py +67 -0
  110. lightly_studio/examples/example_v2.py +21 -0
  111. lightly_studio/export_schema.py +18 -0
  112. lightly_studio/few_shot_classifier/__init__.py +0 -0
  113. lightly_studio/few_shot_classifier/classifier.py +80 -0
  114. lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
  115. lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
  116. lightly_studio/metadata/complex_metadata.py +47 -0
  117. lightly_studio/metadata/gps_coordinate.py +41 -0
  118. lightly_studio/metadata/metadata_protocol.py +17 -0
  119. lightly_studio/metrics/__init__.py +0 -0
  120. lightly_studio/metrics/detection/__init__.py +0 -0
  121. lightly_studio/metrics/detection/map.py +268 -0
  122. lightly_studio/models/__init__.py +1 -0
  123. lightly_studio/models/annotation/__init__.py +0 -0
  124. lightly_studio/models/annotation/annotation_base.py +171 -0
  125. lightly_studio/models/annotation/instance_segmentation.py +56 -0
  126. lightly_studio/models/annotation/links.py +17 -0
  127. lightly_studio/models/annotation/object_detection.py +47 -0
  128. lightly_studio/models/annotation/semantic_segmentation.py +44 -0
  129. lightly_studio/models/annotation_label.py +47 -0
  130. lightly_studio/models/annotation_task.py +28 -0
  131. lightly_studio/models/classifier.py +20 -0
  132. lightly_studio/models/dataset.py +84 -0
  133. lightly_studio/models/embedding_model.py +30 -0
  134. lightly_studio/models/metadata.py +208 -0
  135. lightly_studio/models/sample.py +180 -0
  136. lightly_studio/models/sample_embedding.py +37 -0
  137. lightly_studio/models/settings.py +60 -0
  138. lightly_studio/models/tag.py +96 -0
  139. lightly_studio/py.typed +0 -0
  140. lightly_studio/resolvers/__init__.py +7 -0
  141. lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
  142. lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
  143. lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
  144. lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
  145. lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
  146. lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
  147. lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
  148. lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
  149. lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
  150. lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
  151. lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
  152. lightly_studio/resolvers/annotation_resolver/create.py +19 -0
  153. lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
  154. lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
  155. lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
  156. lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
  157. lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
  158. lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
  159. lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
  160. lightly_studio/resolvers/annotation_task_resolver.py +31 -0
  161. lightly_studio/resolvers/annotations/__init__.py +1 -0
  162. lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
  163. lightly_studio/resolvers/dataset_resolver.py +278 -0
  164. lightly_studio/resolvers/embedding_model_resolver.py +100 -0
  165. lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
  166. lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
  167. lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
  168. lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
  169. lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
  170. lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
  171. lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
  172. lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
  173. lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
  174. lightly_studio/resolvers/sample_resolver.py +249 -0
  175. lightly_studio/resolvers/samples_filter.py +81 -0
  176. lightly_studio/resolvers/settings_resolver.py +58 -0
  177. lightly_studio/resolvers/tag_resolver.py +276 -0
  178. lightly_studio/selection/README.md +6 -0
  179. lightly_studio/selection/mundig.py +105 -0
  180. lightly_studio/selection/select.py +96 -0
  181. lightly_studio/selection/select_via_db.py +93 -0
  182. lightly_studio/selection/selection_config.py +31 -0
  183. lightly_studio/services/annotations_service/__init__.py +21 -0
  184. lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
  185. lightly_studio/services/annotations_service/update_annotation.py +65 -0
  186. lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
  187. lightly_studio/services/annotations_service/update_annotations.py +29 -0
  188. lightly_studio/setup_logging.py +19 -0
  189. lightly_studio/type_definitions.py +19 -0
  190. lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
  191. lightly_studio/vendor/LICENSE +31 -0
  192. lightly_studio/vendor/LICENSE_weights_data +50 -0
  193. lightly_studio/vendor/README.md +5 -0
  194. lightly_studio/vendor/__init__.py +1 -0
  195. lightly_studio/vendor/mobileclip/__init__.py +96 -0
  196. lightly_studio/vendor/mobileclip/clip.py +77 -0
  197. lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
  198. lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
  199. lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
  200. lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
  201. lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
  202. lightly_studio/vendor/mobileclip/logger.py +154 -0
  203. lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
  204. lightly_studio/vendor/mobileclip/models/mci.py +933 -0
  205. lightly_studio/vendor/mobileclip/models/vit.py +433 -0
  206. lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
  207. lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
  208. lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
  209. lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
  210. lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
  211. lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
  212. lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
  213. lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
  214. lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
  215. lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
  216. lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
  217. lightly_studio-0.3.1.dist-info/METADATA +520 -0
  218. lightly_studio-0.3.1.dist-info/RECORD +219 -0
  219. lightly_studio-0.3.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,50 @@
1
+ ML-MobileCLIP Model Weights and Data
2
+
3
+ Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+
5
+ IMPORTANT: This Apple software is supplied to you by Apple
6
+ Inc. ("Apple") in consideration of your agreement to the following
7
+ terms, and your use, installation, modification or redistribution of
8
+ this Apple software constitutes acceptance of these terms. If you do
9
+ not agree with these terms, please do not use, install, modify or
10
+ redistribute this Apple software.
11
+
12
+ In consideration of your agreement to abide by the following terms, and
13
+ subject to these terms, Apple grants you a personal, non-exclusive
14
+ license, under Apple's copyrights in this original Apple software (the
15
+ "Apple Software"), to use, reproduce, modify and redistribute the Apple
16
+ Software, with or without modifications, in source and/or binary forms;
17
+ provided that if you redistribute the Apple Software in its entirety and
18
+ without modifications, you must retain this notice and the following
19
+ text and disclaimers in all such redistributions of the Apple Software.
20
+ Neither the name, trademarks, service marks or logos of Apple Inc. may
21
+ be used to endorse or promote products derived from the Apple Software
22
+ without specific prior written permission from Apple. Except as
23
+ expressly stated in this notice, no other rights or licenses, express or
24
+ implied, are granted by Apple herein, including but not limited to any
25
+ patent rights that may be infringed by your derivative works or by other
26
+ works in which the Apple Software may be incorporated.
27
+
28
+ The Apple Software is provided by Apple on an "AS IS" basis. APPLE
29
+ MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
30
+ THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
31
+ FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
32
+ OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
33
+
34
+ IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
35
+ OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
36
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
37
+ INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
38
+ MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
39
+ AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
40
+ STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
41
+ POSSIBILITY OF SUCH DAMAGE.
42
+
43
+ -------------------------------------------------------------------------------
44
+ SOFTWARE DISTRIBUTED WITH ML-MobileCLIP:
45
+
46
+ The ML-MobileCLIP software copyright and license terms can be found in LICENSE.
47
+
48
+ The ML-MobileCLIP software includes a number of subcomponents with separate
49
+ copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
50
+ -------------------------------------------------------------------------------
@@ -0,0 +1,5 @@
1
+ # MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training
2
+
3
+ Trimmed down version of MobileCLIP with modified imports to handle relative paths.
4
+
5
+ Vendored from https://github.com/apple/ml-mobileclip, commit 1140b8d.
@@ -0,0 +1 @@
1
+ """Vendor directory for third-party code."""
@@ -0,0 +1,96 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ import os
6
+ import json
7
+ from typing import Optional, Union, Tuple, Any
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torchvision.transforms import (
12
+ CenterCrop,
13
+ Compose,
14
+ InterpolationMode,
15
+ Resize,
16
+ ToTensor,
17
+ )
18
+
19
+ from .clip import CLIP
20
+ from .modules.text.tokenizer import (
21
+ ClipTokenizer,
22
+ )
23
+ from .modules.common.mobileone import reparameterize_model
24
+
25
+
26
+ def create_model_and_transforms(
27
+ model_name: str,
28
+ pretrained: Optional[str] = None,
29
+ reparameterize: Optional[bool] = True,
30
+ device: Union[str, torch.device] = "cpu",
31
+ ) -> Tuple[nn.Module, Any, Any]:
32
+ """
33
+ Method to instantiate model and pre-processing transforms necessary for inference.
34
+
35
+ Args:
36
+ model_name: Model name. Choose from ['mobileclip_s0', 'mobileclip_s1', 'mobileclip_s2', 'mobileclip_b']
37
+ pretrained: Location of pretrained checkpoint.
38
+ reparameterize: When set to True, re-parameterizable branches get folded for faster inference.
39
+ device: Device identifier for model placement.
40
+
41
+ Returns:
42
+ Tuple of instantiated model, and preprocessing transforms for inference.
43
+ """
44
+ # Config files
45
+ root_dir = os.path.dirname(os.path.abspath(__file__))
46
+ configs_dir = os.path.join(root_dir, "configs")
47
+ model_cfg_file = os.path.join(configs_dir, model_name + ".json")
48
+
49
+ # Get config from yaml file
50
+ if not os.path.exists(model_cfg_file):
51
+ raise ValueError(f"Unsupported model name: {model_name}")
52
+ model_cfg = json.load(open(model_cfg_file, "r"))
53
+
54
+ # Build preprocessing transforms for inference
55
+ resolution = model_cfg["image_cfg"]["image_size"]
56
+ resize_size = resolution
57
+ centercrop_size = resolution
58
+ aug_list = [
59
+ Resize(
60
+ resize_size,
61
+ interpolation=InterpolationMode.BILINEAR,
62
+ ),
63
+ CenterCrop(centercrop_size),
64
+ ToTensor(),
65
+ ]
66
+ preprocess = Compose(aug_list)
67
+
68
+ # Build model
69
+ model = CLIP(cfg=model_cfg)
70
+ model.to(device)
71
+ model.eval()
72
+
73
+ # Load checkpoint
74
+ if pretrained is not None:
75
+ chkpt = torch.load(pretrained, weights_only=True)
76
+ model.load_state_dict(chkpt)
77
+
78
+ # Reparameterize model for inference (if specified)
79
+ if reparameterize:
80
+ model = reparameterize_model(model)
81
+
82
+ return model, None, preprocess
83
+
84
+
85
+ def get_tokenizer(model_name: str) -> nn.Module:
86
+ # Config files
87
+ root_dir = os.path.dirname(os.path.abspath(__file__))
88
+ configs_dir = os.path.join(root_dir, "configs")
89
+ model_cfg_file = os.path.join(configs_dir, model_name + ".json")
90
+
91
+ # Get config from yaml file
92
+ model_cfg = json.load(open(model_cfg_file, "r"))
93
+
94
+ # Build tokenizer
95
+ text_tokenizer = ClipTokenizer(model_cfg)
96
+ return text_tokenizer
@@ -0,0 +1,77 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ """ Model schema in open_clip format for inference only. """
6
+ import math
7
+ from typing import Any, Optional, Dict
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+
13
+ from .text_encoder import (
14
+ TextTransformer,
15
+ )
16
+
17
+ from .image_encoder import MCi
18
+
19
+
20
+ class CLIP(nn.Module):
21
+ """Base class for multi-modal image-text data"""
22
+
23
+ def __init__(self, cfg: Dict, output_dict: bool = False, *args, **kwargs) -> None:
24
+ super().__init__()
25
+ self.output_dict = output_dict
26
+ self.projection_dim = cfg["embed_dim"]
27
+ if self.projection_dim is None:
28
+ raise ValueError("Please specify `embed_dim` in model config.")
29
+
30
+ self.image_encoder = MCi(
31
+ model_name=cfg["image_cfg"]["model_name"],
32
+ projection_dim=self.projection_dim,
33
+ )
34
+ self.text_encoder = TextTransformer(
35
+ cfg=cfg["text_cfg"], projection_dim=self.projection_dim
36
+ )
37
+ self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1.0 / 0.07))
38
+
39
+ def _exponentiate_and_clip_logits(self, max_scale: float = 100.0):
40
+ scale = self.logit_scale.exp()
41
+ scale = torch.clamp(scale, 0, max_scale)
42
+ return scale
43
+
44
+ def encode_image(self, image: torch.Tensor, normalize: bool = False):
45
+ image_encoder_out = self.image_encoder(image)
46
+ if isinstance(image_encoder_out, dict):
47
+ features = image_encoder_out["logits"]
48
+ else:
49
+ features = image_encoder_out
50
+ return F.normalize(features, dim=-1) if normalize else features
51
+
52
+ def encode_text(self, text: torch.Tensor, normalize: bool = False):
53
+ text_features = self.text_encoder(text_tokens=text, key_padding_mask=None)
54
+ return F.normalize(text_features, dim=-1) if normalize else text_features
55
+
56
+ def forward(
57
+ self,
58
+ image: Optional[torch.Tensor] = None,
59
+ text: Optional[torch.Tensor] = None,
60
+ *args,
61
+ **kwargs
62
+ ) -> Any:
63
+
64
+ image_embeddings = (
65
+ self.encode_image(image, normalize=True) if image is not None else None
66
+ )
67
+ text_embeddings = (
68
+ self.encode_text(text, normalize=True) if text is not None else None
69
+ )
70
+
71
+ if self.output_dict:
72
+ return {
73
+ "image_features": image_embeddings,
74
+ "text_features": text_embeddings,
75
+ "logit_scale": self._exponentiate_and_clip_logits(),
76
+ }
77
+ return image_embeddings, text_embeddings, self._exponentiate_and_clip_logits()
@@ -0,0 +1,18 @@
1
+ {
2
+ "embed_dim": 512,
3
+ "image_cfg": {
4
+ "image_size": 224,
5
+ "model_name": "vit_b16"
6
+ },
7
+ "text_cfg": {
8
+ "context_length": 77,
9
+ "vocab_size": 49408,
10
+ "dim": 512,
11
+ "ffn_multiplier_per_layer": 4.0,
12
+ "n_heads_per_layer": 8,
13
+ "n_transformer_layers": 12,
14
+ "norm_layer": "layer_norm_fp32",
15
+ "causal_masking": true,
16
+ "model_name": "base"
17
+ }
18
+ }
@@ -0,0 +1,18 @@
1
+ {
2
+ "embed_dim": 512,
3
+ "image_cfg": {
4
+ "image_size": 256,
5
+ "model_name": "mci0"
6
+ },
7
+ "text_cfg": {
8
+ "context_length": 77,
9
+ "vocab_size": 49408,
10
+ "dim": 512,
11
+ "ffn_multiplier_per_layer": 4.0,
12
+ "n_heads_per_layer": 8,
13
+ "n_transformer_layers": 4,
14
+ "norm_layer": "layer_norm_fp32",
15
+ "causal_masking": false,
16
+ "model_name": "mct"
17
+ }
18
+ }
@@ -0,0 +1,18 @@
1
+ {
2
+ "embed_dim": 512,
3
+ "image_cfg": {
4
+ "image_size": 256,
5
+ "model_name": "mci1"
6
+ },
7
+ "text_cfg": {
8
+ "context_length": 77,
9
+ "vocab_size": 49408,
10
+ "dim": 512,
11
+ "ffn_multiplier_per_layer": 4.0,
12
+ "n_heads_per_layer": 8,
13
+ "n_transformer_layers": 12,
14
+ "norm_layer": "layer_norm_fp32",
15
+ "causal_masking": false,
16
+ "model_name": "base"
17
+ }
18
+ }
@@ -0,0 +1,18 @@
1
+ {
2
+ "embed_dim": 512,
3
+ "image_cfg": {
4
+ "image_size": 256,
5
+ "model_name": "mci2"
6
+ },
7
+ "text_cfg": {
8
+ "context_length": 77,
9
+ "vocab_size": 49408,
10
+ "dim": 512,
11
+ "ffn_multiplier_per_layer": 4.0,
12
+ "n_heads_per_layer": 8,
13
+ "n_transformer_layers": 12,
14
+ "norm_layer": "layer_norm_fp32",
15
+ "causal_masking": false,
16
+ "model_name": "base"
17
+ }
18
+ }
@@ -0,0 +1,67 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ from typing import Any
6
+
7
+ import torch.nn as nn
8
+ from timm.models import create_model
9
+
10
+ from . import models # Added to register models
11
+ from .modules.image.image_projection import GlobalPool2D
12
+
13
+
14
+ class MCi(nn.Module):
15
+ """
16
+ This class implements `MCi Models <https://arxiv.org/pdf/2311.17049.pdf>`_
17
+ """
18
+
19
+ def __init__(self, model_name: str, *args, **kwargs) -> None:
20
+ super().__init__()
21
+ self.projection_dim = None
22
+ if "projection_dim" in kwargs:
23
+ self.projection_dim = kwargs.get("projection_dim")
24
+
25
+ # Create model
26
+ self.model = create_model(model_name, projection_dim=self.projection_dim)
27
+
28
+ # Build out projection head.
29
+ if self.projection_dim is not None:
30
+ if hasattr(self.model, "head"):
31
+ self.model.head = MCi._update_image_classifier(
32
+ image_classifier=self.model.head, projection_dim=self.projection_dim
33
+ )
34
+
35
+ def forward(self, x: Any, *args, **kwargs) -> Any:
36
+ """A forward function of the model."""
37
+ x = self.model(x)
38
+ return x
39
+
40
+ @staticmethod
41
+ def _get_in_feature_dimension(image_classifier: nn.Module) -> int:
42
+ """Return the input feature dimension to the image classification head."""
43
+ in_features = None
44
+ if isinstance(image_classifier, nn.Sequential):
45
+ # Classifier that uses nn.Sequential usually has global pooling and
46
+ # multiple linear layers. Find the first linear layer and get its
47
+ # in_features
48
+ for layer in image_classifier:
49
+ if isinstance(layer, nn.Linear):
50
+ in_features = layer.in_features
51
+ break
52
+ elif isinstance(image_classifier, nn.Linear):
53
+ in_features = image_classifier.in_features
54
+
55
+ if in_features is None:
56
+ raise NotImplementedError(
57
+ f"Cannot get input feature dimension of {image_classifier}."
58
+ )
59
+ return in_features
60
+
61
+ @staticmethod
62
+ def _update_image_classifier(
63
+ image_classifier: nn.Module, projection_dim: int, *args, **kwargs
64
+ ) -> nn.Module:
65
+ in_features = MCi._get_in_feature_dimension(image_classifier)
66
+ new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim)
67
+ return new_img_classifier
@@ -0,0 +1,154 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+
6
+ import os
7
+ import sys
8
+ import time
9
+ import traceback
10
+ from typing import Optional, Union
11
+
12
+ text_colors = {
13
+ "logs": "\033[34m", # 033 is the escape code and 34 is the color code
14
+ "info": "\033[32m",
15
+ "warning": "\033[33m",
16
+ "debug": "\033[93m",
17
+ "error": "\033[31m",
18
+ "bold": "\033[1m",
19
+ "end_color": "\033[0m",
20
+ "light_red": "\033[36m",
21
+ }
22
+
23
+
24
+ def get_curr_time_stamp() -> str:
25
+ return time.strftime("%Y-%m-%d %H:%M:%S")
26
+
27
+
28
+ def error(message: str) -> None:
29
+ time_stamp = get_curr_time_stamp()
30
+ error_str = (
31
+ text_colors["error"]
32
+ + text_colors["bold"]
33
+ + "ERROR "
34
+ + text_colors["end_color"]
35
+ )
36
+
37
+ # exiting with code -1 does not tell any information about the error (e.g., NaN encountered in the loss).
38
+ # For more descriptive error messages, we replace exit(-1) with sys.exit(ERROR_MESSAGE).
39
+ # This allows us to handle specific exceptions in the tests.
40
+
41
+ # print("{} - {} - {}".format(time_stamp, error_str, message), flush=True)
42
+ # print("{} - {} - {}".format(time_stamp, error_str, "Exiting!!!"), flush=True)
43
+ # exit(-1)
44
+
45
+ if sys.exc_info()[0] is None:
46
+ traceback.print_stack()
47
+ else:
48
+ traceback.print_exc()
49
+ sys.exit("{} - {} - {}. Exiting!!!".format(time_stamp, error_str, message))
50
+
51
+
52
+ def color_text(in_text: str) -> str:
53
+ return text_colors["light_red"] + in_text + text_colors["end_color"]
54
+
55
+
56
+ def log(message: str, end="\n") -> None:
57
+ time_stamp = get_curr_time_stamp()
58
+ log_str = (
59
+ text_colors["logs"] + text_colors["bold"] + "LOGS " + text_colors["end_color"]
60
+ )
61
+ print("{} - {} - {}".format(time_stamp, log_str, message), end=end)
62
+
63
+
64
+ def warning(message: Union[str, Warning]) -> None:
65
+ if isinstance(message, Warning):
66
+ message = f"{type(message).__name__}({','.join(map(repr, message.args))}"
67
+
68
+ time_stamp = get_curr_time_stamp()
69
+ warn_str = (
70
+ text_colors["warning"]
71
+ + text_colors["bold"]
72
+ + "WARNING"
73
+ + text_colors["end_color"]
74
+ )
75
+ print("{} - {} - {}".format(time_stamp, warn_str, message))
76
+
77
+
78
+ def ignore_exception_with_warning(message: str) -> None:
79
+ """
80
+ After catching a tolerable exception E1 (e.g. when Model.forward() fails during
81
+ profiling with try-catch, it'll be helpful to log the exception for future
82
+ investigation. But printing the error stack trace, as is, could be confusing
83
+ when an uncaught (non-tolerable) exception "E2" raises down the road. Then, the log
84
+ will contain two stack traces for E1, E2. When looking for errors in logs, users
85
+ should look for E2, but they may find E1.
86
+
87
+ This function appends "(WARNING)" at the end of all lines of the E1 traceback, so
88
+ that the user can distinguish E1 from uncaught exception E2.
89
+
90
+ Args:
91
+ message: Extra explanation and context for debugging. (Note: the exception obj
92
+ will be automatically fetched from python. No need to pass it as an argument or as
93
+ message)
94
+ """
95
+ warning(f"{message}:\n{traceback.format_exc()}".replace("\n", "\n(WARNING)"))
96
+
97
+
98
+ def info(message: str, print_line: Optional[bool] = False) -> None:
99
+ time_stamp = get_curr_time_stamp()
100
+ info_str = (
101
+ text_colors["info"] + text_colors["bold"] + "INFO " + text_colors["end_color"]
102
+ )
103
+ print("{} - {} - {}".format(time_stamp, info_str, message))
104
+ if print_line:
105
+ double_dash_line(dashes=150)
106
+
107
+
108
+ def debug(message: str) -> None:
109
+ time_stamp = get_curr_time_stamp()
110
+ log_str = (
111
+ text_colors["debug"]
112
+ + text_colors["bold"]
113
+ + "DEBUG "
114
+ + text_colors["end_color"]
115
+ )
116
+ print("{} - {} - {}".format(time_stamp, log_str, message))
117
+
118
+
119
+ def double_dash_line(dashes: Optional[int] = 75) -> None:
120
+ print(text_colors["error"] + "=" * dashes + text_colors["end_color"])
121
+
122
+
123
+ def singe_dash_line(dashes: Optional[int] = 67) -> None:
124
+ print("-" * dashes)
125
+
126
+
127
+ def print_header(header: str) -> None:
128
+ double_dash_line()
129
+ print(
130
+ text_colors["info"]
131
+ + text_colors["bold"]
132
+ + "=" * 50
133
+ + str(header)
134
+ + text_colors["end_color"]
135
+ )
136
+ double_dash_line()
137
+
138
+
139
+ def print_header_minor(header: str) -> None:
140
+ print(
141
+ text_colors["warning"]
142
+ + text_colors["bold"]
143
+ + "=" * 25
144
+ + str(header)
145
+ + text_colors["end_color"]
146
+ )
147
+
148
+
149
+ def disable_printing():
150
+ sys.stdout = open(os.devnull, "w")
151
+
152
+
153
+ def enable_printing():
154
+ sys.stdout = sys.__stdout__
@@ -0,0 +1,10 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All rights reserved.
4
+ #
5
+ from .mci import (
6
+ mci0,
7
+ mci1,
8
+ mci2,
9
+ )
10
+ from .vit import vit_b16