lightly-studio 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (356) hide show
  1. lightly_studio/__init__.py +12 -0
  2. lightly_studio/api/__init__.py +0 -0
  3. lightly_studio/api/app.py +131 -0
  4. lightly_studio/api/cache.py +77 -0
  5. lightly_studio/api/db_tables.py +35 -0
  6. lightly_studio/api/features.py +5 -0
  7. lightly_studio/api/routes/api/annotation.py +305 -0
  8. lightly_studio/api/routes/api/annotation_label.py +87 -0
  9. lightly_studio/api/routes/api/annotations/__init__.py +7 -0
  10. lightly_studio/api/routes/api/annotations/create_annotation.py +52 -0
  11. lightly_studio/api/routes/api/caption.py +100 -0
  12. lightly_studio/api/routes/api/classifier.py +384 -0
  13. lightly_studio/api/routes/api/dataset.py +191 -0
  14. lightly_studio/api/routes/api/dataset_tag.py +266 -0
  15. lightly_studio/api/routes/api/embeddings2d.py +90 -0
  16. lightly_studio/api/routes/api/exceptions.py +114 -0
  17. lightly_studio/api/routes/api/export.py +114 -0
  18. lightly_studio/api/routes/api/features.py +17 -0
  19. lightly_studio/api/routes/api/frame.py +241 -0
  20. lightly_studio/api/routes/api/image.py +155 -0
  21. lightly_studio/api/routes/api/metadata.py +161 -0
  22. lightly_studio/api/routes/api/operator.py +75 -0
  23. lightly_studio/api/routes/api/sample.py +103 -0
  24. lightly_studio/api/routes/api/selection.py +87 -0
  25. lightly_studio/api/routes/api/settings.py +41 -0
  26. lightly_studio/api/routes/api/status.py +19 -0
  27. lightly_studio/api/routes/api/text_embedding.py +50 -0
  28. lightly_studio/api/routes/api/validators.py +17 -0
  29. lightly_studio/api/routes/api/video.py +133 -0
  30. lightly_studio/api/routes/healthz.py +13 -0
  31. lightly_studio/api/routes/images.py +104 -0
  32. lightly_studio/api/routes/video_frames_media.py +116 -0
  33. lightly_studio/api/routes/video_media.py +223 -0
  34. lightly_studio/api/routes/webapp.py +51 -0
  35. lightly_studio/api/server.py +94 -0
  36. lightly_studio/core/__init__.py +0 -0
  37. lightly_studio/core/add_samples.py +533 -0
  38. lightly_studio/core/add_videos.py +294 -0
  39. lightly_studio/core/dataset.py +780 -0
  40. lightly_studio/core/dataset_query/__init__.py +14 -0
  41. lightly_studio/core/dataset_query/boolean_expression.py +67 -0
  42. lightly_studio/core/dataset_query/dataset_query.py +317 -0
  43. lightly_studio/core/dataset_query/field.py +113 -0
  44. lightly_studio/core/dataset_query/field_expression.py +79 -0
  45. lightly_studio/core/dataset_query/match_expression.py +23 -0
  46. lightly_studio/core/dataset_query/order_by.py +79 -0
  47. lightly_studio/core/dataset_query/sample_field.py +37 -0
  48. lightly_studio/core/dataset_query/tags_expression.py +46 -0
  49. lightly_studio/core/image_sample.py +36 -0
  50. lightly_studio/core/loading_log.py +56 -0
  51. lightly_studio/core/sample.py +291 -0
  52. lightly_studio/core/start_gui.py +54 -0
  53. lightly_studio/core/video_sample.py +38 -0
  54. lightly_studio/dataset/__init__.py +0 -0
  55. lightly_studio/dataset/edge_embedding_generator.py +155 -0
  56. lightly_studio/dataset/embedding_generator.py +129 -0
  57. lightly_studio/dataset/embedding_manager.py +349 -0
  58. lightly_studio/dataset/env.py +20 -0
  59. lightly_studio/dataset/file_utils.py +49 -0
  60. lightly_studio/dataset/fsspec_lister.py +275 -0
  61. lightly_studio/dataset/mobileclip_embedding_generator.py +158 -0
  62. lightly_studio/dataset/perception_encoder_embedding_generator.py +260 -0
  63. lightly_studio/db_manager.py +166 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.GcXvs2l7.css +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/12.Dx6SXgAb.css +1 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/17.9X9_k6TP.css +1 -0
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/18.BxiimdIO.css +1 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/2.CkOblLn7.css +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/ClassifierSamplesGrid.BJbCDlvs.css +1 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Medium.DVUZMR_6.ttf +0 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
  78. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.CefECEWA.css +1 -0
  79. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.D5tDcjY-.css +1 -0
  80. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.9X9_k6TP.css +1 -0
  81. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.BxiimdIO.css +1 -0
  82. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.Dx6SXgAb.css +1 -0
  83. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/transform._-1mPSEI.css +1 -0
  84. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/0dDyq72A.js +20 -0
  85. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
  86. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BK4An2kI.js +1 -0
  87. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BRmB-kJ9.js +1 -0
  88. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B_1cpokE.js +1 -0
  89. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BiqpDEr0.js +1 -0
  90. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BpLiSKgx.js +1 -0
  91. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BscxbINH.js +39 -0
  92. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C1FmrZbK.js +1 -0
  93. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C80h3dJx.js +1 -0
  94. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8mfFM-u.js +2 -0
  95. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CGY1p9L4.js +517 -0
  96. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/COfLknXM.js +1 -0
  97. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
  98. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
  99. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CmLg0ys7.js +1 -0
  100. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvGjimpO.js +1 -0
  101. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D3RDXHoj.js +39 -0
  102. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D4y7iiT3.js +1 -0
  103. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D9SC3jBb.js +1 -0
  104. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DCuAdx1Q.js +20 -0
  105. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DDBy-_jD.js +1 -0
  106. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
  107. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DL9a7v5o.js +1 -0
  108. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DSKECuqX.js +39 -0
  109. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D_FFv0Oe.js +1 -0
  110. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DiZ5o5vz.js +1 -0
  111. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DkbXUtyG.js +1 -0
  112. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DmK2hulV.js +1 -0
  113. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DqnHaLTj.js +1 -0
  114. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DtWZc_tl.js +1 -0
  115. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DuUalyFS.js +1 -0
  116. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DwIonDAZ.js +1 -0
  117. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Il-mSPmK.js +1 -0
  118. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KNLP4aJU.js +1 -0
  119. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KjYeVjkE.js +1 -0
  120. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/MErlcOXj.js +1 -0
  121. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VRI4prUD.js +1 -0
  122. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VYb2dkNs.js +1 -0
  123. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VqWvU2yF.js +1 -0
  124. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/dHC3otuL.js +1 -0
  125. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/da7Oy_lO.js +1 -0
  126. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/eAy8rZzC.js +2 -0
  127. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/erjNR5MX.js +1 -0
  128. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/f1oG3eFE.js +1 -0
  129. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/rsLi1iKv.js +20 -0
  130. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/rwuuBP9f.js +1 -0
  131. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/xGHZQ1pe.js +3 -0
  132. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.DrTRUgT3.js +2 -0
  133. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.BK5EOJl2.js +1 -0
  134. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.CIvTuljF.js +4 -0
  135. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.UBvSzxdA.js +1 -0
  136. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.CQ_tiLJa.js +1 -0
  137. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.KqkAcaxW.js +1 -0
  138. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.DoYsmxQc.js +1 -0
  139. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/13.571n2LZA.js +1 -0
  140. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/14.DGs689M-.js +1 -0
  141. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/15.CWG1ehzT.js +1 -0
  142. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/16.Dpq6jbSh.js +1 -0
  143. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/17.B5AZbHUU.js +1 -0
  144. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/18.CBga8cnq.js +1 -0
  145. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.D2HXgz-8.js +1090 -0
  146. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.f4HAg-y3.js +1 -0
  147. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.BKF4xuKQ.js +1 -0
  148. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.BAE0Pm_f.js +39 -0
  149. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.CouWWpzA.js +1 -0
  150. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.UBHT0ktp.js +1 -0
  151. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.FiYNElcc.js +1 -0
  152. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.B3-UaT23.js +1 -0
  153. lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/clustering.worker-DKqeLtG0.js +2 -0
  154. lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/search.worker-vNSty3B0.js +1 -0
  155. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
  156. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
  157. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
  158. lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
  159. lightly_studio/dist_lightly_studio_view_app/index.html +45 -0
  160. lightly_studio/errors.py +5 -0
  161. lightly_studio/examples/example.py +25 -0
  162. lightly_studio/examples/example_coco.py +27 -0
  163. lightly_studio/examples/example_coco_caption.py +29 -0
  164. lightly_studio/examples/example_metadata.py +369 -0
  165. lightly_studio/examples/example_operators.py +111 -0
  166. lightly_studio/examples/example_selection.py +28 -0
  167. lightly_studio/examples/example_split_work.py +48 -0
  168. lightly_studio/examples/example_video.py +22 -0
  169. lightly_studio/examples/example_video_annotations.py +157 -0
  170. lightly_studio/examples/example_yolo.py +22 -0
  171. lightly_studio/export/coco_captions.py +69 -0
  172. lightly_studio/export/export_dataset.py +104 -0
  173. lightly_studio/export/lightly_studio_label_input.py +120 -0
  174. lightly_studio/export_schema.py +18 -0
  175. lightly_studio/export_version.py +57 -0
  176. lightly_studio/few_shot_classifier/__init__.py +0 -0
  177. lightly_studio/few_shot_classifier/classifier.py +80 -0
  178. lightly_studio/few_shot_classifier/classifier_manager.py +644 -0
  179. lightly_studio/few_shot_classifier/random_forest_classifier.py +495 -0
  180. lightly_studio/metadata/complex_metadata.py +47 -0
  181. lightly_studio/metadata/compute_similarity.py +84 -0
  182. lightly_studio/metadata/compute_typicality.py +67 -0
  183. lightly_studio/metadata/gps_coordinate.py +41 -0
  184. lightly_studio/metadata/metadata_protocol.py +17 -0
  185. lightly_studio/models/__init__.py +1 -0
  186. lightly_studio/models/annotation/__init__.py +0 -0
  187. lightly_studio/models/annotation/annotation_base.py +303 -0
  188. lightly_studio/models/annotation/instance_segmentation.py +56 -0
  189. lightly_studio/models/annotation/links.py +17 -0
  190. lightly_studio/models/annotation/object_detection.py +47 -0
  191. lightly_studio/models/annotation/semantic_segmentation.py +44 -0
  192. lightly_studio/models/annotation_label.py +47 -0
  193. lightly_studio/models/caption.py +49 -0
  194. lightly_studio/models/classifier.py +20 -0
  195. lightly_studio/models/dataset.py +70 -0
  196. lightly_studio/models/embedding_model.py +30 -0
  197. lightly_studio/models/image.py +96 -0
  198. lightly_studio/models/metadata.py +208 -0
  199. lightly_studio/models/range.py +17 -0
  200. lightly_studio/models/sample.py +154 -0
  201. lightly_studio/models/sample_embedding.py +36 -0
  202. lightly_studio/models/settings.py +69 -0
  203. lightly_studio/models/tag.py +96 -0
  204. lightly_studio/models/two_dim_embedding.py +16 -0
  205. lightly_studio/models/video.py +161 -0
  206. lightly_studio/plugins/__init__.py +0 -0
  207. lightly_studio/plugins/base_operator.py +60 -0
  208. lightly_studio/plugins/operator_registry.py +47 -0
  209. lightly_studio/plugins/parameter.py +70 -0
  210. lightly_studio/py.typed +0 -0
  211. lightly_studio/resolvers/__init__.py +0 -0
  212. lightly_studio/resolvers/annotation_label_resolver/__init__.py +22 -0
  213. lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
  214. lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
  215. lightly_studio/resolvers/annotation_label_resolver/get_all.py +37 -0
  216. lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
  217. lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
  218. lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
  219. lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
  220. lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
  221. lightly_studio/resolvers/annotation_resolver/__init__.py +40 -0
  222. lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +129 -0
  223. lightly_studio/resolvers/annotation_resolver/create_many.py +124 -0
  224. lightly_studio/resolvers/annotation_resolver/delete_annotation.py +87 -0
  225. lightly_studio/resolvers/annotation_resolver/delete_annotations.py +60 -0
  226. lightly_studio/resolvers/annotation_resolver/get_all.py +85 -0
  227. lightly_studio/resolvers/annotation_resolver/get_all_with_payload.py +179 -0
  228. lightly_studio/resolvers/annotation_resolver/get_by_id.py +34 -0
  229. lightly_studio/resolvers/annotation_resolver/get_by_id_with_payload.py +130 -0
  230. lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +142 -0
  231. lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
  232. lightly_studio/resolvers/annotations/__init__.py +1 -0
  233. lightly_studio/resolvers/annotations/annotations_filter.py +88 -0
  234. lightly_studio/resolvers/caption_resolver.py +129 -0
  235. lightly_studio/resolvers/dataset_resolver/__init__.py +55 -0
  236. lightly_studio/resolvers/dataset_resolver/check_dataset_type.py +29 -0
  237. lightly_studio/resolvers/dataset_resolver/create.py +20 -0
  238. lightly_studio/resolvers/dataset_resolver/delete.py +20 -0
  239. lightly_studio/resolvers/dataset_resolver/export.py +267 -0
  240. lightly_studio/resolvers/dataset_resolver/get_all.py +19 -0
  241. lightly_studio/resolvers/dataset_resolver/get_by_id.py +16 -0
  242. lightly_studio/resolvers/dataset_resolver/get_by_name.py +12 -0
  243. lightly_studio/resolvers/dataset_resolver/get_dataset_details.py +27 -0
  244. lightly_studio/resolvers/dataset_resolver/get_hierarchy.py +31 -0
  245. lightly_studio/resolvers/dataset_resolver/get_or_create_child_dataset.py +58 -0
  246. lightly_studio/resolvers/dataset_resolver/get_parent_dataset_by_sample_id.py +27 -0
  247. lightly_studio/resolvers/dataset_resolver/get_parent_dataset_id.py +22 -0
  248. lightly_studio/resolvers/dataset_resolver/get_root_dataset.py +61 -0
  249. lightly_studio/resolvers/dataset_resolver/get_root_datasets_overview.py +41 -0
  250. lightly_studio/resolvers/dataset_resolver/update.py +25 -0
  251. lightly_studio/resolvers/embedding_model_resolver.py +120 -0
  252. lightly_studio/resolvers/image_filter.py +50 -0
  253. lightly_studio/resolvers/image_resolver/__init__.py +21 -0
  254. lightly_studio/resolvers/image_resolver/create_many.py +52 -0
  255. lightly_studio/resolvers/image_resolver/delete.py +20 -0
  256. lightly_studio/resolvers/image_resolver/filter_new_paths.py +23 -0
  257. lightly_studio/resolvers/image_resolver/get_all_by_dataset_id.py +117 -0
  258. lightly_studio/resolvers/image_resolver/get_by_id.py +14 -0
  259. lightly_studio/resolvers/image_resolver/get_dimension_bounds.py +75 -0
  260. lightly_studio/resolvers/image_resolver/get_many_by_id.py +22 -0
  261. lightly_studio/resolvers/image_resolver/get_samples_excluding.py +43 -0
  262. lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
  263. lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
  264. lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
  265. lightly_studio/resolvers/metadata_resolver/sample/bulk_update_metadata.py +46 -0
  266. lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
  267. lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
  268. lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
  269. lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
  270. lightly_studio/resolvers/sample_embedding_resolver.py +132 -0
  271. lightly_studio/resolvers/sample_resolver/__init__.py +17 -0
  272. lightly_studio/resolvers/sample_resolver/count_by_dataset_id.py +16 -0
  273. lightly_studio/resolvers/sample_resolver/create.py +16 -0
  274. lightly_studio/resolvers/sample_resolver/create_many.py +25 -0
  275. lightly_studio/resolvers/sample_resolver/get_by_id.py +14 -0
  276. lightly_studio/resolvers/sample_resolver/get_filtered_samples.py +56 -0
  277. lightly_studio/resolvers/sample_resolver/get_many_by_id.py +22 -0
  278. lightly_studio/resolvers/sample_resolver/sample_filter.py +74 -0
  279. lightly_studio/resolvers/settings_resolver.py +62 -0
  280. lightly_studio/resolvers/tag_resolver.py +299 -0
  281. lightly_studio/resolvers/twodim_embedding_resolver.py +119 -0
  282. lightly_studio/resolvers/video_frame_resolver/__init__.py +23 -0
  283. lightly_studio/resolvers/video_frame_resolver/count_video_frames_annotations.py +83 -0
  284. lightly_studio/resolvers/video_frame_resolver/create_many.py +57 -0
  285. lightly_studio/resolvers/video_frame_resolver/get_all_by_dataset_id.py +63 -0
  286. lightly_studio/resolvers/video_frame_resolver/get_by_id.py +13 -0
  287. lightly_studio/resolvers/video_frame_resolver/get_table_fields_bounds.py +44 -0
  288. lightly_studio/resolvers/video_frame_resolver/video_frame_annotations_counter_filter.py +47 -0
  289. lightly_studio/resolvers/video_frame_resolver/video_frame_filter.py +57 -0
  290. lightly_studio/resolvers/video_resolver/__init__.py +27 -0
  291. lightly_studio/resolvers/video_resolver/count_video_frame_annotations_by_video_dataset.py +86 -0
  292. lightly_studio/resolvers/video_resolver/create_many.py +58 -0
  293. lightly_studio/resolvers/video_resolver/filter_new_paths.py +33 -0
  294. lightly_studio/resolvers/video_resolver/get_all_by_dataset_id.py +181 -0
  295. lightly_studio/resolvers/video_resolver/get_by_id.py +22 -0
  296. lightly_studio/resolvers/video_resolver/get_table_fields_bounds.py +72 -0
  297. lightly_studio/resolvers/video_resolver/get_view_by_id.py +52 -0
  298. lightly_studio/resolvers/video_resolver/video_count_annotations_filter.py +50 -0
  299. lightly_studio/resolvers/video_resolver/video_filter.py +98 -0
  300. lightly_studio/selection/__init__.py +1 -0
  301. lightly_studio/selection/mundig.py +143 -0
  302. lightly_studio/selection/select.py +203 -0
  303. lightly_studio/selection/select_via_db.py +273 -0
  304. lightly_studio/selection/selection_config.py +49 -0
  305. lightly_studio/services/annotations_service/__init__.py +33 -0
  306. lightly_studio/services/annotations_service/create_annotation.py +64 -0
  307. lightly_studio/services/annotations_service/delete_annotation.py +22 -0
  308. lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
  309. lightly_studio/services/annotations_service/update_annotation.py +54 -0
  310. lightly_studio/services/annotations_service/update_annotation_bounding_box.py +36 -0
  311. lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
  312. lightly_studio/services/annotations_service/update_annotations.py +29 -0
  313. lightly_studio/setup_logging.py +59 -0
  314. lightly_studio/type_definitions.py +31 -0
  315. lightly_studio/utils/__init__.py +3 -0
  316. lightly_studio/utils/download.py +94 -0
  317. lightly_studio/vendor/__init__.py +1 -0
  318. lightly_studio/vendor/mobileclip/ACKNOWLEDGEMENTS +422 -0
  319. lightly_studio/vendor/mobileclip/LICENSE +31 -0
  320. lightly_studio/vendor/mobileclip/LICENSE_weights_data +50 -0
  321. lightly_studio/vendor/mobileclip/README.md +5 -0
  322. lightly_studio/vendor/mobileclip/__init__.py +96 -0
  323. lightly_studio/vendor/mobileclip/clip.py +77 -0
  324. lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
  325. lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
  326. lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
  327. lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
  328. lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
  329. lightly_studio/vendor/mobileclip/logger.py +154 -0
  330. lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
  331. lightly_studio/vendor/mobileclip/models/mci.py +933 -0
  332. lightly_studio/vendor/mobileclip/models/vit.py +433 -0
  333. lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
  334. lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
  335. lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
  336. lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
  337. lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
  338. lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
  339. lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
  340. lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
  341. lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
  342. lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
  343. lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
  344. lightly_studio/vendor/perception_encoder/LICENSE.PE +201 -0
  345. lightly_studio/vendor/perception_encoder/README.md +11 -0
  346. lightly_studio/vendor/perception_encoder/vision_encoder/__init__.py +0 -0
  347. lightly_studio/vendor/perception_encoder/vision_encoder/bpe_simple_vocab_16e6.txt.gz +0 -0
  348. lightly_studio/vendor/perception_encoder/vision_encoder/config.py +205 -0
  349. lightly_studio/vendor/perception_encoder/vision_encoder/config_src.py +264 -0
  350. lightly_studio/vendor/perception_encoder/vision_encoder/pe.py +766 -0
  351. lightly_studio/vendor/perception_encoder/vision_encoder/rope.py +352 -0
  352. lightly_studio/vendor/perception_encoder/vision_encoder/tokenizer.py +347 -0
  353. lightly_studio/vendor/perception_encoder/vision_encoder/transforms.py +36 -0
  354. lightly_studio-0.4.6.dist-info/METADATA +88 -0
  355. lightly_studio-0.4.6.dist-info/RECORD +356 -0
  356. lightly_studio-0.4.6.dist-info/WHEEL +4 -0
@@ -0,0 +1,433 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ """
6
+ Implementation of the following modules is borrowed from ml-cvnets repo:
7
+ https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/vit.py
8
+
9
+ Please see ACKNOWLEDGEMENTS for license details.
10
+ """
11
+
12
+ from typing import Dict, Optional, Tuple, Union
13
+
14
+ import numpy as np
15
+ import torch
16
+ from torch import Tensor, nn
17
+
18
+ from timm.models import register_model
19
+ from ..modules.common.transformer import (
20
+ PositionalEmbedding,
21
+ TransformerEncoder,
22
+ get_normalization_layer,
23
+ )
24
+ from ..modules.image.image_projection import SimpleImageProjectionHead
25
+ from .. import logger
26
+
27
+
28
+ class ConvNormAct(nn.Module):
29
+ """
30
+ Applies an N-dimensional convolution over an input.
31
+
32
+ Args:
33
+ cfg: Model configuration.
34
+ in_channels: :math:`C_{out}` from an expected output of size
35
+ :math:`(bs, C_{in}, X_{1}, ..., X_{N})`.
36
+ out_channels: :math:`C_{out}` from an expected output of size
37
+ :math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`.
38
+ kernel_size: Kernel size for convolution. An integer, or tuple of length ``N``.
39
+ stride: Stride for convolution. An integer, or tuple of length ``N``. Default: 1.
40
+ dilation: Dilation rate for convolution. An integer, or tuple of length ``N``.
41
+ Default: ``1``.
42
+ padding: Padding for convolution. An integer, or tuple of length ``N``.
43
+ If not specified, padding is automatically computed based on kernel size and
44
+ dilation range. Default : ``None`` (equivalent to ``[
45
+ int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(N)]``).
46
+ groups: Number of groups in convolution. Default: ``1``.
47
+ bias: Use bias. Default: ``False``.
48
+ padding_mode: Padding mode ('zeros', 'reflect', 'replicate' or 'circular').
49
+ Default: ``zeros``.
50
+ use_norm: Use normalization layer after convolution. Default: ``True``.
51
+ use_act: Use activation layer after convolution (or convolution and normalization).
52
+ Default: ``True``.
53
+ norm_layer: If not None, the provided normalization layer object will be used.
54
+ Otherwise, a normalization object will be created based on config
55
+ ``model.normalization.*`` opts.
56
+ act_layer: If not None, the provided activation function will be used.
57
+ Otherwise, an activation function will be created based on config
58
+ ``model.activation.*`` opts.
59
+
60
+ Shape:
61
+ - Input: :math:`(bs, C_{in}, X_{1}, ..., X_{N})`.
62
+ - Output: :math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`.
63
+
64
+ .. note::
65
+ For depth-wise convolution, `groups=C_{in}=C_{out}`.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ cfg: Dict,
71
+ in_channels: int,
72
+ out_channels: int,
73
+ kernel_size: Union[int, Tuple[int, ...]],
74
+ stride: Union[int, Tuple[int, ...]] = 1,
75
+ dilation: Union[int, Tuple[int, ...]] = 1,
76
+ padding: Optional[Union[int, Tuple[int, ...]]] = None,
77
+ groups: int = 1,
78
+ bias: bool = False,
79
+ padding_mode: str = "zeros",
80
+ use_norm: bool = True,
81
+ use_act: bool = True,
82
+ norm_layer: Optional[nn.Module] = None,
83
+ act_layer: Optional[nn.Module] = None,
84
+ *args,
85
+ **kwargs,
86
+ ) -> None:
87
+ super().__init__()
88
+ self.ndim = 2
89
+
90
+ if norm_layer is None and use_norm:
91
+ norm_type = cfg.get("normalization", "batch_norm")
92
+ if norm_type == "batch_norm":
93
+ norm_layer = nn.BatchNorm2d(
94
+ num_features=out_channels,
95
+ momentum=cfg.get("momentum", 0.1),
96
+ )
97
+ else:
98
+ norm_layer = get_normalization_layer(
99
+ num_features=out_channels, norm_type=norm_type
100
+ )
101
+ elif norm_layer is not None and use_norm:
102
+ logger.error(
103
+ f"When use_norm is False, norm_layer should be None, but norm_layer={norm_layer} is provided."
104
+ )
105
+
106
+ if act_layer is None and use_act:
107
+ act_layer = nn.GELU() # Default to GELU
108
+ elif act_layer is not None and use_act:
109
+ logger.error(
110
+ f"When use_act is False, act_layer should be None, but act_layer={act_layer} is provided."
111
+ )
112
+
113
+ if (
114
+ use_norm
115
+ and any(param[0] == "bias" for param in norm_layer.named_parameters())
116
+ and bias
117
+ ):
118
+ assert (
119
+ not bias
120
+ ), "Do not use bias when using normalization layers with bias."
121
+
122
+ if isinstance(kernel_size, int):
123
+ kernel_size = (kernel_size,) * self.ndim
124
+
125
+ if isinstance(stride, int):
126
+ stride = (stride,) * self.ndim
127
+
128
+ if isinstance(dilation, int):
129
+ dilation = (dilation,) * self.ndim
130
+
131
+ assert isinstance(kernel_size, Tuple)
132
+ assert isinstance(stride, Tuple)
133
+ assert isinstance(dilation, Tuple)
134
+
135
+ if padding is None:
136
+ padding = (
137
+ int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(self.ndim)
138
+ )
139
+
140
+ if in_channels % groups != 0:
141
+ logger.error(
142
+ "Input channels are not divisible by groups. {}%{} != 0 ".format(
143
+ in_channels, groups
144
+ )
145
+ )
146
+ if out_channels % groups != 0:
147
+ logger.error(
148
+ "Output channels are not divisible by groups. {}%{} != 0 ".format(
149
+ out_channels, groups
150
+ )
151
+ )
152
+
153
+ block = nn.Sequential()
154
+
155
+ conv_layer = nn.Conv2d(
156
+ in_channels=in_channels,
157
+ out_channels=out_channels,
158
+ kernel_size=kernel_size, # type: ignore
159
+ stride=stride, # type: ignore
160
+ padding=padding,
161
+ dilation=dilation, # type: ignore
162
+ groups=groups,
163
+ bias=bias,
164
+ padding_mode=padding_mode,
165
+ )
166
+
167
+ block.add_module(name="conv", module=conv_layer)
168
+
169
+ self.norm_name = None
170
+ if use_norm:
171
+ block.add_module(name="norm", module=norm_layer)
172
+ self.norm_name = norm_layer.__class__.__name__
173
+
174
+ self.act_name = None
175
+ if use_act:
176
+ block.add_module(name="act", module=act_layer)
177
+ self.act_name = act_layer.__class__.__name__
178
+
179
+ self.block = block
180
+ self.in_channels = in_channels
181
+ self.out_channels = out_channels
182
+ self.stride = stride
183
+ self.groups = groups
184
+ self.kernel_size = conv_layer.kernel_size
185
+ self.bias = bias
186
+ self.dilation = dilation
187
+
188
+ def forward(self, x: Tensor) -> Tensor:
189
+ return self.block(x)
190
+
191
+
192
+ class VisionTransformer(nn.Module):
193
+ """
194
+ This class defines the `Vision Transformer architecture <https://arxiv.org/abs/2010.11929>`_. Our model implementation
195
+ is inspired from `Early Convolutions Help Transformers See Better <https://arxiv.org/abs/2106.14881>`_
196
+
197
+ .. note::
198
+ Our implementation is different from the original implementation in two ways:
199
+ 1. Kernel size is odd.
200
+ 2. Our positional encoding implementation allows us to use ViT with any multiple input scales
201
+ 3. We do not use StochasticDepth
202
+ 4. We do not add positional encoding to class token (if enabled), as suggested in `DeiT-3 paper <https://arxiv.org/abs/2204.07118>`_
203
+ """
204
+
205
+ def __init__(self, cfg, *args, **kwargs) -> None:
206
+ super().__init__()
207
+ image_channels = 3
208
+ num_classes = cfg.get("n_classes", 1000)
209
+
210
+ self.projection_dim = None
211
+ if "projection_dim" in kwargs:
212
+ self.projection_dim = kwargs.get("projection_dim")
213
+
214
+ kernel_sizes_conv_stem = [4, 2, 2]
215
+ strides_conv_stem = [4, 2, 2]
216
+
217
+ # Typically, in the ImageNet dataset, we use 224x224 as a resolution.
218
+ # For out ViT implementation, patch size is 16 (16 = 4 * 2 * 2)
219
+ # Therefore, total number of embeddings along width and height are (224 / 16)^2
220
+ num_embeddings = (224 // 16) ** 2
221
+
222
+ embed_dim = cfg["embed_dim"]
223
+ ffn_dim = cfg["embed_dim"] * 4
224
+ pos_emb_drop_p = cfg.get("pos_emb_drop_p", 0.0)
225
+ n_transformer_layers = cfg["n_transformer_layers"]
226
+ num_heads = cfg["n_attn_heads"]
227
+ attn_dropout = cfg.get("attn_dropout", 0.0)
228
+ dropout = cfg.get("dropout", 0.0)
229
+ ffn_dropout = cfg.get("ffn_dropout", 0.0)
230
+ norm_layer = cfg.get("norm_layer", "layer_norm")
231
+
232
+ conv_stem_proj_dim = max(32, embed_dim // 4)
233
+ patch_emb = [
234
+ ConvNormAct(
235
+ cfg=cfg,
236
+ in_channels=image_channels,
237
+ out_channels=conv_stem_proj_dim,
238
+ kernel_size=kernel_sizes_conv_stem[0],
239
+ stride=strides_conv_stem[0],
240
+ bias=False,
241
+ use_norm=True,
242
+ use_act=True,
243
+ ),
244
+ ConvNormAct(
245
+ cfg=cfg,
246
+ in_channels=conv_stem_proj_dim,
247
+ out_channels=conv_stem_proj_dim,
248
+ kernel_size=kernel_sizes_conv_stem[1],
249
+ stride=strides_conv_stem[1],
250
+ bias=False,
251
+ use_norm=True,
252
+ use_act=True,
253
+ ),
254
+ ConvNormAct(
255
+ cfg=cfg,
256
+ in_channels=conv_stem_proj_dim,
257
+ out_channels=embed_dim,
258
+ kernel_size=kernel_sizes_conv_stem[2],
259
+ stride=strides_conv_stem[2],
260
+ bias=True,
261
+ use_norm=False,
262
+ use_act=False,
263
+ ),
264
+ ]
265
+
266
+ self.patch_emb = nn.Sequential(*patch_emb)
267
+
268
+ use_cls_token = not cfg.get("no_cls_token", False)
269
+ stochastic_dropout = cfg.get("stochastic_dropout", 0.0)
270
+ per_layer_stochastic_drop_rate = [
271
+ round(x, 3)
272
+ for x in np.linspace(0, stochastic_dropout, n_transformer_layers)
273
+ ]
274
+ transformer_blocks = [
275
+ TransformerEncoder(
276
+ embed_dim=embed_dim,
277
+ ffn_latent_dim=ffn_dim,
278
+ num_heads=num_heads,
279
+ attn_dropout=attn_dropout,
280
+ dropout=dropout,
281
+ ffn_dropout=ffn_dropout,
282
+ transformer_norm_layer=norm_layer,
283
+ stochastic_dropout=per_layer_stochastic_drop_rate[layer_idx],
284
+ )
285
+ for layer_idx in range(n_transformer_layers)
286
+ ]
287
+
288
+ self.post_transformer_norm = get_normalization_layer(
289
+ num_features=embed_dim, norm_type=norm_layer
290
+ )
291
+
292
+ self.transformer = nn.Sequential(*transformer_blocks)
293
+
294
+ if self.projection_dim is None:
295
+ self.classifier = nn.Linear(embed_dim, num_classes)
296
+ else:
297
+ self.classifier = SimpleImageProjectionHead(embed_dim, self.projection_dim)
298
+
299
+ if use_cls_token:
300
+ self.cls_token = nn.Parameter(torch.zeros(size=(1, 1, embed_dim)))
301
+ torch.nn.init.trunc_normal_(self.cls_token, std=0.02)
302
+ else:
303
+ self.cls_token = None
304
+
305
+ self.pos_embed = PositionalEmbedding(
306
+ num_embeddings=num_embeddings,
307
+ embedding_dim=embed_dim,
308
+ padding_idx=None,
309
+ interpolation_mode="bilinear",
310
+ )
311
+ self.emb_dropout = nn.Dropout(p=pos_emb_drop_p)
312
+
313
+ def extract_patch_embeddings(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]:
314
+ # input is of shape [Batch, in_channels, height, width]. in_channels is mostly 3 (for RGB images)
315
+ batch_size = x.shape[0]
316
+
317
+ # [Batch, in_channels, height, width] --> [Batch, emb_dim, num_patches_height, num_patches_width]
318
+ patch_emb = self.patch_emb(x)
319
+ n_h, n_w = patch_emb.shape[-2:]
320
+
321
+ # [Batch, emb_dim, num_patches_height, num_patches_width] --> [Batch, emb_dim, num_patches]
322
+ patch_emb = patch_emb.flatten(2)
323
+ # [Batch, emb_dim, num_patches] --> [Batch, num_patches, emb_dim]
324
+ patch_emb = patch_emb.transpose(1, 2).contiguous()
325
+
326
+ n_patches = patch_emb.shape[1]
327
+ # we resize the positional encodings dynamically.
328
+ pos_emb = self.pos_embed(n_patches).to(patch_emb.dtype)
329
+
330
+ # add positional encodings
331
+ patch_emb = pos_emb + patch_emb
332
+
333
+ # add classification token
334
+ if self.cls_token is not None:
335
+ # [1, 1, emb_dim] --> [Batch, 1, emb_dim]
336
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
337
+ # Concat([Batch, 1, emb_dim], [Batch, num_patches, emb_dim]) --> [Batch, num_patches + 1, emb_dim]
338
+ patch_emb = torch.cat((cls_tokens, patch_emb), dim=1)
339
+
340
+ # dropout
341
+ patch_emb = self.emb_dropout(patch_emb)
342
+ return patch_emb, (n_h, n_w)
343
+
344
+ def _features_from_transformer(
345
+ self, x: Tensor, *args, **kwargs
346
+ ) -> Tuple[Tensor, Tuple[int, int]]:
347
+ # this function extract patch embeddings and then apply transformer module to learn
348
+ # inter-patch representations
349
+
350
+ # [B, N, C] --> [N, B, embed_dim], where B is batch size, N is number of tokens,
351
+ # and embed_dim is feature dim
352
+ x, (n_h, n_w) = self.extract_patch_embeddings(x)
353
+
354
+ for layer in self.transformer:
355
+ x = layer(x)
356
+ x = self.post_transformer_norm(x)
357
+
358
+ return x, (n_h, n_w)
359
+
360
+ def extract_features(
361
+ self, x: Tensor, *args, **kwargs
362
+ ) -> Tuple[Tensor, Optional[Tensor]]:
363
+ # The extract_features function for ViT returns two outputs: (1) embedding corresponding to CLS token
364
+ # and (2) image embeddings of the shape [B, C, h//o, w//o], where the value of o is typically 16.
365
+ return_image_embeddings = kwargs.get("return_image_embeddings", False)
366
+
367
+ # [B, C, H, W] --> [B, N + 1, embed_dim] or [B, N, embed_dim]
368
+ # here, B is batch size, C is input channels
369
+ # H and W are input height and width
370
+ # N is the number of pixels (or tokens) after processing input with conv stem and reshaping
371
+ # We add +1 for cls token (if applicable)
372
+ # embed_dim --> embedding dimension
373
+ x, (n_h, n_w) = self._features_from_transformer(x, *args, **kwargs)
374
+
375
+ if self.cls_token is not None:
376
+ # [B, N + 1, embed_dim] --> [B, embed_dim], [B, N, embed_dim]
377
+ cls_embedding, image_embedding = torch.split(
378
+ x, split_size_or_sections=[1, x.shape[1] - 1], dim=1
379
+ )
380
+ cls_embedding = cls_embedding.squeeze(1)
381
+ else:
382
+ # [B, N, embed_dim] -> [B, embed_dim]
383
+ cls_embedding = torch.mean(x, dim=1)
384
+ # [B, N, embed_dim]
385
+ image_embedding = x
386
+
387
+ if return_image_embeddings:
388
+ # reshape image embedding to 4-D tensor
389
+ # [B, N, C] --> [B, C, N]
390
+ image_embedding = image_embedding.transpose(1, 2).contiguous()
391
+ image_embedding = image_embedding.reshape(
392
+ image_embedding.shape[0], -1, n_h, n_w
393
+ )
394
+
395
+ return cls_embedding, image_embedding
396
+ else:
397
+ return cls_embedding, None
398
+
399
+ def forward_classifier(self, x: Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]:
400
+ cls_embedding, image_embedding = self.extract_features(x, *args, **kwargs)
401
+ # classify based on CLS token
402
+ cls_embedding = self.classifier(cls_embedding)
403
+ return cls_embedding, image_embedding
404
+
405
+ def forward(self, x: Tensor, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]:
406
+ # In ViT model, we can return either classifier embeddings (logits) or image embeddings or both.
407
+ # To return the image embeddings, we need to set keyword argument (return_image_embeddings) as True.
408
+ if kwargs.get("return_image_embeddings", False):
409
+ out_dict = dict()
410
+ prediction, image_embedding = self.forward_classifier(x, *args, **kwargs)
411
+ out_dict.update({"logits": prediction})
412
+ if image_embedding is not None:
413
+ out_dict.update({"image_embeddings": image_embedding})
414
+ return out_dict
415
+ else:
416
+ prediction, _ = self.forward_classifier(x, *args, **kwargs)
417
+ return prediction
418
+
419
+
420
+ @register_model
421
+ def vit_b16(pretrained=False, **kwargs):
422
+ # Vision transformer config
423
+ cfg = {
424
+ "norm_layer": "layer_norm_fp32",
425
+ "act_layer": "gelu",
426
+ "embed_dim": 768,
427
+ "n_transformer_layers": 12,
428
+ "n_attn_heads": 12,
429
+ }
430
+ model = VisionTransformer(cfg=cfg, **kwargs)
431
+ if pretrained:
432
+ raise ValueError("Functionality not implemented.")
433
+ return model
@@ -0,0 +1,4 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All rights reserved.
4
+ #
@@ -0,0 +1,4 @@
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All rights reserved.
4
+ #