lightly-studio 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lightly-studio might be problematic. Click here for more details.
- lightly_studio/__init__.py +4 -4
- lightly_studio/api/app.py +7 -5
- lightly_studio/api/db_tables.py +0 -3
- lightly_studio/api/routes/api/annotation.py +32 -16
- lightly_studio/api/routes/api/annotation_label.py +2 -5
- lightly_studio/api/routes/api/annotations/__init__.py +7 -0
- lightly_studio/api/routes/api/annotations/create_annotation.py +52 -0
- lightly_studio/api/routes/api/classifier.py +2 -5
- lightly_studio/api/routes/api/dataset.py +5 -8
- lightly_studio/api/routes/api/dataset_tag.py +2 -3
- lightly_studio/api/routes/api/embeddings2d.py +104 -0
- lightly_studio/api/routes/api/export.py +73 -0
- lightly_studio/api/routes/api/metadata.py +2 -4
- lightly_studio/api/routes/api/sample.py +5 -13
- lightly_studio/api/routes/api/selection.py +87 -0
- lightly_studio/api/routes/api/settings.py +2 -6
- lightly_studio/api/routes/images.py +6 -6
- lightly_studio/core/add_samples.py +374 -0
- lightly_studio/core/dataset.py +272 -400
- lightly_studio/core/dataset_query/boolean_expression.py +67 -0
- lightly_studio/core/dataset_query/dataset_query.py +216 -0
- lightly_studio/core/dataset_query/field.py +113 -0
- lightly_studio/core/dataset_query/field_expression.py +79 -0
- lightly_studio/core/dataset_query/match_expression.py +23 -0
- lightly_studio/core/dataset_query/order_by.py +79 -0
- lightly_studio/core/dataset_query/sample_field.py +28 -0
- lightly_studio/core/dataset_query/tags_expression.py +46 -0
- lightly_studio/core/sample.py +159 -32
- lightly_studio/core/start_gui.py +35 -0
- lightly_studio/dataset/edge_embedding_generator.py +13 -8
- lightly_studio/dataset/embedding_generator.py +2 -3
- lightly_studio/dataset/embedding_manager.py +74 -6
- lightly_studio/dataset/env.py +4 -0
- lightly_studio/dataset/file_utils.py +13 -2
- lightly_studio/dataset/fsspec_lister.py +275 -0
- lightly_studio/dataset/loader.py +49 -84
- lightly_studio/dataset/mobileclip_embedding_generator.py +9 -6
- lightly_studio/db_manager.py +145 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.CA_CXIBb.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.DS78jgNY.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/index.BVs_sZj9.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/transform.D487hwJk.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/6t3IJ0vQ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{D6su9Aln.js → 8NsknIT2.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{x9G_hzyY.js → BND_-4Kp.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{BylOuP6i.js → BdfTHw61.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{DOlTMNyt.js → BfHVnyNT.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BjkP1AHA.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BuuNVL9G.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{O-EABkf9.js → BzKGpnl4.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CCx7Ho51.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{l7KrR96u.js → CH6P3X75.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{D5-A_Ffd.js → CR2upx_Q.js} +2 -2
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWPZrTTJ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{C8I8rFJQ.js → Cs1XmhiF.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{CDnpyLsT.js → CwPowJfP.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxFKfZ9T.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cxevwdid.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{DjfY96ND.js → D4whDBUi.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6r9vr07.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DA6bFLPR.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DEgUu98i.js +3 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DGTPl6Gk.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DKGxBSlK.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DQXoLcsF.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DQe_kdRt.js +92 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DcY4jgG3.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{Bu7uvVrG.js → RmD8FzRo.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/V-MnMC1X.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{Bsi3UGy5.js → keKYsoph.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.BVr6DYqP.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.u7zsVvqp.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.Da2agmdd.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{1.B4rNYwVp.js → 1.B11tVRJV.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.l30Zud4h.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CgKPGcAP.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.C8HLK8mj.js +857 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{3.CWHpKonm.js → 3.CLvg3QcJ.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{4.OUWOLQeV.js → 4.BQhDtXUI.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.-6XqWX5G.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.uBV1Lhat.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.BXsgoQZh.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.BkbcnUs8.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{9.CPu3CiBc.js → 9.Bkrv-Vww.js} +1 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/clustering.worker-DKqeLtG0.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/workers/search.worker-vNSty3B0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -1
- lightly_studio/dist_lightly_studio_view_app/index.html +14 -14
- lightly_studio/examples/example.py +13 -12
- lightly_studio/examples/example_coco.py +13 -0
- lightly_studio/examples/example_metadata.py +83 -98
- lightly_studio/examples/example_selection.py +7 -19
- lightly_studio/examples/example_split_work.py +12 -36
- lightly_studio/examples/{example_v2.py → example_yolo.py} +3 -4
- lightly_studio/export/export_dataset.py +65 -0
- lightly_studio/export/lightly_studio_label_input.py +120 -0
- lightly_studio/few_shot_classifier/classifier_manager.py +5 -26
- lightly_studio/metadata/compute_typicality.py +67 -0
- lightly_studio/models/annotation/annotation_base.py +18 -20
- lightly_studio/models/annotation/instance_segmentation.py +8 -8
- lightly_studio/models/annotation/object_detection.py +4 -4
- lightly_studio/models/dataset.py +6 -2
- lightly_studio/models/sample.py +10 -3
- lightly_studio/resolvers/annotation_label_resolver/__init__.py +2 -1
- lightly_studio/resolvers/annotation_label_resolver/get_all.py +15 -0
- lightly_studio/resolvers/annotation_resolver/__init__.py +2 -3
- lightly_studio/resolvers/annotation_resolver/create_many.py +3 -3
- lightly_studio/resolvers/annotation_resolver/delete_annotation.py +1 -1
- lightly_studio/resolvers/annotation_resolver/delete_annotations.py +7 -3
- lightly_studio/resolvers/annotation_resolver/get_by_id.py +19 -1
- lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +0 -1
- lightly_studio/resolvers/annotations/annotations_filter.py +1 -11
- lightly_studio/resolvers/dataset_resolver.py +10 -0
- lightly_studio/resolvers/embedding_model_resolver.py +22 -0
- lightly_studio/resolvers/sample_resolver.py +53 -9
- lightly_studio/resolvers/tag_resolver.py +23 -0
- lightly_studio/selection/mundig.py +7 -10
- lightly_studio/selection/select.py +55 -46
- lightly_studio/selection/select_via_db.py +23 -19
- lightly_studio/selection/selection_config.py +10 -4
- lightly_studio/services/annotations_service/__init__.py +12 -0
- lightly_studio/services/annotations_service/create_annotation.py +63 -0
- lightly_studio/services/annotations_service/delete_annotation.py +22 -0
- lightly_studio/services/annotations_service/update_annotation.py +21 -32
- lightly_studio/services/annotations_service/update_annotation_bounding_box.py +36 -0
- lightly_studio-0.3.3.dist-info/METADATA +814 -0
- {lightly_studio-0.3.1.dist-info → lightly_studio-0.3.3.dist-info}/RECORD +130 -113
- lightly_studio/api/db.py +0 -133
- lightly_studio/api/routes/api/annotation_task.py +0 -38
- lightly_studio/api/routes/api/metrics.py +0 -80
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +0 -93
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +0 -3
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +0 -2
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +0 -6
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +0 -1
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +0 -1
- lightly_studio/metrics/detection/__init__.py +0 -0
- lightly_studio/metrics/detection/map.py +0 -268
- lightly_studio/models/annotation_task.py +0 -28
- lightly_studio/resolvers/annotation_resolver/create.py +0 -19
- lightly_studio/resolvers/annotation_task_resolver.py +0 -31
- lightly_studio-0.3.1.dist-info/METADATA +0 -520
- /lightly_studio/{metrics → core/dataset_query}/__init__.py +0 -0
- /lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/{OpenSans- → OpenSans-Medium.DVUZMR_6.ttf} +0 -0
- {lightly_studio-0.3.1.dist-info → lightly_studio-0.3.3.dist-info}/WHEEL +0 -0
|
@@ -20,3 +20,18 @@ def get_all(session: Session) -> list[AnnotationLabelTable]:
|
|
|
20
20
|
select(AnnotationLabelTable).order_by(col(AnnotationLabelTable.created_at).asc())
|
|
21
21
|
).all()
|
|
22
22
|
return list(labels) if labels else []
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_all_sorted_alphabetically(session: Session) -> list[AnnotationLabelTable]:
|
|
26
|
+
"""Retrieve all annotation labels sorted alphabetically.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
session (Session): The database session.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
list[AnnotationLabelTable]: A list of annotation labels.
|
|
33
|
+
"""
|
|
34
|
+
labels = session.exec(
|
|
35
|
+
select(AnnotationLabelTable).order_by(col(AnnotationLabelTable.annotation_label_name).asc())
|
|
36
|
+
).all()
|
|
37
|
+
return list(labels) if labels else []
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
from lightly_studio.resolvers.annotation_resolver.count_annotations_by_dataset import (
|
|
4
4
|
count_annotations_by_dataset,
|
|
5
5
|
)
|
|
6
|
-
from lightly_studio.resolvers.annotation_resolver.create import create
|
|
7
6
|
from lightly_studio.resolvers.annotation_resolver.create_many import create_many
|
|
8
7
|
from lightly_studio.resolvers.annotation_resolver.delete_annotation import (
|
|
9
8
|
delete_annotation,
|
|
@@ -12,7 +11,7 @@ from lightly_studio.resolvers.annotation_resolver.delete_annotations import (
|
|
|
12
11
|
delete_annotations,
|
|
13
12
|
)
|
|
14
13
|
from lightly_studio.resolvers.annotation_resolver.get_all import get_all
|
|
15
|
-
from lightly_studio.resolvers.annotation_resolver.get_by_id import get_by_id
|
|
14
|
+
from lightly_studio.resolvers.annotation_resolver.get_by_id import get_by_id, get_by_ids
|
|
16
15
|
from lightly_studio.resolvers.annotation_resolver.update_annotation_label import (
|
|
17
16
|
update_annotation_label,
|
|
18
17
|
)
|
|
@@ -22,12 +21,12 @@ from lightly_studio.resolvers.annotation_resolver.update_bounding_box import (
|
|
|
22
21
|
|
|
23
22
|
__all__ = [
|
|
24
23
|
"count_annotations_by_dataset",
|
|
25
|
-
"create",
|
|
26
24
|
"create_many",
|
|
27
25
|
"delete_annotation",
|
|
28
26
|
"delete_annotations",
|
|
29
27
|
"get_all",
|
|
30
28
|
"get_by_id",
|
|
29
|
+
"get_by_ids",
|
|
31
30
|
"update_annotation_label",
|
|
32
31
|
"update_bounding_box",
|
|
33
32
|
]
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from collections.abc import Sequence
|
|
6
|
+
from uuid import UUID
|
|
6
7
|
|
|
7
8
|
from sqlmodel import Session
|
|
8
9
|
|
|
@@ -24,7 +25,7 @@ from lightly_studio.models.annotation.semantic_segmentation import (
|
|
|
24
25
|
def create_many(
|
|
25
26
|
session: Session,
|
|
26
27
|
annotations: list[AnnotationCreate],
|
|
27
|
-
) -> Sequence[
|
|
28
|
+
) -> Sequence[UUID]:
|
|
28
29
|
"""Create many annotations with object detection details in bulk."""
|
|
29
30
|
# Step 1: Create all base annotations
|
|
30
31
|
base_annotations = []
|
|
@@ -37,7 +38,6 @@ def create_many(
|
|
|
37
38
|
db_base_annotation = AnnotationBaseTable(
|
|
38
39
|
annotation_label_id=annotation_create.annotation_label_id,
|
|
39
40
|
annotation_type=annotation_create.annotation_type,
|
|
40
|
-
annotation_task_id=annotation_create.annotation_task_id,
|
|
41
41
|
confidence=annotation_create.confidence,
|
|
42
42
|
dataset_id=annotation_create.dataset_id,
|
|
43
43
|
sample_id=annotation_create.sample_id,
|
|
@@ -93,4 +93,4 @@ def create_many(
|
|
|
93
93
|
# Commit everything
|
|
94
94
|
session.commit()
|
|
95
95
|
|
|
96
|
-
return base_annotations
|
|
96
|
+
return [annotation.annotation_id for annotation in base_annotations]
|
|
@@ -26,7 +26,7 @@ def delete_annotation(
|
|
|
26
26
|
annotation_id=annotation_id,
|
|
27
27
|
)
|
|
28
28
|
if not annotation:
|
|
29
|
-
|
|
29
|
+
raise ValueError(f"Annotation {annotation_id} not found")
|
|
30
30
|
if annotation.object_detection_details:
|
|
31
31
|
session.delete(annotation.object_detection_details)
|
|
32
32
|
if annotation.instance_segmentation_details:
|
|
@@ -18,14 +18,12 @@ from lightly_studio.resolvers.annotations.annotations_filter import (
|
|
|
18
18
|
|
|
19
19
|
def delete_annotations(
|
|
20
20
|
session: Session,
|
|
21
|
-
annotation_task_ids: list[UUID] | None,
|
|
22
21
|
annotation_label_ids: list[UUID] | None,
|
|
23
22
|
) -> None:
|
|
24
23
|
"""Delete all annotations and their tag links using filters.
|
|
25
24
|
|
|
26
25
|
Args:
|
|
27
26
|
session: Database session.
|
|
28
|
-
annotation_task_ids: List of annotation task IDs to filter by.
|
|
29
27
|
annotation_label_ids: List of annotation label IDs to filter by.
|
|
30
28
|
"""
|
|
31
29
|
# Find annotation_ids to delete
|
|
@@ -33,9 +31,15 @@ def delete_annotations(
|
|
|
33
31
|
session,
|
|
34
32
|
filters=AnnotationsFilter(
|
|
35
33
|
annotation_label_ids=annotation_label_ids,
|
|
36
|
-
annotation_task_ids=annotation_task_ids,
|
|
37
34
|
),
|
|
38
35
|
).annotations
|
|
36
|
+
for annotation in annotations:
|
|
37
|
+
if annotation.object_detection_details:
|
|
38
|
+
session.delete(annotation.object_detection_details)
|
|
39
|
+
if annotation.instance_segmentation_details:
|
|
40
|
+
session.delete(annotation.instance_segmentation_details)
|
|
41
|
+
if annotation.semantic_segmentation_details:
|
|
42
|
+
session.delete(annotation.semantic_segmentation_details)
|
|
39
43
|
annotation_ids = [annotation.annotation_id for annotation in annotations]
|
|
40
44
|
# TODO(Horatiu, 06/2025): Check if there is a way to delete the links
|
|
41
45
|
# automatically using SQLModel/SQLAlchemy.
|
|
@@ -2,9 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
from collections.abc import Sequence
|
|
5
6
|
from uuid import UUID
|
|
6
7
|
|
|
7
|
-
from sqlmodel import Session, select
|
|
8
|
+
from sqlmodel import Session, col, select
|
|
8
9
|
|
|
9
10
|
from lightly_studio.models.annotation.annotation_base import (
|
|
10
11
|
AnnotationBaseTable,
|
|
@@ -16,3 +17,20 @@ def get_by_id(session: Session, annotation_id: UUID) -> AnnotationBaseTable | No
|
|
|
16
17
|
return session.exec(
|
|
17
18
|
select(AnnotationBaseTable).where(AnnotationBaseTable.annotation_id == annotation_id)
|
|
18
19
|
).one_or_none()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_by_ids(session: Session, annotation_ids: Sequence[UUID]) -> Sequence[AnnotationBaseTable]:
|
|
23
|
+
"""Retrieve multiple annotations by their IDs.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
session: The database session to use for the query.
|
|
27
|
+
annotation_ids: A list of annotation IDs to retrieve.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
A list of annotations matching the provided IDs.
|
|
31
|
+
"""
|
|
32
|
+
return session.exec(
|
|
33
|
+
select(AnnotationBaseTable).where(
|
|
34
|
+
col(AnnotationBaseTable.annotation_id).in_(annotation_ids)
|
|
35
|
+
)
|
|
36
|
+
).all()
|
|
@@ -111,7 +111,6 @@ def update_annotation_label(
|
|
|
111
111
|
annotation_id=annotation_copy.annotation_id,
|
|
112
112
|
annotation_label_id=annotation_copy.annotation_label_id,
|
|
113
113
|
annotation_type=annotation_copy.annotation_type,
|
|
114
|
-
annotation_task_id=annotation_copy.annotation_task_id,
|
|
115
114
|
confidence=annotation_copy.confidence,
|
|
116
115
|
created_at=annotation_copy.created_at,
|
|
117
116
|
dataset_id=annotation_copy.dataset_id,
|
|
@@ -7,8 +7,7 @@ from uuid import UUID
|
|
|
7
7
|
from pydantic import BaseModel, Field
|
|
8
8
|
from sqlmodel import col
|
|
9
9
|
|
|
10
|
-
from lightly_studio.models.annotation.annotation_base import AnnotationBaseTable
|
|
11
|
-
from lightly_studio.models.annotation_task import AnnotationType
|
|
10
|
+
from lightly_studio.models.annotation.annotation_base import AnnotationBaseTable, AnnotationType
|
|
12
11
|
from lightly_studio.models.sample import SampleTable
|
|
13
12
|
from lightly_studio.models.tag import TagTable
|
|
14
13
|
from lightly_studio.type_definitions import QueryType
|
|
@@ -30,9 +29,6 @@ class AnnotationsFilter(BaseModel):
|
|
|
30
29
|
default=None,
|
|
31
30
|
description="List of sample tag UUIDs to filter annotations by",
|
|
32
31
|
)
|
|
33
|
-
annotation_task_ids: list[UUID] | None = Field(
|
|
34
|
-
default=None, description="List of annotation task UUIDs"
|
|
35
|
-
)
|
|
36
32
|
|
|
37
33
|
def apply(
|
|
38
34
|
self,
|
|
@@ -51,12 +47,6 @@ class AnnotationsFilter(BaseModel):
|
|
|
51
47
|
if self.dataset_ids:
|
|
52
48
|
query = query.where(col(AnnotationBaseTable.dataset_id).in_(self.dataset_ids))
|
|
53
49
|
|
|
54
|
-
# Filter by annotation task
|
|
55
|
-
if self.annotation_task_ids:
|
|
56
|
-
query = query.where(
|
|
57
|
-
col(AnnotationBaseTable.annotation_task_id).in_(self.annotation_task_ids)
|
|
58
|
-
)
|
|
59
|
-
|
|
60
50
|
# Filter by annotation label
|
|
61
51
|
if self.annotation_label_ids:
|
|
62
52
|
query = query.where(
|
|
@@ -67,6 +67,16 @@ def get_by_id(session: Session, dataset_id: UUID) -> DatasetTable | None:
|
|
|
67
67
|
).one_or_none()
|
|
68
68
|
|
|
69
69
|
|
|
70
|
+
def get_by_name(session: Session, name: str) -> DatasetTable | None:
|
|
71
|
+
"""Retrieve a single dataset by name."""
|
|
72
|
+
datasets = session.exec(select(DatasetTable).where(DatasetTable.name == name)).all()
|
|
73
|
+
if len(datasets) == 0:
|
|
74
|
+
return None
|
|
75
|
+
if len(datasets) > 1:
|
|
76
|
+
raise ValueError(f"Cannot retrieve a dataset, found multiple with name '{name}'.")
|
|
77
|
+
return datasets[0]
|
|
78
|
+
|
|
79
|
+
|
|
70
80
|
def update(session: Session, dataset_id: UUID, dataset_data: DatasetCreate) -> DatasetTable:
|
|
71
81
|
"""Update an existing dataset."""
|
|
72
82
|
dataset = get_by_id(session=session, dataset_id=dataset_id)
|
|
@@ -21,6 +21,28 @@ def create(session: Session, embedding_model: EmbeddingModelCreate) -> Embedding
|
|
|
21
21
|
return db_embedding_model
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
def get_or_create(session: Session, embedding_model: EmbeddingModelCreate) -> EmbeddingModelTable:
|
|
25
|
+
"""Retrieve an existing EmbeddingModel by hash or create a new one if it does not exist."""
|
|
26
|
+
db_model = get_by_model_hash(
|
|
27
|
+
session=session, embedding_model_hash=embedding_model.embedding_model_hash
|
|
28
|
+
)
|
|
29
|
+
if db_model is None:
|
|
30
|
+
return create(session=session, embedding_model=embedding_model)
|
|
31
|
+
|
|
32
|
+
# Validate that the existing model matches the provided data.
|
|
33
|
+
if (
|
|
34
|
+
db_model.name != embedding_model.name
|
|
35
|
+
or db_model.parameter_count_in_mb != embedding_model.parameter_count_in_mb
|
|
36
|
+
or db_model.embedding_dimension != embedding_model.embedding_dimension
|
|
37
|
+
# TODO(Michal, 09/2025): Allow same model for different datasets.
|
|
38
|
+
or db_model.dataset_id != embedding_model.dataset_id
|
|
39
|
+
):
|
|
40
|
+
raise ValueError(
|
|
41
|
+
"An embedding model with the same hash but different parameters already exists."
|
|
42
|
+
)
|
|
43
|
+
return db_model
|
|
44
|
+
|
|
45
|
+
|
|
24
46
|
def get_all_by_dataset_id(session: Session, dataset_id: UUID) -> list[EmbeddingModelTable]:
|
|
25
47
|
"""Retrieve all embedding models."""
|
|
26
48
|
embedding_models = session.exec(
|
|
@@ -7,9 +7,11 @@ from datetime import datetime, timezone
|
|
|
7
7
|
from uuid import UUID
|
|
8
8
|
|
|
9
9
|
from pydantic import BaseModel
|
|
10
|
+
from sqlalchemy.orm import joinedload, selectinload
|
|
10
11
|
from sqlmodel import Session, col, func, select
|
|
11
12
|
from sqlmodel.sql.expression import Select
|
|
12
13
|
|
|
14
|
+
from lightly_studio.api.routes.api.validators import Paginated
|
|
13
15
|
from lightly_studio.models.annotation.annotation_base import AnnotationBaseTable
|
|
14
16
|
from lightly_studio.models.annotation_label import AnnotationLabelTable
|
|
15
17
|
from lightly_studio.models.embedding_model import EmbeddingModelTable
|
|
@@ -36,6 +38,22 @@ def create_many(session: Session, samples: list[SampleCreate]) -> list[SampleTab
|
|
|
36
38
|
return db_samples
|
|
37
39
|
|
|
38
40
|
|
|
41
|
+
def filter_new_paths(session: Session, file_paths_abs: list[str]) -> tuple[list[str], list[str]]:
|
|
42
|
+
"""Return a) file_path_abs that do not already exist in the database and b) those that do."""
|
|
43
|
+
existing_file_paths_abs = set(
|
|
44
|
+
session.exec(
|
|
45
|
+
select(col(SampleTable.file_path_abs)).where(
|
|
46
|
+
col(SampleTable.file_path_abs).in_(file_paths_abs)
|
|
47
|
+
)
|
|
48
|
+
).all()
|
|
49
|
+
)
|
|
50
|
+
file_paths_abs_set = set(file_paths_abs)
|
|
51
|
+
return (
|
|
52
|
+
list(file_paths_abs_set - existing_file_paths_abs), # paths that are not in the DB
|
|
53
|
+
list(file_paths_abs_set & existing_file_paths_abs), # paths that are already in the DB
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
39
57
|
def get_by_id(session: Session, dataset_id: UUID, sample_id: UUID) -> SampleTable | None:
|
|
40
58
|
"""Retrieve a single sample by ID."""
|
|
41
59
|
return session.exec(
|
|
@@ -45,6 +63,13 @@ def get_by_id(session: Session, dataset_id: UUID, sample_id: UUID) -> SampleTabl
|
|
|
45
63
|
).one_or_none()
|
|
46
64
|
|
|
47
65
|
|
|
66
|
+
def count_by_dataset_id(session: Session, dataset_id: UUID) -> int:
|
|
67
|
+
"""Count the number of samples in a dataset."""
|
|
68
|
+
return session.exec(
|
|
69
|
+
select(func.count()).select_from(SampleTable).where(SampleTable.dataset_id == dataset_id)
|
|
70
|
+
).one()
|
|
71
|
+
|
|
72
|
+
|
|
48
73
|
def get_many_by_id(session: Session, sample_ids: list[UUID]) -> list[SampleTable]:
|
|
49
74
|
"""Retrieve multiple samples by their IDs.
|
|
50
75
|
|
|
@@ -63,19 +88,33 @@ class GetAllSamplesByDatasetIdResult(BaseModel):
|
|
|
63
88
|
|
|
64
89
|
samples: Sequence[SampleTable]
|
|
65
90
|
total_count: int
|
|
91
|
+
next_cursor: int | None = None
|
|
66
92
|
|
|
67
93
|
|
|
68
94
|
def get_all_by_dataset_id( # noqa: PLR0913
|
|
69
95
|
session: Session,
|
|
70
96
|
dataset_id: UUID,
|
|
71
|
-
|
|
72
|
-
limit: int | None = None,
|
|
97
|
+
pagination: Paginated | None = None,
|
|
73
98
|
filters: SampleFilter | None = None,
|
|
74
99
|
text_embedding: list[float] | None = None,
|
|
75
100
|
sample_ids: list[UUID] | None = None,
|
|
76
101
|
) -> GetAllSamplesByDatasetIdResult:
|
|
77
102
|
"""Retrieve samples for a specific dataset with optional filtering."""
|
|
78
|
-
samples_query =
|
|
103
|
+
samples_query = (
|
|
104
|
+
select(SampleTable)
|
|
105
|
+
.options(
|
|
106
|
+
selectinload(SampleTable.annotations).options(
|
|
107
|
+
joinedload(AnnotationBaseTable.annotation_label),
|
|
108
|
+
joinedload(AnnotationBaseTable.object_detection_details),
|
|
109
|
+
joinedload(AnnotationBaseTable.instance_segmentation_details),
|
|
110
|
+
joinedload(AnnotationBaseTable.semantic_segmentation_details),
|
|
111
|
+
),
|
|
112
|
+
selectinload(SampleTable.tags),
|
|
113
|
+
# Ignore type checker error below as it's a false positive caused by TYPE_CHECKING.
|
|
114
|
+
joinedload(SampleTable.metadata_dict), # type: ignore[arg-type]
|
|
115
|
+
)
|
|
116
|
+
.where(SampleTable.dataset_id == dataset_id)
|
|
117
|
+
)
|
|
79
118
|
total_count_query = (
|
|
80
119
|
select(func.count()).select_from(SampleTable).where(SampleTable.dataset_id == dataset_id)
|
|
81
120
|
)
|
|
@@ -120,15 +159,20 @@ def get_all_by_dataset_id( # noqa: PLR0913
|
|
|
120
159
|
col(SampleTable.created_at).asc(), col(SampleTable.sample_id).asc()
|
|
121
160
|
)
|
|
122
161
|
|
|
123
|
-
#
|
|
124
|
-
if
|
|
125
|
-
samples_query = samples_query.offset(offset)
|
|
126
|
-
|
|
127
|
-
|
|
162
|
+
# Apply pagination if provided
|
|
163
|
+
if pagination is not None:
|
|
164
|
+
samples_query = samples_query.offset(pagination.offset).limit(pagination.limit)
|
|
165
|
+
|
|
166
|
+
total_count = session.exec(total_count_query).one()
|
|
167
|
+
|
|
168
|
+
next_cursor = None
|
|
169
|
+
if pagination and pagination.offset + pagination.limit < total_count:
|
|
170
|
+
next_cursor = pagination.offset + pagination.limit
|
|
128
171
|
|
|
129
172
|
return GetAllSamplesByDatasetIdResult(
|
|
130
173
|
samples=session.exec(samples_query).all(),
|
|
131
|
-
total_count=
|
|
174
|
+
total_count=total_count,
|
|
175
|
+
next_cursor=next_cursor,
|
|
132
176
|
)
|
|
133
177
|
|
|
134
178
|
|
|
@@ -274,3 +274,26 @@ def remove_annotation_ids_from_tag_id(
|
|
|
274
274
|
session.commit()
|
|
275
275
|
session.refresh(tag)
|
|
276
276
|
return tag
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def get_or_create_sample_tag_by_name(
|
|
280
|
+
session: Session,
|
|
281
|
+
dataset_id: UUID,
|
|
282
|
+
tag_name: str,
|
|
283
|
+
) -> TagTable:
|
|
284
|
+
"""Get an existing sample tag by name or create a new one if it doesn't exist.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
session: Database session for executing queries.
|
|
288
|
+
dataset_id: The dataset ID to search/create the tag for.
|
|
289
|
+
tag_name: Name of the tag to get or create.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
The existing or newly created sample tag.
|
|
293
|
+
"""
|
|
294
|
+
existing_tag = get_by_name(session=session, tag_name=tag_name, dataset_id=dataset_id)
|
|
295
|
+
if existing_tag:
|
|
296
|
+
return existing_tag
|
|
297
|
+
|
|
298
|
+
new_tag = TagCreate(name=tag_name, dataset_id=dataset_id, kind="sample")
|
|
299
|
+
return create(session=session, tag=new_tag)
|
|
@@ -10,29 +10,26 @@ from typing import Iterable
|
|
|
10
10
|
# Or remove the type ignore once typing stubs were added manually.
|
|
11
11
|
import lightly_mundig # type: ignore[import-untyped]
|
|
12
12
|
import numpy as np
|
|
13
|
-
|
|
13
|
+
|
|
14
|
+
from lightly_studio.dataset.env import LIGHTLY_STUDIO_LICENSE_KEY
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
class Mundig:
|
|
17
|
-
"""Python
|
|
18
|
+
"""Python interface for the Mundig selection algorithm.
|
|
18
19
|
|
|
19
20
|
This class provides a Python interface to the lightly_mundig Rust library
|
|
20
|
-
for sample selection.
|
|
21
|
+
for sample selection. It allows combining different selection strategies
|
|
22
|
+
such as diversity and weighting.
|
|
21
23
|
"""
|
|
22
24
|
|
|
23
25
|
def __init__(self) -> None:
|
|
24
26
|
"""Initialize the Mundig selection interface."""
|
|
25
|
-
|
|
26
|
-
env = Env()
|
|
27
|
-
env.read_env()
|
|
28
|
-
license_key = env.str("LIGHTLY_STUDIO_LICENSE_KEY", default=None)
|
|
29
|
-
if license_key is None:
|
|
27
|
+
if LIGHTLY_STUDIO_LICENSE_KEY is None:
|
|
30
28
|
raise ValueError(
|
|
31
29
|
"LIGHTLY_STUDIO_LICENSE_KEY environment variable is not set. "
|
|
32
30
|
"Please set it to your LightlyStudio license key."
|
|
33
31
|
)
|
|
34
|
-
|
|
35
|
-
self.mundig = lightly_mundig.Selection(token=license_key)
|
|
32
|
+
self.mundig = lightly_mundig.Selection(token=LIGHTLY_STUDIO_LICENSE_KEY)
|
|
36
33
|
|
|
37
34
|
self.n_input_samples: int | None = None
|
|
38
35
|
|
|
@@ -1,96 +1,105 @@
|
|
|
1
|
-
"""Provides the user python interface to selection."""
|
|
1
|
+
"""Provides the user python interface to selection bound to sample ids."""
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
from collections.abc import Iterable
|
|
6
|
+
from typing import Final
|
|
5
7
|
from uuid import UUID
|
|
6
8
|
|
|
7
9
|
from sqlmodel import Session
|
|
8
10
|
|
|
9
|
-
from lightly_studio.resolvers.samples_filter import SampleFilter
|
|
10
11
|
from lightly_studio.selection.select_via_db import select_via_database
|
|
11
12
|
from lightly_studio.selection.selection_config import (
|
|
12
13
|
EmbeddingDiversityStrategy,
|
|
14
|
+
MetadataWeightingStrategy,
|
|
13
15
|
SelectionConfig,
|
|
14
16
|
SelectionStrategy,
|
|
15
17
|
)
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
class Selection:
|
|
19
|
-
"""
|
|
21
|
+
"""Selection interface for candidate sample ids."""
|
|
20
22
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
# dataset_view = ...
|
|
29
|
-
# dataset_view.select.diverse(...)
|
|
30
|
-
#
|
|
31
|
-
# See https://docs.google.com/document/d/1ZRICdFmfJmxUBy3FFoeUWsAgsCNWDHg8CK5MJiGmX74/edit?tab=t.kbfvnrepsuf#bookmark=id.8klhhwr5q4dp
|
|
32
|
-
|
|
33
|
-
def __init__(self, dataset_id: UUID, session: Session):
|
|
34
|
-
"""Creates the interface to run selection.
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
dataset_id: UUID,
|
|
26
|
+
session: Session,
|
|
27
|
+
input_sample_ids: Iterable[UUID],
|
|
28
|
+
) -> None:
|
|
29
|
+
"""Create the selection interface.
|
|
35
30
|
|
|
36
31
|
Args:
|
|
37
|
-
dataset_id:
|
|
38
|
-
session:
|
|
32
|
+
dataset_id: Dataset in which the selection is performed.
|
|
33
|
+
session: Database session to resolve selection dependencies.
|
|
34
|
+
input_sample_ids: Candidate sample ids considered for selection.
|
|
35
|
+
The iterable is consumed immediately to capture a stable snapshot.
|
|
36
|
+
"""
|
|
37
|
+
self._dataset_id: Final[UUID] = dataset_id
|
|
38
|
+
self._session: Final[Session] = session
|
|
39
|
+
self._input_sample_ids: list[UUID] = list(input_sample_ids)
|
|
39
40
|
|
|
41
|
+
def metadata_weighting(
|
|
42
|
+
self,
|
|
43
|
+
n_samples_to_select: int,
|
|
44
|
+
selection_result_tag_name: str,
|
|
45
|
+
metadata_key: str,
|
|
46
|
+
) -> None:
|
|
47
|
+
"""Select a subset based on numeric metadata weights.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
n_samples_to_select: Number of samples to select.
|
|
51
|
+
selection_result_tag_name: Tag name for the selection result.
|
|
52
|
+
metadata_key: Metadata key used as weights (float or int values).
|
|
40
53
|
"""
|
|
41
|
-
|
|
42
|
-
self.
|
|
54
|
+
strategy = MetadataWeightingStrategy(metadata_key=metadata_key)
|
|
55
|
+
self.multi_strategies(
|
|
56
|
+
n_samples_to_select=n_samples_to_select,
|
|
57
|
+
selection_result_tag_name=selection_result_tag_name,
|
|
58
|
+
selection_strategies=[strategy],
|
|
59
|
+
)
|
|
43
60
|
|
|
44
61
|
def diverse(
|
|
45
62
|
self,
|
|
46
63
|
n_samples_to_select: int,
|
|
47
64
|
selection_result_tag_name: str,
|
|
48
65
|
embedding_model_name: str | None = None,
|
|
49
|
-
sample_filter: SampleFilter | None = None,
|
|
50
66
|
) -> None:
|
|
51
|
-
"""
|
|
67
|
+
"""Select a diverse subset using embeddings.
|
|
52
68
|
|
|
53
69
|
Args:
|
|
54
|
-
n_samples_to_select:
|
|
55
|
-
selection_result_tag_name:
|
|
56
|
-
embedding_model_name:
|
|
57
|
-
|
|
58
|
-
If None, assert that there is only one embedding model and uses it.
|
|
59
|
-
sample_filter: An optional filter to apply to the samples.
|
|
70
|
+
n_samples_to_select: Number of samples to select.
|
|
71
|
+
selection_result_tag_name: Tag name for the selection result.
|
|
72
|
+
embedding_model_name: Optional embedding model name. If None, uses the only
|
|
73
|
+
available model or raises if multiple exist.
|
|
60
74
|
"""
|
|
61
75
|
strategy = EmbeddingDiversityStrategy(embedding_model_name=embedding_model_name)
|
|
62
|
-
|
|
63
|
-
dataset_id=self.dataset_id,
|
|
76
|
+
self.multi_strategies(
|
|
64
77
|
n_samples_to_select=n_samples_to_select,
|
|
65
78
|
selection_result_tag_name=selection_result_tag_name,
|
|
66
|
-
|
|
67
|
-
strategies=[strategy],
|
|
79
|
+
selection_strategies=[strategy],
|
|
68
80
|
)
|
|
69
|
-
select_via_database(session=self.session, config=selection_config)
|
|
70
81
|
|
|
71
82
|
def multi_strategies(
|
|
72
83
|
self,
|
|
73
84
|
n_samples_to_select: int,
|
|
74
85
|
selection_result_tag_name: str,
|
|
75
86
|
selection_strategies: list[SelectionStrategy],
|
|
76
|
-
sample_filter: SampleFilter | None = None,
|
|
77
87
|
) -> None:
|
|
78
|
-
"""Select a subset
|
|
88
|
+
"""Select a subset based on multiple strategies.
|
|
79
89
|
|
|
80
90
|
Args:
|
|
81
|
-
n_samples_to_select:
|
|
82
|
-
selection_result_tag_name:
|
|
83
|
-
selection_strategies:
|
|
84
|
-
Selection strategies to use for the selection. They can be created after
|
|
85
|
-
importing them from `lightly_studio.selection.selection_config`.
|
|
86
|
-
sample_filter: An optional filter to apply to the samples.
|
|
87
|
-
|
|
91
|
+
n_samples_to_select: Number of samples to select.
|
|
92
|
+
selection_result_tag_name: Tag name for the selection result.
|
|
93
|
+
selection_strategies: Strategies to compose for selection.
|
|
88
94
|
"""
|
|
89
95
|
config = SelectionConfig(
|
|
90
|
-
dataset_id=self.
|
|
96
|
+
dataset_id=self._dataset_id,
|
|
91
97
|
n_samples_to_select=n_samples_to_select,
|
|
92
98
|
selection_result_tag_name=selection_result_tag_name,
|
|
93
|
-
sample_filter=sample_filter,
|
|
94
99
|
strategies=selection_strategies,
|
|
95
100
|
)
|
|
96
|
-
select_via_database(
|
|
101
|
+
select_via_database(
|
|
102
|
+
session=self._session,
|
|
103
|
+
config=config,
|
|
104
|
+
input_sample_ids=self._input_sample_ids,
|
|
105
|
+
)
|
|
@@ -3,29 +3,33 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import datetime
|
|
6
|
+
from uuid import UUID
|
|
6
7
|
|
|
7
8
|
from sqlmodel import Session
|
|
8
9
|
|
|
9
10
|
from lightly_studio.models.tag import TagCreate
|
|
10
11
|
from lightly_studio.resolvers import (
|
|
11
12
|
embedding_model_resolver,
|
|
13
|
+
metadata_resolver,
|
|
12
14
|
sample_embedding_resolver,
|
|
13
|
-
sample_resolver,
|
|
14
15
|
tag_resolver,
|
|
15
16
|
)
|
|
16
17
|
from lightly_studio.selection.mundig import Mundig
|
|
17
18
|
from lightly_studio.selection.selection_config import (
|
|
18
19
|
EmbeddingDiversityStrategy,
|
|
20
|
+
MetadataWeightingStrategy,
|
|
19
21
|
SelectionConfig,
|
|
20
22
|
)
|
|
21
23
|
|
|
22
24
|
|
|
23
|
-
def select_via_database(
|
|
24
|
-
|
|
25
|
+
def select_via_database(
|
|
26
|
+
session: Session, config: SelectionConfig, input_sample_ids: list[UUID]
|
|
27
|
+
) -> None:
|
|
28
|
+
"""Run selection using the provided candidate sample ids.
|
|
25
29
|
|
|
26
|
-
First resolves the selection config to
|
|
30
|
+
First resolves the selection config to concrete database values.
|
|
27
31
|
Then calls Mundig to run the selection with pure values.
|
|
28
|
-
|
|
32
|
+
Finally creates a tag for the selected set.
|
|
29
33
|
"""
|
|
30
34
|
# Check if the tag name is already used
|
|
31
35
|
existing_tag = tag_resolver.get_by_name(
|
|
@@ -40,18 +44,7 @@ def select_via_database(session: Session, config: SelectionConfig) -> None:
|
|
|
40
44
|
)
|
|
41
45
|
raise ValueError(msg)
|
|
42
46
|
|
|
43
|
-
|
|
44
|
-
# the latter is implemented.
|
|
45
|
-
# See https://linear.app/lightly/issue/LIG-7292/story-python-ui-mvp1-without-datasetquery-and-sample
|
|
46
|
-
samples = sample_resolver.get_all_by_dataset_id(
|
|
47
|
-
session,
|
|
48
|
-
limit=None,
|
|
49
|
-
dataset_id=config.dataset_id,
|
|
50
|
-
filters=config.sample_filter,
|
|
51
|
-
).samples
|
|
52
|
-
sample_ids = [s.sample_id for s in samples]
|
|
53
|
-
|
|
54
|
-
n_samples_to_select = min(config.n_samples_to_select, len(sample_ids))
|
|
47
|
+
n_samples_to_select = min(config.n_samples_to_select, len(input_sample_ids))
|
|
55
48
|
if n_samples_to_select == 0:
|
|
56
49
|
print("No samples available for selection.")
|
|
57
50
|
return
|
|
@@ -66,16 +59,27 @@ def select_via_database(session: Session, config: SelectionConfig) -> None:
|
|
|
66
59
|
).embedding_model_id
|
|
67
60
|
embedding_tables = sample_embedding_resolver.get_by_sample_ids(
|
|
68
61
|
session=session,
|
|
69
|
-
sample_ids=
|
|
62
|
+
sample_ids=input_sample_ids,
|
|
70
63
|
embedding_model_id=embedding_model_id,
|
|
71
64
|
)
|
|
72
65
|
embeddings = [e.embedding for e in embedding_tables]
|
|
73
66
|
mundig.add_diversity(embeddings=embeddings, strength=strat.strength)
|
|
67
|
+
elif isinstance(strat, MetadataWeightingStrategy):
|
|
68
|
+
key = strat.metadata_key
|
|
69
|
+
weights = []
|
|
70
|
+
for sample_id in input_sample_ids:
|
|
71
|
+
weight = metadata_resolver.get_value_for_sample(session, sample_id, key)
|
|
72
|
+
if not isinstance(weight, (float, int)):
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Metadata {key} is not a number, only numbers can be used as weights"
|
|
75
|
+
)
|
|
76
|
+
weights.append(float(weight))
|
|
77
|
+
mundig.add_weighting(weights, strength=strat.strength)
|
|
74
78
|
else:
|
|
75
79
|
raise ValueError(f"Selection strategy of type {type(strat)} is unknown.")
|
|
76
80
|
|
|
77
81
|
selected_indices = mundig.run(n_samples=n_samples_to_select)
|
|
78
|
-
selected_sample_ids = [
|
|
82
|
+
selected_sample_ids = [input_sample_ids[i] for i in selected_indices]
|
|
79
83
|
|
|
80
84
|
datetime_str = datetime.datetime.now(tz=datetime.timezone.utc).isoformat()
|
|
81
85
|
tag_description = f"Selected at {datetime_str} UTC"
|