senoquant 1.0.0b1__py3-none-any.whl → 1.0.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. senoquant/__init__.py +6 -2
  2. senoquant/_reader.py +1 -1
  3. senoquant/reader/core.py +201 -18
  4. senoquant/tabs/batch/backend.py +18 -3
  5. senoquant/tabs/batch/frontend.py +8 -4
  6. senoquant/tabs/quantification/features/marker/dialog.py +26 -6
  7. senoquant/tabs/quantification/features/marker/export.py +97 -24
  8. senoquant/tabs/quantification/features/marker/rows.py +2 -2
  9. senoquant/tabs/quantification/features/spots/dialog.py +41 -11
  10. senoquant/tabs/quantification/features/spots/export.py +163 -10
  11. senoquant/tabs/quantification/frontend.py +2 -2
  12. senoquant/tabs/segmentation/frontend.py +46 -9
  13. senoquant/tabs/segmentation/models/cpsam/model.py +1 -1
  14. senoquant/tabs/segmentation/models/default_2d/model.py +22 -77
  15. senoquant/tabs/segmentation/models/default_3d/model.py +8 -74
  16. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +0 -0
  17. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +13 -13
  18. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/stardist_libs.py +171 -0
  19. senoquant/tabs/spots/frontend.py +42 -5
  20. senoquant/tabs/spots/models/ufish/details.json +17 -0
  21. senoquant/tabs/spots/models/ufish/model.py +129 -0
  22. senoquant/tabs/spots/ufish_utils/__init__.py +13 -0
  23. senoquant/tabs/spots/ufish_utils/core.py +357 -0
  24. senoquant/utils.py +1 -1
  25. senoquant-1.0.0b3.dist-info/METADATA +161 -0
  26. {senoquant-1.0.0b1.dist-info → senoquant-1.0.0b3.dist-info}/RECORD +41 -28
  27. {senoquant-1.0.0b1.dist-info → senoquant-1.0.0b3.dist-info}/top_level.txt +1 -0
  28. ufish/__init__.py +1 -0
  29. ufish/api.py +778 -0
  30. ufish/model/__init__.py +0 -0
  31. ufish/model/loss.py +62 -0
  32. ufish/model/network/__init__.py +0 -0
  33. ufish/model/network/spot_learn.py +50 -0
  34. ufish/model/network/ufish_net.py +204 -0
  35. ufish/model/train.py +175 -0
  36. ufish/utils/__init__.py +0 -0
  37. ufish/utils/img.py +418 -0
  38. ufish/utils/log.py +8 -0
  39. ufish/utils/spot_calling.py +115 -0
  40. senoquant/tabs/spots/models/rmp/details.json +0 -61
  41. senoquant/tabs/spots/models/rmp/model.py +0 -499
  42. senoquant/tabs/spots/models/udwt/details.json +0 -103
  43. senoquant/tabs/spots/models/udwt/model.py +0 -482
  44. senoquant-1.0.0b1.dist-info/METADATA +0 -193
  45. {senoquant-1.0.0b1.dist-info → senoquant-1.0.0b3.dist-info}/WHEEL +0 -0
  46. {senoquant-1.0.0b1.dist-info → senoquant-1.0.0b3.dist-info}/entry_points.txt +0 -0
  47. {senoquant-1.0.0b1.dist-info → senoquant-1.0.0b3.dist-info}/licenses/LICENSE +0 -0
File without changes
ufish/model/loss.py ADDED
@@ -0,0 +1,62 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class DiceLoss(nn.Module):
6
+ def __init__(self):
7
+ super(DiceLoss, self).__init__()
8
+
9
+ def _dice_loss(self, score, target):
10
+ target = target.float()
11
+ smooth = 1e-5
12
+ intersect = torch.sum(score * target)
13
+ y_sum = torch.sum(target * target)
14
+ z_sum = torch.sum(score * score)
15
+ loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
16
+ loss = 1 - loss
17
+ return loss
18
+
19
+ def forward(self, y_hat, y):
20
+ return self._dice_loss(y_hat, y)
21
+
22
+
23
+ class RMSELoss(nn.Module):
24
+ def __init__(self):
25
+ super().__init__()
26
+ self.mse = nn.MSELoss()
27
+
28
+ def forward(self, y_hat, y):
29
+ return torch.sqrt(self.mse(y_hat, y))
30
+
31
+
32
+ class DiceRMSELoss(nn.Module):
33
+ def __init__(self, dice_ratio=0.6, rmse_ratio=0.4):
34
+ super().__init__()
35
+ self.dice_loss = DiceLoss()
36
+ self.rmse_loss = RMSELoss()
37
+ self.dice_ratio = dice_ratio
38
+ self.rmse_ratio = rmse_ratio
39
+
40
+ def forward(self, y_hat, y):
41
+ _dice = self.dice_loss(y_hat, y)
42
+ _dice = self.dice_ratio * _dice
43
+ _rmse = self.rmse_loss(y_hat, y)
44
+ _rmse = self.rmse_ratio * _rmse
45
+ return _dice + _rmse
46
+
47
+
48
+ class DiceCoefLoss(nn.Module):
49
+ def __init__(self):
50
+ super().__init__()
51
+
52
+ def _dice_coef(self, y_true, y_pred):
53
+ smooth = 1.0
54
+ y_true_f = torch.flatten(y_true)
55
+ y_pred_f = torch.flatten(y_pred)
56
+ intersection = torch.sum(y_true_f * y_pred_f)
57
+ a = (2. * intersection + smooth)
58
+ b = (torch.sum(y_true_f) + torch.sum(y_pred_f) + smooth)
59
+ return a / b
60
+
61
+ def forward(self, y_hat, y):
62
+ return - self._dice_coef(y, y_hat)
File without changes
@@ -0,0 +1,50 @@
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+
5
+ class ConvBlock(nn.Module):
6
+ def __init__(self, in_ch, out_ch):
7
+ super().__init__()
8
+ self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
9
+ self.relu1 = nn.ReLU()
10
+ self.dropout1 = nn.Dropout2d(p=0.2)
11
+ self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
12
+ self.relu2 = nn.ReLU()
13
+
14
+ def forward(self, x):
15
+ out = self.conv1(x)
16
+ out = self.relu1(out)
17
+ out = self.dropout1(out)
18
+ out = self.conv2(out)
19
+ out = self.relu2(out)
20
+ return out
21
+
22
+
23
+ class SpotLearn(nn.Module):
24
+ def __init__(self, input_channel: int = 1) -> None:
25
+ super().__init__()
26
+ self.conv1 = ConvBlock(input_channel, 64)
27
+ self.down1 = nn.MaxPool2d(kernel_size=2, stride=2)
28
+ self.conv2 = ConvBlock(64, 128)
29
+ self.down2 = nn.MaxPool2d(kernel_size=2, stride=2)
30
+ self.bottom = ConvBlock(128, 256)
31
+ self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
32
+ self.conv3 = ConvBlock(256, 128)
33
+ self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
34
+ self.conv4 = ConvBlock(128, 64)
35
+ self.conv5 = nn.Conv2d(64, 1, kernel_size=1, padding=0)
36
+
37
+ def forward(self, x):
38
+ in1 = self.conv1(x)
39
+ x = self.down1(in1)
40
+ in2 = self.conv2(x)
41
+ x = self.down2(in2)
42
+ x = self.bottom(x)
43
+ x = self.up1(x)
44
+ x = torch.cat([x, in2], dim=1)
45
+ x = self.conv3(x)
46
+ x = self.up2(x)
47
+ x = torch.cat([x, in1], dim=1)
48
+ x = self.conv4(x)
49
+ x = self.conv5(x)
50
+ return x
@@ -0,0 +1,204 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ConvBlock(nn.Module):
7
+ def __init__(self, in_channels, out_channels) -> None:
8
+ super().__init__()
9
+ self.conv1 = nn.Conv2d(
10
+ in_channels, out_channels, kernel_size=3, padding=1)
11
+ self.conv2 = nn.Conv2d(
12
+ out_channels, out_channels, kernel_size=3, padding=1)
13
+
14
+ def forward(self, x):
15
+ out = F.relu(self.conv1(x))
16
+ out = F.relu(self.conv2(out))
17
+ return out
18
+
19
+
20
+ class ResidualBlock(nn.Module):
21
+ def __init__(self, channels):
22
+ super(ResidualBlock, self).__init__()
23
+ self.conv1 = nn.Conv2d(
24
+ channels, channels, kernel_size=1, padding=0)
25
+ self.conv2 = nn.Conv2d(
26
+ channels, channels, kernel_size=3, padding=1)
27
+ self.BN1 = nn.BatchNorm2d(channels)
28
+ self.BN2 = nn.BatchNorm2d(channels)
29
+
30
+ def forward(self, x):
31
+ out = self.conv1(x)
32
+ out = self.BN1(out)
33
+ out = F.relu(out)
34
+ out = self.conv2(out)
35
+ out = self.BN2(out)
36
+ out = F.relu(out)
37
+ out = out + x
38
+ return out
39
+
40
+
41
+ class DownConv(nn.Module):
42
+ def __init__(self, in_channels, out_channels):
43
+ super(DownConv, self).__init__()
44
+ self.down_conv = nn.Conv2d(
45
+ in_channels, out_channels, kernel_size=3, stride=2, padding=1)
46
+ self.conv = ResidualBlock(out_channels)
47
+
48
+ def forward(self, x):
49
+ out = self.down_conv(x)
50
+ out = self.conv(out)
51
+ return out
52
+
53
+
54
+ class UpConv(nn.Module):
55
+ def __init__(self, in_channels, out_channels):
56
+ super(UpConv, self).__init__()
57
+ self.conv = nn.Conv2d(
58
+ in_channels, out_channels, kernel_size=3, padding=1)
59
+
60
+ def forward(self, x):
61
+ up = F.interpolate(x, scale_factor=2, mode='nearest')
62
+ out = self.conv(up)
63
+ return out
64
+
65
+
66
+ class ChannelAttention(nn.Module):
67
+ def __init__(self, input_nc, ratio=16):
68
+ super(ChannelAttention, self).__init__()
69
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
70
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
71
+ self.f1 = nn.Conv2d(input_nc, input_nc // ratio, 1, bias=False)
72
+ self.relu = nn.ReLU()
73
+ self.f2 = nn.Conv2d(input_nc // ratio, input_nc, 1, bias=False)
74
+ self.sigmoid = nn.Sigmoid()
75
+
76
+ def forward(self, x):
77
+ avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
78
+ max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
79
+ out = self.sigmoid(avg_out + max_out)
80
+ return out
81
+
82
+
83
+ class SpatialAttention(nn.Module):
84
+ def __init__(self, kernel_size=7):
85
+ super(SpatialAttention, self).__init__()
86
+ assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
87
+ padding = 3 if kernel_size == 7 else 1
88
+ self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
89
+ self.sigmoid = nn.Sigmoid()
90
+
91
+ def forward(self, x):
92
+ # 1*h*w
93
+ avg_out = torch.mean(x, dim=1, keepdim=True)
94
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
95
+ x = torch.cat([avg_out, max_out], dim=1)
96
+ # 2*h*w
97
+ x = self.conv(x)
98
+ # 1*h*w
99
+ return self.sigmoid(x)
100
+
101
+
102
+ class CBAM(nn.Module):
103
+ """Convolutional Block Attention Module"""
104
+ def __init__(self, input_nc, ratio=16, kernel_size=7):
105
+ super(CBAM, self).__init__()
106
+ self.channel_attention = ChannelAttention(input_nc, ratio)
107
+ self.spatial_attention = SpatialAttention(kernel_size)
108
+
109
+ def forward(self, x):
110
+ out = self.channel_attention(x) * x
111
+ out = self.spatial_attention(out) * out
112
+ return out
113
+
114
+
115
+ class EncoderBlock(nn.Module):
116
+ def __init__(self, in_channels, out_channels):
117
+ super(EncoderBlock, self).__init__()
118
+ self.conv = ConvBlock(in_channels, out_channels)
119
+
120
+ def forward(self, x):
121
+ out = self.conv(x)
122
+ return out
123
+
124
+
125
+ class DecoderBlock(nn.Module):
126
+ def __init__(self, in_channels, out_channels):
127
+ super(DecoderBlock, self).__init__()
128
+ self.conv = ConvBlock(in_channels, out_channels)
129
+
130
+ def forward(self, x):
131
+ out = self.conv(x)
132
+ return out
133
+
134
+
135
+ class BottoleneckBlock(nn.Module):
136
+ def __init__(self, channels):
137
+ super(BottoleneckBlock, self).__init__()
138
+ self.conv1 = ResidualBlock(channels)
139
+ self.cbam = CBAM(channels)
140
+ self.conv2 = ResidualBlock(channels)
141
+
142
+ def forward(self, x):
143
+ out = self.conv1(x)
144
+ out = self.cbam(out)
145
+ out = self.conv2(out)
146
+ return out
147
+
148
+
149
+ class FinalDecoderBlock(nn.Module):
150
+ def __init__(self, in_channels, out_channels):
151
+ super(FinalDecoderBlock, self).__init__()
152
+ self.cbam = CBAM(in_channels)
153
+ self.conv = ConvBlock(in_channels, out_channels)
154
+
155
+ def forward(self, x):
156
+ out = self.cbam(x)
157
+ out = self.conv(out)
158
+ return out
159
+
160
+
161
+ class UFishNet(nn.Module):
162
+ def __init__(
163
+ self, in_channels=1, out_channels=1,
164
+ channel_numbers=[32, 32, 32]):
165
+ super().__init__()
166
+ self.encoders = nn.ModuleList()
167
+ self.downsamples = nn.ModuleList()
168
+ for i, c in enumerate(channel_numbers[:-1]):
169
+ _ei = (in_channels if i == 0 else c)
170
+ self.encoders.append(EncoderBlock(_ei, c))
171
+ self.downsamples.append(DownConv(c, channel_numbers[i+1]))
172
+
173
+ self.bottom = BottoleneckBlock(channel_numbers[-1])
174
+
175
+ self.decoders = nn.ModuleList()
176
+ self.upsamples = nn.ModuleList()
177
+ rev_nums = channel_numbers[::-1]
178
+ for i, c in enumerate(rev_nums[1:]):
179
+ self.upsamples.append(UpConv(rev_nums[i], c))
180
+ self.decoders.append(DecoderBlock(2*c, c))
181
+
182
+ self.final_decoder = FinalDecoderBlock(
183
+ channel_numbers[0], out_channels)
184
+
185
+ def forward(self, x):
186
+ encodings = []
187
+ for i, encoder in enumerate(self.encoders):
188
+ x = encoder(x)
189
+ encodings.append(x)
190
+ x = self.downsamples[i](x)
191
+
192
+ x = self.bottom(x)
193
+
194
+ for i, decoder in enumerate(self.decoders):
195
+ x = self.upsamples[i](x)
196
+ diffY = encodings[-i - 1].size()[2] - x.size()[2]
197
+ diffX = encodings[-i - 1].size()[3] - x.size()[3]
198
+ x = F.pad(x, (diffX // 2, diffX - diffX // 2,
199
+ diffY // 2, diffY - diffY // 2))
200
+ x = torch.cat([x, encodings[-i - 1]], dim=1)
201
+ x = decoder(x)
202
+
203
+ x = self.final_decoder(x)
204
+ return x
ufish/model/train.py ADDED
@@ -0,0 +1,175 @@
1
+ import typing as T
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import numpy as np
6
+ from torch.utils.tensorboard import SummaryWriter
7
+ from torch.utils.data import DataLoader
8
+ from torch import Tensor
9
+
10
+ from . import loss as loss_mod # noqa: F401
11
+ from ..data import FISHSpotsDataset
12
+ from ..utils.log import logger
13
+
14
+
15
+ def training_loop(
16
+ model: torch.nn.Module,
17
+ optimizer: torch.optim.Optimizer,
18
+ criterion: T.Callable[[Tensor, Tensor], Tensor],
19
+ writer: SummaryWriter,
20
+ device: torch.device,
21
+ train_loader: DataLoader, valid_loader: DataLoader,
22
+ model_save_dir: str,
23
+ save_period: int,
24
+ num_epochs=50,
25
+ ):
26
+ """
27
+ The training loop.
28
+
29
+ Args:
30
+ model: The model to train.
31
+ optimizer: The optimizer to use.
32
+ criterion: The loss function to use.
33
+ writer: The TensorBoard writer.
34
+ device: The device to use.
35
+ train_loader: The training data loader.
36
+ valid_loader: The valid data loader.
37
+ model_save_dir: The directory to save the model to.
38
+ save_period: Save the model every `save_period` epochs.
39
+ num_epochs: The number of epochs to train for.
40
+ """
41
+ best_val_loss = float("inf")
42
+ model_dir_path = Path(model_save_dir)
43
+ model_dir_path.mkdir(parents=True, exist_ok=True)
44
+
45
+ for epoch in range(num_epochs):
46
+ model.train()
47
+ epoch_loss = 0.0
48
+ for idx, batch in enumerate(train_loader):
49
+ images = batch["image"].to(device, dtype=torch.float)
50
+ targets = batch["target"].to(device, dtype=torch.float)
51
+
52
+ optimizer.zero_grad()
53
+ outputs = model(images)
54
+ loss = criterion(outputs, targets)
55
+ loss.backward()
56
+ optimizer.step()
57
+
58
+ if idx % 10 == 0:
59
+ logger.info(
60
+ f"Epoch: {epoch + 1}/{num_epochs}, "
61
+ f"Batch: {idx + 1}/{len(train_loader)}, "
62
+ f"Loss: {loss.item():.4f}"
63
+ )
64
+ writer.add_scalar(
65
+ "Loss/train_batch", loss.item(),
66
+ epoch * len(train_loader) + idx)
67
+ img = images[0, 0].cpu().numpy()
68
+ img = np.stack((img,)*3, axis=0)
69
+ # normalize to 0-255
70
+ img = (img - img.min()) / (img.max() - img.min()) * 255
71
+ # record images
72
+ writer.add_image(
73
+ "Image/input",
74
+ images[0], epoch * len(train_loader) + idx)
75
+ writer.add_image(
76
+ "Image/target",
77
+ targets[0], epoch * len(train_loader) + idx)
78
+ writer.add_image(
79
+ "Image/pred",
80
+ outputs[0], epoch * len(train_loader) + idx)
81
+
82
+ epoch_loss += loss.item()
83
+
84
+ epoch_loss /= len(train_loader)
85
+ writer.add_scalar("Loss/train", epoch_loss, epoch)
86
+
87
+ model.eval()
88
+ val_loss = 0.0
89
+ with torch.no_grad():
90
+ for batch in valid_loader:
91
+ images = batch["image"].to(device, dtype=torch.float)
92
+ targets = batch["target"].to(device, dtype=torch.float)
93
+
94
+ outputs = model(images)
95
+ loss = criterion(outputs, targets)
96
+ val_loss += loss.item()
97
+
98
+ val_loss /= len(valid_loader)
99
+ writer.add_scalar("Loss/val", val_loss, epoch)
100
+
101
+ logger.info(
102
+ f"Epoch {epoch + 1}/{num_epochs}, "
103
+ f"Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}"
104
+ )
105
+
106
+ if val_loss < best_val_loss:
107
+ best_val_loss = val_loss
108
+ torch.save(
109
+ model.state_dict(),
110
+ f"{model_save_dir}/best_model.pth")
111
+ logger.info(f"Best model saved with Val Loss: {val_loss:.4f}")
112
+ if epoch % save_period == 0:
113
+ torch.save(
114
+ model.state_dict(),
115
+ f"{model_save_dir}/model_{epoch}.pth")
116
+ logger.info(f"Model saved at epoch {epoch}.")
117
+
118
+ writer.close()
119
+
120
+
121
+ def train_on_dataset(
122
+ model: torch.nn.Module,
123
+ train_dataset: FISHSpotsDataset,
124
+ valid_dataset: FISHSpotsDataset,
125
+ loss_type: str = "DiceRMSELoss",
126
+ loader_workers: int = 4,
127
+ num_epochs: int = 50,
128
+ batch_size: int = 8,
129
+ lr: float = 1e-4,
130
+ summary_dir: str = "runs/unet",
131
+ model_save_dir: str = "./models",
132
+ save_period: int = 5,
133
+ ):
134
+ """Train the UNet model.
135
+
136
+ Args:
137
+ model: The model to train.
138
+ train_dataset: The training dataset.
139
+ valid_dataset: The validation dataset.
140
+ loss_type: The loss function to use.
141
+ loader_workers: The number of workers to use for the data loader.
142
+ num_epochs: The number of epochs to train for.
143
+ batch_size: The batch size.
144
+ lr: The learning rate.
145
+ summary_dir: The directory to save the TensorBoard summary to.
146
+ model_save_dir: The directory to save the model to.
147
+ save_period: Save the model every `save_period` epochs.
148
+ """
149
+ logger.info(
150
+ f"Loader workers: {loader_workers}, " +
151
+ f"TensorBoard summary dir: {summary_dir}"
152
+ )
153
+ logger.info(
154
+ f"Model save dir: {model_save_dir}, " +
155
+ f"Save period: {save_period}"
156
+ )
157
+ train_loader = DataLoader(
158
+ train_dataset, batch_size=batch_size,
159
+ shuffle=True, num_workers=loader_workers)
160
+ valid_loader = DataLoader(
161
+ valid_dataset, batch_size=batch_size,
162
+ shuffle=False, num_workers=loader_workers)
163
+
164
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
165
+ logger.info(f"Training using device: {device}")
166
+ model = model.to(device)
167
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
168
+
169
+ criterion = eval(f"loss_mod.{loss_type}")()
170
+
171
+ writer = SummaryWriter(summary_dir)
172
+ training_loop(
173
+ model, optimizer, criterion, writer, device,
174
+ train_loader, valid_loader, model_save_dir,
175
+ save_period, num_epochs)
File without changes