samgis_core 1.0.8__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- samgis_core/prediction_api/sam_onnx.py +85 -7
- samgis_core/utilities/fastapi_logger.py +4 -2
- samgis_core/utilities/type_hints.py +17 -0
- {samgis_core-1.0.8.dist-info → samgis_core-1.1.0.dist-info}/METADATA +1 -1
- {samgis_core-1.0.8.dist-info → samgis_core-1.1.0.dist-info}/RECORD +6 -6
- {samgis_core-1.0.8.dist-info → samgis_core-1.1.0.dist-info}/WHEEL +0 -0
@@ -32,7 +32,7 @@ from onnxruntime import get_available_providers, InferenceSession
|
|
32
32
|
|
33
33
|
from samgis_core import app_logger, MODEL_FOLDER
|
34
34
|
from samgis_core.utilities.constants import DEFAULT_INPUT_SHAPE, MODEL_ENCODER_NAME, MODEL_DECODER_NAME
|
35
|
-
from samgis_core.utilities.type_hints import PIL_Image, list_dict, tuple_ndarr_int
|
35
|
+
from samgis_core.utilities.type_hints import PIL_Image, list_dict, tuple_ndarr_int, EmbeddingDict, EmbeddingImage
|
36
36
|
|
37
37
|
|
38
38
|
class SegmentAnythingONNX:
|
@@ -114,7 +114,7 @@ class SegmentAnythingONNX:
|
|
114
114
|
return coords
|
115
115
|
|
116
116
|
def run_decoder(
|
117
|
-
|
117
|
+
self, image_embedding, original_size, transform_matrix, prompt
|
118
118
|
):
|
119
119
|
"""Run decoder"""
|
120
120
|
input_points, input_labels = self.get_input_points(prompt)
|
@@ -124,8 +124,8 @@ class SegmentAnythingONNX:
|
|
124
124
|
[input_points, np_array([[0.0, 0.0]])], axis=0
|
125
125
|
)[None, :, :]
|
126
126
|
onnx_label = concatenate([input_labels, np_array([-1])], axis=0)[
|
127
|
-
|
128
|
-
|
127
|
+
None, :
|
128
|
+
].astype(float32)
|
129
129
|
onnx_coord = self.apply_coords(
|
130
130
|
onnx_coord, self.input_size, self.target_size
|
131
131
|
).astype(float32)
|
@@ -186,14 +186,15 @@ class SegmentAnythingONNX:
|
|
186
186
|
)
|
187
187
|
except Exception as e_warp_affine1:
|
188
188
|
app_logger.error(f"e_warp_affine1 mask shape:{mask.shape}, dtype:{mask.dtype}.")
|
189
|
-
app_logger.error(
|
189
|
+
app_logger.error(
|
190
|
+
f"e_warp_affine1 transform_matrix:{transform_matrix}, [:2] {transform_matrix[:2]}.")
|
190
191
|
app_logger.error(f"e_warp_affine1 original_size:{original_size}.")
|
191
192
|
raise e_warp_affine1
|
192
193
|
batch_masks.append(mask)
|
193
194
|
output_masks.append(batch_masks)
|
194
195
|
return np_array(output_masks)
|
195
196
|
|
196
|
-
def encode(self, cv_image):
|
197
|
+
def encode(self, cv_image: ndarray) -> EmbeddingImage:
|
197
198
|
"""
|
198
199
|
Calculate embedding and metadata for a single image.
|
199
200
|
"""
|
@@ -251,7 +252,7 @@ class SegmentAnythingONNX:
|
|
251
252
|
|
252
253
|
def get_raster_inference(
|
253
254
|
img: PIL_Image or ndarray, prompt: list_dict, models_instance: SegmentAnythingONNX, model_name: str
|
254
|
-
|
255
|
+
) -> tuple_ndarr_int:
|
255
256
|
"""
|
256
257
|
Get inference output for a given image using a SegmentAnythingONNX model
|
257
258
|
|
@@ -275,6 +276,56 @@ def get_raster_inference(
|
|
275
276
|
f"DECODER {MODEL_DECODER_NAME} from {MODEL_FOLDER}: Creating embedding...")
|
276
277
|
embedding = models_instance.encode(np_img)
|
277
278
|
app_logger.debug(f"embedding created, running predict_masks with prompt {prompt}...")
|
279
|
+
return get_raster_inference_using_existing_embedding(embedding, prompt, models_instance)
|
280
|
+
|
281
|
+
|
282
|
+
def get_inference_embedding(
|
283
|
+
img: PIL_Image or ndarray, models_instance: SegmentAnythingONNX, model_name: str, embedding_key: str,
|
284
|
+
embedding_dict: EmbeddingDict) -> EmbeddingDict:
|
285
|
+
"""add an embedding to the embedding dict if needed
|
286
|
+
|
287
|
+
Args:
|
288
|
+
img: input PIL Image
|
289
|
+
models_instance: SegmentAnythingONNX instance model
|
290
|
+
model_name: model name string
|
291
|
+
embedding_key: embedding id
|
292
|
+
embedding_dict: embedding dict object
|
293
|
+
|
294
|
+
Returns:
|
295
|
+
raster dict
|
296
|
+
"""
|
297
|
+
if embedding_key in embedding_dict:
|
298
|
+
app_logger.info("found embedding in dict...")
|
299
|
+
if embedding_key not in embedding_dict:
|
300
|
+
np_img = np_array(img)
|
301
|
+
app_logger.info(f"prepare embedding using img type {type(np_img)}.")
|
302
|
+
app_logger.debug(f"onnxruntime input shape/size (shape if PIL) {np_img.size}.")
|
303
|
+
try:
|
304
|
+
app_logger.debug(f"onnxruntime input shape (NUMPY) {np_img.shape}.")
|
305
|
+
except Exception as e_shape:
|
306
|
+
app_logger.error(f"e_shape:{e_shape}.")
|
307
|
+
app_logger.info(f"instantiated model {model_name}, ENCODER {MODEL_ENCODER_NAME}, "
|
308
|
+
f"DECODER {MODEL_DECODER_NAME} from {MODEL_FOLDER}: Creating embedding...")
|
309
|
+
embedding = models_instance.encode(np_img)
|
310
|
+
embedding_dict[embedding_key] = embedding
|
311
|
+
return embedding_dict
|
312
|
+
|
313
|
+
|
314
|
+
def get_raster_inference_using_existing_embedding(
|
315
|
+
embedding: dict, prompt: list_dict, models_instance: SegmentAnythingONNX) -> tuple_ndarr_int:
|
316
|
+
"""
|
317
|
+
Get inference output for a given image using a SegmentAnythingONNX model, using an existing embedding instead of a
|
318
|
+
new ndarray or PIL image
|
319
|
+
|
320
|
+
Args:
|
321
|
+
embedding: dict
|
322
|
+
prompt: list of prompt dict
|
323
|
+
models_instance: SegmentAnythingONNX instance model
|
324
|
+
|
325
|
+
Returns:
|
326
|
+
raster prediction mask, prediction number
|
327
|
+
"""
|
328
|
+
app_logger.info(f"using existing embedding of type {type(embedding)}.")
|
278
329
|
inference_out = models_instance.predict_masks(embedding, prompt)
|
279
330
|
len_inference_out = len(inference_out[0, :, :, :])
|
280
331
|
app_logger.info(f"Created {len_inference_out} prediction_masks,"
|
@@ -285,3 +336,30 @@ def get_raster_inference(
|
|
285
336
|
f" => mask shape:{mask.shape}, {mask.dtype}.")
|
286
337
|
mask[m > 0.0] = 255
|
287
338
|
return mask, len_inference_out
|
339
|
+
|
340
|
+
|
341
|
+
def get_raster_inference_with_embedding_from_dict(
|
342
|
+
img: PIL_Image or ndarray, prompt: list_dict, models_instance: SegmentAnythingONNX, model_name: str,
|
343
|
+
embedding_key: str, embedding_dict: dict) -> tuple_ndarr_int:
|
344
|
+
"""
|
345
|
+
Get inference output using a SegmentAnythingONNX model, but get the image embedding from the given embedding dict
|
346
|
+
instead of creating a new embedding. This function needs the img argument to update the embedding dict if necessary
|
347
|
+
|
348
|
+
Args:
|
349
|
+
img: input PIL Image
|
350
|
+
prompt: list of prompt dict
|
351
|
+
models_instance: SegmentAnythingONNX instance model
|
352
|
+
model_name: model name string
|
353
|
+
embedding_key: embedding id
|
354
|
+
embedding_dict: embedding dict object
|
355
|
+
|
356
|
+
Returns:
|
357
|
+
raster prediction mask, prediction number
|
358
|
+
"""
|
359
|
+
app_logger.info(f"handling embedding using key {embedding_key}.")
|
360
|
+
embedding_dict = get_inference_embedding(img, models_instance, model_name, embedding_key, embedding_dict)
|
361
|
+
app_logger.info(f"getting embedding with key {embedding_key} from dict...")
|
362
|
+
embedding = embedding_dict[embedding_key]
|
363
|
+
n_keys = len(embedding_dict)
|
364
|
+
app_logger.info(f"embedding created ({n_keys} keys in embedding dict), running predict_masks with prompt {prompt}.")
|
365
|
+
return get_raster_inference_using_existing_embedding(embedding, prompt, models_instance)
|
@@ -1,8 +1,10 @@
|
|
1
1
|
import loguru
|
2
2
|
|
3
3
|
|
4
|
-
|
5
|
-
|
4
|
+
format_string = "{time} - {level} - {file} - {function} - ({extra[request_id]}) {message} "
|
5
|
+
|
6
|
+
|
7
|
+
def setup_logging(debug: bool = False, formatter: str = format_string) -> loguru.logger:
|
6
8
|
"""
|
7
9
|
Create a logging instance with log string formatter.
|
8
10
|
|
@@ -1,7 +1,11 @@
|
|
1
1
|
"""custom type hints"""
|
2
|
+
from enum import Enum
|
3
|
+
from typing import TypedDict
|
4
|
+
|
2
5
|
from PIL.Image import Image
|
3
6
|
from numpy import ndarray
|
4
7
|
|
8
|
+
|
5
9
|
dict_str_int = dict[str, int]
|
6
10
|
dict_str = dict[str]
|
7
11
|
dict_str_any = dict[str, any]
|
@@ -14,3 +18,16 @@ llist_float = list[list_float]
|
|
14
18
|
tuple_float = tuple[float]
|
15
19
|
tuple_float_any = tuple[float, any]
|
16
20
|
PIL_Image = Image
|
21
|
+
|
22
|
+
|
23
|
+
class StrEnum(str, Enum):
|
24
|
+
pass
|
25
|
+
|
26
|
+
|
27
|
+
class EmbeddingImage(TypedDict):
|
28
|
+
image_embedding: ndarray
|
29
|
+
original_size: tuple_int
|
30
|
+
transform_matrix: ndarray
|
31
|
+
|
32
|
+
|
33
|
+
EmbeddingDict = dict[str, EmbeddingImage]
|
@@ -1,13 +1,13 @@
|
|
1
1
|
samgis_core/__init__.py,sha256=1kFX8G22dxNz23J7uOYl-SMWOe4W1olssc-5zKAVwSc,351
|
2
2
|
samgis_core/__version__.py,sha256=x-8LeQkljky-ao1MyT98tSQE1xNvJp3LelqsAdCb9Og,94
|
3
3
|
samgis_core/prediction_api/__init__.py,sha256=_jUZhspS26ygiSzBJywfZ4fQf9X7oC8w6oxvhd9S2hQ,57
|
4
|
-
samgis_core/prediction_api/sam_onnx.py,sha256=
|
4
|
+
samgis_core/prediction_api/sam_onnx.py,sha256=G5jJUyidLC_-9VrVmQRUHQpKQHW-vnv_EC7ppNap9VI,15172
|
5
5
|
samgis_core/utilities/__init__.py,sha256=nL9pzdB4SdEF8m5gCbtlVCtdGLg9JjPm-FNxKBsIBZA,32
|
6
6
|
samgis_core/utilities/constants.py,sha256=645W57jrUNbnujgUwCietfr-rECENDXLGmHeD2YoSwg,157
|
7
|
-
samgis_core/utilities/fastapi_logger.py,sha256=
|
7
|
+
samgis_core/utilities/fastapi_logger.py,sha256=yt8Vj1viyE-Kry1_iy5p4saZvGEEmRfuTEMHYXgnjqk,650
|
8
8
|
samgis_core/utilities/serialize.py,sha256=iWi_m_7vL7Hmr3Lb5yASXU5DtoUU18M_JdwiRSOs9sI,2749
|
9
|
-
samgis_core/utilities/type_hints.py,sha256=
|
9
|
+
samgis_core/utilities/type_hints.py,sha256=R-xfFbiGz12tuHqChlVI9GguR9Q7pxceFkCi_OmscM4,659
|
10
10
|
samgis_core/utilities/utilities.py,sha256=fP3cnxIYULYoCBft54EAivUR_fMcCDz2Z9AMpAO0zZ0,2434
|
11
|
-
samgis_core-1.0.
|
12
|
-
samgis_core-1.0.
|
13
|
-
samgis_core-1.0.
|
11
|
+
samgis_core-1.1.0.dist-info/METADATA,sha256=Zj7E1tpVyMtVbYH4uKYCTNqdYPrjuzgabs878EkNONw,892
|
12
|
+
samgis_core-1.1.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
13
|
+
samgis_core-1.1.0.dist-info/RECORD,,
|
File without changes
|