lightly-studio 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightly_studio/__init__.py +12 -0
- lightly_studio/api/__init__.py +0 -0
- lightly_studio/api/app.py +131 -0
- lightly_studio/api/cache.py +77 -0
- lightly_studio/api/db_tables.py +35 -0
- lightly_studio/api/features.py +5 -0
- lightly_studio/api/routes/api/annotation.py +305 -0
- lightly_studio/api/routes/api/annotation_label.py +87 -0
- lightly_studio/api/routes/api/annotations/__init__.py +7 -0
- lightly_studio/api/routes/api/annotations/create_annotation.py +52 -0
- lightly_studio/api/routes/api/caption.py +100 -0
- lightly_studio/api/routes/api/classifier.py +384 -0
- lightly_studio/api/routes/api/dataset.py +191 -0
- lightly_studio/api/routes/api/dataset_tag.py +266 -0
- lightly_studio/api/routes/api/embeddings2d.py +90 -0
- lightly_studio/api/routes/api/exceptions.py +114 -0
- lightly_studio/api/routes/api/export.py +114 -0
- lightly_studio/api/routes/api/features.py +17 -0
- lightly_studio/api/routes/api/frame.py +241 -0
- lightly_studio/api/routes/api/image.py +155 -0
- lightly_studio/api/routes/api/metadata.py +161 -0
- lightly_studio/api/routes/api/operator.py +75 -0
- lightly_studio/api/routes/api/sample.py +103 -0
- lightly_studio/api/routes/api/selection.py +87 -0
- lightly_studio/api/routes/api/settings.py +41 -0
- lightly_studio/api/routes/api/status.py +19 -0
- lightly_studio/api/routes/api/text_embedding.py +50 -0
- lightly_studio/api/routes/api/validators.py +17 -0
- lightly_studio/api/routes/api/video.py +133 -0
- lightly_studio/api/routes/healthz.py +13 -0
- lightly_studio/api/routes/images.py +104 -0
- lightly_studio/api/routes/video_frames_media.py +116 -0
- lightly_studio/api/routes/video_media.py +223 -0
- lightly_studio/api/routes/webapp.py +51 -0
- lightly_studio/api/server.py +94 -0
- lightly_studio/core/__init__.py +0 -0
- lightly_studio/core/add_samples.py +533 -0
- lightly_studio/core/add_videos.py +294 -0
- lightly_studio/core/dataset.py +780 -0
- lightly_studio/core/dataset_query/__init__.py +14 -0
- lightly_studio/core/dataset_query/boolean_expression.py +67 -0
- lightly_studio/core/dataset_query/dataset_query.py +317 -0
- lightly_studio/core/dataset_query/field.py +113 -0
- lightly_studio/core/dataset_query/field_expression.py +79 -0
- lightly_studio/core/dataset_query/match_expression.py +23 -0
- lightly_studio/core/dataset_query/order_by.py +79 -0
- lightly_studio/core/dataset_query/sample_field.py +37 -0
- lightly_studio/core/dataset_query/tags_expression.py +46 -0
- lightly_studio/core/image_sample.py +36 -0
- lightly_studio/core/loading_log.py +56 -0
- lightly_studio/core/sample.py +291 -0
- lightly_studio/core/start_gui.py +54 -0
- lightly_studio/core/video_sample.py +38 -0
- lightly_studio/dataset/__init__.py +0 -0
- lightly_studio/dataset/edge_embedding_generator.py +155 -0
- lightly_studio/dataset/embedding_generator.py +129 -0
- lightly_studio/dataset/embedding_manager.py +349 -0
- lightly_studio/dataset/env.py +20 -0
- lightly_studio/dataset/file_utils.py +49 -0
- lightly_studio/dataset/fsspec_lister.py +275 -0
- lightly_studio/dataset/mobileclip_embedding_generator.py +158 -0
- lightly_studio/dataset/perception_encoder_embedding_generator.py +260 -0
- lightly_studio/db_manager.py +166 -0
- lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.GcXvs2l7.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/12.Dx6SXgAb.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/17.9X9_k6TP.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/18.BxiimdIO.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/2.CkOblLn7.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/ClassifierSamplesGrid.BJbCDlvs.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Medium.DVUZMR_6.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.CefECEWA.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.D5tDcjY-.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.9X9_k6TP.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.BxiimdIO.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_page.Dx6SXgAb.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/transform._-1mPSEI.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/0dDyq72A.js +20 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BK4An2kI.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BRmB-kJ9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B_1cpokE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BiqpDEr0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BpLiSKgx.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BscxbINH.js +39 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C1FmrZbK.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C80h3dJx.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8mfFM-u.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CGY1p9L4.js +517 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/COfLknXM.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CmLg0ys7.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvGjimpO.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D3RDXHoj.js +39 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D4y7iiT3.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D9SC3jBb.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DCuAdx1Q.js +20 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DDBy-_jD.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DL9a7v5o.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DSKECuqX.js +39 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D_FFv0Oe.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DiZ5o5vz.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DkbXUtyG.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DmK2hulV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DqnHaLTj.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DtWZc_tl.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DuUalyFS.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DwIonDAZ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Il-mSPmK.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KNLP4aJU.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KjYeVjkE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/MErlcOXj.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VRI4prUD.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VYb2dkNs.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/VqWvU2yF.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/dHC3otuL.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/da7Oy_lO.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/eAy8rZzC.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/erjNR5MX.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/f1oG3eFE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/rsLi1iKv.js +20 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/rwuuBP9f.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/xGHZQ1pe.js +3 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.DrTRUgT3.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.BK5EOJl2.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.CIvTuljF.js +4 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.UBvSzxdA.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.CQ_tiLJa.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.KqkAcaxW.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.DoYsmxQc.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/13.571n2LZA.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/14.DGs689M-.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/15.CWG1ehzT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/16.Dpq6jbSh.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/17.B5AZbHUU.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/18.CBga8cnq.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.D2HXgz-8.js +1090 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.f4HAg-y3.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.BKF4xuKQ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.BAE0Pm_f.js +39 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.CouWWpzA.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.UBHT0ktp.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.FiYNElcc.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.B3-UaT23.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/clustering.worker-DKqeLtG0.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/search.worker-vNSty3B0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/index.html +45 -0
- lightly_studio/errors.py +5 -0
- lightly_studio/examples/example.py +25 -0
- lightly_studio/examples/example_coco.py +27 -0
- lightly_studio/examples/example_coco_caption.py +29 -0
- lightly_studio/examples/example_metadata.py +369 -0
- lightly_studio/examples/example_operators.py +111 -0
- lightly_studio/examples/example_selection.py +28 -0
- lightly_studio/examples/example_split_work.py +48 -0
- lightly_studio/examples/example_video.py +22 -0
- lightly_studio/examples/example_video_annotations.py +157 -0
- lightly_studio/examples/example_yolo.py +22 -0
- lightly_studio/export/coco_captions.py +69 -0
- lightly_studio/export/export_dataset.py +104 -0
- lightly_studio/export/lightly_studio_label_input.py +120 -0
- lightly_studio/export_schema.py +18 -0
- lightly_studio/export_version.py +57 -0
- lightly_studio/few_shot_classifier/__init__.py +0 -0
- lightly_studio/few_shot_classifier/classifier.py +80 -0
- lightly_studio/few_shot_classifier/classifier_manager.py +644 -0
- lightly_studio/few_shot_classifier/random_forest_classifier.py +495 -0
- lightly_studio/metadata/complex_metadata.py +47 -0
- lightly_studio/metadata/compute_similarity.py +84 -0
- lightly_studio/metadata/compute_typicality.py +67 -0
- lightly_studio/metadata/gps_coordinate.py +41 -0
- lightly_studio/metadata/metadata_protocol.py +17 -0
- lightly_studio/models/__init__.py +1 -0
- lightly_studio/models/annotation/__init__.py +0 -0
- lightly_studio/models/annotation/annotation_base.py +303 -0
- lightly_studio/models/annotation/instance_segmentation.py +56 -0
- lightly_studio/models/annotation/links.py +17 -0
- lightly_studio/models/annotation/object_detection.py +47 -0
- lightly_studio/models/annotation/semantic_segmentation.py +44 -0
- lightly_studio/models/annotation_label.py +47 -0
- lightly_studio/models/caption.py +49 -0
- lightly_studio/models/classifier.py +20 -0
- lightly_studio/models/dataset.py +70 -0
- lightly_studio/models/embedding_model.py +30 -0
- lightly_studio/models/image.py +96 -0
- lightly_studio/models/metadata.py +208 -0
- lightly_studio/models/range.py +17 -0
- lightly_studio/models/sample.py +154 -0
- lightly_studio/models/sample_embedding.py +36 -0
- lightly_studio/models/settings.py +69 -0
- lightly_studio/models/tag.py +96 -0
- lightly_studio/models/two_dim_embedding.py +16 -0
- lightly_studio/models/video.py +161 -0
- lightly_studio/plugins/__init__.py +0 -0
- lightly_studio/plugins/base_operator.py +60 -0
- lightly_studio/plugins/operator_registry.py +47 -0
- lightly_studio/plugins/parameter.py +70 -0
- lightly_studio/py.typed +0 -0
- lightly_studio/resolvers/__init__.py +0 -0
- lightly_studio/resolvers/annotation_label_resolver/__init__.py +22 -0
- lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
- lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
- lightly_studio/resolvers/annotation_label_resolver/get_all.py +37 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
- lightly_studio/resolvers/annotation_resolver/__init__.py +40 -0
- lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +129 -0
- lightly_studio/resolvers/annotation_resolver/create_many.py +124 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotation.py +87 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotations.py +60 -0
- lightly_studio/resolvers/annotation_resolver/get_all.py +85 -0
- lightly_studio/resolvers/annotation_resolver/get_all_with_payload.py +179 -0
- lightly_studio/resolvers/annotation_resolver/get_by_id.py +34 -0
- lightly_studio/resolvers/annotation_resolver/get_by_id_with_payload.py +130 -0
- lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +142 -0
- lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
- lightly_studio/resolvers/annotations/__init__.py +1 -0
- lightly_studio/resolvers/annotations/annotations_filter.py +88 -0
- lightly_studio/resolvers/caption_resolver.py +129 -0
- lightly_studio/resolvers/dataset_resolver/__init__.py +55 -0
- lightly_studio/resolvers/dataset_resolver/check_dataset_type.py +29 -0
- lightly_studio/resolvers/dataset_resolver/create.py +20 -0
- lightly_studio/resolvers/dataset_resolver/delete.py +20 -0
- lightly_studio/resolvers/dataset_resolver/export.py +267 -0
- lightly_studio/resolvers/dataset_resolver/get_all.py +19 -0
- lightly_studio/resolvers/dataset_resolver/get_by_id.py +16 -0
- lightly_studio/resolvers/dataset_resolver/get_by_name.py +12 -0
- lightly_studio/resolvers/dataset_resolver/get_dataset_details.py +27 -0
- lightly_studio/resolvers/dataset_resolver/get_hierarchy.py +31 -0
- lightly_studio/resolvers/dataset_resolver/get_or_create_child_dataset.py +58 -0
- lightly_studio/resolvers/dataset_resolver/get_parent_dataset_by_sample_id.py +27 -0
- lightly_studio/resolvers/dataset_resolver/get_parent_dataset_id.py +22 -0
- lightly_studio/resolvers/dataset_resolver/get_root_dataset.py +61 -0
- lightly_studio/resolvers/dataset_resolver/get_root_datasets_overview.py +41 -0
- lightly_studio/resolvers/dataset_resolver/update.py +25 -0
- lightly_studio/resolvers/embedding_model_resolver.py +120 -0
- lightly_studio/resolvers/image_filter.py +50 -0
- lightly_studio/resolvers/image_resolver/__init__.py +21 -0
- lightly_studio/resolvers/image_resolver/create_many.py +52 -0
- lightly_studio/resolvers/image_resolver/delete.py +20 -0
- lightly_studio/resolvers/image_resolver/filter_new_paths.py +23 -0
- lightly_studio/resolvers/image_resolver/get_all_by_dataset_id.py +117 -0
- lightly_studio/resolvers/image_resolver/get_by_id.py +14 -0
- lightly_studio/resolvers/image_resolver/get_dimension_bounds.py +75 -0
- lightly_studio/resolvers/image_resolver/get_many_by_id.py +22 -0
- lightly_studio/resolvers/image_resolver/get_samples_excluding.py +43 -0
- lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
- lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
- lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
- lightly_studio/resolvers/metadata_resolver/sample/bulk_update_metadata.py +46 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
- lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
- lightly_studio/resolvers/sample_embedding_resolver.py +132 -0
- lightly_studio/resolvers/sample_resolver/__init__.py +17 -0
- lightly_studio/resolvers/sample_resolver/count_by_dataset_id.py +16 -0
- lightly_studio/resolvers/sample_resolver/create.py +16 -0
- lightly_studio/resolvers/sample_resolver/create_many.py +25 -0
- lightly_studio/resolvers/sample_resolver/get_by_id.py +14 -0
- lightly_studio/resolvers/sample_resolver/get_filtered_samples.py +56 -0
- lightly_studio/resolvers/sample_resolver/get_many_by_id.py +22 -0
- lightly_studio/resolvers/sample_resolver/sample_filter.py +74 -0
- lightly_studio/resolvers/settings_resolver.py +62 -0
- lightly_studio/resolvers/tag_resolver.py +299 -0
- lightly_studio/resolvers/twodim_embedding_resolver.py +119 -0
- lightly_studio/resolvers/video_frame_resolver/__init__.py +23 -0
- lightly_studio/resolvers/video_frame_resolver/count_video_frames_annotations.py +83 -0
- lightly_studio/resolvers/video_frame_resolver/create_many.py +57 -0
- lightly_studio/resolvers/video_frame_resolver/get_all_by_dataset_id.py +63 -0
- lightly_studio/resolvers/video_frame_resolver/get_by_id.py +13 -0
- lightly_studio/resolvers/video_frame_resolver/get_table_fields_bounds.py +44 -0
- lightly_studio/resolvers/video_frame_resolver/video_frame_annotations_counter_filter.py +47 -0
- lightly_studio/resolvers/video_frame_resolver/video_frame_filter.py +57 -0
- lightly_studio/resolvers/video_resolver/__init__.py +27 -0
- lightly_studio/resolvers/video_resolver/count_video_frame_annotations_by_video_dataset.py +86 -0
- lightly_studio/resolvers/video_resolver/create_many.py +58 -0
- lightly_studio/resolvers/video_resolver/filter_new_paths.py +33 -0
- lightly_studio/resolvers/video_resolver/get_all_by_dataset_id.py +181 -0
- lightly_studio/resolvers/video_resolver/get_by_id.py +22 -0
- lightly_studio/resolvers/video_resolver/get_table_fields_bounds.py +72 -0
- lightly_studio/resolvers/video_resolver/get_view_by_id.py +52 -0
- lightly_studio/resolvers/video_resolver/video_count_annotations_filter.py +50 -0
- lightly_studio/resolvers/video_resolver/video_filter.py +98 -0
- lightly_studio/selection/__init__.py +1 -0
- lightly_studio/selection/mundig.py +143 -0
- lightly_studio/selection/select.py +203 -0
- lightly_studio/selection/select_via_db.py +273 -0
- lightly_studio/selection/selection_config.py +49 -0
- lightly_studio/services/annotations_service/__init__.py +33 -0
- lightly_studio/services/annotations_service/create_annotation.py +64 -0
- lightly_studio/services/annotations_service/delete_annotation.py +22 -0
- lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
- lightly_studio/services/annotations_service/update_annotation.py +54 -0
- lightly_studio/services/annotations_service/update_annotation_bounding_box.py +36 -0
- lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
- lightly_studio/services/annotations_service/update_annotations.py +29 -0
- lightly_studio/setup_logging.py +59 -0
- lightly_studio/type_definitions.py +31 -0
- lightly_studio/utils/__init__.py +3 -0
- lightly_studio/utils/download.py +94 -0
- lightly_studio/vendor/__init__.py +1 -0
- lightly_studio/vendor/mobileclip/ACKNOWLEDGEMENTS +422 -0
- lightly_studio/vendor/mobileclip/LICENSE +31 -0
- lightly_studio/vendor/mobileclip/LICENSE_weights_data +50 -0
- lightly_studio/vendor/mobileclip/README.md +5 -0
- lightly_studio/vendor/mobileclip/__init__.py +96 -0
- lightly_studio/vendor/mobileclip/clip.py +77 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
- lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
- lightly_studio/vendor/mobileclip/logger.py +154 -0
- lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
- lightly_studio/vendor/mobileclip/models/mci.py +933 -0
- lightly_studio/vendor/mobileclip/models/vit.py +433 -0
- lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
- lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
- lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
- lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
- lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
- lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
- lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
- lightly_studio/vendor/perception_encoder/LICENSE.PE +201 -0
- lightly_studio/vendor/perception_encoder/README.md +11 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/__init__.py +0 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/bpe_simple_vocab_16e6.txt.gz +0 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/config.py +205 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/config_src.py +264 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/pe.py +766 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/rope.py +352 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/tokenizer.py +347 -0
- lightly_studio/vendor/perception_encoder/vision_encoder/transforms.py +36 -0
- lightly_studio-0.4.6.dist-info/METADATA +88 -0
- lightly_studio-0.4.6.dist-info/RECORD +356 -0
- lightly_studio-0.4.6.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,495 @@
|
|
|
1
|
+
"""RandomForest classifier implementations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import io
|
|
6
|
+
import pickle
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import sklearn # type: ignore[import-untyped]
|
|
13
|
+
from sklearn.ensemble import ( # type: ignore[import-untyped]
|
|
14
|
+
RandomForestClassifier,
|
|
15
|
+
)
|
|
16
|
+
from sklearn.tree import ( # type: ignore[import-untyped]
|
|
17
|
+
DecisionTreeClassifier,
|
|
18
|
+
)
|
|
19
|
+
from sklearn.utils import validation # type: ignore[import-untyped]
|
|
20
|
+
from typing_extensions import assert_never
|
|
21
|
+
|
|
22
|
+
from .classifier import AnnotatedEmbedding, ExportType, FewShotClassifier
|
|
23
|
+
|
|
24
|
+
# The version of the file format used for exporting and importing classifiers.
|
|
25
|
+
# This is used to ensure compatibility between different versions of the code.
|
|
26
|
+
# If the format changes, this version should be incremented.
|
|
27
|
+
FILE_FORMAT_VERSION = "1.0.0"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class ModelExportMetadata:
|
|
32
|
+
"""Metadata for exporting a model for traceability and reproducibility."""
|
|
33
|
+
|
|
34
|
+
name: str
|
|
35
|
+
file_format_version: str
|
|
36
|
+
model_type: str
|
|
37
|
+
created_at: str
|
|
38
|
+
class_names: list[str]
|
|
39
|
+
num_input_features: int
|
|
40
|
+
num_estimators: int
|
|
41
|
+
embedding_model_hash: str
|
|
42
|
+
embedding_model_name: str
|
|
43
|
+
sklearn_version: str
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class InnerNode:
|
|
48
|
+
"""Inner node of a decision tree.
|
|
49
|
+
|
|
50
|
+
Defaults are used for tree construction.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
feature_index: int = 0
|
|
54
|
+
threshold: float = 0.0
|
|
55
|
+
left_child: int = 0
|
|
56
|
+
right_child: int = 0
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class LeafNode:
|
|
61
|
+
"""Leaf node of a decision tree."""
|
|
62
|
+
|
|
63
|
+
class_probabilities: list[float]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class ExportedTree:
|
|
68
|
+
"""Exported tree structure."""
|
|
69
|
+
|
|
70
|
+
inner_nodes: list[InnerNode]
|
|
71
|
+
leaf_nodes: list[LeafNode]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass
|
|
75
|
+
class RandomForestExport:
|
|
76
|
+
"""Datastructure for exporting the RandomForest model."""
|
|
77
|
+
|
|
78
|
+
metadata: ModelExportMetadata
|
|
79
|
+
trees: list[ExportedTree]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class RandomForest(FewShotClassifier):
|
|
83
|
+
"""RandomForest classifier."""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
name: str,
|
|
88
|
+
classes: list[str],
|
|
89
|
+
embedding_model_name: str,
|
|
90
|
+
embedding_model_hash: str,
|
|
91
|
+
) -> None:
|
|
92
|
+
"""Initialize the RandomForestClassifier with predefined classes.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
name: Name of the classifier.
|
|
96
|
+
classes: Ordered list of class labels that will be used for training
|
|
97
|
+
and predictions. The order of this list determines the order of
|
|
98
|
+
probability values in predictions.
|
|
99
|
+
embedding_model_name: Name of the model used for creating the
|
|
100
|
+
embeddings.
|
|
101
|
+
embedding_model_hash: Hash of the model used for creating the
|
|
102
|
+
embeddings.
|
|
103
|
+
Note: embedding_model_name and embedding_model_hash are used for
|
|
104
|
+
traceability in the exported model metadata.
|
|
105
|
+
|
|
106
|
+
Raises:
|
|
107
|
+
ValueError: If classes list is empty.
|
|
108
|
+
"""
|
|
109
|
+
if not classes:
|
|
110
|
+
raise ValueError("Class list cannot be empty.")
|
|
111
|
+
|
|
112
|
+
# Fix the random seed for reproducibility.
|
|
113
|
+
self._model = RandomForestClassifier(class_weight="balanced", random_state=42)
|
|
114
|
+
self.name = name
|
|
115
|
+
self.classes = classes
|
|
116
|
+
self._class_to_index = {label: idx for idx, label in enumerate(classes)}
|
|
117
|
+
self._embedding_model_name = embedding_model_name
|
|
118
|
+
self.embedding_model_hash = embedding_model_hash
|
|
119
|
+
|
|
120
|
+
def train(self, annotated_embeddings: list[AnnotatedEmbedding]) -> None:
|
|
121
|
+
"""Trains a classifier using the provided input.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
annotated_embeddings: A list of annotated embeddings to train the
|
|
125
|
+
classifier.
|
|
126
|
+
|
|
127
|
+
Raises:
|
|
128
|
+
ValueError: If annotated_embeddings is empty or contains invalid
|
|
129
|
+
classes.
|
|
130
|
+
"""
|
|
131
|
+
if not annotated_embeddings:
|
|
132
|
+
raise ValueError("annotated_embeddings cannot be empty.")
|
|
133
|
+
|
|
134
|
+
# Extract embeddings and labels.
|
|
135
|
+
embeddings = [ae.embedding for ae in annotated_embeddings]
|
|
136
|
+
labels = [ae.annotation for ae in annotated_embeddings]
|
|
137
|
+
# Validate that all labels are in predefined classes.
|
|
138
|
+
invalid_labels = set(labels) - set(self.classes)
|
|
139
|
+
if invalid_labels:
|
|
140
|
+
raise ValueError(f"Found labels not in predefined classes: {invalid_labels}")
|
|
141
|
+
|
|
142
|
+
# Convert to NumPy arrays.
|
|
143
|
+
embeddings_np = np.array(embeddings)
|
|
144
|
+
labels_encoded = [self._class_to_index[label] for label in labels]
|
|
145
|
+
|
|
146
|
+
# Train the RandomForestClassifier.
|
|
147
|
+
self._model.fit(embeddings_np, labels_encoded)
|
|
148
|
+
|
|
149
|
+
def predict(self, embeddings: list[list[float]]) -> list[list[float]]:
|
|
150
|
+
"""Predicts the classification scores for a list of embeddings.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
embeddings: A list of embeddings, where each embedding is a list of
|
|
154
|
+
floats.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
A list of lists, where each inner list represents the probability
|
|
158
|
+
distribution over classes for the corresponding input embedding.
|
|
159
|
+
Each value in the inner list corresponds to the likelihood of the
|
|
160
|
+
embedding belonging to a specific class.
|
|
161
|
+
If embeddings is empty, returns an empty list.
|
|
162
|
+
"""
|
|
163
|
+
if len(embeddings) == 0:
|
|
164
|
+
return []
|
|
165
|
+
|
|
166
|
+
# Convert embeddings to a NumPy array.
|
|
167
|
+
embeddings_np = np.array(embeddings)
|
|
168
|
+
|
|
169
|
+
# Get the classes that the model was trained on.
|
|
170
|
+
trained_classes: list[int] = self._model.classes_
|
|
171
|
+
|
|
172
|
+
# Initialize full-size probability array.
|
|
173
|
+
full_probabilities = []
|
|
174
|
+
|
|
175
|
+
# Get raw probabilities from model.
|
|
176
|
+
raw_probabilities = self._model.predict_proba(embeddings_np)
|
|
177
|
+
|
|
178
|
+
for raw_probs in raw_probabilities:
|
|
179
|
+
# Initialize zeros for all possible classes.
|
|
180
|
+
full_probs = [0.0 for _ in range(len(self.classes))]
|
|
181
|
+
# Map probabilities to their correct positions.
|
|
182
|
+
for trained_class, prob in zip(trained_classes, raw_probs):
|
|
183
|
+
full_probs[trained_class] = prob
|
|
184
|
+
full_probabilities.append(full_probs)
|
|
185
|
+
return full_probabilities
|
|
186
|
+
|
|
187
|
+
def export(
|
|
188
|
+
self,
|
|
189
|
+
export_path: Path | None = None,
|
|
190
|
+
buffer: io.BytesIO | None = None,
|
|
191
|
+
export_type: ExportType = "sklearn",
|
|
192
|
+
) -> None:
|
|
193
|
+
"""Exports the classifier to a specified file.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
export_path: The full file path where the export will be saved.
|
|
197
|
+
buffer: A BytesIO buffer to save the export to.
|
|
198
|
+
export_type: The type of export. Options are:
|
|
199
|
+
"sklearn": Exports the RandomForestClassifier instance.
|
|
200
|
+
"lightly": Exports the model in raw format with metadata
|
|
201
|
+
and tree details.
|
|
202
|
+
"""
|
|
203
|
+
metadata = ModelExportMetadata(
|
|
204
|
+
name=self.name,
|
|
205
|
+
file_format_version=FILE_FORMAT_VERSION,
|
|
206
|
+
model_type="RandomForest",
|
|
207
|
+
created_at=str(datetime.now(timezone.utc).isoformat()),
|
|
208
|
+
class_names=self.classes,
|
|
209
|
+
num_input_features=self._model.n_features_in_,
|
|
210
|
+
num_estimators=len(self._model.estimators_),
|
|
211
|
+
embedding_model_hash=self.embedding_model_hash,
|
|
212
|
+
embedding_model_name=self._embedding_model_name,
|
|
213
|
+
sklearn_version=sklearn.__version__,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
if export_type == "sklearn":
|
|
217
|
+
# Combine the model and metadata into a single dictionary
|
|
218
|
+
export_data = {
|
|
219
|
+
"model": self._model,
|
|
220
|
+
"metadata": metadata,
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
if buffer is not None:
|
|
224
|
+
pickle.dump(export_data, buffer)
|
|
225
|
+
elif export_path is not None:
|
|
226
|
+
# Save to the specified file path.
|
|
227
|
+
# Ensure parent dirs exist.
|
|
228
|
+
export_path.parent.mkdir(parents=True, exist_ok=True)
|
|
229
|
+
with open(export_path, "wb") as f:
|
|
230
|
+
pickle.dump(export_data, f)
|
|
231
|
+
|
|
232
|
+
elif export_type == "lightly":
|
|
233
|
+
export_data_raw = _export_random_forest_model(
|
|
234
|
+
model=self._model,
|
|
235
|
+
metadata=metadata,
|
|
236
|
+
all_classes=self.classes,
|
|
237
|
+
)
|
|
238
|
+
if buffer is not None:
|
|
239
|
+
pickle.dump(export_data_raw, buffer)
|
|
240
|
+
elif export_path is not None:
|
|
241
|
+
# Save to the specified file path.
|
|
242
|
+
# Ensure parent dirs exist.
|
|
243
|
+
export_path.parent.mkdir(parents=True, exist_ok=True)
|
|
244
|
+
with open(export_path, "wb") as f:
|
|
245
|
+
pickle.dump(export_data_raw, f)
|
|
246
|
+
else:
|
|
247
|
+
assert_never(export_type)
|
|
248
|
+
|
|
249
|
+
def is_trained(self) -> bool:
|
|
250
|
+
"""Checks if the classifier is trained.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
True if the classifier is trained, False otherwise.
|
|
254
|
+
"""
|
|
255
|
+
try:
|
|
256
|
+
validation.check_is_fitted(self._model)
|
|
257
|
+
return True
|
|
258
|
+
except sklearn.exceptions.NotFittedError:
|
|
259
|
+
return False
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def load_random_forest_classifier(
|
|
263
|
+
classifier_path: Path | None, buffer: io.BytesIO | None
|
|
264
|
+
) -> RandomForest:
|
|
265
|
+
"""Loads a RandomForest classifier from a file or a buffer.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
classifier_path: The path to the exported classifier file.
|
|
269
|
+
buffer: A BytesIO buffer containing the exported classifier.
|
|
270
|
+
If both path and buffer are provided, the path will be used.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
A fully initialized RandomForest classifier instance.
|
|
274
|
+
|
|
275
|
+
Raises:
|
|
276
|
+
FileNotFoundError: If the classifier_path does not exist.
|
|
277
|
+
ValueError: If the file is not a valid 'sklearn' pickled export
|
|
278
|
+
or if the version/format mismatches.
|
|
279
|
+
"""
|
|
280
|
+
if classifier_path is not None:
|
|
281
|
+
if not classifier_path.exists():
|
|
282
|
+
raise FileNotFoundError(f"The file {classifier_path} does not exist.")
|
|
283
|
+
|
|
284
|
+
with open(classifier_path, "rb") as f:
|
|
285
|
+
export_data = pickle.load(f)
|
|
286
|
+
elif buffer is not None:
|
|
287
|
+
export_data = pickle.load(buffer)
|
|
288
|
+
|
|
289
|
+
model = export_data.get("model")
|
|
290
|
+
metadata: ModelExportMetadata = export_data.get("metadata")
|
|
291
|
+
|
|
292
|
+
if model is None or metadata is None:
|
|
293
|
+
raise ValueError("The loaded file does not contain a valid model or metadata.")
|
|
294
|
+
|
|
295
|
+
if metadata.file_format_version != FILE_FORMAT_VERSION:
|
|
296
|
+
raise ValueError(
|
|
297
|
+
f"File format version mismatch. Expected '{FILE_FORMAT_VERSION}', "
|
|
298
|
+
f"got '{metadata.file_format_version}'."
|
|
299
|
+
)
|
|
300
|
+
if metadata.sklearn_version != sklearn.__version__:
|
|
301
|
+
raise ValueError(
|
|
302
|
+
f"File format mismatch, loading a file format for a different sklearn version. "
|
|
303
|
+
f"File format uses '{metadata.sklearn_version}', got '{sklearn.__version__}'."
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
instance = RandomForest(
|
|
307
|
+
name=metadata.name,
|
|
308
|
+
classes=metadata.class_names,
|
|
309
|
+
embedding_model_name=metadata.embedding_model_name,
|
|
310
|
+
embedding_model_hash=metadata.embedding_model_hash,
|
|
311
|
+
)
|
|
312
|
+
# Set the model.
|
|
313
|
+
instance._model = model # noqa: SLF001
|
|
314
|
+
return instance
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _export_random_forest_model(
|
|
318
|
+
model: RandomForestClassifier,
|
|
319
|
+
metadata: ModelExportMetadata,
|
|
320
|
+
all_classes: list[str],
|
|
321
|
+
) -> RandomForestExport:
|
|
322
|
+
"""Converts a sk-learn RandomForestClassifier to RandomForestExport format.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
model: The trained random forest model to export.
|
|
326
|
+
metadata: Metadata describing the dataset and training setup.
|
|
327
|
+
all_classes: Full list of all class labels.
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
RandomForestExport: The serialized export object containing all trees
|
|
331
|
+
and metadata.
|
|
332
|
+
"""
|
|
333
|
+
trained_classes: list[int] = model.classes_
|
|
334
|
+
trees = [_export_single_tree(tree, trained_classes, all_classes) for tree in model.estimators_]
|
|
335
|
+
return RandomForestExport(metadata=metadata, trees=trees)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def load_lightly_random_forest(path: Path | None, buffer: io.BytesIO | None) -> RandomForestExport:
|
|
339
|
+
"""Loads a Lightly exported RandomForest model from a file or buffer.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
path: The path to the exported classifier file.
|
|
343
|
+
buffer: A BytesIO buffer containing the exported classifier.
|
|
344
|
+
If both path and buffer are provided, the path will be used.
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
A RandomForestExport instance.
|
|
348
|
+
|
|
349
|
+
Raises:
|
|
350
|
+
ValueError: If the file is not a valid RandomForestExport or
|
|
351
|
+
if the version/format mismatches.
|
|
352
|
+
"""
|
|
353
|
+
if path is not None:
|
|
354
|
+
with open(path, "rb") as f:
|
|
355
|
+
data = pickle.load(f)
|
|
356
|
+
elif buffer is not None:
|
|
357
|
+
data = pickle.load(buffer)
|
|
358
|
+
|
|
359
|
+
if not isinstance(data, RandomForestExport):
|
|
360
|
+
raise ValueError("Loaded object is not a RandomForestExport instance.")
|
|
361
|
+
|
|
362
|
+
if data.metadata.file_format_version != FILE_FORMAT_VERSION:
|
|
363
|
+
raise ValueError(
|
|
364
|
+
f"File format version mismatch. Expected '{FILE_FORMAT_VERSION}', "
|
|
365
|
+
f"got '{data.metadata.file_format_version}'."
|
|
366
|
+
)
|
|
367
|
+
return data
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def predict_with_lightly_random_forest(
|
|
371
|
+
model: RandomForestExport, embeddings: list[list[float]]
|
|
372
|
+
) -> list[list[float]]:
|
|
373
|
+
"""Predicts the classification scores for a list of embeddings.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
model: A RandomForestExport instance containing the model and metadata.
|
|
377
|
+
embeddings: A list of embeddings.
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
A list of lists, where each inner list represents the probability
|
|
381
|
+
distribution over classes for the corresponding input embedding.
|
|
382
|
+
|
|
383
|
+
Raises:
|
|
384
|
+
ValueError: If the provided embeddings have different size than
|
|
385
|
+
expected.
|
|
386
|
+
"""
|
|
387
|
+
expected_dim = model.metadata.num_input_features
|
|
388
|
+
all_probs: list[list[float]] = []
|
|
389
|
+
|
|
390
|
+
for embedding in embeddings:
|
|
391
|
+
if len(embedding) != expected_dim:
|
|
392
|
+
raise ValueError(
|
|
393
|
+
f"Embedding has wrong dimensionality: expected {expected_dim},got {len(embedding)}"
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
tree_probs: list[list[float]] = [
|
|
397
|
+
_predict_tree_probs(tree, embedding) for tree in model.trees
|
|
398
|
+
]
|
|
399
|
+
|
|
400
|
+
mean_probs = np.mean(tree_probs, axis=0).tolist()
|
|
401
|
+
all_probs.append(mean_probs)
|
|
402
|
+
|
|
403
|
+
return all_probs
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def _export_single_tree(
|
|
407
|
+
tree: DecisionTreeClassifier,
|
|
408
|
+
trained_classes: list[int],
|
|
409
|
+
all_classes: list[str],
|
|
410
|
+
) -> ExportedTree:
|
|
411
|
+
"""Converts a single sk-learn tree into a serializable ExportedTree format.
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
tree: The decision tree to convert.
|
|
415
|
+
trained_classes: Indices of the classes the tree was trained on.
|
|
416
|
+
all_classes: Full list of all class labels.
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
ExportedTree: A representation of the tree with explicit node and leaf
|
|
420
|
+
structures, compatible with the Lightly format.
|
|
421
|
+
"""
|
|
422
|
+
tree_structure = tree.tree_
|
|
423
|
+
inner_nodes: list[InnerNode] = []
|
|
424
|
+
leaf_nodes: list[LeafNode] = []
|
|
425
|
+
node_map = {} # Maps node_id to (mapped_index, is_leaf)
|
|
426
|
+
|
|
427
|
+
for node_id in range(tree_structure.node_count):
|
|
428
|
+
is_leaf = tree_structure.children_left[node_id] == tree_structure.children_right[node_id]
|
|
429
|
+
if is_leaf:
|
|
430
|
+
leaf_idx = len(leaf_nodes)
|
|
431
|
+
# value[node_id] is a 2D array of shape [1, n_classes].
|
|
432
|
+
# [0] is used to extract the inner array and
|
|
433
|
+
# convert it to a 1D array of class counts.
|
|
434
|
+
class_weights = tree_structure.value[node_id][0]
|
|
435
|
+
total = sum(class_weights)
|
|
436
|
+
probs = (class_weights / total).tolist() if total > 0 else [0.0] * len(class_weights)
|
|
437
|
+
|
|
438
|
+
# Order probabilities according to the initial classes.
|
|
439
|
+
# Initialize zeros for all possible classes.
|
|
440
|
+
full_probs = [0.0 for _ in range(len(all_classes))]
|
|
441
|
+
# Map probabilities to their correct positions.
|
|
442
|
+
for trained_class, prob in zip(trained_classes, probs):
|
|
443
|
+
full_probs[trained_class] = prob
|
|
444
|
+
|
|
445
|
+
leaf_nodes.append(LeafNode(class_probabilities=full_probs))
|
|
446
|
+
node_map[node_id] = (-leaf_idx - 1, True)
|
|
447
|
+
else:
|
|
448
|
+
inner_idx = len(inner_nodes)
|
|
449
|
+
node_map[node_id] = (inner_idx, False)
|
|
450
|
+
# Reserve a spot for the inner node.
|
|
451
|
+
inner_nodes.append(InnerNode())
|
|
452
|
+
|
|
453
|
+
# Now populate inner_nodes using mapped indices.
|
|
454
|
+
for node_id in range(tree_structure.node_count):
|
|
455
|
+
mapped_idx, is_leaf = node_map[node_id]
|
|
456
|
+
if is_leaf:
|
|
457
|
+
continue
|
|
458
|
+
|
|
459
|
+
left_id = tree_structure.children_left[node_id]
|
|
460
|
+
right_id = tree_structure.children_right[node_id]
|
|
461
|
+
left_mapped = node_map[left_id][0]
|
|
462
|
+
right_mapped = node_map[right_id][0]
|
|
463
|
+
|
|
464
|
+
inner_nodes[mapped_idx] = InnerNode(
|
|
465
|
+
feature_index=int(tree_structure.feature[node_id]),
|
|
466
|
+
threshold=float(tree_structure.threshold[node_id]),
|
|
467
|
+
left_child=left_mapped,
|
|
468
|
+
right_child=right_mapped,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
return ExportedTree(inner_nodes=inner_nodes, leaf_nodes=leaf_nodes)
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def _predict_tree_probs(tree: ExportedTree, embedding: list[float]) -> list[float]:
|
|
475
|
+
"""Predicts class probabilities for an embedding using a single tree.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
tree: A ExportedTree instance used to determine the probability.
|
|
479
|
+
embedding: A single embedding.
|
|
480
|
+
|
|
481
|
+
"""
|
|
482
|
+
if not tree.inner_nodes:
|
|
483
|
+
return tree.leaf_nodes[0].class_probabilities
|
|
484
|
+
|
|
485
|
+
node_idx = 0 # Start at root
|
|
486
|
+
while node_idx >= 0:
|
|
487
|
+
node = tree.inner_nodes[node_idx]
|
|
488
|
+
if embedding[node.feature_index] <= node.threshold:
|
|
489
|
+
node_idx = node.left_child
|
|
490
|
+
else:
|
|
491
|
+
node_idx = node.right_child
|
|
492
|
+
|
|
493
|
+
leaf_idx = -node_idx - 1
|
|
494
|
+
leaf = tree.leaf_nodes[leaf_idx]
|
|
495
|
+
return leaf.class_probabilities
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Complex metadata types that can be stored in JSON columns."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Type
|
|
4
|
+
|
|
5
|
+
from lightly_studio.metadata.gps_coordinate import GPSCoordinate
|
|
6
|
+
from lightly_studio.metadata.metadata_protocol import ComplexMetadata
|
|
7
|
+
|
|
8
|
+
# Registry of complex metadata types for automatic serialization/deserialization
|
|
9
|
+
COMPLEX_METADATA_TYPES: Dict[str, Type[ComplexMetadata]] = {
|
|
10
|
+
"gps_coordinate": GPSCoordinate,
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def serialize_complex_metadata(value: Any) -> Any:
|
|
15
|
+
"""Serialize complex metadata for JSON storage.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
value: Value to serialize.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Serialized value if it is ComplexMetadata, the original
|
|
22
|
+
value otherwise.
|
|
23
|
+
"""
|
|
24
|
+
if isinstance(value, ComplexMetadata):
|
|
25
|
+
return value.as_dict()
|
|
26
|
+
|
|
27
|
+
return value
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def deserialize_complex_metadata(value: Any, expected_type: str) -> Any:
|
|
31
|
+
"""Deserialize complex metadata from JSON storage.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
value: Value to deserialize.
|
|
35
|
+
expected_type: Expected type name from schema (e.g., "gps_coordinate").
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Deserialized value (complex metadata object if applicable).
|
|
39
|
+
"""
|
|
40
|
+
# If we have an expected type and the value is a dict, try to deserialize.
|
|
41
|
+
if expected_type and isinstance(value, dict) and expected_type in COMPLEX_METADATA_TYPES:
|
|
42
|
+
try:
|
|
43
|
+
return COMPLEX_METADATA_TYPES[expected_type].from_dict(value)
|
|
44
|
+
except (KeyError, TypeError):
|
|
45
|
+
# If deserialization fails, return the original value.
|
|
46
|
+
pass
|
|
47
|
+
return value
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Computes similarity from embeddings."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from uuid import UUID
|
|
6
|
+
|
|
7
|
+
from lightly_mundig import Similarity # type: ignore[import-untyped]
|
|
8
|
+
from sqlmodel import Session
|
|
9
|
+
|
|
10
|
+
from lightly_studio.dataset.env import LIGHTLY_STUDIO_LICENSE_KEY
|
|
11
|
+
from lightly_studio.errors import TagNotFoundError
|
|
12
|
+
from lightly_studio.resolvers import metadata_resolver, sample_embedding_resolver, tag_resolver
|
|
13
|
+
from lightly_studio.resolvers.sample_resolver.sample_filter import SampleFilter
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def compute_similarity_metadata(
|
|
17
|
+
session: Session,
|
|
18
|
+
key_dataset_id: UUID,
|
|
19
|
+
embedding_model_id: UUID,
|
|
20
|
+
query_tag_id: UUID,
|
|
21
|
+
metadata_name: Optional[str] = None,
|
|
22
|
+
) -> str:
|
|
23
|
+
"""Computes similarity for each sample in the dataset from embeddings.
|
|
24
|
+
|
|
25
|
+
Similarity is a measure of how similar a sample is to its nearest neighbor
|
|
26
|
+
in the embedding space. It can be used to find duplicates.
|
|
27
|
+
|
|
28
|
+
The computed similarity values are stored as metadata for each sample.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
session:
|
|
32
|
+
The database session.
|
|
33
|
+
key_dataset_id:
|
|
34
|
+
The ID of the dataset the similarity is computed on.
|
|
35
|
+
embedding_model_id:
|
|
36
|
+
The ID of the embedding model to use for the computation.
|
|
37
|
+
query_tag_id:
|
|
38
|
+
The ID of the tag describing the query.
|
|
39
|
+
metadata_name:
|
|
40
|
+
The name of the metadata field to store the similarity values in.
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
TagNotFoundError if tag with ID `query_tag_id` does not exist.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
The name of the metadata storing the similarity values.
|
|
47
|
+
"""
|
|
48
|
+
license_key = LIGHTLY_STUDIO_LICENSE_KEY
|
|
49
|
+
if license_key is None:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"LIGHTLY_STUDIO_LICENSE_KEY environment variable is not set. "
|
|
52
|
+
"Please set it to your LightlyStudio license key."
|
|
53
|
+
)
|
|
54
|
+
key_samples = sample_embedding_resolver.get_all_by_dataset_id(
|
|
55
|
+
session=session, dataset_id=key_dataset_id, embedding_model_id=embedding_model_id
|
|
56
|
+
)
|
|
57
|
+
key_embeddings = [sample.embedding for sample in key_samples]
|
|
58
|
+
similarity = Similarity(key_embeddings=key_embeddings, token=license_key)
|
|
59
|
+
|
|
60
|
+
query_tag = tag_resolver.get_by_id(session=session, tag_id=query_tag_id)
|
|
61
|
+
if query_tag is None:
|
|
62
|
+
raise TagNotFoundError("Query tag with ID {query_tag_id} not found")
|
|
63
|
+
tag_filter = SampleFilter(tag_ids=[query_tag_id])
|
|
64
|
+
query_samples = sample_embedding_resolver.get_all_by_dataset_id(
|
|
65
|
+
session=session,
|
|
66
|
+
dataset_id=key_dataset_id,
|
|
67
|
+
embedding_model_id=embedding_model_id,
|
|
68
|
+
filters=tag_filter,
|
|
69
|
+
)
|
|
70
|
+
query_embeddings = [sample.embedding for sample in query_samples]
|
|
71
|
+
similarity_values = similarity.calculate_similarity(query_embeddings=query_embeddings)
|
|
72
|
+
if metadata_name is None:
|
|
73
|
+
date = datetime.now(timezone.utc)
|
|
74
|
+
# Only use whole seconds, such as "2025-11-26T10:11:56'. This is 19 characters.
|
|
75
|
+
formatted_date = date.isoformat()[:19]
|
|
76
|
+
metadata_name = f"similarity_{query_tag.name}_{formatted_date}"
|
|
77
|
+
|
|
78
|
+
metadata = [
|
|
79
|
+
(sample.sample_id, {metadata_name: similarity})
|
|
80
|
+
for sample, similarity in zip(key_samples, similarity_values)
|
|
81
|
+
]
|
|
82
|
+
|
|
83
|
+
metadata_resolver.bulk_update_metadata(session, metadata)
|
|
84
|
+
return metadata_name
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Computes typicality from embeddings."""
|
|
2
|
+
|
|
3
|
+
from uuid import UUID
|
|
4
|
+
|
|
5
|
+
from lightly_mundig import Typicality # type: ignore[import-untyped]
|
|
6
|
+
from sqlmodel import Session
|
|
7
|
+
|
|
8
|
+
from lightly_studio.dataset.env import LIGHTLY_STUDIO_LICENSE_KEY
|
|
9
|
+
from lightly_studio.resolvers import (
|
|
10
|
+
metadata_resolver,
|
|
11
|
+
sample_embedding_resolver,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
DEFAULT_NUM_NEAREST_NEIGHBORS = 20
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def compute_typicality_metadata(
|
|
18
|
+
session: Session,
|
|
19
|
+
dataset_id: UUID,
|
|
20
|
+
embedding_model_id: UUID,
|
|
21
|
+
metadata_name: str = "typicality",
|
|
22
|
+
) -> None:
|
|
23
|
+
"""Computes typicality for each sample in the dataset from embeddings.
|
|
24
|
+
|
|
25
|
+
Typicality is a measure of how representative a sample is of the dataset.
|
|
26
|
+
It is calculated for each sample from its K-nearest neighbors in the
|
|
27
|
+
embedding space.
|
|
28
|
+
|
|
29
|
+
The computed typicality values are stored as metadata for each sample.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
session:
|
|
33
|
+
The database session.
|
|
34
|
+
dataset_id:
|
|
35
|
+
The ID of the dataset for which to compute the typicality.
|
|
36
|
+
embedding_model_id:
|
|
37
|
+
The ID of the embedding model to use for the computation.
|
|
38
|
+
metadata_name:
|
|
39
|
+
The name of the metadata field to store the typicality values in.
|
|
40
|
+
Defaults to "typicality".
|
|
41
|
+
"""
|
|
42
|
+
license_key = LIGHTLY_STUDIO_LICENSE_KEY
|
|
43
|
+
if license_key is None:
|
|
44
|
+
raise ValueError(
|
|
45
|
+
"LIGHTLY_STUDIO_LICENSE_KEY environment variable is not set. "
|
|
46
|
+
"Please set it to your LightlyStudio license key."
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
samples = sample_embedding_resolver.get_all_by_dataset_id(
|
|
50
|
+
session=session, dataset_id=dataset_id, embedding_model_id=embedding_model_id
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
embeddings = [sample.embedding for sample in samples]
|
|
54
|
+
typicality = Typicality(embeddings=embeddings, token=license_key)
|
|
55
|
+
typicality_values = typicality.calculate_typicality(
|
|
56
|
+
num_nearest_neighbors=DEFAULT_NUM_NEAREST_NEIGHBORS
|
|
57
|
+
)
|
|
58
|
+
assert len(samples) == len(typicality_values), (
|
|
59
|
+
"The number of samples and computed typicality values must match"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
metadata = [
|
|
63
|
+
(sample.sample_id, {metadata_name: typicality})
|
|
64
|
+
for sample, typicality in zip(samples, typicality_values)
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
metadata_resolver.bulk_update_metadata(session, metadata)
|