labelr 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
labelr/annotate.py CHANGED
@@ -4,17 +4,17 @@ import string
4
4
  from openfoodfacts.utils import get_logger
5
5
 
6
6
  try:
7
+ from openfoodfacts.ml.object_detection import ObjectDetectionRawResult
7
8
  from ultralytics.engine.results import Results
8
9
  except ImportError:
9
10
  pass
10
11
 
11
- from labelr.triton.object_detection import ObjectDetectionResult
12
12
 
13
13
  logger = get_logger(__name__)
14
14
 
15
15
 
16
16
  def format_annotation_results_from_triton(
17
- objects: list[ObjectDetectionResult], image_width: int, image_height: int
17
+ objects: list["ObjectDetectionRawResult"], image_width: int, image_height: int
18
18
  ):
19
19
  """Format annotation results from a Triton object detection model into
20
20
  Label Studio format."""
labelr/main.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Annotated, Optional
1
+ from typing import Annotated
2
2
 
3
3
  import typer
4
4
  from openfoodfacts.utils import get_logger
@@ -6,7 +6,6 @@ from openfoodfacts.utils import get_logger
6
6
  from labelr.apps import datasets as dataset_app
7
7
  from labelr.apps import projects as project_app
8
8
  from labelr.apps import users as user_app
9
- from labelr.config import LABEL_STUDIO_DEFAULT_URL
10
9
 
11
10
  app = typer.Typer(pretty_exceptions_show_locals=False)
12
11
 
@@ -14,243 +13,48 @@ logger = get_logger()
14
13
 
15
14
 
16
15
  @app.command()
17
- def predict_object(
16
+ def predict(
18
17
  model_name: Annotated[
19
18
  str, typer.Option(help="Name of the object detection model to run")
20
19
  ],
20
+ label_names: Annotated[list[str], typer.Argument(help="List of label names")],
21
21
  image_url: Annotated[str, typer.Option(help="URL of the image to process")],
22
22
  triton_uri: Annotated[
23
23
  str, typer.Option(help="URI (host+port) of the Triton Inference Server")
24
24
  ],
25
- threshold: float = 0.5,
26
- ):
27
- from openfoodfacts.utils import get_image_from_url
28
-
29
- from labelr.triton.object_detection import ObjectDetectionModelRegistry
30
-
31
- model = ObjectDetectionModelRegistry.get(model_name)
32
- image = get_image_from_url(image_url)
33
- output = model.detect_from_image(image, triton_uri=triton_uri)
34
- results = output.select(threshold=threshold)
35
-
36
- for result in results:
37
- typer.echo(result)
38
-
39
-
40
- # Temporary scripts
41
-
42
-
43
- @app.command()
44
- def skip_rotated_images(
45
- api_key: Annotated[str, typer.Option(envvar="LABEL_STUDIO_API_KEY")],
46
- project_id: Annotated[int, typer.Option(help="Label Studio project ID")],
47
- updated_by: Annotated[
48
- Optional[int], typer.Option(help="User ID to declare as annotator")
49
- ] = None,
50
- label_studio_url: str = LABEL_STUDIO_DEFAULT_URL,
51
- ):
52
- import requests
53
- import tqdm
54
- from label_studio_sdk.client import LabelStudio
55
- from label_studio_sdk.types.task import Task
56
- from openfoodfacts.ocr import OCRResult
57
-
58
- session = requests.Session()
59
- ls = LabelStudio(base_url=label_studio_url, api_key=api_key)
60
-
61
- task: Task
62
- for task in tqdm.tqdm(
63
- ls.tasks.list(project=project_id, fields="all"), desc="tasks"
64
- ):
65
- if any(annotation["was_cancelled"] for annotation in task.annotations):
66
- continue
67
-
68
- assert task.total_annotations == 1, (
69
- "Task has multiple annotations (%s)" % task.id
70
- )
71
- task_id = task.id
72
-
73
- annotation = task.annotations[0]
74
- annotation_id = annotation["id"]
75
-
76
- ocr_url = task.data["image_url"].replace(".jpg", ".json")
77
- ocr_result = OCRResult.from_url(ocr_url, session=session, error_raise=False)
78
-
79
- if ocr_result is None:
80
- logger.warning("No OCR result for task: %s", task_id)
81
- continue
82
-
83
- orientation_result = ocr_result.get_orientation()
84
-
85
- if orientation_result is None:
86
- # logger.info("No orientation for task: %s", task_id)
87
- continue
88
-
89
- orientation = orientation_result.orientation.name
90
- if orientation != "up":
91
- logger.info(
92
- "Skipping rotated image for task: %s (orientation: %s)",
93
- task_id,
94
- orientation,
95
- )
96
- ls.annotations.update(
97
- id=annotation_id,
98
- was_cancelled=True,
99
- updated_by=updated_by,
100
- )
101
- elif orientation == "up":
102
- logger.debug("Keeping annotation for task: %s", task_id)
103
-
104
-
105
- @app.command()
106
- def fix_label(
107
- api_key: Annotated[str, typer.Option(envvar="LABEL_STUDIO_API_KEY")],
108
- project_id: Annotated[int, typer.Option(help="Label Studio project ID")],
109
- label_studio_url: str = LABEL_STUDIO_DEFAULT_URL,
110
- ):
111
- import tqdm
112
- from label_studio_sdk.client import LabelStudio
113
- from label_studio_sdk.types.task import Task
114
-
115
- ls = LabelStudio(base_url=label_studio_url, api_key=api_key)
116
-
117
- task: Task
118
- for task in tqdm.tqdm(
119
- ls.tasks.list(project=project_id, fields="all"), desc="tasks"
120
- ):
121
- for prediction in task.predictions:
122
- updated = False
123
- if "result" in prediction:
124
- for result in prediction["result"]:
125
- value = result["value"]
126
- if "rectanglelabels" in value and value["rectanglelabels"] != [
127
- "price-tag"
128
- ]:
129
- value["rectanglelabels"] = ["price-tag"]
130
- updated = True
131
-
132
- if updated:
133
- print(f"Updating prediction {prediction['id']}, task {task.id}")
134
- ls.predictions.update(prediction["id"], result=prediction["result"])
135
-
136
- for annotation in task.annotations:
137
- updated = False
138
- if "result" in annotation:
139
- for result in annotation["result"]:
140
- value = result["value"]
141
- if "rectanglelabels" in value and value["rectanglelabels"] != [
142
- "price-tag"
143
- ]:
144
- value["rectanglelabels"] = ["price-tag"]
145
- updated = True
146
-
147
- if updated:
148
- print(f"Updating annotation {annotation['id']}, task {task.id}")
149
- ls.annotations.update(annotation["id"], result=annotation["result"])
150
-
151
-
152
- @app.command()
153
- def select_price_tag_images(
154
- api_key: Annotated[str, typer.Option(envvar="LABEL_STUDIO_API_KEY")],
155
- project_id: Annotated[int, typer.Option(help="Label Studio project ID")],
156
- label_studio_url: str = LABEL_STUDIO_DEFAULT_URL,
157
- ):
158
- import typing
159
- from pathlib import Path
160
- from typing import Any
161
- from urllib.parse import urlparse
162
-
163
- import requests
164
- import tqdm
165
- from label_studio_sdk.client import LabelStudio
166
- from label_studio_sdk.types.task import Task
167
-
168
- session = requests.Session()
169
- ls = LabelStudio(base_url=label_studio_url, api_key=api_key)
170
-
171
- proof_paths = (Path(__file__).parent / "proof.txt").read_text().splitlines()
172
- task: Task
173
- for task in tqdm.tqdm(
174
- ls.tasks.list(project=project_id, include="data,id"), desc="tasks"
175
- ):
176
- data = typing.cast(dict[str, Any], task.data)
177
-
178
- if "is_raw_product_shelf" in data:
179
- continue
180
- image_url = data["image_url"]
181
- file_path = urlparse(image_url).path.replace("/img/", "")
182
- r = session.get(
183
- f"https://robotoff.openfoodfacts.org/api/v1/images/predict?image_url={image_url}&models=price_proof_classification",
184
- )
185
-
186
- if r.status_code != 200:
187
- print(
188
- f"Failed to get prediction for {image_url}, error: {r.text} (status: {r.status_code})"
189
- )
190
- continue
191
-
192
- prediction = r.json()["predictions"]["price_proof_classification"][0]["label"]
193
-
194
- is_raw_preduct_shelf = False
195
- if prediction in ("PRICE_TAG", "SHELF"):
196
- is_raw_preduct_shelf = file_path in proof_paths
197
-
198
- ls.tasks.update(
199
- task.id,
200
- data={
201
- **data,
202
- "is_raw_product_shelf": "true" if is_raw_preduct_shelf else "false",
203
- },
204
- )
205
-
206
-
207
- @app.command()
208
- def add_predicted_category(
209
- api_key: Annotated[str, typer.Option(envvar="LABEL_STUDIO_API_KEY")],
210
- project_id: Annotated[int, typer.Option(help="Label Studio project ID")],
211
- label_studio_url: str = LABEL_STUDIO_DEFAULT_URL,
25
+ image_size: Annotated[
26
+ int, typer.Option(help="Size of the image the model expects")
27
+ ] = 640,
28
+ threshold: Annotated[float, typer.Option(help="Detection threshold")] = 0.5,
29
+ triton_model_version: str = "1",
212
30
  ):
31
+ """Predict objects in an image using an object detection model served by
32
+ Triton."""
213
33
  import typing
214
- from typing import Any
215
-
216
- import requests
217
- import tqdm
218
- from label_studio_sdk.client import LabelStudio
219
- from label_studio_sdk.types.task import Task
220
-
221
- session = requests.Session()
222
- ls = LabelStudio(base_url=label_studio_url, api_key=api_key)
223
34
 
224
- task: Task
225
- for task in tqdm.tqdm(
226
- ls.tasks.list(project=project_id, include="data,id"), desc="tasks"
227
- ):
228
- data = typing.cast(dict[str, Any], task.data)
35
+ from openfoodfacts.ml.object_detection import ObjectDetector
36
+ from openfoodfacts.utils import get_image_from_url
37
+ from PIL import Image
229
38
 
230
- if "predicted_category" in data:
231
- continue
232
- image_url = data["image_url"]
233
- r = session.get(
234
- f"https://robotoff.openfoodfacts.org/api/v1/images/predict?image_url={image_url}&models=price_proof_classification",
235
- )
39
+ model = ObjectDetector(
40
+ model_name=model_name, label_names=label_names, image_size=image_size
41
+ )
42
+ image = typing.cast(Image.Image | None, get_image_from_url(image_url))
236
43
 
237
- if r.status_code != 200:
238
- print(
239
- f"Failed to get prediction for {image_url}, error: {r.text} (status: {r.status_code})"
240
- )
241
- continue
44
+ if image is None:
45
+ logger.error("Failed to download image from URL: %s", image_url)
46
+ raise typer.Abort()
242
47
 
243
- predicted_category = r.json()["predictions"]["price_proof_classification"][0][
244
- "label"
245
- ]
48
+ output = model.detect_from_image(
49
+ image,
50
+ triton_uri=triton_uri,
51
+ model_version=triton_model_version,
52
+ threshold=threshold,
53
+ )
54
+ results = output.to_list()
246
55
 
247
- ls.tasks.update(
248
- task.id,
249
- data={
250
- **data,
251
- "predicted_category": predicted_category,
252
- },
253
- )
56
+ for result in results:
57
+ typer.echo(result)
254
58
 
255
59
 
256
60
  app.add_typer(user_app.app, name="users", help="Manage Label Studio users")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: labelr
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: Add your description here
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
@@ -16,6 +16,7 @@ Provides-Extra: ultralytics
16
16
  Requires-Dist: ultralytics>=8.3.49; extra == "ultralytics"
17
17
  Provides-Extra: triton
18
18
  Requires-Dist: tritonclient>=2.52.0; extra == "triton"
19
+ Requires-Dist: openfoodfacts[ml]>=2.3.4; extra == "triton"
19
20
 
20
21
  # Labelr
21
22
 
@@ -36,50 +37,22 @@ It currently allows to:
36
37
  ## Installation
37
38
 
38
39
  Python 3.10 or higher is required to run this CLI.
39
- You need to install the CLI manually for now, there is no project published on Pypi.
40
- To do so:
41
40
 
42
- We recommend to install the CLI in a virtual environment. You can either use pip or conda for that.
43
-
44
- ### Pip
45
-
46
- Create the virtualenv:
41
+ To install the CLI, simply run:
47
42
 
48
43
  ```bash
49
- python3 -m venv labelr
50
- source labelr/bin/activate
51
- ```
52
- ### Conda
53
-
54
- With conda:
55
- ```bash
56
- conda create -n labelr python=3.12
57
- conda activate labelr
58
- ```
59
-
60
- Then, clone the repository and install the requirements:
61
-
62
- ```bash
63
- git clone git@github.com:openfoodfacts/openfoodfacts-ai.git
64
- ```
65
-
66
- ```bash
67
- pip install -r requirements.txt
44
+ pip install labelr
68
45
  ```
46
+ We recommend to install the CLI in a virtual environment. You can either use pip or conda for that.
69
47
 
70
- We assume in the following that you have installed the CLI in a virtual environment, and defined the following alias in your shell configuration file (e.g. `.bashrc` or `.zshrc`):
48
+ There are two optional dependencies that you can install to use the CLI:
49
+ - `ultralytics`: pre-annotate object detection datasets with an ultralytics model (yolo, yolo-world)
50
+ - `triton`: pre-annotate object detection datasets using a model served by a Triton inference server
71
51
 
72
- ```bash
73
- alias labelr='${VIRTUALENV_DIR}/bin/python3 ${PROJECT_PATH}/main.py'
74
- ```
75
- or if you are using conda:
76
- ```bash
77
- alias labelr='${CONDA_PREFIX}/bin/python3 ${PROJECT_PATH}/main.py'
78
- ```
52
+ To install the optional dependencies, you can run:
79
53
 
80
- with `${VIRTUALENV_DIR}` the path to the virtual environment where you installed the CLI and `${PROJECT_PATH}` the path to the root of the project, for example:
81
54
  ```bash
82
- ${PROJECT_PATH} = /home/user/openfoodfacts-ai/ml_utils/ml_utils_cli
55
+ pip install labelr[ultralytics,triton]
83
56
  ```
84
57
 
85
58
  ## Usage
@@ -1,20 +1,19 @@
1
1
  labelr/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  labelr/__main__.py,sha256=G4e95-IfhI-lOmkOBP6kQ8wl1x_Fl7dZlLOYr90K83c,66
3
- labelr/annotate.py,sha256=8O9SO2thevo_Aa6etIUxCz2xJVXB4MwuSHj4jxz8sqQ,3441
3
+ labelr/annotate.py,sha256=aphaxyGvKVTjB4DQvj00HpX-X8Xz70UHoKSf4QFWaO4,3456
4
4
  labelr/check.py,sha256=3wK6mE0UsKvoBNm0_lyWhCMq7gxkv5r50pvO70damXY,2476
5
5
  labelr/config.py,sha256=3RXF_NdkSuHvfVMGMlYmjlw45fU77zQkLX7gmZq7NxM,64
6
6
  labelr/export.py,sha256=tcOmVnOdJidWfNouNWoQ4OJgHMbbG-bLFHkId9huiS0,10170
7
- labelr/main.py,sha256=1_cZoJLBMpUV-lnaKb1XaVff4XxWjpIUZbSNQh44tPE,8715
7
+ labelr/main.py,sha256=gQ8I287mpLy3HIUWqZUyoLAfPwkphwOIzut7hEbH8tY,2135
8
8
  labelr/sample.py,sha256=cpzvgZWVU6GzwD35tqGKEFVKAgqQbSHlWW6IL9FG15Q,5918
9
9
  labelr/types.py,sha256=CahqnkLnGj23Jg0X9nftK7Jiorq50WYQqR8u9Ln4E-k,281
10
10
  labelr/apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  labelr/apps/datasets.py,sha256=DXU8XZx0iEHDI5SvUeI8atCKSUmj9YJwO6xTgMZDgEI,7936
12
12
  labelr/apps/projects.py,sha256=HpulSciBVTk1sSR1uXjtHytny9t-rN8wiaQ5llNBX6Y,12420
13
13
  labelr/apps/users.py,sha256=twQSlpHxE0hrYkgrJpEFbK8lYfWnpJr8vyfLHLtdAUU,909
14
- labelr/triton/object_detection.py,sha256=QKUOWiYFH72omyZH4SdbA56JDiVA_e_N8YCSQarkzWQ,7409
15
- labelr-0.1.0.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
16
- labelr-0.1.0.dist-info/METADATA,sha256=tBsu8c-LehNqjPNiCG3XjRLboQNeq2RSy9JZiv4v9Dc,6528
17
- labelr-0.1.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
18
- labelr-0.1.0.dist-info/entry_points.txt,sha256=OACukVeR_2z54i8yQuWqqk_jdEHlyTwmTFOFBmxPp1k,43
19
- labelr-0.1.0.dist-info/top_level.txt,sha256=bjZo50aGZhXIcZYpYOX4sdAQcamxh8nwfEh7A9RD_Ag,7
20
- labelr-0.1.0.dist-info/RECORD,,
14
+ labelr-0.2.0.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
15
+ labelr-0.2.0.dist-info/METADATA,sha256=nxbEiMBsVEQS71pzZ39uLL_GCVebIB71wyxvFsueGcU,5960
16
+ labelr-0.2.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
17
+ labelr-0.2.0.dist-info/entry_points.txt,sha256=OACukVeR_2z54i8yQuWqqk_jdEHlyTwmTFOFBmxPp1k,43
18
+ labelr-0.2.0.dist-info/top_level.txt,sha256=bjZo50aGZhXIcZYpYOX4sdAQcamxh8nwfEh7A9RD_Ag,7
19
+ labelr-0.2.0.dist-info/RECORD,,
@@ -1,241 +0,0 @@
1
- import dataclasses
2
- import functools
3
- import logging
4
- import time
5
- from typing import Any, Optional
6
-
7
- import numpy as np
8
- from PIL import Image
9
-
10
- try:
11
- import grpc
12
- from tritonclient.grpc import service_pb2, service_pb2_grpc
13
- from tritonclient.grpc.service_pb2_grpc import GRPCInferenceServiceStub
14
- except ImportError:
15
- pass
16
-
17
- logger = logging.getLogger(__name__)
18
-
19
-
20
- JSONType = dict[str, Any]
21
-
22
- OBJECT_DETECTION_MODEL_VERSION = {
23
- "nutriscore": "tf-nutriscore-1.0",
24
- "nutrition_table": "tf-nutrition-table-1.0",
25
- "universal_logo_detector": "tf-universal-logo-detector-1.0",
26
- }
27
-
28
- LABELS = {
29
- "nutriscore": [
30
- "NULL",
31
- "nutriscore-a",
32
- "nutriscore-b",
33
- "nutriscore-c",
34
- "nutriscore-d",
35
- "nutriscore-e",
36
- ],
37
- }
38
-
39
- OBJECT_DETECTION_IMAGE_MAX_SIZE = (1024, 1024)
40
-
41
-
42
- @functools.cache
43
- def get_triton_inference_stub(
44
- triton_uri: str,
45
- ) -> "GRPCInferenceServiceStub":
46
- """Return a gRPC stub for Triton Inference Server.
47
-
48
- :param triton_uri: URI of the Triton Inference Server
49
- :return: gRPC stub for Triton Inference Server
50
- """
51
- triton_uri = triton_uri
52
- channel = grpc.insecure_channel(triton_uri)
53
- return service_pb2_grpc.GRPCInferenceServiceStub(channel)
54
-
55
-
56
- def convert_image_to_array(image: Image.Image) -> np.ndarray:
57
- """Convert a PIL Image into a numpy array.
58
-
59
- The image is converted to RGB if needed before generating the array.
60
-
61
- :param image: the input image
62
- :return: the generated numpy array of shape (width, height, 3)
63
- """
64
- if image.mode != "RGB":
65
- image = image.convert("RGB")
66
-
67
- (im_width, im_height) = image.size
68
-
69
- return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
70
-
71
-
72
- @dataclasses.dataclass
73
- class ObjectDetectionResult:
74
- bounding_box: tuple[int, int, int, int]
75
- score: float
76
- label: str
77
-
78
-
79
- @dataclasses.dataclass
80
- class ObjectDetectionRawResult:
81
- num_detections: int
82
- detection_boxes: np.ndarray
83
- detection_scores: np.ndarray
84
- detection_classes: np.ndarray
85
- label_names: list[str]
86
- detection_masks: Optional[np.ndarray] = None
87
- boxed_image: Optional[Image.Image] = None
88
-
89
- def select(self, threshold: Optional[float] = None) -> list[ObjectDetectionResult]:
90
- if threshold is None:
91
- threshold = 0.5
92
-
93
- box_masks = self.detection_scores > threshold
94
- selected_boxes = self.detection_boxes[box_masks]
95
- selected_scores = self.detection_scores[box_masks]
96
- selected_classes = self.detection_classes[box_masks]
97
-
98
- results = []
99
- for bounding_box, score, label in zip(
100
- selected_boxes, selected_scores, selected_classes
101
- ):
102
- label_int = int(label)
103
- label_str = self.label_names[label_int]
104
- if label_str is not None:
105
- result = ObjectDetectionResult(
106
- bounding_box=tuple(bounding_box.tolist()), # type: ignore
107
- score=float(score),
108
- label=label_str,
109
- )
110
- results.append(result)
111
-
112
- return results
113
-
114
- def to_json(self, threshold: Optional[float] = None) -> list[JSONType]:
115
- return [dataclasses.asdict(r) for r in self.select(threshold)]
116
-
117
-
118
- def resize_image(image: Image.Image, max_size: tuple[int, int]) -> Image.Image:
119
- width, height = image.size
120
- max_width, max_height = max_size
121
-
122
- if width > max_width or height > max_height:
123
- new_image = image.copy()
124
- new_image.thumbnail((max_width, max_height))
125
- return new_image
126
-
127
- return image
128
-
129
-
130
- class RemoteModel:
131
- def __init__(self, name: str, label_names: list[str]):
132
- self.name: str = name
133
- self.label_names = label_names
134
-
135
- def detect_from_image(
136
- self,
137
- image: Image.Image,
138
- triton_uri: str,
139
- ) -> ObjectDetectionRawResult:
140
- """Run object detection model on an image.
141
-
142
- :param image: the input Pillow image
143
- :param triton_uri: URI of the Triton Inference Server.
144
- :return: the detection result
145
- """
146
- resized_image = resize_image(image, OBJECT_DETECTION_IMAGE_MAX_SIZE)
147
- image_array = convert_image_to_array(resized_image)
148
- grpc_stub = get_triton_inference_stub(triton_uri)
149
- request = service_pb2.ModelInferRequest()
150
- request.model_name = self.name
151
-
152
- image_input = service_pb2.ModelInferRequest().InferInputTensor()
153
- image_input.name = "inputs"
154
- image_input.datatype = "UINT8"
155
- image_input.shape.extend([1, image_array.shape[0], image_array.shape[1], 3])
156
- request.inputs.extend([image_input])
157
-
158
- for output_name in (
159
- "num_detections",
160
- "detection_classes",
161
- "detection_scores",
162
- "detection_boxes",
163
- ):
164
- output = service_pb2.ModelInferRequest().InferRequestedOutputTensor()
165
- output.name = output_name
166
- request.outputs.extend([output])
167
-
168
- request.raw_input_contents.extend([image_array.tobytes()])
169
- start_time = time.monotonic()
170
- response = grpc_stub.ModelInfer(request)
171
- logger.debug(
172
- "Inference time for %s: %s", self.name, time.monotonic() - start_time
173
- )
174
-
175
- if len(response.outputs) != 4:
176
- raise Exception(f"expected 4 output, got {len(response.outputs)}")
177
-
178
- if len(response.raw_output_contents) != 4:
179
- raise Exception(
180
- f"expected 4 raw output content, got {len(response.raw_output_contents)}"
181
- )
182
-
183
- output_index = {output.name: i for i, output in enumerate(response.outputs)}
184
- num_detections = (
185
- np.frombuffer(
186
- response.raw_output_contents[output_index["num_detections"]],
187
- dtype=np.float32,
188
- )
189
- .reshape((1, 1))
190
- .astype(int)[0][0] # type: ignore
191
- )
192
- detection_scores = np.frombuffer(
193
- response.raw_output_contents[output_index["detection_scores"]],
194
- dtype=np.float32,
195
- ).reshape((1, -1))[0]
196
- detection_classes = (
197
- np.frombuffer(
198
- response.raw_output_contents[output_index["detection_classes"]],
199
- dtype=np.float32,
200
- )
201
- .reshape((1, -1))
202
- .astype(int) # type: ignore
203
- )[0]
204
- detection_boxes = np.frombuffer(
205
- response.raw_output_contents[output_index["detection_boxes"]],
206
- dtype=np.float32,
207
- ).reshape((1, -1, 4))[0]
208
-
209
- result = ObjectDetectionRawResult(
210
- num_detections=num_detections,
211
- detection_classes=detection_classes,
212
- detection_boxes=detection_boxes,
213
- detection_scores=detection_scores,
214
- detection_masks=None,
215
- label_names=self.label_names,
216
- )
217
-
218
- return result
219
-
220
-
221
- class ObjectDetectionModelRegistry:
222
- models: dict[str, RemoteModel] = {}
223
- _loaded = False
224
-
225
- @classmethod
226
- def get_available_models(cls) -> list[str]:
227
- cls.load_all()
228
- return list(cls.models.keys())
229
-
230
- @classmethod
231
- def load(cls, name: str) -> RemoteModel:
232
- label_names = LABELS[name]
233
- model = RemoteModel(name, label_names)
234
- cls.models[name] = model
235
- return model
236
-
237
- @classmethod
238
- def get(cls, name: str) -> RemoteModel:
239
- if name not in cls.models:
240
- cls.load(name)
241
- return cls.models[name]
File without changes