tau-fibrils-yolo 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2 @@
1
+ from tau_fibrils_yolo._widget import YoloDetectorWidget
2
+ from tau_fibrils_yolo.predict import FibrilsDetector
@@ -0,0 +1,16 @@
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '0.0.1'
16
+ __version_tuple__ = version_tuple = (0, 0, 1)
@@ -0,0 +1,262 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import napari
4
+ import napari.layers
5
+ from napari.utils.notifications import show_info
6
+ from napari.qt.threading import thread_worker
7
+ from PyQt5.QtCore import Qt
8
+ from qtpy.QtWidgets import (
9
+ QComboBox,
10
+ QGridLayout,
11
+ QLabel,
12
+ QProgressBar,
13
+ QPushButton,
14
+ QWidget,
15
+ QSizePolicy,
16
+ QDoubleSpinBox,
17
+ QFileDialog,
18
+ )
19
+ from tau_fibrils_yolo.predict import FibrilsDetector
20
+ from tau_fibrils_yolo.postprocess import (
21
+ boxes_kernel_density_map,
22
+ minimum_neighbor_distance_filter,
23
+ outside_image_border_filter,
24
+ overlapping_obb_filter,
25
+ )
26
+ from tau_fibrils_yolo.crossover_distance import crossover_distance_measurement
27
+
28
+
29
+ class YoloDetectorWidget(QWidget):
30
+ def __init__(self, napari_viewer):
31
+ super().__init__()
32
+
33
+ self.viewer = napari_viewer
34
+ self.predictor = FibrilsDetector()
35
+
36
+ self.shapes_layer = None
37
+ self.crossover_shapes_layer = None
38
+ self.df = None
39
+
40
+ # Layout
41
+ grid_layout = QGridLayout()
42
+ grid_layout.setAlignment(Qt.AlignTop)
43
+ self.setLayout(grid_layout)
44
+
45
+ # Image
46
+ self.cb_image = QComboBox()
47
+ self.cb_image.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
48
+ grid_layout.addWidget(QLabel("Image", self), 0, 0)
49
+ grid_layout.addWidget(self.cb_image, 0, 1)
50
+
51
+ # Rescale factor
52
+ self.bx_rescale = QDoubleSpinBox()
53
+ self.bx_rescale.setMinimum(0.01)
54
+ self.bx_rescale.setMaximum(100.0)
55
+ self.bx_rescale.setSingleStep(0.05)
56
+ self.bx_rescale.setValue(1.0)
57
+ grid_layout.addWidget(QLabel("Rescale factor", self), 2, 0)
58
+ grid_layout.addWidget(self.bx_rescale, 2, 1)
59
+
60
+ # Compute button
61
+ self.btn = QPushButton("Detect fibrils", self)
62
+ self.btn.clicked.connect(self._start_detection)
63
+ grid_layout.addWidget(self.btn, 3, 0, 1, 2)
64
+
65
+ # Progress bar
66
+ self.pbar = QProgressBar(self, minimum=0, maximum=1)
67
+ self.pbar.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
68
+ grid_layout.addWidget(self.pbar, 4, 0, 1, 2)
69
+
70
+ self.cd_label = QLabel("Crossover distance: -", self)
71
+ grid_layout.addWidget(self.cd_label, 5, 0, 1, 2)
72
+
73
+ export_btn = QPushButton("Export results (.csv)", self)
74
+ export_btn.clicked.connect(self._save_csv)
75
+ grid_layout.addWidget(export_btn, 6, 0, 1, 2)
76
+
77
+ # Setup layer callbacks
78
+ self.viewer.layers.events.inserted.connect(
79
+ lambda e: e.value.events.name.connect(self._on_layer_change)
80
+ )
81
+ self.viewer.layers.events.inserted.connect(self._on_layer_change)
82
+ self.viewer.layers.events.removed.connect(self._on_layer_change)
83
+ self.viewer.layers.events.removed.connect(self._reset_shapes_layer)
84
+ self._on_layer_change(None)
85
+
86
+ def _on_layer_change(self, e):
87
+ self.cb_image.clear()
88
+ for x in self.viewer.layers:
89
+ if isinstance(x, napari.layers.Image):
90
+ if x.data.ndim in [2, 3]:
91
+ self.cb_image.addItem(x.name, x.data)
92
+
93
+ def _reset_shapes_layer(self, e):
94
+ if (self.shapes_layer is not None) and not (self.shapes_layer in e.sources[0]):
95
+ self.shapes_layer = None
96
+
97
+ @thread_worker
98
+ def _prediction_thread(self, rescale_factor):
99
+ all_boxes = []
100
+ all_probas = []
101
+ for k, boxes, probas in self.predictor.yield_predictions(self.selected_image, rescale_factor):
102
+ all_boxes.extend(boxes)
103
+ all_probas.extend(probas)
104
+ yield k, boxes, probas
105
+
106
+ all_boxes = np.array(all_boxes)
107
+ all_probas = np.array(all_probas)
108
+
109
+ return all_boxes, all_probas, rescale_factor
110
+
111
+ def _start_detection(self):
112
+ self.selected_image = self.cb_image.currentData()
113
+ if self.selected_image is None:
114
+ return
115
+
116
+ if self.shapes_layer is not None: # Note: How to properly delete a layer?
117
+ self.shapes_layer.data = []
118
+ self.shapes_layer = None
119
+
120
+ if self.crossover_shapes_layer is not None:
121
+ self.crossover_shapes_layer.data = []
122
+ self.crossover_shapes_layer = None
123
+
124
+ rescale_factor = self.bx_rescale.value()
125
+
126
+ worker = self._prediction_thread(rescale_factor)
127
+
128
+ worker.yielded.connect(self._update_viewer)
129
+ worker.yielded.connect(lambda payload: self.pbar.setValue(payload[0]))
130
+ worker.returned.connect(lambda _: self.pbar.setMaximum(1))
131
+ worker.returned.connect(self._load_in_viewer)
132
+ n_crops = self.predictor.n_crops(self.selected_image, rescale_factor)
133
+ self.pbar.setMaximum(n_crops)
134
+ self.pbar.setValue(0)
135
+ worker.start()
136
+
137
+ def _update_viewer(self, payload):
138
+ _, boxes, probas = payload
139
+
140
+ if len(boxes) == 0:
141
+ return
142
+
143
+ if self.shapes_layer is None:
144
+ shape_kwargs = {
145
+ "shape_type": "rectangle",
146
+ "name": "Tau fibrils (Yolo)",
147
+ "face_color": "transparent",
148
+ "opacity": 1.0,
149
+ "edge_width": 1,
150
+ "edge_color": "#ff0000",
151
+ "properties": {
152
+ "probability": probas,
153
+ },
154
+ }
155
+
156
+ self.shapes_layer = self.viewer.add_shapes(boxes, **shape_kwargs)
157
+ else:
158
+ # Update the layer data
159
+ current_data = self.shapes_layer.data
160
+ current_data.extend(boxes)
161
+ current_probas = self.shapes_layer.properties.get("probability")
162
+ current_probas = list(current_probas)
163
+ current_probas.extend(probas)
164
+ self.shapes_layer.data = current_data
165
+ self.shapes_layer.properties = ({"probability": current_probas},)
166
+ self.shapes_layer.refresh()
167
+
168
+ def _load_in_viewer(self, payload):
169
+ """Callback from thread returning."""
170
+ boxes, probas, rescale_factor = payload
171
+
172
+ filt = outside_image_border_filter(boxes)
173
+ boxes = boxes[filt]
174
+ probas = probas[filt]
175
+
176
+ filt = overlapping_obb_filter(boxes)
177
+ boxes = boxes[filt]
178
+ probas = probas[filt]
179
+
180
+ filt = minimum_neighbor_distance_filter(boxes, rescale_factor)
181
+ boxes = boxes[filt]
182
+ probas = probas[filt]
183
+
184
+ if len(boxes) == 0:
185
+ show_info("No fibrils were detected.")
186
+ return
187
+
188
+ density_map = boxes_kernel_density_map(boxes, self.selected_image)
189
+
190
+ density_layer = self.viewer.add_image(
191
+ density_map, colormap="inferno", opacity=0.5, blending="additive"
192
+ )
193
+ density_layer.visible = False
194
+
195
+ probas = list(probas)
196
+
197
+ # Get the lines and crossover distances
198
+ image = self.cb_image.currentData()
199
+ distances = []
200
+ line_data = []
201
+ centers_x = []
202
+ centers_y = []
203
+ lengths = []
204
+ widths = []
205
+ for box in boxes:
206
+ distance, line_points, center, length, width = (
207
+ crossover_distance_measurement(box, image)
208
+ )
209
+ distances.append(distance)
210
+ line_data.append(line_points)
211
+ centers_x.append(center[0])
212
+ centers_y.append(center[1])
213
+ lengths.append(length)
214
+ widths.append(width)
215
+
216
+ self.df = pd.DataFrame(
217
+ {
218
+ "detection_probability": probas,
219
+ "crossover_distance": distances,
220
+ "length": lengths,
221
+ "width": widths,
222
+ "center_x": centers_x,
223
+ "center_y": centers_y,
224
+ }
225
+ )
226
+
227
+ lines_props = {
228
+ "shape_type": "line",
229
+ "name": "Crossover distance",
230
+ "opacity": 1.0,
231
+ "edge_width": 1,
232
+ "edge_color": "probability",
233
+ "properties": {"probability": probas, "distance": distances},
234
+ }
235
+ self.crossover_shapes_layer = self.viewer.add_shapes(line_data, **lines_props)
236
+
237
+ self.shapes_layer.data = boxes
238
+ self.shapes_layer.properties = (
239
+ {
240
+ "probability": probas,
241
+ "distance": distances,
242
+ "length": lengths,
243
+ "width": widths,
244
+ "center_x": centers_x,
245
+ "center_y": centers_y,
246
+ },
247
+ )
248
+ self.shapes_layer.edge_color = "#00ff00"
249
+ self.shapes_layer.refresh()
250
+
251
+ def _save_csv(self):
252
+ if self.df is None:
253
+ print("No detection data found.")
254
+ return
255
+
256
+ print("Saving CSV.")
257
+ filename, _ = QFileDialog.getSaveFileName(self, "Save as CSV", ".", "*.csv")
258
+ if not filename.endswith(".csv"):
259
+ filename += ".csv"
260
+
261
+ self.df.to_csv(filename)
262
+ print(f"Saved {filename}!")
@@ -0,0 +1,102 @@
1
+ from tau_fibrils_yolo.predict import FibrilsDetector
2
+ from tau_fibrils_yolo.crossover_distance import crossover_distance_measurement
3
+ import tifffile
4
+ from pathlib import Path
5
+ import argparse
6
+ import glob
7
+ import pandas as pd
8
+
9
+
10
+ def process_input_file_predict(input_image_file, predictor, rescale_factor):
11
+ image = tifffile.imread(input_image_file)
12
+
13
+ boxes, probas = predictor.predict(image, rescale_factor)
14
+
15
+ distances = []
16
+ line_data = []
17
+ centers_x = []
18
+ centers_y = []
19
+ lengths = []
20
+ widths = []
21
+ for box in boxes:
22
+ distance, line_points, center, length, width = (
23
+ crossover_distance_measurement(box, image)
24
+ )
25
+ distances.append(distance)
26
+ line_data.append(line_points)
27
+ centers_x.append(center[0])
28
+ centers_y.append(center[1])
29
+ lengths.append(length)
30
+ widths.append(width)
31
+
32
+ df = pd.DataFrame(
33
+ {
34
+ "detection_probability": probas,
35
+ "crossover_distance": distances,
36
+ "length": lengths,
37
+ "width": widths,
38
+ "center_x": centers_x,
39
+ "center_y": centers_y,
40
+ }
41
+ )
42
+
43
+ pt = Path(input_image_file)
44
+ out_file_name = pt.parent / f"{pt.stem}_results.csv"
45
+
46
+ df.to_csv(out_file_name)
47
+
48
+ print("Saved results to ", out_file_name)
49
+
50
+
51
+ def cli_predict_image():
52
+ """Command-line entry point for model inference."""
53
+ parser = argparse.ArgumentParser(description="Use this command to run inference.")
54
+ parser.add_argument(
55
+ "-i",
56
+ type=str,
57
+ required=True,
58
+ help="Input image. Must be a TIF image file.",
59
+ )
60
+ parser.add_argument(
61
+ "-r",
62
+ type=float,
63
+ required=False,
64
+ default=1.0,
65
+ help="Rescale factor.",
66
+ )
67
+ args = parser.parse_args()
68
+
69
+ input_image_file = args.i
70
+ rescale_factor = args.r
71
+
72
+ predictor = FibrilsDetector()
73
+
74
+ process_input_file_predict(input_image_file, predictor, rescale_factor)
75
+
76
+
77
+ def cli_predict_folder():
78
+ parser = argparse.ArgumentParser(
79
+ description="Use this command to run inference in batch on a given folder."
80
+ )
81
+ parser.add_argument(
82
+ "-i",
83
+ type=str,
84
+ required=True,
85
+ help="Input folder. Must contain suitable TIF image files.",
86
+ )
87
+ parser.add_argument(
88
+ "-r",
89
+ type=float,
90
+ required=False,
91
+ default=1.0,
92
+ help="Rescale factor.",
93
+ )
94
+ args = parser.parse_args()
95
+
96
+ input_folder = args.i
97
+ rescale_factor = args.r
98
+
99
+ predictor = FibrilsDetector()
100
+
101
+ for input_image_file in glob.glob(str(Path(input_folder) / "*.tif")):
102
+ process_input_file_predict(input_image_file, predictor, rescale_factor)
@@ -0,0 +1,49 @@
1
+ import numpy as np
2
+ from numpy.linalg import norm
3
+ from scipy.fft import fft
4
+ from scipy.fft import fftfreq
5
+ from skimage.measure import profile_line
6
+ from skimage.exposure import rescale_intensity
7
+
8
+
9
+ def crossover_distance_measurement(box, image):
10
+ p0, p1, p2, p3 = box
11
+
12
+ if norm(p2 - p0) < norm(p1 - p0):
13
+ p0, p3, p2, p1 = box
14
+
15
+ p5 = p0 + (p1 - p0) / 2
16
+ p6 = p3 + (p2 - p3) / 2
17
+
18
+ line_width = norm(p1 - p0)
19
+ line_length = norm(p2 - p0)
20
+
21
+ # Compute the intensity profile along the line, averaged over the width
22
+ pixel_values = profile_line(
23
+ image, src=p5, dst=p6, linewidth=int(line_width), order=1
24
+ )
25
+
26
+ # Normalize the pixel values
27
+ pixel_values = rescale_intensity(pixel_values, out_range=(-1, 1))
28
+
29
+ # Perform the Fourier transform
30
+ fft_values = fft(pixel_values)
31
+
32
+ # Calculate the frequencies
33
+ freqs = fftfreq(len(pixel_values))[1:] # Remove frequency = zero (?)
34
+
35
+ # Identify the main frequency
36
+ main_freq = freqs[np.argmax(np.abs(fft_values))]
37
+
38
+ # Calculate the distance
39
+ distance = (1 / main_freq) * (line_length / len(pixel_values))
40
+
41
+ # Line for visualization
42
+ diagonal = p6 - p5
43
+ center = p5 + diagonal / 2
44
+ diagonal_norm = diagonal / norm(diagonal)
45
+ line_points = np.array(
46
+ [center - distance / 2 * diagonal_norm, center + distance / 2 * diagonal_norm]
47
+ )
48
+
49
+ return distance, line_points, center, line_length, line_width
@@ -0,0 +1,10 @@
1
+ name: tau-fibrils-yolo
2
+ display_name: Tau Fibrils Yolo Detector
3
+ contributions:
4
+ commands:
5
+ - id: tau-fibrils-yolo.start
6
+ title: Tau fibrils detection
7
+ python_name: tau_fibrils_yolo:YoloDetectorWidget
8
+ widgets:
9
+ - command: tau-fibrils-yolo.start
10
+ display_name: Tau fibrils detection
@@ -0,0 +1,77 @@
1
+ import numpy as np
2
+ from skimage.exposure import rescale_intensity
3
+ import cv2
4
+ from scipy.ndimage import zoom
5
+ from scipy.spatial import distance_matrix
6
+ from sklearn.neighbors import KernelDensity
7
+
8
+
9
+ def minimum_neighbor_distance_filter(boxes, rescale_factor, min_dist_px=150):
10
+ """Returns a filter matching objects isolated by less than min_dist_px pixels."""
11
+ min_dist_px = min_dist_px / rescale_factor
12
+ box_centers = boxes_center_coordinates(boxes)
13
+ dist_matrix = distance_matrix(box_centers, box_centers)
14
+ np.fill_diagonal(dist_matrix, np.inf)
15
+ min_distances = np.min(dist_matrix, axis=1)
16
+ filt = min_distances <= min_dist_px
17
+ return filt
18
+
19
+
20
+ def overlapping_obb_filter(bounding_boxes):
21
+ """Returns the indeces of bounding boxes to keep after IOU filtering."""
22
+ keep = []
23
+ for i in range(len(bounding_boxes)):
24
+ should_keep_i = True
25
+ box_i = bounding_boxes[i]
26
+ for j in keep:
27
+ box_j = bounding_boxes[j]
28
+ intersection_area, _ = cv2.intersectConvexConvex(box_i, box_j)
29
+ if intersection_area > 0:
30
+ should_keep_i = False
31
+ break
32
+ if should_keep_i:
33
+ keep.append(i)
34
+ return keep
35
+
36
+
37
+ def outside_image_border_filter(boxes):
38
+ """Returns a filter matching boxes falling outside the image border."""
39
+ return ~(boxes < 0).any(axis=2).any(axis=1)
40
+
41
+
42
+ def boxes_center_coordinates(boxes):
43
+ """Returns the center coordinates of the boxes."""
44
+ x_coords = boxes[..., 0]
45
+ y_coords = boxes[..., 1]
46
+ box_cy = (np.max(y_coords, axis=1) + np.min(y_coords, axis=1)) / 2
47
+ box_cx = (np.max(x_coords, axis=1) + np.min(x_coords, axis=1)) / 2
48
+ box_centers = np.vstack((box_cx, box_cy)).T
49
+
50
+ return box_centers
51
+
52
+
53
+ def boxes_kernel_density_map(boxes, image, gaussian_sigma_px=50, downscale_factor=8):
54
+ """Kernel density estimate from bounding box coordinates"""
55
+
56
+ image_shape = np.array(image.shape) // downscale_factor
57
+
58
+ x_grid, y_grid = np.meshgrid(
59
+ np.linspace(0, image_shape[1] - 1, image_shape[1]),
60
+ np.linspace(0, image_shape[0] - 1, image_shape[0]),
61
+ )
62
+
63
+ grid_points = np.vstack([y_grid.ravel(), x_grid.ravel()]).T
64
+
65
+ kde = KernelDensity(
66
+ bandwidth=gaussian_sigma_px,
67
+ kernel="gaussian",
68
+ algorithm="ball_tree"
69
+ )
70
+
71
+ kde.fit(boxes_center_coordinates(boxes) / downscale_factor)
72
+
73
+ density_map = np.exp(kde.score_samples(grid_points)).reshape(image_shape)
74
+ density_map = zoom(density_map, zoom=downscale_factor, order=1)
75
+ density_map = rescale_intensity(density_map, out_range=(0, 1))
76
+
77
+ return density_map
@@ -0,0 +1,168 @@
1
+ import os
2
+ import numpy as np
3
+ from skimage.exposure import rescale_intensity
4
+ from skimage.transform import rescale
5
+ from ultralytics import YOLO
6
+ import pooch
7
+
8
+ from tau_fibrils_yolo.postprocess import (
9
+ outside_image_border_filter,
10
+ minimum_neighbor_distance_filter,
11
+ overlapping_obb_filter,
12
+ )
13
+
14
+
15
+ MODEL_PATH = os.path.expanduser(os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".yolo"))
16
+
17
+
18
+ def retreive_model():
19
+ """Downloads the model weights from Zenodo."""
20
+ print(
21
+ " /!\ This model is on the Zenodo Sandbox (use it only for testing purposes) /!\ "
22
+ )
23
+ pooch.retrieve(
24
+ url="https://sandbox.zenodo.org/records/99113/files/100ep.pt",
25
+ known_hash="md5:2fc4be1e4feae93f75e335856be3083d",
26
+ path=MODEL_PATH,
27
+ progressbar=True,
28
+ fname="yolo_fibrils_100ep.pt",
29
+ )
30
+
31
+
32
+ def pad_image(image, tile_size_px, overlap_px):
33
+ """Pads the image so that its length and width are divisible by the model's image size (which is assumed to be square)."""
34
+ rx, ry = image.shape
35
+
36
+ pad_x = rx % (tile_size_px - overlap_px) + overlap_px
37
+ half_pad_x = pad_x // 2
38
+ nx_ceil = np.ceil((rx + pad_x - overlap_px) / (tile_size_px - overlap_px)).astype(
39
+ int
40
+ )
41
+ padded_image_size_x = nx_ceil * (tile_size_px - overlap_px) + overlap_px
42
+
43
+ pad_y = ry % (tile_size_px - overlap_px) + overlap_px
44
+ half_pad_y = pad_y // 2
45
+ ny_ceil = np.ceil((ry + pad_y - overlap_px) / (tile_size_px - overlap_px)).astype(
46
+ int
47
+ )
48
+ padded_image_size_y = ny_ceil * (tile_size_px - overlap_px) + overlap_px
49
+
50
+ image_padded = np.zeros((padded_image_size_x, padded_image_size_y))
51
+
52
+ image_padded[half_pad_x : (half_pad_x + rx), half_pad_y : (half_pad_y + ry)] = image
53
+
54
+ return image_padded, (half_pad_x, half_pad_y), (nx_ceil, ny_ceil)
55
+
56
+
57
+ def image_tile_generator(image, imgsz, overlap_px):
58
+ """Generates image tiles and their coordinates in the image domain."""
59
+ image_p, (pad_x, pad_y), (nx, ny) = pad_image(image, imgsz, overlap_px)
60
+ shift_x = imgsz - overlap_px
61
+ shift_y = imgsz - overlap_px
62
+ for ix in range(nx):
63
+ for iy in range(ny):
64
+ image_tile = image_p[
65
+ (ix * shift_x) : (ix * shift_x + imgsz),
66
+ (iy * shift_y) : (iy * shift_y + imgsz),
67
+ ]
68
+ coord_x = ix * shift_x - pad_x
69
+ coord_y = iy * shift_y - pad_y
70
+
71
+ yield image_tile, (coord_x, coord_y)
72
+
73
+
74
+ def to_rgb(arr):
75
+ return np.repeat(arr[..., None], repeats=3, axis=-1)
76
+
77
+
78
+ def preprocess_image(image, rescale_factor):
79
+ # Make sure the image is single-channel
80
+ if len(image.shape) == 2:
81
+ image = rescale_intensity(image, out_range=(0, 255)).astype(np.uint8)
82
+ elif len(image.shape) == 3:
83
+ image = image[..., 0]
84
+
85
+ # Rescale the image to make it match the target resolution.
86
+ image = rescale(image, rescale_factor, order=3, preserve_range=True)
87
+
88
+ return image
89
+
90
+
91
+ def predict_generator(image, model, imgsz, rescale_factor=1.0, overlap_px=None):
92
+ if overlap_px is None:
93
+ # 10% overlap by default
94
+ overlap_px = imgsz // 10
95
+
96
+ image = preprocess_image(image, rescale_factor)
97
+
98
+ for k, (image_tile, (coord_x, coord_y)) in enumerate(
99
+ image_tile_generator(image, imgsz, overlap_px)
100
+ ):
101
+ image_input = to_rgb(image_tile)
102
+
103
+ result = model.predict(
104
+ source=image_input,
105
+ conf=0.05, # Confidence threshold for detections.
106
+ iou=0.1, # Intersection over union threshold.
107
+ max_det=500, # Max detections per 640 x 640 crop
108
+ augment=False,
109
+ imgsz=imgsz, # Square resizing
110
+ )[0]
111
+
112
+ probabilities = result.obb.conf.cpu().numpy()
113
+ boxes_coordinates = result.obb.xyxyxyxy.cpu().numpy()
114
+ boxes_coordinates[..., 1] = (
115
+ boxes_coordinates[..., 1] + coord_x
116
+ ) / rescale_factor
117
+ boxes_coordinates[..., 0] = (
118
+ boxes_coordinates[..., 0] + coord_y
119
+ ) / rescale_factor
120
+
121
+ # Invert X-Y
122
+ boxes_coordinates = boxes_coordinates[..., ::-1]
123
+
124
+ yield k, boxes_coordinates, probabilities
125
+
126
+
127
+ class FibrilsDetector:
128
+ def __init__(self):
129
+ retreive_model()
130
+ self.model = YOLO(os.path.join(MODEL_PATH, "yolo_fibrils_100ep.pt"))
131
+ self.imgsz = 640
132
+ self.overlap_px = self.imgsz // 10
133
+
134
+ def n_crops(self, image: np.ndarray, rescale_factor: float):
135
+ return np.prod(
136
+ pad_image(
137
+ preprocess_image(image, rescale_factor), self.imgsz, self.overlap_px
138
+ )[-1]
139
+ )
140
+
141
+ def predict(self, image: np.ndarray, rescale_factor: float):
142
+ boxes = []
143
+ probas = []
144
+ for _, b, p in predict_generator(image, self.model, self.imgsz, rescale_factor):
145
+ boxes.extend(b)
146
+ probas.extend(p)
147
+ boxes = np.array(boxes)
148
+ probas = np.array(probas)
149
+
150
+ filt = outside_image_border_filter(boxes)
151
+ boxes = boxes[filt]
152
+ probas = probas[filt]
153
+
154
+ filt = overlapping_obb_filter(boxes)
155
+ boxes = boxes[filt]
156
+ probas = probas[filt]
157
+
158
+ filt = minimum_neighbor_distance_filter(boxes, rescale_factor)
159
+ boxes = boxes[filt]
160
+ probas = probas[filt]
161
+
162
+ return boxes, probas
163
+
164
+ def yield_predictions(self, image: np.ndarray, rescale_factor: float):
165
+ for k, boxes, probas in predict_generator(
166
+ image, self.model, self.imgsz, rescale_factor
167
+ ):
168
+ yield k, boxes, probas
@@ -0,0 +1,28 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2024, EPFL.
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ * Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ * Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ * Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,173 @@
1
+ Metadata-Version: 2.1
2
+ Name: tau-fibrils-yolo
3
+ Version: 0.0.1
4
+ Summary: YoloV8 model for the detection of Tau fibrils in Cryo-EM images.
5
+ Author-email: Mallory Wittwer <mallory.wittwer@epfl.ch>
6
+ License: BSD 3-Clause License
7
+
8
+ Copyright (c) 2024, EPFL.
9
+
10
+ Redistribution and use in source and binary forms, with or without
11
+ modification, are permitted provided that the following conditions are met:
12
+
13
+ * Redistributions of source code must retain the above copyright notice, this
14
+ list of conditions and the following disclaimer.
15
+
16
+ * Redistributions in binary form must reproduce the above copyright notice,
17
+ this list of conditions and the following disclaimer in the documentation
18
+ and/or other materials provided with the distribution.
19
+
20
+ * Neither the name of the copyright holder nor the names of its
21
+ contributors may be used to endorse or promote products derived from
22
+ this software without specific prior written permission.
23
+
24
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34
+ Project-URL: homepage, https://github.com/EPFL-Center-for-Imaging/tau_fibrils_yolo
35
+ Project-URL: repository, https://github.com/EPFL-Center-for-Imaging/tau_fibrils_yolo
36
+ Classifier: Development Status :: 2 - Pre-Alpha
37
+ Classifier: License :: OSI Approved :: BSD License
38
+ Classifier: Programming Language :: Python :: 3
39
+ Classifier: Programming Language :: Python :: 3.9
40
+ Classifier: Programming Language :: Python :: 3.10
41
+ Classifier: Topic :: Scientific/Engineering :: Image Processing
42
+ Requires-Python: >=3.9
43
+ Description-Content-Type: text/markdown
44
+ License-File: LICENSE
45
+ Requires-Dist: napari[all]>=0.4.16
46
+ Requires-Dist: qtpy
47
+ Requires-Dist: magicgui
48
+ Requires-Dist: numpy
49
+ Requires-Dist: pandas
50
+ Requires-Dist: tifffile
51
+ Requires-Dist: pooch==1.8.0
52
+ Requires-Dist: scikit-image
53
+ Requires-Dist: scikit-learn
54
+ Requires-Dist: ultralytics
55
+ Requires-Dist: opencv-contrib-python-headless
56
+
57
+ ![EPFL Center for Imaging logo](https://imaging.epfl.ch/resources/logo-for-gitlab.svg)
58
+ # 🧬 Tau Fibrils Yolo - Object detection in Cryo-EM images
59
+
60
+ ![screenshot](assets/screenshot.png)
61
+
62
+ We provide a [YoloV8](https://docs.ultralytics.com/) model for the detection of oriented bounding boxes (OBBs) of Tau fibrils in Cryo-EM images.
63
+
64
+ [[`Installation`](#installation)] [[`Model`](#model)] [[`Usage`](#usage)]
65
+
66
+ This project is part of a collaboration between the [EPFL Center for Imaging](https://imaging.epfl.ch/) and the [Laboratory of Biological Electron Microscopy](https://www.lbem.ch/).
67
+
68
+ ## Installation
69
+
70
+ ### As a standalone app
71
+
72
+ Soon.
73
+
74
+ ### As a Python package
75
+
76
+ We recommend performing the installation in a clean Python environment. Install our package from PyPi:
77
+
78
+ ```sh
79
+ pip install tau-fibrils-yolo
80
+ ```
81
+
82
+ or from the repository:
83
+
84
+ ```sh
85
+ pip install git+https://gitlab.com/center-for-imaging/tau-fibrils-object-detection.git
86
+ ```
87
+
88
+ or clone the repository and install with:
89
+
90
+ ```sh
91
+ git clone git+https://gitlab.com/center-for-imaging/tau-fibrils-object-detection.git
92
+ cd tau-fibrils-yolo
93
+ pip install -e .
94
+ ```
95
+
96
+ ## Model
97
+
98
+ The model weights (6.5 Mb) are automatically downloaded from [this repository on Zenodo](https://sandbox.zenodo.org/records/99113) the first time you run inference. The model files are saved in the user home folder in the `.yolo` directory.
99
+
100
+ ## Usage
101
+
102
+ **In Napari**
103
+
104
+ To use our model in Napari, start the viewer with
105
+
106
+ ```sh
107
+ napari -w tau-fibrils-yolo
108
+ ```
109
+
110
+ or open the Napari menu bar and select `Plugins > Tau fibrils detection`.
111
+
112
+ Open an image using `File > Open files` or drag-and-drop an image into the viewer window.
113
+
114
+ **Sample data**: To test the model, you can run it on our provided sample image. In Napari, open the image from `File > Open Sample > [TODO - add a sample image]`.
115
+
116
+
117
+ **As a library**
118
+
119
+ You can run the model to detect fibrils in an image (represented as a numpy array).
120
+
121
+ ```py
122
+ from tau_fibrils_yolo import FibrilsDetector
123
+
124
+ detector = FibrilsDetector()
125
+
126
+ boxes, probabilities = detector.predict(your_image)
127
+ ```
128
+
129
+ **As a CLI**
130
+
131
+ Run inference on an image from the command-line. For example:
132
+
133
+ ```sh
134
+ tau_fibrils_predict_image -i /path/to/folder/image_001.tif
135
+ ```
136
+
137
+ The command will save the segmentation next to the image:
138
+
139
+ ```
140
+ folder/
141
+ ├── image_001.tif
142
+ ├── image_001_results.csv
143
+ ```
144
+
145
+ Optionally, you can use the `-r` flag to also rescale the image by a given factor.
146
+
147
+ To run inference in batch on all images in a folder, use:
148
+
149
+ ```sh
150
+ tau_fibrils_predict_folder -i /path/to/folder/
151
+ ```
152
+
153
+ This will produce:
154
+
155
+ ```
156
+ folder/
157
+ ├── image_001.tif
158
+ ├── image_001_results.csv
159
+ ├── image_002.tif
160
+ ├── image_002_results.csv
161
+ ```
162
+
163
+ ## Issues
164
+
165
+ If you encounter any problems, please file an issue along with a detailed description.
166
+
167
+ ## License
168
+
169
+ This model is licensed under the [BSD-3](LICENSE) license.
170
+
171
+ ## Acknowledgements
172
+
173
+ We would particularly like to thank **Valentin Vuillon** for annotating the images on which this model was trained, and for developing the preliminary code that laid the foundation for this image analysis project. The repository containing his original version of the project can be found [here](https://gitlab.com/epfl-center-for-imaging/automated-analysis-tau-fibrils-project).
@@ -0,0 +1,14 @@
1
+ tau_fibrils_yolo/__init__.py,sha256=IKx6Hj3WAfsNmxnW5nJS7ibx59PyLKBsAzvvWqXQyeA,108
2
+ tau_fibrils_yolo/_version.py,sha256=pMnmqZnpVmaqR5nqHztNWzbbtb1oy5bPN_v7uhOH8K8,411
3
+ tau_fibrils_yolo/_widget.py,sha256=Eqwm0_A86OX0sJ4TZ_mtXcHfdvxIuOYfFu32y-GKVgM,8746
4
+ tau_fibrils_yolo/cli.py,sha256=DtrKlSja2Gedziv_kSgiHztn-A-kK0DWiX0wiQJJ1y0,2724
5
+ tau_fibrils_yolo/crossover_distance.py,sha256=_FfsSoLL-eCjWXleqKXFk6y6u9nx7qyW58r4RGdElkk,1425
6
+ tau_fibrils_yolo/napari.yaml,sha256=CLac43pJqbklLzjkaEhVEf6aHoGeRjS-CoHPhkcaPE0,303
7
+ tau_fibrils_yolo/postprocess.py,sha256=2e1b_JlAFtCiuWB3yGn_9cVI33GPqtDrPeiDtoyatYE,2609
8
+ tau_fibrils_yolo/predict.py,sha256=hMViSnyX9HSTDpbB-nYrQ77F_IrChWiPdXrwqC2E8KA,5473
9
+ tau_fibrils_yolo-0.0.1.dist-info/LICENSE,sha256=4OTqIt2xVP2wCy4QAOyXVURea7L4FKWFW6drH-hAUKU,1483
10
+ tau_fibrils_yolo-0.0.1.dist-info/METADATA,sha256=d7S6ShQACIQ6MUGTlrOjCXzVLYM9PS8j1Wg1OyTfRcA,6041
11
+ tau_fibrils_yolo-0.0.1.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
12
+ tau_fibrils_yolo-0.0.1.dist-info/entry_points.txt,sha256=rf5DNzu6OGSQ2sZa1YqAvR6I-wnQ_zexP0EyUB4ZxtA,221
13
+ tau_fibrils_yolo-0.0.1.dist-info/top_level.txt,sha256=9s_o7Fja2NxRfdxnPwkMXQljx6D233h3z0JpV4a4bTE,17
14
+ tau_fibrils_yolo-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (73.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,6 @@
1
+ [console_scripts]
2
+ tau_fibrils_predict_folder = tau_fibrils_yolo.cli:cli_predict_folder
3
+ tau_fibrils_predict_image = tau_fibrils_yolo.cli:cli_predict_image
4
+
5
+ [napari.manifest]
6
+ tau_fibrils_yolo = tau_fibrils_yolo:napari.yaml
@@ -0,0 +1 @@
1
+ tau_fibrils_yolo