senoquant 1.0.0b2__py3-none-any.whl → 1.0.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. senoquant/__init__.py +6 -2
  2. senoquant/_reader.py +1 -1
  3. senoquant/reader/core.py +201 -18
  4. senoquant/tabs/batch/backend.py +18 -3
  5. senoquant/tabs/batch/frontend.py +8 -4
  6. senoquant/tabs/quantification/features/marker/dialog.py +26 -6
  7. senoquant/tabs/quantification/features/marker/export.py +97 -24
  8. senoquant/tabs/quantification/features/marker/rows.py +2 -2
  9. senoquant/tabs/quantification/features/spots/dialog.py +41 -11
  10. senoquant/tabs/quantification/features/spots/export.py +163 -10
  11. senoquant/tabs/quantification/frontend.py +2 -2
  12. senoquant/tabs/segmentation/frontend.py +46 -9
  13. senoquant/tabs/segmentation/models/cpsam/model.py +1 -1
  14. senoquant/tabs/segmentation/models/default_2d/model.py +22 -77
  15. senoquant/tabs/segmentation/models/default_3d/model.py +8 -74
  16. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +0 -0
  17. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +13 -13
  18. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/stardist_libs.py +171 -0
  19. senoquant/tabs/spots/frontend.py +42 -5
  20. senoquant/tabs/spots/models/ufish/details.json +17 -0
  21. senoquant/tabs/spots/models/ufish/model.py +129 -0
  22. senoquant/tabs/spots/ufish_utils/__init__.py +13 -0
  23. senoquant/tabs/spots/ufish_utils/core.py +357 -0
  24. senoquant/utils.py +1 -1
  25. senoquant-1.0.0b3.dist-info/METADATA +161 -0
  26. {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b3.dist-info}/RECORD +41 -28
  27. {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b3.dist-info}/top_level.txt +1 -0
  28. ufish/__init__.py +1 -0
  29. ufish/api.py +778 -0
  30. ufish/model/__init__.py +0 -0
  31. ufish/model/loss.py +62 -0
  32. ufish/model/network/__init__.py +0 -0
  33. ufish/model/network/spot_learn.py +50 -0
  34. ufish/model/network/ufish_net.py +204 -0
  35. ufish/model/train.py +175 -0
  36. ufish/utils/__init__.py +0 -0
  37. ufish/utils/img.py +418 -0
  38. ufish/utils/log.py +8 -0
  39. ufish/utils/spot_calling.py +115 -0
  40. senoquant/tabs/spots/models/rmp/details.json +0 -61
  41. senoquant/tabs/spots/models/rmp/model.py +0 -499
  42. senoquant/tabs/spots/models/udwt/details.json +0 -103
  43. senoquant/tabs/spots/models/udwt/model.py +0 -482
  44. senoquant-1.0.0b2.dist-info/METADATA +0 -193
  45. {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b3.dist-info}/WHEEL +0 -0
  46. {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b3.dist-info}/entry_points.txt +0 -0
  47. {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b3.dist-info}/licenses/LICENSE +0 -0
ufish/api.py ADDED
@@ -0,0 +1,778 @@
1
+ import os
2
+ import os.path as osp
3
+ import time
4
+ import typing as T
5
+ from pathlib import Path
6
+ from functools import partial
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ from .utils.log import logger
12
+
13
+ if T.TYPE_CHECKING:
14
+ from torch import nn
15
+ from matplotlib.figure import Figure
16
+ import onnxruntime
17
+
18
+
19
+ BASE_STORE_URL = 'https://huggingface.co/GangCaoLab/U-FISH/resolve/main/'
20
+ DEFAULT_WEIGHTS_FILE = 'v1.0-alldata-ufish_c32.onnx'
21
+ STATC_STORE_PATH = osp.abspath(
22
+ osp.join(osp.dirname(__file__), "model/weights/"))
23
+
24
+
25
+ class UFish():
26
+ def __init__(
27
+ self, cuda: bool = True,
28
+ default_weights_file: T.Optional[str] = None,
29
+ local_store_path: str = '~/.ufish/'
30
+ ) -> None:
31
+ """
32
+ Args:
33
+ cuda: Whether to use GPU.
34
+ default_weight_file: The default weight file to use.
35
+ local_store_path: The local path to store the weights.
36
+ """
37
+ self._cuda = cuda
38
+ self._infer_mode = False
39
+ self.model: T.Optional["nn.Module"] = None
40
+ self.ort_session: T.Optional["onnxruntime.InferenceSession"] = None
41
+ if default_weights_file is None:
42
+ default_weights_file = DEFAULT_WEIGHTS_FILE
43
+ self.default_weights_file = default_weights_file
44
+ self.store_base_url = BASE_STORE_URL
45
+ self.local_store_path = Path(
46
+ os.path.expanduser(local_store_path))
47
+ self.weight_path: T.Optional[str] = None
48
+
49
+ def init_model(
50
+ self,
51
+ model_type: str = 'ufish',
52
+ **kwargs) -> None:
53
+ """Initialize the model.
54
+
55
+ Args:
56
+ model_type: The type of the model. For example,
57
+ 'ufish', 'spot_learn', ...
58
+ kwargs: Other arguments for the model.
59
+ """
60
+ import torch
61
+ if model_type == 'ufish':
62
+ from .model.network.ufish_net import UFishNet
63
+ self.model = UFishNet(**kwargs)
64
+ elif model_type == 'spot_learn':
65
+ from .model.network.spot_learn import SpotLearn
66
+ self.model = SpotLearn(**kwargs)
67
+ else:
68
+ raise ValueError(f'Unknown model type: {model_type}')
69
+ params = sum(p.numel() for p in self.model.parameters())
70
+ logger.info(
71
+ f'Initializing {model_type} model with kwargs: {kwargs}')
72
+ logger.info(f'Number of parameters: {params}')
73
+ self.cuda = False
74
+ if self._cuda:
75
+ if torch.cuda.is_available():
76
+ self.model = self.model.cuda()
77
+ self.cuda = True
78
+ logger.info('CUDA is available, using GPU.')
79
+ else:
80
+ logger.warning('CUDA is not available, using CPU.')
81
+ else:
82
+ logger.info('CUDA is not used, using CPU.')
83
+
84
+ def convert_to_onnx(
85
+ self,
86
+ output_path: T.Union[Path, str],) -> None:
87
+ """Convert the model to ONNX format.
88
+
89
+ Args:
90
+ output_path: The path to the output ONNX file.
91
+ """
92
+ if self.model is None:
93
+ raise RuntimeError('Model is not initialized.')
94
+ self._turn_on_infer_mode(trace_model=True)
95
+ import torch
96
+ import torch.onnx
97
+ output_path = str(output_path)
98
+ logger.info(
99
+ f'Converting model to ONNX format, saving to {output_path}')
100
+ device = torch.device('cuda' if self.cuda else 'cpu')
101
+ inp = torch.rand(1, 1, 512, 512).to(device)
102
+ dyn_axes = {0: 'batch_size', 2: 'y', 3: 'x'}
103
+ torch.onnx.export(
104
+ self.model, inp, output_path,
105
+ input_names=['input'],
106
+ output_names=['output'],
107
+ opset_version=11,
108
+ do_constant_folding=True,
109
+ dynamic_axes={
110
+ 'input': dyn_axes,
111
+ 'output': dyn_axes,
112
+ },
113
+ )
114
+
115
+ def _turn_on_infer_mode(self, trace_model: bool = False) -> None:
116
+ """Turn on the infer mode."""
117
+ if self._infer_mode:
118
+ return
119
+ self._infer_mode = True
120
+ self.model.eval()
121
+ if trace_model:
122
+ import torch
123
+ device = next(self.model.parameters()).device
124
+ inp = torch.rand(1, 1, 512, 512).to(device)
125
+ self.model = torch.jit.trace(self.model, inp)
126
+
127
+ def load_weights_from_internet(
128
+ self,
129
+ weights_file: T.Optional[str] = None,
130
+ max_retry: int = 8,
131
+ force_download: bool = False,
132
+ ) -> None:
133
+ """Load weights from the huggingface repo.
134
+
135
+ Args:
136
+ weights_file: The name of the weights file on the internet.
137
+ See https://huggingface.co/GangCaoLab/U-FISH/tree/main
138
+ for available weights files.
139
+ max_retry: The maximum number of retries.
140
+ force_download: Whether to force download the weights.
141
+ """
142
+ import torch
143
+ weights_file = weights_file or self.default_weights_file
144
+ weight_url = self.store_base_url + weights_file
145
+ local_weight_path = self.local_store_path / weights_file
146
+ if local_weight_path.exists() and (not force_download):
147
+ logger.info(
148
+ f'Local weights {local_weight_path} exists, '
149
+ 'skip downloading.'
150
+ )
151
+ else:
152
+ logger.info(
153
+ f'Downloading weights from {weight_url}, '
154
+ f'storing to {local_weight_path}')
155
+ local_weight_path.parent.mkdir(parents=True, exist_ok=True)
156
+ try_count = 0
157
+ while try_count < max_retry:
158
+ try:
159
+ torch.hub.download_url_to_file(
160
+ weight_url, local_weight_path)
161
+ break
162
+ except Exception as e:
163
+ logger.warning(f'Error downloading weights: {e}')
164
+ try_count += 1
165
+ time.sleep(0.5)
166
+ else:
167
+ raise RuntimeError(
168
+ f'Error downloading weights from {weight_url}.')
169
+ self.load_weights_from_path(local_weight_path)
170
+
171
+ def load_weights_from_path(
172
+ self,
173
+ path: T.Union[Path, str],
174
+ ) -> None:
175
+ """Load weights from a local file.
176
+ The file can be a .pth file or an .onnx file.
177
+
178
+ Args:
179
+ path: The path to the weights file.
180
+ """
181
+ path = str(path)
182
+ self.weight_path = path
183
+ if path.endswith('.pth'):
184
+ self._load_pth_file(path)
185
+ elif path.endswith('.onnx'):
186
+ self._load_onnx(path)
187
+ else:
188
+ raise ValueError(
189
+ 'Weights file must be a pth file or an onnx file.')
190
+
191
+ def load_weights(
192
+ self,
193
+ weights_path: T.Optional[str] = None,
194
+ weights_file: T.Optional[str] = None,
195
+ max_retry: int = 8,
196
+ force_download: bool = False,
197
+ ):
198
+ """Load weights from a local file or the internet.
199
+
200
+ Args:
201
+ weights_path: The path to the weights file.
202
+ weights_file: The name of the weights file on the internet.
203
+ See https://huggingface.co/GangCaoLab/U-FISH/tree/main
204
+ for available weights files.
205
+ max_retry: The maximum number of retries to download the weights.
206
+ force_download: Whether to force download the weights.
207
+ """
208
+ if weights_path is not None:
209
+ self.load_weights_from_path(weights_path)
210
+ else:
211
+ if weights_file is not None:
212
+ self.load_weights_from_internet(
213
+ weights_file=weights_file,
214
+ max_retry=max_retry,
215
+ force_download=force_download,
216
+ )
217
+ else:
218
+ weights_path = osp.join(STATC_STORE_PATH, DEFAULT_WEIGHTS_FILE)
219
+ self.load_weights_from_path(weights_path)
220
+ return self
221
+
222
+ def _load_pth_file(self, path: T.Union[Path, str]) -> None:
223
+ """Load weights from a local file.
224
+
225
+ Args:
226
+ path: The path to the pth weights file."""
227
+ import torch
228
+ if self.model is None:
229
+ self.init_model()
230
+ assert self.model is not None
231
+ path = str(path)
232
+ logger.info(f'Loading weights from {path}')
233
+ device = torch.device('cuda' if self.cuda else 'cpu')
234
+ state_dict = torch.load(path, map_location=device)
235
+ self.model.load_state_dict(state_dict)
236
+
237
+ def _load_onnx(
238
+ self,
239
+ onnx_path: T.Union[Path, str],
240
+ providers: T.Optional[T.List[str]] = None,
241
+ ) -> None:
242
+ """Load weights from a local ONNX file,
243
+ and create an onnxruntime session.
244
+
245
+ Args:
246
+ onnx_path: The path to the ONNX file.
247
+ providers: The providers to use.
248
+ """
249
+ import onnxruntime
250
+ onnx_path = str(onnx_path)
251
+ logger.info(f'Loading ONNX from {onnx_path}')
252
+ if self._cuda:
253
+ providers = providers or ['CUDAExecutionProvider']
254
+ else:
255
+ providers = providers or ['CPUExecutionProvider']
256
+ self.ort_session = onnxruntime.InferenceSession(
257
+ onnx_path, providers=providers)
258
+
259
+ def infer(self, img: np.ndarray) -> np.ndarray:
260
+ """Infer the image using the U-Net model."""
261
+ if self.ort_session is not None:
262
+ output = self._infer_onnx(img)
263
+ elif self.model is not None:
264
+ output = self._infer_torch(img)
265
+ else:
266
+ raise RuntimeError(
267
+ 'Both torch model and ONNX model are not initialized.')
268
+ return output
269
+
270
+ def _infer_torch(self, img: np.ndarray) -> np.ndarray:
271
+ """Infer the image using the torch model."""
272
+ self._turn_on_infer_mode()
273
+ import torch
274
+ tensor = torch.from_numpy(img).float()
275
+ if self.cuda:
276
+ tensor = tensor.cuda()
277
+ with torch.no_grad():
278
+ output = self.model(tensor)
279
+ output = output.detach().cpu().numpy()
280
+ return output
281
+
282
+ def _infer_onnx(self, img: np.ndarray) -> np.ndarray:
283
+ """Infer the image using the ONNX model."""
284
+ ort_inputs = {self.ort_session.get_inputs()[0].name: img}
285
+ ort_outs = self.ort_session.run(None, ort_inputs)
286
+ output = ort_outs[0]
287
+ return output
288
+
289
+ def _enhance_img2d(self, img: np.ndarray) -> np.ndarray:
290
+ """Enhance a 2D image."""
291
+ output = self.infer(img[np.newaxis, np.newaxis])[0, 0]
292
+ return output
293
+
294
+ def _enhance_img3d(
295
+ self, img: np.ndarray, batch_size: int = 4) -> np.ndarray:
296
+ """Enhance a 3D image."""
297
+ logger.info(
298
+ f'Enhancing 3D image in shape {img.shape}, '
299
+ f'batch size: {batch_size}')
300
+ output = np.zeros_like(img, dtype=np.float32)
301
+ for i in range(0, output.shape[0], batch_size):
302
+ logger.info(
303
+ f'Enhancing slice {i}-{i+batch_size}/{output.shape[0]}')
304
+ _slice = img[i:i+batch_size][:, np.newaxis]
305
+ output[i:i+batch_size] = self.infer(_slice)[:, 0]
306
+ return output
307
+
308
+ def _enhance_2d_or_3d(
309
+ self,
310
+ img: np.ndarray,
311
+ axes: str,
312
+ batch_size: int = 4,
313
+ blend_3d: bool = False,
314
+ ) -> np.ndarray:
315
+ """Enhance a 2D or 3D image."""
316
+ from .utils.img import scale_image
317
+ img = scale_image(img, warning=True)
318
+ if img.ndim == 2:
319
+ output = self._enhance_img2d(img)
320
+ elif img.ndim == 3:
321
+ if blend_3d:
322
+ if 'z' not in axes:
323
+ logger.warning(
324
+ 'Image does not have a z axis, ' +
325
+ 'cannot blend along z axis.')
326
+ from .utils.img import enhance_blend_3d
327
+ logger.info(
328
+ "Blending 3D image from 3 directions: z, y, x.")
329
+ output = enhance_blend_3d(
330
+ img, self._enhance_img3d, axes=axes,
331
+ batch_size=batch_size)
332
+ else:
333
+ output = self._enhance_img3d(img, batch_size=batch_size)
334
+ else:
335
+ raise ValueError('Image must be 2D or 3D.')
336
+ return output
337
+
338
+ def call_spots(
339
+ self,
340
+ enhanced_img: np.ndarray,
341
+ method: str = 'local_maxima',
342
+ **kwargs,
343
+ ) -> pd.DataFrame:
344
+ """Call spots from enhanced image.
345
+
346
+ Args:
347
+ enhanced_img: The enhanced image.
348
+ method: The method to use for spot calling.
349
+ kwargs: Other arguments for the spot calling function.
350
+ """
351
+ assert enhanced_img.ndim in (2, 3), 'Image must be 2D or 3D.'
352
+ if method == 'cc_center':
353
+ from .utils.spot_calling import call_spots_cc_center as call_func
354
+ else:
355
+ from .utils.spot_calling import call_spots_local_maxima as call_func # noqa
356
+ df = call_func(enhanced_img, **kwargs)
357
+ return df
358
+
359
+ def _pred_2d_or_3d(
360
+ self, img: np.ndarray, axes: str,
361
+ blend_3d: bool = False,
362
+ batch_size: int = 4,
363
+ spots_calling_method: str = 'local_maxima',
364
+ **kwargs,
365
+ ) -> T.Tuple[pd.DataFrame, np.ndarray]:
366
+ """Predict the spots in a 2D or 3D image. """
367
+ assert img.ndim in (2, 3), 'Image must be 2D or 3D.'
368
+ assert len(axes) == img.ndim, \
369
+ "axes and image dimension must have the same length"
370
+ enhanced_img = self._enhance_2d_or_3d(
371
+ img, axes,
372
+ batch_size=batch_size,
373
+ blend_3d=(blend_3d and ('z' in axes))
374
+ )
375
+ df = self.call_spots(
376
+ enhanced_img,
377
+ method=spots_calling_method,
378
+ **kwargs)
379
+ return df, enhanced_img
380
+
381
+ def predict(
382
+ self, img: np.ndarray,
383
+ enh_img: T.Optional[np.ndarray] = None,
384
+ axes: T.Optional[str] = None,
385
+ blend_3d: bool = True,
386
+ batch_size: int = 4,
387
+ spots_calling_method: str = 'local_maxima',
388
+ **kwargs,
389
+ ) -> T.Tuple[pd.DataFrame, np.ndarray]:
390
+ """Predict the spots in an image.
391
+
392
+ Args:
393
+ img: The image to predict, it should be a multi dimensional array.
394
+ For example, shape (c, z, y, x) for a 4D image,
395
+ shape (z, y, x) for a 3D image,
396
+ shape (y, x) for a 2D image.
397
+ enh_img: The enhanced image, if None, will be created.
398
+ It can be a multi dimensional array or a zarr array.
399
+ axes: The axes of the image.
400
+ For example, 'czxy' for a 4D image,
401
+ 'yx' for a 2D image.
402
+ If None, will try to infer the axes from the shape.
403
+ blend_3d: Whether to blend the 3D image.
404
+ Used only when the image contains a z axis.
405
+ If True, will blend the 3D enhanced images along
406
+ the z, y, x axes.
407
+ batch_size: The batch size for inference.
408
+ Used only when the image dimension is 3 or higher.
409
+ spots_calling_method: The method to use for spot calling.
410
+ kwargs: Other arguments for the spot calling function.
411
+ """
412
+ from .utils.img import (
413
+ infer_img_axes, check_img_axes,
414
+ map_predfunc_to_img
415
+ )
416
+ if axes is None:
417
+ logger.info("Axes not specified, infering from image shape.")
418
+ axes = infer_img_axes(img.shape)
419
+ logger.info(f"Infered axes: {axes}, image shape: {img.shape}")
420
+ check_img_axes(img, axes)
421
+ if not isinstance(img, np.ndarray):
422
+ img = np.array(img)
423
+ predfunc = partial(
424
+ self._pred_2d_or_3d,
425
+ blend_3d=blend_3d,
426
+ batch_size=batch_size,
427
+ spots_calling_method=spots_calling_method,
428
+ **kwargs,
429
+ )
430
+ df, enhanced_img = map_predfunc_to_img(
431
+ predfunc, img, axes)
432
+ if enh_img is not None:
433
+ enh_img[:] = enhanced_img
434
+ return df, enhanced_img
435
+
436
+ def predict_chunks(
437
+ self,
438
+ img: np.ndarray,
439
+ enh_img: T.Optional[np.ndarray] = None,
440
+ axes: T.Optional[str] = None,
441
+ blend_3d: bool = True,
442
+ batch_size: int = 4,
443
+ chunk_size: T.Optional[T.Tuple[T.Union[int, str], ...]] = None,
444
+ spots_calling_method: str = 'local_maxima',
445
+ **kwargs,
446
+ ):
447
+ """Predict the spots in an image chunk by chunk.
448
+
449
+ Args:
450
+ img: The image to predict, it should be a multi dimensional array.
451
+ For example, shape (c, z, y, x) for a 4D image,
452
+ shape (z, y, x) for a 3D image,
453
+ shape (y, x) for a 2D image.
454
+ enh_img: The enhanced image, if None, will be created.
455
+ It can be a multi dimensional array or a zarr array.
456
+ axes: The axes of the image.
457
+ For example, 'czxy' for a 4D image,
458
+ 'yx' for a 2D image.
459
+ If None, will try to infer the axes from the shape.
460
+ blend_3d: Whether to blend the 3D image.
461
+ Used only when the image contains a z axis.
462
+ If True, will blend the 3D enhanced images along
463
+ the z, y, x axes.
464
+ batch_size: The batch size for inference.
465
+ Used only when the image dimension is 3 or higher.
466
+ chunk_size: The chunk size for processing.
467
+ For example, (1, 512, 512) for a 3D image,
468
+ (512, 512) for a 2D image.
469
+ Using 'image' as a dimension will use the whole image
470
+ as a chunk. For example, (1, 'image', 'image') for a 3D image,
471
+ ('image', 'image', 'image', 512, 512) for a 5D image.
472
+ If None, will use the default chunk size.
473
+ spots_calling_method: The method to use for spot calling.
474
+ kwargs: Other arguments for the spot calling function.
475
+ """
476
+ from .utils.img import (
477
+ check_img_axes, chunks_iterator,
478
+ process_chunk_size, infer_img_axes)
479
+ if axes is None:
480
+ axes = infer_img_axes(img.shape)
481
+ check_img_axes(img, axes)
482
+ if chunk_size is None:
483
+ from .utils.img import get_default_chunk_size
484
+ chunk_size = get_default_chunk_size(axes)
485
+ logger.info(f"Chunk size not specified, using {chunk_size}.")
486
+ chunk_size = process_chunk_size(chunk_size, img.shape)
487
+ logger.info(f"Chunk size: {chunk_size}")
488
+ total_dfs = []
489
+ if enh_img is None:
490
+ enh_img = np.zeros_like(img, dtype=np.float32)
491
+ for c_range, chunk in chunks_iterator(img, chunk_size):
492
+ logger.info("Processing chunk: " + str(c_range)
493
+ + ", chunk shape: " + str(chunk.shape))
494
+ c_df, c_enh = self.predict(
495
+ chunk, axes=axes, blend_3d=blend_3d,
496
+ batch_size=batch_size,
497
+ spots_calling_method=spots_calling_method,
498
+ **kwargs)
499
+ dim_start = [c_range[i][0] for i in range(len(axes))]
500
+ c_df += dim_start
501
+ total_dfs.append(c_df)
502
+ c_enh = c_enh[
503
+ tuple(slice(0, (r[1]-r[0])) for r in c_range)]
504
+ enh_img[tuple(slice(*r) for r in c_range)] = c_enh
505
+ df = pd.concat(total_dfs, ignore_index=True)
506
+ return df, enh_img
507
+
508
+ def evaluate_result_dp(
509
+ self,
510
+ pred: pd.DataFrame,
511
+ true: pd.DataFrame,
512
+ mdist: float = 3.0,
513
+ ) -> pd.DataFrame:
514
+ """Evaluate the prediction result using deepblink metrics.
515
+
516
+ Args:
517
+ pred: The predicted spots.
518
+ true: The true spots.
519
+ mdist: The maximum distance to consider a spot as a true positive.
520
+
521
+ Returns:
522
+ A pandas dataframe containing the evaluation metrics."""
523
+ from .utils.metrics_deepblink import compute_metrics
524
+ axis_names = list(pred.columns)
525
+ axis_cols = [n for n in axis_names if n.startswith('axis')]
526
+ pred = pred[axis_cols].values
527
+ true = true[axis_cols].values
528
+ metrics = compute_metrics(
529
+ pred, true, mdist=mdist)
530
+ return metrics
531
+
532
+ def evaluate_result(
533
+ self,
534
+ pred: pd.DataFrame,
535
+ true: pd.DataFrame,
536
+ cutoff: float = 3.0,
537
+ ) -> float:
538
+ """Calculate the F1 score of the prediction result.
539
+
540
+ Args:
541
+ pred: The predicted spots.
542
+ true: The true spots.
543
+ cutoff: The maximum distance to consider a spot as a true positive.
544
+ """
545
+ from .utils.metrics import compute_metrics
546
+ res = compute_metrics(pred.values, true.values, cutoff=cutoff)
547
+ return res
548
+
549
+ def plot_result(
550
+ self,
551
+ img: np.ndarray,
552
+ pred: pd.DataFrame,
553
+ fig_size: T.Tuple[int, int] = (10, 10),
554
+ image_cmap: str = 'gray',
555
+ marker_size: int = 20,
556
+ marker_color: str = 'red',
557
+ marker_style: str = 'x',
558
+ ) -> "Figure":
559
+ """Plot the prediction result.
560
+
561
+ Args:
562
+ img: The image to plot.
563
+ pred: The predicted spots.
564
+ fig_size: The figure size.
565
+ image_cmap: The colormap for the image.
566
+ marker_size: The marker size.
567
+ marker_color: The marker color.
568
+ marker_style: The marker style.
569
+ """
570
+ from .utils.plot import Plot2d
571
+ plt2d = Plot2d()
572
+ plt2d.default_figsize = fig_size
573
+ plt2d.default_marker_size = marker_size
574
+ plt2d.default_marker_color = marker_color
575
+ plt2d.default_marker_style = marker_style
576
+ plt2d.default_imshow_cmap = image_cmap
577
+ plt2d.new_fig()
578
+ plt2d.image(img)
579
+ plt2d.spots(pred.values)
580
+ return plt2d.fig
581
+
582
+ def plot_evaluate(
583
+ self,
584
+ img: np.ndarray,
585
+ pred: pd.DataFrame,
586
+ true: pd.DataFrame,
587
+ cutoff: float = 3.0,
588
+ fig_size: T.Tuple[int, int] = (10, 10),
589
+ image_cmap: str = 'gray',
590
+ marker_size: int = 20,
591
+ tp_color: str = 'green',
592
+ fp_color: str = 'red',
593
+ fn_color: str = 'yellow',
594
+ tp_marker: str = 'x',
595
+ fp_marker: str = 'x',
596
+ fn_marker: str = 'x',
597
+ ) -> "Figure":
598
+ """Plot the prediction result.
599
+
600
+ Args:
601
+ img: The image to plot.
602
+ pred: The predicted spots.
603
+ true: The true spots.
604
+ cutoff: The maximum distance to consider a spot as a true positive.
605
+ fig_size: The figure size.
606
+ image_cmap: The colormap for the image.
607
+ marker_size: The marker size.
608
+ tp_color: The color for true positive.
609
+ fp_color: The color for false positive.
610
+ fn_color: The color for false negative.
611
+ tp_marker_style: The marker style for true positive.
612
+ fp_marker_style: The marker style for false positive.
613
+ fn_marker_style: The marker style for false negative.
614
+ """
615
+ from .utils.plot import Plot2d
616
+ plt2d = Plot2d()
617
+ plt2d.default_figsize = fig_size
618
+ plt2d.default_marker_size = marker_size
619
+ plt2d.default_imshow_cmap = image_cmap
620
+ plt2d.new_fig()
621
+ plt2d.image(img)
622
+ plt2d.evaluate_result(
623
+ pred.values, true.values,
624
+ cutoff=cutoff,
625
+ tp_color=tp_color,
626
+ fp_color=fp_color,
627
+ fn_color=fn_color,
628
+ tp_marker=tp_marker,
629
+ fp_marker=fp_marker,
630
+ fn_marker=fn_marker,
631
+ )
632
+ return plt2d.fig
633
+
634
+ def _load_dataset(
635
+ self,
636
+ path: str,
637
+ root_dir_path: T.Optional[str] = None,
638
+ img_glob: str = '*.tif',
639
+ coord_glob: str = '*.csv',
640
+ process_func=None,
641
+ transform=None,
642
+ ):
643
+ """Load a dataset from a path."""
644
+ from .data import FISHSpotsDataset
645
+ _path = Path(path)
646
+ if _path.is_dir():
647
+ if root_dir_path is not None:
648
+ logger.info(f"Dataset's root directory: {root_dir_path}")
649
+ _path = Path(root_dir_path) / _path
650
+ logger.info(f"Loading dataset from dir: {_path}")
651
+ logger.info(
652
+ f'Image glob: {img_glob}, Coordinate glob: {coord_glob}')
653
+ _path_str = str(_path)
654
+ dataset = FISHSpotsDataset.from_dir(
655
+ _path_str, _path_str,
656
+ img_glob=img_glob, coord_glob=coord_glob,
657
+ process_func=process_func, transform=transform)
658
+ else:
659
+ logger.info(f"Loading dataset using meta csv: {_path}")
660
+ assert _path.suffix == '.csv', \
661
+ "Meta file must be a csv file."
662
+ root_dir = root_dir_path or _path.parent
663
+ logger.info(f'Data root directory: {root_dir}')
664
+ dataset = FISHSpotsDataset.from_meta_csv(
665
+ root_dir=root_dir, meta_csv_path=str(_path),
666
+ process_func=process_func, transform=transform)
667
+ return dataset
668
+
669
+ def train(
670
+ self,
671
+ train_path: str,
672
+ valid_path: str,
673
+ root_dir: T.Optional[str] = None,
674
+ img_glob: str = '*.tif',
675
+ coord_glob: str = '*.csv',
676
+ target_process: T.Optional[str] = 'gaussian',
677
+ loss_type: str = 'DiceRMSELoss',
678
+ loader_workers: int = 4,
679
+ data_argu: bool = False,
680
+ argu_prob: float = 0.5,
681
+ num_epochs: int = 50,
682
+ batch_size: int = 8,
683
+ lr: float = 1e-3,
684
+ summary_dir: str = "runs/unet",
685
+ model_save_dir: str = "./models",
686
+ save_period: int = 5,
687
+ ):
688
+ """Train the U-Net model.
689
+
690
+ Args:
691
+ train_path: The path to the training dataset.
692
+ Path to a directory containing images and coordinates,
693
+ or a meta csv file.
694
+ valid_path: The path to the validation dataset.
695
+ Path to a directory containing images and coordinates,
696
+ or a meta csv file.
697
+ root_dir: The root directory of the dataset.
698
+ If using meta csv, the root directory of the dataset.
699
+ img_glob: The glob pattern for the image files.
700
+ coord_glob: The glob pattern for the coordinate files.
701
+ target_process: The target image processing method.
702
+ 'gaussian' or 'dialation' or None.
703
+ If None, no processing will be applied.
704
+ default 'gaussian'.
705
+ loss_type: The loss function type.
706
+ loader_workers: The number of workers to use for the data loader.
707
+ data_argu: Whether to use data augmentation.
708
+ argu_prob: The probability to use data augmentation.
709
+ num_epochs: The number of epochs to train.
710
+ batch_size: The batch size.
711
+ lr: The learning rate.
712
+ summary_dir: The directory to save the TensorBoard summary to.
713
+ model_save_dir: The directory to save the model to.
714
+ save_period: Save the model every `save_period` epochs.
715
+ """
716
+ from .model.train import train_on_dataset
717
+ from .data import FISHSpotsDataset
718
+ if self.model is None:
719
+ logger.info('Model is not initialized. Will initialize a new one.')
720
+ self.init_model()
721
+ assert self.model is not None
722
+
723
+ if data_argu:
724
+ logger.info(
725
+ 'Using data augmentation. ' +
726
+ f'Probability: {argu_prob}'
727
+ )
728
+ from .data import DataAugmentation
729
+ transform = DataAugmentation(p=argu_prob)
730
+ else:
731
+ transform = None
732
+
733
+ logger.info(f'Using {target_process} as target process.')
734
+ if target_process == 'gaussian':
735
+ process_func = FISHSpotsDataset.gaussian_filter
736
+ elif target_process == 'dialation':
737
+ process_func = FISHSpotsDataset.dialate_mask
738
+ elif isinstance(target_process, str):
739
+ from functools import partial
740
+ process_func = partial(
741
+ FISHSpotsDataset.dialate_mask,
742
+ footprint=target_process)
743
+ else:
744
+ process_func = None
745
+
746
+ logger.info(f"Loading training dataset from {train_path}")
747
+ train_dataset = self._load_dataset(
748
+ train_path, root_dir_path=root_dir,
749
+ img_glob=img_glob, coord_glob=coord_glob,
750
+ process_func=process_func, transform=transform,
751
+ )
752
+ logger.info(f"Loading validation dataset from {valid_path}")
753
+ valid_dataset = self._load_dataset(
754
+ valid_path, root_dir_path=root_dir,
755
+ img_glob=img_glob, coord_glob=coord_glob,
756
+ process_func=process_func,
757
+ )
758
+ logger.info(
759
+ f"Training dataset size: {len(train_dataset)}, "
760
+ f"Validation dataset size: {len(valid_dataset)}"
761
+ )
762
+ logger.info(
763
+ f"Number of epochs: {num_epochs}, "
764
+ f"Batch size: {batch_size}, "
765
+ f"Learning rate: {lr}"
766
+ )
767
+ train_on_dataset(
768
+ self.model,
769
+ train_dataset, valid_dataset,
770
+ loss_type=loss_type,
771
+ loader_workers=loader_workers,
772
+ num_epochs=num_epochs,
773
+ batch_size=batch_size,
774
+ lr=lr,
775
+ summary_dir=summary_dir,
776
+ model_save_dir=model_save_dir,
777
+ save_period=save_period,
778
+ )