senoquant 1.0.0b2__py3-none-any.whl → 1.0.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- senoquant/__init__.py +6 -2
- senoquant/_reader.py +1 -1
- senoquant/_widget.py +9 -1
- senoquant/reader/core.py +201 -18
- senoquant/tabs/__init__.py +2 -0
- senoquant/tabs/batch/backend.py +76 -27
- senoquant/tabs/batch/frontend.py +127 -25
- senoquant/tabs/quantification/features/marker/dialog.py +26 -6
- senoquant/tabs/quantification/features/marker/export.py +97 -24
- senoquant/tabs/quantification/features/marker/rows.py +2 -2
- senoquant/tabs/quantification/features/spots/dialog.py +41 -11
- senoquant/tabs/quantification/features/spots/export.py +163 -10
- senoquant/tabs/quantification/frontend.py +2 -2
- senoquant/tabs/segmentation/frontend.py +46 -9
- senoquant/tabs/segmentation/models/cpsam/model.py +1 -1
- senoquant/tabs/segmentation/models/default_2d/model.py +22 -77
- senoquant/tabs/segmentation/models/default_3d/model.py +8 -74
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +13 -13
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/stardist_libs.py +171 -0
- senoquant/tabs/spots/frontend.py +96 -5
- senoquant/tabs/spots/models/rmp/details.json +3 -9
- senoquant/tabs/spots/models/rmp/model.py +341 -266
- senoquant/tabs/spots/models/ufish/details.json +32 -0
- senoquant/tabs/spots/models/ufish/model.py +327 -0
- senoquant/tabs/spots/ufish_utils/__init__.py +13 -0
- senoquant/tabs/spots/ufish_utils/core.py +387 -0
- senoquant/tabs/visualization/__init__.py +1 -0
- senoquant/tabs/visualization/backend.py +306 -0
- senoquant/tabs/visualization/frontend.py +1113 -0
- senoquant/tabs/visualization/plots/__init__.py +80 -0
- senoquant/tabs/visualization/plots/base.py +152 -0
- senoquant/tabs/visualization/plots/double_expression.py +187 -0
- senoquant/tabs/visualization/plots/spatialplot.py +156 -0
- senoquant/tabs/visualization/plots/umap.py +140 -0
- senoquant/utils.py +1 -1
- senoquant-1.0.0b4.dist-info/METADATA +162 -0
- {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b4.dist-info}/RECORD +53 -30
- {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b4.dist-info}/top_level.txt +1 -0
- ufish/__init__.py +1 -0
- ufish/api.py +778 -0
- ufish/model/__init__.py +0 -0
- ufish/model/loss.py +62 -0
- ufish/model/network/__init__.py +0 -0
- ufish/model/network/spot_learn.py +50 -0
- ufish/model/network/ufish_net.py +204 -0
- ufish/model/train.py +175 -0
- ufish/utils/__init__.py +0 -0
- ufish/utils/img.py +418 -0
- ufish/utils/log.py +8 -0
- ufish/utils/spot_calling.py +115 -0
- senoquant/tabs/spots/models/udwt/details.json +0 -103
- senoquant/tabs/spots/models/udwt/model.py +0 -482
- senoquant-1.0.0b2.dist-info/METADATA +0 -193
- {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b4.dist-info}/WHEEL +0 -0
- {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b4.dist-info}/entry_points.txt +0 -0
- {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b4.dist-info}/licenses/LICENSE +0 -0
ufish/api.py
ADDED
|
@@ -0,0 +1,778 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import os.path as osp
|
|
3
|
+
import time
|
|
4
|
+
import typing as T
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from functools import partial
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
from .utils.log import logger
|
|
12
|
+
|
|
13
|
+
if T.TYPE_CHECKING:
|
|
14
|
+
from torch import nn
|
|
15
|
+
from matplotlib.figure import Figure
|
|
16
|
+
import onnxruntime
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
BASE_STORE_URL = 'https://huggingface.co/GangCaoLab/U-FISH/resolve/main/'
|
|
20
|
+
DEFAULT_WEIGHTS_FILE = 'v1.0-alldata-ufish_c32.onnx'
|
|
21
|
+
STATC_STORE_PATH = osp.abspath(
|
|
22
|
+
osp.join(osp.dirname(__file__), "model/weights/"))
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class UFish():
|
|
26
|
+
def __init__(
|
|
27
|
+
self, cuda: bool = True,
|
|
28
|
+
default_weights_file: T.Optional[str] = None,
|
|
29
|
+
local_store_path: str = '~/.ufish/'
|
|
30
|
+
) -> None:
|
|
31
|
+
"""
|
|
32
|
+
Args:
|
|
33
|
+
cuda: Whether to use GPU.
|
|
34
|
+
default_weight_file: The default weight file to use.
|
|
35
|
+
local_store_path: The local path to store the weights.
|
|
36
|
+
"""
|
|
37
|
+
self._cuda = cuda
|
|
38
|
+
self._infer_mode = False
|
|
39
|
+
self.model: T.Optional["nn.Module"] = None
|
|
40
|
+
self.ort_session: T.Optional["onnxruntime.InferenceSession"] = None
|
|
41
|
+
if default_weights_file is None:
|
|
42
|
+
default_weights_file = DEFAULT_WEIGHTS_FILE
|
|
43
|
+
self.default_weights_file = default_weights_file
|
|
44
|
+
self.store_base_url = BASE_STORE_URL
|
|
45
|
+
self.local_store_path = Path(
|
|
46
|
+
os.path.expanduser(local_store_path))
|
|
47
|
+
self.weight_path: T.Optional[str] = None
|
|
48
|
+
|
|
49
|
+
def init_model(
|
|
50
|
+
self,
|
|
51
|
+
model_type: str = 'ufish',
|
|
52
|
+
**kwargs) -> None:
|
|
53
|
+
"""Initialize the model.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
model_type: The type of the model. For example,
|
|
57
|
+
'ufish', 'spot_learn', ...
|
|
58
|
+
kwargs: Other arguments for the model.
|
|
59
|
+
"""
|
|
60
|
+
import torch
|
|
61
|
+
if model_type == 'ufish':
|
|
62
|
+
from .model.network.ufish_net import UFishNet
|
|
63
|
+
self.model = UFishNet(**kwargs)
|
|
64
|
+
elif model_type == 'spot_learn':
|
|
65
|
+
from .model.network.spot_learn import SpotLearn
|
|
66
|
+
self.model = SpotLearn(**kwargs)
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError(f'Unknown model type: {model_type}')
|
|
69
|
+
params = sum(p.numel() for p in self.model.parameters())
|
|
70
|
+
logger.info(
|
|
71
|
+
f'Initializing {model_type} model with kwargs: {kwargs}')
|
|
72
|
+
logger.info(f'Number of parameters: {params}')
|
|
73
|
+
self.cuda = False
|
|
74
|
+
if self._cuda:
|
|
75
|
+
if torch.cuda.is_available():
|
|
76
|
+
self.model = self.model.cuda()
|
|
77
|
+
self.cuda = True
|
|
78
|
+
logger.info('CUDA is available, using GPU.')
|
|
79
|
+
else:
|
|
80
|
+
logger.warning('CUDA is not available, using CPU.')
|
|
81
|
+
else:
|
|
82
|
+
logger.info('CUDA is not used, using CPU.')
|
|
83
|
+
|
|
84
|
+
def convert_to_onnx(
|
|
85
|
+
self,
|
|
86
|
+
output_path: T.Union[Path, str],) -> None:
|
|
87
|
+
"""Convert the model to ONNX format.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
output_path: The path to the output ONNX file.
|
|
91
|
+
"""
|
|
92
|
+
if self.model is None:
|
|
93
|
+
raise RuntimeError('Model is not initialized.')
|
|
94
|
+
self._turn_on_infer_mode(trace_model=True)
|
|
95
|
+
import torch
|
|
96
|
+
import torch.onnx
|
|
97
|
+
output_path = str(output_path)
|
|
98
|
+
logger.info(
|
|
99
|
+
f'Converting model to ONNX format, saving to {output_path}')
|
|
100
|
+
device = torch.device('cuda' if self.cuda else 'cpu')
|
|
101
|
+
inp = torch.rand(1, 1, 512, 512).to(device)
|
|
102
|
+
dyn_axes = {0: 'batch_size', 2: 'y', 3: 'x'}
|
|
103
|
+
torch.onnx.export(
|
|
104
|
+
self.model, inp, output_path,
|
|
105
|
+
input_names=['input'],
|
|
106
|
+
output_names=['output'],
|
|
107
|
+
opset_version=11,
|
|
108
|
+
do_constant_folding=True,
|
|
109
|
+
dynamic_axes={
|
|
110
|
+
'input': dyn_axes,
|
|
111
|
+
'output': dyn_axes,
|
|
112
|
+
},
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def _turn_on_infer_mode(self, trace_model: bool = False) -> None:
|
|
116
|
+
"""Turn on the infer mode."""
|
|
117
|
+
if self._infer_mode:
|
|
118
|
+
return
|
|
119
|
+
self._infer_mode = True
|
|
120
|
+
self.model.eval()
|
|
121
|
+
if trace_model:
|
|
122
|
+
import torch
|
|
123
|
+
device = next(self.model.parameters()).device
|
|
124
|
+
inp = torch.rand(1, 1, 512, 512).to(device)
|
|
125
|
+
self.model = torch.jit.trace(self.model, inp)
|
|
126
|
+
|
|
127
|
+
def load_weights_from_internet(
|
|
128
|
+
self,
|
|
129
|
+
weights_file: T.Optional[str] = None,
|
|
130
|
+
max_retry: int = 8,
|
|
131
|
+
force_download: bool = False,
|
|
132
|
+
) -> None:
|
|
133
|
+
"""Load weights from the huggingface repo.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
weights_file: The name of the weights file on the internet.
|
|
137
|
+
See https://huggingface.co/GangCaoLab/U-FISH/tree/main
|
|
138
|
+
for available weights files.
|
|
139
|
+
max_retry: The maximum number of retries.
|
|
140
|
+
force_download: Whether to force download the weights.
|
|
141
|
+
"""
|
|
142
|
+
import torch
|
|
143
|
+
weights_file = weights_file or self.default_weights_file
|
|
144
|
+
weight_url = self.store_base_url + weights_file
|
|
145
|
+
local_weight_path = self.local_store_path / weights_file
|
|
146
|
+
if local_weight_path.exists() and (not force_download):
|
|
147
|
+
logger.info(
|
|
148
|
+
f'Local weights {local_weight_path} exists, '
|
|
149
|
+
'skip downloading.'
|
|
150
|
+
)
|
|
151
|
+
else:
|
|
152
|
+
logger.info(
|
|
153
|
+
f'Downloading weights from {weight_url}, '
|
|
154
|
+
f'storing to {local_weight_path}')
|
|
155
|
+
local_weight_path.parent.mkdir(parents=True, exist_ok=True)
|
|
156
|
+
try_count = 0
|
|
157
|
+
while try_count < max_retry:
|
|
158
|
+
try:
|
|
159
|
+
torch.hub.download_url_to_file(
|
|
160
|
+
weight_url, local_weight_path)
|
|
161
|
+
break
|
|
162
|
+
except Exception as e:
|
|
163
|
+
logger.warning(f'Error downloading weights: {e}')
|
|
164
|
+
try_count += 1
|
|
165
|
+
time.sleep(0.5)
|
|
166
|
+
else:
|
|
167
|
+
raise RuntimeError(
|
|
168
|
+
f'Error downloading weights from {weight_url}.')
|
|
169
|
+
self.load_weights_from_path(local_weight_path)
|
|
170
|
+
|
|
171
|
+
def load_weights_from_path(
|
|
172
|
+
self,
|
|
173
|
+
path: T.Union[Path, str],
|
|
174
|
+
) -> None:
|
|
175
|
+
"""Load weights from a local file.
|
|
176
|
+
The file can be a .pth file or an .onnx file.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
path: The path to the weights file.
|
|
180
|
+
"""
|
|
181
|
+
path = str(path)
|
|
182
|
+
self.weight_path = path
|
|
183
|
+
if path.endswith('.pth'):
|
|
184
|
+
self._load_pth_file(path)
|
|
185
|
+
elif path.endswith('.onnx'):
|
|
186
|
+
self._load_onnx(path)
|
|
187
|
+
else:
|
|
188
|
+
raise ValueError(
|
|
189
|
+
'Weights file must be a pth file or an onnx file.')
|
|
190
|
+
|
|
191
|
+
def load_weights(
|
|
192
|
+
self,
|
|
193
|
+
weights_path: T.Optional[str] = None,
|
|
194
|
+
weights_file: T.Optional[str] = None,
|
|
195
|
+
max_retry: int = 8,
|
|
196
|
+
force_download: bool = False,
|
|
197
|
+
):
|
|
198
|
+
"""Load weights from a local file or the internet.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
weights_path: The path to the weights file.
|
|
202
|
+
weights_file: The name of the weights file on the internet.
|
|
203
|
+
See https://huggingface.co/GangCaoLab/U-FISH/tree/main
|
|
204
|
+
for available weights files.
|
|
205
|
+
max_retry: The maximum number of retries to download the weights.
|
|
206
|
+
force_download: Whether to force download the weights.
|
|
207
|
+
"""
|
|
208
|
+
if weights_path is not None:
|
|
209
|
+
self.load_weights_from_path(weights_path)
|
|
210
|
+
else:
|
|
211
|
+
if weights_file is not None:
|
|
212
|
+
self.load_weights_from_internet(
|
|
213
|
+
weights_file=weights_file,
|
|
214
|
+
max_retry=max_retry,
|
|
215
|
+
force_download=force_download,
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
weights_path = osp.join(STATC_STORE_PATH, DEFAULT_WEIGHTS_FILE)
|
|
219
|
+
self.load_weights_from_path(weights_path)
|
|
220
|
+
return self
|
|
221
|
+
|
|
222
|
+
def _load_pth_file(self, path: T.Union[Path, str]) -> None:
|
|
223
|
+
"""Load weights from a local file.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
path: The path to the pth weights file."""
|
|
227
|
+
import torch
|
|
228
|
+
if self.model is None:
|
|
229
|
+
self.init_model()
|
|
230
|
+
assert self.model is not None
|
|
231
|
+
path = str(path)
|
|
232
|
+
logger.info(f'Loading weights from {path}')
|
|
233
|
+
device = torch.device('cuda' if self.cuda else 'cpu')
|
|
234
|
+
state_dict = torch.load(path, map_location=device)
|
|
235
|
+
self.model.load_state_dict(state_dict)
|
|
236
|
+
|
|
237
|
+
def _load_onnx(
|
|
238
|
+
self,
|
|
239
|
+
onnx_path: T.Union[Path, str],
|
|
240
|
+
providers: T.Optional[T.List[str]] = None,
|
|
241
|
+
) -> None:
|
|
242
|
+
"""Load weights from a local ONNX file,
|
|
243
|
+
and create an onnxruntime session.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
onnx_path: The path to the ONNX file.
|
|
247
|
+
providers: The providers to use.
|
|
248
|
+
"""
|
|
249
|
+
import onnxruntime
|
|
250
|
+
onnx_path = str(onnx_path)
|
|
251
|
+
logger.info(f'Loading ONNX from {onnx_path}')
|
|
252
|
+
if self._cuda:
|
|
253
|
+
providers = providers or ['CUDAExecutionProvider']
|
|
254
|
+
else:
|
|
255
|
+
providers = providers or ['CPUExecutionProvider']
|
|
256
|
+
self.ort_session = onnxruntime.InferenceSession(
|
|
257
|
+
onnx_path, providers=providers)
|
|
258
|
+
|
|
259
|
+
def infer(self, img: np.ndarray) -> np.ndarray:
|
|
260
|
+
"""Infer the image using the U-Net model."""
|
|
261
|
+
if self.ort_session is not None:
|
|
262
|
+
output = self._infer_onnx(img)
|
|
263
|
+
elif self.model is not None:
|
|
264
|
+
output = self._infer_torch(img)
|
|
265
|
+
else:
|
|
266
|
+
raise RuntimeError(
|
|
267
|
+
'Both torch model and ONNX model are not initialized.')
|
|
268
|
+
return output
|
|
269
|
+
|
|
270
|
+
def _infer_torch(self, img: np.ndarray) -> np.ndarray:
|
|
271
|
+
"""Infer the image using the torch model."""
|
|
272
|
+
self._turn_on_infer_mode()
|
|
273
|
+
import torch
|
|
274
|
+
tensor = torch.from_numpy(img).float()
|
|
275
|
+
if self.cuda:
|
|
276
|
+
tensor = tensor.cuda()
|
|
277
|
+
with torch.no_grad():
|
|
278
|
+
output = self.model(tensor)
|
|
279
|
+
output = output.detach().cpu().numpy()
|
|
280
|
+
return output
|
|
281
|
+
|
|
282
|
+
def _infer_onnx(self, img: np.ndarray) -> np.ndarray:
|
|
283
|
+
"""Infer the image using the ONNX model."""
|
|
284
|
+
ort_inputs = {self.ort_session.get_inputs()[0].name: img}
|
|
285
|
+
ort_outs = self.ort_session.run(None, ort_inputs)
|
|
286
|
+
output = ort_outs[0]
|
|
287
|
+
return output
|
|
288
|
+
|
|
289
|
+
def _enhance_img2d(self, img: np.ndarray) -> np.ndarray:
|
|
290
|
+
"""Enhance a 2D image."""
|
|
291
|
+
output = self.infer(img[np.newaxis, np.newaxis])[0, 0]
|
|
292
|
+
return output
|
|
293
|
+
|
|
294
|
+
def _enhance_img3d(
|
|
295
|
+
self, img: np.ndarray, batch_size: int = 4) -> np.ndarray:
|
|
296
|
+
"""Enhance a 3D image."""
|
|
297
|
+
logger.info(
|
|
298
|
+
f'Enhancing 3D image in shape {img.shape}, '
|
|
299
|
+
f'batch size: {batch_size}')
|
|
300
|
+
output = np.zeros_like(img, dtype=np.float32)
|
|
301
|
+
for i in range(0, output.shape[0], batch_size):
|
|
302
|
+
logger.info(
|
|
303
|
+
f'Enhancing slice {i}-{i+batch_size}/{output.shape[0]}')
|
|
304
|
+
_slice = img[i:i+batch_size][:, np.newaxis]
|
|
305
|
+
output[i:i+batch_size] = self.infer(_slice)[:, 0]
|
|
306
|
+
return output
|
|
307
|
+
|
|
308
|
+
def _enhance_2d_or_3d(
|
|
309
|
+
self,
|
|
310
|
+
img: np.ndarray,
|
|
311
|
+
axes: str,
|
|
312
|
+
batch_size: int = 4,
|
|
313
|
+
blend_3d: bool = False,
|
|
314
|
+
) -> np.ndarray:
|
|
315
|
+
"""Enhance a 2D or 3D image."""
|
|
316
|
+
from .utils.img import scale_image
|
|
317
|
+
img = scale_image(img, warning=True)
|
|
318
|
+
if img.ndim == 2:
|
|
319
|
+
output = self._enhance_img2d(img)
|
|
320
|
+
elif img.ndim == 3:
|
|
321
|
+
if blend_3d:
|
|
322
|
+
if 'z' not in axes:
|
|
323
|
+
logger.warning(
|
|
324
|
+
'Image does not have a z axis, ' +
|
|
325
|
+
'cannot blend along z axis.')
|
|
326
|
+
from .utils.img import enhance_blend_3d
|
|
327
|
+
logger.info(
|
|
328
|
+
"Blending 3D image from 3 directions: z, y, x.")
|
|
329
|
+
output = enhance_blend_3d(
|
|
330
|
+
img, self._enhance_img3d, axes=axes,
|
|
331
|
+
batch_size=batch_size)
|
|
332
|
+
else:
|
|
333
|
+
output = self._enhance_img3d(img, batch_size=batch_size)
|
|
334
|
+
else:
|
|
335
|
+
raise ValueError('Image must be 2D or 3D.')
|
|
336
|
+
return output
|
|
337
|
+
|
|
338
|
+
def call_spots(
|
|
339
|
+
self,
|
|
340
|
+
enhanced_img: np.ndarray,
|
|
341
|
+
method: str = 'local_maxima',
|
|
342
|
+
**kwargs,
|
|
343
|
+
) -> pd.DataFrame:
|
|
344
|
+
"""Call spots from enhanced image.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
enhanced_img: The enhanced image.
|
|
348
|
+
method: The method to use for spot calling.
|
|
349
|
+
kwargs: Other arguments for the spot calling function.
|
|
350
|
+
"""
|
|
351
|
+
assert enhanced_img.ndim in (2, 3), 'Image must be 2D or 3D.'
|
|
352
|
+
if method == 'cc_center':
|
|
353
|
+
from .utils.spot_calling import call_spots_cc_center as call_func
|
|
354
|
+
else:
|
|
355
|
+
from .utils.spot_calling import call_spots_local_maxima as call_func # noqa
|
|
356
|
+
df = call_func(enhanced_img, **kwargs)
|
|
357
|
+
return df
|
|
358
|
+
|
|
359
|
+
def _pred_2d_or_3d(
|
|
360
|
+
self, img: np.ndarray, axes: str,
|
|
361
|
+
blend_3d: bool = False,
|
|
362
|
+
batch_size: int = 4,
|
|
363
|
+
spots_calling_method: str = 'local_maxima',
|
|
364
|
+
**kwargs,
|
|
365
|
+
) -> T.Tuple[pd.DataFrame, np.ndarray]:
|
|
366
|
+
"""Predict the spots in a 2D or 3D image. """
|
|
367
|
+
assert img.ndim in (2, 3), 'Image must be 2D or 3D.'
|
|
368
|
+
assert len(axes) == img.ndim, \
|
|
369
|
+
"axes and image dimension must have the same length"
|
|
370
|
+
enhanced_img = self._enhance_2d_or_3d(
|
|
371
|
+
img, axes,
|
|
372
|
+
batch_size=batch_size,
|
|
373
|
+
blend_3d=(blend_3d and ('z' in axes))
|
|
374
|
+
)
|
|
375
|
+
df = self.call_spots(
|
|
376
|
+
enhanced_img,
|
|
377
|
+
method=spots_calling_method,
|
|
378
|
+
**kwargs)
|
|
379
|
+
return df, enhanced_img
|
|
380
|
+
|
|
381
|
+
def predict(
|
|
382
|
+
self, img: np.ndarray,
|
|
383
|
+
enh_img: T.Optional[np.ndarray] = None,
|
|
384
|
+
axes: T.Optional[str] = None,
|
|
385
|
+
blend_3d: bool = True,
|
|
386
|
+
batch_size: int = 4,
|
|
387
|
+
spots_calling_method: str = 'local_maxima',
|
|
388
|
+
**kwargs,
|
|
389
|
+
) -> T.Tuple[pd.DataFrame, np.ndarray]:
|
|
390
|
+
"""Predict the spots in an image.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
img: The image to predict, it should be a multi dimensional array.
|
|
394
|
+
For example, shape (c, z, y, x) for a 4D image,
|
|
395
|
+
shape (z, y, x) for a 3D image,
|
|
396
|
+
shape (y, x) for a 2D image.
|
|
397
|
+
enh_img: The enhanced image, if None, will be created.
|
|
398
|
+
It can be a multi dimensional array or a zarr array.
|
|
399
|
+
axes: The axes of the image.
|
|
400
|
+
For example, 'czxy' for a 4D image,
|
|
401
|
+
'yx' for a 2D image.
|
|
402
|
+
If None, will try to infer the axes from the shape.
|
|
403
|
+
blend_3d: Whether to blend the 3D image.
|
|
404
|
+
Used only when the image contains a z axis.
|
|
405
|
+
If True, will blend the 3D enhanced images along
|
|
406
|
+
the z, y, x axes.
|
|
407
|
+
batch_size: The batch size for inference.
|
|
408
|
+
Used only when the image dimension is 3 or higher.
|
|
409
|
+
spots_calling_method: The method to use for spot calling.
|
|
410
|
+
kwargs: Other arguments for the spot calling function.
|
|
411
|
+
"""
|
|
412
|
+
from .utils.img import (
|
|
413
|
+
infer_img_axes, check_img_axes,
|
|
414
|
+
map_predfunc_to_img
|
|
415
|
+
)
|
|
416
|
+
if axes is None:
|
|
417
|
+
logger.info("Axes not specified, infering from image shape.")
|
|
418
|
+
axes = infer_img_axes(img.shape)
|
|
419
|
+
logger.info(f"Infered axes: {axes}, image shape: {img.shape}")
|
|
420
|
+
check_img_axes(img, axes)
|
|
421
|
+
if not isinstance(img, np.ndarray):
|
|
422
|
+
img = np.array(img)
|
|
423
|
+
predfunc = partial(
|
|
424
|
+
self._pred_2d_or_3d,
|
|
425
|
+
blend_3d=blend_3d,
|
|
426
|
+
batch_size=batch_size,
|
|
427
|
+
spots_calling_method=spots_calling_method,
|
|
428
|
+
**kwargs,
|
|
429
|
+
)
|
|
430
|
+
df, enhanced_img = map_predfunc_to_img(
|
|
431
|
+
predfunc, img, axes)
|
|
432
|
+
if enh_img is not None:
|
|
433
|
+
enh_img[:] = enhanced_img
|
|
434
|
+
return df, enhanced_img
|
|
435
|
+
|
|
436
|
+
def predict_chunks(
|
|
437
|
+
self,
|
|
438
|
+
img: np.ndarray,
|
|
439
|
+
enh_img: T.Optional[np.ndarray] = None,
|
|
440
|
+
axes: T.Optional[str] = None,
|
|
441
|
+
blend_3d: bool = True,
|
|
442
|
+
batch_size: int = 4,
|
|
443
|
+
chunk_size: T.Optional[T.Tuple[T.Union[int, str], ...]] = None,
|
|
444
|
+
spots_calling_method: str = 'local_maxima',
|
|
445
|
+
**kwargs,
|
|
446
|
+
):
|
|
447
|
+
"""Predict the spots in an image chunk by chunk.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
img: The image to predict, it should be a multi dimensional array.
|
|
451
|
+
For example, shape (c, z, y, x) for a 4D image,
|
|
452
|
+
shape (z, y, x) for a 3D image,
|
|
453
|
+
shape (y, x) for a 2D image.
|
|
454
|
+
enh_img: The enhanced image, if None, will be created.
|
|
455
|
+
It can be a multi dimensional array or a zarr array.
|
|
456
|
+
axes: The axes of the image.
|
|
457
|
+
For example, 'czxy' for a 4D image,
|
|
458
|
+
'yx' for a 2D image.
|
|
459
|
+
If None, will try to infer the axes from the shape.
|
|
460
|
+
blend_3d: Whether to blend the 3D image.
|
|
461
|
+
Used only when the image contains a z axis.
|
|
462
|
+
If True, will blend the 3D enhanced images along
|
|
463
|
+
the z, y, x axes.
|
|
464
|
+
batch_size: The batch size for inference.
|
|
465
|
+
Used only when the image dimension is 3 or higher.
|
|
466
|
+
chunk_size: The chunk size for processing.
|
|
467
|
+
For example, (1, 512, 512) for a 3D image,
|
|
468
|
+
(512, 512) for a 2D image.
|
|
469
|
+
Using 'image' as a dimension will use the whole image
|
|
470
|
+
as a chunk. For example, (1, 'image', 'image') for a 3D image,
|
|
471
|
+
('image', 'image', 'image', 512, 512) for a 5D image.
|
|
472
|
+
If None, will use the default chunk size.
|
|
473
|
+
spots_calling_method: The method to use for spot calling.
|
|
474
|
+
kwargs: Other arguments for the spot calling function.
|
|
475
|
+
"""
|
|
476
|
+
from .utils.img import (
|
|
477
|
+
check_img_axes, chunks_iterator,
|
|
478
|
+
process_chunk_size, infer_img_axes)
|
|
479
|
+
if axes is None:
|
|
480
|
+
axes = infer_img_axes(img.shape)
|
|
481
|
+
check_img_axes(img, axes)
|
|
482
|
+
if chunk_size is None:
|
|
483
|
+
from .utils.img import get_default_chunk_size
|
|
484
|
+
chunk_size = get_default_chunk_size(axes)
|
|
485
|
+
logger.info(f"Chunk size not specified, using {chunk_size}.")
|
|
486
|
+
chunk_size = process_chunk_size(chunk_size, img.shape)
|
|
487
|
+
logger.info(f"Chunk size: {chunk_size}")
|
|
488
|
+
total_dfs = []
|
|
489
|
+
if enh_img is None:
|
|
490
|
+
enh_img = np.zeros_like(img, dtype=np.float32)
|
|
491
|
+
for c_range, chunk in chunks_iterator(img, chunk_size):
|
|
492
|
+
logger.info("Processing chunk: " + str(c_range)
|
|
493
|
+
+ ", chunk shape: " + str(chunk.shape))
|
|
494
|
+
c_df, c_enh = self.predict(
|
|
495
|
+
chunk, axes=axes, blend_3d=blend_3d,
|
|
496
|
+
batch_size=batch_size,
|
|
497
|
+
spots_calling_method=spots_calling_method,
|
|
498
|
+
**kwargs)
|
|
499
|
+
dim_start = [c_range[i][0] for i in range(len(axes))]
|
|
500
|
+
c_df += dim_start
|
|
501
|
+
total_dfs.append(c_df)
|
|
502
|
+
c_enh = c_enh[
|
|
503
|
+
tuple(slice(0, (r[1]-r[0])) for r in c_range)]
|
|
504
|
+
enh_img[tuple(slice(*r) for r in c_range)] = c_enh
|
|
505
|
+
df = pd.concat(total_dfs, ignore_index=True)
|
|
506
|
+
return df, enh_img
|
|
507
|
+
|
|
508
|
+
def evaluate_result_dp(
|
|
509
|
+
self,
|
|
510
|
+
pred: pd.DataFrame,
|
|
511
|
+
true: pd.DataFrame,
|
|
512
|
+
mdist: float = 3.0,
|
|
513
|
+
) -> pd.DataFrame:
|
|
514
|
+
"""Evaluate the prediction result using deepblink metrics.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
pred: The predicted spots.
|
|
518
|
+
true: The true spots.
|
|
519
|
+
mdist: The maximum distance to consider a spot as a true positive.
|
|
520
|
+
|
|
521
|
+
Returns:
|
|
522
|
+
A pandas dataframe containing the evaluation metrics."""
|
|
523
|
+
from .utils.metrics_deepblink import compute_metrics
|
|
524
|
+
axis_names = list(pred.columns)
|
|
525
|
+
axis_cols = [n for n in axis_names if n.startswith('axis')]
|
|
526
|
+
pred = pred[axis_cols].values
|
|
527
|
+
true = true[axis_cols].values
|
|
528
|
+
metrics = compute_metrics(
|
|
529
|
+
pred, true, mdist=mdist)
|
|
530
|
+
return metrics
|
|
531
|
+
|
|
532
|
+
def evaluate_result(
|
|
533
|
+
self,
|
|
534
|
+
pred: pd.DataFrame,
|
|
535
|
+
true: pd.DataFrame,
|
|
536
|
+
cutoff: float = 3.0,
|
|
537
|
+
) -> float:
|
|
538
|
+
"""Calculate the F1 score of the prediction result.
|
|
539
|
+
|
|
540
|
+
Args:
|
|
541
|
+
pred: The predicted spots.
|
|
542
|
+
true: The true spots.
|
|
543
|
+
cutoff: The maximum distance to consider a spot as a true positive.
|
|
544
|
+
"""
|
|
545
|
+
from .utils.metrics import compute_metrics
|
|
546
|
+
res = compute_metrics(pred.values, true.values, cutoff=cutoff)
|
|
547
|
+
return res
|
|
548
|
+
|
|
549
|
+
def plot_result(
|
|
550
|
+
self,
|
|
551
|
+
img: np.ndarray,
|
|
552
|
+
pred: pd.DataFrame,
|
|
553
|
+
fig_size: T.Tuple[int, int] = (10, 10),
|
|
554
|
+
image_cmap: str = 'gray',
|
|
555
|
+
marker_size: int = 20,
|
|
556
|
+
marker_color: str = 'red',
|
|
557
|
+
marker_style: str = 'x',
|
|
558
|
+
) -> "Figure":
|
|
559
|
+
"""Plot the prediction result.
|
|
560
|
+
|
|
561
|
+
Args:
|
|
562
|
+
img: The image to plot.
|
|
563
|
+
pred: The predicted spots.
|
|
564
|
+
fig_size: The figure size.
|
|
565
|
+
image_cmap: The colormap for the image.
|
|
566
|
+
marker_size: The marker size.
|
|
567
|
+
marker_color: The marker color.
|
|
568
|
+
marker_style: The marker style.
|
|
569
|
+
"""
|
|
570
|
+
from .utils.plot import Plot2d
|
|
571
|
+
plt2d = Plot2d()
|
|
572
|
+
plt2d.default_figsize = fig_size
|
|
573
|
+
plt2d.default_marker_size = marker_size
|
|
574
|
+
plt2d.default_marker_color = marker_color
|
|
575
|
+
plt2d.default_marker_style = marker_style
|
|
576
|
+
plt2d.default_imshow_cmap = image_cmap
|
|
577
|
+
plt2d.new_fig()
|
|
578
|
+
plt2d.image(img)
|
|
579
|
+
plt2d.spots(pred.values)
|
|
580
|
+
return plt2d.fig
|
|
581
|
+
|
|
582
|
+
def plot_evaluate(
|
|
583
|
+
self,
|
|
584
|
+
img: np.ndarray,
|
|
585
|
+
pred: pd.DataFrame,
|
|
586
|
+
true: pd.DataFrame,
|
|
587
|
+
cutoff: float = 3.0,
|
|
588
|
+
fig_size: T.Tuple[int, int] = (10, 10),
|
|
589
|
+
image_cmap: str = 'gray',
|
|
590
|
+
marker_size: int = 20,
|
|
591
|
+
tp_color: str = 'green',
|
|
592
|
+
fp_color: str = 'red',
|
|
593
|
+
fn_color: str = 'yellow',
|
|
594
|
+
tp_marker: str = 'x',
|
|
595
|
+
fp_marker: str = 'x',
|
|
596
|
+
fn_marker: str = 'x',
|
|
597
|
+
) -> "Figure":
|
|
598
|
+
"""Plot the prediction result.
|
|
599
|
+
|
|
600
|
+
Args:
|
|
601
|
+
img: The image to plot.
|
|
602
|
+
pred: The predicted spots.
|
|
603
|
+
true: The true spots.
|
|
604
|
+
cutoff: The maximum distance to consider a spot as a true positive.
|
|
605
|
+
fig_size: The figure size.
|
|
606
|
+
image_cmap: The colormap for the image.
|
|
607
|
+
marker_size: The marker size.
|
|
608
|
+
tp_color: The color for true positive.
|
|
609
|
+
fp_color: The color for false positive.
|
|
610
|
+
fn_color: The color for false negative.
|
|
611
|
+
tp_marker_style: The marker style for true positive.
|
|
612
|
+
fp_marker_style: The marker style for false positive.
|
|
613
|
+
fn_marker_style: The marker style for false negative.
|
|
614
|
+
"""
|
|
615
|
+
from .utils.plot import Plot2d
|
|
616
|
+
plt2d = Plot2d()
|
|
617
|
+
plt2d.default_figsize = fig_size
|
|
618
|
+
plt2d.default_marker_size = marker_size
|
|
619
|
+
plt2d.default_imshow_cmap = image_cmap
|
|
620
|
+
plt2d.new_fig()
|
|
621
|
+
plt2d.image(img)
|
|
622
|
+
plt2d.evaluate_result(
|
|
623
|
+
pred.values, true.values,
|
|
624
|
+
cutoff=cutoff,
|
|
625
|
+
tp_color=tp_color,
|
|
626
|
+
fp_color=fp_color,
|
|
627
|
+
fn_color=fn_color,
|
|
628
|
+
tp_marker=tp_marker,
|
|
629
|
+
fp_marker=fp_marker,
|
|
630
|
+
fn_marker=fn_marker,
|
|
631
|
+
)
|
|
632
|
+
return plt2d.fig
|
|
633
|
+
|
|
634
|
+
def _load_dataset(
|
|
635
|
+
self,
|
|
636
|
+
path: str,
|
|
637
|
+
root_dir_path: T.Optional[str] = None,
|
|
638
|
+
img_glob: str = '*.tif',
|
|
639
|
+
coord_glob: str = '*.csv',
|
|
640
|
+
process_func=None,
|
|
641
|
+
transform=None,
|
|
642
|
+
):
|
|
643
|
+
"""Load a dataset from a path."""
|
|
644
|
+
from .data import FISHSpotsDataset
|
|
645
|
+
_path = Path(path)
|
|
646
|
+
if _path.is_dir():
|
|
647
|
+
if root_dir_path is not None:
|
|
648
|
+
logger.info(f"Dataset's root directory: {root_dir_path}")
|
|
649
|
+
_path = Path(root_dir_path) / _path
|
|
650
|
+
logger.info(f"Loading dataset from dir: {_path}")
|
|
651
|
+
logger.info(
|
|
652
|
+
f'Image glob: {img_glob}, Coordinate glob: {coord_glob}')
|
|
653
|
+
_path_str = str(_path)
|
|
654
|
+
dataset = FISHSpotsDataset.from_dir(
|
|
655
|
+
_path_str, _path_str,
|
|
656
|
+
img_glob=img_glob, coord_glob=coord_glob,
|
|
657
|
+
process_func=process_func, transform=transform)
|
|
658
|
+
else:
|
|
659
|
+
logger.info(f"Loading dataset using meta csv: {_path}")
|
|
660
|
+
assert _path.suffix == '.csv', \
|
|
661
|
+
"Meta file must be a csv file."
|
|
662
|
+
root_dir = root_dir_path or _path.parent
|
|
663
|
+
logger.info(f'Data root directory: {root_dir}')
|
|
664
|
+
dataset = FISHSpotsDataset.from_meta_csv(
|
|
665
|
+
root_dir=root_dir, meta_csv_path=str(_path),
|
|
666
|
+
process_func=process_func, transform=transform)
|
|
667
|
+
return dataset
|
|
668
|
+
|
|
669
|
+
def train(
|
|
670
|
+
self,
|
|
671
|
+
train_path: str,
|
|
672
|
+
valid_path: str,
|
|
673
|
+
root_dir: T.Optional[str] = None,
|
|
674
|
+
img_glob: str = '*.tif',
|
|
675
|
+
coord_glob: str = '*.csv',
|
|
676
|
+
target_process: T.Optional[str] = 'gaussian',
|
|
677
|
+
loss_type: str = 'DiceRMSELoss',
|
|
678
|
+
loader_workers: int = 4,
|
|
679
|
+
data_argu: bool = False,
|
|
680
|
+
argu_prob: float = 0.5,
|
|
681
|
+
num_epochs: int = 50,
|
|
682
|
+
batch_size: int = 8,
|
|
683
|
+
lr: float = 1e-3,
|
|
684
|
+
summary_dir: str = "runs/unet",
|
|
685
|
+
model_save_dir: str = "./models",
|
|
686
|
+
save_period: int = 5,
|
|
687
|
+
):
|
|
688
|
+
"""Train the U-Net model.
|
|
689
|
+
|
|
690
|
+
Args:
|
|
691
|
+
train_path: The path to the training dataset.
|
|
692
|
+
Path to a directory containing images and coordinates,
|
|
693
|
+
or a meta csv file.
|
|
694
|
+
valid_path: The path to the validation dataset.
|
|
695
|
+
Path to a directory containing images and coordinates,
|
|
696
|
+
or a meta csv file.
|
|
697
|
+
root_dir: The root directory of the dataset.
|
|
698
|
+
If using meta csv, the root directory of the dataset.
|
|
699
|
+
img_glob: The glob pattern for the image files.
|
|
700
|
+
coord_glob: The glob pattern for the coordinate files.
|
|
701
|
+
target_process: The target image processing method.
|
|
702
|
+
'gaussian' or 'dialation' or None.
|
|
703
|
+
If None, no processing will be applied.
|
|
704
|
+
default 'gaussian'.
|
|
705
|
+
loss_type: The loss function type.
|
|
706
|
+
loader_workers: The number of workers to use for the data loader.
|
|
707
|
+
data_argu: Whether to use data augmentation.
|
|
708
|
+
argu_prob: The probability to use data augmentation.
|
|
709
|
+
num_epochs: The number of epochs to train.
|
|
710
|
+
batch_size: The batch size.
|
|
711
|
+
lr: The learning rate.
|
|
712
|
+
summary_dir: The directory to save the TensorBoard summary to.
|
|
713
|
+
model_save_dir: The directory to save the model to.
|
|
714
|
+
save_period: Save the model every `save_period` epochs.
|
|
715
|
+
"""
|
|
716
|
+
from .model.train import train_on_dataset
|
|
717
|
+
from .data import FISHSpotsDataset
|
|
718
|
+
if self.model is None:
|
|
719
|
+
logger.info('Model is not initialized. Will initialize a new one.')
|
|
720
|
+
self.init_model()
|
|
721
|
+
assert self.model is not None
|
|
722
|
+
|
|
723
|
+
if data_argu:
|
|
724
|
+
logger.info(
|
|
725
|
+
'Using data augmentation. ' +
|
|
726
|
+
f'Probability: {argu_prob}'
|
|
727
|
+
)
|
|
728
|
+
from .data import DataAugmentation
|
|
729
|
+
transform = DataAugmentation(p=argu_prob)
|
|
730
|
+
else:
|
|
731
|
+
transform = None
|
|
732
|
+
|
|
733
|
+
logger.info(f'Using {target_process} as target process.')
|
|
734
|
+
if target_process == 'gaussian':
|
|
735
|
+
process_func = FISHSpotsDataset.gaussian_filter
|
|
736
|
+
elif target_process == 'dialation':
|
|
737
|
+
process_func = FISHSpotsDataset.dialate_mask
|
|
738
|
+
elif isinstance(target_process, str):
|
|
739
|
+
from functools import partial
|
|
740
|
+
process_func = partial(
|
|
741
|
+
FISHSpotsDataset.dialate_mask,
|
|
742
|
+
footprint=target_process)
|
|
743
|
+
else:
|
|
744
|
+
process_func = None
|
|
745
|
+
|
|
746
|
+
logger.info(f"Loading training dataset from {train_path}")
|
|
747
|
+
train_dataset = self._load_dataset(
|
|
748
|
+
train_path, root_dir_path=root_dir,
|
|
749
|
+
img_glob=img_glob, coord_glob=coord_glob,
|
|
750
|
+
process_func=process_func, transform=transform,
|
|
751
|
+
)
|
|
752
|
+
logger.info(f"Loading validation dataset from {valid_path}")
|
|
753
|
+
valid_dataset = self._load_dataset(
|
|
754
|
+
valid_path, root_dir_path=root_dir,
|
|
755
|
+
img_glob=img_glob, coord_glob=coord_glob,
|
|
756
|
+
process_func=process_func,
|
|
757
|
+
)
|
|
758
|
+
logger.info(
|
|
759
|
+
f"Training dataset size: {len(train_dataset)}, "
|
|
760
|
+
f"Validation dataset size: {len(valid_dataset)}"
|
|
761
|
+
)
|
|
762
|
+
logger.info(
|
|
763
|
+
f"Number of epochs: {num_epochs}, "
|
|
764
|
+
f"Batch size: {batch_size}, "
|
|
765
|
+
f"Learning rate: {lr}"
|
|
766
|
+
)
|
|
767
|
+
train_on_dataset(
|
|
768
|
+
self.model,
|
|
769
|
+
train_dataset, valid_dataset,
|
|
770
|
+
loss_type=loss_type,
|
|
771
|
+
loader_workers=loader_workers,
|
|
772
|
+
num_epochs=num_epochs,
|
|
773
|
+
batch_size=batch_size,
|
|
774
|
+
lr=lr,
|
|
775
|
+
summary_dir=summary_dir,
|
|
776
|
+
model_save_dir=model_save_dir,
|
|
777
|
+
save_period=save_period,
|
|
778
|
+
)
|