lightly-studio 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lightly-studio might be problematic. Click here for more details.

Files changed (219) hide show
  1. lightly_studio/__init__.py +11 -0
  2. lightly_studio/api/__init__.py +0 -0
  3. lightly_studio/api/app.py +110 -0
  4. lightly_studio/api/cache.py +77 -0
  5. lightly_studio/api/db.py +133 -0
  6. lightly_studio/api/db_tables.py +32 -0
  7. lightly_studio/api/features.py +7 -0
  8. lightly_studio/api/routes/api/annotation.py +233 -0
  9. lightly_studio/api/routes/api/annotation_label.py +90 -0
  10. lightly_studio/api/routes/api/annotation_task.py +38 -0
  11. lightly_studio/api/routes/api/classifier.py +387 -0
  12. lightly_studio/api/routes/api/dataset.py +182 -0
  13. lightly_studio/api/routes/api/dataset_tag.py +257 -0
  14. lightly_studio/api/routes/api/exceptions.py +96 -0
  15. lightly_studio/api/routes/api/features.py +17 -0
  16. lightly_studio/api/routes/api/metadata.py +37 -0
  17. lightly_studio/api/routes/api/metrics.py +80 -0
  18. lightly_studio/api/routes/api/sample.py +196 -0
  19. lightly_studio/api/routes/api/settings.py +45 -0
  20. lightly_studio/api/routes/api/status.py +19 -0
  21. lightly_studio/api/routes/api/text_embedding.py +48 -0
  22. lightly_studio/api/routes/api/validators.py +17 -0
  23. lightly_studio/api/routes/healthz.py +13 -0
  24. lightly_studio/api/routes/images.py +104 -0
  25. lightly_studio/api/routes/webapp.py +51 -0
  26. lightly_studio/api/server.py +82 -0
  27. lightly_studio/core/__init__.py +0 -0
  28. lightly_studio/core/dataset.py +523 -0
  29. lightly_studio/core/sample.py +77 -0
  30. lightly_studio/core/start_gui.py +15 -0
  31. lightly_studio/dataset/__init__.py +0 -0
  32. lightly_studio/dataset/edge_embedding_generator.py +144 -0
  33. lightly_studio/dataset/embedding_generator.py +91 -0
  34. lightly_studio/dataset/embedding_manager.py +163 -0
  35. lightly_studio/dataset/env.py +16 -0
  36. lightly_studio/dataset/file_utils.py +35 -0
  37. lightly_studio/dataset/loader.py +622 -0
  38. lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
  39. lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
  40. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
  41. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
  42. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
  43. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
  44. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
  45. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
  46. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
  47. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
  48. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
  49. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
  50. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
  51. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
  52. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
  53. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
  54. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
  55. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
  56. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
  57. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
  58. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
  59. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
  60. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
  61. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
  62. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
  63. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
  78. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
  79. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
  80. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
  81. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
  82. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
  83. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
  84. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
  85. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
  86. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
  87. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
  88. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
  89. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
  90. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
  91. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
  92. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
  93. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
  94. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
  95. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
  96. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
  97. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
  98. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
  99. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
  100. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
  101. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
  102. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
  103. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
  104. lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
  105. lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
  106. lightly_studio/examples/example.py +23 -0
  107. lightly_studio/examples/example_metadata.py +338 -0
  108. lightly_studio/examples/example_selection.py +39 -0
  109. lightly_studio/examples/example_split_work.py +67 -0
  110. lightly_studio/examples/example_v2.py +21 -0
  111. lightly_studio/export_schema.py +18 -0
  112. lightly_studio/few_shot_classifier/__init__.py +0 -0
  113. lightly_studio/few_shot_classifier/classifier.py +80 -0
  114. lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
  115. lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
  116. lightly_studio/metadata/complex_metadata.py +47 -0
  117. lightly_studio/metadata/gps_coordinate.py +41 -0
  118. lightly_studio/metadata/metadata_protocol.py +17 -0
  119. lightly_studio/metrics/__init__.py +0 -0
  120. lightly_studio/metrics/detection/__init__.py +0 -0
  121. lightly_studio/metrics/detection/map.py +268 -0
  122. lightly_studio/models/__init__.py +1 -0
  123. lightly_studio/models/annotation/__init__.py +0 -0
  124. lightly_studio/models/annotation/annotation_base.py +171 -0
  125. lightly_studio/models/annotation/instance_segmentation.py +56 -0
  126. lightly_studio/models/annotation/links.py +17 -0
  127. lightly_studio/models/annotation/object_detection.py +47 -0
  128. lightly_studio/models/annotation/semantic_segmentation.py +44 -0
  129. lightly_studio/models/annotation_label.py +47 -0
  130. lightly_studio/models/annotation_task.py +28 -0
  131. lightly_studio/models/classifier.py +20 -0
  132. lightly_studio/models/dataset.py +84 -0
  133. lightly_studio/models/embedding_model.py +30 -0
  134. lightly_studio/models/metadata.py +208 -0
  135. lightly_studio/models/sample.py +180 -0
  136. lightly_studio/models/sample_embedding.py +37 -0
  137. lightly_studio/models/settings.py +60 -0
  138. lightly_studio/models/tag.py +96 -0
  139. lightly_studio/py.typed +0 -0
  140. lightly_studio/resolvers/__init__.py +7 -0
  141. lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
  142. lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
  143. lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
  144. lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
  145. lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
  146. lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
  147. lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
  148. lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
  149. lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
  150. lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
  151. lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
  152. lightly_studio/resolvers/annotation_resolver/create.py +19 -0
  153. lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
  154. lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
  155. lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
  156. lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
  157. lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
  158. lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
  159. lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
  160. lightly_studio/resolvers/annotation_task_resolver.py +31 -0
  161. lightly_studio/resolvers/annotations/__init__.py +1 -0
  162. lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
  163. lightly_studio/resolvers/dataset_resolver.py +278 -0
  164. lightly_studio/resolvers/embedding_model_resolver.py +100 -0
  165. lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
  166. lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
  167. lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
  168. lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
  169. lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
  170. lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
  171. lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
  172. lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
  173. lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
  174. lightly_studio/resolvers/sample_resolver.py +249 -0
  175. lightly_studio/resolvers/samples_filter.py +81 -0
  176. lightly_studio/resolvers/settings_resolver.py +58 -0
  177. lightly_studio/resolvers/tag_resolver.py +276 -0
  178. lightly_studio/selection/README.md +6 -0
  179. lightly_studio/selection/mundig.py +105 -0
  180. lightly_studio/selection/select.py +96 -0
  181. lightly_studio/selection/select_via_db.py +93 -0
  182. lightly_studio/selection/selection_config.py +31 -0
  183. lightly_studio/services/annotations_service/__init__.py +21 -0
  184. lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
  185. lightly_studio/services/annotations_service/update_annotation.py +65 -0
  186. lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
  187. lightly_studio/services/annotations_service/update_annotations.py +29 -0
  188. lightly_studio/setup_logging.py +19 -0
  189. lightly_studio/type_definitions.py +19 -0
  190. lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
  191. lightly_studio/vendor/LICENSE +31 -0
  192. lightly_studio/vendor/LICENSE_weights_data +50 -0
  193. lightly_studio/vendor/README.md +5 -0
  194. lightly_studio/vendor/__init__.py +1 -0
  195. lightly_studio/vendor/mobileclip/__init__.py +96 -0
  196. lightly_studio/vendor/mobileclip/clip.py +77 -0
  197. lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
  198. lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
  199. lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
  200. lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
  201. lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
  202. lightly_studio/vendor/mobileclip/logger.py +154 -0
  203. lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
  204. lightly_studio/vendor/mobileclip/models/mci.py +933 -0
  205. lightly_studio/vendor/mobileclip/models/vit.py +433 -0
  206. lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
  207. lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
  208. lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
  209. lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
  210. lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
  211. lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
  212. lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
  213. lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
  214. lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
  215. lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
  216. lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
  217. lightly_studio-0.3.1.dist-info/METADATA +520 -0
  218. lightly_studio-0.3.1.dist-info/RECORD +219 -0
  219. lightly_studio-0.3.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,663 @@
1
+ """ClassifierManager implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import io
7
+ from collections.abc import Sequence
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from random import sample
11
+ from uuid import UUID, uuid4
12
+
13
+ from sqlmodel import Session
14
+
15
+ from lightly_studio.few_shot_classifier import random_forest_classifier
16
+ from lightly_studio.few_shot_classifier.classifier import (
17
+ AnnotatedEmbedding,
18
+ ExportType,
19
+ )
20
+ from lightly_studio.few_shot_classifier.random_forest_classifier import (
21
+ RandomForest,
22
+ )
23
+ from lightly_studio.models.annotation.annotation_base import (
24
+ AnnotationCreate,
25
+ )
26
+ from lightly_studio.models.annotation_label import (
27
+ AnnotationLabelCreate,
28
+ )
29
+ from lightly_studio.models.annotation_task import (
30
+ AnnotationTaskTable,
31
+ AnnotationType,
32
+ )
33
+ from lightly_studio.models.classifier import EmbeddingClassifier
34
+ from lightly_studio.models.sample import SampleTable
35
+ from lightly_studio.resolvers import (
36
+ annotation_label_resolver,
37
+ annotation_resolver,
38
+ annotation_task_resolver,
39
+ embedding_model_resolver,
40
+ sample_embedding_resolver,
41
+ sample_resolver,
42
+ )
43
+
44
+ HIGH_CONFIDENCE_THRESHOLD = 0.5
45
+ LOW_CONFIDENCE_THRESHOLD = 0.5
46
+
47
+ HIGH_CONFIDENCE_SAMPLES_NEEDED = 10
48
+ LOW_CONFIDENCE_SAMPLES_NEEDED = 10
49
+
50
+ FSC_ANNOTATION_TASK_PREFIX = "FSC_"
51
+
52
+
53
+ class ClassifierManagerProvider:
54
+ """Provider for the ClassifierManager singleton instance."""
55
+
56
+ _instance: ClassifierManager | None = None
57
+
58
+ @classmethod
59
+ def get_classifier_manager(cls) -> ClassifierManager:
60
+ """Get the singleton instance of ClassifierManager.
61
+
62
+ Returns:
63
+ The singleton instance of ClassifierManager.
64
+
65
+ Raises:
66
+ ValueError: If no instance exists and no session is provided.
67
+ """
68
+ if cls._instance is None:
69
+ cls._instance = ClassifierManager()
70
+ return cls._instance
71
+
72
+
73
+ @dataclass
74
+ class ClassifierEntry:
75
+ """Classifier dataclass."""
76
+
77
+ classifier_id: UUID
78
+
79
+ # TODO(Horatiu, 05/2025): Use FewShotClassifier instead of RandomForest
80
+ # when the interface is ready. Add method to get classifier info.
81
+ few_shot_classifier: RandomForest
82
+
83
+ # Annotations history is used to keep track of the samples that have
84
+ # been used for training. It is a dictionary with the key the class name and
85
+ # the value a list of sample IDs that belong to that class.
86
+ # This is used to avoid using the same samples for fine tuning multiple
87
+ # times.
88
+ annotations: dict[str, list[UUID]]
89
+
90
+ # Inactive classifiers are used for handling the fine tuning process.
91
+ # From the moment the classifier is created untill it is saved is_active
92
+ # will be false.
93
+ is_active: bool = False
94
+
95
+ annotation_task_id: UUID | None = None
96
+ annotation_label_ids: list[UUID] | None = None
97
+
98
+
99
+ class ClassifierManager:
100
+ """ClassifierManager class.
101
+
102
+ This class manages the lifecycle of a few-shot classifier,
103
+ including training, exporting, and loading the classifier.
104
+ """
105
+
106
+ def __init__(self) -> None:
107
+ """Initialize the ClassifierManager."""
108
+ self._classifiers: dict[UUID, ClassifierEntry] = {}
109
+
110
+ def create_classifier(
111
+ self,
112
+ session: Session,
113
+ name: str,
114
+ class_list: list[str],
115
+ dataset_id: UUID,
116
+ ) -> ClassifierEntry:
117
+ """Create a new classifier.
118
+
119
+ Args:
120
+ session: Database session for resolver operations.
121
+ name: The name of the classifier.
122
+ class_list: List of classes to be used for training.
123
+ dataset_id: The dataset_id to which the samples belong.
124
+
125
+ Returns:
126
+ The created classifier name and ID.
127
+ """
128
+ embedding_models = embedding_model_resolver.get_all_by_dataset_id(
129
+ session=session,
130
+ dataset_id=dataset_id,
131
+ )
132
+ if len(embedding_models) == 0:
133
+ raise ValueError("No embedding model found for the given dataset ID.")
134
+ # TODO(Horatiu, 05/2025): Handle multiple models correctly when
135
+ # available
136
+ if len(embedding_models) > 1:
137
+ raise ValueError("Multiple embedding models found for the given dataset ID.")
138
+ embedding_model = embedding_models[0]
139
+ classifier = RandomForest(
140
+ name=name,
141
+ classes=class_list,
142
+ embedding_model_hash=embedding_model.embedding_model_hash,
143
+ embedding_model_name=embedding_model.name,
144
+ )
145
+
146
+ classifier_id = uuid4()
147
+ self._classifiers[classifier_id] = ClassifierEntry(
148
+ classifier_id=classifier_id,
149
+ few_shot_classifier=classifier,
150
+ is_active=False,
151
+ annotations={class_name: [] for class_name in class_list},
152
+ )
153
+
154
+ return self._classifiers[classifier_id]
155
+
156
+ def train_classifier(self, session: Session, classifier_id: UUID) -> None:
157
+ """Train the classifier.
158
+
159
+ Args:
160
+ session: Database session for resolver operations.
161
+ classifier_id: The ID of the classifier to train.
162
+
163
+ Raises:
164
+ ValueError: If the classifier with the given ID does not exist.
165
+ """
166
+ classifier = self._classifiers.get(classifier_id)
167
+ if classifier is None:
168
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
169
+
170
+ embedding_model = embedding_model_resolver.get_by_model_hash(
171
+ session=session,
172
+ embedding_model_hash=classifier.few_shot_classifier.embedding_model_hash,
173
+ )
174
+ if embedding_model is None:
175
+ raise ValueError(
176
+ "No embedding model found for hash '"
177
+ f"{classifier.few_shot_classifier.embedding_model_hash}'"
178
+ )
179
+
180
+ # Get annotations.
181
+ annotations = classifier.annotations
182
+ annotated_embeddings = _create_annotated_embeddings(
183
+ session=session,
184
+ class_to_sample_ids=annotations,
185
+ embedding_model_id=embedding_model.embedding_model_id,
186
+ )
187
+ # Train the classifier with the annotated embeddings.
188
+ # This will overwrite the previous training.
189
+ classifier.few_shot_classifier.train(annotated_embeddings)
190
+
191
+ def commit_temp_classifier(self, classifier_id: UUID) -> None:
192
+ """Set the classifier as active.
193
+
194
+ Args:
195
+ classifier_id: The ID of the classifier to save.
196
+
197
+ Raises:
198
+ ValueError: If the classifier with the given ID does not exist
199
+ or if the classifier is not yet trained.
200
+ """
201
+ classifier = self._classifiers.get(classifier_id)
202
+ if classifier is None:
203
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
204
+ if classifier.few_shot_classifier.is_trained() is False:
205
+ raise ValueError(f"Classifier with ID {classifier_id} is not trained yet.")
206
+ classifier.is_active = True
207
+
208
+ def drop_temp_classifier(self, classifier_id: UUID) -> None:
209
+ """Remove a classifier that is inactive.
210
+
211
+ Args:
212
+ classifier_id: The ID of the classifier to drop.
213
+
214
+ Raises:
215
+ ValueError: If the classifier with the given ID does not exist.
216
+ """
217
+ classifier = self._classifiers.get(classifier_id)
218
+ if classifier is None:
219
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
220
+ if classifier.is_active:
221
+ raise ValueError(f"Classifier with ID {classifier_id} is active and cannot be dropped.")
222
+ self._classifiers.pop(classifier_id, None)
223
+
224
+ def save_classifier_to_file(self, classifier_id: UUID, file_path: Path) -> None:
225
+ """Save the classifier to file.
226
+
227
+ Args:
228
+ classifier_id: The ID of the classifier to save.
229
+ file_path: The path to save the classifer to.
230
+
231
+ Raises:
232
+ ValueError: If the classifier with the given ID does not exist.
233
+ """
234
+ classifier = self._classifiers.get(classifier_id)
235
+ if classifier is None:
236
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
237
+ if not classifier.is_active:
238
+ raise ValueError(
239
+ f"Classifier with ID {classifier_id} is not active and cannot be saved."
240
+ )
241
+ classifier.few_shot_classifier.export(export_path=file_path, export_type="sklearn")
242
+
243
+ def load_classifier_from_file(self, session: Session, file_path: Path) -> ClassifierEntry:
244
+ """Loads a classifier from file.
245
+
246
+ Args:
247
+ session: Database session for resolver operations.
248
+ file_path: The path from where to load the classifier.
249
+
250
+ Returns:
251
+ The ID of the loaded classifier.
252
+ """
253
+ classifier = random_forest_classifier.load_random_forest_classifier(
254
+ classifier_path=file_path, buffer=None
255
+ )
256
+ embedding_model = embedding_model_resolver.get_by_model_hash(
257
+ session=session,
258
+ embedding_model_hash=classifier.embedding_model_hash,
259
+ )
260
+ if embedding_model is None:
261
+ raise ValueError(
262
+ "No matching embedding model found for the classifier's hash:"
263
+ f"'{classifier.embedding_model_hash}'."
264
+ )
265
+
266
+ classifier_id = uuid4()
267
+ self._classifiers[classifier_id] = ClassifierEntry(
268
+ classifier_id=classifier_id,
269
+ few_shot_classifier=classifier,
270
+ is_active=True,
271
+ annotations={class_name: [] for class_name in classifier.classes},
272
+ )
273
+ return self._classifiers[classifier_id]
274
+
275
+ def provide_negative_samples(
276
+ self, session: Session, dataset_id: UUID, selected_samples: list[UUID], limit: int = 10
277
+ ) -> Sequence[SampleTable]:
278
+ """Provide random samples that are not in the selected samples.
279
+
280
+ Args:
281
+ session: Database session for resolver operations.
282
+ dataset_id: The dataset_id to pull samples from.
283
+ selected_samples: List of sample UUIDs to exclude.
284
+ limit: Number of negative samples to return.
285
+
286
+ Returns:
287
+ List of negative samples.
288
+
289
+ """
290
+ return sample_resolver.get_samples_excluding(
291
+ session=session,
292
+ dataset_id=dataset_id,
293
+ excluded_sample_ids=selected_samples,
294
+ limit=limit,
295
+ )
296
+
297
+ def update_classifiers_annotations(
298
+ self,
299
+ classifier_id: UUID,
300
+ new_annotations: dict[str, list[UUID]],
301
+ ) -> None:
302
+ """Update annotations with new samples for multiple classes.
303
+
304
+ Args:
305
+ classifier_id: The ID of the classifier.
306
+ new_annotations: Dictionary mapping class names to lists of sample
307
+ IDs.
308
+
309
+ Raises:
310
+ ValueError: If the classifier doesn't exist.
311
+ """
312
+ classifier = self._classifiers.get(classifier_id)
313
+ if classifier is None:
314
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
315
+
316
+ annotations = classifier.annotations
317
+ # Validate no new classes are being added.
318
+ if not set(new_annotations.keys()).issubset(annotations.keys()):
319
+ invalid_classes = set(new_annotations.keys()) - set(annotations.keys())
320
+ raise ValueError(
321
+ f"Cannot add new classes {invalid_classes} to existing"
322
+ f" classifier. Allowed classes are: {set(annotations.keys())}"
323
+ )
324
+
325
+ # Get all new samples that will be added.
326
+ all_new_samples = {
327
+ sample_id for samples in new_annotations.values() for sample_id in samples
328
+ }
329
+
330
+ # Update annotations.
331
+ for existing_class in annotations:
332
+ # Remove newly annotated samples if existing already
333
+ # and add samples for this class.
334
+ new_class_samples = set(new_annotations.get(existing_class, []))
335
+ annotations[existing_class] = list(
336
+ (set(annotations[existing_class]) - all_new_samples) | new_class_samples
337
+ )
338
+
339
+ def get_annotations(self, classifier_id: UUID) -> dict[str, list[UUID]]:
340
+ """Get all samples used in training for each class.
341
+
342
+ Args:
343
+ classifier_id: The ID of the classifier.
344
+
345
+ Returns:
346
+ Dictionary mapping class names to lists of sample IDs.
347
+
348
+ Raises:
349
+ ValueError: If the classifier doesn't exist.
350
+ """
351
+ classifier = self._classifiers.get(classifier_id)
352
+ if classifier is None:
353
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
354
+
355
+ return copy.deepcopy(classifier.annotations)
356
+
357
+ def get_samples_for_fine_tuning(
358
+ self, session: Session, dataset_id: UUID, classifier_id: UUID
359
+ ) -> dict[str, list[UUID]]:
360
+ """Get samples for fine-tuning the classifier.
361
+
362
+ Gets at most 20 samples total:
363
+ - 10 positive samples (prediction confidence > 0.5)
364
+ - 10 uncertain samples (prediction confidence < 0.5)
365
+ If there are not enough samples, it will return all available
366
+ samples of that type.
367
+
368
+ Args:
369
+ session: Database session for resolver operations.
370
+ dataset_id: The ID of the dataset to pull samples from.
371
+ classifier_id: The ID of the classifier to use.
372
+
373
+ Returns:
374
+ Dictionary mapping class names to sample IDs. The first class from
375
+ classifier.classes gets samples with high confidence predictions,
376
+ the second class gets samples with low confidence predictions.
377
+
378
+ Raises:
379
+ ValueError: If the classifier with the given ID does not exist
380
+ or there is no appropriate embedding model.
381
+ """
382
+ classifier = self._classifiers.get(classifier_id)
383
+ if classifier is None:
384
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
385
+ # Get all previously used annotations.
386
+ annotations = classifier.annotations
387
+ used_samples = {sample_id for samples in annotations.values() for sample_id in samples}
388
+
389
+ embedding_model = embedding_model_resolver.get_by_model_hash(
390
+ session=session,
391
+ embedding_model_hash=classifier.few_shot_classifier.embedding_model_hash,
392
+ )
393
+ if embedding_model is None:
394
+ raise ValueError(
395
+ "No embedding model found for hash '"
396
+ f"{classifier.few_shot_classifier.embedding_model_hash}'"
397
+ )
398
+
399
+ # Create list of SampleEmbedding objects to track sample IDs
400
+ sample_embeddings = sample_embedding_resolver.get_all_by_dataset_id(
401
+ session=session,
402
+ dataset_id=dataset_id,
403
+ embedding_model_id=embedding_model.embedding_model_id,
404
+ )
405
+
406
+ # Get predictions for all embeddings.
407
+ embeddings = [se.embedding for se in sample_embeddings]
408
+ predictions = classifier.few_shot_classifier.predict(embeddings)
409
+
410
+ # Group samples by prediction confidence.
411
+ high_conf = [] # > 0.5
412
+ low_conf = [] # <= 0.5
413
+
414
+ for sample_embedding, pred in zip(sample_embeddings, predictions):
415
+ if sample_embedding.sample_id in used_samples:
416
+ continue
417
+ if pred[0] > HIGH_CONFIDENCE_THRESHOLD:
418
+ high_conf.append(sample_embedding.sample_id)
419
+ elif pred[0] <= LOW_CONFIDENCE_THRESHOLD:
420
+ low_conf.append(sample_embedding.sample_id)
421
+
422
+ return {
423
+ classifier.few_shot_classifier.classes[0]: sample(
424
+ high_conf, min(len(high_conf), HIGH_CONFIDENCE_SAMPLES_NEEDED)
425
+ ),
426
+ classifier.few_shot_classifier.classes[1]: sample(
427
+ low_conf, min(len(low_conf), LOW_CONFIDENCE_SAMPLES_NEEDED)
428
+ ),
429
+ }
430
+
431
+ def run_classifier(self, session: Session, classifier_id: UUID, dataset_id: UUID) -> None:
432
+ """Run the classifier on the dataset.
433
+
434
+ Args:
435
+ session: Database session for resolver operations.
436
+ classifier_id: The ID of the classifier to run.
437
+ dataset_id: The ID of the dataset to run the classifier on.
438
+
439
+ Raises:
440
+ ValueError: If the classifier with the given ID does not exist
441
+ or there is no appropriate embedding model.
442
+ """
443
+ classifier = self._classifiers.get(classifier_id)
444
+ if classifier is None:
445
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
446
+
447
+ if not classifier.is_active:
448
+ raise ValueError(
449
+ f"Classifier with ID {classifier_id} is not active and cannot be used."
450
+ )
451
+ embedding_model = embedding_model_resolver.get_by_model_hash(
452
+ session=session,
453
+ embedding_model_hash=classifier.few_shot_classifier.embedding_model_hash,
454
+ )
455
+ if embedding_model is None:
456
+ raise ValueError(
457
+ "No embedding model found for hash '"
458
+ f"{classifier.few_shot_classifier.embedding_model_hash}'"
459
+ )
460
+
461
+ # Create list of SampleEmbedding objects to track sample IDs
462
+ sample_embeddings = sample_embedding_resolver.get_all_by_dataset_id(
463
+ session=session,
464
+ dataset_id=dataset_id,
465
+ embedding_model_id=embedding_model.embedding_model_id,
466
+ )
467
+
468
+ # Extract just the embeddings for prediction
469
+ embeddings = [se.embedding for se in sample_embeddings]
470
+ predictions = classifier.few_shot_classifier.predict(embeddings)
471
+ if len(predictions):
472
+ _create_annotation_task_and_labels_for_classifier(
473
+ classifier=classifier,
474
+ session=session,
475
+ dataset_id=dataset_id,
476
+ )
477
+ else:
478
+ raise ValueError(f"Predict returned empty list for classifier:'{classifier_id}'")
479
+ # Check if annotation task or labels are available
480
+ if not classifier.annotation_task_id:
481
+ raise ValueError(f"Classifier with ID '{classifier_id}' has no annotation task.")
482
+ if not classifier.annotation_label_ids:
483
+ raise ValueError(f"Classifier with ID '{classifier_id}' has no annotation labels")
484
+
485
+ # For each prediction add a classification annotation for the
486
+ # sample or update an existing one.
487
+ classification_annotations = []
488
+ for sample_embedding, prediction in zip(sample_embeddings, predictions):
489
+ max_index = prediction.index(max(prediction))
490
+ classification_annotations.append(
491
+ AnnotationCreate(
492
+ sample_id=sample_embedding.sample_id,
493
+ annotation_task_id=classifier.annotation_task_id,
494
+ dataset_id=dataset_id,
495
+ annotation_label_id=classifier.annotation_label_ids[max_index],
496
+ annotation_type=AnnotationType.CLASSIFICATION,
497
+ confidence=prediction[max_index],
498
+ )
499
+ )
500
+ # Clear previous annotations by this classifier
501
+ annotation_resolver.delete_annotations(
502
+ session=session,
503
+ annotation_task_ids=[classifier.annotation_task_id],
504
+ annotation_label_ids=classifier.annotation_label_ids,
505
+ )
506
+ annotation_resolver.create_many(session=session, annotations=classification_annotations)
507
+
508
+ def get_all_classifiers(self) -> list[EmbeddingClassifier]:
509
+ """Get all active classifiers.
510
+
511
+ Returns:
512
+ List of EmbeddingClassifier objects representing active classifiers.
513
+ """
514
+ return [
515
+ EmbeddingClassifier(
516
+ classifier_name=classifier.few_shot_classifier.name,
517
+ classifier_id=classifier.classifier_id,
518
+ class_list=classifier.few_shot_classifier.classes,
519
+ )
520
+ for classifier in self._classifiers.values()
521
+ if classifier.is_active
522
+ ]
523
+
524
+ def get_classifier_by_id(self, classifier_id: UUID) -> EmbeddingClassifier:
525
+ """Get all active classifiers.
526
+
527
+ Args:
528
+ classifier_id: The ID of the classifier to get.
529
+
530
+ Raises:
531
+ ValueError: If the classifier with the given ID does not exist.
532
+
533
+ Returns:
534
+ EmbeddingClassifier object.
535
+ """
536
+ classifier = self._classifiers.get(classifier_id)
537
+ if classifier is None:
538
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
539
+ return EmbeddingClassifier(
540
+ classifier_name=classifier.few_shot_classifier.name,
541
+ classifier_id=classifier_id,
542
+ class_list=classifier.few_shot_classifier.classes,
543
+ )
544
+
545
+ def save_classifier_to_buffer(
546
+ self, classifier_id: UUID, buffer: io.BytesIO, export_type: ExportType
547
+ ) -> None:
548
+ """Save the classifier to a buffer.
549
+
550
+ Args:
551
+ classifier_id: The ID of the classifier to save.
552
+ buffer: The buffer to save the classifier to.
553
+ export_type: The type of export to perform.
554
+
555
+ Raises:
556
+ ValueError: If the classifier with the given ID does not exist.
557
+ """
558
+ classifier = self._classifiers.get(classifier_id)
559
+ if classifier is None:
560
+ raise ValueError(f"Classifier with ID {classifier_id} not found.")
561
+ classifier.few_shot_classifier.export(buffer=buffer, export_type=export_type)
562
+
563
+ def load_classifier_from_buffer(self, session: Session, buffer: io.BytesIO) -> ClassifierEntry:
564
+ """Loads a classifier from a buffer.
565
+
566
+ Args:
567
+ session: Database session for resolver operations.
568
+ buffer: The buffer containing the classifier data.
569
+
570
+ Returns:
571
+ The ID of the loaded classifier.
572
+
573
+ Raises:
574
+ ValueError: If no matching embedding model is found for the
575
+ classifier.
576
+ """
577
+ classifier = random_forest_classifier.load_random_forest_classifier(
578
+ buffer=buffer, classifier_path=None
579
+ )
580
+ embedding_model = embedding_model_resolver.get_by_model_hash(
581
+ session=session,
582
+ embedding_model_hash=classifier.embedding_model_hash,
583
+ )
584
+ if embedding_model is None:
585
+ raise ValueError(
586
+ "No matching embedding model found for the classifier's hash: "
587
+ f"'{classifier.embedding_model_hash}'."
588
+ )
589
+
590
+ classifier_id = uuid4()
591
+ self._classifiers[classifier_id] = ClassifierEntry(
592
+ classifier_id=classifier_id,
593
+ few_shot_classifier=classifier,
594
+ is_active=True,
595
+ annotations={class_name: [] for class_name in classifier.classes},
596
+ )
597
+ return self._classifiers[classifier_id]
598
+
599
+
600
+ def _create_annotation_task_and_labels_for_classifier(
601
+ session: Session,
602
+ dataset_id: UUID,
603
+ classifier: ClassifierEntry,
604
+ ) -> None:
605
+ """Create annotation task and labels for the classifier.
606
+
607
+ Args:
608
+ session: Database session.
609
+ dataset_id: The dataset ID to which the samples belong.
610
+ classifier: The classifier object to update.
611
+ """
612
+ # Check if the annotation task exists and if not create it.
613
+ if classifier.annotation_task_id is None:
614
+ annotation_task = annotation_task_resolver.create(
615
+ session=session,
616
+ annotation_task=AnnotationTaskTable(
617
+ name=FSC_ANNOTATION_TASK_PREFIX + classifier.few_shot_classifier.name,
618
+ annotation_type=AnnotationType.CLASSIFICATION,
619
+ is_prediction=True,
620
+ ),
621
+ )
622
+ classifier.annotation_task_id = annotation_task.annotation_task_id
623
+
624
+ # Check if the annotation label with the classifier name and class
625
+ # names exists and if not create it.
626
+ if classifier.annotation_label_ids is None:
627
+ annotation_label_ids = []
628
+ for class_name in classifier.few_shot_classifier.classes:
629
+ annotation_label = annotation_label_resolver.create(
630
+ session=session,
631
+ label=AnnotationLabelCreate(
632
+ dataset_id=dataset_id,
633
+ annotation_label_name=classifier.few_shot_classifier.name + "_" + class_name,
634
+ ),
635
+ )
636
+ annotation_label_ids.append(annotation_label.annotation_label_id)
637
+ classifier.annotation_label_ids = annotation_label_ids
638
+
639
+
640
+ def _create_annotated_embeddings(
641
+ session: Session,
642
+ class_to_sample_ids: dict[str, list[UUID]],
643
+ embedding_model_id: UUID,
644
+ ) -> list[AnnotatedEmbedding]:
645
+ """Create annotated embeddings from input data.
646
+
647
+ Args:
648
+ session: Database session.
649
+ class_to_sample_ids: Dictionary mapping class names to sample UUIDs.
650
+ embedding_model_id: The embedding model ID to filter by.
651
+
652
+ Returns:
653
+ List of annotated embeddings for training.
654
+ """
655
+ return [
656
+ AnnotatedEmbedding(embedding=embedding.embedding, annotation=class_name)
657
+ for class_name, sample_uuids in class_to_sample_ids.items()
658
+ for embedding in sample_embedding_resolver.get_by_sample_ids(
659
+ session=session,
660
+ sample_ids=sample_uuids,
661
+ embedding_model_id=embedding_model_id,
662
+ )
663
+ ]