senoquant 1.0.0b2__py3-none-any.whl → 1.0.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- senoquant/__init__.py +6 -2
- senoquant/_reader.py +1 -1
- senoquant/_widget.py +9 -1
- senoquant/reader/core.py +201 -18
- senoquant/tabs/__init__.py +2 -0
- senoquant/tabs/batch/backend.py +76 -27
- senoquant/tabs/batch/frontend.py +127 -25
- senoquant/tabs/quantification/features/marker/dialog.py +26 -6
- senoquant/tabs/quantification/features/marker/export.py +97 -24
- senoquant/tabs/quantification/features/marker/rows.py +2 -2
- senoquant/tabs/quantification/features/spots/dialog.py +41 -11
- senoquant/tabs/quantification/features/spots/export.py +163 -10
- senoquant/tabs/quantification/frontend.py +2 -2
- senoquant/tabs/segmentation/frontend.py +46 -9
- senoquant/tabs/segmentation/models/cpsam/model.py +1 -1
- senoquant/tabs/segmentation/models/default_2d/model.py +22 -77
- senoquant/tabs/segmentation/models/default_3d/model.py +8 -74
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +13 -13
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/stardist_libs.py +171 -0
- senoquant/tabs/spots/frontend.py +96 -5
- senoquant/tabs/spots/models/rmp/details.json +3 -9
- senoquant/tabs/spots/models/rmp/model.py +341 -266
- senoquant/tabs/spots/models/ufish/details.json +32 -0
- senoquant/tabs/spots/models/ufish/model.py +327 -0
- senoquant/tabs/spots/ufish_utils/__init__.py +13 -0
- senoquant/tabs/spots/ufish_utils/core.py +387 -0
- senoquant/tabs/visualization/__init__.py +1 -0
- senoquant/tabs/visualization/backend.py +306 -0
- senoquant/tabs/visualization/frontend.py +1113 -0
- senoquant/tabs/visualization/plots/__init__.py +80 -0
- senoquant/tabs/visualization/plots/base.py +152 -0
- senoquant/tabs/visualization/plots/double_expression.py +187 -0
- senoquant/tabs/visualization/plots/spatialplot.py +156 -0
- senoquant/tabs/visualization/plots/umap.py +140 -0
- senoquant/utils.py +1 -1
- senoquant-1.0.0b4.dist-info/METADATA +162 -0
- {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b4.dist-info}/RECORD +53 -30
- {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b4.dist-info}/top_level.txt +1 -0
- ufish/__init__.py +1 -0
- ufish/api.py +778 -0
- ufish/model/__init__.py +0 -0
- ufish/model/loss.py +62 -0
- ufish/model/network/__init__.py +0 -0
- ufish/model/network/spot_learn.py +50 -0
- ufish/model/network/ufish_net.py +204 -0
- ufish/model/train.py +175 -0
- ufish/utils/__init__.py +0 -0
- ufish/utils/img.py +418 -0
- ufish/utils/log.py +8 -0
- ufish/utils/spot_calling.py +115 -0
- senoquant/tabs/spots/models/udwt/details.json +0 -103
- senoquant/tabs/spots/models/udwt/model.py +0 -482
- senoquant-1.0.0b2.dist-info/METADATA +0 -193
- {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b4.dist-info}/WHEEL +0 -0
- {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b4.dist-info}/entry_points.txt +0 -0
- {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b4.dist-info}/licenses/LICENSE +0 -0
ufish/model/__init__.py
ADDED
|
File without changes
|
ufish/model/loss.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class DiceLoss(nn.Module):
|
|
6
|
+
def __init__(self):
|
|
7
|
+
super(DiceLoss, self).__init__()
|
|
8
|
+
|
|
9
|
+
def _dice_loss(self, score, target):
|
|
10
|
+
target = target.float()
|
|
11
|
+
smooth = 1e-5
|
|
12
|
+
intersect = torch.sum(score * target)
|
|
13
|
+
y_sum = torch.sum(target * target)
|
|
14
|
+
z_sum = torch.sum(score * score)
|
|
15
|
+
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
|
|
16
|
+
loss = 1 - loss
|
|
17
|
+
return loss
|
|
18
|
+
|
|
19
|
+
def forward(self, y_hat, y):
|
|
20
|
+
return self._dice_loss(y_hat, y)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RMSELoss(nn.Module):
|
|
24
|
+
def __init__(self):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.mse = nn.MSELoss()
|
|
27
|
+
|
|
28
|
+
def forward(self, y_hat, y):
|
|
29
|
+
return torch.sqrt(self.mse(y_hat, y))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class DiceRMSELoss(nn.Module):
|
|
33
|
+
def __init__(self, dice_ratio=0.6, rmse_ratio=0.4):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.dice_loss = DiceLoss()
|
|
36
|
+
self.rmse_loss = RMSELoss()
|
|
37
|
+
self.dice_ratio = dice_ratio
|
|
38
|
+
self.rmse_ratio = rmse_ratio
|
|
39
|
+
|
|
40
|
+
def forward(self, y_hat, y):
|
|
41
|
+
_dice = self.dice_loss(y_hat, y)
|
|
42
|
+
_dice = self.dice_ratio * _dice
|
|
43
|
+
_rmse = self.rmse_loss(y_hat, y)
|
|
44
|
+
_rmse = self.rmse_ratio * _rmse
|
|
45
|
+
return _dice + _rmse
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class DiceCoefLoss(nn.Module):
|
|
49
|
+
def __init__(self):
|
|
50
|
+
super().__init__()
|
|
51
|
+
|
|
52
|
+
def _dice_coef(self, y_true, y_pred):
|
|
53
|
+
smooth = 1.0
|
|
54
|
+
y_true_f = torch.flatten(y_true)
|
|
55
|
+
y_pred_f = torch.flatten(y_pred)
|
|
56
|
+
intersection = torch.sum(y_true_f * y_pred_f)
|
|
57
|
+
a = (2. * intersection + smooth)
|
|
58
|
+
b = (torch.sum(y_true_f) + torch.sum(y_pred_f) + smooth)
|
|
59
|
+
return a / b
|
|
60
|
+
|
|
61
|
+
def forward(self, y_hat, y):
|
|
62
|
+
return - self._dice_coef(y, y_hat)
|
|
File without changes
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ConvBlock(nn.Module):
|
|
6
|
+
def __init__(self, in_ch, out_ch):
|
|
7
|
+
super().__init__()
|
|
8
|
+
self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
|
|
9
|
+
self.relu1 = nn.ReLU()
|
|
10
|
+
self.dropout1 = nn.Dropout2d(p=0.2)
|
|
11
|
+
self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
|
|
12
|
+
self.relu2 = nn.ReLU()
|
|
13
|
+
|
|
14
|
+
def forward(self, x):
|
|
15
|
+
out = self.conv1(x)
|
|
16
|
+
out = self.relu1(out)
|
|
17
|
+
out = self.dropout1(out)
|
|
18
|
+
out = self.conv2(out)
|
|
19
|
+
out = self.relu2(out)
|
|
20
|
+
return out
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SpotLearn(nn.Module):
|
|
24
|
+
def __init__(self, input_channel: int = 1) -> None:
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.conv1 = ConvBlock(input_channel, 64)
|
|
27
|
+
self.down1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
28
|
+
self.conv2 = ConvBlock(64, 128)
|
|
29
|
+
self.down2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
30
|
+
self.bottom = ConvBlock(128, 256)
|
|
31
|
+
self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
|
|
32
|
+
self.conv3 = ConvBlock(256, 128)
|
|
33
|
+
self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
|
|
34
|
+
self.conv4 = ConvBlock(128, 64)
|
|
35
|
+
self.conv5 = nn.Conv2d(64, 1, kernel_size=1, padding=0)
|
|
36
|
+
|
|
37
|
+
def forward(self, x):
|
|
38
|
+
in1 = self.conv1(x)
|
|
39
|
+
x = self.down1(in1)
|
|
40
|
+
in2 = self.conv2(x)
|
|
41
|
+
x = self.down2(in2)
|
|
42
|
+
x = self.bottom(x)
|
|
43
|
+
x = self.up1(x)
|
|
44
|
+
x = torch.cat([x, in2], dim=1)
|
|
45
|
+
x = self.conv3(x)
|
|
46
|
+
x = self.up2(x)
|
|
47
|
+
x = torch.cat([x, in1], dim=1)
|
|
48
|
+
x = self.conv4(x)
|
|
49
|
+
x = self.conv5(x)
|
|
50
|
+
return x
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ConvBlock(nn.Module):
|
|
7
|
+
def __init__(self, in_channels, out_channels) -> None:
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.conv1 = nn.Conv2d(
|
|
10
|
+
in_channels, out_channels, kernel_size=3, padding=1)
|
|
11
|
+
self.conv2 = nn.Conv2d(
|
|
12
|
+
out_channels, out_channels, kernel_size=3, padding=1)
|
|
13
|
+
|
|
14
|
+
def forward(self, x):
|
|
15
|
+
out = F.relu(self.conv1(x))
|
|
16
|
+
out = F.relu(self.conv2(out))
|
|
17
|
+
return out
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ResidualBlock(nn.Module):
|
|
21
|
+
def __init__(self, channels):
|
|
22
|
+
super(ResidualBlock, self).__init__()
|
|
23
|
+
self.conv1 = nn.Conv2d(
|
|
24
|
+
channels, channels, kernel_size=1, padding=0)
|
|
25
|
+
self.conv2 = nn.Conv2d(
|
|
26
|
+
channels, channels, kernel_size=3, padding=1)
|
|
27
|
+
self.BN1 = nn.BatchNorm2d(channels)
|
|
28
|
+
self.BN2 = nn.BatchNorm2d(channels)
|
|
29
|
+
|
|
30
|
+
def forward(self, x):
|
|
31
|
+
out = self.conv1(x)
|
|
32
|
+
out = self.BN1(out)
|
|
33
|
+
out = F.relu(out)
|
|
34
|
+
out = self.conv2(out)
|
|
35
|
+
out = self.BN2(out)
|
|
36
|
+
out = F.relu(out)
|
|
37
|
+
out = out + x
|
|
38
|
+
return out
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DownConv(nn.Module):
|
|
42
|
+
def __init__(self, in_channels, out_channels):
|
|
43
|
+
super(DownConv, self).__init__()
|
|
44
|
+
self.down_conv = nn.Conv2d(
|
|
45
|
+
in_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
|
46
|
+
self.conv = ResidualBlock(out_channels)
|
|
47
|
+
|
|
48
|
+
def forward(self, x):
|
|
49
|
+
out = self.down_conv(x)
|
|
50
|
+
out = self.conv(out)
|
|
51
|
+
return out
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class UpConv(nn.Module):
|
|
55
|
+
def __init__(self, in_channels, out_channels):
|
|
56
|
+
super(UpConv, self).__init__()
|
|
57
|
+
self.conv = nn.Conv2d(
|
|
58
|
+
in_channels, out_channels, kernel_size=3, padding=1)
|
|
59
|
+
|
|
60
|
+
def forward(self, x):
|
|
61
|
+
up = F.interpolate(x, scale_factor=2, mode='nearest')
|
|
62
|
+
out = self.conv(up)
|
|
63
|
+
return out
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ChannelAttention(nn.Module):
|
|
67
|
+
def __init__(self, input_nc, ratio=16):
|
|
68
|
+
super(ChannelAttention, self).__init__()
|
|
69
|
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
70
|
+
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
|
71
|
+
self.f1 = nn.Conv2d(input_nc, input_nc // ratio, 1, bias=False)
|
|
72
|
+
self.relu = nn.ReLU()
|
|
73
|
+
self.f2 = nn.Conv2d(input_nc // ratio, input_nc, 1, bias=False)
|
|
74
|
+
self.sigmoid = nn.Sigmoid()
|
|
75
|
+
|
|
76
|
+
def forward(self, x):
|
|
77
|
+
avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
|
|
78
|
+
max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
|
|
79
|
+
out = self.sigmoid(avg_out + max_out)
|
|
80
|
+
return out
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class SpatialAttention(nn.Module):
|
|
84
|
+
def __init__(self, kernel_size=7):
|
|
85
|
+
super(SpatialAttention, self).__init__()
|
|
86
|
+
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
|
87
|
+
padding = 3 if kernel_size == 7 else 1
|
|
88
|
+
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
|
89
|
+
self.sigmoid = nn.Sigmoid()
|
|
90
|
+
|
|
91
|
+
def forward(self, x):
|
|
92
|
+
# 1*h*w
|
|
93
|
+
avg_out = torch.mean(x, dim=1, keepdim=True)
|
|
94
|
+
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
|
95
|
+
x = torch.cat([avg_out, max_out], dim=1)
|
|
96
|
+
# 2*h*w
|
|
97
|
+
x = self.conv(x)
|
|
98
|
+
# 1*h*w
|
|
99
|
+
return self.sigmoid(x)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class CBAM(nn.Module):
|
|
103
|
+
"""Convolutional Block Attention Module"""
|
|
104
|
+
def __init__(self, input_nc, ratio=16, kernel_size=7):
|
|
105
|
+
super(CBAM, self).__init__()
|
|
106
|
+
self.channel_attention = ChannelAttention(input_nc, ratio)
|
|
107
|
+
self.spatial_attention = SpatialAttention(kernel_size)
|
|
108
|
+
|
|
109
|
+
def forward(self, x):
|
|
110
|
+
out = self.channel_attention(x) * x
|
|
111
|
+
out = self.spatial_attention(out) * out
|
|
112
|
+
return out
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class EncoderBlock(nn.Module):
|
|
116
|
+
def __init__(self, in_channels, out_channels):
|
|
117
|
+
super(EncoderBlock, self).__init__()
|
|
118
|
+
self.conv = ConvBlock(in_channels, out_channels)
|
|
119
|
+
|
|
120
|
+
def forward(self, x):
|
|
121
|
+
out = self.conv(x)
|
|
122
|
+
return out
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class DecoderBlock(nn.Module):
|
|
126
|
+
def __init__(self, in_channels, out_channels):
|
|
127
|
+
super(DecoderBlock, self).__init__()
|
|
128
|
+
self.conv = ConvBlock(in_channels, out_channels)
|
|
129
|
+
|
|
130
|
+
def forward(self, x):
|
|
131
|
+
out = self.conv(x)
|
|
132
|
+
return out
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class BottoleneckBlock(nn.Module):
|
|
136
|
+
def __init__(self, channels):
|
|
137
|
+
super(BottoleneckBlock, self).__init__()
|
|
138
|
+
self.conv1 = ResidualBlock(channels)
|
|
139
|
+
self.cbam = CBAM(channels)
|
|
140
|
+
self.conv2 = ResidualBlock(channels)
|
|
141
|
+
|
|
142
|
+
def forward(self, x):
|
|
143
|
+
out = self.conv1(x)
|
|
144
|
+
out = self.cbam(out)
|
|
145
|
+
out = self.conv2(out)
|
|
146
|
+
return out
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class FinalDecoderBlock(nn.Module):
|
|
150
|
+
def __init__(self, in_channels, out_channels):
|
|
151
|
+
super(FinalDecoderBlock, self).__init__()
|
|
152
|
+
self.cbam = CBAM(in_channels)
|
|
153
|
+
self.conv = ConvBlock(in_channels, out_channels)
|
|
154
|
+
|
|
155
|
+
def forward(self, x):
|
|
156
|
+
out = self.cbam(x)
|
|
157
|
+
out = self.conv(out)
|
|
158
|
+
return out
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class UFishNet(nn.Module):
|
|
162
|
+
def __init__(
|
|
163
|
+
self, in_channels=1, out_channels=1,
|
|
164
|
+
channel_numbers=[32, 32, 32]):
|
|
165
|
+
super().__init__()
|
|
166
|
+
self.encoders = nn.ModuleList()
|
|
167
|
+
self.downsamples = nn.ModuleList()
|
|
168
|
+
for i, c in enumerate(channel_numbers[:-1]):
|
|
169
|
+
_ei = (in_channels if i == 0 else c)
|
|
170
|
+
self.encoders.append(EncoderBlock(_ei, c))
|
|
171
|
+
self.downsamples.append(DownConv(c, channel_numbers[i+1]))
|
|
172
|
+
|
|
173
|
+
self.bottom = BottoleneckBlock(channel_numbers[-1])
|
|
174
|
+
|
|
175
|
+
self.decoders = nn.ModuleList()
|
|
176
|
+
self.upsamples = nn.ModuleList()
|
|
177
|
+
rev_nums = channel_numbers[::-1]
|
|
178
|
+
for i, c in enumerate(rev_nums[1:]):
|
|
179
|
+
self.upsamples.append(UpConv(rev_nums[i], c))
|
|
180
|
+
self.decoders.append(DecoderBlock(2*c, c))
|
|
181
|
+
|
|
182
|
+
self.final_decoder = FinalDecoderBlock(
|
|
183
|
+
channel_numbers[0], out_channels)
|
|
184
|
+
|
|
185
|
+
def forward(self, x):
|
|
186
|
+
encodings = []
|
|
187
|
+
for i, encoder in enumerate(self.encoders):
|
|
188
|
+
x = encoder(x)
|
|
189
|
+
encodings.append(x)
|
|
190
|
+
x = self.downsamples[i](x)
|
|
191
|
+
|
|
192
|
+
x = self.bottom(x)
|
|
193
|
+
|
|
194
|
+
for i, decoder in enumerate(self.decoders):
|
|
195
|
+
x = self.upsamples[i](x)
|
|
196
|
+
diffY = encodings[-i - 1].size()[2] - x.size()[2]
|
|
197
|
+
diffX = encodings[-i - 1].size()[3] - x.size()[3]
|
|
198
|
+
x = F.pad(x, (diffX // 2, diffX - diffX // 2,
|
|
199
|
+
diffY // 2, diffY - diffY // 2))
|
|
200
|
+
x = torch.cat([x, encodings[-i - 1]], dim=1)
|
|
201
|
+
x = decoder(x)
|
|
202
|
+
|
|
203
|
+
x = self.final_decoder(x)
|
|
204
|
+
return x
|
ufish/model/train.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import typing as T
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import numpy as np
|
|
6
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
7
|
+
from torch.utils.data import DataLoader
|
|
8
|
+
from torch import Tensor
|
|
9
|
+
|
|
10
|
+
from . import loss as loss_mod # noqa: F401
|
|
11
|
+
from ..data import FISHSpotsDataset
|
|
12
|
+
from ..utils.log import logger
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def training_loop(
|
|
16
|
+
model: torch.nn.Module,
|
|
17
|
+
optimizer: torch.optim.Optimizer,
|
|
18
|
+
criterion: T.Callable[[Tensor, Tensor], Tensor],
|
|
19
|
+
writer: SummaryWriter,
|
|
20
|
+
device: torch.device,
|
|
21
|
+
train_loader: DataLoader, valid_loader: DataLoader,
|
|
22
|
+
model_save_dir: str,
|
|
23
|
+
save_period: int,
|
|
24
|
+
num_epochs=50,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
The training loop.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
model: The model to train.
|
|
31
|
+
optimizer: The optimizer to use.
|
|
32
|
+
criterion: The loss function to use.
|
|
33
|
+
writer: The TensorBoard writer.
|
|
34
|
+
device: The device to use.
|
|
35
|
+
train_loader: The training data loader.
|
|
36
|
+
valid_loader: The valid data loader.
|
|
37
|
+
model_save_dir: The directory to save the model to.
|
|
38
|
+
save_period: Save the model every `save_period` epochs.
|
|
39
|
+
num_epochs: The number of epochs to train for.
|
|
40
|
+
"""
|
|
41
|
+
best_val_loss = float("inf")
|
|
42
|
+
model_dir_path = Path(model_save_dir)
|
|
43
|
+
model_dir_path.mkdir(parents=True, exist_ok=True)
|
|
44
|
+
|
|
45
|
+
for epoch in range(num_epochs):
|
|
46
|
+
model.train()
|
|
47
|
+
epoch_loss = 0.0
|
|
48
|
+
for idx, batch in enumerate(train_loader):
|
|
49
|
+
images = batch["image"].to(device, dtype=torch.float)
|
|
50
|
+
targets = batch["target"].to(device, dtype=torch.float)
|
|
51
|
+
|
|
52
|
+
optimizer.zero_grad()
|
|
53
|
+
outputs = model(images)
|
|
54
|
+
loss = criterion(outputs, targets)
|
|
55
|
+
loss.backward()
|
|
56
|
+
optimizer.step()
|
|
57
|
+
|
|
58
|
+
if idx % 10 == 0:
|
|
59
|
+
logger.info(
|
|
60
|
+
f"Epoch: {epoch + 1}/{num_epochs}, "
|
|
61
|
+
f"Batch: {idx + 1}/{len(train_loader)}, "
|
|
62
|
+
f"Loss: {loss.item():.4f}"
|
|
63
|
+
)
|
|
64
|
+
writer.add_scalar(
|
|
65
|
+
"Loss/train_batch", loss.item(),
|
|
66
|
+
epoch * len(train_loader) + idx)
|
|
67
|
+
img = images[0, 0].cpu().numpy()
|
|
68
|
+
img = np.stack((img,)*3, axis=0)
|
|
69
|
+
# normalize to 0-255
|
|
70
|
+
img = (img - img.min()) / (img.max() - img.min()) * 255
|
|
71
|
+
# record images
|
|
72
|
+
writer.add_image(
|
|
73
|
+
"Image/input",
|
|
74
|
+
images[0], epoch * len(train_loader) + idx)
|
|
75
|
+
writer.add_image(
|
|
76
|
+
"Image/target",
|
|
77
|
+
targets[0], epoch * len(train_loader) + idx)
|
|
78
|
+
writer.add_image(
|
|
79
|
+
"Image/pred",
|
|
80
|
+
outputs[0], epoch * len(train_loader) + idx)
|
|
81
|
+
|
|
82
|
+
epoch_loss += loss.item()
|
|
83
|
+
|
|
84
|
+
epoch_loss /= len(train_loader)
|
|
85
|
+
writer.add_scalar("Loss/train", epoch_loss, epoch)
|
|
86
|
+
|
|
87
|
+
model.eval()
|
|
88
|
+
val_loss = 0.0
|
|
89
|
+
with torch.no_grad():
|
|
90
|
+
for batch in valid_loader:
|
|
91
|
+
images = batch["image"].to(device, dtype=torch.float)
|
|
92
|
+
targets = batch["target"].to(device, dtype=torch.float)
|
|
93
|
+
|
|
94
|
+
outputs = model(images)
|
|
95
|
+
loss = criterion(outputs, targets)
|
|
96
|
+
val_loss += loss.item()
|
|
97
|
+
|
|
98
|
+
val_loss /= len(valid_loader)
|
|
99
|
+
writer.add_scalar("Loss/val", val_loss, epoch)
|
|
100
|
+
|
|
101
|
+
logger.info(
|
|
102
|
+
f"Epoch {epoch + 1}/{num_epochs}, "
|
|
103
|
+
f"Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if val_loss < best_val_loss:
|
|
107
|
+
best_val_loss = val_loss
|
|
108
|
+
torch.save(
|
|
109
|
+
model.state_dict(),
|
|
110
|
+
f"{model_save_dir}/best_model.pth")
|
|
111
|
+
logger.info(f"Best model saved with Val Loss: {val_loss:.4f}")
|
|
112
|
+
if epoch % save_period == 0:
|
|
113
|
+
torch.save(
|
|
114
|
+
model.state_dict(),
|
|
115
|
+
f"{model_save_dir}/model_{epoch}.pth")
|
|
116
|
+
logger.info(f"Model saved at epoch {epoch}.")
|
|
117
|
+
|
|
118
|
+
writer.close()
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def train_on_dataset(
|
|
122
|
+
model: torch.nn.Module,
|
|
123
|
+
train_dataset: FISHSpotsDataset,
|
|
124
|
+
valid_dataset: FISHSpotsDataset,
|
|
125
|
+
loss_type: str = "DiceRMSELoss",
|
|
126
|
+
loader_workers: int = 4,
|
|
127
|
+
num_epochs: int = 50,
|
|
128
|
+
batch_size: int = 8,
|
|
129
|
+
lr: float = 1e-4,
|
|
130
|
+
summary_dir: str = "runs/unet",
|
|
131
|
+
model_save_dir: str = "./models",
|
|
132
|
+
save_period: int = 5,
|
|
133
|
+
):
|
|
134
|
+
"""Train the UNet model.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
model: The model to train.
|
|
138
|
+
train_dataset: The training dataset.
|
|
139
|
+
valid_dataset: The validation dataset.
|
|
140
|
+
loss_type: The loss function to use.
|
|
141
|
+
loader_workers: The number of workers to use for the data loader.
|
|
142
|
+
num_epochs: The number of epochs to train for.
|
|
143
|
+
batch_size: The batch size.
|
|
144
|
+
lr: The learning rate.
|
|
145
|
+
summary_dir: The directory to save the TensorBoard summary to.
|
|
146
|
+
model_save_dir: The directory to save the model to.
|
|
147
|
+
save_period: Save the model every `save_period` epochs.
|
|
148
|
+
"""
|
|
149
|
+
logger.info(
|
|
150
|
+
f"Loader workers: {loader_workers}, " +
|
|
151
|
+
f"TensorBoard summary dir: {summary_dir}"
|
|
152
|
+
)
|
|
153
|
+
logger.info(
|
|
154
|
+
f"Model save dir: {model_save_dir}, " +
|
|
155
|
+
f"Save period: {save_period}"
|
|
156
|
+
)
|
|
157
|
+
train_loader = DataLoader(
|
|
158
|
+
train_dataset, batch_size=batch_size,
|
|
159
|
+
shuffle=True, num_workers=loader_workers)
|
|
160
|
+
valid_loader = DataLoader(
|
|
161
|
+
valid_dataset, batch_size=batch_size,
|
|
162
|
+
shuffle=False, num_workers=loader_workers)
|
|
163
|
+
|
|
164
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
165
|
+
logger.info(f"Training using device: {device}")
|
|
166
|
+
model = model.to(device)
|
|
167
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
168
|
+
|
|
169
|
+
criterion = eval(f"loss_mod.{loss_type}")()
|
|
170
|
+
|
|
171
|
+
writer = SummaryWriter(summary_dir)
|
|
172
|
+
training_loop(
|
|
173
|
+
model, optimizer, criterion, writer, device,
|
|
174
|
+
train_loader, valid_loader, model_save_dir,
|
|
175
|
+
save_period, num_epochs)
|
ufish/utils/__init__.py
ADDED
|
File without changes
|