tau-fibrils-yolo 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.0.4'
16
- __version_tuple__ = version_tuple = (0, 0, 4)
15
+ __version__ = version = '0.0.6'
16
+ __version_tuple__ = version_tuple = (0, 0, 6)
@@ -16,6 +16,8 @@ from qtpy.QtWidgets import (
16
16
  QDoubleSpinBox,
17
17
  QFileDialog,
18
18
  )
19
+ from matplotlib.backends.backend_qt5agg import FigureCanvas
20
+
19
21
  from tau_fibrils_yolo.predict import FibrilsDetector
20
22
  from tau_fibrils_yolo.postprocess import (
21
23
  boxes_kernel_density_map,
@@ -23,7 +25,21 @@ from tau_fibrils_yolo.postprocess import (
23
25
  outside_image_border_filter,
24
26
  overlapping_obb_filter,
25
27
  )
26
- from tau_fibrils_yolo.crossover_distance import crossover_distance_measurement
28
+ from tau_fibrils_yolo.measure import (
29
+ box_measurements,
30
+ line_profile_measurements,
31
+ crossover_distance_measurement,
32
+ )
33
+
34
+ import matplotlib as mpl
35
+
36
+ mpl.rc("axes", edgecolor="white")
37
+ mpl.rc("axes", facecolor="#262930")
38
+ mpl.rc("axes", labelcolor="white")
39
+ mpl.rc("savefig", facecolor="#262930")
40
+ mpl.rc("text", color="white")
41
+ mpl.rc("xtick", color="white")
42
+ mpl.rc("ytick", color="white")
27
43
 
28
44
 
29
45
  class YoloDetectorWidget(QWidget):
@@ -74,6 +90,19 @@ class YoloDetectorWidget(QWidget):
74
90
  export_btn.clicked.connect(self._save_csv)
75
91
  grid_layout.addWidget(export_btn, 6, 0, 1, 2)
76
92
 
93
+ # Line profile plot
94
+ self.canvas = FigureCanvas()
95
+ self.canvas.figure.set_tight_layout(True)
96
+ self.canvas.figure.patch.set_facecolor("#262930")
97
+ self.axes = self.canvas.figure.subplots()
98
+ self.axes.cla()
99
+ self.axes.set_ylabel("Intensity")
100
+ self.axes.set_xlabel("Pixels")
101
+ self.canvas.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum)
102
+ self.canvas.setMaximumSize(500, 300)
103
+ self.canvas.setMinimumSize(300, 150)
104
+ grid_layout.addWidget(self.canvas, 7, 0, 1, 2)
105
+
77
106
  # Setup layer callbacks
78
107
  self.viewer.layers.events.inserted.connect(
79
108
  lambda e: e.value.events.name.connect(self._on_layer_change)
@@ -98,7 +127,9 @@ class YoloDetectorWidget(QWidget):
98
127
  def _prediction_thread(self, rescale_factor):
99
128
  all_boxes = []
100
129
  all_probas = []
101
- for k, boxes, probas in self.predictor.yield_predictions(self.selected_image, rescale_factor):
130
+ for k, boxes, probas in self.predictor.yield_predictions(
131
+ self.selected_image, rescale_factor
132
+ ):
102
133
  all_boxes.extend(boxes)
103
134
  all_probas.extend(probas)
104
135
  yield k, boxes, probas
@@ -148,38 +179,36 @@ class YoloDetectorWidget(QWidget):
148
179
  "opacity": 1.0,
149
180
  "edge_width": 1,
150
181
  "edge_color": "#ff0000",
151
- "properties": {
152
- "probability": probas,
153
- },
154
182
  }
155
183
 
156
184
  self.shapes_layer = self.viewer.add_shapes(boxes, **shape_kwargs)
185
+ self.shapes_layer.mouse_double_click_callbacks.append(
186
+ self._handle_double_click
187
+ )
188
+ self.shapes_layer.mode = "DIRECT"
157
189
  else:
158
190
  # Update the layer data
159
191
  current_data = self.shapes_layer.data
160
192
  current_data.extend(boxes)
161
- current_probas = self.shapes_layer.properties.get("probability")
162
- current_probas = list(current_probas)
163
- current_probas.extend(probas)
164
193
  self.shapes_layer.data = current_data
165
- self.shapes_layer.properties = ({"probability": current_probas},)
166
194
  self.shapes_layer.refresh()
167
195
 
168
196
  def _load_in_viewer(self, payload):
169
197
  """Callback from thread returning."""
170
198
  boxes, probas, rescale_factor = payload
171
199
 
172
- filt = outside_image_border_filter(boxes)
173
- boxes = boxes[filt]
174
- probas = probas[filt]
200
+ if len(boxes):
201
+ filt = outside_image_border_filter(boxes)
202
+ boxes = boxes[filt]
203
+ probas = probas[filt]
175
204
 
176
- filt = overlapping_obb_filter(boxes)
177
- boxes = boxes[filt]
178
- probas = probas[filt]
205
+ filt = overlapping_obb_filter(boxes)
206
+ boxes = boxes[filt]
207
+ probas = probas[filt]
179
208
 
180
- filt = minimum_neighbor_distance_filter(boxes, rescale_factor)
181
- boxes = boxes[filt]
182
- probas = probas[filt]
209
+ filt = minimum_neighbor_distance_filter(boxes, rescale_factor)
210
+ boxes = boxes[filt]
211
+ probas = probas[filt]
183
212
 
184
213
  if len(boxes) == 0:
185
214
  show_info("No fibrils were detected.")
@@ -191,72 +220,75 @@ class YoloDetectorWidget(QWidget):
191
220
  density_map, colormap="inferno", opacity=0.5, blending="additive"
192
221
  )
193
222
  density_layer.visible = False
223
+ self.viewer.layers.selection.active = self.shapes_layer
194
224
 
195
225
  probas = list(probas)
196
226
 
197
227
  # Get the lines and crossover distances
198
- image = self.cb_image.currentData()
199
- distances = []
200
- line_data = []
201
228
  centers_x = []
202
229
  centers_y = []
203
230
  lengths = []
204
231
  widths = []
205
232
  for box in boxes:
206
- distance, line_points, center, length, width = (
207
- crossover_distance_measurement(box, image)
208
- )
209
- distances.append(distance)
210
- line_data.append(line_points)
233
+ center, length, width = box_measurements(box)
211
234
  centers_x.append(center[0])
212
235
  centers_y.append(center[1])
213
236
  lengths.append(length)
214
237
  widths.append(width)
215
238
 
216
- self.df = pd.DataFrame(
217
- {
218
- "detection_probability": probas,
219
- "crossover_distance": distances,
220
- "length": lengths,
221
- "width": widths,
222
- "center_x": centers_x,
223
- "center_y": centers_y,
224
- }
225
- )
226
-
227
- lines_props = {
228
- "shape_type": "line",
229
- "name": "Crossover distance",
230
- "opacity": 1.0,
231
- "edge_width": 1,
232
- "edge_color": "probability",
233
- "properties": {"probability": probas, "distance": distances},
239
+ self.shapes_layer.data = boxes
240
+ self.shapes_layer.properties = {
241
+ "probability": probas,
242
+ "length": lengths,
243
+ "width": widths,
244
+ "center_x": centers_x,
245
+ "center_y": centers_y,
234
246
  }
235
- self.crossover_shapes_layer = self.viewer.add_shapes(line_data, **lines_props)
236
247
 
237
- self.shapes_layer.data = boxes
238
- self.shapes_layer.properties = (
248
+ self.shapes_layer.edge_color = "#00ff00"
249
+ self.shapes_layer.refresh()
250
+
251
+ self.df = pd.DataFrame(
239
252
  {
240
253
  "probability": probas,
241
- "distance": distances,
242
254
  "length": lengths,
243
255
  "width": widths,
244
256
  "center_x": centers_x,
245
257
  "center_y": centers_y,
246
- },
258
+ }
247
259
  )
248
- self.shapes_layer.edge_color = "#00ff00"
249
- self.shapes_layer.refresh()
250
260
 
251
261
  def _save_csv(self):
252
262
  if self.df is None:
253
263
  print("No detection data found.")
254
264
  return
255
265
 
256
- print("Saving CSV.")
257
266
  filename, _ = QFileDialog.getSaveFileName(self, "Save as CSV", ".", "*.csv")
258
267
  if not filename.endswith(".csv"):
259
268
  filename += ".csv"
260
269
 
261
270
  self.df.to_csv(filename)
262
271
  print(f"Saved {filename}!")
272
+
273
+ def _handle_double_click(self, *args, **kwargs):
274
+ if self.shapes_layer.mode in ["direct", "select"]:
275
+ selected_data = self.shapes_layer.selected_data
276
+ if len(selected_data) == 1:
277
+ selected_shape_idx = list(self.shapes_layer.selected_data)[0]
278
+ boxes = self.shapes_layer.data
279
+ image = self.cb_image.currentData()
280
+ box = boxes[selected_shape_idx]
281
+ line_profile_width, line_profile_centerline = line_profile_measurements(box, image)
282
+ crossover_distance = crossover_distance_measurement(box, image)
283
+ crossover_distance_centerline = crossover_distance_measurement(box, image, method="centerline")
284
+ self._draw(line_profile_width, line_profile_centerline, crossover_distance, crossover_distance_centerline)
285
+
286
+ def _draw(self, line_profile_width, line_profile_centerline, crossover_distance, crossover_distance_centerline):
287
+ self.axes.cla()
288
+ self.axes.plot(line_profile_width, label=f"box_width (crossover_dist: {crossover_distance:.0f} px)")
289
+ self.axes.plot(line_profile_centerline, color="orange", label=f"centerline (crossover_dist: {crossover_distance_centerline:.0f} px)")
290
+ self.axes.set_ylabel("Intensity")
291
+ self.axes.set_xlabel("Pixels")
292
+ self.axes.set_xlim(0, len(line_profile_width))
293
+ self.axes.legend(bbox_to_anchor=(0, 1.02, 1, 0.2), loc="lower left", mode="expand", borderaxespad=0, ncol=1)
294
+ self.canvas.draw()
tau_fibrils_yolo/cli.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from tau_fibrils_yolo.predict import FibrilsDetector
2
- from tau_fibrils_yolo.crossover_distance import crossover_distance_measurement
2
+ from tau_fibrils_yolo.measure import crossover_distance_measurement
3
3
  import tifffile
4
4
  from pathlib import Path
5
5
  import argparse
@@ -0,0 +1,91 @@
1
+ import numpy as np
2
+ from numpy.linalg import norm
3
+ from scipy.fft import fft
4
+ from scipy.fft import fftfreq
5
+ from skimage.measure import profile_line
6
+ from skimage.exposure import rescale_intensity
7
+
8
+
9
+ def crossover_distance_measurement(box, image, method="width"):
10
+ if len(image.shape) == 3:
11
+ image = image[..., 0]
12
+
13
+ p0, p1, p2, p3 = box
14
+
15
+ if norm(p2 - p0) < norm(p1 - p0):
16
+ p0, p3, p2, p1 = box
17
+
18
+ p5 = p0 + (p1 - p0) / 2
19
+ p6 = p3 + (p2 - p3) / 2
20
+
21
+ line_width = norm(p1 - p0)
22
+ line_length = norm(p2 - p0)
23
+
24
+ if method == "width":
25
+ # Compute the intensity profile along the line, averaged over the width
26
+ pixel_values = profile_line(
27
+ image, src=p5, dst=p6, linewidth=int(line_width), order=1
28
+ )
29
+ elif method == "centerline":
30
+ pixel_values = profile_line(
31
+ image, src=p5, dst=p6, linewidth=1, order=1
32
+ )
33
+
34
+ # Normalize the pixel values
35
+ pixel_values = rescale_intensity(pixel_values, out_range=(-1, 1))
36
+
37
+ # Perform the Fourier transform
38
+ fft_values = fft(pixel_values)
39
+
40
+ # Calculate the frequencies
41
+ freqs = fftfreq(len(pixel_values))[1:] # Remove frequency = zero (?)
42
+
43
+ # Identify the main frequency
44
+ main_freq = freqs[np.argmax(np.abs(fft_values))]
45
+
46
+ # Calculate the distance
47
+ distance = (1 / main_freq) * (line_length / len(pixel_values))
48
+
49
+ return distance
50
+
51
+
52
+ def box_measurements(box):
53
+ p0, p1, p2, p3 = box
54
+
55
+ if norm(p2 - p0) < norm(p1 - p0):
56
+ p0, p3, p2, p1 = box
57
+
58
+ p5 = p0 + (p1 - p0) / 2
59
+ p6 = p3 + (p2 - p3) / 2
60
+
61
+ line_width = norm(p1 - p0)
62
+ line_length = norm(p2 - p0)
63
+
64
+ diagonal = p6 - p5
65
+ center = p5 + diagonal / 2
66
+
67
+ return center, line_length, line_width
68
+
69
+
70
+ def line_profile_measurements(box, image):
71
+ if len(image.shape) == 3:
72
+ image = image[..., 0]
73
+
74
+ p0, p1, p2, p3 = box
75
+
76
+ if norm(p2 - p0) < norm(p1 - p0):
77
+ p0, p3, p2, p1 = box
78
+
79
+ p5 = p0 + (p1 - p0) / 2
80
+ p6 = p3 + (p2 - p3) / 2
81
+
82
+ line_width = norm(p1 - p0)
83
+
84
+ # Compute the intensity profile along the line, averaged over the width
85
+ pixel_values_width = profile_line(
86
+ image, src=p5, dst=p6, linewidth=int(line_width), order=1
87
+ )
88
+ # or just the centerline
89
+ pixel_values_centerline = profile_line(image, src=p5, dst=p6, linewidth=1, order=1)
90
+
91
+ return pixel_values_width, pixel_values_centerline
@@ -53,11 +53,12 @@ def boxes_center_coordinates(boxes):
53
53
  def boxes_kernel_density_map(boxes, image, gaussian_sigma_px=50, downscale_factor=8):
54
54
  """Kernel density estimate from bounding box coordinates"""
55
55
 
56
- image_shape = np.array(image.shape) // downscale_factor
56
+ grid_size_x = image.shape[0] // downscale_factor
57
+ grid_size_y = image.shape[1] // downscale_factor
57
58
 
58
59
  x_grid, y_grid = np.meshgrid(
59
- np.linspace(0, image_shape[1] - 1, image_shape[1]),
60
- np.linspace(0, image_shape[0] - 1, image_shape[0]),
60
+ np.linspace(0, grid_size_y - 1, grid_size_y),
61
+ np.linspace(0, grid_size_x - 1, grid_size_x),
61
62
  )
62
63
 
63
64
  grid_points = np.vstack([y_grid.ravel(), x_grid.ravel()]).T
@@ -70,7 +71,7 @@ def boxes_kernel_density_map(boxes, image, gaussian_sigma_px=50, downscale_facto
70
71
 
71
72
  kde.fit(boxes_center_coordinates(boxes) / downscale_factor)
72
73
 
73
- density_map = np.exp(kde.score_samples(grid_points)).reshape(image_shape)
74
+ density_map = np.exp(kde.score_samples(grid_points)).reshape((grid_size_x, grid_size_y))
74
75
  density_map = zoom(density_map, zoom=downscale_factor, order=1)
75
76
  density_map = rescale_intensity(density_map, out_range=(0, 1))
76
77