samgis_core 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,178 @@
1
+ """
2
+ Define a machine learning model executed by ONNX Runtime (https://onnxruntime.ai/)
3
+ for Segment Anything (https://segment-anything.com).
4
+ Modified from
5
+ - https://github.com/vietanhdev/samexporter/
6
+ - https://github.com/AndreyGermanov/sam_onnx_full_export/
7
+
8
+ Copyright (c) 2023 Viet Anh Nguyen, Andrey Germanov
9
+ Copyright (c) 2024-today Alessandro Trinca Tornidor
10
+
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy
12
+ of this software and associated documentation files (the "Software"), to deal
13
+ in the Software without restriction, including without limitation the rights
14
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ copies of the Software, and to permit persons to whom the Software is
16
+ furnished to do so, subject to the following conditions:
17
+
18
+ The above copyright notice and this permission notice shall be included in all
19
+ copies or substantial portions of the Software.
20
+
21
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
+ SOFTWARE.
28
+ """
29
+ from numpy import array as np_array, pad as np_pad, zeros, ndarray, concatenate, float32
30
+ from onnxruntime import get_available_providers, InferenceSession
31
+
32
+ from samgis_core import app_logger
33
+ from samgis_core.utilities.constants import DEFAULT_INPUT_SHAPE
34
+ from samgis_core.utilities.type_hints import ListDict, EmbeddingPILImage, PIL_Image
35
+ from samgis_core.utilities.utilities import convert_ndarray_to_pil, apply_coords
36
+
37
+
38
+ class SegmentAnythingONNX2:
39
+ """
40
+ Segmentation model using SegmentAnything.
41
+ Compatible with onnxruntime 1.17.x and later
42
+ """
43
+
44
+ def __init__(self, encoder_model_path: str, decoder_model_path: str) -> None:
45
+ self.target_size = DEFAULT_INPUT_SHAPE[1]
46
+ self.input_size = DEFAULT_INPUT_SHAPE
47
+
48
+ # Load models
49
+ providers = get_available_providers()
50
+
51
+ # Pop TensorRT Runtime due to crashing issues
52
+ # TODO: Add back when TensorRT backend is stable
53
+ providers = [p for p in providers if p != "TensorrtExecutionProvider"]
54
+
55
+ if providers:
56
+ app_logger.info(
57
+ "Available providers for ONNXRuntime: %s", ", ".join(providers)
58
+ )
59
+ else:
60
+ app_logger.warning("No available providers for ONNXRuntime")
61
+ self.encoder_session = InferenceSession(
62
+ encoder_model_path, providers=providers
63
+ )
64
+ self.encoder_input_name = self.encoder_session.get_inputs()[0].name
65
+ app_logger.info("encoder_input_name:", self.encoder_input_name)
66
+ self.decoder_session = InferenceSession(
67
+ decoder_model_path, providers=providers
68
+ )
69
+
70
+ @staticmethod
71
+ def get_input_points(prompt: ListDict):
72
+ """Get input points"""
73
+ points = []
74
+ labels = []
75
+ for mark in prompt:
76
+ if mark["type"] == "point":
77
+ points.append(mark["data"])
78
+ labels.append(mark["label"])
79
+ elif mark["type"] == "rectangle":
80
+ points.append([mark["data"][0], mark["data"][1]]) # top left
81
+ points.append(
82
+ [mark["data"][2], mark["data"][3]]
83
+ ) # bottom right
84
+ labels.append(2)
85
+ labels.append(3)
86
+ points, labels = np_array(points), np_array(labels)
87
+ return points, labels
88
+
89
+ def encode(self, img: PIL_Image | ndarray) -> EmbeddingPILImage:
90
+ """
91
+ Calculate embedding and metadata for a single image.
92
+
93
+ Args:
94
+ img: input image to embed
95
+
96
+ Returns:
97
+ embedding image dict useful to store and cache image embeddings
98
+ """
99
+ resized_image = self.preprocess_image(img)
100
+ padded_input_tensor = self.padding_tensor(resized_image)
101
+
102
+ # 2. GET IMAGE EMBEDDINGS USING IMAGE ENCODER
103
+ outputs = self.encoder_session.run(None, {"images": padded_input_tensor})
104
+ image_embedding = outputs[0]
105
+ img = convert_ndarray_to_pil(img)
106
+ return {
107
+ "image_embedding": image_embedding,
108
+ "original_size": img.size,
109
+ "resized_size": resized_image.size
110
+ }
111
+
112
+ def predict_masks(self, embedding: EmbeddingPILImage, prompt: ListDict):
113
+ """
114
+ Predict masks for a single image.
115
+ """
116
+ input_points, input_labels = self.get_input_points(prompt)
117
+
118
+ # Add a batch index, concatenate a padding point, and transform.
119
+ onnx_coord = concatenate([input_points, np_array([[0.0, 0.0]])], axis=0)[None, :, :]
120
+ onnx_label = concatenate([input_labels, np_array([-1])], axis=0)[None, :].astype(float32)
121
+
122
+ onnx_coord = apply_coords(onnx_coord, embedding)
123
+ orig_width, orig_height = embedding["original_size"]
124
+ app_logger.info(f"onnx_coord:{onnx_coord}.")
125
+
126
+ # RUN DECODER TO GET MASK
127
+ onnx_mask_input = zeros((1, 1, 256, 256), dtype=float32)
128
+ onnx_has_mask_input = zeros(1, dtype=float32)
129
+ output_masks, _, _ = self.decoder_session.run(None, {
130
+ "image_embeddings": embedding["image_embedding"],
131
+ "point_coords": onnx_coord,
132
+ "point_labels": onnx_label,
133
+ "mask_input": onnx_mask_input,
134
+ "has_mask_input": onnx_has_mask_input,
135
+ "orig_im_size": np_array([orig_height, orig_width], dtype=float32)
136
+ })
137
+ return output_masks
138
+
139
+ def preprocess_image(self, img: PIL_Image | ndarray):
140
+ """Resize image preserving aspect ratio using 'output_size_target' as a long side"""
141
+ from PIL import Image
142
+
143
+ app_logger.info(f"image type:{type(img)}, shape/size:{img.size}.")
144
+ try:
145
+ orig_width, orig_height = img.size
146
+ except TypeError:
147
+ img = Image.fromarray(img)
148
+ orig_width, orig_height = img.size
149
+
150
+ resized_height = self.target_size
151
+ resized_width = int(self.target_size / orig_height * orig_width)
152
+
153
+ if orig_width > orig_height:
154
+ resized_width = self.target_size
155
+ resized_height = int(self.target_size / orig_width * orig_height)
156
+
157
+ img = img.resize((resized_width, resized_height), Image.Resampling.BILINEAR)
158
+ return img
159
+
160
+ def padding_tensor(self, img: PIL_Image | ndarray):
161
+ # Prepare input tensor from image
162
+ tensor_input = np_array(img)
163
+ resized_width, resized_height = img.size
164
+
165
+ # Normalize input tensor numbers
166
+ mean = np_array([123.675, 116.28, 103.53])
167
+ std = np_array([[58.395, 57.12, 57.375]])
168
+ tensor_input = (tensor_input - mean) / std
169
+
170
+ # Transpose input tensor to shape (Batch,Channels,Height,Width
171
+ tensor_input = tensor_input.transpose(2, 0, 1)[None, :, :, :].astype(float32)
172
+
173
+ # Make image square self.target_size x self.target_size by padding short side by zeros
174
+ tensor_input = np_pad(tensor_input, ((0, 0), (0, 0), (0, 0), (0, self.target_size - resized_width)))
175
+ if resized_height < resized_width:
176
+ tensor_input = np_pad(tensor_input, ((0, 0), (0, 0), (0, self.target_size - resized_height), (0, 0)))
177
+
178
+ return tensor_input
@@ -0,0 +1,121 @@
1
+ from numpy import array as np_array, uint8, zeros, ndarray
2
+
3
+ from samgis_core import app_logger, MODEL_FOLDER
4
+ from samgis_core.prediction_api.sam_onnx2 import SegmentAnythingONNX2
5
+ from samgis_core.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME
6
+ from samgis_core.utilities.type_hints import ListDict, PIL_Image, TupleNdarrayInt, EmbeddingPILDict
7
+
8
+
9
+ def get_raster_inference(
10
+ img: PIL_Image | ndarray, prompt: ListDict, models_instance: SegmentAnythingONNX2, model_name: str
11
+ ) -> TupleNdarrayInt:
12
+ """
13
+ Get inference output for a given image using a SegmentAnythingONNX model
14
+
15
+ Args:
16
+ img: input PIL Image
17
+ prompt: list of prompt dict
18
+ models_instance: SegmentAnythingONNX instance model
19
+ model_name: model name string
20
+
21
+ Returns:
22
+ raster prediction mask, prediction number
23
+ """
24
+ np_img = np_array(img)
25
+ app_logger.info(f"img type {type(np_img)}, prompt:{prompt}.")
26
+ app_logger.debug(f"onnxruntime input shape/size (shape if PIL) {np_img.size}.")
27
+ try:
28
+ app_logger.debug(f"onnxruntime input shape (NUMPY) {np_img.shape}.")
29
+ except Exception as e_shape:
30
+ app_logger.error(f"e_shape:{e_shape}.")
31
+ app_logger.info(f"instantiated model {model_name}, ENCODER {MODEL_ENCODER_NAME}, "
32
+ f"DECODER {MODEL_DECODER_NAME} from {MODEL_FOLDER}: Creating embedding...")
33
+ embedding = models_instance.encode(np_img)
34
+ app_logger.debug(f"embedding created, running predict_masks with prompt {prompt}...")
35
+ return get_raster_inference_using_existing_embedding(embedding, prompt, models_instance)
36
+
37
+
38
+ def get_inference_embedding(
39
+ img: PIL_Image | ndarray, models_instance: SegmentAnythingONNX2, model_name: str, embedding_key: str,
40
+ embedding_dict: EmbeddingPILDict) -> EmbeddingPILDict:
41
+ """add an embedding to the embedding dict if needed
42
+
43
+ Args:
44
+ img: input PIL Image
45
+ models_instance: SegmentAnythingONNX instance model
46
+ model_name: model name string
47
+ embedding_key: embedding id
48
+ embedding_dict: embedding dict object
49
+
50
+ Returns:
51
+ raster dict
52
+ """
53
+ if embedding_key in embedding_dict:
54
+ app_logger.info("found embedding in dict...")
55
+ if embedding_key not in embedding_dict:
56
+ np_img = np_array(img)
57
+ app_logger.info(f"prepare embedding using img type {type(np_img)}.")
58
+ app_logger.debug(f"onnxruntime input shape/size (shape if PIL) {np_img.size}.")
59
+ try:
60
+ app_logger.debug(f"onnxruntime input shape (NUMPY) {np_img.shape}.")
61
+ except Exception as e_shape:
62
+ app_logger.error(f"e_shape:{e_shape}.")
63
+ app_logger.info(f"instantiated model {model_name}, ENCODER {MODEL_ENCODER_NAME}, "
64
+ f"DECODER {MODEL_DECODER_NAME} from {MODEL_FOLDER}: Creating embedding...")
65
+ embedding = models_instance.encode(np_img)
66
+ embedding_dict[embedding_key] = embedding
67
+ return embedding_dict
68
+
69
+
70
+ def get_raster_inference_using_existing_embedding(
71
+ embedding: dict, prompt: ListDict, models_instance: SegmentAnythingONNX2) -> TupleNdarrayInt:
72
+ """
73
+ Get inference output for a given image using a SegmentAnythingONNX model, using an existing embedding instead of a
74
+ new ndarray or PIL image
75
+
76
+ Args:
77
+ embedding: dict
78
+ prompt: list of prompt dict
79
+ models_instance: SegmentAnythingONNX instance model
80
+
81
+ Returns:
82
+ raster prediction mask, prediction number
83
+ """
84
+ app_logger.info(f"using existing embedding of type {type(embedding)}.")
85
+ inference_out = models_instance.predict_masks(embedding, prompt)
86
+ len_inference_out = len(inference_out[0, :, :, :])
87
+ app_logger.info(f"Created {len_inference_out} prediction_masks,"
88
+ f"shape:{inference_out.shape}, dtype:{inference_out.dtype}.")
89
+ mask = zeros((inference_out.shape[2], inference_out.shape[3]), dtype=uint8)
90
+ for n, m in enumerate(inference_out[0, :, :, :]):
91
+ app_logger.debug(f"{n}th of prediction_masks shape {inference_out.shape}"
92
+ f" => mask shape:{mask.shape}, {mask.dtype}.")
93
+ mask[m > 0.0] = 255
94
+ return mask, len_inference_out
95
+
96
+
97
+ def get_raster_inference_with_embedding_from_dict(
98
+ img: PIL_Image | ndarray, prompt: ListDict, models_instance: SegmentAnythingONNX2, model_name: str,
99
+ embedding_key: str, embedding_dict: dict) -> TupleNdarrayInt:
100
+ """
101
+ Get inference output using a SegmentAnythingONNX model, but get the image embedding from the given embedding dict
102
+ instead of creating a new embedding. This function needs the img argument to update the embedding dict if necessary
103
+
104
+ Args:
105
+ img: input PIL Image
106
+ prompt: list of prompt dict
107
+ models_instance: SegmentAnythingONNX instance model
108
+ model_name: model name string
109
+ embedding_key: embedding id
110
+ embedding_dict: embedding images dict
111
+
112
+ Returns:
113
+ raster prediction mask, prediction number
114
+ """
115
+ app_logger.info(f"handling embedding using key {embedding_key}.")
116
+ embedding_dict = get_inference_embedding(img, models_instance, model_name, embedding_key, embedding_dict)
117
+ app_logger.info(f"getting embedding with key {embedding_key} from dict...")
118
+ embedding = embedding_dict[embedding_key]
119
+ n_keys = len(embedding_dict)
120
+ app_logger.info(f"embedding created ({n_keys} keys in embedding dict), running predict_masks with prompt {prompt}.")
121
+ return get_raster_inference_using_existing_embedding(embedding, prompt, models_instance)
@@ -1,5 +1,6 @@
1
1
  """Project constants"""
2
- DEFAULT_INPUT_SHAPE = 684, 1024
3
- MODEL_ENCODER_NAME = "mobile_sam.encoder.onnx"
4
- MODEL_DECODER_NAME = "sam_vit_h_4b8939.decoder.onnx"
2
+ import os
5
3
 
4
+ DEFAULT_INPUT_SHAPE = 684, 1024
5
+ MODEL_ENCODER_NAME = os.getenv("MODEL_ENCODER_NAME", "mobile_sam.encoder.onnx")
6
+ MODEL_DECODER_NAME = os.getenv("MODEL_DECODER_NAME", "mobile_sam.decoder.onnx")
@@ -0,0 +1,11 @@
1
+ def helper_imshow_output_expected(img1, img2, title1, title2):
2
+ from matplotlib import pyplot as plt
3
+
4
+ fig, ax = plt.subplot_mosaic([
5
+ [title1, title2]
6
+ ], figsize=(15, 10))
7
+
8
+ ax[title1].imshow(img1)
9
+ ax[title2].imshow(img2)
10
+
11
+ plt.show()
@@ -27,6 +27,9 @@ class ListInt(list[int]): pass
27
27
  class TupleInt(tuple[int]): pass
28
28
 
29
29
 
30
+ TupleInt2 = NewType("TupleInt", tuple[int, int])
31
+
32
+
30
33
  class TupleNdarrayInt(tuple[ndarray, int]): pass
31
34
 
32
35
 
@@ -55,4 +58,11 @@ class EmbeddingImage(TypedDict):
55
58
  transform_matrix: ndarray
56
59
 
57
60
 
61
+ class EmbeddingPILImage(TypedDict):
62
+ image_embedding: ndarray
63
+ original_size: TupleInt2
64
+ resized_size: TupleInt2
65
+
66
+
58
67
  EmbeddingDict = dict[str, EmbeddingImage]
68
+ EmbeddingPILDict = dict[str, EmbeddingPILImage]
@@ -1,8 +1,11 @@
1
1
  """Various utilities (logger, time benchmark, args dump, numerical and stats info)"""
2
- import numbers
2
+ from copy import deepcopy
3
+
4
+ from numpy import ndarray, float32
3
5
 
4
6
  from samgis_core import app_logger
5
7
  from samgis_core.utilities.serialize import serialize
8
+ from samgis_core.utilities.type_hints import EmbeddingPILImage, PIL_Image
6
9
 
7
10
 
8
11
  def _prepare_base64_input(sb):
@@ -14,7 +17,7 @@ def _prepare_base64_input(sb):
14
17
  raise ValueError("Argument must be string or bytes")
15
18
 
16
19
 
17
- def _is_base64(sb: str or bytes):
20
+ def _is_base64(sb: str | bytes):
18
21
  import base64
19
22
 
20
23
  try:
@@ -43,7 +46,7 @@ def base64_decode(s):
43
46
  return s
44
47
 
45
48
 
46
- def base64_encode(sb: str or bytes) -> bytes:
49
+ def base64_encode(sb: str | bytes) -> bytes:
47
50
  """
48
51
  Encode input strings or bytes as base64
49
52
 
@@ -59,7 +62,7 @@ def base64_encode(sb: str or bytes) -> bytes:
59
62
  return base64.b64encode(sb_bytes)
60
63
 
61
64
 
62
- def hash_calculate(arr) -> str or bytes:
65
+ def hash_calculate(arr) -> str | bytes:
63
66
  """
64
67
  Return computed hash from input variable (typically a numpy array).
65
68
 
@@ -92,3 +95,26 @@ def hash_calculate(arr) -> str or bytes:
92
95
  else:
93
96
  raise ValueError(f"variable 'arr':{arr} of type '{type(arr)}' not yet handled.")
94
97
  return b64encode(hash_fn.digest())
98
+
99
+
100
+ def convert_ndarray_to_pil(pil_image: PIL_Image | ndarray):
101
+ from PIL import Image
102
+
103
+ if isinstance(pil_image, ndarray):
104
+ pil_image = Image.fromarray(pil_image)
105
+ return pil_image
106
+
107
+
108
+ def apply_coords(coords: ndarray, embedding: EmbeddingPILImage):
109
+ """
110
+ Expects a numpy np_array of length 2 in the final dimension. Requires the
111
+ original image size in (H, W) format.
112
+ """
113
+ orig_width, orig_height = embedding["original_size"]
114
+ resized_width, resized_height = embedding["resized_size"]
115
+ coords = deepcopy(coords).astype(float)
116
+
117
+ coords[..., 0] = coords[..., 0] * (resized_width / orig_width)
118
+ coords[..., 1] = coords[..., 1] * (resized_height / orig_height)
119
+
120
+ return coords.astype(float32)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: samgis_core
3
- Version: 1.2.1
3
+ Version: 1.2.2
4
4
  Summary: SamGIS CORE
5
5
  License: MIT
6
6
  Author: alessandro trinca tornidor
@@ -14,8 +14,7 @@ Requires-Dist: bson (>=0.5.10,<0.6.0)
14
14
  Requires-Dist: loguru (>=0.7.2,<0.8.0)
15
15
  Requires-Dist: numpy (==1.25.2) ; python_version >= "3.10" and python_version < "3.11"
16
16
  Requires-Dist: numpy (>=1.26,<2.0) ; python_version >= "3.11" and python_version < "3.12"
17
- Requires-Dist: onnxruntime (==1.16.3)
18
- Requires-Dist: opencv-python-headless (==4.8.1.78)
17
+ Requires-Dist: onnxruntime (==1.17.3)
19
18
  Requires-Dist: pillow (>=10.2.0,<11.0.0)
20
19
  Project-URL: Source, https://gitlab.com/aletrn/samgis_core
21
20
  Description-Content-Type: text/markdown
@@ -0,0 +1,15 @@
1
+ samgis_core/__init__.py,sha256=1kFX8G22dxNz23J7uOYl-SMWOe4W1olssc-5zKAVwSc,351
2
+ samgis_core/__version__.py,sha256=x-8LeQkljky-ao1MyT98tSQE1xNvJp3LelqsAdCb9Og,94
3
+ samgis_core/prediction_api/__init__.py,sha256=_jUZhspS26ygiSzBJywfZ4fQf9X7oC8w6oxvhd9S2hQ,57
4
+ samgis_core/prediction_api/sam_onnx2.py,sha256=u0-BgNff0nMzRNk9MOF1-hVMmWJcZr_EFXAp8iF5NBs,7437
5
+ samgis_core/prediction_api/sam_onnx_inference.py,sha256=EQjJy1RLZn0-m66bdY_T1bR-t6PnNu47JJVphEkgPhE,5539
6
+ samgis_core/utilities/__init__.py,sha256=nL9pzdB4SdEF8m5gCbtlVCtdGLg9JjPm-FNxKBsIBZA,32
7
+ samgis_core/utilities/constants.py,sha256=0xBdfGYwCg4O0OXFtTcMVNj-kryjbajcxOZhMVkVP7U,227
8
+ samgis_core/utilities/fastapi_logger.py,sha256=yt8Vj1viyE-Kry1_iy5p4saZvGEEmRfuTEMHYXgnjqk,650
9
+ samgis_core/utilities/plot_images.py,sha256=nY_1KW7x_h218MA9leulEKLkoNhLlyEscLSWs0Az2uE,263
10
+ samgis_core/utilities/serialize.py,sha256=aIjhEoibBpV_gpgOg6LiVxZCWjOkYxlzcboDZLQctJE,2689
11
+ samgis_core/utilities/type_hints.py,sha256=hAMYXpHgMhYguOPbegCQbVCIexWNSn6S-2Q_nilfesQ,1068
12
+ samgis_core/utilities/utilities.py,sha256=tRGp-Iw0PoPf5YHDL6Hx2CEdZhqs03hh3YPe_rmoO-E,3286
13
+ samgis_core-1.2.2.dist-info/METADATA,sha256=ZBYbeGm4TKxV7OVTihqdCXrhRVHSlYkF3zY3KaoFErg,1096
14
+ samgis_core-1.2.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
15
+ samgis_core-1.2.2.dist-info/RECORD,,
@@ -1,371 +0,0 @@
1
- """
2
- Define a machine learning model executed by ONNX Runtime (https://ai/)
3
- for Segment Anything (https://segment-anything.com).
4
- Modified from https://github.com/vietanhdev/samexporter/
5
-
6
- Copyright (c) 2023 Viet Anh Nguyen
7
-
8
- Permission is hereby granted, free of charge, to any person obtaining a copy
9
- of this software and associated documentation files (the "Software"), to deal
10
- in the Software without restriction, including without limitation the rights
11
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
- copies of the Software, and to permit persons to whom the Software is
13
- furnished to do so, subject to the following conditions:
14
-
15
- The above copyright notice and this permission notice shall be included in all
16
- copies or substantial portions of the Software.
17
-
18
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
- SOFTWARE.
25
- """
26
- from copy import deepcopy
27
-
28
- from cv2 import INTER_LINEAR, warpAffine
29
- from numpy import array as np_array, uint8, zeros, ndarray
30
- from numpy import concatenate, float32, linalg, matmul, ones
31
- from onnxruntime import get_available_providers, InferenceSession
32
-
33
- from samgis_core import app_logger, MODEL_FOLDER
34
- from samgis_core.utilities.constants import DEFAULT_INPUT_SHAPE, MODEL_ENCODER_NAME, MODEL_DECODER_NAME
35
- from samgis_core.utilities.type_hints import PIL_Image, ListDict, TupleNdarrayInt, EmbeddingDict, EmbeddingImage
36
-
37
-
38
- class SegmentAnythingONNX:
39
- """Segmentation model using SegmentAnything"""
40
-
41
- def __init__(self, encoder_model_path, decoder_model_path) -> None:
42
- self.target_size = DEFAULT_INPUT_SHAPE[1]
43
- self.input_size = DEFAULT_INPUT_SHAPE
44
-
45
- # Load models
46
- providers = get_available_providers()
47
-
48
- # Pop TensorRT Runtime due to crashing issues
49
- # TODO: Add back when TensorRT backend is stable
50
- providers = [p for p in providers if p != "TensorrtExecutionProvider"]
51
-
52
- if providers:
53
- app_logger.info(
54
- "Available providers for ONNXRuntime: %s", ", ".join(providers)
55
- )
56
- else:
57
- app_logger.warning("No available providers for ONNXRuntime")
58
- self.encoder_session = InferenceSession(
59
- encoder_model_path, providers=providers
60
- )
61
- self.encoder_input_name = self.encoder_session.get_inputs()[0].name
62
- self.decoder_session = InferenceSession(
63
- decoder_model_path, providers=providers
64
- )
65
-
66
- @staticmethod
67
- def get_input_points(prompt):
68
- """Get input points"""
69
- points = []
70
- labels = []
71
- for mark in prompt:
72
- if mark["type"] == "point":
73
- points.append(mark["data"])
74
- labels.append(mark["label"])
75
- elif mark["type"] == "rectangle":
76
- points.append([mark["data"][0], mark["data"][1]]) # top left
77
- points.append(
78
- [mark["data"][2], mark["data"][3]]
79
- ) # bottom right
80
- labels.append(2)
81
- labels.append(3)
82
- points, labels = np_array(points), np_array(labels)
83
- return points, labels
84
-
85
- def run_encoder(self, encoder_inputs):
86
- """Run encoder"""
87
- output = self.encoder_session.run(None, encoder_inputs)
88
- image_embedding = output[0]
89
- return image_embedding
90
-
91
- @staticmethod
92
- def get_preprocess_shape(old_h: int, old_w: int, long_side_length: int):
93
- """
94
- Compute the output size given input size and target long side length.
95
- """
96
- scale = long_side_length * 1.0 / max(old_h, old_w)
97
- new_h, new_w = old_h * scale, old_w * scale
98
- new_w = int(new_w + 0.5)
99
- new_h = int(new_h + 0.5)
100
- return new_h, new_w
101
-
102
- def apply_coords(self, coords: ndarray, original_size, target_length):
103
- """
104
- Expects a numpy np_array of length 2 in the final dimension. Requires the
105
- original image size in (H, W) format.
106
- """
107
- old_h, old_w = original_size
108
- new_h, new_w = self.get_preprocess_shape(
109
- original_size[0], original_size[1], target_length
110
- )
111
- coords = deepcopy(coords).astype(float)
112
- coords[..., 0] = coords[..., 0] * (new_w / old_w)
113
- coords[..., 1] = coords[..., 1] * (new_h / old_h)
114
- return coords
115
-
116
- def run_decoder(
117
- self, image_embedding, original_size, transform_matrix, prompt
118
- ):
119
- """Run decoder"""
120
- input_points, input_labels = self.get_input_points(prompt)
121
-
122
- # Add a batch index, concatenate a padding point, and transform.
123
- onnx_coord = concatenate(
124
- [input_points, np_array([[0.0, 0.0]])], axis=0
125
- )[None, :, :]
126
- onnx_label = concatenate([input_labels, np_array([-1])], axis=0)[
127
- None, :
128
- ].astype(float32)
129
- onnx_coord = self.apply_coords(
130
- onnx_coord, self.input_size, self.target_size
131
- ).astype(float32)
132
-
133
- # Apply the transformation matrix to the coordinates.
134
- onnx_coord = concatenate(
135
- [
136
- onnx_coord,
137
- ones((1, onnx_coord.shape[1], 1), dtype=float32),
138
- ],
139
- axis=2,
140
- )
141
- onnx_coord = matmul(onnx_coord, transform_matrix.T)
142
- onnx_coord = onnx_coord[:, :, :2].astype(float32)
143
-
144
- # Create an empty mask input and an indicator for no mask.
145
- onnx_mask_input = zeros((1, 1, 256, 256), dtype=float32)
146
- onnx_has_mask_input = zeros(1, dtype=float32)
147
-
148
- decoder_inputs = {
149
- "image_embeddings": image_embedding,
150
- "point_coords": onnx_coord,
151
- "point_labels": onnx_label,
152
- "mask_input": onnx_mask_input,
153
- "has_mask_input": onnx_has_mask_input,
154
- "orig_im_size": np_array(self.input_size, dtype=float32),
155
- }
156
- masks, _, _ = self.decoder_session.run(None, decoder_inputs)
157
-
158
- # Transform the masks back to the original image size.
159
- inv_transform_matrix = linalg.inv(transform_matrix)
160
- transformed_masks = self.transform_masks(
161
- masks, original_size, inv_transform_matrix
162
- )
163
-
164
- return transformed_masks
165
-
166
- @staticmethod
167
- def transform_masks(masks, original_size, transform_matrix):
168
- """Transform masks
169
- Transform the masks back to the original image size.
170
- """
171
- output_masks = []
172
- for batch in range(masks.shape[0]):
173
- batch_masks = []
174
- for mask_id in range(masks.shape[1]):
175
- mask = masks[batch, mask_id]
176
- try:
177
- try:
178
- app_logger.debug(f"mask_shape transform_masks:{mask.shape}, dtype:{mask.dtype}.")
179
- except Exception as e_mask_shape_transform_masks:
180
- app_logger.error(f"e_mask_shape_transform_masks:{e_mask_shape_transform_masks}.")
181
- mask = warpAffine(
182
- mask,
183
- transform_matrix[:2],
184
- (original_size[1], original_size[0]),
185
- flags=INTER_LINEAR,
186
- )
187
- except Exception as e_warp_affine1:
188
- app_logger.error(f"e_warp_affine1 mask shape:{mask.shape}, dtype:{mask.dtype}.")
189
- app_logger.error(
190
- f"e_warp_affine1 transform_matrix:{transform_matrix}, [:2] {transform_matrix[:2]}.")
191
- app_logger.error(f"e_warp_affine1 original_size:{original_size}.")
192
- raise e_warp_affine1
193
- batch_masks.append(mask)
194
- output_masks.append(batch_masks)
195
- return np_array(output_masks)
196
-
197
- def encode(self, cv_image: ndarray) -> EmbeddingImage:
198
- """
199
- Calculate embedding and metadata for a single image.
200
-
201
- Args:
202
- cv_image: input image to embed
203
-
204
- Returns:
205
- embedding image dict useful to store and cache image embeddings
206
- """
207
- original_size = cv_image.shape[:2]
208
-
209
- # Calculate a transformation matrix to convert to self.input_size
210
- scale_x = self.input_size[1] / cv_image.shape[1]
211
- scale_y = self.input_size[0] / cv_image.shape[0]
212
- scale = min(scale_x, scale_y)
213
- transform_matrix = np_array(
214
- [
215
- [scale, 0, 0],
216
- [0, scale, 0],
217
- [0, 0, 1],
218
- ]
219
- )
220
- try:
221
- cv_image = warpAffine(
222
- cv_image,
223
- transform_matrix[:2],
224
- (self.input_size[1], self.input_size[0]),
225
- flags=INTER_LINEAR,
226
- )
227
- except Exception as e_warp_affine2:
228
- app_logger.error(f"e_warp_affine2:{e_warp_affine2}.")
229
- np_cv_image = np_array(cv_image)
230
- app_logger.error(f"e_warp_affine2 cv_image shape:{np_cv_image.shape}, dtype:{np_cv_image.dtype}.")
231
- app_logger.error(f"e_warp_affine2 transform_matrix:{transform_matrix}, [:2] {transform_matrix[:2]}")
232
- app_logger.error(f"e_warp_affine2 self.input_size:{self.input_size}.")
233
- raise e_warp_affine2
234
-
235
- encoder_inputs = {
236
- self.encoder_input_name: cv_image.astype(float32),
237
- }
238
- image_embedding = self.run_encoder(encoder_inputs)
239
- return {
240
- "image_embedding": image_embedding,
241
- "original_size": original_size,
242
- "transform_matrix": transform_matrix,
243
- }
244
-
245
- def predict_masks(self, embedding, prompt):
246
- """
247
- Predict masks for a single image.
248
- """
249
- masks = self.run_decoder(
250
- embedding["image_embedding"],
251
- embedding["original_size"],
252
- embedding["transform_matrix"],
253
- prompt,
254
- )
255
-
256
- return masks
257
-
258
-
259
- def get_raster_inference(
260
- img: PIL_Image or ndarray, prompt: ListDict, models_instance: SegmentAnythingONNX, model_name: str
261
- ) -> TupleNdarrayInt:
262
- """
263
- Get inference output for a given image using a SegmentAnythingONNX model
264
-
265
- Args:
266
- img: input PIL Image
267
- prompt: list of prompt dict
268
- models_instance: SegmentAnythingONNX instance model
269
- model_name: model name string
270
-
271
- Returns:
272
- raster prediction mask, prediction number
273
- """
274
- np_img = np_array(img)
275
- app_logger.info(f"img type {type(np_img)}, prompt:{prompt}.")
276
- app_logger.debug(f"onnxruntime input shape/size (shape if PIL) {np_img.size}.")
277
- try:
278
- app_logger.debug(f"onnxruntime input shape (NUMPY) {np_img.shape}.")
279
- except Exception as e_shape:
280
- app_logger.error(f"e_shape:{e_shape}.")
281
- app_logger.info(f"instantiated model {model_name}, ENCODER {MODEL_ENCODER_NAME}, "
282
- f"DECODER {MODEL_DECODER_NAME} from {MODEL_FOLDER}: Creating embedding...")
283
- embedding = models_instance.encode(np_img)
284
- app_logger.debug(f"embedding created, running predict_masks with prompt {prompt}...")
285
- return get_raster_inference_using_existing_embedding(embedding, prompt, models_instance)
286
-
287
-
288
- def get_inference_embedding(
289
- img: PIL_Image or ndarray, models_instance: SegmentAnythingONNX, model_name: str, embedding_key: str,
290
- embedding_dict: EmbeddingDict) -> EmbeddingDict:
291
- """add an embedding to the embedding dict if needed
292
-
293
- Args:
294
- img: input PIL Image
295
- models_instance: SegmentAnythingONNX instance model
296
- model_name: model name string
297
- embedding_key: embedding id
298
- embedding_dict: embedding dict object
299
-
300
- Returns:
301
- raster dict
302
- """
303
- if embedding_key in embedding_dict:
304
- app_logger.info("found embedding in dict...")
305
- if embedding_key not in embedding_dict:
306
- np_img = np_array(img)
307
- app_logger.info(f"prepare embedding using img type {type(np_img)}.")
308
- app_logger.debug(f"onnxruntime input shape/size (shape if PIL) {np_img.size}.")
309
- try:
310
- app_logger.debug(f"onnxruntime input shape (NUMPY) {np_img.shape}.")
311
- except Exception as e_shape:
312
- app_logger.error(f"e_shape:{e_shape}.")
313
- app_logger.info(f"instantiated model {model_name}, ENCODER {MODEL_ENCODER_NAME}, "
314
- f"DECODER {MODEL_DECODER_NAME} from {MODEL_FOLDER}: Creating embedding...")
315
- embedding = models_instance.encode(np_img)
316
- embedding_dict[embedding_key] = embedding
317
- return embedding_dict
318
-
319
-
320
- def get_raster_inference_using_existing_embedding(
321
- embedding: dict, prompt: ListDict, models_instance: SegmentAnythingONNX) -> TupleNdarrayInt:
322
- """
323
- Get inference output for a given image using a SegmentAnythingONNX model, using an existing embedding instead of a
324
- new ndarray or PIL image
325
-
326
- Args:
327
- embedding: dict
328
- prompt: list of prompt dict
329
- models_instance: SegmentAnythingONNX instance model
330
-
331
- Returns:
332
- raster prediction mask, prediction number
333
- """
334
- app_logger.info(f"using existing embedding of type {type(embedding)}.")
335
- inference_out = models_instance.predict_masks(embedding, prompt)
336
- len_inference_out = len(inference_out[0, :, :, :])
337
- app_logger.info(f"Created {len_inference_out} prediction_masks,"
338
- f"shape:{inference_out.shape}, dtype:{inference_out.dtype}.")
339
- mask = zeros((inference_out.shape[2], inference_out.shape[3]), dtype=uint8)
340
- for n, m in enumerate(inference_out[0, :, :, :]):
341
- app_logger.debug(f"{n}th of prediction_masks shape {inference_out.shape}"
342
- f" => mask shape:{mask.shape}, {mask.dtype}.")
343
- mask[m > 0.0] = 255
344
- return mask, len_inference_out
345
-
346
-
347
- def get_raster_inference_with_embedding_from_dict(
348
- img: PIL_Image or ndarray, prompt: ListDict, models_instance: SegmentAnythingONNX, model_name: str,
349
- embedding_key: str, embedding_dict: dict) -> TupleNdarrayInt:
350
- """
351
- Get inference output using a SegmentAnythingONNX model, but get the image embedding from the given embedding dict
352
- instead of creating a new embedding. This function needs the img argument to update the embedding dict if necessary
353
-
354
- Args:
355
- img: input PIL Image
356
- prompt: list of prompt dict
357
- models_instance: SegmentAnythingONNX instance model
358
- model_name: model name string
359
- embedding_key: embedding id
360
- embedding_dict: embedding images dict
361
-
362
- Returns:
363
- raster prediction mask, prediction number
364
- """
365
- app_logger.info(f"handling embedding using key {embedding_key}.")
366
- embedding_dict = get_inference_embedding(img, models_instance, model_name, embedding_key, embedding_dict)
367
- app_logger.info(f"getting embedding with key {embedding_key} from dict...")
368
- embedding = embedding_dict[embedding_key]
369
- n_keys = len(embedding_dict)
370
- app_logger.info(f"embedding created ({n_keys} keys in embedding dict), running predict_masks with prompt {prompt}.")
371
- return get_raster_inference_using_existing_embedding(embedding, prompt, models_instance)
@@ -1,13 +0,0 @@
1
- samgis_core/__init__.py,sha256=1kFX8G22dxNz23J7uOYl-SMWOe4W1olssc-5zKAVwSc,351
2
- samgis_core/__version__.py,sha256=x-8LeQkljky-ao1MyT98tSQE1xNvJp3LelqsAdCb9Og,94
3
- samgis_core/prediction_api/__init__.py,sha256=_jUZhspS26ygiSzBJywfZ4fQf9X7oC8w6oxvhd9S2hQ,57
4
- samgis_core/prediction_api/sam_onnx.py,sha256=YEX6UOFKcqPiuF_SCPAcK_gCmxoj4lvTLNNDul2htZ8,15320
5
- samgis_core/utilities/__init__.py,sha256=nL9pzdB4SdEF8m5gCbtlVCtdGLg9JjPm-FNxKBsIBZA,32
6
- samgis_core/utilities/constants.py,sha256=645W57jrUNbnujgUwCietfr-rECENDXLGmHeD2YoSwg,157
7
- samgis_core/utilities/fastapi_logger.py,sha256=yt8Vj1viyE-Kry1_iy5p4saZvGEEmRfuTEMHYXgnjqk,650
8
- samgis_core/utilities/serialize.py,sha256=aIjhEoibBpV_gpgOg6LiVxZCWjOkYxlzcboDZLQctJE,2689
9
- samgis_core/utilities/type_hints.py,sha256=iDkWiVB2B0W_GaRhdpypEra2FvQtwC5lNLVZr7a4yhU,845
10
- samgis_core/utilities/utilities.py,sha256=fP3cnxIYULYoCBft54EAivUR_fMcCDz2Z9AMpAO0zZ0,2434
11
- samgis_core-1.2.1.dist-info/METADATA,sha256=L0FMNdjcgVtJnGr3j78buKJlry9o340vjIWL0b_TUrQ,1147
12
- samgis_core-1.2.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
- samgis_core-1.2.1.dist-info/RECORD,,