frontveg 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
frontveg/__init__.py CHANGED
@@ -3,15 +3,7 @@ try:
3
3
  except ImportError:
4
4
  __version__ = "unknown"
5
5
  from ._widget import (
6
- ExampleQWidget,
7
- ImageThreshold,
8
- threshold_autogenerate_widget,
9
- threshold_magic_widget,
6
+ vegetation,
10
7
  )
11
8
 
12
- __all__ = (
13
- "ExampleQWidget",
14
- "ImageThreshold",
15
- "threshold_autogenerate_widget",
16
- "threshold_magic_widget",
17
- )
9
+ __all__ = ("vegetation",)
frontveg/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.1'
21
- __version_tuple__ = version_tuple = (0, 2, 1)
20
+ __version__ = version = '0.3.0'
21
+ __version_tuple__ = version_tuple = (0, 3, 0)
frontveg/_widget.py CHANGED
@@ -31,99 +31,116 @@ Replace code below according to your needs.
31
31
 
32
32
  from typing import TYPE_CHECKING
33
33
 
34
+ import numpy as np
35
+ import torch
34
36
  from magicgui import magic_factory
35
- from magicgui.widgets import CheckBox, Container, create_widget
36
- from qtpy.QtWidgets import QHBoxLayout, QPushButton, QWidget
37
- from skimage.util import img_as_float
37
+ from PIL import Image
38
+ from scipy import ndimage
39
+ from transformers import pipeline
38
40
 
39
41
  if TYPE_CHECKING:
40
42
  import napari
43
+ from frontveg.utils import frontground_part, ground_dino, sam2
41
44
 
42
-
43
- # Uses the `autogenerate: true` flag in the plugin manifest
44
- # to indicate it should be wrapped as a magicgui to autogenerate
45
- # a widget.
46
- def threshold_autogenerate_widget(
47
- img: "napari.types.ImageData",
48
- threshold: "float",
49
- ) -> "napari.types.LabelsData":
50
- return img_as_float(img) > threshold
45
+ pipe = pipeline(
46
+ task="depth-estimation", model="depth-anything/Depth-Anything-V2-Large-hf"
47
+ )
51
48
 
52
49
 
53
- # the magic_factory decorator lets us customize aspects of our widget
54
- # we specify a widget type for the threshold parameter
55
- # and use auto_call=True so the function is called whenever
56
- # the value of a parameter changes
57
- @magic_factory(
58
- threshold={"widget_type": "FloatSlider", "max": 1}, auto_call=True
59
- )
60
- def threshold_magic_widget(
61
- img_layer: "napari.layers.Image", threshold: "float"
50
+ @magic_factory(call_button="Run")
51
+ def vegetation(
52
+ input_data: "napari.types.ImageData",
62
53
  ) -> "napari.types.LabelsData":
63
- return img_as_float(img_layer.data) > threshold
64
-
65
-
66
- # if we want even more control over our widget, we can use
67
- # magicgui `Container`
68
- class ImageThreshold(Container):
69
- def __init__(self, viewer: "napari.viewer.Viewer"):
70
- super().__init__()
71
- self._viewer = viewer
72
- # use create_widget to generate widgets from type annotations
73
- self._image_layer_combo = create_widget(
74
- label="Image", annotation="napari.layers.Image"
54
+ device = "cuda"
55
+
56
+ if input_data.ndim == 4:
57
+ output_data = np.zeros(
58
+ (input_data.shape[0], input_data.shape[1], input_data.shape[2]),
59
+ dtype="uint8",
75
60
  )
76
- self._threshold_slider = create_widget(
77
- label="Threshold", annotation=float, widget_type="FloatSlider"
61
+ INPUT = []
62
+ for i in range(len(input_data)):
63
+ rgb_data = input_data[i, :, :, :].compute()
64
+ image = Image.fromarray(rgb_data)
65
+ INPUT.append(image)
66
+ else:
67
+ output_data = np.zeros(
68
+ (1, input_data.shape[0], input_data.shape[1]), dtype="uint8"
78
69
  )
79
- self._threshold_slider.min = 0
80
- self._threshold_slider.max = 1
81
- # use magicgui widgets directly
82
- self._invert_checkbox = CheckBox(text="Keep pixels below threshold")
83
-
84
- # connect your own callbacks
85
- self._threshold_slider.changed.connect(self._threshold_im)
86
- self._invert_checkbox.changed.connect(self._threshold_im)
87
-
88
- # append into/extend the container with your widgets
89
- self.extend(
90
- [
91
- self._image_layer_combo,
92
- self._threshold_slider,
93
- self._invert_checkbox,
94
- ]
70
+ rgb_data = input_data
71
+ image = Image.fromarray(rgb_data)
72
+ INPUT = [image]
73
+ depth = pipe(INPUT)
74
+ n = len(depth)
75
+
76
+ model, processor = ground_dino()
77
+ predictor, text_labels = sam2()
78
+
79
+ for i in range(n):
80
+ depth_pred = depth[i]["depth"]
81
+ msks_depth = np.array(depth_pred)
82
+ msks_front = frontground_part(msks_depth)
83
+ msks_front = msks_front.astype(np.uint8) * 255
84
+
85
+ image = INPUT[i]
86
+ inputs = processor(
87
+ images=image, text=text_labels, return_tensors="pt"
88
+ ).to(device)
89
+ with torch.no_grad():
90
+ outputs = model(**inputs)
91
+
92
+ results = processor.post_process_grounded_object_detection(
93
+ outputs,
94
+ inputs.input_ids,
95
+ box_threshold=0.4,
96
+ text_threshold=0.3,
97
+ target_sizes=[image.size[::-1]],
95
98
  )
96
99
 
97
- def _threshold_im(self):
98
- image_layer = self._image_layer_combo.value
99
- if image_layer is None:
100
- return
101
-
102
- image = img_as_float(image_layer.data)
103
- name = image_layer.name + "_thresholded"
104
- threshold = self._threshold_slider.value
105
- if self._invert_checkbox.value:
106
- thresholded = image < threshold
100
+ # Retrieve the first image result
101
+ result = results[0]
102
+ for box, score, labels in zip(
103
+ result["boxes"], result["scores"], result["labels"], strict=False
104
+ ):
105
+ box = [round(x, 2) for x in box.tolist()]
106
+ print(
107
+ f"Detected {labels} with confidence {round(score.item(), 3)} at location {box}"
108
+ )
109
+ if len(result["boxes"]) == 0:
110
+ masks = np.zeros(image.size[::-1], dtype="uint8")
107
111
  else:
108
- thresholded = image > threshold
109
- if name in self._viewer.layers:
110
- self._viewer.layers[name].data = thresholded
112
+ with (
113
+ torch.inference_mode(),
114
+ torch.autocast("cuda", dtype=torch.bfloat16),
115
+ ):
116
+ predictor.set_image(image)
117
+ masks_sam, _, _ = predictor.predict(
118
+ box=result["boxes"],
119
+ point_labels=result["labels"],
120
+ multimask_output=False,
121
+ )
122
+ if masks_sam.ndim == 4:
123
+ masks = np.sum(masks_sam, axis=0)
124
+ masks = masks[0, :, :]
125
+ else:
126
+ masks = masks_sam[0, :, :]
127
+
128
+ msks_veg = masks.astype(np.uint8) * 255
129
+
130
+ mask1 = msks_front.copy() # Masque 1
131
+ mask2 = msks_veg.copy() # Masque 2
132
+ mask2 = ndimage.binary_fill_holes(mask2) # Fill holes
133
+ mask1 = (mask1 > 0).astype(np.uint8) # Convertir en binaire
134
+ mask2 = (mask2 > 0).astype(np.uint8) # Convertir en binaire
135
+ if len(np.unique(mask2)) == 2:
136
+ intersection = (
137
+ mask1 & mask2
138
+ ) # Intersection : les pixels qui sont 1 dans les deux masques
139
+ intersection = intersection > 0
111
140
  else:
112
- self._viewer.add_labels(thresholded, name=name)
113
-
114
-
115
- class ExampleQWidget(QWidget):
116
- # your QWidget.__init__ can optionally request the napari viewer instance
117
- # use a type annotation of 'napari.viewer.Viewer' for any parameter
118
- def __init__(self, viewer: "napari.viewer.Viewer"):
119
- super().__init__()
120
- self.viewer = viewer
121
-
122
- btn = QPushButton("Click me!")
123
- btn.clicked.connect(self._on_click)
124
-
125
- self.setLayout(QHBoxLayout())
126
- self.layout().addWidget(btn)
127
-
128
- def _on_click(self):
129
- print("napari has", len(self.viewer.layers), "layers")
141
+ intersection = mask1.copy()
142
+ intersection = (intersection * 255).astype(
143
+ np.uint8
144
+ ) # Si tu veux un masque avec des 0 et 255 (ex. pour OpenCV)
145
+ output_data[i, :, :] = intersection
146
+ return output_data
frontveg/napari.yaml CHANGED
@@ -6,25 +6,9 @@ visibility: public
6
6
  categories: ["Annotation", "Segmentation", "Acquisition"]
7
7
  contributions:
8
8
  commands:
9
- - id: frontveg.make_container_widget
10
- python_name: frontveg:ImageThreshold
11
- title: Make threshold Container widget
12
- - id: frontveg.make_magic_widget
13
- python_name: frontveg:threshold_magic_widget
14
- title: Make threshold magic widget
15
- - id: frontveg.make_function_widget
16
- python_name: frontveg:threshold_autogenerate_widget
17
- title: Make threshold function widget
18
- - id: frontveg.make_qwidget
19
- python_name: frontveg:ExampleQWidget
20
- title: Make example QWidget
9
+ - id: frontveg.vegetation
10
+ python_name: frontveg:vegetation
11
+ title: Vegetation plugin
21
12
  widgets:
22
- - command: frontveg.make_container_widget
23
- display_name: Container Threshold
24
- - command: frontveg.make_magic_widget
25
- display_name: Magic Threshold
26
- - command: frontveg.make_function_widget
27
- autogenerate: true
28
- display_name: Autogenerate Threshold
29
- - command: frontveg.make_qwidget
30
- display_name: Example QWidget
13
+ - command: frontveg.vegetation
14
+ display_name: Frontground vegetation
frontveg/utils.py ADDED
@@ -0,0 +1,109 @@
1
+ import os
2
+ from collections import Counter
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from scipy.signal import find_peaks
7
+ from tqdm import tqdm
8
+ from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
9
+
10
+ # CONF = config.get_conf_dict()
11
+ homedir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12
+
13
+ # base_dir = CONF['general']['base_directory']
14
+ base_dir = "."
15
+
16
+ model_id = "IDEA-Research/grounding-dino-tiny"
17
+ device = "cuda"
18
+
19
+ processor = AutoProcessor.from_pretrained(model_id)
20
+ model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(
21
+ device
22
+ )
23
+
24
+
25
+ def ground_dino():
26
+ return model, processor
27
+
28
+
29
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
30
+
31
+ predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
32
+ text_labels = ["green region. foliage."]
33
+
34
+
35
+ def sam2():
36
+ return predictor, text_labels
37
+
38
+
39
+ def minimum_betw_max(dico_, visua=False):
40
+ Ax = list(dico_.keys())
41
+ Ay = list(dico_.values())
42
+
43
+ # Approximation par une régression polynomiale
44
+ x = Ax[1:]
45
+ y = Ay[1:]
46
+ degree = 14 # Choisissez le degré selon la complexité de la courbe
47
+ coefficients = np.polyfit(x, y, degree)
48
+ polynomial = np.poly1d(coefficients)
49
+
50
+ # Points lissés pour tracer la courbe
51
+ x_fit = np.linspace(min(x), max(x), 500)
52
+ y_fit = polynomial(x_fit)
53
+
54
+ # Détection des maxima
55
+ peaks, _ = find_peaks(y_fit)
56
+
57
+ peak_values = y_fit[peaks]
58
+ sorted_indices = np.argsort(peak_values)[
59
+ ::-1
60
+ ] # Trier en ordre décroissant
61
+ top_two_peaks = peaks[
62
+ sorted_indices[:2]
63
+ ] # Les indices des deux plus grands pics
64
+
65
+ # Trouver le minimum entre les deux maxima
66
+ x_min_range = x_fit[top_two_peaks[0] : top_two_peaks[1] + 1]
67
+ y_min_range = y_fit[top_two_peaks[0] : top_two_peaks[1] + 1]
68
+ minx = min([top_two_peaks[0], top_two_peaks[1]])
69
+ maxx = max([top_two_peaks[0], top_two_peaks[1]])
70
+ x_min_range = x_fit[minx : maxx + 1]
71
+ y_min_range = y_fit[minx : maxx + 1]
72
+ min_index = np.argmin(y_min_range) # Index du minimum dans cette plage
73
+ x_min = x_min_range[min_index]
74
+ y_min = y_min_range[min_index]
75
+
76
+ if visua:
77
+ # Tracé
78
+ plt.scatter(x, y, color="blue")
79
+ plt.plot(x_fit, y_fit, color="red", label="Polynomial regression")
80
+ plt.scatter(
81
+ x_fit[top_two_peaks],
82
+ y_fit[top_two_peaks],
83
+ color="green",
84
+ label="Local maximum",
85
+ )
86
+ plt.scatter(x_min, y_min, color="orange", s=100, label="Local minimum")
87
+ plt.legend()
88
+ plt.xlabel("Depth pixel")
89
+ plt.ylabel("Count")
90
+ # plt.title('Approximation et détection des points maximum')
91
+ plt.show()
92
+ return x_min, y_min
93
+
94
+
95
+ def frontground_part(depths):
96
+ depth_one = depths[:, :]
97
+ n, m = depth_one.shape
98
+ A = []
99
+ for i in tqdm(range(n)):
100
+ for j in range(m):
101
+ A.append([i, j, depth_one[i, j]])
102
+ X = np.array(A)
103
+
104
+ dico_ = Counter(X[:, 2])
105
+ min_coord = minimum_betw_max(dico_, visua=False)
106
+
107
+ th_ = min_coord[0]
108
+ msks_depth = depth_one > th_
109
+ return msks_depth
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: frontveg
3
- Version: 0.2.1
3
+ Version: 0.3.0
4
4
  Summary: Segmentation of vegetation located to close to camera
5
5
  Author: Herearii Metuarea
6
6
  Author-email: herearii.metuarea@univ-angers.fr
@@ -50,13 +50,19 @@ Classifier: Programming Language :: Python :: 3.11
50
50
  Classifier: Programming Language :: Python :: 3.12
51
51
  Classifier: Programming Language :: Python :: 3.13
52
52
  Classifier: Topic :: Scientific/Engineering :: Image Processing
53
- Requires-Python: >=3.10
53
+ Requires-Python: ==3.11.12
54
54
  Description-Content-Type: text/markdown
55
55
  License-File: LICENSE
56
56
  Requires-Dist: numpy
57
57
  Requires-Dist: magicgui
58
58
  Requires-Dist: qtpy
59
59
  Requires-Dist: scikit-image
60
+ Requires-Dist: transformers==4.51.3
61
+ Requires-Dist: torch>=2.3.1
62
+ Requires-Dist: torchvision>=0.18.1
63
+ Requires-Dist: hydra-core==1.3.2
64
+ Requires-Dist: iopath>=0.1.10
65
+ Requires-Dist: pillow>=9.4.0
60
66
  Provides-Extra: testing
61
67
  Requires-Dist: tox; extra == "testing"
62
68
  Requires-Dist: pytest; extra == "testing"
@@ -0,0 +1,13 @@
1
+ frontveg/__init__.py,sha256=m4oqJTxKFurizNLN-4HNrqF-hvjQkdyLMbIULtTd1NA,179
2
+ frontveg/_version.py,sha256=9WppBElv1NvAyEcEfKgQ7lPPwwYXq7rnN9aaDi0TN-s,532
3
+ frontveg/_widget.py,sha256=eZS0gv2f8NjMzskHFF0J_zscBWPVstqIQaSXIFhJOG4,5389
4
+ frontveg/napari.yaml,sha256=33HxiAA2If2tjogtnkb5PYfeR8bXLxW4uWKV88kDYKQ,511
5
+ frontveg/utils.py,sha256=3t011wr99KhJ4nW-lRYXoFLjZu_DULQOPcoXAlEKEK0,3214
6
+ frontveg/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ frontveg/_tests/test_widget.py,sha256=jaFBX-JpnaRHDvQvU6QbUhOiL6Aejn2yljq11yl3hmY,2265
8
+ frontveg-0.3.0.dist-info/licenses/LICENSE,sha256=2qUWKx6xVq9efOuuI6lxeftgMSY2njkm5Qy4HXLRQgA,1520
9
+ frontveg-0.3.0.dist-info/METADATA,sha256=5DKXT3bY049syECZGRRsQQn7DiefgHXIRjJ4wsqIJMY,7967
10
+ frontveg-0.3.0.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
11
+ frontveg-0.3.0.dist-info/entry_points.txt,sha256=VMaRha_yYtIcJAdA0suCmR0of0MZJfUaUn2aKSYtR0I,50
12
+ frontveg-0.3.0.dist-info/top_level.txt,sha256=skkajXDCaVFNYqsXXqsUv6fqlA6Pl-2cLwKJO52ldBI,9
13
+ frontveg-0.3.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.7.1)
2
+ Generator: setuptools (80.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,12 +0,0 @@
1
- frontveg/__init__.py,sha256=BVJaebsOBu1m--T0j2UjXzD_pG4zBvZ1PHfH3nga9js,373
2
- frontveg/_version.py,sha256=cTPlZaUCc20I4ZWsDjY35UftpFNRgfDaDBgkWxfIQmg,532
3
- frontveg/_widget.py,sha256=gyCQpmWr20TvgwkurBZTG5EeBoGEk2GxOEAiS3Zqmpg,4940
4
- frontveg/napari.yaml,sha256=YTDShC2Rt39ypSM-opRP4lNDOncghjanBNydozPcHvE,1208
5
- frontveg/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- frontveg/_tests/test_widget.py,sha256=jaFBX-JpnaRHDvQvU6QbUhOiL6Aejn2yljq11yl3hmY,2265
7
- frontveg-0.2.1.dist-info/licenses/LICENSE,sha256=2qUWKx6xVq9efOuuI6lxeftgMSY2njkm5Qy4HXLRQgA,1520
8
- frontveg-0.2.1.dist-info/METADATA,sha256=O_Yf9RJNGxSvG_-w4WKW87x48qy0mKqT6BYPVT6eMZ0,7767
9
- frontveg-0.2.1.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
10
- frontveg-0.2.1.dist-info/entry_points.txt,sha256=VMaRha_yYtIcJAdA0suCmR0of0MZJfUaUn2aKSYtR0I,50
11
- frontveg-0.2.1.dist-info/top_level.txt,sha256=skkajXDCaVFNYqsXXqsUv6fqlA6Pl-2cLwKJO52ldBI,9
12
- frontveg-0.2.1.dist-info/RECORD,,