labelr 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- labelr/annotate.py +2 -2
- labelr/main.py +29 -225
- {labelr-0.1.0.dist-info → labelr-0.2.0.dist-info}/METADATA +10 -37
- {labelr-0.1.0.dist-info → labelr-0.2.0.dist-info}/RECORD +8 -9
- labelr/triton/object_detection.py +0 -241
- {labelr-0.1.0.dist-info → labelr-0.2.0.dist-info}/LICENSE +0 -0
- {labelr-0.1.0.dist-info → labelr-0.2.0.dist-info}/WHEEL +0 -0
- {labelr-0.1.0.dist-info → labelr-0.2.0.dist-info}/entry_points.txt +0 -0
- {labelr-0.1.0.dist-info → labelr-0.2.0.dist-info}/top_level.txt +0 -0
labelr/annotate.py
CHANGED
|
@@ -4,17 +4,17 @@ import string
|
|
|
4
4
|
from openfoodfacts.utils import get_logger
|
|
5
5
|
|
|
6
6
|
try:
|
|
7
|
+
from openfoodfacts.ml.object_detection import ObjectDetectionRawResult
|
|
7
8
|
from ultralytics.engine.results import Results
|
|
8
9
|
except ImportError:
|
|
9
10
|
pass
|
|
10
11
|
|
|
11
|
-
from labelr.triton.object_detection import ObjectDetectionResult
|
|
12
12
|
|
|
13
13
|
logger = get_logger(__name__)
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def format_annotation_results_from_triton(
|
|
17
|
-
objects: list[
|
|
17
|
+
objects: list["ObjectDetectionRawResult"], image_width: int, image_height: int
|
|
18
18
|
):
|
|
19
19
|
"""Format annotation results from a Triton object detection model into
|
|
20
20
|
Label Studio format."""
|
labelr/main.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Annotated
|
|
1
|
+
from typing import Annotated
|
|
2
2
|
|
|
3
3
|
import typer
|
|
4
4
|
from openfoodfacts.utils import get_logger
|
|
@@ -6,7 +6,6 @@ from openfoodfacts.utils import get_logger
|
|
|
6
6
|
from labelr.apps import datasets as dataset_app
|
|
7
7
|
from labelr.apps import projects as project_app
|
|
8
8
|
from labelr.apps import users as user_app
|
|
9
|
-
from labelr.config import LABEL_STUDIO_DEFAULT_URL
|
|
10
9
|
|
|
11
10
|
app = typer.Typer(pretty_exceptions_show_locals=False)
|
|
12
11
|
|
|
@@ -14,243 +13,48 @@ logger = get_logger()
|
|
|
14
13
|
|
|
15
14
|
|
|
16
15
|
@app.command()
|
|
17
|
-
def
|
|
16
|
+
def predict(
|
|
18
17
|
model_name: Annotated[
|
|
19
18
|
str, typer.Option(help="Name of the object detection model to run")
|
|
20
19
|
],
|
|
20
|
+
label_names: Annotated[list[str], typer.Argument(help="List of label names")],
|
|
21
21
|
image_url: Annotated[str, typer.Option(help="URL of the image to process")],
|
|
22
22
|
triton_uri: Annotated[
|
|
23
23
|
str, typer.Option(help="URI (host+port) of the Triton Inference Server")
|
|
24
24
|
],
|
|
25
|
-
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
model = ObjectDetectionModelRegistry.get(model_name)
|
|
32
|
-
image = get_image_from_url(image_url)
|
|
33
|
-
output = model.detect_from_image(image, triton_uri=triton_uri)
|
|
34
|
-
results = output.select(threshold=threshold)
|
|
35
|
-
|
|
36
|
-
for result in results:
|
|
37
|
-
typer.echo(result)
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
# Temporary scripts
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
@app.command()
|
|
44
|
-
def skip_rotated_images(
|
|
45
|
-
api_key: Annotated[str, typer.Option(envvar="LABEL_STUDIO_API_KEY")],
|
|
46
|
-
project_id: Annotated[int, typer.Option(help="Label Studio project ID")],
|
|
47
|
-
updated_by: Annotated[
|
|
48
|
-
Optional[int], typer.Option(help="User ID to declare as annotator")
|
|
49
|
-
] = None,
|
|
50
|
-
label_studio_url: str = LABEL_STUDIO_DEFAULT_URL,
|
|
51
|
-
):
|
|
52
|
-
import requests
|
|
53
|
-
import tqdm
|
|
54
|
-
from label_studio_sdk.client import LabelStudio
|
|
55
|
-
from label_studio_sdk.types.task import Task
|
|
56
|
-
from openfoodfacts.ocr import OCRResult
|
|
57
|
-
|
|
58
|
-
session = requests.Session()
|
|
59
|
-
ls = LabelStudio(base_url=label_studio_url, api_key=api_key)
|
|
60
|
-
|
|
61
|
-
task: Task
|
|
62
|
-
for task in tqdm.tqdm(
|
|
63
|
-
ls.tasks.list(project=project_id, fields="all"), desc="tasks"
|
|
64
|
-
):
|
|
65
|
-
if any(annotation["was_cancelled"] for annotation in task.annotations):
|
|
66
|
-
continue
|
|
67
|
-
|
|
68
|
-
assert task.total_annotations == 1, (
|
|
69
|
-
"Task has multiple annotations (%s)" % task.id
|
|
70
|
-
)
|
|
71
|
-
task_id = task.id
|
|
72
|
-
|
|
73
|
-
annotation = task.annotations[0]
|
|
74
|
-
annotation_id = annotation["id"]
|
|
75
|
-
|
|
76
|
-
ocr_url = task.data["image_url"].replace(".jpg", ".json")
|
|
77
|
-
ocr_result = OCRResult.from_url(ocr_url, session=session, error_raise=False)
|
|
78
|
-
|
|
79
|
-
if ocr_result is None:
|
|
80
|
-
logger.warning("No OCR result for task: %s", task_id)
|
|
81
|
-
continue
|
|
82
|
-
|
|
83
|
-
orientation_result = ocr_result.get_orientation()
|
|
84
|
-
|
|
85
|
-
if orientation_result is None:
|
|
86
|
-
# logger.info("No orientation for task: %s", task_id)
|
|
87
|
-
continue
|
|
88
|
-
|
|
89
|
-
orientation = orientation_result.orientation.name
|
|
90
|
-
if orientation != "up":
|
|
91
|
-
logger.info(
|
|
92
|
-
"Skipping rotated image for task: %s (orientation: %s)",
|
|
93
|
-
task_id,
|
|
94
|
-
orientation,
|
|
95
|
-
)
|
|
96
|
-
ls.annotations.update(
|
|
97
|
-
id=annotation_id,
|
|
98
|
-
was_cancelled=True,
|
|
99
|
-
updated_by=updated_by,
|
|
100
|
-
)
|
|
101
|
-
elif orientation == "up":
|
|
102
|
-
logger.debug("Keeping annotation for task: %s", task_id)
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
@app.command()
|
|
106
|
-
def fix_label(
|
|
107
|
-
api_key: Annotated[str, typer.Option(envvar="LABEL_STUDIO_API_KEY")],
|
|
108
|
-
project_id: Annotated[int, typer.Option(help="Label Studio project ID")],
|
|
109
|
-
label_studio_url: str = LABEL_STUDIO_DEFAULT_URL,
|
|
110
|
-
):
|
|
111
|
-
import tqdm
|
|
112
|
-
from label_studio_sdk.client import LabelStudio
|
|
113
|
-
from label_studio_sdk.types.task import Task
|
|
114
|
-
|
|
115
|
-
ls = LabelStudio(base_url=label_studio_url, api_key=api_key)
|
|
116
|
-
|
|
117
|
-
task: Task
|
|
118
|
-
for task in tqdm.tqdm(
|
|
119
|
-
ls.tasks.list(project=project_id, fields="all"), desc="tasks"
|
|
120
|
-
):
|
|
121
|
-
for prediction in task.predictions:
|
|
122
|
-
updated = False
|
|
123
|
-
if "result" in prediction:
|
|
124
|
-
for result in prediction["result"]:
|
|
125
|
-
value = result["value"]
|
|
126
|
-
if "rectanglelabels" in value and value["rectanglelabels"] != [
|
|
127
|
-
"price-tag"
|
|
128
|
-
]:
|
|
129
|
-
value["rectanglelabels"] = ["price-tag"]
|
|
130
|
-
updated = True
|
|
131
|
-
|
|
132
|
-
if updated:
|
|
133
|
-
print(f"Updating prediction {prediction['id']}, task {task.id}")
|
|
134
|
-
ls.predictions.update(prediction["id"], result=prediction["result"])
|
|
135
|
-
|
|
136
|
-
for annotation in task.annotations:
|
|
137
|
-
updated = False
|
|
138
|
-
if "result" in annotation:
|
|
139
|
-
for result in annotation["result"]:
|
|
140
|
-
value = result["value"]
|
|
141
|
-
if "rectanglelabels" in value and value["rectanglelabels"] != [
|
|
142
|
-
"price-tag"
|
|
143
|
-
]:
|
|
144
|
-
value["rectanglelabels"] = ["price-tag"]
|
|
145
|
-
updated = True
|
|
146
|
-
|
|
147
|
-
if updated:
|
|
148
|
-
print(f"Updating annotation {annotation['id']}, task {task.id}")
|
|
149
|
-
ls.annotations.update(annotation["id"], result=annotation["result"])
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
@app.command()
|
|
153
|
-
def select_price_tag_images(
|
|
154
|
-
api_key: Annotated[str, typer.Option(envvar="LABEL_STUDIO_API_KEY")],
|
|
155
|
-
project_id: Annotated[int, typer.Option(help="Label Studio project ID")],
|
|
156
|
-
label_studio_url: str = LABEL_STUDIO_DEFAULT_URL,
|
|
157
|
-
):
|
|
158
|
-
import typing
|
|
159
|
-
from pathlib import Path
|
|
160
|
-
from typing import Any
|
|
161
|
-
from urllib.parse import urlparse
|
|
162
|
-
|
|
163
|
-
import requests
|
|
164
|
-
import tqdm
|
|
165
|
-
from label_studio_sdk.client import LabelStudio
|
|
166
|
-
from label_studio_sdk.types.task import Task
|
|
167
|
-
|
|
168
|
-
session = requests.Session()
|
|
169
|
-
ls = LabelStudio(base_url=label_studio_url, api_key=api_key)
|
|
170
|
-
|
|
171
|
-
proof_paths = (Path(__file__).parent / "proof.txt").read_text().splitlines()
|
|
172
|
-
task: Task
|
|
173
|
-
for task in tqdm.tqdm(
|
|
174
|
-
ls.tasks.list(project=project_id, include="data,id"), desc="tasks"
|
|
175
|
-
):
|
|
176
|
-
data = typing.cast(dict[str, Any], task.data)
|
|
177
|
-
|
|
178
|
-
if "is_raw_product_shelf" in data:
|
|
179
|
-
continue
|
|
180
|
-
image_url = data["image_url"]
|
|
181
|
-
file_path = urlparse(image_url).path.replace("/img/", "")
|
|
182
|
-
r = session.get(
|
|
183
|
-
f"https://robotoff.openfoodfacts.org/api/v1/images/predict?image_url={image_url}&models=price_proof_classification",
|
|
184
|
-
)
|
|
185
|
-
|
|
186
|
-
if r.status_code != 200:
|
|
187
|
-
print(
|
|
188
|
-
f"Failed to get prediction for {image_url}, error: {r.text} (status: {r.status_code})"
|
|
189
|
-
)
|
|
190
|
-
continue
|
|
191
|
-
|
|
192
|
-
prediction = r.json()["predictions"]["price_proof_classification"][0]["label"]
|
|
193
|
-
|
|
194
|
-
is_raw_preduct_shelf = False
|
|
195
|
-
if prediction in ("PRICE_TAG", "SHELF"):
|
|
196
|
-
is_raw_preduct_shelf = file_path in proof_paths
|
|
197
|
-
|
|
198
|
-
ls.tasks.update(
|
|
199
|
-
task.id,
|
|
200
|
-
data={
|
|
201
|
-
**data,
|
|
202
|
-
"is_raw_product_shelf": "true" if is_raw_preduct_shelf else "false",
|
|
203
|
-
},
|
|
204
|
-
)
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
@app.command()
|
|
208
|
-
def add_predicted_category(
|
|
209
|
-
api_key: Annotated[str, typer.Option(envvar="LABEL_STUDIO_API_KEY")],
|
|
210
|
-
project_id: Annotated[int, typer.Option(help="Label Studio project ID")],
|
|
211
|
-
label_studio_url: str = LABEL_STUDIO_DEFAULT_URL,
|
|
25
|
+
image_size: Annotated[
|
|
26
|
+
int, typer.Option(help="Size of the image the model expects")
|
|
27
|
+
] = 640,
|
|
28
|
+
threshold: Annotated[float, typer.Option(help="Detection threshold")] = 0.5,
|
|
29
|
+
triton_model_version: str = "1",
|
|
212
30
|
):
|
|
31
|
+
"""Predict objects in an image using an object detection model served by
|
|
32
|
+
Triton."""
|
|
213
33
|
import typing
|
|
214
|
-
from typing import Any
|
|
215
|
-
|
|
216
|
-
import requests
|
|
217
|
-
import tqdm
|
|
218
|
-
from label_studio_sdk.client import LabelStudio
|
|
219
|
-
from label_studio_sdk.types.task import Task
|
|
220
|
-
|
|
221
|
-
session = requests.Session()
|
|
222
|
-
ls = LabelStudio(base_url=label_studio_url, api_key=api_key)
|
|
223
34
|
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
):
|
|
228
|
-
data = typing.cast(dict[str, Any], task.data)
|
|
35
|
+
from openfoodfacts.ml.object_detection import ObjectDetector
|
|
36
|
+
from openfoodfacts.utils import get_image_from_url
|
|
37
|
+
from PIL import Image
|
|
229
38
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
f"https://robotoff.openfoodfacts.org/api/v1/images/predict?image_url={image_url}&models=price_proof_classification",
|
|
235
|
-
)
|
|
39
|
+
model = ObjectDetector(
|
|
40
|
+
model_name=model_name, label_names=label_names, image_size=image_size
|
|
41
|
+
)
|
|
42
|
+
image = typing.cast(Image.Image | None, get_image_from_url(image_url))
|
|
236
43
|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
)
|
|
241
|
-
continue
|
|
44
|
+
if image is None:
|
|
45
|
+
logger.error("Failed to download image from URL: %s", image_url)
|
|
46
|
+
raise typer.Abort()
|
|
242
47
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
48
|
+
output = model.detect_from_image(
|
|
49
|
+
image,
|
|
50
|
+
triton_uri=triton_uri,
|
|
51
|
+
model_version=triton_model_version,
|
|
52
|
+
threshold=threshold,
|
|
53
|
+
)
|
|
54
|
+
results = output.to_list()
|
|
246
55
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
data={
|
|
250
|
-
**data,
|
|
251
|
-
"predicted_category": predicted_category,
|
|
252
|
-
},
|
|
253
|
-
)
|
|
56
|
+
for result in results:
|
|
57
|
+
typer.echo(result)
|
|
254
58
|
|
|
255
59
|
|
|
256
60
|
app.add_typer(user_app.app, name="users", help="Manage Label Studio users")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: labelr
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: Add your description here
|
|
5
5
|
Requires-Python: >=3.10
|
|
6
6
|
Description-Content-Type: text/markdown
|
|
@@ -16,6 +16,7 @@ Provides-Extra: ultralytics
|
|
|
16
16
|
Requires-Dist: ultralytics>=8.3.49; extra == "ultralytics"
|
|
17
17
|
Provides-Extra: triton
|
|
18
18
|
Requires-Dist: tritonclient>=2.52.0; extra == "triton"
|
|
19
|
+
Requires-Dist: openfoodfacts[ml]>=2.3.4; extra == "triton"
|
|
19
20
|
|
|
20
21
|
# Labelr
|
|
21
22
|
|
|
@@ -36,50 +37,22 @@ It currently allows to:
|
|
|
36
37
|
## Installation
|
|
37
38
|
|
|
38
39
|
Python 3.10 or higher is required to run this CLI.
|
|
39
|
-
You need to install the CLI manually for now, there is no project published on Pypi.
|
|
40
|
-
To do so:
|
|
41
40
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
### Pip
|
|
45
|
-
|
|
46
|
-
Create the virtualenv:
|
|
41
|
+
To install the CLI, simply run:
|
|
47
42
|
|
|
48
43
|
```bash
|
|
49
|
-
|
|
50
|
-
source labelr/bin/activate
|
|
51
|
-
```
|
|
52
|
-
### Conda
|
|
53
|
-
|
|
54
|
-
With conda:
|
|
55
|
-
```bash
|
|
56
|
-
conda create -n labelr python=3.12
|
|
57
|
-
conda activate labelr
|
|
58
|
-
```
|
|
59
|
-
|
|
60
|
-
Then, clone the repository and install the requirements:
|
|
61
|
-
|
|
62
|
-
```bash
|
|
63
|
-
git clone git@github.com:openfoodfacts/openfoodfacts-ai.git
|
|
64
|
-
```
|
|
65
|
-
|
|
66
|
-
```bash
|
|
67
|
-
pip install -r requirements.txt
|
|
44
|
+
pip install labelr
|
|
68
45
|
```
|
|
46
|
+
We recommend to install the CLI in a virtual environment. You can either use pip or conda for that.
|
|
69
47
|
|
|
70
|
-
|
|
48
|
+
There are two optional dependencies that you can install to use the CLI:
|
|
49
|
+
- `ultralytics`: pre-annotate object detection datasets with an ultralytics model (yolo, yolo-world)
|
|
50
|
+
- `triton`: pre-annotate object detection datasets using a model served by a Triton inference server
|
|
71
51
|
|
|
72
|
-
|
|
73
|
-
alias labelr='${VIRTUALENV_DIR}/bin/python3 ${PROJECT_PATH}/main.py'
|
|
74
|
-
```
|
|
75
|
-
or if you are using conda:
|
|
76
|
-
```bash
|
|
77
|
-
alias labelr='${CONDA_PREFIX}/bin/python3 ${PROJECT_PATH}/main.py'
|
|
78
|
-
```
|
|
52
|
+
To install the optional dependencies, you can run:
|
|
79
53
|
|
|
80
|
-
with `${VIRTUALENV_DIR}` the path to the virtual environment where you installed the CLI and `${PROJECT_PATH}` the path to the root of the project, for example:
|
|
81
54
|
```bash
|
|
82
|
-
|
|
55
|
+
pip install labelr[ultralytics,triton]
|
|
83
56
|
```
|
|
84
57
|
|
|
85
58
|
## Usage
|
|
@@ -1,20 +1,19 @@
|
|
|
1
1
|
labelr/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
labelr/__main__.py,sha256=G4e95-IfhI-lOmkOBP6kQ8wl1x_Fl7dZlLOYr90K83c,66
|
|
3
|
-
labelr/annotate.py,sha256=
|
|
3
|
+
labelr/annotate.py,sha256=aphaxyGvKVTjB4DQvj00HpX-X8Xz70UHoKSf4QFWaO4,3456
|
|
4
4
|
labelr/check.py,sha256=3wK6mE0UsKvoBNm0_lyWhCMq7gxkv5r50pvO70damXY,2476
|
|
5
5
|
labelr/config.py,sha256=3RXF_NdkSuHvfVMGMlYmjlw45fU77zQkLX7gmZq7NxM,64
|
|
6
6
|
labelr/export.py,sha256=tcOmVnOdJidWfNouNWoQ4OJgHMbbG-bLFHkId9huiS0,10170
|
|
7
|
-
labelr/main.py,sha256=
|
|
7
|
+
labelr/main.py,sha256=gQ8I287mpLy3HIUWqZUyoLAfPwkphwOIzut7hEbH8tY,2135
|
|
8
8
|
labelr/sample.py,sha256=cpzvgZWVU6GzwD35tqGKEFVKAgqQbSHlWW6IL9FG15Q,5918
|
|
9
9
|
labelr/types.py,sha256=CahqnkLnGj23Jg0X9nftK7Jiorq50WYQqR8u9Ln4E-k,281
|
|
10
10
|
labelr/apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
11
|
labelr/apps/datasets.py,sha256=DXU8XZx0iEHDI5SvUeI8atCKSUmj9YJwO6xTgMZDgEI,7936
|
|
12
12
|
labelr/apps/projects.py,sha256=HpulSciBVTk1sSR1uXjtHytny9t-rN8wiaQ5llNBX6Y,12420
|
|
13
13
|
labelr/apps/users.py,sha256=twQSlpHxE0hrYkgrJpEFbK8lYfWnpJr8vyfLHLtdAUU,909
|
|
14
|
-
labelr/
|
|
15
|
-
labelr-0.
|
|
16
|
-
labelr-0.
|
|
17
|
-
labelr-0.
|
|
18
|
-
labelr-0.
|
|
19
|
-
labelr-0.
|
|
20
|
-
labelr-0.1.0.dist-info/RECORD,,
|
|
14
|
+
labelr-0.2.0.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
15
|
+
labelr-0.2.0.dist-info/METADATA,sha256=nxbEiMBsVEQS71pzZ39uLL_GCVebIB71wyxvFsueGcU,5960
|
|
16
|
+
labelr-0.2.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
|
17
|
+
labelr-0.2.0.dist-info/entry_points.txt,sha256=OACukVeR_2z54i8yQuWqqk_jdEHlyTwmTFOFBmxPp1k,43
|
|
18
|
+
labelr-0.2.0.dist-info/top_level.txt,sha256=bjZo50aGZhXIcZYpYOX4sdAQcamxh8nwfEh7A9RD_Ag,7
|
|
19
|
+
labelr-0.2.0.dist-info/RECORD,,
|
|
@@ -1,241 +0,0 @@
|
|
|
1
|
-
import dataclasses
|
|
2
|
-
import functools
|
|
3
|
-
import logging
|
|
4
|
-
import time
|
|
5
|
-
from typing import Any, Optional
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
from PIL import Image
|
|
9
|
-
|
|
10
|
-
try:
|
|
11
|
-
import grpc
|
|
12
|
-
from tritonclient.grpc import service_pb2, service_pb2_grpc
|
|
13
|
-
from tritonclient.grpc.service_pb2_grpc import GRPCInferenceServiceStub
|
|
14
|
-
except ImportError:
|
|
15
|
-
pass
|
|
16
|
-
|
|
17
|
-
logger = logging.getLogger(__name__)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
JSONType = dict[str, Any]
|
|
21
|
-
|
|
22
|
-
OBJECT_DETECTION_MODEL_VERSION = {
|
|
23
|
-
"nutriscore": "tf-nutriscore-1.0",
|
|
24
|
-
"nutrition_table": "tf-nutrition-table-1.0",
|
|
25
|
-
"universal_logo_detector": "tf-universal-logo-detector-1.0",
|
|
26
|
-
}
|
|
27
|
-
|
|
28
|
-
LABELS = {
|
|
29
|
-
"nutriscore": [
|
|
30
|
-
"NULL",
|
|
31
|
-
"nutriscore-a",
|
|
32
|
-
"nutriscore-b",
|
|
33
|
-
"nutriscore-c",
|
|
34
|
-
"nutriscore-d",
|
|
35
|
-
"nutriscore-e",
|
|
36
|
-
],
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
OBJECT_DETECTION_IMAGE_MAX_SIZE = (1024, 1024)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
@functools.cache
|
|
43
|
-
def get_triton_inference_stub(
|
|
44
|
-
triton_uri: str,
|
|
45
|
-
) -> "GRPCInferenceServiceStub":
|
|
46
|
-
"""Return a gRPC stub for Triton Inference Server.
|
|
47
|
-
|
|
48
|
-
:param triton_uri: URI of the Triton Inference Server
|
|
49
|
-
:return: gRPC stub for Triton Inference Server
|
|
50
|
-
"""
|
|
51
|
-
triton_uri = triton_uri
|
|
52
|
-
channel = grpc.insecure_channel(triton_uri)
|
|
53
|
-
return service_pb2_grpc.GRPCInferenceServiceStub(channel)
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def convert_image_to_array(image: Image.Image) -> np.ndarray:
|
|
57
|
-
"""Convert a PIL Image into a numpy array.
|
|
58
|
-
|
|
59
|
-
The image is converted to RGB if needed before generating the array.
|
|
60
|
-
|
|
61
|
-
:param image: the input image
|
|
62
|
-
:return: the generated numpy array of shape (width, height, 3)
|
|
63
|
-
"""
|
|
64
|
-
if image.mode != "RGB":
|
|
65
|
-
image = image.convert("RGB")
|
|
66
|
-
|
|
67
|
-
(im_width, im_height) = image.size
|
|
68
|
-
|
|
69
|
-
return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
@dataclasses.dataclass
|
|
73
|
-
class ObjectDetectionResult:
|
|
74
|
-
bounding_box: tuple[int, int, int, int]
|
|
75
|
-
score: float
|
|
76
|
-
label: str
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
@dataclasses.dataclass
|
|
80
|
-
class ObjectDetectionRawResult:
|
|
81
|
-
num_detections: int
|
|
82
|
-
detection_boxes: np.ndarray
|
|
83
|
-
detection_scores: np.ndarray
|
|
84
|
-
detection_classes: np.ndarray
|
|
85
|
-
label_names: list[str]
|
|
86
|
-
detection_masks: Optional[np.ndarray] = None
|
|
87
|
-
boxed_image: Optional[Image.Image] = None
|
|
88
|
-
|
|
89
|
-
def select(self, threshold: Optional[float] = None) -> list[ObjectDetectionResult]:
|
|
90
|
-
if threshold is None:
|
|
91
|
-
threshold = 0.5
|
|
92
|
-
|
|
93
|
-
box_masks = self.detection_scores > threshold
|
|
94
|
-
selected_boxes = self.detection_boxes[box_masks]
|
|
95
|
-
selected_scores = self.detection_scores[box_masks]
|
|
96
|
-
selected_classes = self.detection_classes[box_masks]
|
|
97
|
-
|
|
98
|
-
results = []
|
|
99
|
-
for bounding_box, score, label in zip(
|
|
100
|
-
selected_boxes, selected_scores, selected_classes
|
|
101
|
-
):
|
|
102
|
-
label_int = int(label)
|
|
103
|
-
label_str = self.label_names[label_int]
|
|
104
|
-
if label_str is not None:
|
|
105
|
-
result = ObjectDetectionResult(
|
|
106
|
-
bounding_box=tuple(bounding_box.tolist()), # type: ignore
|
|
107
|
-
score=float(score),
|
|
108
|
-
label=label_str,
|
|
109
|
-
)
|
|
110
|
-
results.append(result)
|
|
111
|
-
|
|
112
|
-
return results
|
|
113
|
-
|
|
114
|
-
def to_json(self, threshold: Optional[float] = None) -> list[JSONType]:
|
|
115
|
-
return [dataclasses.asdict(r) for r in self.select(threshold)]
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
def resize_image(image: Image.Image, max_size: tuple[int, int]) -> Image.Image:
|
|
119
|
-
width, height = image.size
|
|
120
|
-
max_width, max_height = max_size
|
|
121
|
-
|
|
122
|
-
if width > max_width or height > max_height:
|
|
123
|
-
new_image = image.copy()
|
|
124
|
-
new_image.thumbnail((max_width, max_height))
|
|
125
|
-
return new_image
|
|
126
|
-
|
|
127
|
-
return image
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
class RemoteModel:
|
|
131
|
-
def __init__(self, name: str, label_names: list[str]):
|
|
132
|
-
self.name: str = name
|
|
133
|
-
self.label_names = label_names
|
|
134
|
-
|
|
135
|
-
def detect_from_image(
|
|
136
|
-
self,
|
|
137
|
-
image: Image.Image,
|
|
138
|
-
triton_uri: str,
|
|
139
|
-
) -> ObjectDetectionRawResult:
|
|
140
|
-
"""Run object detection model on an image.
|
|
141
|
-
|
|
142
|
-
:param image: the input Pillow image
|
|
143
|
-
:param triton_uri: URI of the Triton Inference Server.
|
|
144
|
-
:return: the detection result
|
|
145
|
-
"""
|
|
146
|
-
resized_image = resize_image(image, OBJECT_DETECTION_IMAGE_MAX_SIZE)
|
|
147
|
-
image_array = convert_image_to_array(resized_image)
|
|
148
|
-
grpc_stub = get_triton_inference_stub(triton_uri)
|
|
149
|
-
request = service_pb2.ModelInferRequest()
|
|
150
|
-
request.model_name = self.name
|
|
151
|
-
|
|
152
|
-
image_input = service_pb2.ModelInferRequest().InferInputTensor()
|
|
153
|
-
image_input.name = "inputs"
|
|
154
|
-
image_input.datatype = "UINT8"
|
|
155
|
-
image_input.shape.extend([1, image_array.shape[0], image_array.shape[1], 3])
|
|
156
|
-
request.inputs.extend([image_input])
|
|
157
|
-
|
|
158
|
-
for output_name in (
|
|
159
|
-
"num_detections",
|
|
160
|
-
"detection_classes",
|
|
161
|
-
"detection_scores",
|
|
162
|
-
"detection_boxes",
|
|
163
|
-
):
|
|
164
|
-
output = service_pb2.ModelInferRequest().InferRequestedOutputTensor()
|
|
165
|
-
output.name = output_name
|
|
166
|
-
request.outputs.extend([output])
|
|
167
|
-
|
|
168
|
-
request.raw_input_contents.extend([image_array.tobytes()])
|
|
169
|
-
start_time = time.monotonic()
|
|
170
|
-
response = grpc_stub.ModelInfer(request)
|
|
171
|
-
logger.debug(
|
|
172
|
-
"Inference time for %s: %s", self.name, time.monotonic() - start_time
|
|
173
|
-
)
|
|
174
|
-
|
|
175
|
-
if len(response.outputs) != 4:
|
|
176
|
-
raise Exception(f"expected 4 output, got {len(response.outputs)}")
|
|
177
|
-
|
|
178
|
-
if len(response.raw_output_contents) != 4:
|
|
179
|
-
raise Exception(
|
|
180
|
-
f"expected 4 raw output content, got {len(response.raw_output_contents)}"
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
output_index = {output.name: i for i, output in enumerate(response.outputs)}
|
|
184
|
-
num_detections = (
|
|
185
|
-
np.frombuffer(
|
|
186
|
-
response.raw_output_contents[output_index["num_detections"]],
|
|
187
|
-
dtype=np.float32,
|
|
188
|
-
)
|
|
189
|
-
.reshape((1, 1))
|
|
190
|
-
.astype(int)[0][0] # type: ignore
|
|
191
|
-
)
|
|
192
|
-
detection_scores = np.frombuffer(
|
|
193
|
-
response.raw_output_contents[output_index["detection_scores"]],
|
|
194
|
-
dtype=np.float32,
|
|
195
|
-
).reshape((1, -1))[0]
|
|
196
|
-
detection_classes = (
|
|
197
|
-
np.frombuffer(
|
|
198
|
-
response.raw_output_contents[output_index["detection_classes"]],
|
|
199
|
-
dtype=np.float32,
|
|
200
|
-
)
|
|
201
|
-
.reshape((1, -1))
|
|
202
|
-
.astype(int) # type: ignore
|
|
203
|
-
)[0]
|
|
204
|
-
detection_boxes = np.frombuffer(
|
|
205
|
-
response.raw_output_contents[output_index["detection_boxes"]],
|
|
206
|
-
dtype=np.float32,
|
|
207
|
-
).reshape((1, -1, 4))[0]
|
|
208
|
-
|
|
209
|
-
result = ObjectDetectionRawResult(
|
|
210
|
-
num_detections=num_detections,
|
|
211
|
-
detection_classes=detection_classes,
|
|
212
|
-
detection_boxes=detection_boxes,
|
|
213
|
-
detection_scores=detection_scores,
|
|
214
|
-
detection_masks=None,
|
|
215
|
-
label_names=self.label_names,
|
|
216
|
-
)
|
|
217
|
-
|
|
218
|
-
return result
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
class ObjectDetectionModelRegistry:
|
|
222
|
-
models: dict[str, RemoteModel] = {}
|
|
223
|
-
_loaded = False
|
|
224
|
-
|
|
225
|
-
@classmethod
|
|
226
|
-
def get_available_models(cls) -> list[str]:
|
|
227
|
-
cls.load_all()
|
|
228
|
-
return list(cls.models.keys())
|
|
229
|
-
|
|
230
|
-
@classmethod
|
|
231
|
-
def load(cls, name: str) -> RemoteModel:
|
|
232
|
-
label_names = LABELS[name]
|
|
233
|
-
model = RemoteModel(name, label_names)
|
|
234
|
-
cls.models[name] = model
|
|
235
|
-
return model
|
|
236
|
-
|
|
237
|
-
@classmethod
|
|
238
|
-
def get(cls, name: str) -> RemoteModel:
|
|
239
|
-
if name not in cls.models:
|
|
240
|
-
cls.load(name)
|
|
241
|
-
return cls.models[name]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|