lightly-studio 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lightly-studio might be problematic. Click here for more details.

Files changed (133) hide show
  1. lightly_studio/__init__.py +4 -4
  2. lightly_studio/api/app.py +1 -1
  3. lightly_studio/api/routes/api/annotation.py +6 -16
  4. lightly_studio/api/routes/api/annotation_label.py +2 -5
  5. lightly_studio/api/routes/api/annotation_task.py +4 -5
  6. lightly_studio/api/routes/api/classifier.py +2 -5
  7. lightly_studio/api/routes/api/dataset.py +2 -3
  8. lightly_studio/api/routes/api/dataset_tag.py +2 -3
  9. lightly_studio/api/routes/api/metadata.py +2 -4
  10. lightly_studio/api/routes/api/metrics.py +2 -6
  11. lightly_studio/api/routes/api/sample.py +5 -13
  12. lightly_studio/api/routes/api/settings.py +2 -6
  13. lightly_studio/api/routes/images.py +6 -6
  14. lightly_studio/core/add_samples.py +383 -0
  15. lightly_studio/core/dataset.py +250 -362
  16. lightly_studio/core/dataset_query/__init__.py +0 -0
  17. lightly_studio/core/dataset_query/boolean_expression.py +67 -0
  18. lightly_studio/core/dataset_query/dataset_query.py +211 -0
  19. lightly_studio/core/dataset_query/field.py +113 -0
  20. lightly_studio/core/dataset_query/field_expression.py +79 -0
  21. lightly_studio/core/dataset_query/match_expression.py +23 -0
  22. lightly_studio/core/dataset_query/order_by.py +79 -0
  23. lightly_studio/core/dataset_query/sample_field.py +28 -0
  24. lightly_studio/core/dataset_query/tags_expression.py +46 -0
  25. lightly_studio/core/sample.py +159 -32
  26. lightly_studio/core/start_gui.py +35 -0
  27. lightly_studio/dataset/edge_embedding_generator.py +13 -8
  28. lightly_studio/dataset/embedding_generator.py +2 -3
  29. lightly_studio/dataset/embedding_manager.py +74 -6
  30. lightly_studio/dataset/fsspec_lister.py +275 -0
  31. lightly_studio/dataset/loader.py +49 -30
  32. lightly_studio/dataset/mobileclip_embedding_generator.py +6 -4
  33. lightly_studio/db_manager.py +145 -0
  34. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.BBm0IWdq.css +1 -0
  35. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.BNTuXSAe.css +1 -0
  36. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/2O287xak.js +3 -0
  37. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{O-EABkf9.js → 7YNGEs1C.js} +1 -1
  38. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BBoGk9hq.js +1 -0
  39. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BRnH9v23.js +92 -0
  40. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bg1Y5eUZ.js +1 -0
  41. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{DOlTMNyt.js → BqBqV92V.js} +1 -1
  42. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C0JiMuYn.js +1 -0
  43. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{DjfY96ND.js → C98Hk3r5.js} +1 -1
  44. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{r64xT6ao.js → CG0dMCJi.js} +1 -1
  45. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{C8I8rFJQ.js → Ccq4ZD0B.js} +1 -1
  46. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cpy-nab_.js +1 -0
  47. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{Bu7uvVrG.js → Crk-jcvV.js} +1 -1
  48. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cs31G8Qn.js +1 -0
  49. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CsKrY2zA.js +1 -0
  50. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{x9G_hzyY.js → Cur71c3O.js} +1 -1
  51. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CzgC3GFB.js +1 -0
  52. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D8GZDMNN.js +1 -0
  53. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DFRh-Spp.js +1 -0
  54. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{BylOuP6i.js → DRZO-E-T.js} +1 -1
  55. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{l7KrR96u.js → DcGCxgpH.js} +1 -1
  56. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{Bsi3UGy5.js → Df3aMO5B.js} +1 -1
  57. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{hQVEETDE.js → DkR_EZ_B.js} +1 -1
  58. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DqUGznj_.js +1 -0
  59. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KpAtIldw.js +1 -0
  60. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/M1Q1F7bw.js +4 -0
  61. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{CDnpyLsT.js → OH7-C_mc.js} +1 -1
  62. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{D6su9Aln.js → gLNdjSzu.js} +1 -1
  63. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/i0ZZ4z06.js +1 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.BI-EA5gL.js +2 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.CcsRl3cZ.js +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.BbO4Zc3r.js +1 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{1.B4rNYwVp.js → 1._I9GR805.js} +1 -1
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.J2RBFrSr.js +1 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.Cmqj25a-.js +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.C45iKJHA.js +6 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{3.CWHpKonm.js → 3.w9g4AcAx.js} +1 -1
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{4.OUWOLQeV.js → 4.BBI8KwnD.js} +1 -1
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.huHuxdiF.js +1 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.CrbkRPam.js +1 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.FomEdhD6.js +1 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cb_ADSLk.js +1 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{9.CPu3CiBc.js → 9.CajIG5ce.js} +1 -1
  78. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -1
  79. lightly_studio/dist_lightly_studio_view_app/index.html +14 -14
  80. lightly_studio/examples/example.py +13 -12
  81. lightly_studio/examples/example_coco.py +13 -0
  82. lightly_studio/examples/example_metadata.py +83 -98
  83. lightly_studio/examples/example_selection.py +7 -19
  84. lightly_studio/examples/example_split_work.py +12 -36
  85. lightly_studio/examples/{example_v2.py → example_yolo.py} +3 -4
  86. lightly_studio/models/annotation/annotation_base.py +7 -8
  87. lightly_studio/models/annotation/instance_segmentation.py +8 -8
  88. lightly_studio/models/annotation/object_detection.py +4 -4
  89. lightly_studio/models/dataset.py +6 -2
  90. lightly_studio/models/sample.py +10 -3
  91. lightly_studio/resolvers/dataset_resolver.py +10 -0
  92. lightly_studio/resolvers/embedding_model_resolver.py +22 -0
  93. lightly_studio/resolvers/sample_resolver.py +53 -9
  94. lightly_studio/resolvers/tag_resolver.py +23 -0
  95. lightly_studio/selection/select.py +55 -46
  96. lightly_studio/selection/select_via_db.py +23 -19
  97. lightly_studio/selection/selection_config.py +6 -3
  98. lightly_studio/services/annotations_service/__init__.py +4 -0
  99. lightly_studio/services/annotations_service/update_annotation.py +21 -32
  100. lightly_studio/services/annotations_service/update_annotation_bounding_box.py +36 -0
  101. lightly_studio-0.3.2.dist-info/METADATA +689 -0
  102. {lightly_studio-0.3.1.dist-info → lightly_studio-0.3.2.dist-info}/RECORD +104 -91
  103. lightly_studio/api/db.py +0 -133
  104. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +0 -1
  105. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +0 -1
  106. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +0 -1
  107. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +0 -1
  108. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +0 -1
  109. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +0 -1
  110. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +0 -1
  111. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +0 -93
  112. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +0 -3
  113. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +0 -4
  114. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +0 -1
  115. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +0 -1
  116. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +0 -1
  117. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +0 -1
  118. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +0 -1
  119. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +0 -1
  120. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +0 -1
  121. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +0 -2
  122. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +0 -1
  123. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +0 -1
  124. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +0 -1
  125. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +0 -1
  126. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +0 -6
  127. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +0 -1
  128. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +0 -1
  129. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +0 -1
  130. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +0 -1
  131. lightly_studio-0.3.1.dist-info/METADATA +0 -520
  132. /lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/{OpenSans- → OpenSans-Medium.DVUZMR_6.ttf} +0 -0
  133. {lightly_studio-0.3.1.dist-info → lightly_studio-0.3.2.dist-info}/WHEEL +0 -0
@@ -21,6 +21,28 @@ def create(session: Session, embedding_model: EmbeddingModelCreate) -> Embedding
21
21
  return db_embedding_model
22
22
 
23
23
 
24
+ def get_or_create(session: Session, embedding_model: EmbeddingModelCreate) -> EmbeddingModelTable:
25
+ """Retrieve an existing EmbeddingModel by hash or create a new one if it does not exist."""
26
+ db_model = get_by_model_hash(
27
+ session=session, embedding_model_hash=embedding_model.embedding_model_hash
28
+ )
29
+ if db_model is None:
30
+ return create(session=session, embedding_model=embedding_model)
31
+
32
+ # Validate that the existing model matches the provided data.
33
+ if (
34
+ db_model.name != embedding_model.name
35
+ or db_model.parameter_count_in_mb != embedding_model.parameter_count_in_mb
36
+ or db_model.embedding_dimension != embedding_model.embedding_dimension
37
+ # TODO(Michal, 09/2025): Allow same model for different datasets.
38
+ or db_model.dataset_id != embedding_model.dataset_id
39
+ ):
40
+ raise ValueError(
41
+ "An embedding model with the same hash but different parameters already exists."
42
+ )
43
+ return db_model
44
+
45
+
24
46
  def get_all_by_dataset_id(session: Session, dataset_id: UUID) -> list[EmbeddingModelTable]:
25
47
  """Retrieve all embedding models."""
26
48
  embedding_models = session.exec(
@@ -7,9 +7,11 @@ from datetime import datetime, timezone
7
7
  from uuid import UUID
8
8
 
9
9
  from pydantic import BaseModel
10
+ from sqlalchemy.orm import joinedload, selectinload
10
11
  from sqlmodel import Session, col, func, select
11
12
  from sqlmodel.sql.expression import Select
12
13
 
14
+ from lightly_studio.api.routes.api.validators import Paginated
13
15
  from lightly_studio.models.annotation.annotation_base import AnnotationBaseTable
14
16
  from lightly_studio.models.annotation_label import AnnotationLabelTable
15
17
  from lightly_studio.models.embedding_model import EmbeddingModelTable
@@ -36,6 +38,22 @@ def create_many(session: Session, samples: list[SampleCreate]) -> list[SampleTab
36
38
  return db_samples
37
39
 
38
40
 
41
+ def filter_new_paths(session: Session, file_paths_abs: list[str]) -> tuple[list[str], list[str]]:
42
+ """Return a) file_path_abs that do not already exist in the database and b) those that do."""
43
+ existing_file_paths_abs = set(
44
+ session.exec(
45
+ select(col(SampleTable.file_path_abs)).where(
46
+ col(SampleTable.file_path_abs).in_(file_paths_abs)
47
+ )
48
+ ).all()
49
+ )
50
+ file_paths_abs_set = set(file_paths_abs)
51
+ return (
52
+ list(file_paths_abs_set - existing_file_paths_abs), # paths that are not in the DB
53
+ list(file_paths_abs_set & existing_file_paths_abs), # paths that are already in the DB
54
+ )
55
+
56
+
39
57
  def get_by_id(session: Session, dataset_id: UUID, sample_id: UUID) -> SampleTable | None:
40
58
  """Retrieve a single sample by ID."""
41
59
  return session.exec(
@@ -45,6 +63,13 @@ def get_by_id(session: Session, dataset_id: UUID, sample_id: UUID) -> SampleTabl
45
63
  ).one_or_none()
46
64
 
47
65
 
66
+ def count_by_dataset_id(session: Session, dataset_id: UUID) -> int:
67
+ """Count the number of samples in a dataset."""
68
+ return session.exec(
69
+ select(func.count()).select_from(SampleTable).where(SampleTable.dataset_id == dataset_id)
70
+ ).one()
71
+
72
+
48
73
  def get_many_by_id(session: Session, sample_ids: list[UUID]) -> list[SampleTable]:
49
74
  """Retrieve multiple samples by their IDs.
50
75
 
@@ -63,19 +88,33 @@ class GetAllSamplesByDatasetIdResult(BaseModel):
63
88
 
64
89
  samples: Sequence[SampleTable]
65
90
  total_count: int
91
+ next_cursor: int | None = None
66
92
 
67
93
 
68
94
  def get_all_by_dataset_id( # noqa: PLR0913
69
95
  session: Session,
70
96
  dataset_id: UUID,
71
- offset: int = 0,
72
- limit: int | None = None,
97
+ pagination: Paginated | None = None,
73
98
  filters: SampleFilter | None = None,
74
99
  text_embedding: list[float] | None = None,
75
100
  sample_ids: list[UUID] | None = None,
76
101
  ) -> GetAllSamplesByDatasetIdResult:
77
102
  """Retrieve samples for a specific dataset with optional filtering."""
78
- samples_query = select(SampleTable).where(SampleTable.dataset_id == dataset_id)
103
+ samples_query = (
104
+ select(SampleTable)
105
+ .options(
106
+ selectinload(SampleTable.annotations).options(
107
+ joinedload(AnnotationBaseTable.annotation_label),
108
+ joinedload(AnnotationBaseTable.object_detection_details),
109
+ joinedload(AnnotationBaseTable.instance_segmentation_details),
110
+ joinedload(AnnotationBaseTable.semantic_segmentation_details),
111
+ ),
112
+ selectinload(SampleTable.tags),
113
+ # Ignore type checker error below as it's a false positive caused by TYPE_CHECKING.
114
+ joinedload(SampleTable.metadata_dict), # type: ignore[arg-type]
115
+ )
116
+ .where(SampleTable.dataset_id == dataset_id)
117
+ )
79
118
  total_count_query = (
80
119
  select(func.count()).select_from(SampleTable).where(SampleTable.dataset_id == dataset_id)
81
120
  )
@@ -120,15 +159,20 @@ def get_all_by_dataset_id( # noqa: PLR0913
120
159
  col(SampleTable.created_at).asc(), col(SampleTable.sample_id).asc()
121
160
  )
122
161
 
123
- # paginate query when offset or limit are set/positive
124
- if offset > 0:
125
- samples_query = samples_query.offset(offset)
126
- if limit is not None:
127
- samples_query = samples_query.limit(limit)
162
+ # Apply pagination if provided
163
+ if pagination is not None:
164
+ samples_query = samples_query.offset(pagination.offset).limit(pagination.limit)
165
+
166
+ total_count = session.exec(total_count_query).one()
167
+
168
+ next_cursor = None
169
+ if pagination and pagination.offset + pagination.limit < total_count:
170
+ next_cursor = pagination.offset + pagination.limit
128
171
 
129
172
  return GetAllSamplesByDatasetIdResult(
130
173
  samples=session.exec(samples_query).all(),
131
- total_count=session.exec(total_count_query).one(),
174
+ total_count=total_count,
175
+ next_cursor=next_cursor,
132
176
  )
133
177
 
134
178
 
@@ -274,3 +274,26 @@ def remove_annotation_ids_from_tag_id(
274
274
  session.commit()
275
275
  session.refresh(tag)
276
276
  return tag
277
+
278
+
279
+ def get_or_create_sample_tag_by_name(
280
+ session: Session,
281
+ dataset_id: UUID,
282
+ tag_name: str,
283
+ ) -> TagTable:
284
+ """Get an existing sample tag by name or create a new one if it doesn't exist.
285
+
286
+ Args:
287
+ session: Database session for executing queries.
288
+ dataset_id: The dataset ID to search/create the tag for.
289
+ tag_name: Name of the tag to get or create.
290
+
291
+ Returns:
292
+ The existing or newly created sample tag.
293
+ """
294
+ existing_tag = get_by_name(session=session, tag_name=tag_name, dataset_id=dataset_id)
295
+ if existing_tag:
296
+ return existing_tag
297
+
298
+ new_tag = TagCreate(name=tag_name, dataset_id=dataset_id, kind="sample")
299
+ return create(session=session, tag=new_tag)
@@ -1,96 +1,105 @@
1
- """Provides the user python interface to selection."""
1
+ """Provides the user python interface to selection bound to sample ids."""
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from collections.abc import Iterable
6
+ from typing import Final
5
7
  from uuid import UUID
6
8
 
7
9
  from sqlmodel import Session
8
10
 
9
- from lightly_studio.resolvers.samples_filter import SampleFilter
10
11
  from lightly_studio.selection.select_via_db import select_via_database
11
12
  from lightly_studio.selection.selection_config import (
12
13
  EmbeddingDiversityStrategy,
14
+ MetadataWeightingStrategy,
13
15
  SelectionConfig,
14
16
  SelectionStrategy,
15
17
  )
16
18
 
17
19
 
18
20
  class Selection:
19
- """User selection interface for the dataset."""
21
+ """Selection interface for candidate sample ids."""
20
22
 
21
- # TODO(Malte, 08/2025): Create this class within the DatasetView.
22
- # Then the arguments can be passed directly from the DatasetView.
23
- # Example:
24
- # class DatasetView:
25
- # def __init__(self, dataset_id: UUID, session: Session):
26
- # self.select = Select(dataset_id, session)
27
- # User interface:
28
- # dataset_view = ...
29
- # dataset_view.select.diverse(...)
30
- #
31
- # See https://docs.google.com/document/d/1ZRICdFmfJmxUBy3FFoeUWsAgsCNWDHg8CK5MJiGmX74/edit?tab=t.kbfvnrepsuf#bookmark=id.8klhhwr5q4dp
32
-
33
- def __init__(self, dataset_id: UUID, session: Session):
34
- """Creates the interface to run selection.
23
+ def __init__(
24
+ self,
25
+ dataset_id: UUID,
26
+ session: Session,
27
+ input_sample_ids: Iterable[UUID],
28
+ ) -> None:
29
+ """Create the selection interface.
35
30
 
36
31
  Args:
37
- dataset_id: The ID of the dataset to select from.
38
- session: The database session to use for selection.
32
+ dataset_id: Dataset in which the selection is performed.
33
+ session: Database session to resolve selection dependencies.
34
+ input_sample_ids: Candidate sample ids considered for selection.
35
+ The iterable is consumed immediately to capture a stable snapshot.
36
+ """
37
+ self._dataset_id: Final[UUID] = dataset_id
38
+ self._session: Final[Session] = session
39
+ self._input_sample_ids: list[UUID] = list(input_sample_ids)
39
40
 
41
+ def metadata_weighting(
42
+ self,
43
+ n_samples_to_select: int,
44
+ selection_result_tag_name: str,
45
+ metadata_key: str,
46
+ ) -> None:
47
+ """Select a subset based on numeric metadata weights.
48
+
49
+ Args:
50
+ n_samples_to_select: Number of samples to select.
51
+ selection_result_tag_name: Tag name for the selection result.
52
+ metadata_key: Metadata key used as weights (float or int values).
40
53
  """
41
- self.dataset_id = dataset_id
42
- self.session = session
54
+ strategy = MetadataWeightingStrategy(metadata_key=metadata_key)
55
+ self.multi_strategies(
56
+ n_samples_to_select=n_samples_to_select,
57
+ selection_result_tag_name=selection_result_tag_name,
58
+ selection_strategies=[strategy],
59
+ )
43
60
 
44
61
  def diverse(
45
62
  self,
46
63
  n_samples_to_select: int,
47
64
  selection_result_tag_name: str,
48
65
  embedding_model_name: str | None = None,
49
- sample_filter: SampleFilter | None = None,
50
66
  ) -> None:
51
- """Selects a diverse subset of the dataset.
67
+ """Select a diverse subset using embeddings.
52
68
 
53
69
  Args:
54
- n_samples_to_select: The number of samples to select.
55
- selection_result_tag_name: The tag name to use for the selection result.
56
- embedding_model_name:
57
- The name of the embedding model to use.
58
- If None, assert that there is only one embedding model and uses it.
59
- sample_filter: An optional filter to apply to the samples.
70
+ n_samples_to_select: Number of samples to select.
71
+ selection_result_tag_name: Tag name for the selection result.
72
+ embedding_model_name: Optional embedding model name. If None, uses the only
73
+ available model or raises if multiple exist.
60
74
  """
61
75
  strategy = EmbeddingDiversityStrategy(embedding_model_name=embedding_model_name)
62
- selection_config = SelectionConfig(
63
- dataset_id=self.dataset_id,
76
+ self.multi_strategies(
64
77
  n_samples_to_select=n_samples_to_select,
65
78
  selection_result_tag_name=selection_result_tag_name,
66
- sample_filter=sample_filter,
67
- strategies=[strategy],
79
+ selection_strategies=[strategy],
68
80
  )
69
- select_via_database(session=self.session, config=selection_config)
70
81
 
71
82
  def multi_strategies(
72
83
  self,
73
84
  n_samples_to_select: int,
74
85
  selection_result_tag_name: str,
75
86
  selection_strategies: list[SelectionStrategy],
76
- sample_filter: SampleFilter | None = None,
77
87
  ) -> None:
78
- """Select a subset of the dataset based on multiple selection strategies.
88
+ """Select a subset based on multiple strategies.
79
89
 
80
90
  Args:
81
- n_samples_to_select: The number of samples to select.
82
- selection_result_tag_name: The tag name to use for the selection result.
83
- selection_strategies:
84
- Selection strategies to use for the selection. They can be created after
85
- importing them from `lightly_studio.selection.selection_config`.
86
- sample_filter: An optional filter to apply to the samples.
87
-
91
+ n_samples_to_select: Number of samples to select.
92
+ selection_result_tag_name: Tag name for the selection result.
93
+ selection_strategies: Strategies to compose for selection.
88
94
  """
89
95
  config = SelectionConfig(
90
- dataset_id=self.dataset_id,
96
+ dataset_id=self._dataset_id,
91
97
  n_samples_to_select=n_samples_to_select,
92
98
  selection_result_tag_name=selection_result_tag_name,
93
- sample_filter=sample_filter,
94
99
  strategies=selection_strategies,
95
100
  )
96
- select_via_database(session=self.session, config=config)
101
+ select_via_database(
102
+ session=self._session,
103
+ config=config,
104
+ input_sample_ids=self._input_sample_ids,
105
+ )
@@ -3,29 +3,33 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import datetime
6
+ from uuid import UUID
6
7
 
7
8
  from sqlmodel import Session
8
9
 
9
10
  from lightly_studio.models.tag import TagCreate
10
11
  from lightly_studio.resolvers import (
11
12
  embedding_model_resolver,
13
+ metadata_resolver,
12
14
  sample_embedding_resolver,
13
- sample_resolver,
14
15
  tag_resolver,
15
16
  )
16
17
  from lightly_studio.selection.mundig import Mundig
17
18
  from lightly_studio.selection.selection_config import (
18
19
  EmbeddingDiversityStrategy,
20
+ MetadataWeightingStrategy,
19
21
  SelectionConfig,
20
22
  )
21
23
 
22
24
 
23
- def select_via_database(session: Session, config: SelectionConfig) -> None:
24
- """Runs selection and all database interactions of it.
25
+ def select_via_database(
26
+ session: Session, config: SelectionConfig, input_sample_ids: list[UUID]
27
+ ) -> None:
28
+ """Run selection using the provided candidate sample ids.
25
29
 
26
- First resolves the selection config to actual database values.
30
+ First resolves the selection config to concrete database values.
27
31
  Then calls Mundig to run the selection with pure values.
28
- Last creates a tag for the selected set.
32
+ Finally creates a tag for the selected set.
29
33
  """
30
34
  # Check if the tag name is already used
31
35
  existing_tag = tag_resolver.get_by_name(
@@ -40,18 +44,7 @@ def select_via_database(session: Session, config: SelectionConfig) -> None:
40
44
  )
41
45
  raise ValueError(msg)
42
46
 
43
- # TODO(Malte, 08/2025): Use a DatasetQuery instead of SampleFilter once
44
- # the latter is implemented.
45
- # See https://linear.app/lightly/issue/LIG-7292/story-python-ui-mvp1-without-datasetquery-and-sample
46
- samples = sample_resolver.get_all_by_dataset_id(
47
- session,
48
- limit=None,
49
- dataset_id=config.dataset_id,
50
- filters=config.sample_filter,
51
- ).samples
52
- sample_ids = [s.sample_id for s in samples]
53
-
54
- n_samples_to_select = min(config.n_samples_to_select, len(sample_ids))
47
+ n_samples_to_select = min(config.n_samples_to_select, len(input_sample_ids))
55
48
  if n_samples_to_select == 0:
56
49
  print("No samples available for selection.")
57
50
  return
@@ -66,16 +59,27 @@ def select_via_database(session: Session, config: SelectionConfig) -> None:
66
59
  ).embedding_model_id
67
60
  embedding_tables = sample_embedding_resolver.get_by_sample_ids(
68
61
  session=session,
69
- sample_ids=sample_ids,
62
+ sample_ids=input_sample_ids,
70
63
  embedding_model_id=embedding_model_id,
71
64
  )
72
65
  embeddings = [e.embedding for e in embedding_tables]
73
66
  mundig.add_diversity(embeddings=embeddings, strength=strat.strength)
67
+ elif isinstance(strat, MetadataWeightingStrategy):
68
+ key = strat.metadata_key
69
+ weights = []
70
+ for sample_id in input_sample_ids:
71
+ weight = metadata_resolver.get_value_for_sample(session, sample_id, key)
72
+ if not isinstance(weight, (float, int)):
73
+ raise ValueError(
74
+ f"Metadata {key} is not a number, only numbers can be used as weights"
75
+ )
76
+ weights.append(float(weight))
77
+ mundig.add_weighting(weights, strength=strat.strength)
74
78
  else:
75
79
  raise ValueError(f"Selection strategy of type {type(strat)} is unknown.")
76
80
 
77
81
  selected_indices = mundig.run(n_samples=n_samples_to_select)
78
- selected_sample_ids = [sample_ids[i] for i in selected_indices]
82
+ selected_sample_ids = [input_sample_ids[i] for i in selected_indices]
79
83
 
80
84
  datetime_str = datetime.datetime.now(tz=datetime.timezone.utc).isoformat()
81
85
  tag_description = f"Selected at {datetime_str} UTC"
@@ -6,14 +6,11 @@ from uuid import UUID
6
6
 
7
7
  from pydantic import BaseModel
8
8
 
9
- from lightly_studio.resolvers.samples_filter import SampleFilter
10
-
11
9
 
12
10
  class SelectionConfig(BaseModel):
13
11
  """Configuration for the selection process."""
14
12
 
15
13
  dataset_id: UUID
16
- sample_filter: SampleFilter | None = None
17
14
  n_samples_to_select: int
18
15
  selection_result_tag_name: str
19
16
  strategies: list[SelectionStrategy]
@@ -29,3 +26,9 @@ class EmbeddingDiversityStrategy(SelectionStrategy):
29
26
  """Selection strategy based on embedding diversity."""
30
27
 
31
28
  embedding_model_name: str | None
29
+
30
+
31
+ class MetadataWeightingStrategy(SelectionStrategy):
32
+ """Selection strategy based on metadata weighting."""
33
+
34
+ metadata_key: str
@@ -6,6 +6,9 @@ from lightly_studio.services.annotations_service.get_annotation_by_id import (
6
6
  from lightly_studio.services.annotations_service.update_annotation import (
7
7
  update_annotation,
8
8
  )
9
+ from lightly_studio.services.annotations_service.update_annotation_bounding_box import (
10
+ update_annotation_bounding_box,
11
+ )
9
12
  from lightly_studio.services.annotations_service.update_annotation_label import (
10
13
  update_annotation_label,
11
14
  )
@@ -16,6 +19,7 @@ from lightly_studio.services.annotations_service.update_annotations import (
16
19
  __all__ = [
17
20
  "get_annotation_by_id",
18
21
  "update_annotation",
22
+ "update_annotation_bounding_box",
19
23
  "update_annotation_label",
20
24
  "update_annotations",
21
25
  ]
@@ -10,6 +10,7 @@ from sqlmodel import Session
10
10
  from lightly_studio.models.annotation.annotation_base import (
11
11
  AnnotationBaseTable,
12
12
  )
13
+ from lightly_studio.resolvers.annotation_resolver.update_bounding_box import BoundingBoxCoordinates
13
14
  from lightly_studio.services import annotations_service
14
15
 
15
16
 
@@ -18,11 +19,8 @@ class AnnotationUpdate(BaseModel):
18
19
 
19
20
  annotation_id: UUID
20
21
  dataset_id: UUID
21
- label_name: str | None
22
- x: int | None = None
23
- y: int | None = None
24
- width: int | None = None
25
- height: int | None = None
22
+ label_name: str | None = None
23
+ bounding_box: BoundingBoxCoordinates | None = None
26
24
 
27
25
 
28
26
  def update_annotation(session: Session, annotation_update: AnnotationUpdate) -> AnnotationBaseTable:
@@ -36,30 +34,21 @@ def update_annotation(session: Session, annotation_update: AnnotationUpdate) ->
36
34
  The updated annotation.
37
35
 
38
36
  """
39
- if annotation_update.label_name is None:
40
- raise ValueError("Label name must be provided for updating annotation")
41
-
42
- # comment this out for now so e2e tests will pass
43
- # todo: uncomment after passing bbox coordinates on update from frontend
44
- # annotation=get_annotation_by_id(session,annotation_update.annotation_id)
45
- # if annotation.annotation_type in (
46
- # AnnotationType.OBJECT_DETECTION,
47
- # AnnotationType.INSTANCE_SEGMENTATION,
48
- # ) and any(
49
- # [
50
- # annotation_update.x is None,
51
- # annotation_update.y is None,
52
- # annotation_update.width is None,
53
- # annotation_update.height is None,
54
- # ]
55
- # ):
56
- # raise ValueError(
57
- # "All bounding box coordinates (x, y, width, height) "
58
- # "must be provided for updating this annotation type"
59
- # )
60
-
61
- return annotations_service.update_annotation_label(
62
- session,
63
- annotation_update.annotation_id,
64
- annotation_update.label_name,
65
- )
37
+ result = None
38
+ if annotation_update.label_name is not None:
39
+ result = annotations_service.update_annotation_label(
40
+ session,
41
+ annotation_update.annotation_id,
42
+ annotation_update.label_name,
43
+ )
44
+
45
+ if annotation_update.bounding_box is not None:
46
+ result = annotations_service.update_annotation_bounding_box(
47
+ session,
48
+ annotation_update.annotation_id,
49
+ bounding_box=annotation_update.bounding_box,
50
+ )
51
+
52
+ if result is None:
53
+ raise ValueError("No updates provided for the annotation.")
54
+ return result
@@ -0,0 +1,36 @@
1
+ """Update the bounding box of an annotation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from uuid import UUID
6
+
7
+ from sqlmodel import Session
8
+
9
+ from lightly_studio.models.annotation.annotation_base import (
10
+ AnnotationBaseTable,
11
+ )
12
+ from lightly_studio.resolvers import (
13
+ annotation_resolver,
14
+ )
15
+ from lightly_studio.resolvers.annotation_resolver.update_bounding_box import BoundingBoxCoordinates
16
+
17
+
18
+ def update_annotation_bounding_box(
19
+ session: Session, annotation_id: UUID, bounding_box: BoundingBoxCoordinates
20
+ ) -> AnnotationBaseTable:
21
+ """Update the bounding box of an annotation.
22
+
23
+ Args:
24
+ session: Database session for executing the operation.
25
+ annotation_id: UUID of the annotation to update.
26
+ bounding_box: New bounding box coordinates to assign to the annotation.
27
+
28
+ Returns:
29
+ The updated annotation with the new bounding box assigned.
30
+
31
+ """
32
+ return annotation_resolver.update_bounding_box(
33
+ session,
34
+ annotation_id,
35
+ bounding_box,
36
+ )