lightly-studio 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (356) hide show
  1. lightly_studio/__init__.py +12 -0
  2. lightly_studio/api/__init__.py +0 -0
  3. lightly_studio/api/app.py +131 -0
  4. lightly_studio/api/cache.py +77 -0
  5. lightly_studio/api/db_tables.py +35 -0
  6. lightly_studio/api/features.py +5 -0
  7. lightly_studio/api/routes/api/annotation.py +305 -0
  8. lightly_studio/api/routes/api/annotation_label.py +87 -0
  9. lightly_studio/api/routes/api/annotations/__init__.py +7 -0
  10. lightly_studio/api/routes/api/annotations/create_annotation.py +52 -0
  11. lightly_studio/api/routes/api/caption.py +100 -0
  12. lightly_studio/api/routes/api/classifier.py +384 -0
  13. lightly_studio/api/routes/api/dataset.py +191 -0
  14. lightly_studio/api/routes/api/dataset_tag.py +266 -0
  15. lightly_studio/api/routes/api/embeddings2d.py +90 -0
  16. lightly_studio/api/routes/api/exceptions.py +114 -0
  17. lightly_studio/api/routes/api/export.py +114 -0
  18. lightly_studio/api/routes/api/features.py +17 -0
  19. lightly_studio/api/routes/api/frame.py +241 -0
  20. lightly_studio/api/routes/api/image.py +155 -0
  21. lightly_studio/api/routes/api/metadata.py +161 -0
  22. lightly_studio/api/routes/api/operator.py +75 -0
  23. lightly_studio/api/routes/api/sample.py +103 -0
  24. lightly_studio/api/routes/api/selection.py +87 -0
  25. lightly_studio/api/routes/api/settings.py +41 -0
  26. lightly_studio/api/routes/api/status.py +19 -0
  27. lightly_studio/api/routes/api/text_embedding.py +50 -0
  28. lightly_studio/api/routes/api/validators.py +17 -0
  29. lightly_studio/api/routes/api/video.py +133 -0
  30. lightly_studio/api/routes/healthz.py +13 -0
  31. lightly_studio/api/routes/images.py +104 -0
  32. lightly_studio/api/routes/video_frames_media.py +116 -0
  33. lightly_studio/api/routes/video_media.py +223 -0
  34. lightly_studio/api/routes/webapp.py +51 -0
  35. lightly_studio/api/server.py +94 -0
  36. lightly_studio/core/__init__.py +0 -0
  37. lightly_studio/core/add_samples.py +533 -0
  38. lightly_studio/core/add_videos.py +294 -0
  39. lightly_studio/core/dataset.py +780 -0
  40. lightly_studio/core/dataset_query/__init__.py +14 -0
  41. lightly_studio/core/dataset_query/boolean_expression.py +67 -0
  42. lightly_studio/core/dataset_query/dataset_query.py +317 -0
  43. lightly_studio/core/dataset_query/field.py +113 -0
  44. lightly_studio/core/dataset_query/field_expression.py +79 -0
  45. lightly_studio/core/dataset_query/match_expression.py +23 -0
  46. lightly_studio/core/dataset_query/order_by.py +79 -0
  47. lightly_studio/core/dataset_query/sample_field.py +37 -0
  48. lightly_studio/core/dataset_query/tags_expression.py +46 -0
  49. lightly_studio/core/image_sample.py +36 -0
  50. lightly_studio/core/loading_log.py +56 -0
  51. lightly_studio/core/sample.py +291 -0
  52. lightly_studio/core/start_gui.py +54 -0
  53. lightly_studio/core/video_sample.py +38 -0
  54. lightly_studio/dataset/__init__.py +0 -0
  55. lightly_studio/dataset/edge_embedding_generator.py +155 -0
  56. lightly_studio/dataset/embedding_generator.py +129 -0
  57. lightly_studio/dataset/embedding_manager.py +349 -0
  58. lightly_studio/dataset/env.py +20 -0
  59. lightly_studio/dataset/file_utils.py +49 -0
  60. lightly_studio/dataset/fsspec_lister.py +275 -0
  61. lightly_studio/dataset/mobileclip_embedding_generator.py +158 -0
  62. lightly_studio/dataset/perception_encoder_embedding_generator.py +260 -0
  63. lightly_studio/db_manager.py +166 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.GcXvs2l7.css +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/12.Dx6SXgAb.css +1 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/17.9X9_k6TP.css +1 -0
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/18.BxiimdIO.css +1 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/2.CkOblLn7.css +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/ClassifierSamplesGrid.BJbCDlvs.css +1 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Medium.DVUZMR_6.ttf +0 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
  78. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.CefECEWA.css +1 -0
  79. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.D5tDcjY-.css +1 -0
  80. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.9X9_k6TP.css +1 -0
  81. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.BxiimdIO.css +1 -0
  82. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.Dx6SXgAb.css +1 -0
  83. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/transform._-1mPSEI.css +1 -0
  84. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/0dDyq72A.js +20 -0
  85. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
  86. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BK4An2kI.js +1 -0
  87. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BRmB-kJ9.js +1 -0
  88. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B_1cpokE.js +1 -0
  89. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BiqpDEr0.js +1 -0
  90. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BpLiSKgx.js +1 -0
  91. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BscxbINH.js +39 -0
  92. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C1FmrZbK.js +1 -0
  93. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C80h3dJx.js +1 -0
  94. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8mfFM-u.js +2 -0
  95. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CGY1p9L4.js +517 -0
  96. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/COfLknXM.js +1 -0
  97. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
  98. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
  99. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CmLg0ys7.js +1 -0
  100. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvGjimpO.js +1 -0
  101. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D3RDXHoj.js +39 -0
  102. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D4y7iiT3.js +1 -0
  103. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D9SC3jBb.js +1 -0
  104. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DCuAdx1Q.js +20 -0
  105. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DDBy-_jD.js +1 -0
  106. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
  107. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DL9a7v5o.js +1 -0
  108. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DSKECuqX.js +39 -0
  109. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D_FFv0Oe.js +1 -0
  110. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DiZ5o5vz.js +1 -0
  111. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DkbXUtyG.js +1 -0
  112. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DmK2hulV.js +1 -0
  113. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DqnHaLTj.js +1 -0
  114. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DtWZc_tl.js +1 -0
  115. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DuUalyFS.js +1 -0
  116. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DwIonDAZ.js +1 -0
  117. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Il-mSPmK.js +1 -0
  118. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KNLP4aJU.js +1 -0
  119. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KjYeVjkE.js +1 -0
  120. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/MErlcOXj.js +1 -0
  121. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VRI4prUD.js +1 -0
  122. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VYb2dkNs.js +1 -0
  123. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VqWvU2yF.js +1 -0
  124. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/dHC3otuL.js +1 -0
  125. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/da7Oy_lO.js +1 -0
  126. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/eAy8rZzC.js +2 -0
  127. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/erjNR5MX.js +1 -0
  128. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/f1oG3eFE.js +1 -0
  129. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/rsLi1iKv.js +20 -0
  130. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/rwuuBP9f.js +1 -0
  131. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/xGHZQ1pe.js +3 -0
  132. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.DrTRUgT3.js +2 -0
  133. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.BK5EOJl2.js +1 -0
  134. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.CIvTuljF.js +4 -0
  135. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.UBvSzxdA.js +1 -0
  136. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.CQ_tiLJa.js +1 -0
  137. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.KqkAcaxW.js +1 -0
  138. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.DoYsmxQc.js +1 -0
  139. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/13.571n2LZA.js +1 -0
  140. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/14.DGs689M-.js +1 -0
  141. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/15.CWG1ehzT.js +1 -0
  142. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/16.Dpq6jbSh.js +1 -0
  143. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/17.B5AZbHUU.js +1 -0
  144. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/18.CBga8cnq.js +1 -0
  145. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.D2HXgz-8.js +1090 -0
  146. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.f4HAg-y3.js +1 -0
  147. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.BKF4xuKQ.js +1 -0
  148. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.BAE0Pm_f.js +39 -0
  149. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.CouWWpzA.js +1 -0
  150. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.UBHT0ktp.js +1 -0
  151. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.FiYNElcc.js +1 -0
  152. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.B3-UaT23.js +1 -0
  153. lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/clustering.worker-DKqeLtG0.js +2 -0
  154. lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/search.worker-vNSty3B0.js +1 -0
  155. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
  156. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
  157. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
  158. lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
  159. lightly_studio/dist_lightly_studio_view_app/index.html +45 -0
  160. lightly_studio/errors.py +5 -0
  161. lightly_studio/examples/example.py +25 -0
  162. lightly_studio/examples/example_coco.py +27 -0
  163. lightly_studio/examples/example_coco_caption.py +29 -0
  164. lightly_studio/examples/example_metadata.py +369 -0
  165. lightly_studio/examples/example_operators.py +111 -0
  166. lightly_studio/examples/example_selection.py +28 -0
  167. lightly_studio/examples/example_split_work.py +48 -0
  168. lightly_studio/examples/example_video.py +22 -0
  169. lightly_studio/examples/example_video_annotations.py +157 -0
  170. lightly_studio/examples/example_yolo.py +22 -0
  171. lightly_studio/export/coco_captions.py +69 -0
  172. lightly_studio/export/export_dataset.py +104 -0
  173. lightly_studio/export/lightly_studio_label_input.py +120 -0
  174. lightly_studio/export_schema.py +18 -0
  175. lightly_studio/export_version.py +57 -0
  176. lightly_studio/few_shot_classifier/__init__.py +0 -0
  177. lightly_studio/few_shot_classifier/classifier.py +80 -0
  178. lightly_studio/few_shot_classifier/classifier_manager.py +644 -0
  179. lightly_studio/few_shot_classifier/random_forest_classifier.py +495 -0
  180. lightly_studio/metadata/complex_metadata.py +47 -0
  181. lightly_studio/metadata/compute_similarity.py +84 -0
  182. lightly_studio/metadata/compute_typicality.py +67 -0
  183. lightly_studio/metadata/gps_coordinate.py +41 -0
  184. lightly_studio/metadata/metadata_protocol.py +17 -0
  185. lightly_studio/models/__init__.py +1 -0
  186. lightly_studio/models/annotation/__init__.py +0 -0
  187. lightly_studio/models/annotation/annotation_base.py +303 -0
  188. lightly_studio/models/annotation/instance_segmentation.py +56 -0
  189. lightly_studio/models/annotation/links.py +17 -0
  190. lightly_studio/models/annotation/object_detection.py +47 -0
  191. lightly_studio/models/annotation/semantic_segmentation.py +44 -0
  192. lightly_studio/models/annotation_label.py +47 -0
  193. lightly_studio/models/caption.py +49 -0
  194. lightly_studio/models/classifier.py +20 -0
  195. lightly_studio/models/dataset.py +70 -0
  196. lightly_studio/models/embedding_model.py +30 -0
  197. lightly_studio/models/image.py +96 -0
  198. lightly_studio/models/metadata.py +208 -0
  199. lightly_studio/models/range.py +17 -0
  200. lightly_studio/models/sample.py +154 -0
  201. lightly_studio/models/sample_embedding.py +36 -0
  202. lightly_studio/models/settings.py +69 -0
  203. lightly_studio/models/tag.py +96 -0
  204. lightly_studio/models/two_dim_embedding.py +16 -0
  205. lightly_studio/models/video.py +161 -0
  206. lightly_studio/plugins/__init__.py +0 -0
  207. lightly_studio/plugins/base_operator.py +60 -0
  208. lightly_studio/plugins/operator_registry.py +47 -0
  209. lightly_studio/plugins/parameter.py +70 -0
  210. lightly_studio/py.typed +0 -0
  211. lightly_studio/resolvers/__init__.py +0 -0
  212. lightly_studio/resolvers/annotation_label_resolver/__init__.py +22 -0
  213. lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
  214. lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
  215. lightly_studio/resolvers/annotation_label_resolver/get_all.py +37 -0
  216. lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
  217. lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
  218. lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
  219. lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
  220. lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
  221. lightly_studio/resolvers/annotation_resolver/__init__.py +40 -0
  222. lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +129 -0
  223. lightly_studio/resolvers/annotation_resolver/create_many.py +124 -0
  224. lightly_studio/resolvers/annotation_resolver/delete_annotation.py +87 -0
  225. lightly_studio/resolvers/annotation_resolver/delete_annotations.py +60 -0
  226. lightly_studio/resolvers/annotation_resolver/get_all.py +85 -0
  227. lightly_studio/resolvers/annotation_resolver/get_all_with_payload.py +179 -0
  228. lightly_studio/resolvers/annotation_resolver/get_by_id.py +34 -0
  229. lightly_studio/resolvers/annotation_resolver/get_by_id_with_payload.py +130 -0
  230. lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +142 -0
  231. lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
  232. lightly_studio/resolvers/annotations/__init__.py +1 -0
  233. lightly_studio/resolvers/annotations/annotations_filter.py +88 -0
  234. lightly_studio/resolvers/caption_resolver.py +129 -0
  235. lightly_studio/resolvers/dataset_resolver/__init__.py +55 -0
  236. lightly_studio/resolvers/dataset_resolver/check_dataset_type.py +29 -0
  237. lightly_studio/resolvers/dataset_resolver/create.py +20 -0
  238. lightly_studio/resolvers/dataset_resolver/delete.py +20 -0
  239. lightly_studio/resolvers/dataset_resolver/export.py +267 -0
  240. lightly_studio/resolvers/dataset_resolver/get_all.py +19 -0
  241. lightly_studio/resolvers/dataset_resolver/get_by_id.py +16 -0
  242. lightly_studio/resolvers/dataset_resolver/get_by_name.py +12 -0
  243. lightly_studio/resolvers/dataset_resolver/get_dataset_details.py +27 -0
  244. lightly_studio/resolvers/dataset_resolver/get_hierarchy.py +31 -0
  245. lightly_studio/resolvers/dataset_resolver/get_or_create_child_dataset.py +58 -0
  246. lightly_studio/resolvers/dataset_resolver/get_parent_dataset_by_sample_id.py +27 -0
  247. lightly_studio/resolvers/dataset_resolver/get_parent_dataset_id.py +22 -0
  248. lightly_studio/resolvers/dataset_resolver/get_root_dataset.py +61 -0
  249. lightly_studio/resolvers/dataset_resolver/get_root_datasets_overview.py +41 -0
  250. lightly_studio/resolvers/dataset_resolver/update.py +25 -0
  251. lightly_studio/resolvers/embedding_model_resolver.py +120 -0
  252. lightly_studio/resolvers/image_filter.py +50 -0
  253. lightly_studio/resolvers/image_resolver/__init__.py +21 -0
  254. lightly_studio/resolvers/image_resolver/create_many.py +52 -0
  255. lightly_studio/resolvers/image_resolver/delete.py +20 -0
  256. lightly_studio/resolvers/image_resolver/filter_new_paths.py +23 -0
  257. lightly_studio/resolvers/image_resolver/get_all_by_dataset_id.py +117 -0
  258. lightly_studio/resolvers/image_resolver/get_by_id.py +14 -0
  259. lightly_studio/resolvers/image_resolver/get_dimension_bounds.py +75 -0
  260. lightly_studio/resolvers/image_resolver/get_many_by_id.py +22 -0
  261. lightly_studio/resolvers/image_resolver/get_samples_excluding.py +43 -0
  262. lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
  263. lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
  264. lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
  265. lightly_studio/resolvers/metadata_resolver/sample/bulk_update_metadata.py +46 -0
  266. lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
  267. lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
  268. lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
  269. lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
  270. lightly_studio/resolvers/sample_embedding_resolver.py +132 -0
  271. lightly_studio/resolvers/sample_resolver/__init__.py +17 -0
  272. lightly_studio/resolvers/sample_resolver/count_by_dataset_id.py +16 -0
  273. lightly_studio/resolvers/sample_resolver/create.py +16 -0
  274. lightly_studio/resolvers/sample_resolver/create_many.py +25 -0
  275. lightly_studio/resolvers/sample_resolver/get_by_id.py +14 -0
  276. lightly_studio/resolvers/sample_resolver/get_filtered_samples.py +56 -0
  277. lightly_studio/resolvers/sample_resolver/get_many_by_id.py +22 -0
  278. lightly_studio/resolvers/sample_resolver/sample_filter.py +74 -0
  279. lightly_studio/resolvers/settings_resolver.py +62 -0
  280. lightly_studio/resolvers/tag_resolver.py +299 -0
  281. lightly_studio/resolvers/twodim_embedding_resolver.py +119 -0
  282. lightly_studio/resolvers/video_frame_resolver/__init__.py +23 -0
  283. lightly_studio/resolvers/video_frame_resolver/count_video_frames_annotations.py +83 -0
  284. lightly_studio/resolvers/video_frame_resolver/create_many.py +57 -0
  285. lightly_studio/resolvers/video_frame_resolver/get_all_by_dataset_id.py +63 -0
  286. lightly_studio/resolvers/video_frame_resolver/get_by_id.py +13 -0
  287. lightly_studio/resolvers/video_frame_resolver/get_table_fields_bounds.py +44 -0
  288. lightly_studio/resolvers/video_frame_resolver/video_frame_annotations_counter_filter.py +47 -0
  289. lightly_studio/resolvers/video_frame_resolver/video_frame_filter.py +57 -0
  290. lightly_studio/resolvers/video_resolver/__init__.py +27 -0
  291. lightly_studio/resolvers/video_resolver/count_video_frame_annotations_by_video_dataset.py +86 -0
  292. lightly_studio/resolvers/video_resolver/create_many.py +58 -0
  293. lightly_studio/resolvers/video_resolver/filter_new_paths.py +33 -0
  294. lightly_studio/resolvers/video_resolver/get_all_by_dataset_id.py +181 -0
  295. lightly_studio/resolvers/video_resolver/get_by_id.py +22 -0
  296. lightly_studio/resolvers/video_resolver/get_table_fields_bounds.py +72 -0
  297. lightly_studio/resolvers/video_resolver/get_view_by_id.py +52 -0
  298. lightly_studio/resolvers/video_resolver/video_count_annotations_filter.py +50 -0
  299. lightly_studio/resolvers/video_resolver/video_filter.py +98 -0
  300. lightly_studio/selection/__init__.py +1 -0
  301. lightly_studio/selection/mundig.py +143 -0
  302. lightly_studio/selection/select.py +203 -0
  303. lightly_studio/selection/select_via_db.py +273 -0
  304. lightly_studio/selection/selection_config.py +49 -0
  305. lightly_studio/services/annotations_service/__init__.py +33 -0
  306. lightly_studio/services/annotations_service/create_annotation.py +64 -0
  307. lightly_studio/services/annotations_service/delete_annotation.py +22 -0
  308. lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
  309. lightly_studio/services/annotations_service/update_annotation.py +54 -0
  310. lightly_studio/services/annotations_service/update_annotation_bounding_box.py +36 -0
  311. lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
  312. lightly_studio/services/annotations_service/update_annotations.py +29 -0
  313. lightly_studio/setup_logging.py +59 -0
  314. lightly_studio/type_definitions.py +31 -0
  315. lightly_studio/utils/__init__.py +3 -0
  316. lightly_studio/utils/download.py +94 -0
  317. lightly_studio/vendor/__init__.py +1 -0
  318. lightly_studio/vendor/mobileclip/ACKNOWLEDGEMENTS +422 -0
  319. lightly_studio/vendor/mobileclip/LICENSE +31 -0
  320. lightly_studio/vendor/mobileclip/LICENSE_weights_data +50 -0
  321. lightly_studio/vendor/mobileclip/README.md +5 -0
  322. lightly_studio/vendor/mobileclip/__init__.py +96 -0
  323. lightly_studio/vendor/mobileclip/clip.py +77 -0
  324. lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
  325. lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
  326. lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
  327. lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
  328. lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
  329. lightly_studio/vendor/mobileclip/logger.py +154 -0
  330. lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
  331. lightly_studio/vendor/mobileclip/models/mci.py +933 -0
  332. lightly_studio/vendor/mobileclip/models/vit.py +433 -0
  333. lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
  334. lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
  335. lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
  336. lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
  337. lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
  338. lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
  339. lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
  340. lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
  341. lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
  342. lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
  343. lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
  344. lightly_studio/vendor/perception_encoder/LICENSE.PE +201 -0
  345. lightly_studio/vendor/perception_encoder/README.md +11 -0
  346. lightly_studio/vendor/perception_encoder/vision_encoder/__init__.py +0 -0
  347. lightly_studio/vendor/perception_encoder/vision_encoder/bpe_simple_vocab_16e6.txt.gz +0 -0
  348. lightly_studio/vendor/perception_encoder/vision_encoder/config.py +205 -0
  349. lightly_studio/vendor/perception_encoder/vision_encoder/config_src.py +264 -0
  350. lightly_studio/vendor/perception_encoder/vision_encoder/pe.py +766 -0
  351. lightly_studio/vendor/perception_encoder/vision_encoder/rope.py +352 -0
  352. lightly_studio/vendor/perception_encoder/vision_encoder/tokenizer.py +347 -0
  353. lightly_studio/vendor/perception_encoder/vision_encoder/transforms.py +36 -0
  354. lightly_studio-0.4.6.dist-info/METADATA +88 -0
  355. lightly_studio-0.4.6.dist-info/RECORD +356 -0
  356. lightly_studio-0.4.6.dist-info/WHEEL +4 -0
@@ -0,0 +1,644 @@
1
+ """ClassifierManager implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import io
7
+ from collections.abc import Sequence
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from random import sample
11
+ from uuid import UUID, uuid4
12
+
13
+ from sqlmodel import Session
14
+
15
+ from lightly_studio.few_shot_classifier import random_forest_classifier
16
+ from lightly_studio.few_shot_classifier.classifier import (
17
+ AnnotatedEmbedding,
18
+ ExportType,
19
+ )
20
+ from lightly_studio.few_shot_classifier.random_forest_classifier import (
21
+ RandomForest,
22
+ )
23
+ from lightly_studio.models.annotation.annotation_base import (
24
+ AnnotationCreate,
25
+ AnnotationType,
26
+ )
27
+ from lightly_studio.models.annotation_label import (
28
+ AnnotationLabelCreate,
29
+ )
30
+ from lightly_studio.models.classifier import EmbeddingClassifier
31
+ from lightly_studio.models.image import ImageTable
32
+ from lightly_studio.resolvers import (
33
+ annotation_label_resolver,
34
+ annotation_resolver,
35
+ embedding_model_resolver,
36
+ image_resolver,
37
+ sample_embedding_resolver,
38
+ )
39
+
40
+ HIGH_CONFIDENCE_THRESHOLD = 0.5
41
+ LOW_CONFIDENCE_THRESHOLD = 0.5
42
+
43
+ HIGH_CONFIDENCE_SAMPLES_NEEDED = 10
44
+ LOW_CONFIDENCE_SAMPLES_NEEDED = 10
45
+
46
+ FSC_ANNOTATION_TASK_PREFIX = "FSC_"
47
+
48
+
49
+ class ClassifierManagerProvider:
50
+ """Provider for the ClassifierManager singleton instance."""
51
+
52
+ _instance: ClassifierManager | None = None
53
+
54
+ @classmethod
55
+ def get_classifier_manager(cls) -> ClassifierManager:
56
+ """Get the singleton instance of ClassifierManager.
57
+
58
+ Returns:
59
+ The singleton instance of ClassifierManager.
60
+
61
+ Raises:
62
+ ValueError: If no instance exists and no session is provided.
63
+ """
64
+ if cls._instance is None:
65
+ cls._instance = ClassifierManager()
66
+ return cls._instance
67
+
68
+
69
+ @dataclass
70
+ class ClassifierEntry:
71
+ """Classifier dataclass."""
72
+
73
+ classifier_id: UUID
74
+
75
+ # TODO(Horatiu, 05/2025): Use FewShotClassifier instead of RandomForest
76
+ # when the interface is ready. Add method to get classifier info.
77
+ few_shot_classifier: RandomForest
78
+
79
+ # Annotations history is used to keep track of the samples that have
80
+ # been used for training. It is a dictionary with the key the class name and
81
+ # the value a list of sample IDs that belong to that class.
82
+ # This is used to avoid using the same samples for fine tuning multiple
83
+ # times.
84
+ annotations: dict[str, list[UUID]]
85
+
86
+ # Inactive classifiers are used for handling the fine tuning process.
87
+ # From the moment the classifier is created untill it is saved is_active
88
+ # will be false.
89
+ is_active: bool = False
90
+
91
+ annotation_label_ids: list[UUID] | None = None
92
+
93
+
94
+ class ClassifierManager:
95
+ """ClassifierManager class.
96
+
97
+ This class manages the lifecycle of a few-shot classifier,
98
+ including training, exporting, and loading the classifier.
99
+ """
100
+
101
+ def __init__(self) -> None:
102
+ """Initialize the ClassifierManager."""
103
+ self._classifiers: dict[UUID, ClassifierEntry] = {}
104
+
105
+ def create_classifier(
106
+ self,
107
+ session: Session,
108
+ name: str,
109
+ class_list: list[str],
110
+ dataset_id: UUID,
111
+ ) -> ClassifierEntry:
112
+ """Create a new classifier.
113
+
114
+ Args:
115
+ session: Database session for resolver operations.
116
+ name: The name of the classifier.
117
+ class_list: List of classes to be used for training.
118
+ dataset_id: The dataset_id to which the samples belong.
119
+
120
+ Returns:
121
+ The created classifier name and ID.
122
+ """
123
+ embedding_models = embedding_model_resolver.get_all_by_dataset_id(
124
+ session=session,
125
+ dataset_id=dataset_id,
126
+ )
127
+ if len(embedding_models) == 0:
128
+ raise ValueError("No embedding model found for the given dataset ID.")
129
+ # TODO(Horatiu, 05/2025): Handle multiple models correctly when
130
+ # available
131
+ if len(embedding_models) > 1:
132
+ raise ValueError("Multiple embedding models found for the given dataset ID.")
133
+ embedding_model = embedding_models[0]
134
+ classifier = RandomForest(
135
+ name=name,
136
+ classes=class_list,
137
+ embedding_model_hash=embedding_model.embedding_model_hash,
138
+ embedding_model_name=embedding_model.name,
139
+ )
140
+
141
+ classifier_id = uuid4()
142
+ self._classifiers[classifier_id] = ClassifierEntry(
143
+ classifier_id=classifier_id,
144
+ few_shot_classifier=classifier,
145
+ is_active=False,
146
+ annotations={class_name: [] for class_name in class_list},
147
+ )
148
+
149
+ return self._classifiers[classifier_id]
150
+
151
+ def train_classifier(self, session: Session, classifier_id: UUID) -> None:
152
+ """Train the classifier.
153
+
154
+ Args:
155
+ session: Database session for resolver operations.
156
+ classifier_id: The ID of the classifier to train.
157
+
158
+ Raises:
159
+ ValueError: If the classifier with the given ID does not exist.
160
+ """
161
+ classifier = self._classifiers.get(classifier_id)
162
+ if classifier is None:
163
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
164
+
165
+ embedding_model = embedding_model_resolver.get_by_model_hash(
166
+ session=session,
167
+ embedding_model_hash=classifier.few_shot_classifier.embedding_model_hash,
168
+ )
169
+ if embedding_model is None:
170
+ raise ValueError(
171
+ "No embedding model found for hash '"
172
+ f"{classifier.few_shot_classifier.embedding_model_hash}'"
173
+ )
174
+
175
+ # Get annotations.
176
+ annotations = classifier.annotations
177
+ annotated_embeddings = _create_annotated_embeddings(
178
+ session=session,
179
+ class_to_sample_ids=annotations,
180
+ embedding_model_id=embedding_model.embedding_model_id,
181
+ )
182
+ # Train the classifier with the annotated embeddings.
183
+ # This will overwrite the previous training.
184
+ classifier.few_shot_classifier.train(annotated_embeddings)
185
+
186
+ def commit_temp_classifier(self, classifier_id: UUID) -> None:
187
+ """Set the classifier as active.
188
+
189
+ Args:
190
+ classifier_id: The ID of the classifier to save.
191
+
192
+ Raises:
193
+ ValueError: If the classifier with the given ID does not exist
194
+ or if the classifier is not yet trained.
195
+ """
196
+ classifier = self._classifiers.get(classifier_id)
197
+ if classifier is None:
198
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
199
+ if classifier.few_shot_classifier.is_trained() is False:
200
+ raise ValueError(f"Classifier with ID {classifier_id} is not trained yet.")
201
+ classifier.is_active = True
202
+
203
+ def drop_temp_classifier(self, classifier_id: UUID) -> None:
204
+ """Remove a classifier that is inactive.
205
+
206
+ Args:
207
+ classifier_id: The ID of the classifier to drop.
208
+
209
+ Raises:
210
+ ValueError: If the classifier with the given ID does not exist.
211
+ """
212
+ classifier = self._classifiers.get(classifier_id)
213
+ if classifier is None:
214
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
215
+ if classifier.is_active:
216
+ raise ValueError(f"Classifier with ID {classifier_id} is active and cannot be dropped.")
217
+ self._classifiers.pop(classifier_id, None)
218
+
219
+ def save_classifier_to_file(self, classifier_id: UUID, file_path: Path) -> None:
220
+ """Save the classifier to file.
221
+
222
+ Args:
223
+ classifier_id: The ID of the classifier to save.
224
+ file_path: The path to save the classifer to.
225
+
226
+ Raises:
227
+ ValueError: If the classifier with the given ID does not exist.
228
+ """
229
+ classifier = self._classifiers.get(classifier_id)
230
+ if classifier is None:
231
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
232
+ if not classifier.is_active:
233
+ raise ValueError(
234
+ f"Classifier with ID {classifier_id} is not active and cannot be saved."
235
+ )
236
+ classifier.few_shot_classifier.export(export_path=file_path, export_type="sklearn")
237
+
238
+ def load_classifier_from_file(self, session: Session, file_path: Path) -> ClassifierEntry:
239
+ """Loads a classifier from file.
240
+
241
+ Args:
242
+ session: Database session for resolver operations.
243
+ file_path: The path from where to load the classifier.
244
+
245
+ Returns:
246
+ The ID of the loaded classifier.
247
+ """
248
+ classifier = random_forest_classifier.load_random_forest_classifier(
249
+ classifier_path=file_path, buffer=None
250
+ )
251
+ embedding_model = embedding_model_resolver.get_by_model_hash(
252
+ session=session,
253
+ embedding_model_hash=classifier.embedding_model_hash,
254
+ )
255
+ if embedding_model is None:
256
+ raise ValueError(
257
+ "No matching embedding model found for the classifier's hash:"
258
+ f"'{classifier.embedding_model_hash}'."
259
+ )
260
+
261
+ classifier_id = uuid4()
262
+ self._classifiers[classifier_id] = ClassifierEntry(
263
+ classifier_id=classifier_id,
264
+ few_shot_classifier=classifier,
265
+ is_active=True,
266
+ annotations={class_name: [] for class_name in classifier.classes},
267
+ )
268
+ return self._classifiers[classifier_id]
269
+
270
+ def provide_negative_samples(
271
+ self, session: Session, dataset_id: UUID, selected_samples: list[UUID], limit: int = 10
272
+ ) -> Sequence[ImageTable]:
273
+ """Provide random samples that are not in the selected samples.
274
+
275
+ Args:
276
+ session: Database session for resolver operations.
277
+ dataset_id: The dataset_id to pull samples from.
278
+ selected_samples: List of sample UUIDs to exclude.
279
+ limit: Number of negative samples to return.
280
+
281
+ Returns:
282
+ List of negative samples.
283
+
284
+ """
285
+ return image_resolver.get_samples_excluding(
286
+ session=session,
287
+ dataset_id=dataset_id,
288
+ excluded_sample_ids=selected_samples,
289
+ limit=limit,
290
+ )
291
+
292
+ def update_classifiers_annotations(
293
+ self,
294
+ classifier_id: UUID,
295
+ new_annotations: dict[str, list[UUID]],
296
+ ) -> None:
297
+ """Update annotations with new samples for multiple classes.
298
+
299
+ Args:
300
+ classifier_id: The ID of the classifier.
301
+ new_annotations: Dictionary mapping class names to lists of sample
302
+ IDs.
303
+
304
+ Raises:
305
+ ValueError: If the classifier doesn't exist.
306
+ """
307
+ classifier = self._classifiers.get(classifier_id)
308
+ if classifier is None:
309
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
310
+
311
+ annotations = classifier.annotations
312
+ # Validate no new classes are being added.
313
+ if not set(new_annotations.keys()).issubset(annotations.keys()):
314
+ invalid_classes = set(new_annotations.keys()) - set(annotations.keys())
315
+ raise ValueError(
316
+ f"Cannot add new classes {invalid_classes} to existing"
317
+ f" classifier. Allowed classes are: {set(annotations.keys())}"
318
+ )
319
+
320
+ # Get all new samples that will be added.
321
+ all_new_samples = {
322
+ sample_id for samples in new_annotations.values() for sample_id in samples
323
+ }
324
+
325
+ # Update annotations.
326
+ for existing_class in annotations:
327
+ # Remove newly annotated samples if existing already
328
+ # and add samples for this class.
329
+ new_class_samples = set(new_annotations.get(existing_class, []))
330
+ annotations[existing_class] = list(
331
+ (set(annotations[existing_class]) - all_new_samples) | new_class_samples
332
+ )
333
+
334
+ def get_annotations(self, classifier_id: UUID) -> dict[str, list[UUID]]:
335
+ """Get all samples used in training for each class.
336
+
337
+ Args:
338
+ classifier_id: The ID of the classifier.
339
+
340
+ Returns:
341
+ Dictionary mapping class names to lists of sample IDs.
342
+
343
+ Raises:
344
+ ValueError: If the classifier doesn't exist.
345
+ """
346
+ classifier = self._classifiers.get(classifier_id)
347
+ if classifier is None:
348
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
349
+
350
+ return copy.deepcopy(classifier.annotations)
351
+
352
+ def get_samples_for_fine_tuning(
353
+ self, session: Session, dataset_id: UUID, classifier_id: UUID
354
+ ) -> dict[str, list[UUID]]:
355
+ """Get samples for fine-tuning the classifier.
356
+
357
+ Gets at most 20 samples total:
358
+ - 10 positive samples (prediction confidence > 0.5)
359
+ - 10 uncertain samples (prediction confidence < 0.5)
360
+ If there are not enough samples, it will return all available
361
+ samples of that type.
362
+
363
+ Args:
364
+ session: Database session for resolver operations.
365
+ dataset_id: The ID of the dataset to pull samples from.
366
+ classifier_id: The ID of the classifier to use.
367
+
368
+ Returns:
369
+ Dictionary mapping class names to sample IDs. The first class from
370
+ classifier.classes gets samples with high confidence predictions,
371
+ the second class gets samples with low confidence predictions.
372
+
373
+ Raises:
374
+ ValueError: If the classifier with the given ID does not exist
375
+ or there is no appropriate embedding model.
376
+ """
377
+ classifier = self._classifiers.get(classifier_id)
378
+ if classifier is None:
379
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
380
+ # Get all previously used annotations.
381
+ annotations = classifier.annotations
382
+ used_samples = {sample_id for samples in annotations.values() for sample_id in samples}
383
+
384
+ embedding_model = embedding_model_resolver.get_by_model_hash(
385
+ session=session,
386
+ embedding_model_hash=classifier.few_shot_classifier.embedding_model_hash,
387
+ )
388
+ if embedding_model is None:
389
+ raise ValueError(
390
+ "No embedding model found for hash '"
391
+ f"{classifier.few_shot_classifier.embedding_model_hash}'"
392
+ )
393
+
394
+ # Create list of SampleEmbedding objects to track sample IDs
395
+ sample_embeddings = sample_embedding_resolver.get_all_by_dataset_id(
396
+ session=session,
397
+ dataset_id=dataset_id,
398
+ embedding_model_id=embedding_model.embedding_model_id,
399
+ )
400
+
401
+ # Get predictions for all embeddings.
402
+ embeddings = [se.embedding for se in sample_embeddings]
403
+ predictions = classifier.few_shot_classifier.predict(embeddings)
404
+
405
+ # Group samples by prediction confidence.
406
+ high_conf = [] # > 0.5
407
+ low_conf = [] # <= 0.5
408
+
409
+ for sample_embedding, pred in zip(sample_embeddings, predictions):
410
+ if sample_embedding.sample_id in used_samples:
411
+ continue
412
+ if pred[0] > HIGH_CONFIDENCE_THRESHOLD:
413
+ high_conf.append(sample_embedding.sample_id)
414
+ elif pred[0] <= LOW_CONFIDENCE_THRESHOLD:
415
+ low_conf.append(sample_embedding.sample_id)
416
+
417
+ return {
418
+ classifier.few_shot_classifier.classes[0]: sample(
419
+ high_conf, min(len(high_conf), HIGH_CONFIDENCE_SAMPLES_NEEDED)
420
+ ),
421
+ classifier.few_shot_classifier.classes[1]: sample(
422
+ low_conf, min(len(low_conf), LOW_CONFIDENCE_SAMPLES_NEEDED)
423
+ ),
424
+ }
425
+
426
+ def run_classifier(self, session: Session, classifier_id: UUID, dataset_id: UUID) -> None:
427
+ """Run the classifier on the dataset.
428
+
429
+ Args:
430
+ session: Database session for resolver operations.
431
+ classifier_id: The ID of the classifier to run.
432
+ dataset_id: The ID of the dataset to run the classifier on.
433
+
434
+ Raises:
435
+ ValueError: If the classifier with the given ID does not exist
436
+ or there is no appropriate embedding model.
437
+ """
438
+ classifier = self._classifiers.get(classifier_id)
439
+ if classifier is None:
440
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
441
+
442
+ if not classifier.is_active:
443
+ raise ValueError(
444
+ f"Classifier with ID {classifier_id} is not active and cannot be used."
445
+ )
446
+ embedding_model = embedding_model_resolver.get_by_model_hash(
447
+ session=session,
448
+ embedding_model_hash=classifier.few_shot_classifier.embedding_model_hash,
449
+ )
450
+ if embedding_model is None:
451
+ raise ValueError(
452
+ "No embedding model found for hash '"
453
+ f"{classifier.few_shot_classifier.embedding_model_hash}'"
454
+ )
455
+
456
+ # Create list of SampleEmbedding objects to track sample IDs
457
+ sample_embeddings = sample_embedding_resolver.get_all_by_dataset_id(
458
+ session=session,
459
+ dataset_id=dataset_id,
460
+ embedding_model_id=embedding_model.embedding_model_id,
461
+ )
462
+
463
+ # Extract just the embeddings for prediction
464
+ embeddings = [se.embedding for se in sample_embeddings]
465
+ predictions = classifier.few_shot_classifier.predict(embeddings)
466
+ if len(predictions):
467
+ _create_annotation_labels_for_classifier(
468
+ classifier=classifier,
469
+ session=session,
470
+ dataset_id=dataset_id,
471
+ )
472
+ else:
473
+ raise ValueError(f"Predict returned empty list for classifier:'{classifier_id}'")
474
+ # Check if annotation labels are available
475
+ if not classifier.annotation_label_ids:
476
+ raise ValueError(f"Classifier with ID '{classifier_id}' has no annotation labels")
477
+
478
+ # For each prediction add a classification annotation for the
479
+ # sample or update an existing one.
480
+ classification_annotations = []
481
+ for sample_embedding, prediction in zip(sample_embeddings, predictions):
482
+ max_index = prediction.index(max(prediction))
483
+ classification_annotations.append(
484
+ AnnotationCreate(
485
+ parent_sample_id=sample_embedding.sample_id,
486
+ dataset_id=dataset_id,
487
+ annotation_label_id=classifier.annotation_label_ids[max_index],
488
+ annotation_type=AnnotationType.CLASSIFICATION,
489
+ confidence=prediction[max_index],
490
+ )
491
+ )
492
+ # Clear previous annotations by this classifier
493
+ annotation_resolver.delete_annotations(
494
+ session=session,
495
+ annotation_label_ids=classifier.annotation_label_ids,
496
+ )
497
+ annotation_resolver.create_many(
498
+ session=session, parent_dataset_id=dataset_id, annotations=classification_annotations
499
+ )
500
+
501
+ def get_all_classifiers(self) -> list[EmbeddingClassifier]:
502
+ """Get all active classifiers.
503
+
504
+ Returns:
505
+ List of EmbeddingClassifier objects representing active classifiers.
506
+ """
507
+ return [
508
+ EmbeddingClassifier(
509
+ classifier_name=classifier.few_shot_classifier.name,
510
+ classifier_id=classifier.classifier_id,
511
+ class_list=classifier.few_shot_classifier.classes,
512
+ )
513
+ for classifier in self._classifiers.values()
514
+ if classifier.is_active
515
+ ]
516
+
517
+ def get_classifier_by_id(self, classifier_id: UUID) -> EmbeddingClassifier:
518
+ """Get all active classifiers.
519
+
520
+ Args:
521
+ classifier_id: The ID of the classifier to get.
522
+
523
+ Raises:
524
+ ValueError: If the classifier with the given ID does not exist.
525
+
526
+ Returns:
527
+ EmbeddingClassifier object.
528
+ """
529
+ classifier = self._classifiers.get(classifier_id)
530
+ if classifier is None:
531
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
532
+ return EmbeddingClassifier(
533
+ classifier_name=classifier.few_shot_classifier.name,
534
+ classifier_id=classifier_id,
535
+ class_list=classifier.few_shot_classifier.classes,
536
+ )
537
+
538
+ def save_classifier_to_buffer(
539
+ self, classifier_id: UUID, buffer: io.BytesIO, export_type: ExportType
540
+ ) -> None:
541
+ """Save the classifier to a buffer.
542
+
543
+ Args:
544
+ classifier_id: The ID of the classifier to save.
545
+ buffer: The buffer to save the classifier to.
546
+ export_type: The type of export to perform.
547
+
548
+ Raises:
549
+ ValueError: If the classifier with the given ID does not exist.
550
+ """
551
+ classifier = self._classifiers.get(classifier_id)
552
+ if classifier is None:
553
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
554
+ classifier.few_shot_classifier.export(buffer=buffer, export_type=export_type)
555
+
556
+ def load_classifier_from_buffer(self, session: Session, buffer: io.BytesIO) -> ClassifierEntry:
557
+ """Loads a classifier from a buffer.
558
+
559
+ Args:
560
+ session: Database session for resolver operations.
561
+ buffer: The buffer containing the classifier data.
562
+
563
+ Returns:
564
+ The ID of the loaded classifier.
565
+
566
+ Raises:
567
+ ValueError: If no matching embedding model is found for the
568
+ classifier.
569
+ """
570
+ classifier = random_forest_classifier.load_random_forest_classifier(
571
+ buffer=buffer, classifier_path=None
572
+ )
573
+ embedding_model = embedding_model_resolver.get_by_model_hash(
574
+ session=session,
575
+ embedding_model_hash=classifier.embedding_model_hash,
576
+ )
577
+ if embedding_model is None:
578
+ raise ValueError(
579
+ "No matching embedding model found for the classifier's hash: "
580
+ f"'{classifier.embedding_model_hash}'."
581
+ )
582
+
583
+ classifier_id = uuid4()
584
+ self._classifiers[classifier_id] = ClassifierEntry(
585
+ classifier_id=classifier_id,
586
+ few_shot_classifier=classifier,
587
+ is_active=True,
588
+ annotations={class_name: [] for class_name in classifier.classes},
589
+ )
590
+ return self._classifiers[classifier_id]
591
+
592
+
593
+ def _create_annotation_labels_for_classifier(
594
+ session: Session,
595
+ dataset_id: UUID,
596
+ classifier: ClassifierEntry,
597
+ ) -> None:
598
+ """Create annotation labels for the classifier.
599
+
600
+ Args:
601
+ session: Database session.
602
+ dataset_id: The dataset ID to which the samples belong.
603
+ classifier: The classifier object to update.
604
+ """
605
+ # Check if the annotation label with the classifier name and class
606
+ # names exists and if not create it.
607
+ if classifier.annotation_label_ids is None:
608
+ annotation_label_ids = []
609
+ for class_name in classifier.few_shot_classifier.classes:
610
+ annotation_label = annotation_label_resolver.create(
611
+ session=session,
612
+ label=AnnotationLabelCreate(
613
+ dataset_id=dataset_id,
614
+ annotation_label_name=classifier.few_shot_classifier.name + "_" + class_name,
615
+ ),
616
+ )
617
+ annotation_label_ids.append(annotation_label.annotation_label_id)
618
+ classifier.annotation_label_ids = annotation_label_ids
619
+
620
+
621
+ def _create_annotated_embeddings(
622
+ session: Session,
623
+ class_to_sample_ids: dict[str, list[UUID]],
624
+ embedding_model_id: UUID,
625
+ ) -> list[AnnotatedEmbedding]:
626
+ """Create annotated embeddings from input data.
627
+
628
+ Args:
629
+ session: Database session.
630
+ class_to_sample_ids: Dictionary mapping class names to sample UUIDs.
631
+ embedding_model_id: The embedding model ID to filter by.
632
+
633
+ Returns:
634
+ List of annotated embeddings for training.
635
+ """
636
+ return [
637
+ AnnotatedEmbedding(embedding=embedding.embedding, annotation=class_name)
638
+ for class_name, sample_uuids in class_to_sample_ids.items()
639
+ for embedding in sample_embedding_resolver.get_by_sample_ids(
640
+ session=session,
641
+ sample_ids=sample_uuids,
642
+ embedding_model_id=embedding_model_id,
643
+ )
644
+ ]