lightly-studio 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (356) hide show
  1. lightly_studio/__init__.py +12 -0
  2. lightly_studio/api/__init__.py +0 -0
  3. lightly_studio/api/app.py +131 -0
  4. lightly_studio/api/cache.py +77 -0
  5. lightly_studio/api/db_tables.py +35 -0
  6. lightly_studio/api/features.py +5 -0
  7. lightly_studio/api/routes/api/annotation.py +305 -0
  8. lightly_studio/api/routes/api/annotation_label.py +87 -0
  9. lightly_studio/api/routes/api/annotations/__init__.py +7 -0
  10. lightly_studio/api/routes/api/annotations/create_annotation.py +52 -0
  11. lightly_studio/api/routes/api/caption.py +100 -0
  12. lightly_studio/api/routes/api/classifier.py +384 -0
  13. lightly_studio/api/routes/api/dataset.py +191 -0
  14. lightly_studio/api/routes/api/dataset_tag.py +266 -0
  15. lightly_studio/api/routes/api/embeddings2d.py +90 -0
  16. lightly_studio/api/routes/api/exceptions.py +114 -0
  17. lightly_studio/api/routes/api/export.py +114 -0
  18. lightly_studio/api/routes/api/features.py +17 -0
  19. lightly_studio/api/routes/api/frame.py +241 -0
  20. lightly_studio/api/routes/api/image.py +155 -0
  21. lightly_studio/api/routes/api/metadata.py +161 -0
  22. lightly_studio/api/routes/api/operator.py +75 -0
  23. lightly_studio/api/routes/api/sample.py +103 -0
  24. lightly_studio/api/routes/api/selection.py +87 -0
  25. lightly_studio/api/routes/api/settings.py +41 -0
  26. lightly_studio/api/routes/api/status.py +19 -0
  27. lightly_studio/api/routes/api/text_embedding.py +50 -0
  28. lightly_studio/api/routes/api/validators.py +17 -0
  29. lightly_studio/api/routes/api/video.py +133 -0
  30. lightly_studio/api/routes/healthz.py +13 -0
  31. lightly_studio/api/routes/images.py +104 -0
  32. lightly_studio/api/routes/video_frames_media.py +116 -0
  33. lightly_studio/api/routes/video_media.py +223 -0
  34. lightly_studio/api/routes/webapp.py +51 -0
  35. lightly_studio/api/server.py +94 -0
  36. lightly_studio/core/__init__.py +0 -0
  37. lightly_studio/core/add_samples.py +533 -0
  38. lightly_studio/core/add_videos.py +294 -0
  39. lightly_studio/core/dataset.py +780 -0
  40. lightly_studio/core/dataset_query/__init__.py +14 -0
  41. lightly_studio/core/dataset_query/boolean_expression.py +67 -0
  42. lightly_studio/core/dataset_query/dataset_query.py +317 -0
  43. lightly_studio/core/dataset_query/field.py +113 -0
  44. lightly_studio/core/dataset_query/field_expression.py +79 -0
  45. lightly_studio/core/dataset_query/match_expression.py +23 -0
  46. lightly_studio/core/dataset_query/order_by.py +79 -0
  47. lightly_studio/core/dataset_query/sample_field.py +37 -0
  48. lightly_studio/core/dataset_query/tags_expression.py +46 -0
  49. lightly_studio/core/image_sample.py +36 -0
  50. lightly_studio/core/loading_log.py +56 -0
  51. lightly_studio/core/sample.py +291 -0
  52. lightly_studio/core/start_gui.py +54 -0
  53. lightly_studio/core/video_sample.py +38 -0
  54. lightly_studio/dataset/__init__.py +0 -0
  55. lightly_studio/dataset/edge_embedding_generator.py +155 -0
  56. lightly_studio/dataset/embedding_generator.py +129 -0
  57. lightly_studio/dataset/embedding_manager.py +349 -0
  58. lightly_studio/dataset/env.py +20 -0
  59. lightly_studio/dataset/file_utils.py +49 -0
  60. lightly_studio/dataset/fsspec_lister.py +275 -0
  61. lightly_studio/dataset/mobileclip_embedding_generator.py +158 -0
  62. lightly_studio/dataset/perception_encoder_embedding_generator.py +260 -0
  63. lightly_studio/db_manager.py +166 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.GcXvs2l7.css +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/12.Dx6SXgAb.css +1 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/17.9X9_k6TP.css +1 -0
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/18.BxiimdIO.css +1 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/2.CkOblLn7.css +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/ClassifierSamplesGrid.BJbCDlvs.css +1 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Medium.DVUZMR_6.ttf +0 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
  78. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.CefECEWA.css +1 -0
  79. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.D5tDcjY-.css +1 -0
  80. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.9X9_k6TP.css +1 -0
  81. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.BxiimdIO.css +1 -0
  82. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.Dx6SXgAb.css +1 -0
  83. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/transform._-1mPSEI.css +1 -0
  84. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/0dDyq72A.js +20 -0
  85. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
  86. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BK4An2kI.js +1 -0
  87. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BRmB-kJ9.js +1 -0
  88. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B_1cpokE.js +1 -0
  89. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BiqpDEr0.js +1 -0
  90. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BpLiSKgx.js +1 -0
  91. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BscxbINH.js +39 -0
  92. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C1FmrZbK.js +1 -0
  93. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C80h3dJx.js +1 -0
  94. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8mfFM-u.js +2 -0
  95. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CGY1p9L4.js +517 -0
  96. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/COfLknXM.js +1 -0
  97. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
  98. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
  99. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CmLg0ys7.js +1 -0
  100. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvGjimpO.js +1 -0
  101. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D3RDXHoj.js +39 -0
  102. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D4y7iiT3.js +1 -0
  103. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D9SC3jBb.js +1 -0
  104. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DCuAdx1Q.js +20 -0
  105. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DDBy-_jD.js +1 -0
  106. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
  107. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DL9a7v5o.js +1 -0
  108. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DSKECuqX.js +39 -0
  109. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D_FFv0Oe.js +1 -0
  110. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DiZ5o5vz.js +1 -0
  111. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DkbXUtyG.js +1 -0
  112. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DmK2hulV.js +1 -0
  113. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DqnHaLTj.js +1 -0
  114. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DtWZc_tl.js +1 -0
  115. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DuUalyFS.js +1 -0
  116. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DwIonDAZ.js +1 -0
  117. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Il-mSPmK.js +1 -0
  118. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KNLP4aJU.js +1 -0
  119. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KjYeVjkE.js +1 -0
  120. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/MErlcOXj.js +1 -0
  121. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VRI4prUD.js +1 -0
  122. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VYb2dkNs.js +1 -0
  123. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VqWvU2yF.js +1 -0
  124. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/dHC3otuL.js +1 -0
  125. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/da7Oy_lO.js +1 -0
  126. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/eAy8rZzC.js +2 -0
  127. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/erjNR5MX.js +1 -0
  128. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/f1oG3eFE.js +1 -0
  129. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/rsLi1iKv.js +20 -0
  130. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/rwuuBP9f.js +1 -0
  131. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/xGHZQ1pe.js +3 -0
  132. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.DrTRUgT3.js +2 -0
  133. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.BK5EOJl2.js +1 -0
  134. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.CIvTuljF.js +4 -0
  135. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.UBvSzxdA.js +1 -0
  136. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.CQ_tiLJa.js +1 -0
  137. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.KqkAcaxW.js +1 -0
  138. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.DoYsmxQc.js +1 -0
  139. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/13.571n2LZA.js +1 -0
  140. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/14.DGs689M-.js +1 -0
  141. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/15.CWG1ehzT.js +1 -0
  142. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/16.Dpq6jbSh.js +1 -0
  143. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/17.B5AZbHUU.js +1 -0
  144. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/18.CBga8cnq.js +1 -0
  145. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.D2HXgz-8.js +1090 -0
  146. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.f4HAg-y3.js +1 -0
  147. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.BKF4xuKQ.js +1 -0
  148. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.BAE0Pm_f.js +39 -0
  149. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.CouWWpzA.js +1 -0
  150. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.UBHT0ktp.js +1 -0
  151. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.FiYNElcc.js +1 -0
  152. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.B3-UaT23.js +1 -0
  153. lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/clustering.worker-DKqeLtG0.js +2 -0
  154. lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/search.worker-vNSty3B0.js +1 -0
  155. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
  156. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
  157. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
  158. lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
  159. lightly_studio/dist_lightly_studio_view_app/index.html +45 -0
  160. lightly_studio/errors.py +5 -0
  161. lightly_studio/examples/example.py +25 -0
  162. lightly_studio/examples/example_coco.py +27 -0
  163. lightly_studio/examples/example_coco_caption.py +29 -0
  164. lightly_studio/examples/example_metadata.py +369 -0
  165. lightly_studio/examples/example_operators.py +111 -0
  166. lightly_studio/examples/example_selection.py +28 -0
  167. lightly_studio/examples/example_split_work.py +48 -0
  168. lightly_studio/examples/example_video.py +22 -0
  169. lightly_studio/examples/example_video_annotations.py +157 -0
  170. lightly_studio/examples/example_yolo.py +22 -0
  171. lightly_studio/export/coco_captions.py +69 -0
  172. lightly_studio/export/export_dataset.py +104 -0
  173. lightly_studio/export/lightly_studio_label_input.py +120 -0
  174. lightly_studio/export_schema.py +18 -0
  175. lightly_studio/export_version.py +57 -0
  176. lightly_studio/few_shot_classifier/__init__.py +0 -0
  177. lightly_studio/few_shot_classifier/classifier.py +80 -0
  178. lightly_studio/few_shot_classifier/classifier_manager.py +644 -0
  179. lightly_studio/few_shot_classifier/random_forest_classifier.py +495 -0
  180. lightly_studio/metadata/complex_metadata.py +47 -0
  181. lightly_studio/metadata/compute_similarity.py +84 -0
  182. lightly_studio/metadata/compute_typicality.py +67 -0
  183. lightly_studio/metadata/gps_coordinate.py +41 -0
  184. lightly_studio/metadata/metadata_protocol.py +17 -0
  185. lightly_studio/models/__init__.py +1 -0
  186. lightly_studio/models/annotation/__init__.py +0 -0
  187. lightly_studio/models/annotation/annotation_base.py +303 -0
  188. lightly_studio/models/annotation/instance_segmentation.py +56 -0
  189. lightly_studio/models/annotation/links.py +17 -0
  190. lightly_studio/models/annotation/object_detection.py +47 -0
  191. lightly_studio/models/annotation/semantic_segmentation.py +44 -0
  192. lightly_studio/models/annotation_label.py +47 -0
  193. lightly_studio/models/caption.py +49 -0
  194. lightly_studio/models/classifier.py +20 -0
  195. lightly_studio/models/dataset.py +70 -0
  196. lightly_studio/models/embedding_model.py +30 -0
  197. lightly_studio/models/image.py +96 -0
  198. lightly_studio/models/metadata.py +208 -0
  199. lightly_studio/models/range.py +17 -0
  200. lightly_studio/models/sample.py +154 -0
  201. lightly_studio/models/sample_embedding.py +36 -0
  202. lightly_studio/models/settings.py +69 -0
  203. lightly_studio/models/tag.py +96 -0
  204. lightly_studio/models/two_dim_embedding.py +16 -0
  205. lightly_studio/models/video.py +161 -0
  206. lightly_studio/plugins/__init__.py +0 -0
  207. lightly_studio/plugins/base_operator.py +60 -0
  208. lightly_studio/plugins/operator_registry.py +47 -0
  209. lightly_studio/plugins/parameter.py +70 -0
  210. lightly_studio/py.typed +0 -0
  211. lightly_studio/resolvers/__init__.py +0 -0
  212. lightly_studio/resolvers/annotation_label_resolver/__init__.py +22 -0
  213. lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
  214. lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
  215. lightly_studio/resolvers/annotation_label_resolver/get_all.py +37 -0
  216. lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
  217. lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
  218. lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
  219. lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
  220. lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
  221. lightly_studio/resolvers/annotation_resolver/__init__.py +40 -0
  222. lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +129 -0
  223. lightly_studio/resolvers/annotation_resolver/create_many.py +124 -0
  224. lightly_studio/resolvers/annotation_resolver/delete_annotation.py +87 -0
  225. lightly_studio/resolvers/annotation_resolver/delete_annotations.py +60 -0
  226. lightly_studio/resolvers/annotation_resolver/get_all.py +85 -0
  227. lightly_studio/resolvers/annotation_resolver/get_all_with_payload.py +179 -0
  228. lightly_studio/resolvers/annotation_resolver/get_by_id.py +34 -0
  229. lightly_studio/resolvers/annotation_resolver/get_by_id_with_payload.py +130 -0
  230. lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +142 -0
  231. lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
  232. lightly_studio/resolvers/annotations/__init__.py +1 -0
  233. lightly_studio/resolvers/annotations/annotations_filter.py +88 -0
  234. lightly_studio/resolvers/caption_resolver.py +129 -0
  235. lightly_studio/resolvers/dataset_resolver/__init__.py +55 -0
  236. lightly_studio/resolvers/dataset_resolver/check_dataset_type.py +29 -0
  237. lightly_studio/resolvers/dataset_resolver/create.py +20 -0
  238. lightly_studio/resolvers/dataset_resolver/delete.py +20 -0
  239. lightly_studio/resolvers/dataset_resolver/export.py +267 -0
  240. lightly_studio/resolvers/dataset_resolver/get_all.py +19 -0
  241. lightly_studio/resolvers/dataset_resolver/get_by_id.py +16 -0
  242. lightly_studio/resolvers/dataset_resolver/get_by_name.py +12 -0
  243. lightly_studio/resolvers/dataset_resolver/get_dataset_details.py +27 -0
  244. lightly_studio/resolvers/dataset_resolver/get_hierarchy.py +31 -0
  245. lightly_studio/resolvers/dataset_resolver/get_or_create_child_dataset.py +58 -0
  246. lightly_studio/resolvers/dataset_resolver/get_parent_dataset_by_sample_id.py +27 -0
  247. lightly_studio/resolvers/dataset_resolver/get_parent_dataset_id.py +22 -0
  248. lightly_studio/resolvers/dataset_resolver/get_root_dataset.py +61 -0
  249. lightly_studio/resolvers/dataset_resolver/get_root_datasets_overview.py +41 -0
  250. lightly_studio/resolvers/dataset_resolver/update.py +25 -0
  251. lightly_studio/resolvers/embedding_model_resolver.py +120 -0
  252. lightly_studio/resolvers/image_filter.py +50 -0
  253. lightly_studio/resolvers/image_resolver/__init__.py +21 -0
  254. lightly_studio/resolvers/image_resolver/create_many.py +52 -0
  255. lightly_studio/resolvers/image_resolver/delete.py +20 -0
  256. lightly_studio/resolvers/image_resolver/filter_new_paths.py +23 -0
  257. lightly_studio/resolvers/image_resolver/get_all_by_dataset_id.py +117 -0
  258. lightly_studio/resolvers/image_resolver/get_by_id.py +14 -0
  259. lightly_studio/resolvers/image_resolver/get_dimension_bounds.py +75 -0
  260. lightly_studio/resolvers/image_resolver/get_many_by_id.py +22 -0
  261. lightly_studio/resolvers/image_resolver/get_samples_excluding.py +43 -0
  262. lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
  263. lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
  264. lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
  265. lightly_studio/resolvers/metadata_resolver/sample/bulk_update_metadata.py +46 -0
  266. lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
  267. lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
  268. lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
  269. lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
  270. lightly_studio/resolvers/sample_embedding_resolver.py +132 -0
  271. lightly_studio/resolvers/sample_resolver/__init__.py +17 -0
  272. lightly_studio/resolvers/sample_resolver/count_by_dataset_id.py +16 -0
  273. lightly_studio/resolvers/sample_resolver/create.py +16 -0
  274. lightly_studio/resolvers/sample_resolver/create_many.py +25 -0
  275. lightly_studio/resolvers/sample_resolver/get_by_id.py +14 -0
  276. lightly_studio/resolvers/sample_resolver/get_filtered_samples.py +56 -0
  277. lightly_studio/resolvers/sample_resolver/get_many_by_id.py +22 -0
  278. lightly_studio/resolvers/sample_resolver/sample_filter.py +74 -0
  279. lightly_studio/resolvers/settings_resolver.py +62 -0
  280. lightly_studio/resolvers/tag_resolver.py +299 -0
  281. lightly_studio/resolvers/twodim_embedding_resolver.py +119 -0
  282. lightly_studio/resolvers/video_frame_resolver/__init__.py +23 -0
  283. lightly_studio/resolvers/video_frame_resolver/count_video_frames_annotations.py +83 -0
  284. lightly_studio/resolvers/video_frame_resolver/create_many.py +57 -0
  285. lightly_studio/resolvers/video_frame_resolver/get_all_by_dataset_id.py +63 -0
  286. lightly_studio/resolvers/video_frame_resolver/get_by_id.py +13 -0
  287. lightly_studio/resolvers/video_frame_resolver/get_table_fields_bounds.py +44 -0
  288. lightly_studio/resolvers/video_frame_resolver/video_frame_annotations_counter_filter.py +47 -0
  289. lightly_studio/resolvers/video_frame_resolver/video_frame_filter.py +57 -0
  290. lightly_studio/resolvers/video_resolver/__init__.py +27 -0
  291. lightly_studio/resolvers/video_resolver/count_video_frame_annotations_by_video_dataset.py +86 -0
  292. lightly_studio/resolvers/video_resolver/create_many.py +58 -0
  293. lightly_studio/resolvers/video_resolver/filter_new_paths.py +33 -0
  294. lightly_studio/resolvers/video_resolver/get_all_by_dataset_id.py +181 -0
  295. lightly_studio/resolvers/video_resolver/get_by_id.py +22 -0
  296. lightly_studio/resolvers/video_resolver/get_table_fields_bounds.py +72 -0
  297. lightly_studio/resolvers/video_resolver/get_view_by_id.py +52 -0
  298. lightly_studio/resolvers/video_resolver/video_count_annotations_filter.py +50 -0
  299. lightly_studio/resolvers/video_resolver/video_filter.py +98 -0
  300. lightly_studio/selection/__init__.py +1 -0
  301. lightly_studio/selection/mundig.py +143 -0
  302. lightly_studio/selection/select.py +203 -0
  303. lightly_studio/selection/select_via_db.py +273 -0
  304. lightly_studio/selection/selection_config.py +49 -0
  305. lightly_studio/services/annotations_service/__init__.py +33 -0
  306. lightly_studio/services/annotations_service/create_annotation.py +64 -0
  307. lightly_studio/services/annotations_service/delete_annotation.py +22 -0
  308. lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
  309. lightly_studio/services/annotations_service/update_annotation.py +54 -0
  310. lightly_studio/services/annotations_service/update_annotation_bounding_box.py +36 -0
  311. lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
  312. lightly_studio/services/annotations_service/update_annotations.py +29 -0
  313. lightly_studio/setup_logging.py +59 -0
  314. lightly_studio/type_definitions.py +31 -0
  315. lightly_studio/utils/__init__.py +3 -0
  316. lightly_studio/utils/download.py +94 -0
  317. lightly_studio/vendor/__init__.py +1 -0
  318. lightly_studio/vendor/mobileclip/ACKNOWLEDGEMENTS +422 -0
  319. lightly_studio/vendor/mobileclip/LICENSE +31 -0
  320. lightly_studio/vendor/mobileclip/LICENSE_weights_data +50 -0
  321. lightly_studio/vendor/mobileclip/README.md +5 -0
  322. lightly_studio/vendor/mobileclip/__init__.py +96 -0
  323. lightly_studio/vendor/mobileclip/clip.py +77 -0
  324. lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
  325. lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
  326. lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
  327. lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
  328. lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
  329. lightly_studio/vendor/mobileclip/logger.py +154 -0
  330. lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
  331. lightly_studio/vendor/mobileclip/models/mci.py +933 -0
  332. lightly_studio/vendor/mobileclip/models/vit.py +433 -0
  333. lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
  334. lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
  335. lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
  336. lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
  337. lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
  338. lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
  339. lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
  340. lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
  341. lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
  342. lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
  343. lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
  344. lightly_studio/vendor/perception_encoder/LICENSE.PE +201 -0
  345. lightly_studio/vendor/perception_encoder/README.md +11 -0
  346. lightly_studio/vendor/perception_encoder/vision_encoder/__init__.py +0 -0
  347. lightly_studio/vendor/perception_encoder/vision_encoder/bpe_simple_vocab_16e6.txt.gz +0 -0
  348. lightly_studio/vendor/perception_encoder/vision_encoder/config.py +205 -0
  349. lightly_studio/vendor/perception_encoder/vision_encoder/config_src.py +264 -0
  350. lightly_studio/vendor/perception_encoder/vision_encoder/pe.py +766 -0
  351. lightly_studio/vendor/perception_encoder/vision_encoder/rope.py +352 -0
  352. lightly_studio/vendor/perception_encoder/vision_encoder/tokenizer.py +347 -0
  353. lightly_studio/vendor/perception_encoder/vision_encoder/transforms.py +36 -0
  354. lightly_studio-0.4.6.dist-info/METADATA +88 -0
  355. lightly_studio-0.4.6.dist-info/RECORD +356 -0
  356. lightly_studio-0.4.6.dist-info/WHEEL +4 -0
@@ -0,0 +1,352 @@
1
+ #
2
+ # For licensing see accompanying LICENSE.PE file.
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ #
5
+
6
+ from math import log, pi
7
+ from typing import Literal, Optional, Union
8
+
9
+ import torch
10
+ from einops import rearrange, repeat
11
+ from torch import Tensor, broadcast_tensors, einsum, nn
12
+ from torch.amp import autocast
13
+ from torch.nn import Module, ModuleList
14
+
15
+ # helper functions
16
+
17
+
18
+ def exists(val):
19
+ return val is not None
20
+
21
+
22
+ def default(val, d):
23
+ return val if exists(val) else d
24
+
25
+
26
+ # broadcat, as tortoise-tts was using it
27
+
28
+
29
+ def broadcat(tensors, dim=-1):
30
+ broadcasted_tensors = broadcast_tensors(*tensors)
31
+ return torch.cat(broadcasted_tensors, dim=dim)
32
+
33
+
34
+ # rotary embedding helper functions
35
+
36
+
37
+ def rotate_half(x):
38
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
39
+ x1, x2 = x.unbind(dim=-1)
40
+ x = torch.stack((-x2, x1), dim=-1)
41
+ return rearrange(x, "... d r -> ... (d r)")
42
+
43
+
44
+ @autocast("cuda", enabled=False)
45
+ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
46
+ dtype = t.dtype
47
+
48
+ if t.ndim == 3:
49
+ seq_len = t.shape[seq_dim]
50
+ freqs = freqs[-seq_len:]
51
+
52
+ rot_dim = freqs.shape[-1]
53
+ end_index = start_index + rot_dim
54
+
55
+ assert (
56
+ rot_dim <= t.shape[-1]
57
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
58
+
59
+ t_left, t, t_right = (
60
+ t[..., :start_index],
61
+ t[..., start_index:end_index],
62
+ t[..., end_index:],
63
+ )
64
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
65
+ out = torch.cat((t_left, t, t_right), dim=-1)
66
+
67
+ return out.type(dtype)
68
+
69
+
70
+ # learned rotation helpers
71
+
72
+
73
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
74
+ if exists(freq_ranges):
75
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
76
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
77
+
78
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
79
+ return apply_rotary_emb(rotations, t, start_index=start_index)
80
+
81
+
82
+ # classes
83
+
84
+
85
+ class RotaryEmbedding(Module):
86
+ def __init__(
87
+ self,
88
+ dim,
89
+ custom_freqs: Optional[Tensor] = None,
90
+ freqs_for: Union[
91
+ Literal["lang"], Literal["pixel"], Literal["constant"]
92
+ ] = "lang",
93
+ theta=10000,
94
+ max_freq=10,
95
+ num_freqs=1,
96
+ learned_freq=False,
97
+ use_xpos=False,
98
+ xpos_scale_base=512,
99
+ interpolate_factor=1.0,
100
+ theta_rescale_factor=1.0,
101
+ seq_before_head_dim=False,
102
+ cache_if_possible=True,
103
+ ):
104
+ super().__init__()
105
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
106
+ # has some connection to NTK literature
107
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
108
+
109
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
110
+
111
+ self.freqs_for = freqs_for
112
+
113
+ if exists(custom_freqs):
114
+ freqs = custom_freqs
115
+ elif freqs_for == "lang":
116
+ freqs = 1.0 / (
117
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
118
+ )
119
+ elif freqs_for == "pixel":
120
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
121
+ elif freqs_for == "constant":
122
+ freqs = torch.ones(num_freqs).float()
123
+
124
+ self.cache_if_possible = cache_if_possible
125
+
126
+ self.tmp_store("cached_freqs", None)
127
+ self.tmp_store("cached_scales", None)
128
+
129
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
130
+
131
+ self.learned_freq = learned_freq
132
+
133
+ # dummy for device
134
+
135
+ self.tmp_store("dummy", torch.tensor(0))
136
+
137
+ # default sequence dimension
138
+
139
+ self.seq_before_head_dim = seq_before_head_dim
140
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
141
+
142
+ # interpolation factors
143
+
144
+ assert interpolate_factor >= 1.0
145
+ self.interpolate_factor = interpolate_factor
146
+
147
+ # xpos
148
+
149
+ self.use_xpos = use_xpos
150
+ if not use_xpos:
151
+ self.tmp_store("scale", None)
152
+ return
153
+
154
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
155
+
156
+ self.scale_base = xpos_scale_base
157
+ self.tmp_store("scale", scale)
158
+
159
+ # add apply_rotary_emb as static method
160
+
161
+ self.apply_rotary_emb = staticmethod(apply_rotary_emb)
162
+
163
+ @property
164
+ def device(self):
165
+ return self.dummy.device
166
+
167
+ def tmp_store(self, key, value):
168
+ self.register_buffer(key, value, persistent=False)
169
+
170
+ def get_seq_pos(self, seq_len, device, dtype, offset=0):
171
+ return (
172
+ torch.arange(seq_len, device=device, dtype=dtype) + offset
173
+ ) / self.interpolate_factor
174
+
175
+ def rotate_queries_or_keys(self, t, seq_dim=None, offset=0):
176
+ seq_dim = default(seq_dim, self.default_seq_dim)
177
+
178
+ assert (
179
+ not self.use_xpos
180
+ ), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
181
+
182
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
183
+
184
+ freqs = self.forward(
185
+ self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset),
186
+ seq_len=seq_len,
187
+ offset=offset,
188
+ )
189
+
190
+ if seq_dim == -3:
191
+ freqs = rearrange(freqs, "n d -> n 1 d")
192
+
193
+ return apply_rotary_emb(freqs, t, seq_dim=seq_dim)
194
+
195
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
196
+ seq_dim = default(seq_dim, self.default_seq_dim)
197
+
198
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
199
+ assert q_len <= k_len
200
+
201
+ rotated_q = self.rotate_queries_or_keys(
202
+ q, seq_dim=seq_dim, offset=k_len - q_len + offset
203
+ )
204
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset)
205
+
206
+ rotated_q = rotated_q.type(q.dtype)
207
+ rotated_k = rotated_k.type(k.dtype)
208
+
209
+ return rotated_q, rotated_k
210
+
211
+ def rotate_queries_and_keys(self, q, k, seq_dim=None):
212
+ seq_dim = default(seq_dim, self.default_seq_dim)
213
+
214
+ assert self.use_xpos
215
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
216
+
217
+ seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
218
+
219
+ freqs = self.forward(seq, seq_len=seq_len)
220
+ scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
221
+
222
+ if seq_dim == -3:
223
+ freqs = rearrange(freqs, "n d -> n 1 d")
224
+ scale = rearrange(scale, "n d -> n 1 d")
225
+
226
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
227
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
228
+
229
+ rotated_q = rotated_q.type(q.dtype)
230
+ rotated_k = rotated_k.type(k.dtype)
231
+
232
+ return rotated_q, rotated_k
233
+
234
+ def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0):
235
+ assert self.use_xpos
236
+
237
+ should_cache = self.cache_if_possible and exists(seq_len)
238
+
239
+ if (
240
+ should_cache
241
+ and exists(self.cached_scales)
242
+ and (seq_len + offset) <= self.cached_scales.shape[0]
243
+ ):
244
+ return self.cached_scales[offset : (offset + seq_len)]
245
+
246
+ scale = 1.0
247
+ if self.use_xpos:
248
+ power = (t - len(t) // 2) / self.scale_base
249
+ scale = self.scale ** rearrange(power, "n -> n 1")
250
+ scale = torch.cat((scale, scale), dim=-1)
251
+
252
+ if should_cache:
253
+ self.tmp_store("cached_scales", scale)
254
+
255
+ return scale
256
+
257
+ def get_axial_freqs(self, *dims):
258
+ Colon = slice(None)
259
+ all_freqs = []
260
+
261
+ for ind, dim in enumerate(dims):
262
+ if self.freqs_for == "pixel":
263
+ pos = torch.linspace(-1, 1, steps=dim, device=self.device)
264
+ else:
265
+ pos = torch.arange(dim, device=self.device)
266
+
267
+ freqs = self.forward(pos, seq_len=dim)
268
+
269
+ all_axis = [None] * len(dims)
270
+ all_axis[ind] = Colon
271
+
272
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
273
+ all_freqs.append(freqs[new_axis_slice])
274
+
275
+ all_freqs = broadcast_tensors(*all_freqs)
276
+ return torch.cat(all_freqs, dim=-1)
277
+
278
+ @autocast("cuda", enabled=False)
279
+ def forward(self, t: Tensor, seq_len=None, offset=0):
280
+ should_cache = (
281
+ self.cache_if_possible
282
+ and not self.learned_freq
283
+ and exists(seq_len)
284
+ and self.freqs_for != "pixel"
285
+ )
286
+
287
+ if (
288
+ should_cache
289
+ and exists(self.cached_freqs)
290
+ and (offset + seq_len) <= self.cached_freqs.shape[0]
291
+ ):
292
+ return self.cached_freqs[offset : (offset + seq_len)].detach()
293
+
294
+ freqs = self.freqs
295
+
296
+ freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
297
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
298
+
299
+ if should_cache:
300
+ self.tmp_store("cached_freqs", freqs.detach())
301
+
302
+ return freqs
303
+
304
+
305
+
306
+
307
+
308
+ class Rope2D:
309
+ """ Helper class to apply RoPE2D as well as interpolate on the fly. """
310
+
311
+ def __init__(self, dim, use_cls_token=False):
312
+ self.dim = dim
313
+ self.use_cls_token = use_cls_token
314
+ self.grid_size = None
315
+ self.freq = None
316
+
317
+ def init_tensors(self):
318
+ self.rope = RotaryEmbedding(self.dim // 2)
319
+
320
+ def update_grid(self, device, grid_h, grid_w):
321
+ if self.grid_size != (grid_h, grid_w):
322
+ self.grid_size = (grid_h, grid_w)
323
+
324
+ self.rope = self.rope.to(device)
325
+
326
+ if self.use_cls_token:
327
+ # +1 to leave space for the cls token to be (0, 0)
328
+ grid_y_range = torch.arange(grid_h, device=device) + 1
329
+ grid_x_range = torch.arange(grid_w, device=device) + 1
330
+ else:
331
+ grid_y_range = torch.arange(grid_h, device=device)
332
+ grid_x_range = torch.arange(grid_w, device=device)
333
+
334
+ freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1)
335
+ freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
336
+ freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
337
+
338
+ if self.use_cls_token:
339
+ freq = torch.cat(
340
+ [torch.zeros(1, freq.shape[-1], device=device), freq], dim=0
341
+ )
342
+
343
+ self.freq = freq[None, ...]
344
+
345
+ self.freq = self.freq.to(device)
346
+
347
+ def __call__(self, q, k):
348
+ # batch, heads, seq, dim = q.shape
349
+ q = apply_rotary_emb(self.freq[:, None, :, :], q)
350
+ k = apply_rotary_emb(self.freq[:, None, :, :], k)
351
+
352
+ return q, k
@@ -0,0 +1,347 @@
1
+ #
2
+ # For licensing see accompanying LICENSE.PE file.
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ #
5
+ # Original Header:
6
+ # CLIP tokenizer
7
+ #
8
+ # Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
9
+ #
10
+
11
+ import gzip
12
+ import html
13
+ import os
14
+ import random
15
+ import string
16
+ from functools import lru_cache, partial
17
+ from typing import Callable, List, Optional, Union
18
+
19
+ import ftfy
20
+ import regex as re
21
+ import torch
22
+
23
+ # https://stackoverflow.com/q/62691279
24
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
25
+
26
+ DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP
27
+
28
+
29
+ @lru_cache()
30
+ def default_bpe():
31
+ return os.path.join(
32
+ os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
33
+ )
34
+
35
+
36
+ @lru_cache()
37
+ def bytes_to_unicode():
38
+ """
39
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
40
+ The reversible bpe codes work on unicode strings.
41
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
42
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
43
+ This is a significant percentage of your normal, say, 32K bpe vocab.
44
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
45
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
46
+ """
47
+ bs = (
48
+ list(range(ord("!"), ord("~") + 1))
49
+ + list(range(ord("¡"), ord("¬") + 1))
50
+ + list(range(ord("®"), ord("ÿ") + 1))
51
+ )
52
+ cs = bs[:]
53
+ n = 0
54
+ for b in range(2**8):
55
+ if b not in bs:
56
+ bs.append(b)
57
+ cs.append(2**8 + n)
58
+ n += 1
59
+ cs = [chr(n) for n in cs]
60
+ return dict(zip(bs, cs))
61
+
62
+
63
+ def get_pairs(word):
64
+ """Return set of symbol pairs in a word.
65
+ Word is represented as tuple of symbols (symbols being variable-length strings).
66
+ """
67
+ pairs = set()
68
+ prev_char = word[0]
69
+ for char in word[1:]:
70
+ pairs.add((prev_char, char))
71
+ prev_char = char
72
+ return pairs
73
+
74
+
75
+ def basic_clean(text):
76
+ text = ftfy.fix_text(text)
77
+ text = html.unescape(html.unescape(text))
78
+ return text.strip()
79
+
80
+
81
+ def whitespace_clean(text):
82
+ text = re.sub(r"\s+", " ", text)
83
+ text = text.strip()
84
+ return text
85
+
86
+
87
+ def _clean_canonicalize(x):
88
+ # basic, remove whitespace, remove punctuation, lower case
89
+ return canonicalize_text(basic_clean(x))
90
+
91
+
92
+ def _clean_lower(x):
93
+ # basic, remove whitespace, lower case
94
+ return whitespace_clean(basic_clean(x)).lower()
95
+
96
+
97
+ def _clean_whitespace(x):
98
+ # basic, remove whitespace
99
+ return whitespace_clean(basic_clean(x))
100
+
101
+
102
+ def get_clean_fn(type: str):
103
+ if type == "canonicalize":
104
+ return _clean_canonicalize
105
+ elif type == "lower":
106
+ return _clean_lower
107
+ elif type == "whitespace":
108
+ return _clean_whitespace
109
+ else:
110
+ assert False, f"Invalid clean function ({type})."
111
+
112
+
113
+ def canonicalize_text(text, *, keep_punctuation_exact_string=None):
114
+ """Returns canonicalized `text` (lowercase and punctuation removed).
115
+
116
+ From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
117
+
118
+ Args:
119
+ text: string to be canonicalized.
120
+ keep_punctuation_exact_string: If provided, then this exact string kept.
121
+ For example providing '{}' will keep any occurrences of '{}' (but will
122
+ still remove '{' and '}' that appear separately).
123
+ """
124
+ text = text.replace("_", " ")
125
+ if keep_punctuation_exact_string:
126
+ text = keep_punctuation_exact_string.join(
127
+ part.translate(str.maketrans("", "", string.punctuation))
128
+ for part in text.split(keep_punctuation_exact_string)
129
+ )
130
+ else:
131
+ text = text.translate(str.maketrans("", "", string.punctuation))
132
+ text = text.lower()
133
+ text = re.sub(r"\s+", " ", text)
134
+ return text.strip()
135
+
136
+
137
+ class SimpleTokenizer(object):
138
+ def __init__(
139
+ self,
140
+ bpe_path: str = default_bpe(),
141
+ additional_special_tokens: Optional[List[str]] = None,
142
+ context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
143
+ clean: str = "lower",
144
+ reduction_mask: str = "",
145
+ ):
146
+ self.byte_encoder = bytes_to_unicode()
147
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
148
+ merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
149
+ merges = merges[1 : 49152 - 256 - 2 + 1]
150
+ merges = [tuple(merge.split()) for merge in merges]
151
+ vocab = list(bytes_to_unicode().values())
152
+ vocab = vocab + [v + "</w>" for v in vocab]
153
+ for merge in merges:
154
+ vocab.append("".join(merge))
155
+ special_tokens = ["<start_of_text>", "<end_of_text>"]
156
+ if additional_special_tokens:
157
+ special_tokens += additional_special_tokens
158
+ vocab.extend(special_tokens)
159
+ self.encoder = dict(zip(vocab, range(len(vocab))))
160
+ self.decoder = {v: k for k, v in self.encoder.items()}
161
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
162
+ self.cache = {t: t for t in special_tokens}
163
+ special = "|".join(special_tokens)
164
+ self.pat = re.compile(
165
+ special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
166
+ re.IGNORECASE,
167
+ )
168
+ self.vocab_size = len(self.encoder)
169
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
170
+ self.sot_token_id = self.all_special_ids[0]
171
+ self.eot_token_id = self.all_special_ids[1]
172
+ self.context_length = context_length
173
+ self.clean_fn = get_clean_fn(clean)
174
+ self.reduction_fn = (
175
+ get_reduction_mask_fn(reduction_mask) if reduction_mask else None
176
+ )
177
+
178
+ def bpe(self, token):
179
+ if token in self.cache:
180
+ return self.cache[token]
181
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
182
+ pairs = get_pairs(word)
183
+
184
+ if not pairs:
185
+ return token + "</w>"
186
+
187
+ while True:
188
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
189
+ if bigram not in self.bpe_ranks:
190
+ break
191
+ first, second = bigram
192
+ new_word = []
193
+ i = 0
194
+ while i < len(word):
195
+ try:
196
+ j = word.index(first, i)
197
+ new_word.extend(word[i:j])
198
+ i = j
199
+ except:
200
+ new_word.extend(word[i:])
201
+ break
202
+
203
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
204
+ new_word.append(first + second)
205
+ i += 2
206
+ else:
207
+ new_word.append(word[i])
208
+ i += 1
209
+ new_word = tuple(new_word)
210
+ word = new_word
211
+ if len(word) == 1:
212
+ break
213
+ else:
214
+ pairs = get_pairs(word)
215
+ word = " ".join(word)
216
+ self.cache[token] = word
217
+ return word
218
+
219
+ def encode(self, text):
220
+ bpe_tokens = []
221
+ text = self.clean_fn(text)
222
+ for token in re.findall(self.pat, text):
223
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
224
+ bpe_tokens.extend(
225
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
226
+ )
227
+ return bpe_tokens
228
+
229
+ def decode(self, tokens):
230
+ text = "".join([self.decoder[token] for token in tokens])
231
+ text = (
232
+ bytearray([self.byte_decoder[c] for c in text])
233
+ .decode("utf-8", errors="replace")
234
+ .replace("</w>", " ")
235
+ )
236
+ return text
237
+
238
+ def __call__(
239
+ self, texts: Union[str, List[str]], context_length: Optional[int] = None
240
+ ) -> torch.LongTensor:
241
+ """Returns the tokenized representation of given input string(s)
242
+
243
+ Parameters
244
+ ----------
245
+ texts : Union[str, List[str]]
246
+ An input string or a list of input strings to tokenize
247
+ context_length : int
248
+ The context length to use; all CLIP models use 77 as the context length
249
+
250
+ Returns
251
+ -------
252
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
253
+ """
254
+ if isinstance(texts, str):
255
+ texts = [texts]
256
+
257
+ context_length = context_length or self.context_length
258
+ assert context_length, "Please set a valid context length"
259
+
260
+ if self.reduction_fn is not None:
261
+ # use reduction strategy for tokenize if set, otherwise default to truncation below
262
+ return self.reduction_fn(
263
+ texts,
264
+ context_length=context_length,
265
+ sot_token_id=self.sot_token_id,
266
+ eot_token_id=self.eot_token_id,
267
+ encode_fn=self.encode,
268
+ )
269
+
270
+ all_tokens = [
271
+ [self.sot_token_id] + self.encode(text) + [self.eot_token_id]
272
+ for text in texts
273
+ ]
274
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
275
+
276
+ for i, tokens in enumerate(all_tokens):
277
+ if len(tokens) > context_length:
278
+ tokens = tokens[:context_length] # Truncate
279
+ tokens[-1] = self.eot_token_id
280
+ result[i, : len(tokens)] = torch.tensor(tokens)
281
+
282
+ return result
283
+
284
+
285
+ def random_mask_tokenize(
286
+ texts: Union[str, List[str]],
287
+ context_length: int,
288
+ sot_token_id: int,
289
+ eot_token_id: int,
290
+ encode_fn: Callable,
291
+ shuffle: bool = False,
292
+ ):
293
+ all_tokens = [encode_fn(text) for text in texts]
294
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
295
+
296
+ for i, tokens in enumerate(all_tokens):
297
+ tokens = torch.tensor(tokens)
298
+ num_tokens = len(tokens)
299
+ if num_tokens > context_length - 2: # 2 for sot and eot token
300
+ num_keep = context_length - 2
301
+ indices = torch.randperm(len(tokens))
302
+ indices = indices[:num_keep]
303
+ if not shuffle:
304
+ indices = indices.msort()
305
+ tokens = tokens[indices]
306
+ num_tokens = num_keep
307
+ result[i, 0] = sot_token_id
308
+ result[i, 1 : num_tokens + 1] = tokens
309
+ result[i, num_tokens + 1] = eot_token_id
310
+
311
+ return result
312
+
313
+
314
+ def simple_mask_tokenize(
315
+ texts: Union[str, List[str]],
316
+ context_length: int,
317
+ sot_token_id: int,
318
+ eot_token_id: int,
319
+ encode_fn: Callable,
320
+ ):
321
+ all_tokens = [encode_fn(text) for text in texts]
322
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
323
+
324
+ for i, tokens in enumerate(all_tokens):
325
+ num_tokens = len(tokens)
326
+ if num_tokens > context_length - 2: # 2 for sot and eot token
327
+ num_keep = context_length - 2
328
+ start_index = random.randint(0, num_tokens - num_keep) # high is incl
329
+ tokens = tokens[start_index : start_index + num_keep]
330
+ tokens = [sot_token_id] + tokens + [eot_token_id]
331
+ result[i, : len(tokens)] = torch.tensor(tokens)
332
+
333
+ return result
334
+
335
+
336
+
337
+ def get_reduction_mask_fn(type: str):
338
+ """Choose strategy for dropping (masking) tokens to achieve target context length"""
339
+ assert type in ("simple", "random", "shuffle")
340
+ if type == "simple":
341
+ return simple_mask_tokenize # randomly select block [start:end]
342
+ elif type == "random":
343
+ return random_mask_tokenize # randomly drop tokens (keep order)
344
+ elif type == "shuffle":
345
+ return partial(
346
+ random_mask_tokenize, shuffle=True
347
+ ) # randomly drop tokens (shuffle order)