lightly-studio 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lightly-studio might be problematic. Click here for more details.

Files changed (133) hide show
  1. lightly_studio/__init__.py +4 -4
  2. lightly_studio/api/app.py +1 -1
  3. lightly_studio/api/routes/api/annotation.py +6 -16
  4. lightly_studio/api/routes/api/annotation_label.py +2 -5
  5. lightly_studio/api/routes/api/annotation_task.py +4 -5
  6. lightly_studio/api/routes/api/classifier.py +2 -5
  7. lightly_studio/api/routes/api/dataset.py +2 -3
  8. lightly_studio/api/routes/api/dataset_tag.py +2 -3
  9. lightly_studio/api/routes/api/metadata.py +2 -4
  10. lightly_studio/api/routes/api/metrics.py +2 -6
  11. lightly_studio/api/routes/api/sample.py +5 -13
  12. lightly_studio/api/routes/api/settings.py +2 -6
  13. lightly_studio/api/routes/images.py +6 -6
  14. lightly_studio/core/add_samples.py +383 -0
  15. lightly_studio/core/dataset.py +250 -362
  16. lightly_studio/core/dataset_query/__init__.py +0 -0
  17. lightly_studio/core/dataset_query/boolean_expression.py +67 -0
  18. lightly_studio/core/dataset_query/dataset_query.py +211 -0
  19. lightly_studio/core/dataset_query/field.py +113 -0
  20. lightly_studio/core/dataset_query/field_expression.py +79 -0
  21. lightly_studio/core/dataset_query/match_expression.py +23 -0
  22. lightly_studio/core/dataset_query/order_by.py +79 -0
  23. lightly_studio/core/dataset_query/sample_field.py +28 -0
  24. lightly_studio/core/dataset_query/tags_expression.py +46 -0
  25. lightly_studio/core/sample.py +159 -32
  26. lightly_studio/core/start_gui.py +35 -0
  27. lightly_studio/dataset/edge_embedding_generator.py +13 -8
  28. lightly_studio/dataset/embedding_generator.py +2 -3
  29. lightly_studio/dataset/embedding_manager.py +74 -6
  30. lightly_studio/dataset/fsspec_lister.py +275 -0
  31. lightly_studio/dataset/loader.py +49 -30
  32. lightly_studio/dataset/mobileclip_embedding_generator.py +6 -4
  33. lightly_studio/db_manager.py +145 -0
  34. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.BBm0IWdq.css +1 -0
  35. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.BNTuXSAe.css +1 -0
  36. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/2O287xak.js +3 -0
  37. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{O-EABkf9.js → 7YNGEs1C.js} +1 -1
  38. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BBoGk9hq.js +1 -0
  39. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BRnH9v23.js +92 -0
  40. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bg1Y5eUZ.js +1 -0
  41. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{DOlTMNyt.js → BqBqV92V.js} +1 -1
  42. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C0JiMuYn.js +1 -0
  43. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{DjfY96ND.js → C98Hk3r5.js} +1 -1
  44. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{r64xT6ao.js → CG0dMCJi.js} +1 -1
  45. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{C8I8rFJQ.js → Ccq4ZD0B.js} +1 -1
  46. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cpy-nab_.js +1 -0
  47. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{Bu7uvVrG.js → Crk-jcvV.js} +1 -1
  48. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Cs31G8Qn.js +1 -0
  49. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CsKrY2zA.js +1 -0
  50. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{x9G_hzyY.js → Cur71c3O.js} +1 -1
  51. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CzgC3GFB.js +1 -0
  52. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D8GZDMNN.js +1 -0
  53. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DFRh-Spp.js +1 -0
  54. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{BylOuP6i.js → DRZO-E-T.js} +1 -1
  55. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{l7KrR96u.js → DcGCxgpH.js} +1 -1
  56. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{Bsi3UGy5.js → Df3aMO5B.js} +1 -1
  57. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{hQVEETDE.js → DkR_EZ_B.js} +1 -1
  58. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DqUGznj_.js +1 -0
  59. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/KpAtIldw.js +1 -0
  60. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/M1Q1F7bw.js +4 -0
  61. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{CDnpyLsT.js → OH7-C_mc.js} +1 -1
  62. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/{D6su9Aln.js → gLNdjSzu.js} +1 -1
  63. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/i0ZZ4z06.js +1 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.BI-EA5gL.js +2 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.CcsRl3cZ.js +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.BbO4Zc3r.js +1 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{1.B4rNYwVp.js → 1._I9GR805.js} +1 -1
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.J2RBFrSr.js +1 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.Cmqj25a-.js +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.C45iKJHA.js +6 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{3.CWHpKonm.js → 3.w9g4AcAx.js} +1 -1
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{4.OUWOLQeV.js → 4.BBI8KwnD.js} +1 -1
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.huHuxdiF.js +1 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.CrbkRPam.js +1 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.FomEdhD6.js +1 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cb_ADSLk.js +1 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/{9.CPu3CiBc.js → 9.CajIG5ce.js} +1 -1
  78. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -1
  79. lightly_studio/dist_lightly_studio_view_app/index.html +14 -14
  80. lightly_studio/examples/example.py +13 -12
  81. lightly_studio/examples/example_coco.py +13 -0
  82. lightly_studio/examples/example_metadata.py +83 -98
  83. lightly_studio/examples/example_selection.py +7 -19
  84. lightly_studio/examples/example_split_work.py +12 -36
  85. lightly_studio/examples/{example_v2.py → example_yolo.py} +3 -4
  86. lightly_studio/models/annotation/annotation_base.py +7 -8
  87. lightly_studio/models/annotation/instance_segmentation.py +8 -8
  88. lightly_studio/models/annotation/object_detection.py +4 -4
  89. lightly_studio/models/dataset.py +6 -2
  90. lightly_studio/models/sample.py +10 -3
  91. lightly_studio/resolvers/dataset_resolver.py +10 -0
  92. lightly_studio/resolvers/embedding_model_resolver.py +22 -0
  93. lightly_studio/resolvers/sample_resolver.py +53 -9
  94. lightly_studio/resolvers/tag_resolver.py +23 -0
  95. lightly_studio/selection/select.py +55 -46
  96. lightly_studio/selection/select_via_db.py +23 -19
  97. lightly_studio/selection/selection_config.py +6 -3
  98. lightly_studio/services/annotations_service/__init__.py +4 -0
  99. lightly_studio/services/annotations_service/update_annotation.py +21 -32
  100. lightly_studio/services/annotations_service/update_annotation_bounding_box.py +36 -0
  101. lightly_studio-0.3.2.dist-info/METADATA +689 -0
  102. {lightly_studio-0.3.1.dist-info → lightly_studio-0.3.2.dist-info}/RECORD +104 -91
  103. lightly_studio/api/db.py +0 -133
  104. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +0 -1
  105. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +0 -1
  106. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +0 -1
  107. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +0 -1
  108. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +0 -1
  109. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +0 -1
  110. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +0 -1
  111. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +0 -93
  112. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +0 -3
  113. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +0 -4
  114. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +0 -1
  115. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +0 -1
  116. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +0 -1
  117. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +0 -1
  118. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +0 -1
  119. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +0 -1
  120. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +0 -1
  121. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +0 -2
  122. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +0 -1
  123. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +0 -1
  124. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +0 -1
  125. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +0 -1
  126. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +0 -6
  127. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +0 -1
  128. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +0 -1
  129. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +0 -1
  130. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +0 -1
  131. lightly_studio-0.3.1.dist-info/METADATA +0 -520
  132. /lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/{OpenSans- → OpenSans-Medium.DVUZMR_6.ttf} +0 -0
  133. {lightly_studio-0.3.1.dist-info → lightly_studio-0.3.2.dist-info}/WHEEL +0 -0
@@ -2,44 +2,39 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from dataclasses import dataclass
6
5
  from pathlib import Path
7
- from typing import Iterable
6
+ from typing import Iterable, Iterator
8
7
  from uuid import UUID
9
8
 
10
- import PIL
11
9
  from labelformat.formats import (
12
10
  COCOInstanceSegmentationInput,
13
11
  COCOObjectDetectionInput,
14
12
  YOLOv8ObjectDetectionInput,
15
13
  )
16
- from labelformat.model.binary_mask_segmentation import BinaryMaskSegmentation
17
- from labelformat.model.bounding_box import BoundingBoxFormat
18
- from labelformat.model.image import Image
19
14
  from labelformat.model.instance_segmentation import (
20
- ImageInstanceSegmentation,
21
15
  InstanceSegmentationInput,
22
16
  )
23
- from labelformat.model.multipolygon import MultiPolygon
24
17
  from labelformat.model.object_detection import (
25
- ImageObjectDetection,
26
18
  ObjectDetectionInput,
27
19
  )
28
- from sqlmodel import Session
29
- from tqdm import tqdm
30
-
31
- from lightly_studio.api.db import db_manager
32
- from lightly_studio.models.annotation.annotation_base import AnnotationCreate
33
- from lightly_studio.models.annotation_label import AnnotationLabelCreate
20
+ from sqlmodel import Session, select
21
+
22
+ from lightly_studio import db_manager
23
+ from lightly_studio.api import features
24
+ from lightly_studio.core import add_samples
25
+ from lightly_studio.core.dataset_query.dataset_query import DatasetQuery
26
+ from lightly_studio.core.dataset_query.match_expression import MatchExpression
27
+ from lightly_studio.core.dataset_query.order_by import OrderByExpression
28
+ from lightly_studio.core.sample import Sample
29
+ from lightly_studio.dataset import fsspec_lister
30
+ from lightly_studio.dataset.embedding_manager import EmbeddingManagerProvider
34
31
  from lightly_studio.models.annotation_task import (
35
32
  AnnotationTaskTable,
36
33
  AnnotationType,
37
34
  )
38
35
  from lightly_studio.models.dataset import DatasetCreate, DatasetTable
39
- from lightly_studio.models.sample import SampleCreate, SampleTable
36
+ from lightly_studio.models.sample import SampleTable
40
37
  from lightly_studio.resolvers import (
41
- annotation_label_resolver,
42
- annotation_resolver,
43
38
  annotation_task_resolver,
44
39
  dataset_resolver,
45
40
  sample_resolver,
@@ -47,97 +42,201 @@ from lightly_studio.resolvers import (
47
42
  from lightly_studio.type_definitions import PathLike
48
43
 
49
44
  # Constants
50
- ANNOTATION_BATCH_SIZE = 64 # Number of annotations to process in a single batch
51
- SAMPLE_BATCH_SIZE = 32 # Number of samples to process in a single batch
52
-
45
+ DEFAULT_DATASET_NAME = "default_dataset"
53
46
 
54
- @dataclass
55
- class AnnotationProcessingContext:
56
- """Context for processing annotations for a single sample."""
57
-
58
- dataset_id: UUID
59
- sample_id: UUID
60
- label_map: dict[int, UUID]
61
- annotation_task_id: UUID
47
+ _SliceType = slice # to avoid shadowing built-in slice in type annotations
62
48
 
63
49
 
64
50
  class Dataset:
65
51
  """A LightlyStudio Dataset.
66
52
 
67
- Represents a dataset in LightlyStudio.
68
-
69
- Args:
70
- name: The name of the dataset. If None, a default name will be assigned.
53
+ Keeps a reference to the underlying DatasetTable.
71
54
  """
72
55
 
73
- def __init__(self, name: str | None = None) -> None:
56
+ def __init__(self, dataset: DatasetTable) -> None:
74
57
  """Initialize a LightlyStudio Dataset."""
58
+ self._inner = dataset
59
+ # TODO(Michal, 09/2025): Do not store the session. Instead, use the
60
+ # dataset object session.
61
+ self.session = db_manager.persistent_session()
62
+
63
+ @staticmethod
64
+ def create(name: str | None = None) -> Dataset:
65
+ """Create a new dataset."""
66
+ if name is None:
67
+ name = DEFAULT_DATASET_NAME
68
+
69
+ dataset = dataset_resolver.create(
70
+ session=db_manager.persistent_session(),
71
+ dataset=DatasetCreate(name=name, directory=""),
72
+ )
73
+ return Dataset(dataset=dataset)
74
+
75
+ @staticmethod
76
+ def load(name: str | None = None) -> Dataset:
77
+ """Load an existing dataset."""
75
78
  if name is None:
76
79
  name = "default_dataset"
77
- self.name = name
78
- self.session = db_manager.persistent_session()
79
- # Create dataset.
80
- self._dataset = dataset_resolver.create(
81
- session=self.session,
82
- dataset=DatasetCreate(
83
- name=self.name,
84
- directory="", # The directory is not used at the moment
85
- ),
80
+
81
+ dataset = dataset_resolver.get_by_name(session=db_manager.persistent_session(), name=name)
82
+ if dataset is None:
83
+ raise ValueError(f"Dataset with name '{name}' not found.")
84
+
85
+ return Dataset(dataset=dataset)
86
+
87
+ @staticmethod
88
+ def load_or_create(name: str | None = None) -> Dataset:
89
+ """Create a new dataset or load an existing one."""
90
+ if name is None:
91
+ name = "default_dataset"
92
+
93
+ dataset = dataset_resolver.get_by_name(session=db_manager.persistent_session(), name=name)
94
+ if dataset is None:
95
+ return Dataset.create(name=name)
96
+
97
+ return Dataset(dataset=dataset)
98
+
99
+ def __iter__(self) -> Iterator[Sample]:
100
+ """Iterate over samples in the dataset."""
101
+ for sample in self.session.exec(
102
+ select(SampleTable).where(SampleTable.dataset_id == self.dataset_id)
103
+ ):
104
+ yield Sample(inner=sample)
105
+
106
+ def get_sample(self, sample_id: UUID) -> Sample:
107
+ """Get a single sample from the dataset by its ID.
108
+
109
+ Args:
110
+ sample_id: The UUID of the sample to retrieve.
111
+
112
+ Returns:
113
+ A single SampleTable object.
114
+
115
+ Raises:
116
+ IndexError: If no sample is found with the given sample_id.
117
+ """
118
+ sample = sample_resolver.get_by_id(
119
+ self.session, dataset_id=self.dataset_id, sample_id=sample_id
86
120
  )
87
121
 
122
+ if sample is None:
123
+ raise IndexError(f"No sample found for sample_id: {sample_id}")
124
+ return Sample(inner=sample)
125
+
88
126
  @property
89
127
  def dataset_id(self) -> UUID:
90
128
  """Get the dataset ID."""
91
- return self._dataset.dataset_id
129
+ return self._inner.dataset_id
130
+
131
+ @property
132
+ def name(self) -> str:
133
+ """Get the dataset name."""
134
+ return self._inner.name
135
+
136
+ def query(self) -> DatasetQuery:
137
+ """Create a DatasetQuery for this dataset.
138
+
139
+ Returns:
140
+ A DatasetQuery instance for querying samples in this dataset.
141
+ """
142
+ return DatasetQuery(dataset=self._inner, session=self.session)
143
+
144
+ def match(self, match_expression: MatchExpression) -> DatasetQuery:
145
+ """Create a query on the dataset and store a field condition for filtering.
146
+
147
+ Args:
148
+ match_expression: Defines the filter.
149
+
150
+ Returns:
151
+ DatasetQuery for method chaining.
152
+ """
153
+ return self.query().match(match_expression)
154
+
155
+ def order_by(self, *order_by: OrderByExpression) -> DatasetQuery:
156
+ """Create a query on the dataset and store ordering expressions.
157
+
158
+ Args:
159
+ order_by: One or more ordering expressions. They are applied in order.
160
+ E.g. first ordering by sample width and then by sample file_name will
161
+ only order the samples with the same sample width by file_name.
162
+
163
+ Returns:
164
+ DatasetQuery for method chaining.
165
+ """
166
+ return self.query().order_by(*order_by)
167
+
168
+ def slice(self, offset: int = 0, limit: int | None = None) -> DatasetQuery:
169
+ """Create a query on the dataset and apply offset and limit to results.
170
+
171
+ Args:
172
+ offset: Number of items to skip from beginning (default: 0).
173
+ limit: Maximum number of items to return (None = no limit).
174
+
175
+ Returns:
176
+ DatasetQuery for method chaining.
177
+ """
178
+ return self.query().slice(offset, limit)
179
+
180
+ def __getitem__(self, key: _SliceType) -> DatasetQuery:
181
+ """Create a query on the dataset and enable bracket notation for slicing.
182
+
183
+ Args:
184
+ key: A slice object (e.g., [10:20], [:50], [100:]).
185
+
186
+ Returns:
187
+ DatasetQuery with slice applied.
188
+
189
+ Raises:
190
+ TypeError: If key is not a slice object.
191
+ ValueError: If slice contains unsupported features or conflicts with existing slice.
192
+ """
193
+ return self.query()[key]
92
194
 
93
195
  def add_samples_from_path(
94
196
  self,
95
197
  path: PathLike,
96
- recursive: bool = True,
97
- allowed_extensions: Iterable[str] = {
98
- ".png",
99
- ".jpg",
100
- ".jpeg",
101
- ".gif",
102
- ".webp",
103
- ".bmp",
104
- ".tiff",
105
- },
198
+ allowed_extensions: Iterable[str] | None = None,
199
+ embed: bool = True,
106
200
  ) -> None:
107
201
  """Adding samples from the specified path to the dataset.
108
202
 
109
203
  Args:
110
204
  path: Path to the folder containing the images to add.
111
- recursive: If True, search for images recursively in subfolders.
112
205
  allowed_extensions: An iterable container of allowed image file
113
206
  extensions.
207
+ embed: If True, generate embeddings for the newly added samples.
114
208
  """
115
- path = Path(path).absolute() if isinstance(path, str) else path.absolute()
116
- if not path.exists() or not path.is_dir():
117
- raise ValueError(f"Provided path is not a valid directory: {path}")
118
-
119
209
  # Collect image file paths.
120
- allowed_extensions_set = {ext.lower() for ext in allowed_extensions}
121
- image_paths = []
122
- path_iter = path.rglob("*") if recursive else path.glob("*")
123
- for path in path_iter:
124
- if path.is_file() and path.suffix.lower() in allowed_extensions_set:
125
- image_paths.append(path)
210
+ if allowed_extensions:
211
+ allowed_extensions_set = {ext.lower() for ext in allowed_extensions}
212
+ else:
213
+ allowed_extensions_set = None
214
+ image_paths = list(
215
+ fsspec_lister.iter_files_from_path(
216
+ path=str(path), allowed_extensions=allowed_extensions_set
217
+ )
218
+ )
126
219
  print(f"Found {len(image_paths)} images in {path}.")
127
220
 
128
221
  # Process images.
129
- _load_into_dataset_from_paths(
222
+ created_sample_ids = add_samples.load_into_dataset_from_paths(
130
223
  session=self.session,
131
224
  dataset_id=self.dataset_id,
132
225
  image_paths=image_paths,
133
226
  )
134
227
 
228
+ if embed:
229
+ _generate_embeddings(
230
+ session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
231
+ )
232
+
135
233
  def add_samples_from_labelformat(
136
234
  self,
137
235
  input_labels: ObjectDetectionInput | InstanceSegmentationInput,
138
236
  images_path: PathLike,
139
237
  is_prediction: bool = True,
140
238
  task_name: str | None = None,
239
+ embed: bool = True,
141
240
  ) -> None:
142
241
  """Load a dataset from a labelformat object and store in database.
143
242
 
@@ -147,9 +246,7 @@ class Dataset:
147
246
  is_prediction: Whether the task is for prediction or labels.
148
247
  task_name: Optional name for the annotation task. If None, a
149
248
  default name is generated.
150
-
151
- Returns:
152
- DatasetTable: The created dataset table entry.
249
+ embed: If True, generate embeddings for the newly added samples.
153
250
  """
154
251
  if isinstance(images_path, str):
155
252
  images_path = Path(images_path)
@@ -174,7 +271,7 @@ class Dataset:
174
271
  ),
175
272
  )
176
273
 
177
- _load_into_dataset(
274
+ created_sample_ids = add_samples.load_into_dataset_from_labelformat(
178
275
  session=self.session,
179
276
  dataset_id=self.dataset_id,
180
277
  input_labels=input_labels,
@@ -182,25 +279,33 @@ class Dataset:
182
279
  annotation_task_id=new_annotation_task.annotation_task_id,
183
280
  )
184
281
 
185
- def from_yolo(
282
+ if embed:
283
+ _generate_embeddings(
284
+ session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
285
+ )
286
+
287
+ def add_samples_from_yolo(
186
288
  self,
187
- data_yaml_path: str,
289
+ data_yaml: PathLike,
188
290
  input_split: str = "train",
189
291
  task_name: str | None = None,
190
- ) -> DatasetTable:
292
+ embed: bool = True,
293
+ ) -> None:
191
294
  """Load a dataset in YOLO format and store in DB.
192
295
 
193
296
  Args:
194
- data_yaml_path: Path to the YOLO data.yaml file.
297
+ data_yaml: Path to the YOLO data.yaml file.
195
298
  input_split: The split to load (e.g., 'train', 'val').
196
299
  task_name: Optional name for the annotation task. If None, a
197
300
  default name is generated.
198
-
199
- Returns:
200
- DatasetTable: The created dataset table entry.
301
+ embed: If True, generate embeddings for the newly added samples.
201
302
  """
202
- data_yaml = Path(data_yaml_path).absolute()
203
- dataset_name = data_yaml.parent.name
303
+ if isinstance(data_yaml, str):
304
+ data_yaml = Path(data_yaml)
305
+ data_yaml = data_yaml.absolute()
306
+
307
+ if not data_yaml.is_file() or data_yaml.suffix != ".yaml":
308
+ raise FileNotFoundError(f"YOLO data yaml file not found: '{data_yaml}'")
204
309
 
205
310
  if task_name is None:
206
311
  task_name = f"Loaded from YOLO: {data_yaml.name} ({input_split} split)"
@@ -210,314 +315,97 @@ class Dataset:
210
315
  input_file=data_yaml,
211
316
  input_split=input_split,
212
317
  )
213
- img_dir = label_input._images_dir() # noqa: SLF001
318
+ images_path = label_input._images_dir() # noqa: SLF001
214
319
 
215
- return self.from_labelformat( # type: ignore[no-any-return,attr-defined]
320
+ self.add_samples_from_labelformat(
216
321
  input_labels=label_input,
217
- dataset_name=dataset_name,
218
- img_dir=str(img_dir),
322
+ images_path=images_path,
219
323
  is_prediction=False,
220
324
  task_name=task_name,
325
+ embed=embed,
221
326
  )
222
327
 
223
- def from_coco_object_detections(
328
+ def add_samples_from_coco(
224
329
  self,
225
- annotations_json_path: str,
226
- img_dir: str,
330
+ annotations_json: PathLike,
331
+ images_path: PathLike,
227
332
  task_name: str | None = None,
228
- ) -> DatasetTable:
333
+ annotation_type: AnnotationType = AnnotationType.BBOX,
334
+ embed: bool = True,
335
+ ) -> None:
229
336
  """Load a dataset in COCO Object Detection format and store in DB.
230
337
 
231
338
  Args:
232
- annotations_json_path: Path to the COCO annotations JSON file.
233
- img_dir: Path to the folder containing the images.
339
+ annotations_json: Path to the COCO annotations JSON file.
340
+ images_path: Path to the folder containing the images.
234
341
  task_name: Optional name for the annotation task. If None, a
235
342
  default name is generated.
236
-
237
- Returns:
238
- DatasetTable: The created dataset table entry.
343
+ annotation_type: The type of annotation to be loaded (e.g., 'ObjectDetection',
344
+ 'InstanceSegmentation').
345
+ embed: If True, generate embeddings for the newly added samples.
239
346
  """
240
- annotations_json = Path(annotations_json_path)
241
- dataset_name = annotations_json.parent.name
242
-
243
- if task_name is None:
244
- task_name = f"Loaded from COCO Object Detection: {annotations_json.name}"
245
-
246
- label_input = COCOObjectDetectionInput(
247
- input_file=annotations_json,
248
- )
249
- img_dir_path = Path(img_dir).absolute()
347
+ if isinstance(annotations_json, str):
348
+ annotations_json = Path(annotations_json)
349
+ annotations_json = annotations_json.absolute()
250
350
 
251
- return self.from_labelformat( # type: ignore[no-any-return, attr-defined]
252
- input_labels=label_input,
253
- dataset_name=dataset_name,
254
- img_dir=str(img_dir_path),
255
- is_prediction=False,
256
- task_name=task_name,
257
- )
351
+ if not annotations_json.is_file() or annotations_json.suffix != ".json":
352
+ raise FileNotFoundError(f"COCO annotations json file not found: '{annotations_json}'")
258
353
 
259
- def from_coco_instance_segmentations(
260
- self,
261
- annotations_json_path: str,
262
- img_dir: str,
263
- task_name: str | None = None,
264
- ) -> DatasetTable:
265
- """Load a dataset in COCO Instance Segmentation format and store in DB.
354
+ label_input: COCOObjectDetectionInput | COCOInstanceSegmentationInput
266
355
 
267
- Args:
268
- annotations_json_path: Path to the COCO annotations JSON file.
269
- img_dir: Path to the folder containing the images.
270
- task_name: Optional name for the annotation task. If None, a
271
- default name is generated.
272
-
273
- Returns:
274
- DatasetTable: The created dataset table entry.
275
- """
276
- annotations_json = Path(annotations_json_path)
277
- dataset_name = annotations_json.parent.name
356
+ if annotation_type == AnnotationType.BBOX:
357
+ label_input = COCOObjectDetectionInput(
358
+ input_file=annotations_json,
359
+ )
360
+ task_name_default = f"Loaded from COCO Object Detection: {annotations_json.name}"
361
+ elif annotation_type == AnnotationType.INSTANCE_SEGMENTATION:
362
+ label_input = COCOInstanceSegmentationInput(
363
+ input_file=annotations_json,
364
+ )
365
+ task_name_default = f"Loaded from COCO Instance Segmentation: {annotations_json.name}"
366
+ else:
367
+ raise ValueError(f"Invalid annotation type: {annotation_type}")
278
368
 
279
369
  if task_name is None:
280
- task_name = f"Loaded from COCO Instance Segmentation: {annotations_json.name}"
370
+ task_name = task_name_default
281
371
 
282
- label_input = COCOInstanceSegmentationInput(
283
- input_file=annotations_json,
284
- )
285
- img_dir_path = Path(img_dir).absolute()
372
+ images_path = Path(images_path).absolute()
286
373
 
287
- return self.from_labelformat( # type: ignore[no-any-return,attr-defined]
374
+ self.add_samples_from_labelformat(
288
375
  input_labels=label_input,
289
- dataset_name=dataset_name,
290
- img_dir=str(img_dir_path),
376
+ images_path=images_path,
291
377
  is_prediction=False,
292
378
  task_name=task_name,
379
+ embed=embed,
293
380
  )
294
381
 
295
- @staticmethod
296
- def load_from_db(name: str, db_path: PathLike) -> Dataset:
297
- """Load a dataset from the database.
298
382
 
299
- Returns:
300
- Dataset: The loaded dataset.
301
- """
302
- raise NotImplementedError
303
-
304
-
305
- def _load_into_dataset_from_paths(
306
- dataset_id: UUID,
307
- session: Session,
308
- image_paths: Iterable[Path],
309
- ) -> None:
310
- samples_to_create: list[SampleCreate] = []
311
-
312
- for image_path in tqdm(
313
- image_paths,
314
- desc="Processing images",
315
- unit=" images",
316
- ):
317
- try:
318
- image = PIL.Image.open(image_path)
319
- width, height = image.size
320
- image.close()
321
- except (FileNotFoundError, PIL.UnidentifiedImageError, OSError):
322
- continue
323
-
324
- sample = SampleCreate(
325
- file_name=image_path.name,
326
- file_path_abs=str(image_path),
327
- width=width,
328
- height=height,
329
- dataset_id=dataset_id,
330
- )
331
- samples_to_create.append(sample)
332
-
333
- # Process batch when it reaches SAMPLE_BATCH_SIZE
334
- if len(samples_to_create) >= SAMPLE_BATCH_SIZE:
335
- _ = sample_resolver.create_many(session=session, samples=samples_to_create)
336
- samples_to_create = []
337
-
338
- # Handle remaining samples
339
- if samples_to_create:
340
- _ = sample_resolver.create_many(session=session, samples=samples_to_create)
341
-
342
-
343
- def _load_into_dataset(
344
- session: Session,
345
- dataset_id: UUID,
346
- input_labels: ObjectDetectionInput | InstanceSegmentationInput,
347
- images_path: Path,
348
- annotation_task_id: UUID,
349
- ) -> None:
350
- """Store a loaded dataset in database."""
351
- # Create label mapping
352
- label_map = _create_label_map(session=session, input_labels=input_labels)
353
-
354
- annotations_to_create: list[AnnotationCreate] = []
355
- sample_ids: list[UUID] = []
356
- samples_to_create: list[SampleCreate] = []
357
- samples_image_data: list[
358
- tuple[SampleCreate, ImageInstanceSegmentation | ImageObjectDetection]
359
- ] = []
360
-
361
- for image_data in tqdm(input_labels.get_labels(), desc="Processing images", unit=" images"):
362
- image: Image = image_data.image # type: ignore[attr-defined]
363
-
364
- typed_image_data: ImageInstanceSegmentation | ImageObjectDetection = image_data # type: ignore[assignment]
365
- sample = SampleCreate(
366
- file_name=str(image.filename),
367
- file_path_abs=str(images_path / image.filename),
368
- width=image.width,
369
- height=image.height,
370
- dataset_id=dataset_id,
371
- )
372
- samples_to_create.append(sample)
373
- samples_image_data.append((sample, typed_image_data))
374
-
375
- if len(samples_to_create) >= SAMPLE_BATCH_SIZE:
376
- stored_samples = sample_resolver.create_many(session=session, samples=samples_to_create)
377
- _process_batch_annotations(
378
- session=session,
379
- stored_samples=stored_samples,
380
- samples_data=samples_image_data,
381
- dataset_id=dataset_id,
382
- label_map=label_map,
383
- annotation_task_id=annotation_task_id,
384
- annotations_to_create=annotations_to_create,
385
- sample_ids=sample_ids,
386
- )
387
- samples_to_create.clear()
388
- samples_image_data.clear()
389
-
390
- if samples_to_create:
391
- stored_samples = sample_resolver.create_many(session=session, samples=samples_to_create)
392
- _process_batch_annotations(
393
- session=session,
394
- stored_samples=stored_samples,
395
- samples_data=samples_image_data,
396
- dataset_id=dataset_id,
397
- label_map=label_map,
398
- annotation_task_id=annotation_task_id,
399
- annotations_to_create=annotations_to_create,
400
- sample_ids=sample_ids,
401
- )
402
-
403
- # Insert any remaining annotations
404
- if annotations_to_create:
405
- annotation_resolver.create_many(session=session, annotations=annotations_to_create)
406
-
407
-
408
- def _create_label_map(
409
- session: Session,
410
- input_labels: ObjectDetectionInput | InstanceSegmentationInput,
411
- ) -> dict[int, UUID]:
412
- """Create a mapping of category IDs to annotation label IDs."""
413
- label_map = {}
414
- for category in tqdm(
415
- input_labels.get_categories(),
416
- desc="Processing categories",
417
- unit=" categories",
418
- ):
419
- label = AnnotationLabelCreate(annotation_label_name=category.name)
420
- stored_label = annotation_label_resolver.create(session=session, label=label)
421
- label_map[category.id] = stored_label.annotation_label_id
422
- return label_map
423
-
424
-
425
- def _process_object_detection_annotations(
426
- context: AnnotationProcessingContext,
427
- image_data: ImageObjectDetection,
428
- ) -> list[AnnotationCreate]:
429
- """Process object detection annotations for a single image."""
430
- new_annotations = []
431
- for obj in image_data.objects:
432
- box = obj.box.to_format(BoundingBoxFormat.XYWH)
433
- x, y, width, height = box
434
-
435
- new_annotations.append(
436
- AnnotationCreate(
437
- dataset_id=context.dataset_id,
438
- sample_id=context.sample_id,
439
- annotation_label_id=context.label_map[obj.category.id],
440
- annotation_type="object_detection",
441
- x=x,
442
- y=y,
443
- width=width,
444
- height=height,
445
- confidence=obj.confidence,
446
- annotation_task_id=context.annotation_task_id,
447
- )
448
- )
449
- return new_annotations
450
-
451
-
452
- def _process_instance_segmentation_annotations(
453
- context: AnnotationProcessingContext,
454
- image_data: ImageInstanceSegmentation,
455
- ) -> list[AnnotationCreate]:
456
- """Process instance segmentation annotations for a single image."""
457
- new_annotations = []
458
- for obj in image_data.objects:
459
- segmentation_rle: None | list[int] = None
460
- if isinstance(obj.segmentation, MultiPolygon):
461
- box = obj.segmentation.bounding_box().to_format(BoundingBoxFormat.XYWH)
462
- elif isinstance(obj.segmentation, BinaryMaskSegmentation):
463
- box = obj.segmentation.bounding_box.to_format(BoundingBoxFormat.XYWH)
464
- segmentation_rle = obj.segmentation._rle_row_wise # noqa: SLF001
465
- else:
466
- raise ValueError(f"Unsupported segmentation type: {type(obj.segmentation)}")
467
-
468
- x, y, width, height = box
469
-
470
- new_annotations.append(
471
- AnnotationCreate(
472
- dataset_id=context.dataset_id,
473
- sample_id=context.sample_id,
474
- annotation_label_id=context.label_map[obj.category.id],
475
- annotation_type="instance_segmentation",
476
- x=x,
477
- y=y,
478
- width=width,
479
- height=height,
480
- segmentation_mask=segmentation_rle,
481
- annotation_task_id=context.annotation_task_id,
482
- )
483
- )
484
- return new_annotations
485
-
486
-
487
- def _process_batch_annotations( # noqa: PLR0913
488
- session: Session,
489
- stored_samples: list[SampleTable],
490
- samples_data: list[tuple[SampleCreate, ImageInstanceSegmentation | ImageObjectDetection]],
491
- dataset_id: UUID,
492
- label_map: dict[int, UUID],
493
- annotation_task_id: UUID,
494
- annotations_to_create: list[AnnotationCreate],
495
- sample_ids: list[UUID],
496
- ) -> None:
497
- """Process annotations for a batch of samples."""
498
- for stored_sample, (_, img_data) in zip(stored_samples, samples_data):
499
- sample_ids.append(stored_sample.sample_id)
500
-
501
- context = AnnotationProcessingContext(
502
- dataset_id=dataset_id,
503
- sample_id=stored_sample.sample_id,
504
- label_map=label_map,
505
- annotation_task_id=annotation_task_id,
506
- )
383
+ def _generate_embeddings(session: Session, dataset_id: UUID, sample_ids: list[UUID]) -> None:
384
+ """Generate and store embeddings for samples.
507
385
 
508
- if isinstance(img_data, ImageInstanceSegmentation):
509
- new_annotations = _process_instance_segmentation_annotations(
510
- context=context, image_data=img_data
511
- )
512
- elif isinstance(img_data, ImageObjectDetection):
513
- new_annotations = _process_object_detection_annotations(
514
- context=context, image_data=img_data
515
- )
516
- else:
517
- raise ValueError(f"Unsupported annotation type: {type(img_data)}")
518
-
519
- annotations_to_create.extend(new_annotations)
520
-
521
- if len(annotations_to_create) >= ANNOTATION_BATCH_SIZE:
522
- annotation_resolver.create_many(session=session, annotations=annotations_to_create)
523
- annotations_to_create.clear()
386
+ Args:
387
+ session: Database session for resolver operations.
388
+ dataset_id: The ID of the dataset to associate with the embedding model.
389
+ sample_ids: List of sample IDs to generate embeddings for.
390
+ """
391
+ if not sample_ids:
392
+ return
393
+
394
+ embedding_manager = EmbeddingManagerProvider.get_embedding_manager()
395
+ model_id = embedding_manager.load_or_get_default_model(
396
+ session=session,
397
+ dataset_id=dataset_id,
398
+ )
399
+ if model_id is None:
400
+ print("No embedding model loaded. Skipping embedding generation.")
401
+ return
402
+
403
+ embedding_manager.embed_images(
404
+ session=session,
405
+ sample_ids=sample_ids,
406
+ embedding_model_id=model_id,
407
+ )
408
+
409
+ # Mark the embedding search feature as enabled.
410
+ if "embeddingSearchEnabled" not in features.lightly_studio_active_features:
411
+ features.lightly_studio_active_features.append("embeddingSearchEnabled")