labelr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
labelr/sample.py ADDED
@@ -0,0 +1,186 @@
1
+ import logging
2
+ import random
3
+ import string
4
+
5
+ import datasets
6
+ from openfoodfacts.images import download_image
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def format_annotation_results_from_hf(
12
+ objects: dict, image_width: int, image_height: int
13
+ ):
14
+ """Format annotation results from a HF object detection dataset into Label
15
+ Studio format."""
16
+ annotation_results = []
17
+ for i in range(len(objects["bbox"])):
18
+ bbox = objects["bbox"][i]
19
+ # category_id = objects["category_id"][i]
20
+ category_name = objects["category_name"][i]
21
+ # These are relative coordinates (between 0.0 and 1.0)
22
+ y_min, x_min, y_max, x_max = bbox
23
+ # Make sure the coordinates are within the image boundaries,
24
+ # and convert them to percentages
25
+ y_min = min(max(0, y_min), 1.0) * 100
26
+ x_min = min(max(0, x_min), 1.0) * 100
27
+ y_max = min(max(0, y_max), 1.0) * 100
28
+ x_max = min(max(0, x_max), 1.0) * 100
29
+ x = x_min
30
+ y = y_min
31
+ width = x_max - x_min
32
+ height = y_max - y_min
33
+
34
+ id_ = "".join(random.choices(string.ascii_letters + string.digits, k=10))
35
+ annotation_results.append(
36
+ {
37
+ "id": id_,
38
+ "type": "rectanglelabels",
39
+ "from_name": "label",
40
+ "to_name": "image",
41
+ "original_width": image_width,
42
+ "original_height": image_height,
43
+ "image_rotation": 0,
44
+ "value": {
45
+ "rotation": 0,
46
+ "x": x,
47
+ "y": y,
48
+ "width": width,
49
+ "height": height,
50
+ "rectanglelabels": [category_name],
51
+ },
52
+ },
53
+ )
54
+ return annotation_results
55
+
56
+
57
+ def format_object_detection_sample_from_hf(hf_sample: dict, split: str) -> dict:
58
+ hf_meta = hf_sample["meta"]
59
+ objects = hf_sample["objects"]
60
+ image_width = hf_sample["width"]
61
+ image_height = hf_sample["height"]
62
+ annotation_results = format_annotation_results_from_hf(
63
+ objects, image_width, image_height
64
+ )
65
+ return {
66
+ "data": {
67
+ "image_id": hf_sample["image_id"],
68
+ "image_url": hf_meta["image_url"],
69
+ "batch": "null",
70
+ "split": split,
71
+ "meta": {
72
+ "width": image_width,
73
+ "height": image_height,
74
+ "barcode": hf_meta["barcode"],
75
+ "off_image_id": hf_meta["off_image_id"],
76
+ },
77
+ },
78
+ "predictions": [{"result": annotation_results}],
79
+ }
80
+
81
+
82
+ def format_object_detection_sample_to_ls(
83
+ image_id: str,
84
+ image_url: str,
85
+ width: int,
86
+ height: int,
87
+ extra_meta: dict | None = None,
88
+ ) -> dict:
89
+ """Format an object detection sample in Label Studio format.
90
+
91
+ Args:
92
+ image_id: The image ID.
93
+ image_url: The URL of the image.
94
+ width: The width of the image.
95
+ height: The height of the image.
96
+ extra_meta: Extra metadata to include in the sample.
97
+ """
98
+ extra_meta = extra_meta or {}
99
+ return {
100
+ "data": {
101
+ "image_id": image_id,
102
+ "image_url": image_url,
103
+ "batch": "null",
104
+ "meta": {
105
+ "width": width,
106
+ "height": height,
107
+ **extra_meta,
108
+ },
109
+ },
110
+ }
111
+
112
+
113
+ def format_object_detection_sample_to_hf(
114
+ task_data: dict, annotations: list[dict], category_names: list[str]
115
+ ) -> dict | None:
116
+ if len(annotations) > 1:
117
+ logger.info("More than one annotation found, skipping")
118
+ return None
119
+ elif len(annotations) == 0:
120
+ logger.info("No annotation found, skipping")
121
+ return None
122
+
123
+ annotation = annotations[0]
124
+ bboxes = []
125
+ bbox_category_ids = []
126
+ bbox_category_names = []
127
+
128
+ for annotation_result in annotation["result"]:
129
+ if annotation_result["type"] != "rectanglelabels":
130
+ raise ValueError("Invalid annotation type: %s" % annotation_result["type"])
131
+
132
+ value = annotation_result["value"]
133
+ x_min = value["x"] / 100
134
+ y_min = value["y"] / 100
135
+ width = value["width"] / 100
136
+ height = value["height"] / 100
137
+ x_max = x_min + width
138
+ y_max = y_min + height
139
+ bboxes.append([y_min, x_min, y_max, x_max])
140
+ category_name = value["rectanglelabels"][0]
141
+ bbox_category_names.append(category_name)
142
+ bbox_category_ids.append(category_names.index(category_name))
143
+
144
+ image_url = task_data["image_url"]
145
+ image = download_image(image_url, error_raise=False)
146
+ if image is None:
147
+ logger.error("Failed to download image: %s", image_url)
148
+ return None
149
+
150
+ return {
151
+ "image_id": task_data["image_id"],
152
+ "image": image,
153
+ "width": task_data["meta"]["width"],
154
+ "height": task_data["meta"]["height"],
155
+ "meta": {
156
+ "barcode": task_data["meta"]["barcode"],
157
+ "off_image_id": task_data["meta"]["off_image_id"],
158
+ "image_url": image_url,
159
+ },
160
+ "objects": {
161
+ "bbox": bboxes,
162
+ "category_id": bbox_category_ids,
163
+ "category_name": bbox_category_names,
164
+ },
165
+ }
166
+
167
+
168
+ # The HuggingFace Dataset features
169
+ HF_DS_FEATURES = datasets.Features(
170
+ {
171
+ "image_id": datasets.Value("string"),
172
+ "image": datasets.features.Image(),
173
+ "width": datasets.Value("int64"),
174
+ "height": datasets.Value("int64"),
175
+ "meta": {
176
+ "barcode": datasets.Value("string"),
177
+ "off_image_id": datasets.Value("string"),
178
+ "image_url": datasets.Value("string"),
179
+ },
180
+ "objects": {
181
+ "bbox": datasets.Sequence(datasets.Sequence(datasets.Value("float32"))),
182
+ "category_id": datasets.Sequence(datasets.Value("int64")),
183
+ "category_name": datasets.Sequence(datasets.Value("string")),
184
+ },
185
+ }
186
+ )
@@ -0,0 +1,241 @@
1
+ import dataclasses
2
+ import functools
3
+ import logging
4
+ import time
5
+ from typing import Any, Optional
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ try:
11
+ import grpc
12
+ from tritonclient.grpc import service_pb2, service_pb2_grpc
13
+ from tritonclient.grpc.service_pb2_grpc import GRPCInferenceServiceStub
14
+ except ImportError:
15
+ pass
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ JSONType = dict[str, Any]
21
+
22
+ OBJECT_DETECTION_MODEL_VERSION = {
23
+ "nutriscore": "tf-nutriscore-1.0",
24
+ "nutrition_table": "tf-nutrition-table-1.0",
25
+ "universal_logo_detector": "tf-universal-logo-detector-1.0",
26
+ }
27
+
28
+ LABELS = {
29
+ "nutriscore": [
30
+ "NULL",
31
+ "nutriscore-a",
32
+ "nutriscore-b",
33
+ "nutriscore-c",
34
+ "nutriscore-d",
35
+ "nutriscore-e",
36
+ ],
37
+ }
38
+
39
+ OBJECT_DETECTION_IMAGE_MAX_SIZE = (1024, 1024)
40
+
41
+
42
+ @functools.cache
43
+ def get_triton_inference_stub(
44
+ triton_uri: str,
45
+ ) -> "GRPCInferenceServiceStub":
46
+ """Return a gRPC stub for Triton Inference Server.
47
+
48
+ :param triton_uri: URI of the Triton Inference Server
49
+ :return: gRPC stub for Triton Inference Server
50
+ """
51
+ triton_uri = triton_uri
52
+ channel = grpc.insecure_channel(triton_uri)
53
+ return service_pb2_grpc.GRPCInferenceServiceStub(channel)
54
+
55
+
56
+ def convert_image_to_array(image: Image.Image) -> np.ndarray:
57
+ """Convert a PIL Image into a numpy array.
58
+
59
+ The image is converted to RGB if needed before generating the array.
60
+
61
+ :param image: the input image
62
+ :return: the generated numpy array of shape (width, height, 3)
63
+ """
64
+ if image.mode != "RGB":
65
+ image = image.convert("RGB")
66
+
67
+ (im_width, im_height) = image.size
68
+
69
+ return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
70
+
71
+
72
+ @dataclasses.dataclass
73
+ class ObjectDetectionResult:
74
+ bounding_box: tuple[int, int, int, int]
75
+ score: float
76
+ label: str
77
+
78
+
79
+ @dataclasses.dataclass
80
+ class ObjectDetectionRawResult:
81
+ num_detections: int
82
+ detection_boxes: np.ndarray
83
+ detection_scores: np.ndarray
84
+ detection_classes: np.ndarray
85
+ label_names: list[str]
86
+ detection_masks: Optional[np.ndarray] = None
87
+ boxed_image: Optional[Image.Image] = None
88
+
89
+ def select(self, threshold: Optional[float] = None) -> list[ObjectDetectionResult]:
90
+ if threshold is None:
91
+ threshold = 0.5
92
+
93
+ box_masks = self.detection_scores > threshold
94
+ selected_boxes = self.detection_boxes[box_masks]
95
+ selected_scores = self.detection_scores[box_masks]
96
+ selected_classes = self.detection_classes[box_masks]
97
+
98
+ results = []
99
+ for bounding_box, score, label in zip(
100
+ selected_boxes, selected_scores, selected_classes
101
+ ):
102
+ label_int = int(label)
103
+ label_str = self.label_names[label_int]
104
+ if label_str is not None:
105
+ result = ObjectDetectionResult(
106
+ bounding_box=tuple(bounding_box.tolist()), # type: ignore
107
+ score=float(score),
108
+ label=label_str,
109
+ )
110
+ results.append(result)
111
+
112
+ return results
113
+
114
+ def to_json(self, threshold: Optional[float] = None) -> list[JSONType]:
115
+ return [dataclasses.asdict(r) for r in self.select(threshold)]
116
+
117
+
118
+ def resize_image(image: Image.Image, max_size: tuple[int, int]) -> Image.Image:
119
+ width, height = image.size
120
+ max_width, max_height = max_size
121
+
122
+ if width > max_width or height > max_height:
123
+ new_image = image.copy()
124
+ new_image.thumbnail((max_width, max_height))
125
+ return new_image
126
+
127
+ return image
128
+
129
+
130
+ class RemoteModel:
131
+ def __init__(self, name: str, label_names: list[str]):
132
+ self.name: str = name
133
+ self.label_names = label_names
134
+
135
+ def detect_from_image(
136
+ self,
137
+ image: Image.Image,
138
+ triton_uri: str,
139
+ ) -> ObjectDetectionRawResult:
140
+ """Run object detection model on an image.
141
+
142
+ :param image: the input Pillow image
143
+ :param triton_uri: URI of the Triton Inference Server.
144
+ :return: the detection result
145
+ """
146
+ resized_image = resize_image(image, OBJECT_DETECTION_IMAGE_MAX_SIZE)
147
+ image_array = convert_image_to_array(resized_image)
148
+ grpc_stub = get_triton_inference_stub(triton_uri)
149
+ request = service_pb2.ModelInferRequest()
150
+ request.model_name = self.name
151
+
152
+ image_input = service_pb2.ModelInferRequest().InferInputTensor()
153
+ image_input.name = "inputs"
154
+ image_input.datatype = "UINT8"
155
+ image_input.shape.extend([1, image_array.shape[0], image_array.shape[1], 3])
156
+ request.inputs.extend([image_input])
157
+
158
+ for output_name in (
159
+ "num_detections",
160
+ "detection_classes",
161
+ "detection_scores",
162
+ "detection_boxes",
163
+ ):
164
+ output = service_pb2.ModelInferRequest().InferRequestedOutputTensor()
165
+ output.name = output_name
166
+ request.outputs.extend([output])
167
+
168
+ request.raw_input_contents.extend([image_array.tobytes()])
169
+ start_time = time.monotonic()
170
+ response = grpc_stub.ModelInfer(request)
171
+ logger.debug(
172
+ "Inference time for %s: %s", self.name, time.monotonic() - start_time
173
+ )
174
+
175
+ if len(response.outputs) != 4:
176
+ raise Exception(f"expected 4 output, got {len(response.outputs)}")
177
+
178
+ if len(response.raw_output_contents) != 4:
179
+ raise Exception(
180
+ f"expected 4 raw output content, got {len(response.raw_output_contents)}"
181
+ )
182
+
183
+ output_index = {output.name: i for i, output in enumerate(response.outputs)}
184
+ num_detections = (
185
+ np.frombuffer(
186
+ response.raw_output_contents[output_index["num_detections"]],
187
+ dtype=np.float32,
188
+ )
189
+ .reshape((1, 1))
190
+ .astype(int)[0][0] # type: ignore
191
+ )
192
+ detection_scores = np.frombuffer(
193
+ response.raw_output_contents[output_index["detection_scores"]],
194
+ dtype=np.float32,
195
+ ).reshape((1, -1))[0]
196
+ detection_classes = (
197
+ np.frombuffer(
198
+ response.raw_output_contents[output_index["detection_classes"]],
199
+ dtype=np.float32,
200
+ )
201
+ .reshape((1, -1))
202
+ .astype(int) # type: ignore
203
+ )[0]
204
+ detection_boxes = np.frombuffer(
205
+ response.raw_output_contents[output_index["detection_boxes"]],
206
+ dtype=np.float32,
207
+ ).reshape((1, -1, 4))[0]
208
+
209
+ result = ObjectDetectionRawResult(
210
+ num_detections=num_detections,
211
+ detection_classes=detection_classes,
212
+ detection_boxes=detection_boxes,
213
+ detection_scores=detection_scores,
214
+ detection_masks=None,
215
+ label_names=self.label_names,
216
+ )
217
+
218
+ return result
219
+
220
+
221
+ class ObjectDetectionModelRegistry:
222
+ models: dict[str, RemoteModel] = {}
223
+ _loaded = False
224
+
225
+ @classmethod
226
+ def get_available_models(cls) -> list[str]:
227
+ cls.load_all()
228
+ return list(cls.models.keys())
229
+
230
+ @classmethod
231
+ def load(cls, name: str) -> RemoteModel:
232
+ label_names = LABELS[name]
233
+ model = RemoteModel(name, label_names)
234
+ cls.models[name] = model
235
+ return model
236
+
237
+ @classmethod
238
+ def get(cls, name: str) -> RemoteModel:
239
+ if name not in cls.models:
240
+ cls.load(name)
241
+ return cls.models[name]
labelr/types.py ADDED
@@ -0,0 +1,16 @@
1
+ import enum
2
+
3
+
4
+ class ExportSource(str, enum.Enum):
5
+ hf = "hf"
6
+ ls = "ls"
7
+
8
+
9
+ class ExportDestination(str, enum.Enum):
10
+ hf = "hf"
11
+ ultralytics = "ultralytics"
12
+
13
+
14
+ class TaskType(str, enum.Enum):
15
+ object_detection = "object_detection"
16
+ classification = "classification"