lightly-studio 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightly_studio/__init__.py +12 -0
- lightly_studio/api/__init__.py +0 -0
- lightly_studio/api/app.py +131 -0
- lightly_studio/api/cache.py +77 -0
- lightly_studio/api/db_tables.py +35 -0
- lightly_studio/api/features.py +5 -0
- lightly_studio/api/routes/api/annotation.py +305 -0
- lightly_studio/api/routes/api/annotation_label.py +87 -0
- lightly_studio/api/routes/api/annotations/__init__.py +7 -0
- lightly_studio/api/routes/api/annotations/create_annotation.py +52 -0
- lightly_studio/api/routes/api/caption.py +100 -0
- lightly_studio/api/routes/api/classifier.py +384 -0
- lightly_studio/api/routes/api/dataset.py +191 -0
- lightly_studio/api/routes/api/dataset_tag.py +266 -0
- lightly_studio/api/routes/api/embeddings2d.py +90 -0
- lightly_studio/api/routes/api/exceptions.py +114 -0
- lightly_studio/api/routes/api/export.py +114 -0
- lightly_studio/api/routes/api/features.py +17 -0
- lightly_studio/api/routes/api/frame.py +241 -0
- lightly_studio/api/routes/api/image.py +155 -0
- lightly_studio/api/routes/api/metadata.py +161 -0
- lightly_studio/api/routes/api/operator.py +75 -0
- lightly_studio/api/routes/api/sample.py +103 -0
- lightly_studio/api/routes/api/selection.py +87 -0
- lightly_studio/api/routes/api/settings.py +41 -0
- lightly_studio/api/routes/api/status.py +19 -0
- lightly_studio/api/routes/api/text_embedding.py +50 -0
- lightly_studio/api/routes/api/validators.py +17 -0
- lightly_studio/api/routes/api/video.py +133 -0
- lightly_studio/api/routes/healthz.py +13 -0
- lightly_studio/api/routes/images.py +104 -0
- lightly_studio/api/routes/video_frames_media.py +116 -0
- lightly_studio/api/routes/video_media.py +223 -0
- lightly_studio/api/routes/webapp.py +51 -0
- lightly_studio/api/server.py +94 -0
- lightly_studio/core/__init__.py +0 -0
- lightly_studio/core/add_samples.py +533 -0
- lightly_studio/core/add_videos.py +294 -0
- lightly_studio/core/dataset.py +780 -0
- lightly_studio/core/dataset_query/__init__.py +14 -0
- lightly_studio/core/dataset_query/boolean_expression.py +67 -0
- lightly_studio/core/dataset_query/dataset_query.py +317 -0
- lightly_studio/core/dataset_query/field.py +113 -0
- lightly_studio/core/dataset_query/field_expression.py +79 -0
- lightly_studio/core/dataset_query/match_expression.py +23 -0
- lightly_studio/core/dataset_query/order_by.py +79 -0
- lightly_studio/core/dataset_query/sample_field.py +37 -0
- lightly_studio/core/dataset_query/tags_expression.py +46 -0
- lightly_studio/core/image_sample.py +36 -0
- lightly_studio/core/loading_log.py +56 -0
- lightly_studio/core/sample.py +291 -0
- lightly_studio/core/start_gui.py +54 -0
- lightly_studio/core/video_sample.py +38 -0
- lightly_studio/dataset/__init__.py +0 -0
- lightly_studio/dataset/edge_embedding_generator.py +155 -0
- lightly_studio/dataset/embedding_generator.py +129 -0
- lightly_studio/dataset/embedding_manager.py +349 -0
- lightly_studio/dataset/env.py +20 -0
- lightly_studio/dataset/file_utils.py +49 -0
- lightly_studio/dataset/fsspec_lister.py +275 -0
- lightly_studio/dataset/mobileclip_embedding_generator.py +158 -0
- lightly_studio/dataset/perception_encoder_embedding_generator.py +260 -0
- lightly_studio/db_manager.py +166 -0
- lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.GcXvs2l7.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/12.Dx6SXgAb.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/17.9X9_k6TP.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/18.BxiimdIO.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/2.CkOblLn7.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/ClassifierSamplesGrid.BJbCDlvs.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Medium.DVUZMR_6.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.CefECEWA.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.D5tDcjY-.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.9X9_k6TP.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.BxiimdIO.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.Dx6SXgAb.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/transform._-1mPSEI.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/0dDyq72A.js +20 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BK4An2kI.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BRmB-kJ9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B_1cpokE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BiqpDEr0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BpLiSKgx.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BscxbINH.js +39 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C1FmrZbK.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C80h3dJx.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8mfFM-u.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CGY1p9L4.js +517 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/COfLknXM.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CmLg0ys7.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvGjimpO.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D3RDXHoj.js +39 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D4y7iiT3.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D9SC3jBb.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DCuAdx1Q.js +20 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DDBy-_jD.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DL9a7v5o.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DSKECuqX.js +39 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D_FFv0Oe.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DiZ5o5vz.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DkbXUtyG.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DmK2hulV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DqnHaLTj.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DtWZc_tl.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DuUalyFS.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DwIonDAZ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Il-mSPmK.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KNLP4aJU.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KjYeVjkE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/MErlcOXj.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VRI4prUD.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VYb2dkNs.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VqWvU2yF.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/dHC3otuL.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/da7Oy_lO.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/eAy8rZzC.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/erjNR5MX.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/f1oG3eFE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/rsLi1iKv.js +20 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/rwuuBP9f.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/xGHZQ1pe.js +3 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.DrTRUgT3.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.BK5EOJl2.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.CIvTuljF.js +4 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.UBvSzxdA.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.CQ_tiLJa.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.KqkAcaxW.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.DoYsmxQc.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/13.571n2LZA.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/14.DGs689M-.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/15.CWG1ehzT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/16.Dpq6jbSh.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/17.B5AZbHUU.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/18.CBga8cnq.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.D2HXgz-8.js +1090 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.f4HAg-y3.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.BKF4xuKQ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.BAE0Pm_f.js +39 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.CouWWpzA.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.UBHT0ktp.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.FiYNElcc.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.B3-UaT23.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/clustering.worker-DKqeLtG0.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/search.worker-vNSty3B0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/index.html +45 -0
- lightly_studio/errors.py +5 -0
- lightly_studio/examples/example.py +25 -0
- lightly_studio/examples/example_coco.py +27 -0
- lightly_studio/examples/example_coco_caption.py +29 -0
- lightly_studio/examples/example_metadata.py +369 -0
- lightly_studio/examples/example_operators.py +111 -0
- lightly_studio/examples/example_selection.py +28 -0
- lightly_studio/examples/example_split_work.py +48 -0
- lightly_studio/examples/example_video.py +22 -0
- lightly_studio/examples/example_video_annotations.py +157 -0
- lightly_studio/examples/example_yolo.py +22 -0
- lightly_studio/export/coco_captions.py +69 -0
- lightly_studio/export/export_dataset.py +104 -0
- lightly_studio/export/lightly_studio_label_input.py +120 -0
- lightly_studio/export_schema.py +18 -0
- lightly_studio/export_version.py +57 -0
- lightly_studio/few_shot_classifier/__init__.py +0 -0
- lightly_studio/few_shot_classifier/classifier.py +80 -0
- lightly_studio/few_shot_classifier/classifier_manager.py +644 -0
- lightly_studio/few_shot_classifier/random_forest_classifier.py +495 -0
- lightly_studio/metadata/complex_metadata.py +47 -0
- lightly_studio/metadata/compute_similarity.py +84 -0
- lightly_studio/metadata/compute_typicality.py +67 -0
- lightly_studio/metadata/gps_coordinate.py +41 -0
- lightly_studio/metadata/metadata_protocol.py +17 -0
- lightly_studio/models/__init__.py +1 -0
- lightly_studio/models/annotation/__init__.py +0 -0
- lightly_studio/models/annotation/annotation_base.py +303 -0
- lightly_studio/models/annotation/instance_segmentation.py +56 -0
- lightly_studio/models/annotation/links.py +17 -0
- lightly_studio/models/annotation/object_detection.py +47 -0
- lightly_studio/models/annotation/semantic_segmentation.py +44 -0
- lightly_studio/models/annotation_label.py +47 -0
- lightly_studio/models/caption.py +49 -0
- lightly_studio/models/classifier.py +20 -0
- lightly_studio/models/dataset.py +70 -0
- lightly_studio/models/embedding_model.py +30 -0
- lightly_studio/models/image.py +96 -0
- lightly_studio/models/metadata.py +208 -0
- lightly_studio/models/range.py +17 -0
- lightly_studio/models/sample.py +154 -0
- lightly_studio/models/sample_embedding.py +36 -0
- lightly_studio/models/settings.py +69 -0
- lightly_studio/models/tag.py +96 -0
- lightly_studio/models/two_dim_embedding.py +16 -0
- lightly_studio/models/video.py +161 -0
- lightly_studio/plugins/__init__.py +0 -0
- lightly_studio/plugins/base_operator.py +60 -0
- lightly_studio/plugins/operator_registry.py +47 -0
- lightly_studio/plugins/parameter.py +70 -0
- lightly_studio/py.typed +0 -0
- lightly_studio/resolvers/__init__.py +0 -0
- lightly_studio/resolvers/annotation_label_resolver/__init__.py +22 -0
- lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
- lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
- lightly_studio/resolvers/annotation_label_resolver/get_all.py +37 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
- lightly_studio/resolvers/annotation_resolver/__init__.py +40 -0
- lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +129 -0
- lightly_studio/resolvers/annotation_resolver/create_many.py +124 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotation.py +87 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotations.py +60 -0
- lightly_studio/resolvers/annotation_resolver/get_all.py +85 -0
- lightly_studio/resolvers/annotation_resolver/get_all_with_payload.py +179 -0
- lightly_studio/resolvers/annotation_resolver/get_by_id.py +34 -0
- lightly_studio/resolvers/annotation_resolver/get_by_id_with_payload.py +130 -0
- lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +142 -0
- lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
- lightly_studio/resolvers/annotations/__init__.py +1 -0
- lightly_studio/resolvers/annotations/annotations_filter.py +88 -0
- lightly_studio/resolvers/caption_resolver.py +129 -0
- lightly_studio/resolvers/dataset_resolver/__init__.py +55 -0
- lightly_studio/resolvers/dataset_resolver/check_dataset_type.py +29 -0
- lightly_studio/resolvers/dataset_resolver/create.py +20 -0
- lightly_studio/resolvers/dataset_resolver/delete.py +20 -0
- lightly_studio/resolvers/dataset_resolver/export.py +267 -0
- lightly_studio/resolvers/dataset_resolver/get_all.py +19 -0
- lightly_studio/resolvers/dataset_resolver/get_by_id.py +16 -0
- lightly_studio/resolvers/dataset_resolver/get_by_name.py +12 -0
- lightly_studio/resolvers/dataset_resolver/get_dataset_details.py +27 -0
- lightly_studio/resolvers/dataset_resolver/get_hierarchy.py +31 -0
- lightly_studio/resolvers/dataset_resolver/get_or_create_child_dataset.py +58 -0
- lightly_studio/resolvers/dataset_resolver/get_parent_dataset_by_sample_id.py +27 -0
- lightly_studio/resolvers/dataset_resolver/get_parent_dataset_id.py +22 -0
- lightly_studio/resolvers/dataset_resolver/get_root_dataset.py +61 -0
- lightly_studio/resolvers/dataset_resolver/get_root_datasets_overview.py +41 -0
- lightly_studio/resolvers/dataset_resolver/update.py +25 -0
- lightly_studio/resolvers/embedding_model_resolver.py +120 -0
- lightly_studio/resolvers/image_filter.py +50 -0
- lightly_studio/resolvers/image_resolver/__init__.py +21 -0
- lightly_studio/resolvers/image_resolver/create_many.py +52 -0
- lightly_studio/resolvers/image_resolver/delete.py +20 -0
- lightly_studio/resolvers/image_resolver/filter_new_paths.py +23 -0
- lightly_studio/resolvers/image_resolver/get_all_by_dataset_id.py +117 -0
- lightly_studio/resolvers/image_resolver/get_by_id.py +14 -0
- lightly_studio/resolvers/image_resolver/get_dimension_bounds.py +75 -0
- lightly_studio/resolvers/image_resolver/get_many_by_id.py +22 -0
- lightly_studio/resolvers/image_resolver/get_samples_excluding.py +43 -0
- lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
- lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
- lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
- lightly_studio/resolvers/metadata_resolver/sample/bulk_update_metadata.py +46 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
- lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
- lightly_studio/resolvers/sample_embedding_resolver.py +132 -0
- lightly_studio/resolvers/sample_resolver/__init__.py +17 -0
- lightly_studio/resolvers/sample_resolver/count_by_dataset_id.py +16 -0
- lightly_studio/resolvers/sample_resolver/create.py +16 -0
- lightly_studio/resolvers/sample_resolver/create_many.py +25 -0
- lightly_studio/resolvers/sample_resolver/get_by_id.py +14 -0
- lightly_studio/resolvers/sample_resolver/get_filtered_samples.py +56 -0
- lightly_studio/resolvers/sample_resolver/get_many_by_id.py +22 -0
- lightly_studio/resolvers/sample_resolver/sample_filter.py +74 -0
- lightly_studio/resolvers/settings_resolver.py +62 -0
- lightly_studio/resolvers/tag_resolver.py +299 -0
- lightly_studio/resolvers/twodim_embedding_resolver.py +119 -0
- lightly_studio/resolvers/video_frame_resolver/__init__.py +23 -0
- lightly_studio/resolvers/video_frame_resolver/count_video_frames_annotations.py +83 -0
- lightly_studio/resolvers/video_frame_resolver/create_many.py +57 -0
- lightly_studio/resolvers/video_frame_resolver/get_all_by_dataset_id.py +63 -0
- lightly_studio/resolvers/video_frame_resolver/get_by_id.py +13 -0
- lightly_studio/resolvers/video_frame_resolver/get_table_fields_bounds.py +44 -0
- lightly_studio/resolvers/video_frame_resolver/video_frame_annotations_counter_filter.py +47 -0
- lightly_studio/resolvers/video_frame_resolver/video_frame_filter.py +57 -0
- lightly_studio/resolvers/video_resolver/__init__.py +27 -0
- lightly_studio/resolvers/video_resolver/count_video_frame_annotations_by_video_dataset.py +86 -0
- lightly_studio/resolvers/video_resolver/create_many.py +58 -0
- lightly_studio/resolvers/video_resolver/filter_new_paths.py +33 -0
- lightly_studio/resolvers/video_resolver/get_all_by_dataset_id.py +181 -0
- lightly_studio/resolvers/video_resolver/get_by_id.py +22 -0
- lightly_studio/resolvers/video_resolver/get_table_fields_bounds.py +72 -0
- lightly_studio/resolvers/video_resolver/get_view_by_id.py +52 -0
- lightly_studio/resolvers/video_resolver/video_count_annotations_filter.py +50 -0
- lightly_studio/resolvers/video_resolver/video_filter.py +98 -0
- lightly_studio/selection/__init__.py +1 -0
- lightly_studio/selection/mundig.py +143 -0
- lightly_studio/selection/select.py +203 -0
- lightly_studio/selection/select_via_db.py +273 -0
- lightly_studio/selection/selection_config.py +49 -0
- lightly_studio/services/annotations_service/__init__.py +33 -0
- lightly_studio/services/annotations_service/create_annotation.py +64 -0
- lightly_studio/services/annotations_service/delete_annotation.py +22 -0
- lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
- lightly_studio/services/annotations_service/update_annotation.py +54 -0
- lightly_studio/services/annotations_service/update_annotation_bounding_box.py +36 -0
- lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
- lightly_studio/services/annotations_service/update_annotations.py +29 -0
- lightly_studio/setup_logging.py +59 -0
- lightly_studio/type_definitions.py +31 -0
- lightly_studio/utils/__init__.py +3 -0
- lightly_studio/utils/download.py +94 -0
- lightly_studio/vendor/__init__.py +1 -0
- lightly_studio/vendor/mobileclip/ACKNOWLEDGEMENTS +422 -0
- lightly_studio/vendor/mobileclip/LICENSE +31 -0
- lightly_studio/vendor/mobileclip/LICENSE_weights_data +50 -0
- lightly_studio/vendor/mobileclip/README.md +5 -0
- lightly_studio/vendor/mobileclip/__init__.py +96 -0
- lightly_studio/vendor/mobileclip/clip.py +77 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
- lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
- lightly_studio/vendor/mobileclip/logger.py +154 -0
- lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
- lightly_studio/vendor/mobileclip/models/mci.py +933 -0
- lightly_studio/vendor/mobileclip/models/vit.py +433 -0
- lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
- lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
- lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
- lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
- lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
- lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
- lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
- lightly_studio/vendor/perception_encoder/LICENSE.PE +201 -0
- lightly_studio/vendor/perception_encoder/README.md +11 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/__init__.py +0 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/bpe_simple_vocab_16e6.txt.gz +0 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/config.py +205 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/config_src.py +264 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/pe.py +766 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/rope.py +352 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/tokenizer.py +347 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/transforms.py +36 -0
- lightly_studio-0.4.6.dist-info/METADATA +88 -0
- lightly_studio-0.4.6.dist-info/RECORD +356 -0
- lightly_studio-0.4.6.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,644 @@
|
|
|
1
|
+
"""ClassifierManager implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
import io
|
|
7
|
+
from collections.abc import Sequence
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from random import sample
|
|
11
|
+
from uuid import UUID, uuid4
|
|
12
|
+
|
|
13
|
+
from sqlmodel import Session
|
|
14
|
+
|
|
15
|
+
from lightly_studio.few_shot_classifier import random_forest_classifier
|
|
16
|
+
from lightly_studio.few_shot_classifier.classifier import (
|
|
17
|
+
AnnotatedEmbedding,
|
|
18
|
+
ExportType,
|
|
19
|
+
)
|
|
20
|
+
from lightly_studio.few_shot_classifier.random_forest_classifier import (
|
|
21
|
+
RandomForest,
|
|
22
|
+
)
|
|
23
|
+
from lightly_studio.models.annotation.annotation_base import (
|
|
24
|
+
AnnotationCreate,
|
|
25
|
+
AnnotationType,
|
|
26
|
+
)
|
|
27
|
+
from lightly_studio.models.annotation_label import (
|
|
28
|
+
AnnotationLabelCreate,
|
|
29
|
+
)
|
|
30
|
+
from lightly_studio.models.classifier import EmbeddingClassifier
|
|
31
|
+
from lightly_studio.models.image import ImageTable
|
|
32
|
+
from lightly_studio.resolvers import (
|
|
33
|
+
annotation_label_resolver,
|
|
34
|
+
annotation_resolver,
|
|
35
|
+
embedding_model_resolver,
|
|
36
|
+
image_resolver,
|
|
37
|
+
sample_embedding_resolver,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
HIGH_CONFIDENCE_THRESHOLD = 0.5
|
|
41
|
+
LOW_CONFIDENCE_THRESHOLD = 0.5
|
|
42
|
+
|
|
43
|
+
HIGH_CONFIDENCE_SAMPLES_NEEDED = 10
|
|
44
|
+
LOW_CONFIDENCE_SAMPLES_NEEDED = 10
|
|
45
|
+
|
|
46
|
+
FSC_ANNOTATION_TASK_PREFIX = "FSC_"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ClassifierManagerProvider:
|
|
50
|
+
"""Provider for the ClassifierManager singleton instance."""
|
|
51
|
+
|
|
52
|
+
_instance: ClassifierManager | None = None
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def get_classifier_manager(cls) -> ClassifierManager:
|
|
56
|
+
"""Get the singleton instance of ClassifierManager.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
The singleton instance of ClassifierManager.
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
ValueError: If no instance exists and no session is provided.
|
|
63
|
+
"""
|
|
64
|
+
if cls._instance is None:
|
|
65
|
+
cls._instance = ClassifierManager()
|
|
66
|
+
return cls._instance
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class ClassifierEntry:
|
|
71
|
+
"""Classifier dataclass."""
|
|
72
|
+
|
|
73
|
+
classifier_id: UUID
|
|
74
|
+
|
|
75
|
+
# TODO(Horatiu, 05/2025): Use FewShotClassifier instead of RandomForest
|
|
76
|
+
# when the interface is ready. Add method to get classifier info.
|
|
77
|
+
few_shot_classifier: RandomForest
|
|
78
|
+
|
|
79
|
+
# Annotations history is used to keep track of the samples that have
|
|
80
|
+
# been used for training. It is a dictionary with the key the class name and
|
|
81
|
+
# the value a list of sample IDs that belong to that class.
|
|
82
|
+
# This is used to avoid using the same samples for fine tuning multiple
|
|
83
|
+
# times.
|
|
84
|
+
annotations: dict[str, list[UUID]]
|
|
85
|
+
|
|
86
|
+
# Inactive classifiers are used for handling the fine tuning process.
|
|
87
|
+
# From the moment the classifier is created untill it is saved is_active
|
|
88
|
+
# will be false.
|
|
89
|
+
is_active: bool = False
|
|
90
|
+
|
|
91
|
+
annotation_label_ids: list[UUID] | None = None
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class ClassifierManager:
|
|
95
|
+
"""ClassifierManager class.
|
|
96
|
+
|
|
97
|
+
This class manages the lifecycle of a few-shot classifier,
|
|
98
|
+
including training, exporting, and loading the classifier.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(self) -> None:
|
|
102
|
+
"""Initialize the ClassifierManager."""
|
|
103
|
+
self._classifiers: dict[UUID, ClassifierEntry] = {}
|
|
104
|
+
|
|
105
|
+
def create_classifier(
|
|
106
|
+
self,
|
|
107
|
+
session: Session,
|
|
108
|
+
name: str,
|
|
109
|
+
class_list: list[str],
|
|
110
|
+
dataset_id: UUID,
|
|
111
|
+
) -> ClassifierEntry:
|
|
112
|
+
"""Create a new classifier.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
session: Database session for resolver operations.
|
|
116
|
+
name: The name of the classifier.
|
|
117
|
+
class_list: List of classes to be used for training.
|
|
118
|
+
dataset_id: The dataset_id to which the samples belong.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
The created classifier name and ID.
|
|
122
|
+
"""
|
|
123
|
+
embedding_models = embedding_model_resolver.get_all_by_dataset_id(
|
|
124
|
+
session=session,
|
|
125
|
+
dataset_id=dataset_id,
|
|
126
|
+
)
|
|
127
|
+
if len(embedding_models) == 0:
|
|
128
|
+
raise ValueError("No embedding model found for the given dataset ID.")
|
|
129
|
+
# TODO(Horatiu, 05/2025): Handle multiple models correctly when
|
|
130
|
+
# available
|
|
131
|
+
if len(embedding_models) > 1:
|
|
132
|
+
raise ValueError("Multiple embedding models found for the given dataset ID.")
|
|
133
|
+
embedding_model = embedding_models[0]
|
|
134
|
+
classifier = RandomForest(
|
|
135
|
+
name=name,
|
|
136
|
+
classes=class_list,
|
|
137
|
+
embedding_model_hash=embedding_model.embedding_model_hash,
|
|
138
|
+
embedding_model_name=embedding_model.name,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
classifier_id = uuid4()
|
|
142
|
+
self._classifiers[classifier_id] = ClassifierEntry(
|
|
143
|
+
classifier_id=classifier_id,
|
|
144
|
+
few_shot_classifier=classifier,
|
|
145
|
+
is_active=False,
|
|
146
|
+
annotations={class_name: [] for class_name in class_list},
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
return self._classifiers[classifier_id]
|
|
150
|
+
|
|
151
|
+
def train_classifier(self, session: Session, classifier_id: UUID) -> None:
|
|
152
|
+
"""Train the classifier.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
session: Database session for resolver operations.
|
|
156
|
+
classifier_id: The ID of the classifier to train.
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
ValueError: If the classifier with the given ID does not exist.
|
|
160
|
+
"""
|
|
161
|
+
classifier = self._classifiers.get(classifier_id)
|
|
162
|
+
if classifier is None:
|
|
163
|
+
raise ValueError(f"Classifier with ID {classifier_id} not found.")
|
|
164
|
+
|
|
165
|
+
embedding_model = embedding_model_resolver.get_by_model_hash(
|
|
166
|
+
session=session,
|
|
167
|
+
embedding_model_hash=classifier.few_shot_classifier.embedding_model_hash,
|
|
168
|
+
)
|
|
169
|
+
if embedding_model is None:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
"No embedding model found for hash '"
|
|
172
|
+
f"{classifier.few_shot_classifier.embedding_model_hash}'"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Get annotations.
|
|
176
|
+
annotations = classifier.annotations
|
|
177
|
+
annotated_embeddings = _create_annotated_embeddings(
|
|
178
|
+
session=session,
|
|
179
|
+
class_to_sample_ids=annotations,
|
|
180
|
+
embedding_model_id=embedding_model.embedding_model_id,
|
|
181
|
+
)
|
|
182
|
+
# Train the classifier with the annotated embeddings.
|
|
183
|
+
# This will overwrite the previous training.
|
|
184
|
+
classifier.few_shot_classifier.train(annotated_embeddings)
|
|
185
|
+
|
|
186
|
+
def commit_temp_classifier(self, classifier_id: UUID) -> None:
|
|
187
|
+
"""Set the classifier as active.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
classifier_id: The ID of the classifier to save.
|
|
191
|
+
|
|
192
|
+
Raises:
|
|
193
|
+
ValueError: If the classifier with the given ID does not exist
|
|
194
|
+
or if the classifier is not yet trained.
|
|
195
|
+
"""
|
|
196
|
+
classifier = self._classifiers.get(classifier_id)
|
|
197
|
+
if classifier is None:
|
|
198
|
+
raise ValueError(f"Classifier with ID {classifier_id} not found.")
|
|
199
|
+
if classifier.few_shot_classifier.is_trained() is False:
|
|
200
|
+
raise ValueError(f"Classifier with ID {classifier_id} is not trained yet.")
|
|
201
|
+
classifier.is_active = True
|
|
202
|
+
|
|
203
|
+
def drop_temp_classifier(self, classifier_id: UUID) -> None:
|
|
204
|
+
"""Remove a classifier that is inactive.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
classifier_id: The ID of the classifier to drop.
|
|
208
|
+
|
|
209
|
+
Raises:
|
|
210
|
+
ValueError: If the classifier with the given ID does not exist.
|
|
211
|
+
"""
|
|
212
|
+
classifier = self._classifiers.get(classifier_id)
|
|
213
|
+
if classifier is None:
|
|
214
|
+
raise ValueError(f"Classifier with ID {classifier_id} not found.")
|
|
215
|
+
if classifier.is_active:
|
|
216
|
+
raise ValueError(f"Classifier with ID {classifier_id} is active and cannot be dropped.")
|
|
217
|
+
self._classifiers.pop(classifier_id, None)
|
|
218
|
+
|
|
219
|
+
def save_classifier_to_file(self, classifier_id: UUID, file_path: Path) -> None:
|
|
220
|
+
"""Save the classifier to file.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
classifier_id: The ID of the classifier to save.
|
|
224
|
+
file_path: The path to save the classifer to.
|
|
225
|
+
|
|
226
|
+
Raises:
|
|
227
|
+
ValueError: If the classifier with the given ID does not exist.
|
|
228
|
+
"""
|
|
229
|
+
classifier = self._classifiers.get(classifier_id)
|
|
230
|
+
if classifier is None:
|
|
231
|
+
raise ValueError(f"Classifier with ID {classifier_id} not found.")
|
|
232
|
+
if not classifier.is_active:
|
|
233
|
+
raise ValueError(
|
|
234
|
+
f"Classifier with ID {classifier_id} is not active and cannot be saved."
|
|
235
|
+
)
|
|
236
|
+
classifier.few_shot_classifier.export(export_path=file_path, export_type="sklearn")
|
|
237
|
+
|
|
238
|
+
def load_classifier_from_file(self, session: Session, file_path: Path) -> ClassifierEntry:
|
|
239
|
+
"""Loads a classifier from file.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
session: Database session for resolver operations.
|
|
243
|
+
file_path: The path from where to load the classifier.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
The ID of the loaded classifier.
|
|
247
|
+
"""
|
|
248
|
+
classifier = random_forest_classifier.load_random_forest_classifier(
|
|
249
|
+
classifier_path=file_path, buffer=None
|
|
250
|
+
)
|
|
251
|
+
embedding_model = embedding_model_resolver.get_by_model_hash(
|
|
252
|
+
session=session,
|
|
253
|
+
embedding_model_hash=classifier.embedding_model_hash,
|
|
254
|
+
)
|
|
255
|
+
if embedding_model is None:
|
|
256
|
+
raise ValueError(
|
|
257
|
+
"No matching embedding model found for the classifier's hash:"
|
|
258
|
+
f"'{classifier.embedding_model_hash}'."
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
classifier_id = uuid4()
|
|
262
|
+
self._classifiers[classifier_id] = ClassifierEntry(
|
|
263
|
+
classifier_id=classifier_id,
|
|
264
|
+
few_shot_classifier=classifier,
|
|
265
|
+
is_active=True,
|
|
266
|
+
annotations={class_name: [] for class_name in classifier.classes},
|
|
267
|
+
)
|
|
268
|
+
return self._classifiers[classifier_id]
|
|
269
|
+
|
|
270
|
+
def provide_negative_samples(
|
|
271
|
+
self, session: Session, dataset_id: UUID, selected_samples: list[UUID], limit: int = 10
|
|
272
|
+
) -> Sequence[ImageTable]:
|
|
273
|
+
"""Provide random samples that are not in the selected samples.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
session: Database session for resolver operations.
|
|
277
|
+
dataset_id: The dataset_id to pull samples from.
|
|
278
|
+
selected_samples: List of sample UUIDs to exclude.
|
|
279
|
+
limit: Number of negative samples to return.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
List of negative samples.
|
|
283
|
+
|
|
284
|
+
"""
|
|
285
|
+
return image_resolver.get_samples_excluding(
|
|
286
|
+
session=session,
|
|
287
|
+
dataset_id=dataset_id,
|
|
288
|
+
excluded_sample_ids=selected_samples,
|
|
289
|
+
limit=limit,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
def update_classifiers_annotations(
|
|
293
|
+
self,
|
|
294
|
+
classifier_id: UUID,
|
|
295
|
+
new_annotations: dict[str, list[UUID]],
|
|
296
|
+
) -> None:
|
|
297
|
+
"""Update annotations with new samples for multiple classes.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
classifier_id: The ID of the classifier.
|
|
301
|
+
new_annotations: Dictionary mapping class names to lists of sample
|
|
302
|
+
IDs.
|
|
303
|
+
|
|
304
|
+
Raises:
|
|
305
|
+
ValueError: If the classifier doesn't exist.
|
|
306
|
+
"""
|
|
307
|
+
classifier = self._classifiers.get(classifier_id)
|
|
308
|
+
if classifier is None:
|
|
309
|
+
raise ValueError(f"Classifier with ID {classifier_id} not found.")
|
|
310
|
+
|
|
311
|
+
annotations = classifier.annotations
|
|
312
|
+
# Validate no new classes are being added.
|
|
313
|
+
if not set(new_annotations.keys()).issubset(annotations.keys()):
|
|
314
|
+
invalid_classes = set(new_annotations.keys()) - set(annotations.keys())
|
|
315
|
+
raise ValueError(
|
|
316
|
+
f"Cannot add new classes {invalid_classes} to existing"
|
|
317
|
+
f" classifier. Allowed classes are: {set(annotations.keys())}"
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# Get all new samples that will be added.
|
|
321
|
+
all_new_samples = {
|
|
322
|
+
sample_id for samples in new_annotations.values() for sample_id in samples
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
# Update annotations.
|
|
326
|
+
for existing_class in annotations:
|
|
327
|
+
# Remove newly annotated samples if existing already
|
|
328
|
+
# and add samples for this class.
|
|
329
|
+
new_class_samples = set(new_annotations.get(existing_class, []))
|
|
330
|
+
annotations[existing_class] = list(
|
|
331
|
+
(set(annotations[existing_class]) - all_new_samples) | new_class_samples
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
def get_annotations(self, classifier_id: UUID) -> dict[str, list[UUID]]:
|
|
335
|
+
"""Get all samples used in training for each class.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
classifier_id: The ID of the classifier.
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
Dictionary mapping class names to lists of sample IDs.
|
|
342
|
+
|
|
343
|
+
Raises:
|
|
344
|
+
ValueError: If the classifier doesn't exist.
|
|
345
|
+
"""
|
|
346
|
+
classifier = self._classifiers.get(classifier_id)
|
|
347
|
+
if classifier is None:
|
|
348
|
+
raise ValueError(f"Classifier with ID {classifier_id} not found.")
|
|
349
|
+
|
|
350
|
+
return copy.deepcopy(classifier.annotations)
|
|
351
|
+
|
|
352
|
+
def get_samples_for_fine_tuning(
|
|
353
|
+
self, session: Session, dataset_id: UUID, classifier_id: UUID
|
|
354
|
+
) -> dict[str, list[UUID]]:
|
|
355
|
+
"""Get samples for fine-tuning the classifier.
|
|
356
|
+
|
|
357
|
+
Gets at most 20 samples total:
|
|
358
|
+
- 10 positive samples (prediction confidence > 0.5)
|
|
359
|
+
- 10 uncertain samples (prediction confidence < 0.5)
|
|
360
|
+
If there are not enough samples, it will return all available
|
|
361
|
+
samples of that type.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
session: Database session for resolver operations.
|
|
365
|
+
dataset_id: The ID of the dataset to pull samples from.
|
|
366
|
+
classifier_id: The ID of the classifier to use.
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
Dictionary mapping class names to sample IDs. The first class from
|
|
370
|
+
classifier.classes gets samples with high confidence predictions,
|
|
371
|
+
the second class gets samples with low confidence predictions.
|
|
372
|
+
|
|
373
|
+
Raises:
|
|
374
|
+
ValueError: If the classifier with the given ID does not exist
|
|
375
|
+
or there is no appropriate embedding model.
|
|
376
|
+
"""
|
|
377
|
+
classifier = self._classifiers.get(classifier_id)
|
|
378
|
+
if classifier is None:
|
|
379
|
+
raise ValueError(f"Classifier with ID {classifier_id} not found.")
|
|
380
|
+
# Get all previously used annotations.
|
|
381
|
+
annotations = classifier.annotations
|
|
382
|
+
used_samples = {sample_id for samples in annotations.values() for sample_id in samples}
|
|
383
|
+
|
|
384
|
+
embedding_model = embedding_model_resolver.get_by_model_hash(
|
|
385
|
+
session=session,
|
|
386
|
+
embedding_model_hash=classifier.few_shot_classifier.embedding_model_hash,
|
|
387
|
+
)
|
|
388
|
+
if embedding_model is None:
|
|
389
|
+
raise ValueError(
|
|
390
|
+
"No embedding model found for hash '"
|
|
391
|
+
f"{classifier.few_shot_classifier.embedding_model_hash}'"
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# Create list of SampleEmbedding objects to track sample IDs
|
|
395
|
+
sample_embeddings = sample_embedding_resolver.get_all_by_dataset_id(
|
|
396
|
+
session=session,
|
|
397
|
+
dataset_id=dataset_id,
|
|
398
|
+
embedding_model_id=embedding_model.embedding_model_id,
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
# Get predictions for all embeddings.
|
|
402
|
+
embeddings = [se.embedding for se in sample_embeddings]
|
|
403
|
+
predictions = classifier.few_shot_classifier.predict(embeddings)
|
|
404
|
+
|
|
405
|
+
# Group samples by prediction confidence.
|
|
406
|
+
high_conf = [] # > 0.5
|
|
407
|
+
low_conf = [] # <= 0.5
|
|
408
|
+
|
|
409
|
+
for sample_embedding, pred in zip(sample_embeddings, predictions):
|
|
410
|
+
if sample_embedding.sample_id in used_samples:
|
|
411
|
+
continue
|
|
412
|
+
if pred[0] > HIGH_CONFIDENCE_THRESHOLD:
|
|
413
|
+
high_conf.append(sample_embedding.sample_id)
|
|
414
|
+
elif pred[0] <= LOW_CONFIDENCE_THRESHOLD:
|
|
415
|
+
low_conf.append(sample_embedding.sample_id)
|
|
416
|
+
|
|
417
|
+
return {
|
|
418
|
+
classifier.few_shot_classifier.classes[0]: sample(
|
|
419
|
+
high_conf, min(len(high_conf), HIGH_CONFIDENCE_SAMPLES_NEEDED)
|
|
420
|
+
),
|
|
421
|
+
classifier.few_shot_classifier.classes[1]: sample(
|
|
422
|
+
low_conf, min(len(low_conf), LOW_CONFIDENCE_SAMPLES_NEEDED)
|
|
423
|
+
),
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
def run_classifier(self, session: Session, classifier_id: UUID, dataset_id: UUID) -> None:
|
|
427
|
+
"""Run the classifier on the dataset.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
session: Database session for resolver operations.
|
|
431
|
+
classifier_id: The ID of the classifier to run.
|
|
432
|
+
dataset_id: The ID of the dataset to run the classifier on.
|
|
433
|
+
|
|
434
|
+
Raises:
|
|
435
|
+
ValueError: If the classifier with the given ID does not exist
|
|
436
|
+
or there is no appropriate embedding model.
|
|
437
|
+
"""
|
|
438
|
+
classifier = self._classifiers.get(classifier_id)
|
|
439
|
+
if classifier is None:
|
|
440
|
+
raise ValueError(f"Classifier with ID {classifier_id} not found.")
|
|
441
|
+
|
|
442
|
+
if not classifier.is_active:
|
|
443
|
+
raise ValueError(
|
|
444
|
+
f"Classifier with ID {classifier_id} is not active and cannot be used."
|
|
445
|
+
)
|
|
446
|
+
embedding_model = embedding_model_resolver.get_by_model_hash(
|
|
447
|
+
session=session,
|
|
448
|
+
embedding_model_hash=classifier.few_shot_classifier.embedding_model_hash,
|
|
449
|
+
)
|
|
450
|
+
if embedding_model is None:
|
|
451
|
+
raise ValueError(
|
|
452
|
+
"No embedding model found for hash '"
|
|
453
|
+
f"{classifier.few_shot_classifier.embedding_model_hash}'"
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# Create list of SampleEmbedding objects to track sample IDs
|
|
457
|
+
sample_embeddings = sample_embedding_resolver.get_all_by_dataset_id(
|
|
458
|
+
session=session,
|
|
459
|
+
dataset_id=dataset_id,
|
|
460
|
+
embedding_model_id=embedding_model.embedding_model_id,
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
# Extract just the embeddings for prediction
|
|
464
|
+
embeddings = [se.embedding for se in sample_embeddings]
|
|
465
|
+
predictions = classifier.few_shot_classifier.predict(embeddings)
|
|
466
|
+
if len(predictions):
|
|
467
|
+
_create_annotation_labels_for_classifier(
|
|
468
|
+
classifier=classifier,
|
|
469
|
+
session=session,
|
|
470
|
+
dataset_id=dataset_id,
|
|
471
|
+
)
|
|
472
|
+
else:
|
|
473
|
+
raise ValueError(f"Predict returned empty list for classifier:'{classifier_id}'")
|
|
474
|
+
# Check if annotation labels are available
|
|
475
|
+
if not classifier.annotation_label_ids:
|
|
476
|
+
raise ValueError(f"Classifier with ID '{classifier_id}' has no annotation labels")
|
|
477
|
+
|
|
478
|
+
# For each prediction add a classification annotation for the
|
|
479
|
+
# sample or update an existing one.
|
|
480
|
+
classification_annotations = []
|
|
481
|
+
for sample_embedding, prediction in zip(sample_embeddings, predictions):
|
|
482
|
+
max_index = prediction.index(max(prediction))
|
|
483
|
+
classification_annotations.append(
|
|
484
|
+
AnnotationCreate(
|
|
485
|
+
parent_sample_id=sample_embedding.sample_id,
|
|
486
|
+
dataset_id=dataset_id,
|
|
487
|
+
annotation_label_id=classifier.annotation_label_ids[max_index],
|
|
488
|
+
annotation_type=AnnotationType.CLASSIFICATION,
|
|
489
|
+
confidence=prediction[max_index],
|
|
490
|
+
)
|
|
491
|
+
)
|
|
492
|
+
# Clear previous annotations by this classifier
|
|
493
|
+
annotation_resolver.delete_annotations(
|
|
494
|
+
session=session,
|
|
495
|
+
annotation_label_ids=classifier.annotation_label_ids,
|
|
496
|
+
)
|
|
497
|
+
annotation_resolver.create_many(
|
|
498
|
+
session=session, parent_dataset_id=dataset_id, annotations=classification_annotations
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
def get_all_classifiers(self) -> list[EmbeddingClassifier]:
|
|
502
|
+
"""Get all active classifiers.
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
List of EmbeddingClassifier objects representing active classifiers.
|
|
506
|
+
"""
|
|
507
|
+
return [
|
|
508
|
+
EmbeddingClassifier(
|
|
509
|
+
classifier_name=classifier.few_shot_classifier.name,
|
|
510
|
+
classifier_id=classifier.classifier_id,
|
|
511
|
+
class_list=classifier.few_shot_classifier.classes,
|
|
512
|
+
)
|
|
513
|
+
for classifier in self._classifiers.values()
|
|
514
|
+
if classifier.is_active
|
|
515
|
+
]
|
|
516
|
+
|
|
517
|
+
def get_classifier_by_id(self, classifier_id: UUID) -> EmbeddingClassifier:
|
|
518
|
+
"""Get all active classifiers.
|
|
519
|
+
|
|
520
|
+
Args:
|
|
521
|
+
classifier_id: The ID of the classifier to get.
|
|
522
|
+
|
|
523
|
+
Raises:
|
|
524
|
+
ValueError: If the classifier with the given ID does not exist.
|
|
525
|
+
|
|
526
|
+
Returns:
|
|
527
|
+
EmbeddingClassifier object.
|
|
528
|
+
"""
|
|
529
|
+
classifier = self._classifiers.get(classifier_id)
|
|
530
|
+
if classifier is None:
|
|
531
|
+
raise ValueError(f"Classifier with ID {classifier_id} not found.")
|
|
532
|
+
return EmbeddingClassifier(
|
|
533
|
+
classifier_name=classifier.few_shot_classifier.name,
|
|
534
|
+
classifier_id=classifier_id,
|
|
535
|
+
class_list=classifier.few_shot_classifier.classes,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
def save_classifier_to_buffer(
|
|
539
|
+
self, classifier_id: UUID, buffer: io.BytesIO, export_type: ExportType
|
|
540
|
+
) -> None:
|
|
541
|
+
"""Save the classifier to a buffer.
|
|
542
|
+
|
|
543
|
+
Args:
|
|
544
|
+
classifier_id: The ID of the classifier to save.
|
|
545
|
+
buffer: The buffer to save the classifier to.
|
|
546
|
+
export_type: The type of export to perform.
|
|
547
|
+
|
|
548
|
+
Raises:
|
|
549
|
+
ValueError: If the classifier with the given ID does not exist.
|
|
550
|
+
"""
|
|
551
|
+
classifier = self._classifiers.get(classifier_id)
|
|
552
|
+
if classifier is None:
|
|
553
|
+
raise ValueError(f"Classifier with ID {classifier_id} not found.")
|
|
554
|
+
classifier.few_shot_classifier.export(buffer=buffer, export_type=export_type)
|
|
555
|
+
|
|
556
|
+
def load_classifier_from_buffer(self, session: Session, buffer: io.BytesIO) -> ClassifierEntry:
|
|
557
|
+
"""Loads a classifier from a buffer.
|
|
558
|
+
|
|
559
|
+
Args:
|
|
560
|
+
session: Database session for resolver operations.
|
|
561
|
+
buffer: The buffer containing the classifier data.
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
The ID of the loaded classifier.
|
|
565
|
+
|
|
566
|
+
Raises:
|
|
567
|
+
ValueError: If no matching embedding model is found for the
|
|
568
|
+
classifier.
|
|
569
|
+
"""
|
|
570
|
+
classifier = random_forest_classifier.load_random_forest_classifier(
|
|
571
|
+
buffer=buffer, classifier_path=None
|
|
572
|
+
)
|
|
573
|
+
embedding_model = embedding_model_resolver.get_by_model_hash(
|
|
574
|
+
session=session,
|
|
575
|
+
embedding_model_hash=classifier.embedding_model_hash,
|
|
576
|
+
)
|
|
577
|
+
if embedding_model is None:
|
|
578
|
+
raise ValueError(
|
|
579
|
+
"No matching embedding model found for the classifier's hash: "
|
|
580
|
+
f"'{classifier.embedding_model_hash}'."
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
classifier_id = uuid4()
|
|
584
|
+
self._classifiers[classifier_id] = ClassifierEntry(
|
|
585
|
+
classifier_id=classifier_id,
|
|
586
|
+
few_shot_classifier=classifier,
|
|
587
|
+
is_active=True,
|
|
588
|
+
annotations={class_name: [] for class_name in classifier.classes},
|
|
589
|
+
)
|
|
590
|
+
return self._classifiers[classifier_id]
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def _create_annotation_labels_for_classifier(
|
|
594
|
+
session: Session,
|
|
595
|
+
dataset_id: UUID,
|
|
596
|
+
classifier: ClassifierEntry,
|
|
597
|
+
) -> None:
|
|
598
|
+
"""Create annotation labels for the classifier.
|
|
599
|
+
|
|
600
|
+
Args:
|
|
601
|
+
session: Database session.
|
|
602
|
+
dataset_id: The dataset ID to which the samples belong.
|
|
603
|
+
classifier: The classifier object to update.
|
|
604
|
+
"""
|
|
605
|
+
# Check if the annotation label with the classifier name and class
|
|
606
|
+
# names exists and if not create it.
|
|
607
|
+
if classifier.annotation_label_ids is None:
|
|
608
|
+
annotation_label_ids = []
|
|
609
|
+
for class_name in classifier.few_shot_classifier.classes:
|
|
610
|
+
annotation_label = annotation_label_resolver.create(
|
|
611
|
+
session=session,
|
|
612
|
+
label=AnnotationLabelCreate(
|
|
613
|
+
dataset_id=dataset_id,
|
|
614
|
+
annotation_label_name=classifier.few_shot_classifier.name + "_" + class_name,
|
|
615
|
+
),
|
|
616
|
+
)
|
|
617
|
+
annotation_label_ids.append(annotation_label.annotation_label_id)
|
|
618
|
+
classifier.annotation_label_ids = annotation_label_ids
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
def _create_annotated_embeddings(
|
|
622
|
+
session: Session,
|
|
623
|
+
class_to_sample_ids: dict[str, list[UUID]],
|
|
624
|
+
embedding_model_id: UUID,
|
|
625
|
+
) -> list[AnnotatedEmbedding]:
|
|
626
|
+
"""Create annotated embeddings from input data.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
session: Database session.
|
|
630
|
+
class_to_sample_ids: Dictionary mapping class names to sample UUIDs.
|
|
631
|
+
embedding_model_id: The embedding model ID to filter by.
|
|
632
|
+
|
|
633
|
+
Returns:
|
|
634
|
+
List of annotated embeddings for training.
|
|
635
|
+
"""
|
|
636
|
+
return [
|
|
637
|
+
AnnotatedEmbedding(embedding=embedding.embedding, annotation=class_name)
|
|
638
|
+
for class_name, sample_uuids in class_to_sample_ids.items()
|
|
639
|
+
for embedding in sample_embedding_resolver.get_by_sample_ids(
|
|
640
|
+
session=session,
|
|
641
|
+
sample_ids=sample_uuids,
|
|
642
|
+
embedding_model_id=embedding_model_id,
|
|
643
|
+
)
|
|
644
|
+
]
|