lightly-studio 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (356) hide show
  1. lightly_studio/__init__.py +12 -0
  2. lightly_studio/api/__init__.py +0 -0
  3. lightly_studio/api/app.py +131 -0
  4. lightly_studio/api/cache.py +77 -0
  5. lightly_studio/api/db_tables.py +35 -0
  6. lightly_studio/api/features.py +5 -0
  7. lightly_studio/api/routes/api/annotation.py +305 -0
  8. lightly_studio/api/routes/api/annotation_label.py +87 -0
  9. lightly_studio/api/routes/api/annotations/__init__.py +7 -0
  10. lightly_studio/api/routes/api/annotations/create_annotation.py +52 -0
  11. lightly_studio/api/routes/api/caption.py +100 -0
  12. lightly_studio/api/routes/api/classifier.py +384 -0
  13. lightly_studio/api/routes/api/dataset.py +191 -0
  14. lightly_studio/api/routes/api/dataset_tag.py +266 -0
  15. lightly_studio/api/routes/api/embeddings2d.py +90 -0
  16. lightly_studio/api/routes/api/exceptions.py +114 -0
  17. lightly_studio/api/routes/api/export.py +114 -0
  18. lightly_studio/api/routes/api/features.py +17 -0
  19. lightly_studio/api/routes/api/frame.py +241 -0
  20. lightly_studio/api/routes/api/image.py +155 -0
  21. lightly_studio/api/routes/api/metadata.py +161 -0
  22. lightly_studio/api/routes/api/operator.py +75 -0
  23. lightly_studio/api/routes/api/sample.py +103 -0
  24. lightly_studio/api/routes/api/selection.py +87 -0
  25. lightly_studio/api/routes/api/settings.py +41 -0
  26. lightly_studio/api/routes/api/status.py +19 -0
  27. lightly_studio/api/routes/api/text_embedding.py +50 -0
  28. lightly_studio/api/routes/api/validators.py +17 -0
  29. lightly_studio/api/routes/api/video.py +133 -0
  30. lightly_studio/api/routes/healthz.py +13 -0
  31. lightly_studio/api/routes/images.py +104 -0
  32. lightly_studio/api/routes/video_frames_media.py +116 -0
  33. lightly_studio/api/routes/video_media.py +223 -0
  34. lightly_studio/api/routes/webapp.py +51 -0
  35. lightly_studio/api/server.py +94 -0
  36. lightly_studio/core/__init__.py +0 -0
  37. lightly_studio/core/add_samples.py +533 -0
  38. lightly_studio/core/add_videos.py +294 -0
  39. lightly_studio/core/dataset.py +780 -0
  40. lightly_studio/core/dataset_query/__init__.py +14 -0
  41. lightly_studio/core/dataset_query/boolean_expression.py +67 -0
  42. lightly_studio/core/dataset_query/dataset_query.py +317 -0
  43. lightly_studio/core/dataset_query/field.py +113 -0
  44. lightly_studio/core/dataset_query/field_expression.py +79 -0
  45. lightly_studio/core/dataset_query/match_expression.py +23 -0
  46. lightly_studio/core/dataset_query/order_by.py +79 -0
  47. lightly_studio/core/dataset_query/sample_field.py +37 -0
  48. lightly_studio/core/dataset_query/tags_expression.py +46 -0
  49. lightly_studio/core/image_sample.py +36 -0
  50. lightly_studio/core/loading_log.py +56 -0
  51. lightly_studio/core/sample.py +291 -0
  52. lightly_studio/core/start_gui.py +54 -0
  53. lightly_studio/core/video_sample.py +38 -0
  54. lightly_studio/dataset/__init__.py +0 -0
  55. lightly_studio/dataset/edge_embedding_generator.py +155 -0
  56. lightly_studio/dataset/embedding_generator.py +129 -0
  57. lightly_studio/dataset/embedding_manager.py +349 -0
  58. lightly_studio/dataset/env.py +20 -0
  59. lightly_studio/dataset/file_utils.py +49 -0
  60. lightly_studio/dataset/fsspec_lister.py +275 -0
  61. lightly_studio/dataset/mobileclip_embedding_generator.py +158 -0
  62. lightly_studio/dataset/perception_encoder_embedding_generator.py +260 -0
  63. lightly_studio/db_manager.py +166 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.GcXvs2l7.css +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/12.Dx6SXgAb.css +1 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/17.9X9_k6TP.css +1 -0
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/18.BxiimdIO.css +1 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/2.CkOblLn7.css +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/ClassifierSamplesGrid.BJbCDlvs.css +1 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Medium.DVUZMR_6.ttf +0 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
  78. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.CefECEWA.css +1 -0
  79. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.D5tDcjY-.css +1 -0
  80. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.9X9_k6TP.css +1 -0
  81. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.BxiimdIO.css +1 -0
  82. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.Dx6SXgAb.css +1 -0
  83. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/transform._-1mPSEI.css +1 -0
  84. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/0dDyq72A.js +20 -0
  85. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
  86. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BK4An2kI.js +1 -0
  87. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BRmB-kJ9.js +1 -0
  88. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B_1cpokE.js +1 -0
  89. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BiqpDEr0.js +1 -0
  90. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BpLiSKgx.js +1 -0
  91. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BscxbINH.js +39 -0
  92. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C1FmrZbK.js +1 -0
  93. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C80h3dJx.js +1 -0
  94. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8mfFM-u.js +2 -0
  95. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CGY1p9L4.js +517 -0
  96. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/COfLknXM.js +1 -0
  97. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
  98. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
  99. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CmLg0ys7.js +1 -0
  100. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvGjimpO.js +1 -0
  101. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D3RDXHoj.js +39 -0
  102. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D4y7iiT3.js +1 -0
  103. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D9SC3jBb.js +1 -0
  104. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DCuAdx1Q.js +20 -0
  105. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DDBy-_jD.js +1 -0
  106. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
  107. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DL9a7v5o.js +1 -0
  108. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DSKECuqX.js +39 -0
  109. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D_FFv0Oe.js +1 -0
  110. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DiZ5o5vz.js +1 -0
  111. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DkbXUtyG.js +1 -0
  112. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DmK2hulV.js +1 -0
  113. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DqnHaLTj.js +1 -0
  114. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DtWZc_tl.js +1 -0
  115. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DuUalyFS.js +1 -0
  116. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DwIonDAZ.js +1 -0
  117. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Il-mSPmK.js +1 -0
  118. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KNLP4aJU.js +1 -0
  119. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KjYeVjkE.js +1 -0
  120. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/MErlcOXj.js +1 -0
  121. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VRI4prUD.js +1 -0
  122. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VYb2dkNs.js +1 -0
  123. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VqWvU2yF.js +1 -0
  124. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/dHC3otuL.js +1 -0
  125. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/da7Oy_lO.js +1 -0
  126. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/eAy8rZzC.js +2 -0
  127. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/erjNR5MX.js +1 -0
  128. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/f1oG3eFE.js +1 -0
  129. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/rsLi1iKv.js +20 -0
  130. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/rwuuBP9f.js +1 -0
  131. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/xGHZQ1pe.js +3 -0
  132. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.DrTRUgT3.js +2 -0
  133. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.BK5EOJl2.js +1 -0
  134. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.CIvTuljF.js +4 -0
  135. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.UBvSzxdA.js +1 -0
  136. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.CQ_tiLJa.js +1 -0
  137. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.KqkAcaxW.js +1 -0
  138. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.DoYsmxQc.js +1 -0
  139. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/13.571n2LZA.js +1 -0
  140. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/14.DGs689M-.js +1 -0
  141. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/15.CWG1ehzT.js +1 -0
  142. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/16.Dpq6jbSh.js +1 -0
  143. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/17.B5AZbHUU.js +1 -0
  144. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/18.CBga8cnq.js +1 -0
  145. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.D2HXgz-8.js +1090 -0
  146. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.f4HAg-y3.js +1 -0
  147. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.BKF4xuKQ.js +1 -0
  148. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.BAE0Pm_f.js +39 -0
  149. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.CouWWpzA.js +1 -0
  150. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.UBHT0ktp.js +1 -0
  151. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.FiYNElcc.js +1 -0
  152. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.B3-UaT23.js +1 -0
  153. lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/clustering.worker-DKqeLtG0.js +2 -0
  154. lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/search.worker-vNSty3B0.js +1 -0
  155. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
  156. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
  157. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
  158. lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
  159. lightly_studio/dist_lightly_studio_view_app/index.html +45 -0
  160. lightly_studio/errors.py +5 -0
  161. lightly_studio/examples/example.py +25 -0
  162. lightly_studio/examples/example_coco.py +27 -0
  163. lightly_studio/examples/example_coco_caption.py +29 -0
  164. lightly_studio/examples/example_metadata.py +369 -0
  165. lightly_studio/examples/example_operators.py +111 -0
  166. lightly_studio/examples/example_selection.py +28 -0
  167. lightly_studio/examples/example_split_work.py +48 -0
  168. lightly_studio/examples/example_video.py +22 -0
  169. lightly_studio/examples/example_video_annotations.py +157 -0
  170. lightly_studio/examples/example_yolo.py +22 -0
  171. lightly_studio/export/coco_captions.py +69 -0
  172. lightly_studio/export/export_dataset.py +104 -0
  173. lightly_studio/export/lightly_studio_label_input.py +120 -0
  174. lightly_studio/export_schema.py +18 -0
  175. lightly_studio/export_version.py +57 -0
  176. lightly_studio/few_shot_classifier/__init__.py +0 -0
  177. lightly_studio/few_shot_classifier/classifier.py +80 -0
  178. lightly_studio/few_shot_classifier/classifier_manager.py +644 -0
  179. lightly_studio/few_shot_classifier/random_forest_classifier.py +495 -0
  180. lightly_studio/metadata/complex_metadata.py +47 -0
  181. lightly_studio/metadata/compute_similarity.py +84 -0
  182. lightly_studio/metadata/compute_typicality.py +67 -0
  183. lightly_studio/metadata/gps_coordinate.py +41 -0
  184. lightly_studio/metadata/metadata_protocol.py +17 -0
  185. lightly_studio/models/__init__.py +1 -0
  186. lightly_studio/models/annotation/__init__.py +0 -0
  187. lightly_studio/models/annotation/annotation_base.py +303 -0
  188. lightly_studio/models/annotation/instance_segmentation.py +56 -0
  189. lightly_studio/models/annotation/links.py +17 -0
  190. lightly_studio/models/annotation/object_detection.py +47 -0
  191. lightly_studio/models/annotation/semantic_segmentation.py +44 -0
  192. lightly_studio/models/annotation_label.py +47 -0
  193. lightly_studio/models/caption.py +49 -0
  194. lightly_studio/models/classifier.py +20 -0
  195. lightly_studio/models/dataset.py +70 -0
  196. lightly_studio/models/embedding_model.py +30 -0
  197. lightly_studio/models/image.py +96 -0
  198. lightly_studio/models/metadata.py +208 -0
  199. lightly_studio/models/range.py +17 -0
  200. lightly_studio/models/sample.py +154 -0
  201. lightly_studio/models/sample_embedding.py +36 -0
  202. lightly_studio/models/settings.py +69 -0
  203. lightly_studio/models/tag.py +96 -0
  204. lightly_studio/models/two_dim_embedding.py +16 -0
  205. lightly_studio/models/video.py +161 -0
  206. lightly_studio/plugins/__init__.py +0 -0
  207. lightly_studio/plugins/base_operator.py +60 -0
  208. lightly_studio/plugins/operator_registry.py +47 -0
  209. lightly_studio/plugins/parameter.py +70 -0
  210. lightly_studio/py.typed +0 -0
  211. lightly_studio/resolvers/__init__.py +0 -0
  212. lightly_studio/resolvers/annotation_label_resolver/__init__.py +22 -0
  213. lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
  214. lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
  215. lightly_studio/resolvers/annotation_label_resolver/get_all.py +37 -0
  216. lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
  217. lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
  218. lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
  219. lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
  220. lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
  221. lightly_studio/resolvers/annotation_resolver/__init__.py +40 -0
  222. lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +129 -0
  223. lightly_studio/resolvers/annotation_resolver/create_many.py +124 -0
  224. lightly_studio/resolvers/annotation_resolver/delete_annotation.py +87 -0
  225. lightly_studio/resolvers/annotation_resolver/delete_annotations.py +60 -0
  226. lightly_studio/resolvers/annotation_resolver/get_all.py +85 -0
  227. lightly_studio/resolvers/annotation_resolver/get_all_with_payload.py +179 -0
  228. lightly_studio/resolvers/annotation_resolver/get_by_id.py +34 -0
  229. lightly_studio/resolvers/annotation_resolver/get_by_id_with_payload.py +130 -0
  230. lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +142 -0
  231. lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
  232. lightly_studio/resolvers/annotations/__init__.py +1 -0
  233. lightly_studio/resolvers/annotations/annotations_filter.py +88 -0
  234. lightly_studio/resolvers/caption_resolver.py +129 -0
  235. lightly_studio/resolvers/dataset_resolver/__init__.py +55 -0
  236. lightly_studio/resolvers/dataset_resolver/check_dataset_type.py +29 -0
  237. lightly_studio/resolvers/dataset_resolver/create.py +20 -0
  238. lightly_studio/resolvers/dataset_resolver/delete.py +20 -0
  239. lightly_studio/resolvers/dataset_resolver/export.py +267 -0
  240. lightly_studio/resolvers/dataset_resolver/get_all.py +19 -0
  241. lightly_studio/resolvers/dataset_resolver/get_by_id.py +16 -0
  242. lightly_studio/resolvers/dataset_resolver/get_by_name.py +12 -0
  243. lightly_studio/resolvers/dataset_resolver/get_dataset_details.py +27 -0
  244. lightly_studio/resolvers/dataset_resolver/get_hierarchy.py +31 -0
  245. lightly_studio/resolvers/dataset_resolver/get_or_create_child_dataset.py +58 -0
  246. lightly_studio/resolvers/dataset_resolver/get_parent_dataset_by_sample_id.py +27 -0
  247. lightly_studio/resolvers/dataset_resolver/get_parent_dataset_id.py +22 -0
  248. lightly_studio/resolvers/dataset_resolver/get_root_dataset.py +61 -0
  249. lightly_studio/resolvers/dataset_resolver/get_root_datasets_overview.py +41 -0
  250. lightly_studio/resolvers/dataset_resolver/update.py +25 -0
  251. lightly_studio/resolvers/embedding_model_resolver.py +120 -0
  252. lightly_studio/resolvers/image_filter.py +50 -0
  253. lightly_studio/resolvers/image_resolver/__init__.py +21 -0
  254. lightly_studio/resolvers/image_resolver/create_many.py +52 -0
  255. lightly_studio/resolvers/image_resolver/delete.py +20 -0
  256. lightly_studio/resolvers/image_resolver/filter_new_paths.py +23 -0
  257. lightly_studio/resolvers/image_resolver/get_all_by_dataset_id.py +117 -0
  258. lightly_studio/resolvers/image_resolver/get_by_id.py +14 -0
  259. lightly_studio/resolvers/image_resolver/get_dimension_bounds.py +75 -0
  260. lightly_studio/resolvers/image_resolver/get_many_by_id.py +22 -0
  261. lightly_studio/resolvers/image_resolver/get_samples_excluding.py +43 -0
  262. lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
  263. lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
  264. lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
  265. lightly_studio/resolvers/metadata_resolver/sample/bulk_update_metadata.py +46 -0
  266. lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
  267. lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
  268. lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
  269. lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
  270. lightly_studio/resolvers/sample_embedding_resolver.py +132 -0
  271. lightly_studio/resolvers/sample_resolver/__init__.py +17 -0
  272. lightly_studio/resolvers/sample_resolver/count_by_dataset_id.py +16 -0
  273. lightly_studio/resolvers/sample_resolver/create.py +16 -0
  274. lightly_studio/resolvers/sample_resolver/create_many.py +25 -0
  275. lightly_studio/resolvers/sample_resolver/get_by_id.py +14 -0
  276. lightly_studio/resolvers/sample_resolver/get_filtered_samples.py +56 -0
  277. lightly_studio/resolvers/sample_resolver/get_many_by_id.py +22 -0
  278. lightly_studio/resolvers/sample_resolver/sample_filter.py +74 -0
  279. lightly_studio/resolvers/settings_resolver.py +62 -0
  280. lightly_studio/resolvers/tag_resolver.py +299 -0
  281. lightly_studio/resolvers/twodim_embedding_resolver.py +119 -0
  282. lightly_studio/resolvers/video_frame_resolver/__init__.py +23 -0
  283. lightly_studio/resolvers/video_frame_resolver/count_video_frames_annotations.py +83 -0
  284. lightly_studio/resolvers/video_frame_resolver/create_many.py +57 -0
  285. lightly_studio/resolvers/video_frame_resolver/get_all_by_dataset_id.py +63 -0
  286. lightly_studio/resolvers/video_frame_resolver/get_by_id.py +13 -0
  287. lightly_studio/resolvers/video_frame_resolver/get_table_fields_bounds.py +44 -0
  288. lightly_studio/resolvers/video_frame_resolver/video_frame_annotations_counter_filter.py +47 -0
  289. lightly_studio/resolvers/video_frame_resolver/video_frame_filter.py +57 -0
  290. lightly_studio/resolvers/video_resolver/__init__.py +27 -0
  291. lightly_studio/resolvers/video_resolver/count_video_frame_annotations_by_video_dataset.py +86 -0
  292. lightly_studio/resolvers/video_resolver/create_many.py +58 -0
  293. lightly_studio/resolvers/video_resolver/filter_new_paths.py +33 -0
  294. lightly_studio/resolvers/video_resolver/get_all_by_dataset_id.py +181 -0
  295. lightly_studio/resolvers/video_resolver/get_by_id.py +22 -0
  296. lightly_studio/resolvers/video_resolver/get_table_fields_bounds.py +72 -0
  297. lightly_studio/resolvers/video_resolver/get_view_by_id.py +52 -0
  298. lightly_studio/resolvers/video_resolver/video_count_annotations_filter.py +50 -0
  299. lightly_studio/resolvers/video_resolver/video_filter.py +98 -0
  300. lightly_studio/selection/__init__.py +1 -0
  301. lightly_studio/selection/mundig.py +143 -0
  302. lightly_studio/selection/select.py +203 -0
  303. lightly_studio/selection/select_via_db.py +273 -0
  304. lightly_studio/selection/selection_config.py +49 -0
  305. lightly_studio/services/annotations_service/__init__.py +33 -0
  306. lightly_studio/services/annotations_service/create_annotation.py +64 -0
  307. lightly_studio/services/annotations_service/delete_annotation.py +22 -0
  308. lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
  309. lightly_studio/services/annotations_service/update_annotation.py +54 -0
  310. lightly_studio/services/annotations_service/update_annotation_bounding_box.py +36 -0
  311. lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
  312. lightly_studio/services/annotations_service/update_annotations.py +29 -0
  313. lightly_studio/setup_logging.py +59 -0
  314. lightly_studio/type_definitions.py +31 -0
  315. lightly_studio/utils/__init__.py +3 -0
  316. lightly_studio/utils/download.py +94 -0
  317. lightly_studio/vendor/__init__.py +1 -0
  318. lightly_studio/vendor/mobileclip/ACKNOWLEDGEMENTS +422 -0
  319. lightly_studio/vendor/mobileclip/LICENSE +31 -0
  320. lightly_studio/vendor/mobileclip/LICENSE_weights_data +50 -0
  321. lightly_studio/vendor/mobileclip/README.md +5 -0
  322. lightly_studio/vendor/mobileclip/__init__.py +96 -0
  323. lightly_studio/vendor/mobileclip/clip.py +77 -0
  324. lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
  325. lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
  326. lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
  327. lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
  328. lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
  329. lightly_studio/vendor/mobileclip/logger.py +154 -0
  330. lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
  331. lightly_studio/vendor/mobileclip/models/mci.py +933 -0
  332. lightly_studio/vendor/mobileclip/models/vit.py +433 -0
  333. lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
  334. lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
  335. lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
  336. lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
  337. lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
  338. lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
  339. lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
  340. lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
  341. lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
  342. lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
  343. lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
  344. lightly_studio/vendor/perception_encoder/LICENSE.PE +201 -0
  345. lightly_studio/vendor/perception_encoder/README.md +11 -0
  346. lightly_studio/vendor/perception_encoder/vision_encoder/__init__.py +0 -0
  347. lightly_studio/vendor/perception_encoder/vision_encoder/bpe_simple_vocab_16e6.txt.gz +0 -0
  348. lightly_studio/vendor/perception_encoder/vision_encoder/config.py +205 -0
  349. lightly_studio/vendor/perception_encoder/vision_encoder/config_src.py +264 -0
  350. lightly_studio/vendor/perception_encoder/vision_encoder/pe.py +766 -0
  351. lightly_studio/vendor/perception_encoder/vision_encoder/rope.py +352 -0
  352. lightly_studio/vendor/perception_encoder/vision_encoder/tokenizer.py +347 -0
  353. lightly_studio/vendor/perception_encoder/vision_encoder/transforms.py +36 -0
  354. lightly_studio-0.4.6.dist-info/METADATA +88 -0
  355. lightly_studio-0.4.6.dist-info/RECORD +356 -0
  356. lightly_studio-0.4.6.dist-info/WHEEL +4 -0
@@ -0,0 +1,766 @@
1
+ #
2
+ # For licensing see accompanying LICENSE.PE file.
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ #
5
+
6
+ import copy
7
+ import math
8
+ import random
9
+ from collections import OrderedDict
10
+ from dataclasses import asdict
11
+ from functools import partial
12
+ from logging import getLogger
13
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Literal
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ from einops import rearrange
19
+ from torch import nn
20
+ from torch.nn import functional as F
21
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
22
+ from torch.nn.parameter import Parameter
23
+ from torch.utils.checkpoint import checkpoint
24
+
25
+ from lightly_studio.vendor.perception_encoder.vision_encoder.rope import Rope2D
26
+ from lightly_studio.vendor.perception_encoder.vision_encoder.config import PEConfig, PETextConfig, PE_VISION_CONFIG, PE_TEXT_CONFIG, fetch_pe_checkpoint
27
+
28
+
29
+
30
+ logger = getLogger()
31
+
32
+
33
+
34
+ class LayerScale(nn.Module):
35
+ def __init__(self, dim, init_values=1e-5, inplace=False):
36
+ super().__init__()
37
+ self.inplace = inplace
38
+ self.dim = dim
39
+ self.init_values = init_values
40
+
41
+ def forward(self, x):
42
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
43
+
44
+ def init_tensors(self):
45
+ self.gamma = nn.Parameter(self.init_values * torch.ones(self.dim))
46
+
47
+
48
+ class AttentionPooling(nn.Module):
49
+ def __init__(
50
+ self,
51
+ embed_dim: int,
52
+ num_heads: int,
53
+ num_probe: int = 1,
54
+ mlp_ratio: int = 4,
55
+ act_layer: Callable = nn.GELU,
56
+ norm_layer: Callable = nn.LayerNorm,
57
+ ):
58
+ super().__init__()
59
+
60
+ self.embed_dim = embed_dim
61
+ self.num_heads = num_heads
62
+
63
+ assert (
64
+ self.embed_dim % num_heads == 0
65
+ ), "embed_dim must be divisible by num_heads"
66
+
67
+ self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim))
68
+ self.attn = nn.MultiheadAttention(
69
+ self.embed_dim, self.num_heads, batch_first=True
70
+ )
71
+
72
+ self.layernorm = norm_layer(embed_dim)
73
+ self.mlp_width = int(embed_dim * mlp_ratio)
74
+ self.mlp = nn.Sequential(
75
+ OrderedDict(
76
+ [
77
+ ("c_fc", nn.Linear(self.embed_dim, self.mlp_width)),
78
+ ("gelu", act_layer()),
79
+ ("c_proj", nn.Linear(self.mlp_width, self.embed_dim)),
80
+ ]
81
+ )
82
+ )
83
+
84
+ def forward(self, x: torch.Tensor):
85
+ batch, _, _ = x.shape
86
+
87
+ q = self.probe.repeat((batch, 1, 1)).to(x.dtype)
88
+ x = self.attn(q, x, x, need_weights=False)[0]
89
+ x = x + self.mlp(self.layernorm(x))
90
+
91
+ return x
92
+
93
+
94
+ class SelfAttention(nn.Module):
95
+ r"""
96
+ Implements sequence packed attention and RoPe
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ embed_dim: int,
102
+ num_heads: int,
103
+ rope: Optional[nn.Module] = None,
104
+ ):
105
+ super(SelfAttention, self).__init__()
106
+ self.embed_dim = embed_dim
107
+
108
+ self.num_heads = num_heads
109
+ self.head_dim = embed_dim // num_heads
110
+ assert (
111
+ self.head_dim * num_heads == self.embed_dim
112
+ ), "embed_dim must be divisible by num_heads"
113
+
114
+ # To make this compatibile with nn.MultiHeadAttention
115
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
116
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
117
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
118
+
119
+ self.rope = rope
120
+ self.scale = self.head_dim ** (-0.5)
121
+
122
+ def init_tensors(self):
123
+ xavier_uniform_(self.in_proj_weight)
124
+ constant_(self.in_proj_bias, 0.0)
125
+ constant_(self.out_proj.bias, 0.0)
126
+
127
+ def forward(self, x, attn_mask=None):
128
+ batch, seq, embed_dim = x.shape
129
+ proj = F.linear(x, self.in_proj_weight, self.in_proj_bias)
130
+
131
+ # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
132
+ proj = (
133
+ proj.unflatten(-1, (3, embed_dim))
134
+ .unsqueeze(0)
135
+ .transpose(0, -2)
136
+ .squeeze(-2)
137
+ .contiguous()
138
+ )
139
+ q, k, v = proj[0], proj[1], proj[2]
140
+
141
+ # Use "q_" so that we don't accidentally quit in pdb :)
142
+ q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
143
+ k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
144
+ v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
145
+
146
+ if self.rope:
147
+ q, k = self.rope(q, k)
148
+
149
+ attn = F.scaled_dot_product_attention(
150
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale
151
+ )
152
+ attn = rearrange(attn, "b h s d -> b s (h d)")
153
+
154
+ return F.linear(attn, self.out_proj.weight, self.out_proj.bias)
155
+
156
+
157
+ class ResidualAttentionBlock(nn.Module):
158
+ def __init__(
159
+ self,
160
+ d_model: int,
161
+ n_head: int,
162
+ mlp_ratio: float = 4.0,
163
+ ls_init_value: float = None,
164
+ act_layer: Callable = nn.GELU,
165
+ norm_layer: Callable = nn.LayerNorm,
166
+ drop_path: float = 0.0,
167
+ rope: Optional[nn.Module] = None,
168
+ ):
169
+ super().__init__()
170
+
171
+ if rope:
172
+ self.attn = SelfAttention(d_model, n_head, rope=rope)
173
+ else:
174
+ self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
175
+
176
+ self.ls_1 = (
177
+ LayerScale(d_model, ls_init_value)
178
+ if ls_init_value is not None
179
+ else nn.Identity()
180
+ )
181
+ self.ls_2 = (
182
+ LayerScale(d_model, ls_init_value)
183
+ if ls_init_value is not None
184
+ else nn.Identity()
185
+ )
186
+
187
+ self.ln_1 = norm_layer(d_model)
188
+ self.ln_2 = norm_layer(d_model)
189
+
190
+ self.drop_path1 = nn.Identity()
191
+ self.drop_path2 = nn.Identity()
192
+
193
+ mlp_width = int(d_model * mlp_ratio)
194
+ self.mlp = nn.Sequential(
195
+ OrderedDict(
196
+ [
197
+ ("c_fc", nn.Linear(d_model, mlp_width)),
198
+ ("gelu", act_layer()),
199
+ ("c_proj", nn.Linear(mlp_width, d_model)),
200
+ ]
201
+ )
202
+ )
203
+
204
+ def _call_attn(
205
+ self,
206
+ q_x: torch.Tensor,
207
+ attn_mask: Optional[torch.Tensor] = None,
208
+ ):
209
+
210
+ if attn_mask is not None:
211
+ # Leave boolean masks as is
212
+ if not attn_mask.dtype == torch.bool:
213
+ attn_mask = attn_mask.to(q_x.dtype)
214
+
215
+ if isinstance(self.attn, SelfAttention):
216
+ return self.attn(q_x, attn_mask=attn_mask)
217
+ else:
218
+ return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=False)[0]
219
+
220
+ def forward(
221
+ self,
222
+ x: torch.Tensor,
223
+ attn_mask: Optional[torch.Tensor] = None,
224
+ ):
225
+ x = x + self.drop_path1(
226
+ self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask))
227
+ )
228
+ x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x))))
229
+ return x
230
+
231
+
232
+ class Transformer(nn.Module):
233
+ def __init__(
234
+ self,
235
+ width: int,
236
+ layers: int,
237
+ heads: int,
238
+ mlp_ratio: float = 4.0,
239
+ ls_init_value: float = None,
240
+ act_layer: Callable = nn.GELU,
241
+ norm_layer: Callable = nn.LayerNorm,
242
+ drop_path: float = 0.0,
243
+ rope: Optional[nn.Module] = None,
244
+ ):
245
+ super().__init__()
246
+ self.width = width
247
+ self.layers = layers
248
+ self.grad_checkpointing = False
249
+
250
+ self.resblocks = nn.ModuleList(
251
+ [
252
+ ResidualAttentionBlock(
253
+ width,
254
+ heads,
255
+ mlp_ratio,
256
+ ls_init_value=ls_init_value,
257
+ act_layer=act_layer,
258
+ norm_layer=norm_layer,
259
+ drop_path=drop_path,
260
+ rope=rope,
261
+ )
262
+ for _ in range(layers)
263
+ ]
264
+ )
265
+
266
+ @torch.jit.ignore
267
+ def set_grad_checkpointing(self, enable=True):
268
+ self.grad_checkpointing = enable
269
+
270
+ @torch.jit.ignore
271
+ def truncate(self, layer_idx: int):
272
+ """ Delete layers so the last layer is the given layer index. """
273
+ self.layers = ((self.layers + layer_idx) % self.layers) + 1
274
+ self.resblocks = nn.ModuleList(self.resblocks[:self.layers])
275
+
276
+ def forward(
277
+ self,
278
+ x: torch.Tensor,
279
+ attn_mask: Optional[torch.Tensor] = None,
280
+ layer_idx: int = -1,
281
+ ):
282
+ stop_idx = (self.layers + layer_idx) % self.layers
283
+
284
+ for i, r in enumerate(self.resblocks):
285
+ if self.grad_checkpointing and not torch.jit.is_scripting():
286
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
287
+ x = checkpoint(r, x, None, None, attn_mask)
288
+ else:
289
+ x = r(x, attn_mask=attn_mask)
290
+
291
+ if i == stop_idx:
292
+ break
293
+
294
+ return x
295
+
296
+
297
+ class VisionTransformer(nn.Module):
298
+ def __init__(
299
+ self,
300
+ patch_size: int,
301
+ width: int,
302
+ layers: int,
303
+ heads: int,
304
+ mlp_ratio: float,
305
+ act_layer: Callable = nn.GELU,
306
+ norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5),
307
+ use_ln_pre: bool = True,
308
+ use_ln_post: bool = True,
309
+ ls_init_value: float = None,
310
+ drop_path: float = 0.0,
311
+ image_size: int = 448, # Pretrain image size only; you can pass in any image size
312
+ use_abs_posemb: bool = True,
313
+ use_rope2d: bool = True,
314
+ use_cls_token: bool = False,
315
+ output_dim: Optional[int] = 1280,
316
+ attn_pooler_heads: int = 8,
317
+ pool_type: Literal["attn", "tok", "avg", "none"] = "attn",
318
+ ):
319
+ super().__init__()
320
+ assert pool_type in ("attn", "tok", "avg", "none")
321
+ self.pool_type = pool_type
322
+ self.patch_size = patch_size
323
+
324
+ self.output_dim = output_dim or width
325
+ self.proj_dim = output_dim
326
+ self.heads = heads
327
+ self.width = width
328
+ self.layers = layers
329
+
330
+ self.use_abs_posemb = use_abs_posemb
331
+ self.use_cls_token = use_cls_token
332
+ self.use_rope2d = use_rope2d
333
+ self.image_size = image_size
334
+
335
+ self.conv1 = nn.Conv2d(
336
+ in_channels=3,
337
+ out_channels=width,
338
+ kernel_size=patch_size,
339
+ stride=patch_size,
340
+ bias=False,
341
+ )
342
+ self.rope = (
343
+ Rope2D(
344
+ dim=width // heads,
345
+ use_cls_token=self.use_cls_token,
346
+ )
347
+ if self.use_rope2d
348
+ else None
349
+ )
350
+
351
+ self.ln_pre = norm_layer(width) if use_ln_pre else nn.Identity()
352
+ self.ln_post = norm_layer(self.width) if use_ln_post else nn.Identity()
353
+
354
+ self.transformer = Transformer(
355
+ width,
356
+ layers,
357
+ heads,
358
+ mlp_ratio,
359
+ ls_init_value=ls_init_value,
360
+ act_layer=act_layer,
361
+ norm_layer=norm_layer,
362
+ drop_path=drop_path,
363
+ rope=self.rope,
364
+ )
365
+
366
+ if pool_type == "attn":
367
+ self.attn_pool = AttentionPooling(
368
+ embed_dim=width,
369
+ num_heads=attn_pooler_heads,
370
+ act_layer=act_layer,
371
+ norm_layer=norm_layer,
372
+ )
373
+ else:
374
+ self.attn_pool = None
375
+
376
+ self.init_tensors()
377
+
378
+
379
+ def init_tensors(self):
380
+ def init_submodule_tensors(module):
381
+ for name, child in module.named_children():
382
+ if hasattr(child, "init_tensors"):
383
+ logger.debug(f"Initializing tensors for submodule: {name}")
384
+ child.init_tensors()
385
+ init_submodule_tensors(child)
386
+
387
+ init_submodule_tensors(self)
388
+ self.rope.init_tensors()
389
+
390
+ # class embeddings and positional embeddings
391
+ init_scale = self.width**-0.5
392
+
393
+ if self.use_cls_token:
394
+ self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width))
395
+
396
+ if self.use_abs_posemb:
397
+ self.posemb_grid_size = self.image_size // self.patch_size
398
+ self.positional_embedding = nn.Parameter(
399
+ init_scale
400
+ * torch.randn(
401
+ int(self.use_cls_token) + self.posemb_grid_size**2, self.width
402
+ )
403
+ )
404
+
405
+ if self.proj_dim is not None:
406
+ self.proj = nn.Parameter(
407
+ init_scale * torch.randn(self.width, self.proj_dim)
408
+ )
409
+
410
+
411
+ def load_ckpt(self, ckpt_path: str, verbose: bool = True):
412
+ _sd = torch.load(ckpt_path, weights_only=True)
413
+ if "state_dict" in _sd:
414
+ _sd = _sd["state_dict"]
415
+ elif "weights" in _sd:
416
+ _sd = _sd["weights"]
417
+
418
+ # for backwards compatibility
419
+ _sd = {k.replace("module.", ""): v for k, v in _sd.items()}
420
+ if any(k.startswith("visual.") for k in _sd):
421
+ _sd = {k.replace("visual.", ""): v for k, v in _sd.items() if "visual" in k}
422
+
423
+ m, u = self.load_state_dict(_sd, strict=False)
424
+
425
+ if verbose or (m or u):
426
+ logger.info(f"Missing keys for loading vision encoder: {m}")
427
+ logger.info(f"Unexpected keys for loading vision encoder: {u}")
428
+ print(f"Missing keys for loading vision encoder: {m}")
429
+ print(f"Unexpected keys for loading vision encoder: {u}")
430
+
431
+
432
+ def truncate(self, layer_idx: int):
433
+ """ Delete layers so the last layer is the given layer index. """
434
+ self.transformer.truncate(layer_idx)
435
+ self.layers = self.transformer.layers
436
+
437
+
438
+ @classmethod
439
+ def from_config(
440
+ cls,
441
+ name: str,
442
+ pretrained: bool = False,
443
+ checkpoint_path: Optional[str] = None,
444
+ **kwdargs
445
+ ):
446
+ if name not in PE_VISION_CONFIG:
447
+ raise RuntimeError(f"{name} not found in configs.")
448
+
449
+ args = asdict(PE_VISION_CONFIG[name])
450
+ args.update(kwdargs)
451
+
452
+ model = cls(**args)
453
+ if pretrained:
454
+ model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path))
455
+
456
+ return model
457
+
458
+ @classmethod
459
+ def available_configs(cls):
460
+ return list(PE_VISION_CONFIG.keys())
461
+
462
+
463
+ @torch.jit.ignore
464
+ def set_grad_checkpointing(self, enable=True):
465
+ self.transformer.set_grad_checkpointing(enable=enable)
466
+
467
+ def _sample_abs_posemb(self, grid_h: int, grid_w: int):
468
+ """Interpolates the absolute position embedding if necessary."""
469
+ if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
470
+ return self.positional_embedding[None, ...]
471
+
472
+ pos_embed = self.positional_embedding
473
+ if self.use_cls_token:
474
+ cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
475
+
476
+ pos_embed = (
477
+ pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1)
478
+ .permute(0, 3, 1, 2)
479
+ .contiguous()
480
+ )
481
+ pos_embed = F.interpolate(
482
+ pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False
483
+ )
484
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous()
485
+
486
+ if self.use_cls_token:
487
+ pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
488
+
489
+ return pos_embed[None, ...]
490
+
491
+ def _pool(self, x: torch.Tensor):
492
+ if self.pool_type == "tok":
493
+ return x[:, 0]
494
+ elif self.pool_type == "avg":
495
+ return x.mean(dim=1)
496
+ elif self.pool_type == "attn":
497
+ return self.attn_pool(x).squeeze(1)
498
+ elif self.pool_type == "none":
499
+ return x
500
+ else:
501
+ raise NotImplementedError
502
+
503
+ def forward_features(
504
+ self,
505
+ x: torch.Tensor,
506
+ norm: bool = False,
507
+ layer_idx: int = -1,
508
+ strip_cls_token: bool = False
509
+ ):
510
+ batch, _, h, w = x.shape
511
+ grid_h, grid_w = h // self.patch_size, w // self.patch_size
512
+
513
+ x = self.conv1(x)
514
+ x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width)
515
+
516
+ if self.use_cls_token:
517
+ x = torch.cat(
518
+ [self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x],
519
+ dim=1,
520
+ )
521
+
522
+ if self.use_abs_posemb:
523
+ x = x + self._sample_abs_posemb(grid_h, grid_w)
524
+
525
+ if self.use_rope2d:
526
+ self.rope.update_grid(x.device, grid_h, grid_w)
527
+
528
+ x = self.ln_pre(x)
529
+ x = self.transformer(x, layer_idx=layer_idx)
530
+
531
+ if norm:
532
+ x = self.ln_post(x)
533
+
534
+ if strip_cls_token and self.use_cls_token:
535
+ x = x[:, 1:, :]
536
+
537
+ return x
538
+
539
+ def forward(self, x: torch.Tensor, **kwargs):
540
+ x = self.forward_features(x, norm=True, **kwargs)
541
+ x = self._pool(x)
542
+
543
+ if self.proj_dim is not None:
544
+ x = x @ self.proj
545
+
546
+ return x
547
+
548
+
549
+
550
+
551
+
552
+
553
+
554
+
555
+
556
+ class TextTransformer(nn.Module):
557
+ def __init__(
558
+ self,
559
+ context_length: int = 72,
560
+ vocab_size: int = 49408,
561
+ width: int = 512,
562
+ heads: int = 8,
563
+ layers: int = 12,
564
+ mlp_ratio: float = 4.0,
565
+ ls_init_value: float = None,
566
+ output_dim: int = 1280,
567
+ no_causal_mask: bool = False,
568
+ pad_id: int = 0,
569
+ pool_type: str = "argmax",
570
+ proj_bias: bool = False,
571
+ act_layer: Callable = nn.GELU,
572
+ norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5),
573
+ output_tokens: bool = False,
574
+ use_ln_post: bool = True,
575
+ ):
576
+ super().__init__()
577
+ assert pool_type in ("first", "last", "argmax", "none")
578
+ self.pool_type = pool_type
579
+ self.output_tokens = output_tokens
580
+ self.num_pos = self.context_length = context_length
581
+ self.vocab_size = vocab_size
582
+ self.width = width
583
+ self.output_dim = output_dim
584
+ self.heads = heads
585
+ self.pad_id = pad_id
586
+ self.layers = layers
587
+
588
+ self.token_embedding = nn.Embedding(vocab_size, width)
589
+ self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
590
+
591
+ self.transformer = Transformer(
592
+ width=width,
593
+ layers=layers,
594
+ heads=heads,
595
+ mlp_ratio=mlp_ratio,
596
+ ls_init_value=ls_init_value,
597
+ act_layer=act_layer,
598
+ norm_layer=norm_layer,
599
+ )
600
+
601
+ self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
602
+
603
+ if no_causal_mask:
604
+ self.attn_mask = None
605
+ else:
606
+ self.register_buffer(
607
+ "attn_mask", self.build_causal_mask(), persistent=False
608
+ )
609
+
610
+ if pool_type == "attn" or pool_type == "attn_eos":
611
+ self.attn_pool = AttentionPooling(
612
+ embed_dim=width,
613
+ num_heads=heads,
614
+ act_layer=act_layer,
615
+ norm_layer=norm_layer,
616
+ )
617
+ else: # argmax
618
+ self.attn_pool = None
619
+
620
+ if proj_bias:
621
+ self.text_projection = nn.Linear(width, output_dim)
622
+ else:
623
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
624
+
625
+ def build_causal_mask(self):
626
+ # lazily create causal attention mask, with full attention between the tokens
627
+ # pytorch uses additive attention mask; fill with -inf
628
+ mask = torch.empty(self.num_pos, self.num_pos)
629
+ mask.fill_(float("-inf"))
630
+ mask.triu_(1) # zero out the lower diagonal
631
+ return mask
632
+
633
+ def load_ckpt(self, ckpt_path: str, verbose: bool = True):
634
+ _sd = torch.load(ckpt_path, weights_only=True)
635
+ if "state_dict" in _sd:
636
+ _sd = _sd["state_dict"]
637
+ elif "weights" in _sd:
638
+ _sd = _sd["weights"]
639
+
640
+ _sd = {k.replace("module.", ""): v for k, v in _sd.items()}
641
+
642
+ m, u = self.load_state_dict(_sd, strict=False)
643
+
644
+ if verbose or (m or u):
645
+ logger.info(f"Missing keys for loading model: {m}")
646
+ logger.info(f"Unexpected keys for loading model: {u}")
647
+
648
+ def build_cls_mask(self, text):
649
+ cls_mask = (text != self.pad_id).unsqueeze(1)
650
+ cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
651
+ additive_mask = torch.empty(cls_mask.shape, device=cls_mask.device)
652
+ additive_mask.fill_(0)
653
+ additive_mask.masked_fill_(~cls_mask, float("-inf"))
654
+ additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
655
+ return additive_mask
656
+
657
+ def text_global_pool(
658
+ self, x, text: Optional[torch.Tensor] = None, pool_type: str = "argmax"
659
+ ):
660
+ if pool_type == "first":
661
+ pooled, tokens = x[:, 0], x[:, 1:]
662
+ elif pool_type == "last":
663
+ pooled, tokens = x[:, -1], x[:, :-1]
664
+ elif pool_type == "argmax":
665
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
666
+ assert text is not None
667
+ pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
668
+ else:
669
+ pooled = tokens = x
670
+
671
+ return pooled, tokens
672
+
673
+ def forward(self, text):
674
+ seq_len = text.shape[1]
675
+ x = self.token_embedding(
676
+ text
677
+ )
678
+ attn_mask = self.attn_mask
679
+ if attn_mask is not None:
680
+ attn_mask = attn_mask[:seq_len, :seq_len]
681
+
682
+ x = x + self.positional_embedding[:seq_len]
683
+ x = self.transformer(x, attn_mask=attn_mask)
684
+
685
+ x = self.ln_final(x)
686
+ pooled, tokens = self.text_global_pool(x, text, pool_type=self.pool_type)
687
+
688
+ if self.text_projection is not None:
689
+ if isinstance(self.text_projection, nn.Linear):
690
+ pooled = self.text_projection(pooled)
691
+ else:
692
+ pooled = pooled @ self.text_projection
693
+
694
+ if self.output_tokens:
695
+ return pooled, tokens
696
+
697
+ return pooled
698
+
699
+
700
+
701
+
702
+ class CLIP(TextTransformer):
703
+ def __init__(
704
+ self,
705
+ vision_cfg: PEConfig,
706
+ text_cfg: PETextConfig,
707
+ init_logit_scale: float = np.log(1 / 0.07)
708
+ ):
709
+ super(CLIP, self).__init__(**asdict(text_cfg))
710
+ self.visual = VisionTransformer(**asdict(vision_cfg))
711
+ self.image_size = self.visual.image_size # For ease of use
712
+ self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
713
+
714
+
715
+ def encode_image(self, image, normalize: bool = False):
716
+ x = self.visual(image)
717
+ return F.normalize(x, dim=-1) if normalize else x
718
+
719
+ def encode_video(self, video, normalize: bool = False): # b n c h w
720
+ b, n, c, h, w = video.shape
721
+ frms = video.reshape(b * n, c, h, w)
722
+ frm_feats = self.encode_image(frms, normalize=normalize)
723
+ video_feats = frm_feats.reshape(b, n, -1)
724
+ video_feats = video_feats.mean(dim=1)
725
+ return video_feats
726
+
727
+ def encode_text(self, text, normalize: bool = False):
728
+ x = super().forward(text)
729
+ return F.normalize(x, dim=-1) if normalize else x
730
+
731
+ def forward(
732
+ self,
733
+ image: Optional[torch.Tensor] = None,
734
+ text: Optional[torch.Tensor] = None,
735
+ ):
736
+ image_features = (
737
+ self.encode_image(image, normalize=True) if image is not None else None
738
+ )
739
+ text_features = (
740
+ self.encode_text(text, normalize=True) if text is not None else None
741
+ )
742
+ return image_features, text_features, self.logit_scale.exp()
743
+
744
+
745
+ @classmethod
746
+ def from_config(
747
+ cls,
748
+ name: str,
749
+ pretrained: bool = False,
750
+ checkpoint_path: Optional[str] = None # To load your own
751
+ ):
752
+ if name not in PE_VISION_CONFIG or name not in PE_TEXT_CONFIG:
753
+ raise RuntimeError(f"{name} not found in configs.")
754
+
755
+ model = cls(PE_VISION_CONFIG[name], PE_TEXT_CONFIG[name])
756
+ model_path = ""
757
+ if pretrained:
758
+ model_path = fetch_pe_checkpoint(name, checkpoint_path)
759
+ model.load_ckpt(model_path)
760
+
761
+ # CHANGED: Different from the original implementation, the model_path is returned as well.
762
+ return model, model_path
763
+
764
+ @classmethod
765
+ def available_configs(cls):
766
+ return [k for k in PE_VISION_CONFIG if k in PE_TEXT_CONFIG]