lightly-studio 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lightly-studio might be problematic. Click here for more details.

Files changed (219) hide show
  1. lightly_studio/__init__.py +11 -0
  2. lightly_studio/api/__init__.py +0 -0
  3. lightly_studio/api/app.py +110 -0
  4. lightly_studio/api/cache.py +77 -0
  5. lightly_studio/api/db.py +133 -0
  6. lightly_studio/api/db_tables.py +32 -0
  7. lightly_studio/api/features.py +7 -0
  8. lightly_studio/api/routes/api/annotation.py +233 -0
  9. lightly_studio/api/routes/api/annotation_label.py +90 -0
  10. lightly_studio/api/routes/api/annotation_task.py +38 -0
  11. lightly_studio/api/routes/api/classifier.py +387 -0
  12. lightly_studio/api/routes/api/dataset.py +182 -0
  13. lightly_studio/api/routes/api/dataset_tag.py +257 -0
  14. lightly_studio/api/routes/api/exceptions.py +96 -0
  15. lightly_studio/api/routes/api/features.py +17 -0
  16. lightly_studio/api/routes/api/metadata.py +37 -0
  17. lightly_studio/api/routes/api/metrics.py +80 -0
  18. lightly_studio/api/routes/api/sample.py +196 -0
  19. lightly_studio/api/routes/api/settings.py +45 -0
  20. lightly_studio/api/routes/api/status.py +19 -0
  21. lightly_studio/api/routes/api/text_embedding.py +48 -0
  22. lightly_studio/api/routes/api/validators.py +17 -0
  23. lightly_studio/api/routes/healthz.py +13 -0
  24. lightly_studio/api/routes/images.py +104 -0
  25. lightly_studio/api/routes/webapp.py +51 -0
  26. lightly_studio/api/server.py +82 -0
  27. lightly_studio/core/__init__.py +0 -0
  28. lightly_studio/core/dataset.py +523 -0
  29. lightly_studio/core/sample.py +77 -0
  30. lightly_studio/core/start_gui.py +15 -0
  31. lightly_studio/dataset/__init__.py +0 -0
  32. lightly_studio/dataset/edge_embedding_generator.py +144 -0
  33. lightly_studio/dataset/embedding_generator.py +91 -0
  34. lightly_studio/dataset/embedding_manager.py +163 -0
  35. lightly_studio/dataset/env.py +16 -0
  36. lightly_studio/dataset/file_utils.py +35 -0
  37. lightly_studio/dataset/loader.py +622 -0
  38. lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
  39. lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
  40. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
  41. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
  42. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
  43. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
  44. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
  45. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
  46. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
  47. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
  48. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
  49. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
  50. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
  51. lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
  52. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
  53. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
  54. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
  55. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
  56. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
  57. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
  58. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
  59. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
  60. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
  61. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
  62. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
  63. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
  64. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
  65. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
  66. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
  67. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
  68. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
  69. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
  70. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
  71. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
  72. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
  73. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
  74. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
  75. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
  76. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
  77. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
  78. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
  79. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
  80. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
  81. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
  82. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
  83. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
  84. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
  85. lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
  86. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
  87. lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
  88. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
  89. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
  90. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
  91. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
  92. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
  93. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
  94. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
  95. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
  96. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
  97. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
  98. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
  99. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
  100. lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
  101. lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
  102. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
  103. lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
  104. lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
  105. lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
  106. lightly_studio/examples/example.py +23 -0
  107. lightly_studio/examples/example_metadata.py +338 -0
  108. lightly_studio/examples/example_selection.py +39 -0
  109. lightly_studio/examples/example_split_work.py +67 -0
  110. lightly_studio/examples/example_v2.py +21 -0
  111. lightly_studio/export_schema.py +18 -0
  112. lightly_studio/few_shot_classifier/__init__.py +0 -0
  113. lightly_studio/few_shot_classifier/classifier.py +80 -0
  114. lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
  115. lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
  116. lightly_studio/metadata/complex_metadata.py +47 -0
  117. lightly_studio/metadata/gps_coordinate.py +41 -0
  118. lightly_studio/metadata/metadata_protocol.py +17 -0
  119. lightly_studio/metrics/__init__.py +0 -0
  120. lightly_studio/metrics/detection/__init__.py +0 -0
  121. lightly_studio/metrics/detection/map.py +268 -0
  122. lightly_studio/models/__init__.py +1 -0
  123. lightly_studio/models/annotation/__init__.py +0 -0
  124. lightly_studio/models/annotation/annotation_base.py +171 -0
  125. lightly_studio/models/annotation/instance_segmentation.py +56 -0
  126. lightly_studio/models/annotation/links.py +17 -0
  127. lightly_studio/models/annotation/object_detection.py +47 -0
  128. lightly_studio/models/annotation/semantic_segmentation.py +44 -0
  129. lightly_studio/models/annotation_label.py +47 -0
  130. lightly_studio/models/annotation_task.py +28 -0
  131. lightly_studio/models/classifier.py +20 -0
  132. lightly_studio/models/dataset.py +84 -0
  133. lightly_studio/models/embedding_model.py +30 -0
  134. lightly_studio/models/metadata.py +208 -0
  135. lightly_studio/models/sample.py +180 -0
  136. lightly_studio/models/sample_embedding.py +37 -0
  137. lightly_studio/models/settings.py +60 -0
  138. lightly_studio/models/tag.py +96 -0
  139. lightly_studio/py.typed +0 -0
  140. lightly_studio/resolvers/__init__.py +7 -0
  141. lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
  142. lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
  143. lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
  144. lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
  145. lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
  146. lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
  147. lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
  148. lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
  149. lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
  150. lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
  151. lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
  152. lightly_studio/resolvers/annotation_resolver/create.py +19 -0
  153. lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
  154. lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
  155. lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
  156. lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
  157. lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
  158. lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
  159. lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
  160. lightly_studio/resolvers/annotation_task_resolver.py +31 -0
  161. lightly_studio/resolvers/annotations/__init__.py +1 -0
  162. lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
  163. lightly_studio/resolvers/dataset_resolver.py +278 -0
  164. lightly_studio/resolvers/embedding_model_resolver.py +100 -0
  165. lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
  166. lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
  167. lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
  168. lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
  169. lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
  170. lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
  171. lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
  172. lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
  173. lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
  174. lightly_studio/resolvers/sample_resolver.py +249 -0
  175. lightly_studio/resolvers/samples_filter.py +81 -0
  176. lightly_studio/resolvers/settings_resolver.py +58 -0
  177. lightly_studio/resolvers/tag_resolver.py +276 -0
  178. lightly_studio/selection/README.md +6 -0
  179. lightly_studio/selection/mundig.py +105 -0
  180. lightly_studio/selection/select.py +96 -0
  181. lightly_studio/selection/select_via_db.py +93 -0
  182. lightly_studio/selection/selection_config.py +31 -0
  183. lightly_studio/services/annotations_service/__init__.py +21 -0
  184. lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
  185. lightly_studio/services/annotations_service/update_annotation.py +65 -0
  186. lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
  187. lightly_studio/services/annotations_service/update_annotations.py +29 -0
  188. lightly_studio/setup_logging.py +19 -0
  189. lightly_studio/type_definitions.py +19 -0
  190. lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
  191. lightly_studio/vendor/LICENSE +31 -0
  192. lightly_studio/vendor/LICENSE_weights_data +50 -0
  193. lightly_studio/vendor/README.md +5 -0
  194. lightly_studio/vendor/__init__.py +1 -0
  195. lightly_studio/vendor/mobileclip/__init__.py +96 -0
  196. lightly_studio/vendor/mobileclip/clip.py +77 -0
  197. lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
  198. lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
  199. lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
  200. lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
  201. lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
  202. lightly_studio/vendor/mobileclip/logger.py +154 -0
  203. lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
  204. lightly_studio/vendor/mobileclip/models/mci.py +933 -0
  205. lightly_studio/vendor/mobileclip/models/vit.py +433 -0
  206. lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
  207. lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
  208. lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
  209. lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
  210. lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
  211. lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
  212. lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
  213. lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
  214. lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
  215. lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
  216. lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
  217. lightly_studio-0.3.1.dist-info/METADATA +520 -0
  218. lightly_studio-0.3.1.dist-info/RECORD +219 -0
  219. lightly_studio-0.3.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,489 @@
1
+ """RandomForest classifier implementations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import io
6
+ import pickle
7
+ from dataclasses import dataclass
8
+ from datetime import datetime, timezone
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import sklearn # type: ignore[import-untyped]
13
+ from sklearn.ensemble import ( # type: ignore[import-untyped]
14
+ RandomForestClassifier,
15
+ )
16
+ from sklearn.tree import ( # type: ignore[import-untyped]
17
+ DecisionTreeClassifier,
18
+ )
19
+ from sklearn.utils import validation # type: ignore[import-untyped]
20
+ from typing_extensions import assert_never
21
+
22
+ from .classifier import AnnotatedEmbedding, ExportType, FewShotClassifier
23
+
24
+ # The version of the file format used for exporting and importing classifiers.
25
+ # This is used to ensure compatibility between different versions of the code.
26
+ # If the format changes, this version should be incremented.
27
+ FILE_FORMAT_VERSION = "1.0.0"
28
+
29
+
30
+ @dataclass
31
+ class ModelExportMetadata:
32
+ """Metadata for exporting a model for traceability and reproducibility."""
33
+
34
+ name: str
35
+ file_format_version: str
36
+ model_type: str
37
+ created_at: str
38
+ class_names: list[str]
39
+ num_input_features: int
40
+ num_estimators: int
41
+ embedding_model_hash: str
42
+ embedding_model_name: str
43
+ sklearn_version: str
44
+
45
+
46
+ @dataclass
47
+ class InnerNode:
48
+ """Inner node of a decision tree.
49
+
50
+ Defaults are used for tree construction.
51
+ """
52
+
53
+ feature_index: int = 0
54
+ threshold: float = 0.0
55
+ left_child: int = 0
56
+ right_child: int = 0
57
+
58
+
59
+ @dataclass
60
+ class LeafNode:
61
+ """Leaf node of a decision tree."""
62
+
63
+ class_probabilities: list[float]
64
+
65
+
66
+ @dataclass
67
+ class ExportedTree:
68
+ """Exported tree structure."""
69
+
70
+ inner_nodes: list[InnerNode]
71
+ leaf_nodes: list[LeafNode]
72
+
73
+
74
+ @dataclass
75
+ class RandomForestExport:
76
+ """Datastructure for exporting the RandomForest model."""
77
+
78
+ metadata: ModelExportMetadata
79
+ trees: list[ExportedTree]
80
+
81
+
82
+ class RandomForest(FewShotClassifier):
83
+ """RandomForest classifier."""
84
+
85
+ def __init__(
86
+ self,
87
+ name: str,
88
+ classes: list[str],
89
+ embedding_model_name: str,
90
+ embedding_model_hash: str,
91
+ ) -> None:
92
+ """Initialize the RandomForestClassifier with predefined classes.
93
+
94
+ Args:
95
+ name: Name of the classifier.
96
+ classes: Ordered list of class labels that will be used for training
97
+ and predictions. The order of this list determines the order of
98
+ probability values in predictions.
99
+ embedding_model_name: Name of the model used for creating the
100
+ embeddings.
101
+ embedding_model_hash: Hash of the model used for creating the
102
+ embeddings.
103
+ Note: embedding_model_name and embedding_model_hash are used for
104
+ traceability in the exported model metadata.
105
+
106
+ Raises:
107
+ ValueError: If classes list is empty.
108
+ """
109
+ if not classes:
110
+ raise ValueError("Class list cannot be empty.")
111
+
112
+ # Fix the random seed for reproducibility.
113
+ self._model = RandomForestClassifier(class_weight="balanced", random_state=42)
114
+ self.name = name
115
+ self.classes = classes
116
+ self._class_to_index = {label: idx for idx, label in enumerate(classes)}
117
+ self._embedding_model_name = embedding_model_name
118
+ self.embedding_model_hash = embedding_model_hash
119
+
120
+ def train(self, annotated_embeddings: list[AnnotatedEmbedding]) -> None:
121
+ """Trains a classifier using the provided input.
122
+
123
+ Args:
124
+ annotated_embeddings: A list of annotated embeddings to train the
125
+ classifier.
126
+
127
+ Raises:
128
+ ValueError: If annotated_embeddings is empty or contains invalid
129
+ classes.
130
+ """
131
+ if not annotated_embeddings:
132
+ raise ValueError("annotated_embeddings cannot be empty.")
133
+
134
+ # Extract embeddings and labels.
135
+ embeddings = [ae.embedding for ae in annotated_embeddings]
136
+ labels = [ae.annotation for ae in annotated_embeddings]
137
+ # Validate that all labels are in predefined classes.
138
+ invalid_labels = set(labels) - set(self.classes)
139
+ if invalid_labels:
140
+ raise ValueError(f"Found labels not in predefined classes: {invalid_labels}")
141
+
142
+ # Convert to NumPy arrays.
143
+ embeddings_np = np.array(embeddings)
144
+ labels_encoded = [self._class_to_index[label] for label in labels]
145
+
146
+ # Train the RandomForestClassifier.
147
+ self._model.fit(embeddings_np, labels_encoded)
148
+
149
+ def predict(self, embeddings: list[list[float]]) -> list[list[float]]:
150
+ """Predicts the classification scores for a list of embeddings.
151
+
152
+ Args:
153
+ embeddings: A list of embeddings, where each embedding is a list of
154
+ floats.
155
+
156
+ Returns:
157
+ A list of lists, where each inner list represents the probability
158
+ distribution over classes for the corresponding input embedding.
159
+ Each value in the inner list corresponds to the likelihood of the
160
+ embedding belonging to a specific class.
161
+ If embeddings is empty, returns an empty list.
162
+ """
163
+ if len(embeddings) == 0:
164
+ return []
165
+
166
+ # Convert embeddings to a NumPy array.
167
+ embeddings_np = np.array(embeddings)
168
+
169
+ # Get the classes that the model was trained on.
170
+ trained_classes: list[int] = self._model.classes_
171
+
172
+ # Initialize full-size probability array.
173
+ full_probabilities = []
174
+
175
+ # Get raw probabilities from model.
176
+ raw_probabilities = self._model.predict_proba(embeddings_np)
177
+
178
+ for raw_probs in raw_probabilities:
179
+ # Initialize zeros for all possible classes.
180
+ full_probs = [0.0 for _ in range(len(self.classes))]
181
+ # Map probabilities to their correct positions.
182
+ for trained_class, prob in zip(trained_classes, raw_probs):
183
+ full_probs[trained_class] = prob
184
+ full_probabilities.append(full_probs)
185
+ return full_probabilities
186
+
187
+ def export(
188
+ self,
189
+ export_path: Path | None = None,
190
+ buffer: io.BytesIO | None = None,
191
+ export_type: ExportType = "sklearn",
192
+ ) -> None:
193
+ """Exports the classifier to a specified file.
194
+
195
+ Args:
196
+ export_path: The full file path where the export will be saved.
197
+ buffer: A BytesIO buffer to save the export to.
198
+ export_type: The type of export. Options are:
199
+ "sklearn": Exports the RandomForestClassifier instance.
200
+ "lightly": Exports the model in raw format with metadata
201
+ and tree details.
202
+ """
203
+ metadata = ModelExportMetadata(
204
+ name=self.name,
205
+ file_format_version=FILE_FORMAT_VERSION,
206
+ model_type="RandomForest",
207
+ created_at=str(datetime.now(timezone.utc).isoformat()),
208
+ class_names=self.classes,
209
+ num_input_features=self._model.n_features_in_,
210
+ num_estimators=len(self._model.estimators_),
211
+ embedding_model_hash=self.embedding_model_hash,
212
+ embedding_model_name=self._embedding_model_name,
213
+ sklearn_version=sklearn.__version__,
214
+ )
215
+
216
+ if export_type == "sklearn":
217
+ # Combine the model and metadata into a single dictionary
218
+ export_data = {
219
+ "model": self._model,
220
+ "metadata": metadata,
221
+ }
222
+
223
+ if buffer is not None:
224
+ pickle.dump(export_data, buffer)
225
+ elif export_path is not None:
226
+ # Save to the specified file path.
227
+ # Ensure parent dirs exist.
228
+ export_path.parent.mkdir(parents=True, exist_ok=True)
229
+ with open(export_path, "wb") as f:
230
+ pickle.dump(export_data, f)
231
+
232
+ elif export_type == "lightly":
233
+ export_data_raw = _export_random_forest_model(
234
+ model=self._model,
235
+ metadata=metadata,
236
+ all_classes=self.classes,
237
+ )
238
+ if buffer is not None:
239
+ pickle.dump(export_data_raw, buffer)
240
+ elif export_path is not None:
241
+ # Save to the specified file path.
242
+ # Ensure parent dirs exist.
243
+ export_path.parent.mkdir(parents=True, exist_ok=True)
244
+ with open(export_path, "wb") as f:
245
+ pickle.dump(export_data_raw, f)
246
+ else:
247
+ assert_never(export_type)
248
+
249
+ def is_trained(self) -> bool:
250
+ """Checks if the classifier is trained.
251
+
252
+ Returns:
253
+ True if the classifier is trained, False otherwise.
254
+ """
255
+ try:
256
+ validation.check_is_fitted(self._model)
257
+ return True
258
+ except sklearn.exceptions.NotFittedError:
259
+ return False
260
+
261
+
262
+ def load_random_forest_classifier(
263
+ classifier_path: Path | None, buffer: io.BytesIO | None
264
+ ) -> RandomForest:
265
+ """Loads a RandomForest classifier from a file or a buffer.
266
+
267
+ Args:
268
+ classifier_path: The path to the exported classifier file.
269
+ buffer: A BytesIO buffer containing the exported classifier.
270
+ If both path and buffer are provided, the path will be used.
271
+
272
+ Returns:
273
+ A fully initialized RandomForest classifier instance.
274
+
275
+ Raises:
276
+ FileNotFoundError: If the classifier_path does not exist.
277
+ ValueError: If the file is not a valid 'sklearn' pickled export
278
+ or if the version/format mismatches.
279
+ """
280
+ if classifier_path is not None:
281
+ if not classifier_path.exists():
282
+ raise FileNotFoundError(f"The file {classifier_path} does not exist.")
283
+
284
+ with open(classifier_path, "rb") as f:
285
+ export_data = pickle.load(f)
286
+ elif buffer is not None:
287
+ export_data = pickle.load(buffer)
288
+
289
+ model = export_data.get("model")
290
+ metadata: ModelExportMetadata = export_data.get("metadata")
291
+
292
+ if model is None or metadata is None:
293
+ raise ValueError("The loaded file does not contain a valid model or metadata.")
294
+
295
+ if metadata.file_format_version != FILE_FORMAT_VERSION:
296
+ raise ValueError(
297
+ f"File format version mismatch. Expected '{FILE_FORMAT_VERSION}', "
298
+ f"got '{metadata.file_format_version}'."
299
+ )
300
+ instance = RandomForest(
301
+ name=metadata.name,
302
+ classes=metadata.class_names,
303
+ embedding_model_name=metadata.embedding_model_name,
304
+ embedding_model_hash=metadata.embedding_model_hash,
305
+ )
306
+ # Set the model.
307
+ instance._model = model # noqa: SLF001
308
+ return instance
309
+
310
+
311
+ def _export_random_forest_model(
312
+ model: RandomForestClassifier,
313
+ metadata: ModelExportMetadata,
314
+ all_classes: list[str],
315
+ ) -> RandomForestExport:
316
+ """Converts a sk-learn RandomForestClassifier to RandomForestExport format.
317
+
318
+ Args:
319
+ model: The trained random forest model to export.
320
+ metadata: Metadata describing the dataset and training setup.
321
+ all_classes: Full list of all class labels.
322
+
323
+ Returns:
324
+ RandomForestExport: The serialized export object containing all trees
325
+ and metadata.
326
+ """
327
+ trained_classes: list[int] = model.classes_
328
+ trees = [_export_single_tree(tree, trained_classes, all_classes) for tree in model.estimators_]
329
+ return RandomForestExport(metadata=metadata, trees=trees)
330
+
331
+
332
+ def load_lightly_random_forest(path: Path | None, buffer: io.BytesIO | None) -> RandomForestExport:
333
+ """Loads a Lightly exported RandomForest model from a file or buffer.
334
+
335
+ Args:
336
+ path: The path to the exported classifier file.
337
+ buffer: A BytesIO buffer containing the exported classifier.
338
+ If both path and buffer are provided, the path will be used.
339
+
340
+ Returns:
341
+ A RandomForestExport instance.
342
+
343
+ Raises:
344
+ ValueError: If the file is not a valid RandomForestExport or
345
+ if the version/format mismatches.
346
+ """
347
+ if path is not None:
348
+ with open(path, "rb") as f:
349
+ data = pickle.load(f)
350
+ elif buffer is not None:
351
+ data = pickle.load(buffer)
352
+
353
+ if not isinstance(data, RandomForestExport):
354
+ raise ValueError("Loaded object is not a RandomForestExport instance.")
355
+
356
+ if data.metadata.file_format_version != FILE_FORMAT_VERSION:
357
+ raise ValueError(
358
+ f"File format version mismatch. Expected '{FILE_FORMAT_VERSION}', "
359
+ f"got '{data.metadata.file_format_version}'."
360
+ )
361
+ return data
362
+
363
+
364
+ def predict_with_lightly_random_forest(
365
+ model: RandomForestExport, embeddings: list[list[float]]
366
+ ) -> list[list[float]]:
367
+ """Predicts the classification scores for a list of embeddings.
368
+
369
+ Args:
370
+ model: A RandomForestExport instance containing the model and metadata.
371
+ embeddings: A list of embeddings.
372
+
373
+ Returns:
374
+ A list of lists, where each inner list represents the probability
375
+ distribution over classes for the corresponding input embedding.
376
+
377
+ Raises:
378
+ ValueError: If the provided embeddings have different size than
379
+ expected.
380
+ """
381
+ expected_dim = model.metadata.num_input_features
382
+ all_probs: list[list[float]] = []
383
+
384
+ for embedding in embeddings:
385
+ if len(embedding) != expected_dim:
386
+ raise ValueError(
387
+ f"Embedding has wrong dimensionality: expected {expected_dim},got {len(embedding)}"
388
+ )
389
+
390
+ tree_probs: list[list[float]] = [
391
+ _predict_tree_probs(tree, embedding) for tree in model.trees
392
+ ]
393
+
394
+ mean_probs = np.mean(tree_probs, axis=0).tolist()
395
+ all_probs.append(mean_probs)
396
+
397
+ return all_probs
398
+
399
+
400
+ def _export_single_tree(
401
+ tree: DecisionTreeClassifier,
402
+ trained_classes: list[int],
403
+ all_classes: list[str],
404
+ ) -> ExportedTree:
405
+ """Converts a single sk-learn tree into a serializable ExportedTree format.
406
+
407
+ Args:
408
+ tree: The decision tree to convert.
409
+ trained_classes: Indices of the classes the tree was trained on.
410
+ all_classes: Full list of all class labels.
411
+
412
+ Returns:
413
+ ExportedTree: A representation of the tree with explicit node and leaf
414
+ structures, compatible with the Lightly format.
415
+ """
416
+ tree_structure = tree.tree_
417
+ inner_nodes: list[InnerNode] = []
418
+ leaf_nodes: list[LeafNode] = []
419
+ node_map = {} # Maps node_id to (mapped_index, is_leaf)
420
+
421
+ for node_id in range(tree_structure.node_count):
422
+ is_leaf = tree_structure.children_left[node_id] == tree_structure.children_right[node_id]
423
+ if is_leaf:
424
+ leaf_idx = len(leaf_nodes)
425
+ # value[node_id] is a 2D array of shape [1, n_classes].
426
+ # [0] is used to extract the inner array and
427
+ # convert it to a 1D array of class counts.
428
+ class_weights = tree_structure.value[node_id][0]
429
+ total = sum(class_weights)
430
+ probs = (class_weights / total).tolist() if total > 0 else [0.0] * len(class_weights)
431
+
432
+ # Order probabilities according to the initial classes.
433
+ # Initialize zeros for all possible classes.
434
+ full_probs = [0.0 for _ in range(len(all_classes))]
435
+ # Map probabilities to their correct positions.
436
+ for trained_class, prob in zip(trained_classes, probs):
437
+ full_probs[trained_class] = prob
438
+
439
+ leaf_nodes.append(LeafNode(class_probabilities=full_probs))
440
+ node_map[node_id] = (-leaf_idx - 1, True)
441
+ else:
442
+ inner_idx = len(inner_nodes)
443
+ node_map[node_id] = (inner_idx, False)
444
+ # Reserve a spot for the inner node.
445
+ inner_nodes.append(InnerNode())
446
+
447
+ # Now populate inner_nodes using mapped indices.
448
+ for node_id in range(tree_structure.node_count):
449
+ mapped_idx, is_leaf = node_map[node_id]
450
+ if is_leaf:
451
+ continue
452
+
453
+ left_id = tree_structure.children_left[node_id]
454
+ right_id = tree_structure.children_right[node_id]
455
+ left_mapped = node_map[left_id][0]
456
+ right_mapped = node_map[right_id][0]
457
+
458
+ inner_nodes[mapped_idx] = InnerNode(
459
+ feature_index=int(tree_structure.feature[node_id]),
460
+ threshold=float(tree_structure.threshold[node_id]),
461
+ left_child=left_mapped,
462
+ right_child=right_mapped,
463
+ )
464
+
465
+ return ExportedTree(inner_nodes=inner_nodes, leaf_nodes=leaf_nodes)
466
+
467
+
468
+ def _predict_tree_probs(tree: ExportedTree, embedding: list[float]) -> list[float]:
469
+ """Predicts class probabilities for an embedding using a single tree.
470
+
471
+ Args:
472
+ tree: A ExportedTree instance used to determine the probability.
473
+ embedding: A single embedding.
474
+
475
+ """
476
+ if not tree.inner_nodes:
477
+ return tree.leaf_nodes[0].class_probabilities
478
+
479
+ node_idx = 0 # Start at root
480
+ while node_idx >= 0:
481
+ node = tree.inner_nodes[node_idx]
482
+ if embedding[node.feature_index] <= node.threshold:
483
+ node_idx = node.left_child
484
+ else:
485
+ node_idx = node.right_child
486
+
487
+ leaf_idx = -node_idx - 1
488
+ leaf = tree.leaf_nodes[leaf_idx]
489
+ return leaf.class_probabilities
@@ -0,0 +1,47 @@
1
+ """Complex metadata types that can be stored in JSON columns."""
2
+
3
+ from typing import Any, Dict, Type
4
+
5
+ from lightly_studio.metadata.gps_coordinate import GPSCoordinate
6
+ from lightly_studio.metadata.metadata_protocol import ComplexMetadata
7
+
8
+ # Registry of complex metadata types for automatic serialization/deserialization
9
+ COMPLEX_METADATA_TYPES: Dict[str, Type[ComplexMetadata]] = {
10
+ "gps_coordinate": GPSCoordinate,
11
+ }
12
+
13
+
14
+ def serialize_complex_metadata(value: Any) -> Any:
15
+ """Serialize complex metadata for JSON storage.
16
+
17
+ Args:
18
+ value: Value to serialize.
19
+
20
+ Returns:
21
+ Serialized value if it is ComplexMetadata, the original
22
+ value otherwise.
23
+ """
24
+ if isinstance(value, ComplexMetadata):
25
+ return value.as_dict()
26
+
27
+ return value
28
+
29
+
30
+ def deserialize_complex_metadata(value: Any, expected_type: str) -> Any:
31
+ """Deserialize complex metadata from JSON storage.
32
+
33
+ Args:
34
+ value: Value to deserialize.
35
+ expected_type: Expected type name from schema (e.g., "gps_coordinate").
36
+
37
+ Returns:
38
+ Deserialized value (complex metadata object if applicable).
39
+ """
40
+ # If we have an expected type and the value is a dict, try to deserialize.
41
+ if expected_type and isinstance(value, dict) and expected_type in COMPLEX_METADATA_TYPES:
42
+ try:
43
+ return COMPLEX_METADATA_TYPES[expected_type].from_dict(value)
44
+ except (KeyError, TypeError):
45
+ # If deserialization fails, return the original value.
46
+ pass
47
+ return value
@@ -0,0 +1,41 @@
1
+ """GPS coordinate representation for complex metadata."""
2
+
3
+ from typing import Dict
4
+
5
+
6
+ class GPSCoordinate:
7
+ """Represents a GPS coordinate."""
8
+
9
+ def __init__(self, lat: float, lon: float):
10
+ """Initialize GPS coordinate.
11
+
12
+ Args:
13
+ lat: Latitude in decimal degrees.
14
+ lon: Longitude in decimal degrees.
15
+ """
16
+ self.lat = lat
17
+ self.lon = lon
18
+
19
+ def __repr__(self) -> str:
20
+ """String representation of the GPS coordinate."""
21
+ return f"GPSCoordinate(lat={self.lat}, lon={self.lon})"
22
+
23
+ def as_dict(self) -> Dict[str, float]:
24
+ """Convert the GPSCoordinate to a dictionary.
25
+
26
+ Returns:
27
+ Dictionary with 'lat' and 'lon' keys.
28
+ """
29
+ return {"lat": self.lat, "lon": self.lon}
30
+
31
+ @classmethod
32
+ def from_dict(cls, data: Dict[str, float]) -> "GPSCoordinate":
33
+ """Create a GPSCoordinate from a dictionary.
34
+
35
+ Args:
36
+ data: Dictionary with 'lat' and 'lon' keys.
37
+
38
+ Returns:
39
+ A GPSCoordinate instance.
40
+ """
41
+ return cls(lat=data["lat"], lon=data["lon"])
@@ -0,0 +1,17 @@
1
+ """Protocol for complex metadata types that can be stored in JSON columns."""
2
+
3
+ from typing import Any, Dict, Protocol, runtime_checkable
4
+
5
+
6
+ @runtime_checkable
7
+ class ComplexMetadata(Protocol):
8
+ """Protocol for complex types that can be serialized to/from JSON."""
9
+
10
+ def as_dict(self) -> Dict[str, Any]:
11
+ """Convert the complex metadata to a dictionary for JSON storage."""
12
+ ...
13
+
14
+ @classmethod
15
+ def from_dict(cls, data: Dict[str, Any]) -> "ComplexMetadata":
16
+ """Create the complex metadata from a dictionary."""
17
+ ...
File without changes
File without changes