lightly-studio 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lightly-studio might be problematic. Click here for more details.

Files changed (219) hide show
  1. lightly_studio/__init__.py +11 -0
  2. lightly_studio/api/__init__.py +0 -0
  3. lightly_studio/api/app.py +110 -0
  4. lightly_studio/api/cache.py +77 -0
  5. lightly_studio/api/db.py +133 -0
  6. lightly_studio/api/db_tables.py +32 -0
  7. lightly_studio/api/features.py +7 -0
  8. lightly_studio/api/routes/api/annotation.py +233 -0
  9. lightly_studio/api/routes/api/annotation_label.py +90 -0
  10. lightly_studio/api/routes/api/annotation_task.py +38 -0
  11. lightly_studio/api/routes/api/classifier.py +387 -0
  12. lightly_studio/api/routes/api/dataset.py +182 -0
  13. lightly_studio/api/routes/api/dataset_tag.py +257 -0
  14. lightly_studio/api/routes/api/exceptions.py +96 -0
  15. lightly_studio/api/routes/api/features.py +17 -0
  16. lightly_studio/api/routes/api/metadata.py +37 -0
  17. lightly_studio/api/routes/api/metrics.py +80 -0
  18. lightly_studio/api/routes/api/sample.py +196 -0
  19. lightly_studio/api/routes/api/settings.py +45 -0
  20. lightly_studio/api/routes/api/status.py +19 -0
  21. lightly_studio/api/routes/api/text_embedding.py +48 -0
  22. lightly_studio/api/routes/api/validators.py +17 -0
  23. lightly_studio/api/routes/healthz.py +13 -0
  24. lightly_studio/api/routes/images.py +104 -0
  25. lightly_studio/api/routes/webapp.py +51 -0
  26. lightly_studio/api/server.py +82 -0
  27. lightly_studio/core/__init__.py +0 -0
  28. lightly_studio/core/dataset.py +523 -0
  29. lightly_studio/core/sample.py +77 -0
  30. lightly_studio/core/start_gui.py +15 -0
  31. lightly_studio/dataset/__init__.py +0 -0
  32. lightly_studio/dataset/edge_embedding_generator.py +144 -0
  33. lightly_studio/dataset/embedding_generator.py +91 -0
  34. lightly_studio/dataset/embedding_manager.py +163 -0
  35. lightly_studio/dataset/env.py +16 -0
  36. lightly_studio/dataset/file_utils.py +35 -0
  37. lightly_studio/dataset/loader.py +622 -0
  38. lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
  39. lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
  40. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
  41. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
  42. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
  43. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
  44. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
  45. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
  46. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
  47. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
  48. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
  49. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
  50. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
  51. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
  52. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
  53. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
  54. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
  55. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
  56. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
  57. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
  58. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
  59. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
  60. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
  61. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
  62. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
  63. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
  78. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
  79. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
  80. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
  81. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
  82. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
  83. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
  84. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
  85. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
  86. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
  87. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
  88. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
  89. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
  90. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
  91. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
  92. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
  93. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
  94. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
  95. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
  96. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
  97. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
  98. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
  99. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
  100. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
  101. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
  102. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
  103. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
  104. lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
  105. lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
  106. lightly_studio/examples/example.py +23 -0
  107. lightly_studio/examples/example_metadata.py +338 -0
  108. lightly_studio/examples/example_selection.py +39 -0
  109. lightly_studio/examples/example_split_work.py +67 -0
  110. lightly_studio/examples/example_v2.py +21 -0
  111. lightly_studio/export_schema.py +18 -0
  112. lightly_studio/few_shot_classifier/__init__.py +0 -0
  113. lightly_studio/few_shot_classifier/classifier.py +80 -0
  114. lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
  115. lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
  116. lightly_studio/metadata/complex_metadata.py +47 -0
  117. lightly_studio/metadata/gps_coordinate.py +41 -0
  118. lightly_studio/metadata/metadata_protocol.py +17 -0
  119. lightly_studio/metrics/__init__.py +0 -0
  120. lightly_studio/metrics/detection/__init__.py +0 -0
  121. lightly_studio/metrics/detection/map.py +268 -0
  122. lightly_studio/models/__init__.py +1 -0
  123. lightly_studio/models/annotation/__init__.py +0 -0
  124. lightly_studio/models/annotation/annotation_base.py +171 -0
  125. lightly_studio/models/annotation/instance_segmentation.py +56 -0
  126. lightly_studio/models/annotation/links.py +17 -0
  127. lightly_studio/models/annotation/object_detection.py +47 -0
  128. lightly_studio/models/annotation/semantic_segmentation.py +44 -0
  129. lightly_studio/models/annotation_label.py +47 -0
  130. lightly_studio/models/annotation_task.py +28 -0
  131. lightly_studio/models/classifier.py +20 -0
  132. lightly_studio/models/dataset.py +84 -0
  133. lightly_studio/models/embedding_model.py +30 -0
  134. lightly_studio/models/metadata.py +208 -0
  135. lightly_studio/models/sample.py +180 -0
  136. lightly_studio/models/sample_embedding.py +37 -0
  137. lightly_studio/models/settings.py +60 -0
  138. lightly_studio/models/tag.py +96 -0
  139. lightly_studio/py.typed +0 -0
  140. lightly_studio/resolvers/__init__.py +7 -0
  141. lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
  142. lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
  143. lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
  144. lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
  145. lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
  146. lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
  147. lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
  148. lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
  149. lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
  150. lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
  151. lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
  152. lightly_studio/resolvers/annotation_resolver/create.py +19 -0
  153. lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
  154. lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
  155. lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
  156. lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
  157. lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
  158. lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
  159. lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
  160. lightly_studio/resolvers/annotation_task_resolver.py +31 -0
  161. lightly_studio/resolvers/annotations/__init__.py +1 -0
  162. lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
  163. lightly_studio/resolvers/dataset_resolver.py +278 -0
  164. lightly_studio/resolvers/embedding_model_resolver.py +100 -0
  165. lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
  166. lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
  167. lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
  168. lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
  169. lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
  170. lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
  171. lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
  172. lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
  173. lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
  174. lightly_studio/resolvers/sample_resolver.py +249 -0
  175. lightly_studio/resolvers/samples_filter.py +81 -0
  176. lightly_studio/resolvers/settings_resolver.py +58 -0
  177. lightly_studio/resolvers/tag_resolver.py +276 -0
  178. lightly_studio/selection/README.md +6 -0
  179. lightly_studio/selection/mundig.py +105 -0
  180. lightly_studio/selection/select.py +96 -0
  181. lightly_studio/selection/select_via_db.py +93 -0
  182. lightly_studio/selection/selection_config.py +31 -0
  183. lightly_studio/services/annotations_service/__init__.py +21 -0
  184. lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
  185. lightly_studio/services/annotations_service/update_annotation.py +65 -0
  186. lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
  187. lightly_studio/services/annotations_service/update_annotations.py +29 -0
  188. lightly_studio/setup_logging.py +19 -0
  189. lightly_studio/type_definitions.py +19 -0
  190. lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
  191. lightly_studio/vendor/LICENSE +31 -0
  192. lightly_studio/vendor/LICENSE_weights_data +50 -0
  193. lightly_studio/vendor/README.md +5 -0
  194. lightly_studio/vendor/__init__.py +1 -0
  195. lightly_studio/vendor/mobileclip/__init__.py +96 -0
  196. lightly_studio/vendor/mobileclip/clip.py +77 -0
  197. lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
  198. lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
  199. lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
  200. lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
  201. lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
  202. lightly_studio/vendor/mobileclip/logger.py +154 -0
  203. lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
  204. lightly_studio/vendor/mobileclip/models/mci.py +933 -0
  205. lightly_studio/vendor/mobileclip/models/vit.py +433 -0
  206. lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
  207. lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
  208. lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
  209. lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
  210. lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
  211. lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
  212. lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
  213. lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
  214. lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
  215. lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
  216. lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
  217. lightly_studio-0.3.1.dist-info/METADATA +520 -0
  218. lightly_studio-0.3.1.dist-info/RECORD +219 -0
  219. lightly_studio-0.3.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,38 @@
1
+ """API endpoints for annotation tasks."""
2
+
3
+ from typing import List
4
+ from uuid import UUID
5
+
6
+ from fastapi import APIRouter, Depends, HTTPException, status
7
+ from sqlmodel import Session
8
+
9
+ from lightly_studio.api.db import get_session
10
+ from lightly_studio.models.annotation_task import AnnotationTaskTable
11
+ from lightly_studio.resolvers import annotation_task_resolver
12
+
13
+ router = APIRouter(prefix="/annotationtasks", tags=["annotationtasks"])
14
+
15
+
16
+ @router.get("/", response_model=List[AnnotationTaskTable])
17
+ def get_annotation_tasks(
18
+ session: Session = Depends(get_session), # noqa: B008
19
+ ) -> List[AnnotationTaskTable]:
20
+ """Get all annotation tasks."""
21
+ return annotation_task_resolver.get_all(session=session)
22
+
23
+
24
+ @router.get("/{annotation_task_id}", response_model=AnnotationTaskTable)
25
+ def get_annotation_task(
26
+ annotation_task_id: UUID,
27
+ session: Session = Depends(get_session), # noqa: B008
28
+ ) -> AnnotationTaskTable:
29
+ """Get an annotation task by ID."""
30
+ task = annotation_task_resolver.get_by_id(
31
+ session=session, annotation_task_id=annotation_task_id
32
+ )
33
+ if task is None:
34
+ raise HTTPException(
35
+ status_code=status.HTTP_404_NOT_FOUND,
36
+ detail=f"Annotation task with ID {annotation_task_id} not found",
37
+ )
38
+ return task
@@ -0,0 +1,387 @@
1
+ """This module contains the API routes for managing classifiers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import io
6
+ from pathlib import Path
7
+ from uuid import UUID
8
+
9
+ from fastapi import APIRouter, Depends, UploadFile
10
+ from fastapi.responses import StreamingResponse
11
+ from pydantic import BaseModel
12
+ from sqlmodel import Session
13
+ from typing_extensions import Annotated
14
+
15
+ from lightly_studio.api.db import get_session
16
+ from lightly_studio.few_shot_classifier.classifier import (
17
+ ExportType,
18
+ )
19
+ from lightly_studio.few_shot_classifier.classifier_manager import (
20
+ ClassifierManagerProvider,
21
+ )
22
+ from lightly_studio.models.classifier import EmbeddingClassifier
23
+
24
+ classifier_router = APIRouter()
25
+ SessionDep = Annotated[Session, Depends(get_session)]
26
+
27
+
28
+ class GetNegativeSamplesRequest(BaseModel):
29
+ """Request for getting negative samples for classifier training."""
30
+
31
+ positive_sample_ids: list[UUID]
32
+ dataset_id: UUID
33
+
34
+
35
+ class GetNegativeSamplesResponse(BaseModel):
36
+ """Response for getting negative samples for classifier training."""
37
+
38
+ negative_sample_ids: list[UUID]
39
+
40
+
41
+ @classifier_router.post("/classifiers/get_negative_samples")
42
+ def get_negative_samples(
43
+ request: GetNegativeSamplesRequest, session: SessionDep
44
+ ) -> GetNegativeSamplesResponse:
45
+ """Get negative samples for classifier training.
46
+
47
+ Args:
48
+ request: The request containing negative sample parameters.
49
+ session: Database session.
50
+
51
+ Returns:
52
+ The response containing negative sample IDs.
53
+ """
54
+ classifier_manager = ClassifierManagerProvider.get_classifier_manager()
55
+ negative_samples = classifier_manager.provide_negative_samples(
56
+ session=session,
57
+ dataset_id=request.dataset_id,
58
+ selected_samples=request.positive_sample_ids,
59
+ )
60
+ # Extract just the sample IDs from the returned Sample objects
61
+ negative_sample_ids = [sample.sample_id for sample in negative_samples]
62
+ return GetNegativeSamplesResponse(negative_sample_ids=negative_sample_ids)
63
+
64
+
65
+ class SamplesToRefineResponse(BaseModel):
66
+ """Response for samples for classifier refinement.
67
+
68
+ Maps class names to lists of sample IDs. First class gets high confidence
69
+ samples, second class gets low confidence samples.
70
+ """
71
+
72
+ samples: dict[str, list[UUID]]
73
+
74
+
75
+ @classifier_router.get("/classifiers/{classifier_id}/samples_to_refine")
76
+ def samples_to_refine(
77
+ classifier_id: UUID,
78
+ dataset_id: UUID,
79
+ session: SessionDep,
80
+ ) -> SamplesToRefineResponse:
81
+ """Get samples for classifier refinement.
82
+
83
+ Args:
84
+ classifier_id: The ID of the classifier.
85
+ dataset_id: The ID of the dataset.
86
+ session: Database session.
87
+
88
+ Returns:
89
+ The response containing sample IDs for refinement.
90
+ """
91
+ classifier_manager = ClassifierManagerProvider.get_classifier_manager()
92
+ samples = classifier_manager.get_samples_for_fine_tuning(
93
+ session=session, classifier_id=classifier_id, dataset_id=dataset_id
94
+ )
95
+ return SamplesToRefineResponse(samples=samples)
96
+
97
+
98
+ @classifier_router.get("/classifiers/{classifier_id}/sample_history")
99
+ def sample_history(
100
+ classifier_id: UUID,
101
+ ) -> SamplesToRefineResponse:
102
+ """Get all samples used in the classifier training.
103
+
104
+ Args:
105
+ classifier_id: The ID of the classifier.
106
+
107
+ Returns:
108
+ The response containing sample IDs used in the training.
109
+ """
110
+ classifier_manager = ClassifierManagerProvider.get_classifier_manager()
111
+ samples = classifier_manager.get_annotations(classifier_id=classifier_id)
112
+ return SamplesToRefineResponse(samples=samples)
113
+
114
+
115
+ @classifier_router.post(
116
+ "/classifiers/{classifier_id}/commit_temp_classifier",
117
+ )
118
+ def commit_temp_classifier(
119
+ classifier_id: UUID,
120
+ ) -> None:
121
+ """Commit the classifier.
122
+
123
+ Args:
124
+ classifier_id: The ID of the classifier.
125
+
126
+ Returns:
127
+ None
128
+ """
129
+ classifier_manager = ClassifierManagerProvider.get_classifier_manager()
130
+ classifier_manager.commit_temp_classifier(classifier_id=classifier_id)
131
+
132
+
133
+ @classifier_router.delete(
134
+ "/classifiers/{classifier_id}/drop_temp_classifier",
135
+ )
136
+ def drop_temp_classifier(
137
+ classifier_id: UUID,
138
+ ) -> None:
139
+ """Drop the classifier.
140
+
141
+ Args:
142
+ classifier_id: The ID of the classifier.
143
+
144
+ Returns:
145
+ None
146
+ """
147
+ classifier_manager = ClassifierManagerProvider.get_classifier_manager()
148
+ classifier_manager.drop_temp_classifier(classifier_id=classifier_id)
149
+
150
+
151
+ class SaveClassifierRequest(BaseModel):
152
+ """Request for saving classifier to a file."""
153
+
154
+ file_path: str
155
+
156
+
157
+ @classifier_router.post(
158
+ "/classifiers/{classifier_id}/save_classifier_to_file/{export_type}",
159
+ )
160
+ def save_classifier_to_file(
161
+ classifier_id: UUID,
162
+ export_type: ExportType,
163
+ ) -> StreamingResponse:
164
+ """Save the classifier to a file.
165
+
166
+ Args:
167
+ classifier_id: The ID of the classifier.
168
+ export_type: The type of export (e.g., "sklearn", "lightly").
169
+
170
+ Returns:
171
+ StreamingResponse containing the pickled classifier file.
172
+ """
173
+ classifier_manager = ClassifierManagerProvider.get_classifier_manager()
174
+ # Use BytesIO to capture the file content and send it as a response.
175
+ buffer = io.BytesIO()
176
+ classifier_manager.save_classifier_to_buffer(
177
+ classifier_id=classifier_id, buffer=buffer, export_type=export_type
178
+ )
179
+ buffer.seek(0)
180
+
181
+ # Get classifier name for the filename
182
+ classifier = classifier_manager.get_classifier_by_id(classifier_id=classifier_id)
183
+ filename = f"{classifier.classifier_name}.pkl"
184
+ headers = {
185
+ "Content-Disposition": f'attachment; filename="{filename}"',
186
+ "Content-Type": "application/octet-stream",
187
+ "Access-Control-Expose-Headers": "Content-Disposition",
188
+ }
189
+
190
+ return StreamingResponse(buffer, headers=headers, media_type="application/octet-stream")
191
+
192
+
193
+ class LoadClassifierRequest(BaseModel):
194
+ """Request for loading classifier from a file."""
195
+
196
+ file_path: str
197
+
198
+
199
+ class LoadClassifierResponse(BaseModel):
200
+ """Response for loading classifier from a file."""
201
+
202
+ classifier_id: UUID
203
+
204
+
205
+ @classifier_router.post(
206
+ "/classifiers/load_classifier_from_file",
207
+ )
208
+ def load_classifier_from_file(
209
+ request: LoadClassifierRequest,
210
+ session: SessionDep,
211
+ ) -> LoadClassifierResponse:
212
+ """Load the classifier from a file.
213
+
214
+ Args:
215
+ request: The request containing the file path.
216
+ session: Database session.
217
+
218
+ Returns:
219
+ Response with the ID of the loaded classifier.
220
+ """
221
+ classifier_manager = ClassifierManagerProvider.get_classifier_manager()
222
+ classifier = classifier_manager.load_classifier_from_file(
223
+ session=session, file_path=Path(request.file_path)
224
+ )
225
+ return LoadClassifierResponse(classifier_id=classifier.classifier_id)
226
+
227
+
228
+ @classifier_router.post(
229
+ "/classifiers/load_classifier_from_buffer",
230
+ )
231
+ def load_classifier_from_buffer(
232
+ file: UploadFile,
233
+ session: SessionDep,
234
+ ) -> UUID:
235
+ """Load a classifier from an uploaded file buffer.
236
+
237
+ Args:
238
+ file: The uploaded classifier file.
239
+ session: Database session.
240
+
241
+ Returns:
242
+ The ID of the loaded classifier.
243
+ """
244
+ classifier_manager = ClassifierManagerProvider.get_classifier_manager()
245
+
246
+ # Read file into buffer
247
+ buffer = io.BytesIO(file.file.read())
248
+
249
+ # Load classifier from buffer
250
+ classifier = classifier_manager.load_classifier_from_buffer(session=session, buffer=buffer)
251
+ return classifier.classifier_id
252
+
253
+
254
+ @classifier_router.post(
255
+ "/classifiers/{classifier_id}/train_classifier",
256
+ )
257
+ def train_classifier(
258
+ classifier_id: UUID,
259
+ session: SessionDep,
260
+ ) -> None:
261
+ """Train the classifier.
262
+
263
+ Args:
264
+ classifier_id: The ID of the classifier.
265
+ session: Database session.
266
+
267
+ Returns:
268
+ None
269
+ """
270
+ classifier_manager = ClassifierManagerProvider.get_classifier_manager()
271
+ classifier_manager.train_classifier(session=session, classifier_id=classifier_id)
272
+
273
+
274
+ class UpdateAnnotationsRequest(BaseModel):
275
+ """Request for updating classifier annotations."""
276
+
277
+ annotations: dict[str, list[UUID]]
278
+
279
+
280
+ @classifier_router.post(
281
+ "/classifiers/{classifier_id}/update_annotations",
282
+ )
283
+ def update_classifiers_annotations(
284
+ classifier_id: UUID,
285
+ request: UpdateAnnotationsRequest,
286
+ ) -> None:
287
+ """Update the annotations for a classifier.
288
+
289
+ Args:
290
+ classifier_id: The ID of the classifier.
291
+ request: The request containing the new annotations.
292
+
293
+ Returns:
294
+ None
295
+
296
+ """
297
+ classifier_manager = ClassifierManagerProvider.get_classifier_manager()
298
+ classifier_manager.update_classifiers_annotations(
299
+ classifier_id=classifier_id,
300
+ new_annotations=request.annotations,
301
+ )
302
+
303
+
304
+ class CreateClassifierRequest(BaseModel):
305
+ """Request model for creating a classifier."""
306
+
307
+ name: str
308
+ class_list: list[str]
309
+ dataset_id: UUID
310
+
311
+
312
+ class CreateClassifierResponse(BaseModel):
313
+ """Response model for creating a classifier."""
314
+
315
+ name: str
316
+ classifier_id: str
317
+
318
+
319
+ @classifier_router.post("/classifiers/create")
320
+ def create_classifier(
321
+ request: CreateClassifierRequest, session: SessionDep
322
+ ) -> CreateClassifierResponse:
323
+ """Create a new classifier.
324
+
325
+ Args:
326
+ request: The request containing classifier creation parameters.
327
+ session: Database session.
328
+
329
+ Returns:
330
+ Response with the name and ID of the classifier.
331
+
332
+ """
333
+ classifier_manager = ClassifierManagerProvider.get_classifier_manager()
334
+ classifier = classifier_manager.create_classifier(
335
+ session=session,
336
+ name=request.name,
337
+ class_list=request.class_list,
338
+ dataset_id=request.dataset_id,
339
+ )
340
+ return CreateClassifierResponse(
341
+ name=classifier.few_shot_classifier.name,
342
+ classifier_id=str(classifier.classifier_id),
343
+ )
344
+
345
+
346
+ class GetAllClassifiersResponse(BaseModel):
347
+ """Response model for getting all active classifiers."""
348
+
349
+ classifiers: list[EmbeddingClassifier]
350
+
351
+
352
+ @classifier_router.get("/classifiers/get_all_classifiers")
353
+ def get_all_classifiers() -> GetAllClassifiersResponse:
354
+ """Get all active classifiers.
355
+
356
+ Returns:
357
+ Response with list of tuples containing classifier names and IDs.
358
+ """
359
+ classifier_manager = ClassifierManagerProvider.get_classifier_manager()
360
+ classifiers = classifier_manager.get_all_classifiers()
361
+ return GetAllClassifiersResponse(classifiers=classifiers)
362
+
363
+
364
+ @classifier_router.post(
365
+ "/classifiers/{classifier_id}/run_on_dataset/{dataset_id}",
366
+ )
367
+ def run_classifier_route(
368
+ classifier_id: UUID,
369
+ dataset_id: UUID,
370
+ session: SessionDep,
371
+ ) -> None:
372
+ """Run the classifier on a dataset.
373
+
374
+ Args:
375
+ dataset_id: The ID of the dataset to run the classifier on.
376
+ classifier_id: The ID of the classifier.
377
+ session: Database session.
378
+
379
+ Returns:
380
+ None
381
+ """
382
+ classifier_manager = ClassifierManagerProvider.get_classifier_manager()
383
+ classifier_manager.run_classifier(
384
+ session=session,
385
+ classifier_id=classifier_id,
386
+ dataset_id=dataset_id,
387
+ )
@@ -0,0 +1,182 @@
1
+ """This module contains the API routes for managing datasets."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from datetime import datetime, timezone
6
+ from typing import List
7
+ from uuid import UUID
8
+
9
+ from fastapi import APIRouter, Depends, HTTPException, Path, Query
10
+ from fastapi.responses import PlainTextResponse
11
+ from pydantic import BaseModel
12
+ from sqlmodel import Field, Session
13
+ from typing_extensions import Annotated
14
+
15
+ from lightly_studio.api.db import get_session
16
+ from lightly_studio.api.routes.api.status import HTTP_STATUS_NOT_FOUND
17
+ from lightly_studio.api.routes.api.validators import Paginated
18
+ from lightly_studio.models.dataset import (
19
+ DatasetCreate,
20
+ DatasetTable,
21
+ DatasetView,
22
+ )
23
+ from lightly_studio.resolvers import dataset_resolver
24
+ from lightly_studio.resolvers.dataset_resolver import (
25
+ ExportFilter,
26
+ )
27
+
28
+ dataset_router = APIRouter()
29
+ SessionDep = Annotated[Session, Depends(get_session)]
30
+
31
+
32
+ def get_and_validate_dataset_id(
33
+ session: SessionDep,
34
+ dataset_id: UUID,
35
+ ) -> DatasetTable:
36
+ """Get and validate the existence of a dataset on a route."""
37
+ dataset = dataset_resolver.get_by_id(session=session, dataset_id=dataset_id)
38
+ if not dataset:
39
+ raise HTTPException(
40
+ status_code=HTTP_STATUS_NOT_FOUND,
41
+ detail=f""" Dataset with {dataset_id} not found.""",
42
+ )
43
+ return dataset
44
+
45
+
46
+ @dataset_router.post(
47
+ "/datasets",
48
+ response_model=DatasetView,
49
+ status_code=201,
50
+ )
51
+ def create_dataset(
52
+ dataset_input: DatasetCreate,
53
+ session: SessionDep,
54
+ ) -> DatasetTable:
55
+ """Create a new dataset in the database."""
56
+ return dataset_resolver.create(session=session, dataset=dataset_input)
57
+
58
+
59
+ @dataset_router.get("/datasets", response_model=List[DatasetView])
60
+ def read_datasets(
61
+ session: SessionDep,
62
+ paginated: Annotated[Paginated, Query()],
63
+ ) -> list[DatasetTable]:
64
+ """Retrieve a list of datasets from the database."""
65
+ return dataset_resolver.get_all(session=session, offset=paginated.offset, limit=paginated.limit)
66
+
67
+
68
+ @dataset_router.get("/datasets/{dataset_id}")
69
+ def read_dataset(
70
+ dataset: Annotated[
71
+ DatasetTable,
72
+ Path(title="Dataset Id"),
73
+ Depends(get_and_validate_dataset_id),
74
+ ],
75
+ ) -> DatasetTable:
76
+ """Retrieve a single dataset from the database."""
77
+ return dataset
78
+
79
+
80
+ @dataset_router.put("/datasets/{dataset_id}")
81
+ def update_dataset(
82
+ session: SessionDep,
83
+ dataset: Annotated[
84
+ DatasetTable,
85
+ Path(title="Dataset Id"),
86
+ Depends(get_and_validate_dataset_id),
87
+ ],
88
+ dataset_input: DatasetCreate,
89
+ ) -> DatasetTable:
90
+ """Update an existing dataset in the database."""
91
+ return dataset_resolver.update(
92
+ session=session,
93
+ dataset_id=dataset.dataset_id,
94
+ dataset_data=dataset_input,
95
+ )
96
+
97
+
98
+ @dataset_router.delete("/datasets/{dataset_id}")
99
+ def delete_dataset(
100
+ session: SessionDep,
101
+ dataset: Annotated[
102
+ DatasetTable,
103
+ Path(title="Dataset Id"),
104
+ Depends(get_and_validate_dataset_id),
105
+ ],
106
+ ) -> dict[str, str]:
107
+ """Delete a dataset from the database."""
108
+ dataset_resolver.delete(session=session, dataset_id=dataset.dataset_id)
109
+ return {"status": "deleted"}
110
+
111
+
112
+ class ExportBody(BaseModel):
113
+ """body parameters for including or excluding tag_ids or sample_ids."""
114
+
115
+ include: ExportFilter | None = Field(
116
+ None, description="include filter for sample_ids or tag_ids"
117
+ )
118
+ exclude: ExportFilter | None = Field(
119
+ None, description="exclude filter for sample_ids or tag_ids"
120
+ )
121
+
122
+
123
+ # This endpoint should be a GET, however due to the potential huge size
124
+ # of sample_ids, it is a POST request to avoid URL length limitations.
125
+ # A body with a GET request is supported by fastAPI however it has undefined
126
+ # behavior: https://fastapi.tiangolo.com/tutorial/body/
127
+ @dataset_router.post(
128
+ "/datasets/{dataset_id}/export",
129
+ )
130
+ def export_dataset_to_absolute_paths(
131
+ session: SessionDep,
132
+ dataset: Annotated[
133
+ DatasetTable,
134
+ Path(title="Dataset Id"),
135
+ Depends(get_and_validate_dataset_id),
136
+ ],
137
+ body: ExportBody,
138
+ ) -> PlainTextResponse:
139
+ """Export dataset from the database."""
140
+ # export dataset to absolute paths
141
+ exported = dataset_resolver.export(
142
+ session=session,
143
+ dataset_id=dataset.dataset_id,
144
+ include=body.include,
145
+ exclude=body.exclude,
146
+ )
147
+
148
+ # Create a response with the exported data
149
+ response = PlainTextResponse("\n".join(exported))
150
+
151
+ # Add the Content-Disposition header to force download
152
+ filename = f"{dataset.name}_exported_{datetime.now(timezone.utc)}.txt"
153
+ response.headers["Access-Control-Expose-Headers"] = "Content-Disposition"
154
+ response.headers["Content-Disposition"] = f"attachment; filename={filename}"
155
+
156
+ return response
157
+
158
+
159
+ """
160
+ Endpoint to export samples from a dataset.
161
+ """
162
+
163
+
164
+ @dataset_router.post(
165
+ "/datasets/{dataset_id}/export/stats",
166
+ )
167
+ def export_dataset_stats(
168
+ session: SessionDep,
169
+ dataset: Annotated[
170
+ DatasetTable,
171
+ Path(title="Dataset Id"),
172
+ Depends(get_and_validate_dataset_id),
173
+ ],
174
+ body: ExportBody,
175
+ ) -> int:
176
+ """Get statistics about the export query."""
177
+ return dataset_resolver.get_filtered_samples_count(
178
+ session=session,
179
+ dataset_id=dataset.dataset_id,
180
+ include=body.include,
181
+ exclude=body.exclude,
182
+ )