lightly-studio 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (356) hide show
  1. lightly_studio/__init__.py +12 -0
  2. lightly_studio/api/__init__.py +0 -0
  3. lightly_studio/api/app.py +131 -0
  4. lightly_studio/api/cache.py +77 -0
  5. lightly_studio/api/db_tables.py +35 -0
  6. lightly_studio/api/features.py +5 -0
  7. lightly_studio/api/routes/api/annotation.py +305 -0
  8. lightly_studio/api/routes/api/annotation_label.py +87 -0
  9. lightly_studio/api/routes/api/annotations/__init__.py +7 -0
  10. lightly_studio/api/routes/api/annotations/create_annotation.py +52 -0
  11. lightly_studio/api/routes/api/caption.py +100 -0
  12. lightly_studio/api/routes/api/classifier.py +384 -0
  13. lightly_studio/api/routes/api/dataset.py +191 -0
  14. lightly_studio/api/routes/api/dataset_tag.py +266 -0
  15. lightly_studio/api/routes/api/embeddings2d.py +90 -0
  16. lightly_studio/api/routes/api/exceptions.py +114 -0
  17. lightly_studio/api/routes/api/export.py +114 -0
  18. lightly_studio/api/routes/api/features.py +17 -0
  19. lightly_studio/api/routes/api/frame.py +241 -0
  20. lightly_studio/api/routes/api/image.py +155 -0
  21. lightly_studio/api/routes/api/metadata.py +161 -0
  22. lightly_studio/api/routes/api/operator.py +75 -0
  23. lightly_studio/api/routes/api/sample.py +103 -0
  24. lightly_studio/api/routes/api/selection.py +87 -0
  25. lightly_studio/api/routes/api/settings.py +41 -0
  26. lightly_studio/api/routes/api/status.py +19 -0
  27. lightly_studio/api/routes/api/text_embedding.py +50 -0
  28. lightly_studio/api/routes/api/validators.py +17 -0
  29. lightly_studio/api/routes/api/video.py +133 -0
  30. lightly_studio/api/routes/healthz.py +13 -0
  31. lightly_studio/api/routes/images.py +104 -0
  32. lightly_studio/api/routes/video_frames_media.py +116 -0
  33. lightly_studio/api/routes/video_media.py +223 -0
  34. lightly_studio/api/routes/webapp.py +51 -0
  35. lightly_studio/api/server.py +94 -0
  36. lightly_studio/core/__init__.py +0 -0
  37. lightly_studio/core/add_samples.py +533 -0
  38. lightly_studio/core/add_videos.py +294 -0
  39. lightly_studio/core/dataset.py +780 -0
  40. lightly_studio/core/dataset_query/__init__.py +14 -0
  41. lightly_studio/core/dataset_query/boolean_expression.py +67 -0
  42. lightly_studio/core/dataset_query/dataset_query.py +317 -0
  43. lightly_studio/core/dataset_query/field.py +113 -0
  44. lightly_studio/core/dataset_query/field_expression.py +79 -0
  45. lightly_studio/core/dataset_query/match_expression.py +23 -0
  46. lightly_studio/core/dataset_query/order_by.py +79 -0
  47. lightly_studio/core/dataset_query/sample_field.py +37 -0
  48. lightly_studio/core/dataset_query/tags_expression.py +46 -0
  49. lightly_studio/core/image_sample.py +36 -0
  50. lightly_studio/core/loading_log.py +56 -0
  51. lightly_studio/core/sample.py +291 -0
  52. lightly_studio/core/start_gui.py +54 -0
  53. lightly_studio/core/video_sample.py +38 -0
  54. lightly_studio/dataset/__init__.py +0 -0
  55. lightly_studio/dataset/edge_embedding_generator.py +155 -0
  56. lightly_studio/dataset/embedding_generator.py +129 -0
  57. lightly_studio/dataset/embedding_manager.py +349 -0
  58. lightly_studio/dataset/env.py +20 -0
  59. lightly_studio/dataset/file_utils.py +49 -0
  60. lightly_studio/dataset/fsspec_lister.py +275 -0
  61. lightly_studio/dataset/mobileclip_embedding_generator.py +158 -0
  62. lightly_studio/dataset/perception_encoder_embedding_generator.py +260 -0
  63. lightly_studio/db_manager.py +166 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.GcXvs2l7.css +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/12.Dx6SXgAb.css +1 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/17.9X9_k6TP.css +1 -0
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/18.BxiimdIO.css +1 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/2.CkOblLn7.css +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/ClassifierSamplesGrid.BJbCDlvs.css +1 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Medium.DVUZMR_6.ttf +0 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
  78. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.CefECEWA.css +1 -0
  79. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.D5tDcjY-.css +1 -0
  80. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.9X9_k6TP.css +1 -0
  81. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.BxiimdIO.css +1 -0
  82. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.Dx6SXgAb.css +1 -0
  83. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/transform._-1mPSEI.css +1 -0
  84. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/0dDyq72A.js +20 -0
  85. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
  86. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BK4An2kI.js +1 -0
  87. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BRmB-kJ9.js +1 -0
  88. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B_1cpokE.js +1 -0
  89. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BiqpDEr0.js +1 -0
  90. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BpLiSKgx.js +1 -0
  91. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BscxbINH.js +39 -0
  92. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C1FmrZbK.js +1 -0
  93. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C80h3dJx.js +1 -0
  94. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8mfFM-u.js +2 -0
  95. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CGY1p9L4.js +517 -0
  96. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/COfLknXM.js +1 -0
  97. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
  98. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
  99. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CmLg0ys7.js +1 -0
  100. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvGjimpO.js +1 -0
  101. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D3RDXHoj.js +39 -0
  102. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D4y7iiT3.js +1 -0
  103. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D9SC3jBb.js +1 -0
  104. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DCuAdx1Q.js +20 -0
  105. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DDBy-_jD.js +1 -0
  106. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
  107. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DL9a7v5o.js +1 -0
  108. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DSKECuqX.js +39 -0
  109. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D_FFv0Oe.js +1 -0
  110. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DiZ5o5vz.js +1 -0
  111. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DkbXUtyG.js +1 -0
  112. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DmK2hulV.js +1 -0
  113. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DqnHaLTj.js +1 -0
  114. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DtWZc_tl.js +1 -0
  115. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DuUalyFS.js +1 -0
  116. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DwIonDAZ.js +1 -0
  117. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Il-mSPmK.js +1 -0
  118. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KNLP4aJU.js +1 -0
  119. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KjYeVjkE.js +1 -0
  120. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/MErlcOXj.js +1 -0
  121. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VRI4prUD.js +1 -0
  122. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VYb2dkNs.js +1 -0
  123. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VqWvU2yF.js +1 -0
  124. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/dHC3otuL.js +1 -0
  125. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/da7Oy_lO.js +1 -0
  126. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/eAy8rZzC.js +2 -0
  127. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/erjNR5MX.js +1 -0
  128. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/f1oG3eFE.js +1 -0
  129. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/rsLi1iKv.js +20 -0
  130. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/rwuuBP9f.js +1 -0
  131. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/xGHZQ1pe.js +3 -0
  132. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.DrTRUgT3.js +2 -0
  133. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.BK5EOJl2.js +1 -0
  134. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.CIvTuljF.js +4 -0
  135. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.UBvSzxdA.js +1 -0
  136. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.CQ_tiLJa.js +1 -0
  137. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.KqkAcaxW.js +1 -0
  138. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.DoYsmxQc.js +1 -0
  139. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/13.571n2LZA.js +1 -0
  140. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/14.DGs689M-.js +1 -0
  141. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/15.CWG1ehzT.js +1 -0
  142. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/16.Dpq6jbSh.js +1 -0
  143. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/17.B5AZbHUU.js +1 -0
  144. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/18.CBga8cnq.js +1 -0
  145. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.D2HXgz-8.js +1090 -0
  146. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.f4HAg-y3.js +1 -0
  147. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.BKF4xuKQ.js +1 -0
  148. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.BAE0Pm_f.js +39 -0
  149. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.CouWWpzA.js +1 -0
  150. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.UBHT0ktp.js +1 -0
  151. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.FiYNElcc.js +1 -0
  152. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.B3-UaT23.js +1 -0
  153. lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/clustering.worker-DKqeLtG0.js +2 -0
  154. lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/search.worker-vNSty3B0.js +1 -0
  155. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
  156. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
  157. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
  158. lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
  159. lightly_studio/dist_lightly_studio_view_app/index.html +45 -0
  160. lightly_studio/errors.py +5 -0
  161. lightly_studio/examples/example.py +25 -0
  162. lightly_studio/examples/example_coco.py +27 -0
  163. lightly_studio/examples/example_coco_caption.py +29 -0
  164. lightly_studio/examples/example_metadata.py +369 -0
  165. lightly_studio/examples/example_operators.py +111 -0
  166. lightly_studio/examples/example_selection.py +28 -0
  167. lightly_studio/examples/example_split_work.py +48 -0
  168. lightly_studio/examples/example_video.py +22 -0
  169. lightly_studio/examples/example_video_annotations.py +157 -0
  170. lightly_studio/examples/example_yolo.py +22 -0
  171. lightly_studio/export/coco_captions.py +69 -0
  172. lightly_studio/export/export_dataset.py +104 -0
  173. lightly_studio/export/lightly_studio_label_input.py +120 -0
  174. lightly_studio/export_schema.py +18 -0
  175. lightly_studio/export_version.py +57 -0
  176. lightly_studio/few_shot_classifier/__init__.py +0 -0
  177. lightly_studio/few_shot_classifier/classifier.py +80 -0
  178. lightly_studio/few_shot_classifier/classifier_manager.py +644 -0
  179. lightly_studio/few_shot_classifier/random_forest_classifier.py +495 -0
  180. lightly_studio/metadata/complex_metadata.py +47 -0
  181. lightly_studio/metadata/compute_similarity.py +84 -0
  182. lightly_studio/metadata/compute_typicality.py +67 -0
  183. lightly_studio/metadata/gps_coordinate.py +41 -0
  184. lightly_studio/metadata/metadata_protocol.py +17 -0
  185. lightly_studio/models/__init__.py +1 -0
  186. lightly_studio/models/annotation/__init__.py +0 -0
  187. lightly_studio/models/annotation/annotation_base.py +303 -0
  188. lightly_studio/models/annotation/instance_segmentation.py +56 -0
  189. lightly_studio/models/annotation/links.py +17 -0
  190. lightly_studio/models/annotation/object_detection.py +47 -0
  191. lightly_studio/models/annotation/semantic_segmentation.py +44 -0
  192. lightly_studio/models/annotation_label.py +47 -0
  193. lightly_studio/models/caption.py +49 -0
  194. lightly_studio/models/classifier.py +20 -0
  195. lightly_studio/models/dataset.py +70 -0
  196. lightly_studio/models/embedding_model.py +30 -0
  197. lightly_studio/models/image.py +96 -0
  198. lightly_studio/models/metadata.py +208 -0
  199. lightly_studio/models/range.py +17 -0
  200. lightly_studio/models/sample.py +154 -0
  201. lightly_studio/models/sample_embedding.py +36 -0
  202. lightly_studio/models/settings.py +69 -0
  203. lightly_studio/models/tag.py +96 -0
  204. lightly_studio/models/two_dim_embedding.py +16 -0
  205. lightly_studio/models/video.py +161 -0
  206. lightly_studio/plugins/__init__.py +0 -0
  207. lightly_studio/plugins/base_operator.py +60 -0
  208. lightly_studio/plugins/operator_registry.py +47 -0
  209. lightly_studio/plugins/parameter.py +70 -0
  210. lightly_studio/py.typed +0 -0
  211. lightly_studio/resolvers/__init__.py +0 -0
  212. lightly_studio/resolvers/annotation_label_resolver/__init__.py +22 -0
  213. lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
  214. lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
  215. lightly_studio/resolvers/annotation_label_resolver/get_all.py +37 -0
  216. lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
  217. lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
  218. lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
  219. lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
  220. lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
  221. lightly_studio/resolvers/annotation_resolver/__init__.py +40 -0
  222. lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +129 -0
  223. lightly_studio/resolvers/annotation_resolver/create_many.py +124 -0
  224. lightly_studio/resolvers/annotation_resolver/delete_annotation.py +87 -0
  225. lightly_studio/resolvers/annotation_resolver/delete_annotations.py +60 -0
  226. lightly_studio/resolvers/annotation_resolver/get_all.py +85 -0
  227. lightly_studio/resolvers/annotation_resolver/get_all_with_payload.py +179 -0
  228. lightly_studio/resolvers/annotation_resolver/get_by_id.py +34 -0
  229. lightly_studio/resolvers/annotation_resolver/get_by_id_with_payload.py +130 -0
  230. lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +142 -0
  231. lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
  232. lightly_studio/resolvers/annotations/__init__.py +1 -0
  233. lightly_studio/resolvers/annotations/annotations_filter.py +88 -0
  234. lightly_studio/resolvers/caption_resolver.py +129 -0
  235. lightly_studio/resolvers/dataset_resolver/__init__.py +55 -0
  236. lightly_studio/resolvers/dataset_resolver/check_dataset_type.py +29 -0
  237. lightly_studio/resolvers/dataset_resolver/create.py +20 -0
  238. lightly_studio/resolvers/dataset_resolver/delete.py +20 -0
  239. lightly_studio/resolvers/dataset_resolver/export.py +267 -0
  240. lightly_studio/resolvers/dataset_resolver/get_all.py +19 -0
  241. lightly_studio/resolvers/dataset_resolver/get_by_id.py +16 -0
  242. lightly_studio/resolvers/dataset_resolver/get_by_name.py +12 -0
  243. lightly_studio/resolvers/dataset_resolver/get_dataset_details.py +27 -0
  244. lightly_studio/resolvers/dataset_resolver/get_hierarchy.py +31 -0
  245. lightly_studio/resolvers/dataset_resolver/get_or_create_child_dataset.py +58 -0
  246. lightly_studio/resolvers/dataset_resolver/get_parent_dataset_by_sample_id.py +27 -0
  247. lightly_studio/resolvers/dataset_resolver/get_parent_dataset_id.py +22 -0
  248. lightly_studio/resolvers/dataset_resolver/get_root_dataset.py +61 -0
  249. lightly_studio/resolvers/dataset_resolver/get_root_datasets_overview.py +41 -0
  250. lightly_studio/resolvers/dataset_resolver/update.py +25 -0
  251. lightly_studio/resolvers/embedding_model_resolver.py +120 -0
  252. lightly_studio/resolvers/image_filter.py +50 -0
  253. lightly_studio/resolvers/image_resolver/__init__.py +21 -0
  254. lightly_studio/resolvers/image_resolver/create_many.py +52 -0
  255. lightly_studio/resolvers/image_resolver/delete.py +20 -0
  256. lightly_studio/resolvers/image_resolver/filter_new_paths.py +23 -0
  257. lightly_studio/resolvers/image_resolver/get_all_by_dataset_id.py +117 -0
  258. lightly_studio/resolvers/image_resolver/get_by_id.py +14 -0
  259. lightly_studio/resolvers/image_resolver/get_dimension_bounds.py +75 -0
  260. lightly_studio/resolvers/image_resolver/get_many_by_id.py +22 -0
  261. lightly_studio/resolvers/image_resolver/get_samples_excluding.py +43 -0
  262. lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
  263. lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
  264. lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
  265. lightly_studio/resolvers/metadata_resolver/sample/bulk_update_metadata.py +46 -0
  266. lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
  267. lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
  268. lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
  269. lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
  270. lightly_studio/resolvers/sample_embedding_resolver.py +132 -0
  271. lightly_studio/resolvers/sample_resolver/__init__.py +17 -0
  272. lightly_studio/resolvers/sample_resolver/count_by_dataset_id.py +16 -0
  273. lightly_studio/resolvers/sample_resolver/create.py +16 -0
  274. lightly_studio/resolvers/sample_resolver/create_many.py +25 -0
  275. lightly_studio/resolvers/sample_resolver/get_by_id.py +14 -0
  276. lightly_studio/resolvers/sample_resolver/get_filtered_samples.py +56 -0
  277. lightly_studio/resolvers/sample_resolver/get_many_by_id.py +22 -0
  278. lightly_studio/resolvers/sample_resolver/sample_filter.py +74 -0
  279. lightly_studio/resolvers/settings_resolver.py +62 -0
  280. lightly_studio/resolvers/tag_resolver.py +299 -0
  281. lightly_studio/resolvers/twodim_embedding_resolver.py +119 -0
  282. lightly_studio/resolvers/video_frame_resolver/__init__.py +23 -0
  283. lightly_studio/resolvers/video_frame_resolver/count_video_frames_annotations.py +83 -0
  284. lightly_studio/resolvers/video_frame_resolver/create_many.py +57 -0
  285. lightly_studio/resolvers/video_frame_resolver/get_all_by_dataset_id.py +63 -0
  286. lightly_studio/resolvers/video_frame_resolver/get_by_id.py +13 -0
  287. lightly_studio/resolvers/video_frame_resolver/get_table_fields_bounds.py +44 -0
  288. lightly_studio/resolvers/video_frame_resolver/video_frame_annotations_counter_filter.py +47 -0
  289. lightly_studio/resolvers/video_frame_resolver/video_frame_filter.py +57 -0
  290. lightly_studio/resolvers/video_resolver/__init__.py +27 -0
  291. lightly_studio/resolvers/video_resolver/count_video_frame_annotations_by_video_dataset.py +86 -0
  292. lightly_studio/resolvers/video_resolver/create_many.py +58 -0
  293. lightly_studio/resolvers/video_resolver/filter_new_paths.py +33 -0
  294. lightly_studio/resolvers/video_resolver/get_all_by_dataset_id.py +181 -0
  295. lightly_studio/resolvers/video_resolver/get_by_id.py +22 -0
  296. lightly_studio/resolvers/video_resolver/get_table_fields_bounds.py +72 -0
  297. lightly_studio/resolvers/video_resolver/get_view_by_id.py +52 -0
  298. lightly_studio/resolvers/video_resolver/video_count_annotations_filter.py +50 -0
  299. lightly_studio/resolvers/video_resolver/video_filter.py +98 -0
  300. lightly_studio/selection/__init__.py +1 -0
  301. lightly_studio/selection/mundig.py +143 -0
  302. lightly_studio/selection/select.py +203 -0
  303. lightly_studio/selection/select_via_db.py +273 -0
  304. lightly_studio/selection/selection_config.py +49 -0
  305. lightly_studio/services/annotations_service/__init__.py +33 -0
  306. lightly_studio/services/annotations_service/create_annotation.py +64 -0
  307. lightly_studio/services/annotations_service/delete_annotation.py +22 -0
  308. lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
  309. lightly_studio/services/annotations_service/update_annotation.py +54 -0
  310. lightly_studio/services/annotations_service/update_annotation_bounding_box.py +36 -0
  311. lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
  312. lightly_studio/services/annotations_service/update_annotations.py +29 -0
  313. lightly_studio/setup_logging.py +59 -0
  314. lightly_studio/type_definitions.py +31 -0
  315. lightly_studio/utils/__init__.py +3 -0
  316. lightly_studio/utils/download.py +94 -0
  317. lightly_studio/vendor/__init__.py +1 -0
  318. lightly_studio/vendor/mobileclip/ACKNOWLEDGEMENTS +422 -0
  319. lightly_studio/vendor/mobileclip/LICENSE +31 -0
  320. lightly_studio/vendor/mobileclip/LICENSE_weights_data +50 -0
  321. lightly_studio/vendor/mobileclip/README.md +5 -0
  322. lightly_studio/vendor/mobileclip/__init__.py +96 -0
  323. lightly_studio/vendor/mobileclip/clip.py +77 -0
  324. lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
  325. lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
  326. lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
  327. lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
  328. lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
  329. lightly_studio/vendor/mobileclip/logger.py +154 -0
  330. lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
  331. lightly_studio/vendor/mobileclip/models/mci.py +933 -0
  332. lightly_studio/vendor/mobileclip/models/vit.py +433 -0
  333. lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
  334. lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
  335. lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
  336. lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
  337. lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
  338. lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
  339. lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
  340. lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
  341. lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
  342. lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
  343. lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
  344. lightly_studio/vendor/perception_encoder/LICENSE.PE +201 -0
  345. lightly_studio/vendor/perception_encoder/README.md +11 -0
  346. lightly_studio/vendor/perception_encoder/vision_encoder/__init__.py +0 -0
  347. lightly_studio/vendor/perception_encoder/vision_encoder/bpe_simple_vocab_16e6.txt.gz +0 -0
  348. lightly_studio/vendor/perception_encoder/vision_encoder/config.py +205 -0
  349. lightly_studio/vendor/perception_encoder/vision_encoder/config_src.py +264 -0
  350. lightly_studio/vendor/perception_encoder/vision_encoder/pe.py +766 -0
  351. lightly_studio/vendor/perception_encoder/vision_encoder/rope.py +352 -0
  352. lightly_studio/vendor/perception_encoder/vision_encoder/tokenizer.py +347 -0
  353. lightly_studio/vendor/perception_encoder/vision_encoder/transforms.py +36 -0
  354. lightly_studio-0.4.6.dist-info/METADATA +88 -0
  355. lightly_studio-0.4.6.dist-info/RECORD +356 -0
  356. lightly_studio-0.4.6.dist-info/WHEEL +4 -0
@@ -0,0 +1,275 @@
1
+ """File listing utilities using fsspec.
2
+
3
+ Handles local and remote paths, directories, and glob patterns.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+ from collections.abc import Iterator
10
+ from typing import Any
11
+
12
+ import fsspec
13
+ from tqdm import tqdm
14
+
15
+ # Constants
16
+ PROTOCOL_SEPARATOR = "://"
17
+ DEFAULT_PROTOCOL = "file"
18
+ PATH_SEPARATOR = "/"
19
+
20
+ # Glob pattern characters
21
+ GLOB_CHARS = ["*", "?", "[", "]"]
22
+
23
+ # Cloud storage protocols
24
+ CLOUD_PROTOCOLS = ("s3", "gs", "gcs", "azure", "abfs")
25
+
26
+ # Image file extensions
27
+ IMAGE_EXTENSIONS = {
28
+ ".png",
29
+ ".jpg",
30
+ ".jpeg",
31
+ ".gif",
32
+ ".webp",
33
+ ".bmp",
34
+ ".tiff",
35
+ }
36
+
37
+
38
+ def iter_files_from_path(path: str, allowed_extensions: set[str] | None = None) -> Iterator[str]:
39
+ """List all files from a single path, handling directories, globs, and individual files.
40
+
41
+ Args:
42
+ path: A single path which can be:
43
+ - Individual file path
44
+ - Directory path (will list all files recursively)
45
+ - Glob pattern
46
+ - Remote path (s3://, gcs://, etc.)
47
+ allowed_extensions: Optional set of allowed file extensions (e.g., {".jpg", ".png"}).
48
+ If None, uses default IMAGE_EXTENSIONS.
49
+
50
+ Yields:
51
+ File paths as they are discovered, with progress tracking
52
+ """
53
+ seen: set[str] = set()
54
+ extensions = allowed_extensions or IMAGE_EXTENSIONS
55
+ with tqdm(desc="Discovering files", unit=" files", dynamic_ncols=True) as pbar:
56
+ cleaned_path = str(path).strip()
57
+ if not cleaned_path:
58
+ return
59
+ fs = _get_filesystem(cleaned_path)
60
+ yield from _process_single_path_streaming(fs, cleaned_path, seen, pbar, extensions)
61
+
62
+
63
+ def _process_single_path_streaming(
64
+ fs: fsspec.AbstractFileSystem, path: str, seen: set[str], pbar: tqdm[Any], extensions: set[str]
65
+ ) -> Iterator[str]:
66
+ """Process a single path and yield matching image files.
67
+
68
+ Handles different path types: individual files, directories, and glob patterns.
69
+
70
+ Args:
71
+ fs: The filesystem instance.
72
+ path: The path to process (file, directory, or glob pattern).
73
+ seen: Set of already processed paths to avoid duplicates.
74
+ pbar: Progress bar instance for tracking progress.
75
+ extensions: Set of allowed file extensions.
76
+
77
+ Yields:
78
+ File paths that match the criteria.
79
+
80
+ Raises:
81
+ ValueError: If the path doesn't exist or is not an image file when expected.
82
+ """
83
+ if _is_glob_pattern(path):
84
+ yield from _process_glob_pattern(fs, path, seen, pbar, extensions)
85
+ elif not fs.exists(path):
86
+ raise ValueError(f"Path does not exist: {path}")
87
+ elif fs.isfile(path):
88
+ if _is_image_file(path, extensions) and path not in seen:
89
+ seen.add(path)
90
+ pbar.update(1)
91
+ yield path
92
+ elif not _is_image_file(path, extensions):
93
+ raise ValueError(f"File is not an image: {path}")
94
+ elif fs.isdir(path):
95
+ for file_path in _stream_files_from_directory(fs, path, extensions):
96
+ if file_path not in seen:
97
+ seen.add(file_path)
98
+ pbar.update(1)
99
+ yield file_path
100
+
101
+
102
+ def _process_glob_pattern(
103
+ fs: fsspec.AbstractFileSystem, path: str, seen: set[str], pbar: tqdm[Any], extensions: set[str]
104
+ ) -> Iterator[str]:
105
+ """Process glob pattern and yield matching image files.
106
+
107
+ Args:
108
+ fs: The filesystem instance.
109
+ path: The glob pattern path.
110
+ seen: Set of already processed paths to avoid duplicates.
111
+ pbar: Progress bar instance for tracking progress.
112
+ extensions: Set of allowed file extensions.
113
+
114
+ Yields:
115
+ File paths that match the glob pattern and allowed extensions.
116
+ """
117
+ matching_paths = fs.glob(path)
118
+ for p in matching_paths:
119
+ path_str = str(p)
120
+ if _needs_protocol_prefix(path_str, fs):
121
+ protocol = _get_protocol_string(fs)
122
+ path_str = f"{protocol}{PROTOCOL_SEPARATOR}{path_str}"
123
+ if fs.isfile(path_str) and _is_image_file(path_str, extensions) and path_str not in seen:
124
+ seen.add(path_str)
125
+ pbar.update(1)
126
+ yield path_str
127
+
128
+
129
+ def _stream_files_from_directory(
130
+ fs: fsspec.AbstractFileSystem, path: str, extensions: set[str]
131
+ ) -> Iterator[str]:
132
+ """Stream files from a directory with progress tracking.
133
+
134
+ Args:
135
+ fs: The filesystem instance
136
+ path: Directory path to list
137
+ extensions: Set of allowed file extensions
138
+
139
+ Yields:
140
+ File paths as they are discovered
141
+ """
142
+ try:
143
+ protocol = _get_protocol_string(fs)
144
+ if protocol in CLOUD_PROTOCOLS:
145
+ yield from _stream_files_using_walk(fs, path, extensions)
146
+ else:
147
+ try:
148
+ all_paths = fs.find(path, detail=False)
149
+ for p in all_paths:
150
+ if fs.isfile(p) and _is_image_file(p, extensions):
151
+ yield p
152
+ except Exception as e:
153
+ logging.warning(f"fs.find() failed for {path}, trying alternative method: {e}")
154
+ yield from _stream_files_using_walk(fs, path, extensions)
155
+ except Exception as e:
156
+ logging.error(f"Error streaming files from '{path}': {e}")
157
+
158
+
159
+ def _stream_files_using_walk(
160
+ fs: fsspec.AbstractFileSystem, path: str, extensions: set[str]
161
+ ) -> Iterator[str]:
162
+ """Stream files using fs.walk() method.
163
+
164
+ Args:
165
+ fs: The filesystem instance.
166
+ path: The directory path to walk.
167
+ extensions: Set of allowed file extensions.
168
+
169
+ Yields:
170
+ File paths that match the allowed extensions.
171
+ """
172
+
173
+ def add_protocol_if_needed(p: str) -> str:
174
+ if _needs_protocol_prefix(p, fs):
175
+ protocol = _get_protocol_string(fs)
176
+ return f"{protocol}{PROTOCOL_SEPARATOR}{p}"
177
+ return p
178
+
179
+ for root, _dirs, files in fs.walk(path):
180
+ for file in files:
181
+ if not root.endswith(PATH_SEPARATOR):
182
+ full_path = f"{root}{PATH_SEPARATOR}{file}"
183
+ else:
184
+ full_path = f"{root}{file}"
185
+ full_path = add_protocol_if_needed(full_path)
186
+ if _is_image_file(full_path, extensions):
187
+ yield full_path
188
+
189
+
190
+ def _get_filesystem(path: str) -> fsspec.AbstractFileSystem:
191
+ """Get the appropriate filesystem for the given path.
192
+
193
+ Args:
194
+ path: The path to determine the filesystem for. Can be local or remote.
195
+
196
+ Returns:
197
+ An fsspec filesystem instance appropriate for the path's protocol.
198
+
199
+ Raises:
200
+ ValueError: If the protocol cannot be determined or is invalid.
201
+ """
202
+ protocol = path.split(PROTOCOL_SEPARATOR)[0] if PROTOCOL_SEPARATOR in path else DEFAULT_PROTOCOL
203
+
204
+ # Ensure protocol is a string, not a tuple
205
+ if isinstance(protocol, (list, tuple)):
206
+ protocol = protocol[0]
207
+
208
+ return fsspec.filesystem(protocol)
209
+
210
+
211
+ def _is_glob_pattern(path: str) -> bool:
212
+ """Check if a path contains glob pattern characters.
213
+
214
+ Args:
215
+ path: The path to check for glob patterns.
216
+
217
+ Returns:
218
+ True if the path contains glob pattern characters (*, ?, [, ]), False otherwise.
219
+ """
220
+ return any(char in path for char in GLOB_CHARS)
221
+
222
+
223
+ def _is_image_file(path: str, extensions: set[str]) -> bool:
224
+ """Check if a file is an image based on its extension.
225
+
226
+ Args:
227
+ path: The file path to check.
228
+ extensions: Set of allowed file extensions (e.g., {'.jpg', '.png'}).
229
+
230
+ Returns:
231
+ True if the file has an allowed image extension, False otherwise.
232
+ """
233
+ path_lower = path.lower()
234
+ return any(path_lower.endswith(ext) for ext in extensions)
235
+
236
+
237
+ def _needs_protocol_prefix(path: str, fs: fsspec.AbstractFileSystem) -> bool:
238
+ """Check if a path needs protocol prefix.
239
+
240
+ Args:
241
+ path: The path to check.
242
+ fs: The filesystem instance.
243
+
244
+ Returns:
245
+ True if the path needs a protocol prefix (e.g., for cloud storage),
246
+ False if it is a local path.
247
+ """
248
+ if PROTOCOL_SEPARATOR in path:
249
+ return False
250
+
251
+ if not hasattr(fs, "protocol"):
252
+ return False
253
+
254
+ protocol = getattr(fs, "protocol", DEFAULT_PROTOCOL)
255
+ # Handle case where protocol is a tuple (common with fsspec)
256
+ if isinstance(protocol, (list, tuple)):
257
+ protocol = protocol[0]
258
+
259
+ return str(protocol) != DEFAULT_PROTOCOL
260
+
261
+
262
+ def _get_protocol_string(fs: fsspec.AbstractFileSystem) -> str:
263
+ """Get the protocol string from filesystem.
264
+
265
+ Args:
266
+ fs: The filesystem instance.
267
+
268
+ Returns:
269
+ The protocol string (e.g., 's3', 'file', 'gcs').
270
+ Returns 'file' as default if protocol cannot be determined.
271
+ """
272
+ protocol = getattr(fs, "protocol", DEFAULT_PROTOCOL)
273
+ if isinstance(protocol, (list, tuple)):
274
+ return str(protocol[0])
275
+ return str(protocol)
@@ -0,0 +1,158 @@
1
+ """MobileCLIP embedding generator."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import tempfile
6
+ from pathlib import Path
7
+ from typing import Callable
8
+ from uuid import UUID
9
+
10
+ import fsspec
11
+ import numpy as np
12
+ import torch
13
+ from numpy.typing import NDArray
14
+ from PIL import Image
15
+ from torch.utils.data import DataLoader, Dataset
16
+ from tqdm import tqdm
17
+
18
+ from lightly_studio.models.embedding_model import EmbeddingModelCreate
19
+ from lightly_studio.vendor import mobileclip
20
+
21
+ from . import file_utils
22
+ from .embedding_generator import ImageEmbeddingGenerator
23
+
24
+ MODEL_NAME = "mobileclip_s0"
25
+ MOBILECLIP_DOWNLOAD_URL = (
26
+ f"https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/{MODEL_NAME}.pt"
27
+ )
28
+ MAX_BATCH_SIZE: int = 16
29
+ EMBEDDING_DIMENSION: int = 512
30
+
31
+
32
+ # Dataset for efficient batched image loading and preprocessing
33
+ class _ImageFileDataset(Dataset[torch.Tensor]):
34
+ """Dataset wrapping image file paths and a preprocess function."""
35
+
36
+ def __init__(
37
+ self,
38
+ filepaths: list[str],
39
+ preprocess: Callable[[Image.Image], torch.Tensor],
40
+ ) -> None:
41
+ self.filepaths = filepaths
42
+ self.preprocess = preprocess
43
+
44
+ def __len__(self) -> int:
45
+ return len(self.filepaths)
46
+
47
+ def __getitem__(self, idx: int) -> torch.Tensor:
48
+ with fsspec.open(self.filepaths[idx], "rb") as file:
49
+ image = Image.open(file).convert("RGB")
50
+ return self.preprocess(image)
51
+
52
+
53
+ class MobileCLIPEmbeddingGenerator(ImageEmbeddingGenerator):
54
+ """MobileCLIP embedding model."""
55
+
56
+ def __init__(self) -> None:
57
+ """Initialize the MobileCLIP embedding model.
58
+
59
+ This method loads the MobileCLIP model and its tokenizer. The model
60
+ checkpoint is downloaded and cached locally for future use.
61
+ """
62
+ model_path = _get_cached_mobileclip_checkpoint()
63
+ self._model, _, self._preprocess = mobileclip.create_model_and_transforms(
64
+ model_name=MODEL_NAME, pretrained=str(model_path)
65
+ )
66
+
67
+ # Auto select device: CUDA > MPS (Apple Silicon) > CPU
68
+ self._device = torch.device(
69
+ "cuda"
70
+ if torch.cuda.is_available()
71
+ else "mps"
72
+ if torch.backends.mps.is_available()
73
+ else "cpu"
74
+ )
75
+ self._model = self._model.to(self._device)
76
+ self._tokenizer = mobileclip.get_tokenizer(model_name=MODEL_NAME)
77
+ self._model_hash = file_utils.get_file_xxhash(model_path)
78
+
79
+ def get_embedding_model_input(self, dataset_id: UUID) -> EmbeddingModelCreate:
80
+ """Generate an EmbeddingModelCreate instance.
81
+
82
+ Args:
83
+ dataset_id: The ID of the dataset.
84
+
85
+ Returns:
86
+ An EmbeddingModelCreate instance with the model details.
87
+ """
88
+ return EmbeddingModelCreate(
89
+ name=MODEL_NAME,
90
+ embedding_model_hash=self._model_hash,
91
+ embedding_dimension=EMBEDDING_DIMENSION,
92
+ dataset_id=dataset_id,
93
+ )
94
+
95
+ def embed_text(self, text: str) -> list[float]:
96
+ """Embed a text with MobileCLIP.
97
+
98
+ Args:
99
+ text: The text to embed.
100
+
101
+ Returns:
102
+ A list of floats representing the generated embedding.
103
+ """
104
+ tokenized = self._tokenizer([text]).to(self._device)
105
+ with torch.no_grad():
106
+ embedding = self._model.encode_text(tokenized)[0]
107
+ # Convert embedding to list of floats.
108
+ embedding_list: list[float] = embedding.cpu().numpy().flatten().tolist()
109
+ return embedding_list
110
+
111
+ def embed_images(self, filepaths: list[str]) -> NDArray[np.float32]:
112
+ """Embed images with MobileCLIP.
113
+
114
+ Args:
115
+ filepaths: A list of file paths to the images to embed.
116
+
117
+ Returns:
118
+ A numpy array representing the generated embeddings
119
+ in the same order as the input file paths.
120
+ """
121
+ total_images = len(filepaths)
122
+ if not total_images:
123
+ return np.empty((0, EMBEDDING_DIMENSION), dtype=np.float32)
124
+
125
+ dataset = _ImageFileDataset(filepaths, self._preprocess)
126
+
127
+ # To avoid issues with db locking and multiprocessing we set the
128
+ # number of workers to 0 (no multiprocessing). The DataLoader is still
129
+ # very useful for batching and async prefetching of images.
130
+ loader = DataLoader(
131
+ dataset,
132
+ batch_size=MAX_BATCH_SIZE,
133
+ num_workers=0, # must be 0 to avoid multiprocessing issues
134
+ )
135
+
136
+ embeddings = np.empty((total_images, EMBEDDING_DIMENSION), dtype=np.float32)
137
+ position = 0
138
+ with tqdm(
139
+ total=total_images, desc="Generating embeddings", unit=" images"
140
+ ) as progress_bar, torch.no_grad():
141
+ for images_tensor in loader:
142
+ imgs = images_tensor.to(self._device, non_blocking=True)
143
+ batch_embeddings = self._model.encode_image(imgs).cpu().numpy()
144
+ batch_size = imgs.size(0)
145
+ embeddings[position : position + batch_size] = batch_embeddings
146
+ position += batch_size
147
+ progress_bar.update(batch_size)
148
+
149
+ return embeddings
150
+
151
+
152
+ def _get_cached_mobileclip_checkpoint() -> Path:
153
+ file_path = Path(tempfile.gettempdir()) / f"{MODEL_NAME}.pt"
154
+ file_utils.download_file_if_does_not_exist(
155
+ url=MOBILECLIP_DOWNLOAD_URL,
156
+ local_filename=file_path,
157
+ )
158
+ return file_path
@@ -0,0 +1,260 @@
1
+ """Perception Encoder embedding generator."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Callable
7
+ from uuid import UUID
8
+
9
+ import fsspec
10
+ import numpy as np
11
+ import torch
12
+ from av import container
13
+ from numpy.typing import NDArray
14
+ from PIL import Image
15
+ from torch.utils.data import DataLoader, Dataset
16
+ from tqdm import tqdm
17
+
18
+ from lightly_studio.models.embedding_model import EmbeddingModelCreate
19
+ from lightly_studio.vendor.perception_encoder.vision_encoder import pe, transforms
20
+
21
+ from . import file_utils
22
+ from .embedding_generator import ImageEmbeddingGenerator, VideoEmbeddingGenerator
23
+
24
+ MODEL_NAME = "PE-Core-T16-384"
25
+ DEFAULT_VIDEO_CHANNEL = 0
26
+ MAX_BATCH_SIZE: int = 16
27
+ VIDEO_FRAMES_PER_SAMPLE: int = 8
28
+
29
+
30
+ # TODO(Jonas, 12/225): Move to a helper.
31
+ class _ImageFileDataset(Dataset[torch.Tensor]):
32
+ """Dataset wrapping image file paths and a preprocess function.
33
+
34
+ Used for efficient batched image loading and preprocessing
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ filepaths: list[str],
40
+ preprocess: Callable[[Image.Image], torch.Tensor],
41
+ ) -> None:
42
+ self.filepaths = filepaths
43
+ self.preprocess = preprocess
44
+
45
+ def __len__(self) -> int:
46
+ return len(self.filepaths)
47
+
48
+ def __getitem__(self, idx: int) -> torch.Tensor:
49
+ with fsspec.open(self.filepaths[idx], "rb") as file:
50
+ image = Image.open(file).convert("RGB")
51
+ return self.preprocess(image)
52
+
53
+
54
+ class _VideoFileDataset(Dataset[torch.Tensor]):
55
+ """Dataset wrapping video file paths and a preprocess function.
56
+
57
+ Used for efficient batched video loading and preprocessing
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ filepaths: list[str],
63
+ preprocess: Callable[[Image.Image], torch.Tensor],
64
+ ) -> None:
65
+ self.filepaths = filepaths
66
+ self.preprocess = preprocess
67
+
68
+ def __len__(self) -> int:
69
+ return len(self.filepaths)
70
+
71
+ def __getitem__(self, idx: int) -> torch.Tensor:
72
+ """Return tensor [N C H W] for idx-th video.
73
+
74
+ As in the original paper we subsample N frames from a video and stack them to a tensor.
75
+ As in the paper, we use a default of 8 frames per video (VIDEO_FRAMES_PER_SAMPLE).
76
+ Note: the video length in the paper was 16.7 +/- 9.8 sec, hence for longer videos we might
77
+ consider alternative models or more frames.
78
+ """
79
+ video_path = self.filepaths[idx]
80
+ frames = self._load_frames(video_path)
81
+ if not frames:
82
+ raise ValueError(f"Unable to read frames from video '{video_path}'.")
83
+
84
+ processed_frames = [self.preprocess(frame) for frame in frames]
85
+ return torch.stack(processed_frames)
86
+
87
+ def _load_frames(self, video_path: str) -> list[Image.Image]:
88
+ """Sample uniformly spaced frames and return them as PIL images.
89
+
90
+ Using seek for sampling is fast, however it may yield slightly different results on
91
+ different OS (known issue: MacOS vs Linux).
92
+
93
+ Alternative option is to decode frame-by-frame to be OS independent,
94
+ however this comes with performance drop.
95
+ """
96
+ fs, fs_path = fsspec.core.url_to_fs(url=video_path)
97
+ with fs.open(path=fs_path, mode="rb") as video_file, container.open(
98
+ file=video_file
99
+ ) as video_container:
100
+ video_stream = video_container.streams.video[DEFAULT_VIDEO_CHANNEL]
101
+ duration_pts = video_stream.duration
102
+ time_base = float(video_stream.time_base)
103
+ if duration_pts is None or duration_pts <= 0 or time_base <= 0.0:
104
+ return []
105
+
106
+ duration_seconds = duration_pts * time_base
107
+
108
+ # Sample VIDEO_FRAMES_PER_SAMPLE evenly spaced inside [0, duration_seconds)
109
+ ts_to_sample = np.linspace(
110
+ 0.0,
111
+ duration_seconds,
112
+ num=VIDEO_FRAMES_PER_SAMPLE,
113
+ endpoint=False,
114
+ dtype=np.float64,
115
+ )
116
+
117
+ frames: list[Image.Image] = []
118
+ for ts_target in ts_to_sample:
119
+ pts_target = int(ts_target / time_base)
120
+ video_container.seek(offset=pts_target, stream=video_stream)
121
+ frame = next(video_container.decode(video=DEFAULT_VIDEO_CHANNEL))
122
+ frames.append(frame.to_image())
123
+
124
+ return frames
125
+
126
+
127
+ class PerceptionEncoderEmbeddingGenerator(ImageEmbeddingGenerator, VideoEmbeddingGenerator):
128
+ """Perception Encoder Core embedding model."""
129
+
130
+ def __init__(self) -> None:
131
+ """Initialize the Perception Encoder Core embedding model.
132
+
133
+ This method loads the Perception Encoder Core model and its tokenizer. The model
134
+ checkpoint is downloaded and cached locally for future use.
135
+ """
136
+ self._model, model_path = pe.CLIP.from_config(MODEL_NAME, pretrained=True)
137
+ self._preprocess = transforms.get_image_transform(self._model.image_size)
138
+ self._tokenizer = transforms.get_text_tokenizer(self._model.context_length)
139
+
140
+ # Auto select device: CUDA > MPS (Apple Silicon) > CPU
141
+ self._device = torch.device(
142
+ "cuda"
143
+ if torch.cuda.is_available()
144
+ else "mps"
145
+ if torch.backends.mps.is_available()
146
+ else "cpu"
147
+ )
148
+ self._model = self._model.to(self._device)
149
+ self._model_hash = file_utils.get_file_xxhash(Path(model_path))
150
+
151
+ def get_embedding_model_input(self, dataset_id: UUID) -> EmbeddingModelCreate:
152
+ """Generate an EmbeddingModelCreate instance.
153
+
154
+ Args:
155
+ dataset_id: The ID of the dataset.
156
+
157
+ Returns:
158
+ An EmbeddingModelCreate instance with the model details.
159
+ """
160
+ return EmbeddingModelCreate(
161
+ name=MODEL_NAME,
162
+ embedding_model_hash=self._model_hash,
163
+ embedding_dimension=self._model.output_dim,
164
+ dataset_id=dataset_id,
165
+ )
166
+
167
+ def embed_text(self, text: str) -> list[float]:
168
+ """Embed a text with Perception Encoder.
169
+
170
+ Args:
171
+ text: The text to embed.
172
+
173
+ Returns:
174
+ A list of floats representing the generated embedding.
175
+ """
176
+ tokenized = self._tokenizer([text]).to(self._device)
177
+ with torch.no_grad():
178
+ embedding = self._model.encode_text(tokenized, normalize=True)[0]
179
+ # Convert embedding to list of floats.
180
+ embedding_list: list[float] = embedding.cpu().numpy().flatten().tolist()
181
+ return embedding_list
182
+
183
+ def embed_images(self, filepaths: list[str]) -> NDArray[np.float32]:
184
+ """Embed images with Perception Encoder.
185
+
186
+ Args:
187
+ filepaths: A list of file paths to the images to embed.
188
+
189
+ Returns:
190
+ A numpy array representing the generated embeddings
191
+ in the same order as the input file paths.
192
+ """
193
+ total_images = len(filepaths)
194
+ if not total_images:
195
+ return np.empty((0, self._model.output_dim), dtype=np.float32)
196
+
197
+ dataset = _ImageFileDataset(filepaths, self._preprocess)
198
+
199
+ # To avoid issues with db locking and multiprocessing we set the
200
+ # number of workers to 0 (no multiprocessing). The DataLoader is still
201
+ # very useful for batching and async prefetching of images.
202
+ loader = DataLoader(
203
+ dataset,
204
+ batch_size=MAX_BATCH_SIZE,
205
+ num_workers=0, # must be 0 to avoid multiprocessing issues
206
+ )
207
+
208
+ embeddings = np.empty((total_images, self._model.output_dim), dtype=np.float32)
209
+ position = 0
210
+ with tqdm(
211
+ total=total_images, desc="Generating embeddings", unit=" images"
212
+ ) as progress_bar, torch.no_grad():
213
+ for images_tensor in loader:
214
+ imgs = images_tensor.to(self._device, non_blocking=True)
215
+ batch_embeddings = self._model.encode_image(imgs, normalize=True).cpu().numpy()
216
+ batch_size = imgs.size(0)
217
+ embeddings[position : position + batch_size] = batch_embeddings
218
+ position += batch_size
219
+ progress_bar.update(batch_size)
220
+
221
+ return embeddings
222
+
223
+ def embed_videos(self, filepaths: list[str]) -> NDArray[np.float32]:
224
+ """Embed videos with Perception Encoder.
225
+
226
+ Args:
227
+ filepaths: A list of file paths to the videos to embed.
228
+
229
+ Returns:
230
+ A numpy array representing the generated embeddings
231
+ in the same order as the input file paths.
232
+ """
233
+ dataset = _VideoFileDataset(filepaths, self._preprocess)
234
+
235
+ # To avoid issues with db locking and multiprocessing we set the
236
+ # number of workers to 0 (no multiprocessing). The DataLoader is still
237
+ # very useful for batching and async prefetching of videos.
238
+ loader = DataLoader(
239
+ dataset,
240
+ batch_size=MAX_BATCH_SIZE,
241
+ num_workers=0, # must be 0 to avoid multiprocessing issues
242
+ )
243
+ total_videos = len(filepaths)
244
+ if not total_videos:
245
+ return np.empty((0, self._model.output_dim), dtype=np.float32)
246
+
247
+ embeddings = np.empty((total_videos, self._model.output_dim), dtype=np.float32)
248
+ position = 0
249
+ with tqdm(
250
+ total=total_videos, desc="Generating embeddings", unit=" videos"
251
+ ) as progress_bar, torch.no_grad():
252
+ for videos_tensor in loader:
253
+ videos = videos_tensor.to(self._device, non_blocking=True)
254
+ batch_embeddings = self._model.encode_video(videos, normalize=True).cpu().numpy()
255
+ batch_size = videos.size(0)
256
+ embeddings[position : position + batch_size] = batch_embeddings
257
+ position += batch_size
258
+ progress_bar.update(batch_size)
259
+
260
+ return embeddings