locals-api 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- locals_api-0.1.0/PKG-INFO +7 -0
- locals_api-0.1.0/README.md +28 -0
- locals_api-0.1.0/pyproject.toml +20 -0
- locals_api-0.1.0/setup.cfg +4 -0
- locals_api-0.1.0/src/locals/__init__.py +6 -0
- locals_api-0.1.0/src/locals/architectures.py +65 -0
- locals_api-0.1.0/src/locals/blocks.py +56 -0
- locals_api-0.1.0/src/locals/constants.py +1 -0
- locals_api-0.1.0/src/locals/dataset.py +78 -0
- locals_api-0.1.0/src/locals/evaluator.py +192 -0
- locals_api-0.1.0/src/locals/figmaker.py +212 -0
- locals_api-0.1.0/src/locals/locals.py +50 -0
- locals_api-0.1.0/src/locals/losses.py +63 -0
- locals_api-0.1.0/src/locals/math.py +29 -0
- locals_api-0.1.0/src/locals/runner.py +67 -0
- locals_api-0.1.0/src/locals/trainer.py +83 -0
- locals_api-0.1.0/src/locals_api.egg-info/PKG-INFO +7 -0
- locals_api-0.1.0/src/locals_api.egg-info/SOURCES.txt +19 -0
- locals_api-0.1.0/src/locals_api.egg-info/dependency_links.txt +1 -0
- locals_api-0.1.0/src/locals_api.egg-info/requires.txt +3 -0
- locals_api-0.1.0/src/locals_api.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# LOCALS
|
|
2
|
+
|
|
3
|
+
Rapid experimentation framework for the LOCALS detector.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
Install PyTorch according to your CUDA version.
|
|
8
|
+
|
|
9
|
+
Then:
|
|
10
|
+
|
|
11
|
+
```bash
|
|
12
|
+
pip install locals-api
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
## Example
|
|
16
|
+
|
|
17
|
+
```python
|
|
18
|
+
from locals import run
|
|
19
|
+
|
|
20
|
+
run(
|
|
21
|
+
seed=1111,
|
|
22
|
+
train_split=0.8,
|
|
23
|
+
test_split=0.1,
|
|
24
|
+
images_dir="dataset/images",
|
|
25
|
+
labels_dir="dataset/labels",
|
|
26
|
+
num_epochs=50,
|
|
27
|
+
)
|
|
28
|
+
```
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "locals-api"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
requires-python = ">=3.10"
|
|
9
|
+
|
|
10
|
+
dependencies = [
|
|
11
|
+
"numpy",
|
|
12
|
+
"opencv-python",
|
|
13
|
+
"matplotlib",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
[tool.setuptools]
|
|
17
|
+
package-dir = {"" = "src"}
|
|
18
|
+
|
|
19
|
+
[tool.setuptools.packages.find]
|
|
20
|
+
where = ["src"]
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from .blocks import ResnetBlock, ConvBlock
|
|
5
|
+
|
|
6
|
+
class LOCALSn(nn.Module):
|
|
7
|
+
def __init__(self):
|
|
8
|
+
super().__init__()
|
|
9
|
+
|
|
10
|
+
self.model = nn.Sequential(
|
|
11
|
+
ConvBlock(input_channels=3, output_channels=16, kernel_size=2, stride=2), # 224
|
|
12
|
+
ResnetBlock(input_channels=16, output_channels=32), # 112
|
|
13
|
+
ResnetBlock(input_channels=32, output_channels=64), # 56
|
|
14
|
+
ConvBlock(input_channels=64, output_channels=128, kernel_size=2, stride=2), # 28
|
|
15
|
+
ResnetBlock(input_channels=128, output_channels=256), # 14
|
|
16
|
+
ResnetBlock(input_channels=256, output_channels=256), # 7
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# 1x1 conv is important here because we do not want to do feature extraction in heads
|
|
20
|
+
|
|
21
|
+
self.head = nn.Sequential(
|
|
22
|
+
ConvBlock(input_channels=256, output_channels=3, activation=False),
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
def forward(self, x):
|
|
26
|
+
out = self.model(x)
|
|
27
|
+
out = self.head(out)
|
|
28
|
+
out = out.permute(0, 2, 3, 1) # [B, 7, 7, 3]
|
|
29
|
+
out = torch.sigmoid(out)
|
|
30
|
+
return out
|
|
31
|
+
|
|
32
|
+
class LOCALSs(nn.Module):
|
|
33
|
+
def __init__(self):
|
|
34
|
+
super().__init__()
|
|
35
|
+
|
|
36
|
+
self.model = nn.Sequential(
|
|
37
|
+
ConvBlock(input_channels=3, output_channels=16, kernel_size=3, padding=1), # 448
|
|
38
|
+
ResnetBlock(input_channels=16, output_channels=32), # 224
|
|
39
|
+
ResnetBlock(input_channels=32, output_channels=64), # 112
|
|
40
|
+
ResnetBlock(input_channels=64, output_channels=128), # 56
|
|
41
|
+
ResnetBlock(input_channels=128, output_channels=256), # 28
|
|
42
|
+
ResnetBlock(input_channels=256, output_channels=512), # 14
|
|
43
|
+
ResnetBlock(input_channels=512, output_channels=256), # 7
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# 1x1 conv is important here because we do not want to do feature extraction in heads
|
|
47
|
+
|
|
48
|
+
self.loc_head = nn.Sequential(
|
|
49
|
+
ConvBlock(input_channels=256, output_channels=64),
|
|
50
|
+
ConvBlock(input_channels=64, output_channels=16),
|
|
51
|
+
ConvBlock(input_channels=16, output_channels=2, activation=False),
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
self.class_head = nn.Sequential(
|
|
55
|
+
ConvBlock(input_channels=256, output_channels=1, activation=False),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def forward(self, x):
|
|
59
|
+
out = self.model(x)
|
|
60
|
+
loc_info = self.loc_head(out) # [B, 2, 7, 7]
|
|
61
|
+
class_info = self.class_head(out) # [B, 1, 7, 7]
|
|
62
|
+
out = torch.concatenate([loc_info, class_info], dim=1) # [B, 3, 7, 7]
|
|
63
|
+
out = out.permute(0, 2, 3, 1) # [B, 7, 7, 3]
|
|
64
|
+
out = torch.sigmoid(out)
|
|
65
|
+
return out
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
|
|
3
|
+
'''
|
|
4
|
+
This block models the skip connection introduced in the Resnet paper 2015.
|
|
5
|
+
'''
|
|
6
|
+
|
|
7
|
+
class ResnetBlock(nn.Module):
|
|
8
|
+
def __init__(self, input_channels, output_channels):
|
|
9
|
+
super().__init__()
|
|
10
|
+
|
|
11
|
+
self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, padding=1)
|
|
12
|
+
self.conv2 = nn.Conv2d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, stride=2, padding=1)
|
|
13
|
+
self.downsamp_conv = nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=2, stride=2)
|
|
14
|
+
self.relu = nn.ReLU()
|
|
15
|
+
|
|
16
|
+
self.bn1 = nn.BatchNorm2d(output_channels)
|
|
17
|
+
self.bn2 = nn.BatchNorm2d(output_channels)
|
|
18
|
+
self.downsamp_bn = nn.BatchNorm2d(output_channels)
|
|
19
|
+
|
|
20
|
+
def forward(self, x):
|
|
21
|
+
# feature extraction
|
|
22
|
+
out = self.conv1(x)
|
|
23
|
+
out = self.bn1(out)
|
|
24
|
+
out = self.relu(out)
|
|
25
|
+
|
|
26
|
+
# reduce H and W by 2
|
|
27
|
+
out = self.conv2(out)
|
|
28
|
+
out = self.bn2(out)
|
|
29
|
+
|
|
30
|
+
# bring x to required number of channels
|
|
31
|
+
downsamp_x = self.downsamp_conv(x)
|
|
32
|
+
downsamp_x = self.downsamp_bn(downsamp_x)
|
|
33
|
+
|
|
34
|
+
# skip connection
|
|
35
|
+
return self.relu(out + downsamp_x)
|
|
36
|
+
|
|
37
|
+
'''
|
|
38
|
+
Simple 1x1 convolution block, nothing much to see here.
|
|
39
|
+
'''
|
|
40
|
+
|
|
41
|
+
class ConvBlock(nn.Module):
|
|
42
|
+
def __init__(self, input_channels, output_channels, kernel_size=1, padding=0, stride=1, activation=True):
|
|
43
|
+
super().__init__()
|
|
44
|
+
|
|
45
|
+
self.channel_downsamp_conv = nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=kernel_size, padding=padding, stride=stride)
|
|
46
|
+
self.channel_downsamp_bn = nn.BatchNorm2d(output_channels)
|
|
47
|
+
self.relu = nn.ReLU()
|
|
48
|
+
self.activation = activation
|
|
49
|
+
|
|
50
|
+
def forward(self, x):
|
|
51
|
+
out = self.channel_downsamp_conv(x)
|
|
52
|
+
out = self.channel_downsamp_bn(out)
|
|
53
|
+
if self.activation:
|
|
54
|
+
out = self.relu(out)
|
|
55
|
+
|
|
56
|
+
return out
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
NUM_GRID_CELLS = 7
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
from torch.utils.data import DataLoader, Dataset, random_split
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
class LOCALSDataset(Dataset):
|
|
9
|
+
def __init__(self, images_dir, labels_dir):
|
|
10
|
+
self.image_list = [
|
|
11
|
+
f_name for f_name in os.listdir(images_dir)
|
|
12
|
+
if f_name.lower().endswith((".png", ".jpg", ".jpeg"))
|
|
13
|
+
]
|
|
14
|
+
self.images_dir = images_dir
|
|
15
|
+
self.labels_dir = labels_dir
|
|
16
|
+
self.image_size=448
|
|
17
|
+
|
|
18
|
+
def __len__(self):
|
|
19
|
+
return len(self.image_list)
|
|
20
|
+
|
|
21
|
+
def __getitem__(self, idx):
|
|
22
|
+
image_name = self.image_list[idx]
|
|
23
|
+
|
|
24
|
+
image = cv2.imread(
|
|
25
|
+
os.path.join(self.images_dir,
|
|
26
|
+
image_name)
|
|
27
|
+
)
|
|
28
|
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
29
|
+
image = cv2.resize(image, (self.image_size, self.image_size))
|
|
30
|
+
image = image.astype(np.float32) / 255.0
|
|
31
|
+
image = torch.from_numpy(image).permute(2, 0, 1)
|
|
32
|
+
|
|
33
|
+
prefix = image_name.split('.')[0]
|
|
34
|
+
annotation_matrix = np.load(os.path.join(self.labels_dir, f'{prefix}.npy'))
|
|
35
|
+
|
|
36
|
+
return image, annotation_matrix
|
|
37
|
+
|
|
38
|
+
def get_dataloaders(self, train_split = 0.8, test_split = 0.1, batch_size=16):
|
|
39
|
+
assert (
|
|
40
|
+
0 < train_split < 1
|
|
41
|
+
and 0 < test_split < 1
|
|
42
|
+
and train_split + test_split <= 1
|
|
43
|
+
), "train_split + test_split must not exceed 1."
|
|
44
|
+
|
|
45
|
+
N = len(self)
|
|
46
|
+
train_size = int(train_split * N)
|
|
47
|
+
test_size = int(N * test_split)
|
|
48
|
+
val_size = N - (train_size + test_size)
|
|
49
|
+
|
|
50
|
+
train_dataset, test_dataset, val_dataset = random_split(
|
|
51
|
+
self,
|
|
52
|
+
[train_size, test_size, val_size]
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
train_loader = DataLoader(
|
|
56
|
+
train_dataset,
|
|
57
|
+
batch_size=batch_size,
|
|
58
|
+
shuffle=True,
|
|
59
|
+
pin_memory=True
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
test_loader = DataLoader(
|
|
63
|
+
test_dataset,
|
|
64
|
+
batch_size=1,
|
|
65
|
+
shuffle=True,
|
|
66
|
+
pin_memory=True
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
val_loader = DataLoader(
|
|
70
|
+
val_dataset,
|
|
71
|
+
batch_size=batch_size,
|
|
72
|
+
shuffle=False,
|
|
73
|
+
pin_memory=True
|
|
74
|
+
) if val_size > 0 else None
|
|
75
|
+
|
|
76
|
+
if val_size > 0:
|
|
77
|
+
return train_loader, test_loader, val_loader
|
|
78
|
+
return train_loader, test_loader
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from torch.utils.data import DataLoader
|
|
4
|
+
|
|
5
|
+
from .math import pearson_corr
|
|
6
|
+
from .constants import NUM_GRID_CELLS
|
|
7
|
+
|
|
8
|
+
def find_recall_precision_f1_score(model, data, threshold=0.5, num_batches=100):
|
|
9
|
+
if not isinstance(data, DataLoader):
|
|
10
|
+
data = DataLoader(
|
|
11
|
+
data,
|
|
12
|
+
batch_size=1,
|
|
13
|
+
shuffle=False
|
|
14
|
+
)
|
|
15
|
+
model.eval()
|
|
16
|
+
|
|
17
|
+
true_positives = 0
|
|
18
|
+
false_positives = 0
|
|
19
|
+
false_negatives = 0
|
|
20
|
+
|
|
21
|
+
with torch.no_grad():
|
|
22
|
+
for batch_idx, (images, labels) in enumerate(data):
|
|
23
|
+
if batch_idx >= num_batches:
|
|
24
|
+
break
|
|
25
|
+
|
|
26
|
+
images_tensor = images.to(model.device)
|
|
27
|
+
predictions_batch = model(images_tensor) # [B, 7, 7, 3]
|
|
28
|
+
batch_size = images_tensor.shape[0]
|
|
29
|
+
|
|
30
|
+
for i in range(batch_size):
|
|
31
|
+
prediction = predictions_batch[i]
|
|
32
|
+
prediction = prediction.cpu().numpy()
|
|
33
|
+
|
|
34
|
+
label = labels[i]
|
|
35
|
+
label_numpy = label.cpu().numpy()
|
|
36
|
+
|
|
37
|
+
# compare predictions vs labels
|
|
38
|
+
for j in range(7):
|
|
39
|
+
for k in range(7):
|
|
40
|
+
pred_obj = prediction[j, k, -1]
|
|
41
|
+
label_obj = label_numpy[j, k, -1]
|
|
42
|
+
|
|
43
|
+
if pred_obj > threshold and label_obj > 0:
|
|
44
|
+
true_positives += 1
|
|
45
|
+
elif pred_obj > threshold and label_obj == 0:
|
|
46
|
+
false_positives += 1
|
|
47
|
+
elif pred_obj <= threshold and label_obj > 0:
|
|
48
|
+
false_negatives += 1
|
|
49
|
+
|
|
50
|
+
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
|
|
51
|
+
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
|
|
52
|
+
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
|
53
|
+
|
|
54
|
+
return recall, precision, f1_score
|
|
55
|
+
|
|
56
|
+
def find_mAP(model, data, class_threshold=0.5):
|
|
57
|
+
if not isinstance(data, DataLoader):
|
|
58
|
+
data = DataLoader(
|
|
59
|
+
data,
|
|
60
|
+
batch_size=1,
|
|
61
|
+
shuffle=False
|
|
62
|
+
)
|
|
63
|
+
model.eval()
|
|
64
|
+
|
|
65
|
+
dist_thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
|
|
66
|
+
APs = []
|
|
67
|
+
|
|
68
|
+
for dist_threshold in dist_thresholds:
|
|
69
|
+
all_detections = []
|
|
70
|
+
num_trues = 0
|
|
71
|
+
for batched_images, batched_labels in data:
|
|
72
|
+
batched_images = batched_images.to(model.device)
|
|
73
|
+
batched_predictions = model(batched_images).detach().cpu()
|
|
74
|
+
batch_size = batched_images.shape[0]
|
|
75
|
+
|
|
76
|
+
for b in range(batch_size):
|
|
77
|
+
predictions = batched_predictions[b]
|
|
78
|
+
labels = batched_labels[b]
|
|
79
|
+
|
|
80
|
+
predictions = predictions.reshape(-1, predictions.shape[-1])
|
|
81
|
+
trues = labels.reshape(-1, labels.shape[-1])
|
|
82
|
+
taken_trues = [False] * trues.shape[0]
|
|
83
|
+
|
|
84
|
+
num_trues += torch.sum(trues[:, 2] == 1).item()
|
|
85
|
+
indices = torch.arange(predictions.shape[0])
|
|
86
|
+
order = predictions[:, -1].argsort(descending=True)
|
|
87
|
+
predictions = predictions[order]
|
|
88
|
+
indices = indices[order]
|
|
89
|
+
|
|
90
|
+
for i in range(predictions.shape[0]):
|
|
91
|
+
pred_xnb, pred_ynb, pred_conf = predictions[i]
|
|
92
|
+
cell_idx = indices[i]
|
|
93
|
+
row = cell_idx // NUM_GRID_CELLS
|
|
94
|
+
col = cell_idx % NUM_GRID_CELLS
|
|
95
|
+
pred_x = ((col / NUM_GRID_CELLS) + (pred_xnb * (1 / NUM_GRID_CELLS)))
|
|
96
|
+
pred_y = ((row / NUM_GRID_CELLS) + (pred_ynb * (1 / NUM_GRID_CELLS)))
|
|
97
|
+
|
|
98
|
+
if pred_conf >= class_threshold:
|
|
99
|
+
closest_idx = -1
|
|
100
|
+
closest_distance = torch.inf
|
|
101
|
+
for j in range(trues.shape[0]):
|
|
102
|
+
true_xnb, true_ynb, true_conf = trues[j]
|
|
103
|
+
row = j // NUM_GRID_CELLS
|
|
104
|
+
col = j % NUM_GRID_CELLS
|
|
105
|
+
true_x = ((col / NUM_GRID_CELLS) + (true_xnb * (1 / NUM_GRID_CELLS)))
|
|
106
|
+
true_y = ((row / NUM_GRID_CELLS) + (true_ynb * (1 / NUM_GRID_CELLS)))
|
|
107
|
+
|
|
108
|
+
if true_conf >= 0.5:
|
|
109
|
+
pred_coord = torch.tensor([pred_x, pred_y])
|
|
110
|
+
true_coord = torch.tensor([true_x, true_y])
|
|
111
|
+
|
|
112
|
+
dist = torch.sqrt(torch.sum((pred_coord - true_coord) ** 2, dim=-1))
|
|
113
|
+
if dist < closest_distance:
|
|
114
|
+
closest_distance = dist
|
|
115
|
+
closest_idx = j
|
|
116
|
+
|
|
117
|
+
dist_conf = 1 - torch.sigmoid(69 * (closest_distance - 0.1))
|
|
118
|
+
if closest_idx != -1 and not taken_trues[closest_idx] and dist_conf >= dist_threshold:
|
|
119
|
+
taken_trues[closest_idx] = True
|
|
120
|
+
all_detections.append((pred_conf.item(), 1, 0))
|
|
121
|
+
else:
|
|
122
|
+
all_detections.append((pred_conf.item(), 0, 1))
|
|
123
|
+
|
|
124
|
+
all_detections.sort(key=lambda x: x[0], reverse=True)
|
|
125
|
+
TP = [d[1] for d in all_detections]
|
|
126
|
+
FP = [d[2] for d in all_detections]
|
|
127
|
+
|
|
128
|
+
TP_cum = np.cumsum(TP)
|
|
129
|
+
FP_cum = np.cumsum(FP)
|
|
130
|
+
|
|
131
|
+
precisions = TP_cum / (TP_cum + FP_cum + 1e-6)
|
|
132
|
+
recalls = TP_cum / (num_trues + 1e-6)
|
|
133
|
+
AP = np.trapezoid(precisions, recalls)
|
|
134
|
+
APs.append(AP)
|
|
135
|
+
|
|
136
|
+
return sum(APs)/len(APs)
|
|
137
|
+
|
|
138
|
+
def find_mCS(model, data, threshold=0.5, num_batches=100):
|
|
139
|
+
if not isinstance(data, DataLoader):
|
|
140
|
+
data = DataLoader(
|
|
141
|
+
data,
|
|
142
|
+
batch_size=1,
|
|
143
|
+
shuffle=False
|
|
144
|
+
)
|
|
145
|
+
model.eval()
|
|
146
|
+
|
|
147
|
+
correlations = []
|
|
148
|
+
|
|
149
|
+
with torch.no_grad():
|
|
150
|
+
for batch_idx, (images, labels) in enumerate(data):
|
|
151
|
+
if batch_idx >= num_batches:
|
|
152
|
+
break
|
|
153
|
+
|
|
154
|
+
images_tensor = images.to(model.device)
|
|
155
|
+
|
|
156
|
+
predictions_batch = model(images_tensor)
|
|
157
|
+
batch_size = images_tensor.shape[0]
|
|
158
|
+
|
|
159
|
+
for i in range(batch_size):
|
|
160
|
+
prediction = predictions_batch[i]
|
|
161
|
+
prediction = prediction.cpu().numpy()
|
|
162
|
+
|
|
163
|
+
label = labels[i]
|
|
164
|
+
|
|
165
|
+
# extract predicted points
|
|
166
|
+
predicted_points = []
|
|
167
|
+
for row in range(prediction.shape[0]):
|
|
168
|
+
for col in range(prediction.shape[1]):
|
|
169
|
+
cell = prediction[row][col]
|
|
170
|
+
if cell[-1] > threshold:
|
|
171
|
+
xnb, ynb, c = cell
|
|
172
|
+
xn = ((col / NUM_GRID_CELLS) + (xnb * (1 / NUM_GRID_CELLS)))
|
|
173
|
+
yn = ((row / NUM_GRID_CELLS) + (ynb * (1 / NUM_GRID_CELLS)))
|
|
174
|
+
predicted_points.append([xn, yn])
|
|
175
|
+
|
|
176
|
+
# extract label points
|
|
177
|
+
label_points = []
|
|
178
|
+
for row in range(label.shape[0]):
|
|
179
|
+
for col in range(label.shape[1]):
|
|
180
|
+
cell = label[row][col]
|
|
181
|
+
if cell[-1] > 0:
|
|
182
|
+
xnb, ynb, c = cell
|
|
183
|
+
xn = ((col / NUM_GRID_CELLS) + (xnb * (1 / NUM_GRID_CELLS)))
|
|
184
|
+
yn = ((row / NUM_GRID_CELLS) + (ynb * (1 / NUM_GRID_CELLS)))
|
|
185
|
+
label_points.append([xn, yn])
|
|
186
|
+
|
|
187
|
+
if not predicted_points:
|
|
188
|
+
correlations.append(0)
|
|
189
|
+
else:
|
|
190
|
+
correlations.append(abs(pearson_corr(label_points + predicted_points)))
|
|
191
|
+
|
|
192
|
+
return np.mean(correlations) if correlations else 0.0
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
from torch.utils.data import DataLoader
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
from .constants import NUM_GRID_CELLS
|
|
8
|
+
from .math import smooth_curve, closest_factors
|
|
9
|
+
|
|
10
|
+
image_size = 448
|
|
11
|
+
|
|
12
|
+
def save_loss_fig(train_loss_ot: list, val_loss_ot: list):
|
|
13
|
+
os.makedirs('figures', exist_ok=True)
|
|
14
|
+
|
|
15
|
+
# smooth the losses
|
|
16
|
+
smooth_train_loss = smooth_curve(train_loss_ot)
|
|
17
|
+
smooth_val_loss = smooth_curve(val_loss_ot) if val_loss_ot else None
|
|
18
|
+
|
|
19
|
+
# plot the smoothed losses
|
|
20
|
+
epochs = range(1, len(train_loss_ot) + 1)
|
|
21
|
+
plt.figure(figsize=(10, 6), dpi=300)
|
|
22
|
+
plt.plot(epochs, smooth_train_loss, color='blue', linewidth=2)
|
|
23
|
+
if val_loss_ot:
|
|
24
|
+
plt.plot(epochs, smooth_val_loss, color='red', linewidth=2)
|
|
25
|
+
plt.title(r'Smoothed Training Loss Over Epochs', fontsize=14)
|
|
26
|
+
plt.xlabel(r'Epoch', fontsize=12)
|
|
27
|
+
plt.ylabel(r'Loss', fontsize=12)
|
|
28
|
+
plt.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.7)
|
|
29
|
+
plt.xticks(fontsize=10)
|
|
30
|
+
plt.yticks(fontsize=10)
|
|
31
|
+
plt.tight_layout()
|
|
32
|
+
plt.savefig("figures/smoothed_training_loss.png", dpi=300)
|
|
33
|
+
|
|
34
|
+
def visualise_dataset(data, num_images, plot_title=None):
|
|
35
|
+
if not isinstance(data, DataLoader):
|
|
36
|
+
data = DataLoader(
|
|
37
|
+
data,
|
|
38
|
+
batch_size=1,
|
|
39
|
+
shuffle=False
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
rows, cols = closest_factors(num_images)
|
|
43
|
+
fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, max(3 * rows, 4)), squeeze=False)
|
|
44
|
+
fig.suptitle('Visualising Dataset of LOCALS for Pencils') if not plot_title else fig.suptitle(plot_title)
|
|
45
|
+
image_idx = 0
|
|
46
|
+
axes = axes.flatten()
|
|
47
|
+
|
|
48
|
+
for images, labels in data:
|
|
49
|
+
if image_idx >= num_images: break
|
|
50
|
+
for i, image in enumerate(images):
|
|
51
|
+
if image_idx >= num_images: break
|
|
52
|
+
image_plt = image.permute(1, 2, 0)
|
|
53
|
+
ax = axes[image_idx]
|
|
54
|
+
|
|
55
|
+
ax.axis('off')
|
|
56
|
+
label = labels[i]
|
|
57
|
+
ax.set_title(f'Sample {image_idx + 1}')
|
|
58
|
+
ax.imshow(image_plt)
|
|
59
|
+
|
|
60
|
+
grid_size = image_size / NUM_GRID_CELLS
|
|
61
|
+
for g in range(NUM_GRID_CELLS + 1):
|
|
62
|
+
pos = g * grid_size
|
|
63
|
+
|
|
64
|
+
# vertical lines
|
|
65
|
+
ax.axvline(
|
|
66
|
+
x=pos,
|
|
67
|
+
color='white',
|
|
68
|
+
linewidth=0.5,
|
|
69
|
+
alpha=0.5
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# horizontal lines
|
|
73
|
+
ax.axhline(
|
|
74
|
+
y=pos,
|
|
75
|
+
color='white',
|
|
76
|
+
linewidth=0.5,
|
|
77
|
+
alpha=0.5
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
for j in range(NUM_GRID_CELLS):
|
|
81
|
+
for k in range(NUM_GRID_CELLS):
|
|
82
|
+
xnb, ynb, c = label[j][k]
|
|
83
|
+
if c < 0.5:
|
|
84
|
+
continue
|
|
85
|
+
x = ((k / NUM_GRID_CELLS) + (xnb * (1 / NUM_GRID_CELLS))) * image_size
|
|
86
|
+
y = ((j / NUM_GRID_CELLS) + (ynb * (1 / NUM_GRID_CELLS))) * image_size
|
|
87
|
+
|
|
88
|
+
ax.scatter(x, y, color='red', marker='x', s=40)
|
|
89
|
+
image_idx += 1
|
|
90
|
+
|
|
91
|
+
plt.savefig('figures/visualised_dataset.png', dpi=300, bbox_inches='tight') if not plot_title else \
|
|
92
|
+
plt.savefig(f"figures/visualised_dataset_{plot_title.lower().replace(' ', '_')}.png", dpi=300, bbox_inches='tight')
|
|
93
|
+
plt.show()
|
|
94
|
+
|
|
95
|
+
def visualise_predictions(model, data, num_images, plot_title=None):
|
|
96
|
+
if not isinstance(data, DataLoader):
|
|
97
|
+
data = DataLoader(
|
|
98
|
+
data,
|
|
99
|
+
batch_size=1,
|
|
100
|
+
shuffle=False
|
|
101
|
+
)
|
|
102
|
+
model.eval()
|
|
103
|
+
|
|
104
|
+
with torch.no_grad():
|
|
105
|
+
rows, cols = closest_factors(num_images)
|
|
106
|
+
fig, axes = plt.subplots(
|
|
107
|
+
rows,
|
|
108
|
+
cols,
|
|
109
|
+
figsize=(3 * cols, 3 * rows + 1),
|
|
110
|
+
squeeze=False
|
|
111
|
+
)
|
|
112
|
+
fig.suptitle('Visualising Predictions of LOCALS for Pencils') if not plot_title else fig.suptitle(plot_title)
|
|
113
|
+
axes = axes.flatten()
|
|
114
|
+
image_idx = 0
|
|
115
|
+
|
|
116
|
+
for images, labels in data:
|
|
117
|
+
if image_idx >= num_images: break
|
|
118
|
+
outputs = model(images.to(model.device))
|
|
119
|
+
for i, image in enumerate(images):
|
|
120
|
+
image_plt = image.permute(1, 2, 0)
|
|
121
|
+
label = labels[i]
|
|
122
|
+
output = outputs[i].detach().cpu()
|
|
123
|
+
ax = axes[image_idx]
|
|
124
|
+
ax.axis('off')
|
|
125
|
+
ax.imshow(image_plt.detach().cpu())
|
|
126
|
+
ax.set_title(f'Prediction {image_idx + 1}')
|
|
127
|
+
grid_size = image_size / NUM_GRID_CELLS
|
|
128
|
+
|
|
129
|
+
for g in range(NUM_GRID_CELLS + 1):
|
|
130
|
+
pos = g * grid_size
|
|
131
|
+
|
|
132
|
+
# vertical lines
|
|
133
|
+
ax.axvline(
|
|
134
|
+
x=pos,
|
|
135
|
+
color='white',
|
|
136
|
+
linewidth=0.5,
|
|
137
|
+
alpha=0.5
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# horizontal lines
|
|
141
|
+
ax.axhline(
|
|
142
|
+
y=pos,
|
|
143
|
+
color='white',
|
|
144
|
+
linewidth=0.5,
|
|
145
|
+
alpha=0.5
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
for j in range(NUM_GRID_CELLS):
|
|
149
|
+
for k in range(NUM_GRID_CELLS):
|
|
150
|
+
|
|
151
|
+
# ground truth
|
|
152
|
+
xnb, ynb, c = label[j][k]
|
|
153
|
+
|
|
154
|
+
if c.item() > 0.5:
|
|
155
|
+
|
|
156
|
+
x = ((k / NUM_GRID_CELLS) + (xnb * (1 / NUM_GRID_CELLS))) * image_size
|
|
157
|
+
y = ((j / NUM_GRID_CELLS) + (ynb * (1 / NUM_GRID_CELLS))) * image_size
|
|
158
|
+
ax.scatter(
|
|
159
|
+
x.detach().cpu(),
|
|
160
|
+
y.detach().cpu(),
|
|
161
|
+
color='yellow',
|
|
162
|
+
marker='x',
|
|
163
|
+
s=40,
|
|
164
|
+
alpha=1
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# prediction
|
|
168
|
+
pxnb, pynb, pc = output[j][k]
|
|
169
|
+
|
|
170
|
+
if pc.item() >= 0.5:
|
|
171
|
+
|
|
172
|
+
x = ((k / NUM_GRID_CELLS) + (pxnb * (1 / NUM_GRID_CELLS))) * image_size
|
|
173
|
+
y = ((j / NUM_GRID_CELLS) + (pynb * (1 / NUM_GRID_CELLS))) * image_size
|
|
174
|
+
|
|
175
|
+
ax.scatter(
|
|
176
|
+
x.detach().cpu(),
|
|
177
|
+
y.detach().cpu(),
|
|
178
|
+
color='red',
|
|
179
|
+
marker='x',
|
|
180
|
+
s=40,
|
|
181
|
+
alpha=0.5,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
ax.text(
|
|
185
|
+
x.detach().cpu() + 5, # x offset
|
|
186
|
+
y.detach().cpu() + 5, # y offset
|
|
187
|
+
f'{pc.detach().cpu():.2f}',
|
|
188
|
+
color='white',
|
|
189
|
+
fontsize=10
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
image_idx += 1
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
fig.tight_layout(rect=[0, 0, 1, 0.95])
|
|
196
|
+
plt.savefig(f"figures/visualised_predictions_{plot_title.lower().replace(' ', '_')}.png", dpi=300, bbox_inches='tight') if plot_title else \
|
|
197
|
+
plt.savefig(f'figures/visualised_predictions.png')
|
|
198
|
+
plt.show()
|
|
199
|
+
|
|
200
|
+
def save_recall_precision_f1_score(recall, precision, f1_score):
|
|
201
|
+
metrics = ['Recall', 'Precision', 'F1 Score']
|
|
202
|
+
values = [recall, precision, f1_score]
|
|
203
|
+
|
|
204
|
+
plt.figure(figsize=(6, 4))
|
|
205
|
+
plt.bar(metrics, values, color=['skyblue', 'lightgreen', 'salmon'])
|
|
206
|
+
plt.ylim(0, 1.1)
|
|
207
|
+
plt.title('Model Classification Performance Metrics')
|
|
208
|
+
plt.ylabel('Score')
|
|
209
|
+
for i, v in enumerate(values):
|
|
210
|
+
plt.text(i, v + 0.02, f"{v:.2f}", ha='center')
|
|
211
|
+
plt.tight_layout()
|
|
212
|
+
plt.savefig("figures/recall-precision-f1score.png", dpi=300)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.utils.data import DataLoader
|
|
3
|
+
|
|
4
|
+
from .trainer import train
|
|
5
|
+
from .architectures import LOCALSn, LOCALSs
|
|
6
|
+
from .figmaker import save_loss_fig, save_recall_precision_f1_score
|
|
7
|
+
from .evaluator import find_recall_precision_f1_score, find_mAP, find_mCS
|
|
8
|
+
|
|
9
|
+
class LOCALS:
|
|
10
|
+
def __init__(self, architecture='n', file_path=None, device='cuda'):
|
|
11
|
+
self.architecture = architecture
|
|
12
|
+
self.device = device
|
|
13
|
+
if architecture == 'n':
|
|
14
|
+
self.model = LOCALSn()
|
|
15
|
+
self.model.to(device)
|
|
16
|
+
if file_path:
|
|
17
|
+
self.model.load_state_dict(torch.load(file_path, weights_only=True))
|
|
18
|
+
|
|
19
|
+
def __call__(self, images):
|
|
20
|
+
return self.model(images)
|
|
21
|
+
|
|
22
|
+
def fit(self, train_loader: DataLoader, *, num_epochs=50, val_loader=None, epoch_to_lr={}, do_save_loss_fig=True):
|
|
23
|
+
train_loss_ot, val_loss_ot = train(train_loader, num_epochs, val_loader, self.model, self.device, epoch_to_lr)
|
|
24
|
+
self.model.load_state_dict(torch.load("best.pth", weights_only=True))
|
|
25
|
+
|
|
26
|
+
if do_save_loss_fig:
|
|
27
|
+
save_loss_fig(train_loss_ot, val_loss_ot)
|
|
28
|
+
|
|
29
|
+
def eval(self):
|
|
30
|
+
self.model.eval()
|
|
31
|
+
|
|
32
|
+
def train(self):
|
|
33
|
+
self.model.train()
|
|
34
|
+
|
|
35
|
+
def evaluate(self, data, do_save_fig=True):
|
|
36
|
+
recall, precision, f1_score = find_recall_precision_f1_score(self, data)
|
|
37
|
+
mAP = find_mAP(self, data)
|
|
38
|
+
mCS = find_mCS(self, data)
|
|
39
|
+
|
|
40
|
+
if do_save_fig:
|
|
41
|
+
save_recall_precision_f1_score(recall, precision, f1_score)
|
|
42
|
+
|
|
43
|
+
return {'recall': recall, 'precision': precision, 'f1_score': f1_score, 'mAP': mAP, 'mCS': mCS}
|
|
44
|
+
|
|
45
|
+
def get_num_params(self):
|
|
46
|
+
num_params = sum(
|
|
47
|
+
p.numel()
|
|
48
|
+
for p in self.model.parameters()
|
|
49
|
+
)
|
|
50
|
+
return num_params
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
def locals_loss(beta=1.0, gamma=5.0):
|
|
4
|
+
'''
|
|
5
|
+
beta: weight for localization loss
|
|
6
|
+
gamma: weight for objectness loss
|
|
7
|
+
'''
|
|
8
|
+
|
|
9
|
+
def binary_focal_loss(pred, target, alpha=0.25, gamma=2.0, eps=1e-8):
|
|
10
|
+
'''
|
|
11
|
+
pred: predicted probabilities (after sigmoid)
|
|
12
|
+
target: ground truth (0 or 1)
|
|
13
|
+
alpha: weighting factor for class imbalance
|
|
14
|
+
gamma: focusing parameter for hard examples
|
|
15
|
+
'''
|
|
16
|
+
pred = pred.clamp(eps, 1.0 - eps) # avoid log(0)
|
|
17
|
+
pt = pred * target + (1 - pred) * (1 - target)
|
|
18
|
+
loss = - alpha * (1 - pt) ** gamma * (target * torch.log(pred + eps) + (1 - target) * torch.log(1 - pred + eps))
|
|
19
|
+
return loss.mean()
|
|
20
|
+
|
|
21
|
+
def focal_localization_loss(pred_coords, true_coords, mask, alpha=0.25, gamma=2.0, eps=1e-8):
|
|
22
|
+
d = torch.sqrt(torch.sum((pred_coords - true_coords) ** 2, dim=-1))
|
|
23
|
+
pt = 1 - torch.sigmoid(20 * (d - 0.1))
|
|
24
|
+
pt = pt.clamp(eps, 1.0 - eps)
|
|
25
|
+
loss = - alpha * (1 - pt) ** gamma * mask * torch.log(pt)
|
|
26
|
+
if mask.sum() == 0:
|
|
27
|
+
return 0
|
|
28
|
+
|
|
29
|
+
return (loss).sum() / mask.sum()
|
|
30
|
+
|
|
31
|
+
# actual loss function
|
|
32
|
+
def loss_func(predicted, true):
|
|
33
|
+
# goal is to sum each loss for each prediction in each batch
|
|
34
|
+
loc_loss = 0
|
|
35
|
+
obj_loss = 0
|
|
36
|
+
|
|
37
|
+
# iterate through each image in the batch
|
|
38
|
+
for i in range(true.shape[0]):
|
|
39
|
+
ith_predicted = predicted[i]
|
|
40
|
+
ith_true = true[i]
|
|
41
|
+
|
|
42
|
+
obj_mask = ith_true[..., 2]
|
|
43
|
+
true_coordinates = ith_true[..., :2]
|
|
44
|
+
|
|
45
|
+
obj_pred = ith_predicted[..., 2]
|
|
46
|
+
pred_coordinates = ith_predicted[..., :2]
|
|
47
|
+
|
|
48
|
+
# find localization loss
|
|
49
|
+
loc_loss += focal_localization_loss(pred_coordinates, true_coordinates, obj_mask)
|
|
50
|
+
|
|
51
|
+
# find objectness loss
|
|
52
|
+
ith_obj_loss = binary_focal_loss(obj_pred, obj_mask)
|
|
53
|
+
obj_loss += ith_obj_loss
|
|
54
|
+
|
|
55
|
+
# first find mean loss
|
|
56
|
+
loc_loss /= true.shape[0]
|
|
57
|
+
obj_loss /= true.shape[0]
|
|
58
|
+
|
|
59
|
+
# then find total loss
|
|
60
|
+
total_loss = beta * loc_loss + gamma * obj_loss
|
|
61
|
+
return total_loss
|
|
62
|
+
|
|
63
|
+
return loss_func
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
def smooth_curve(points, factor=0.9):
|
|
6
|
+
# applies exponential smoothing to a curve
|
|
7
|
+
smoothed = []
|
|
8
|
+
for point in points:
|
|
9
|
+
if smoothed:
|
|
10
|
+
smoothed.append(smoothed[-1] * factor + point * (1 - factor))
|
|
11
|
+
else:
|
|
12
|
+
smoothed.append(point)
|
|
13
|
+
return smoothed
|
|
14
|
+
|
|
15
|
+
def closest_factors(n):
|
|
16
|
+
d = math.isqrt(n) # integer sqrt
|
|
17
|
+
while n % d != 0:
|
|
18
|
+
d -= 1
|
|
19
|
+
return d, n // d
|
|
20
|
+
|
|
21
|
+
# function that calculates the pearson correlation coefficient given a list of points
|
|
22
|
+
def pearson_corr(points):
|
|
23
|
+
points_array = np.array(points)
|
|
24
|
+
|
|
25
|
+
x = points_array[:, 0]
|
|
26
|
+
y = points_array[:, 1]
|
|
27
|
+
|
|
28
|
+
corr = np.corrcoef(x, y)[0, 1]
|
|
29
|
+
return corr
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from torch.utils.flop_counter import FlopCounterMode
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import json
|
|
7
|
+
import uuid
|
|
8
|
+
|
|
9
|
+
from .locals import LOCALS
|
|
10
|
+
from .dataset import LOCALSDataset
|
|
11
|
+
|
|
12
|
+
def save_run_metrics(run_details: dict):
|
|
13
|
+
os.makedirs('runs', exist_ok=True)
|
|
14
|
+
|
|
15
|
+
run_id = str(uuid.uuid4())
|
|
16
|
+
|
|
17
|
+
with open(f'runs/run_metrics_{run_id}.json', 'w') as run_metrics_file:
|
|
18
|
+
json.dump(run_details, run_metrics_file)
|
|
19
|
+
|
|
20
|
+
def run(*, seed: int,
|
|
21
|
+
train_split: float,
|
|
22
|
+
test_split: float,
|
|
23
|
+
images_dir: str,
|
|
24
|
+
labels_dir: str,
|
|
25
|
+
architecture='n',
|
|
26
|
+
num_epochs=50,
|
|
27
|
+
epoch_to_lr={}):
|
|
28
|
+
|
|
29
|
+
# NumPy
|
|
30
|
+
np.random.seed(seed)
|
|
31
|
+
|
|
32
|
+
# PyTorch
|
|
33
|
+
torch.manual_seed(seed)
|
|
34
|
+
if torch.cuda.is_available():
|
|
35
|
+
torch.cuda.manual_seed(seed)
|
|
36
|
+
torch.cuda.manual_seed_all(seed) # for multi-GPU
|
|
37
|
+
|
|
38
|
+
torch.backends.cudnn.deterministic = True
|
|
39
|
+
torch.backends.cudnn.benchmark = False
|
|
40
|
+
|
|
41
|
+
dataset = LOCALSDataset(images_dir, labels_dir)
|
|
42
|
+
train_loader, test_loader, val_loader = dataset.get_dataloaders(train_split, test_split)
|
|
43
|
+
|
|
44
|
+
model = LOCALS(architecture)
|
|
45
|
+
model.fit(train_loader, num_epochs=num_epochs, val_loader=val_loader, epoch_to_lr=epoch_to_lr)
|
|
46
|
+
|
|
47
|
+
flop_counter = FlopCounterMode(display=False)
|
|
48
|
+
with flop_counter:
|
|
49
|
+
for images, labels in test_loader:
|
|
50
|
+
outputs = model(images.to(model.device))
|
|
51
|
+
break
|
|
52
|
+
total_flops = flop_counter.get_total_flops()
|
|
53
|
+
|
|
54
|
+
metrics= model.evaluate(test_loader)
|
|
55
|
+
run_details = {'seed': seed,
|
|
56
|
+
'train_split': train_split,
|
|
57
|
+
'test_split': test_split,
|
|
58
|
+
'recall': metrics['recall'],
|
|
59
|
+
'precision': metrics['precision'],
|
|
60
|
+
'f1_score': metrics['f1_score'],
|
|
61
|
+
'mAP': metrics['mAP'],
|
|
62
|
+
'mCS': metrics['mCS'],
|
|
63
|
+
'num_params': model.get_num_params(),
|
|
64
|
+
'total_flops': total_flops}
|
|
65
|
+
|
|
66
|
+
save_run_metrics(run_details)
|
|
67
|
+
return run_details
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from tqdm import tqdm
|
|
3
|
+
from torch.optim import Adam
|
|
4
|
+
from torch.utils.data import DataLoader
|
|
5
|
+
|
|
6
|
+
from .losses import locals_loss
|
|
7
|
+
|
|
8
|
+
def train(train_loader: DataLoader,
|
|
9
|
+
num_epochs: int,
|
|
10
|
+
val_loader: DataLoader,
|
|
11
|
+
model: torch.nn.Module,
|
|
12
|
+
device='cuda',
|
|
13
|
+
epoch_to_lr = {}):
|
|
14
|
+
|
|
15
|
+
model.to(device)
|
|
16
|
+
torch.save(model.state_dict(), "best.pth")
|
|
17
|
+
|
|
18
|
+
optimizer = Adam(model.parameters(), lr=1e-3)
|
|
19
|
+
criterion = locals_loss()
|
|
20
|
+
validate = val_loader is not None
|
|
21
|
+
|
|
22
|
+
train_loss_ot = []
|
|
23
|
+
val_loss_ot = []
|
|
24
|
+
avg_train_loss = 0
|
|
25
|
+
avg_val_loss = 0
|
|
26
|
+
min_val_loss = float('inf')
|
|
27
|
+
|
|
28
|
+
for epoch in range(1, num_epochs + 1):
|
|
29
|
+
if epoch in epoch_to_lr:
|
|
30
|
+
new_lr = epoch_to_lr[epoch]
|
|
31
|
+
for param_group in optimizer.param_groups:
|
|
32
|
+
param_group["lr"] = new_lr
|
|
33
|
+
if validate:
|
|
34
|
+
model.load_state_dict(torch.load("best.pth", weights_only=True))
|
|
35
|
+
|
|
36
|
+
model.train()
|
|
37
|
+
total_loss = 0
|
|
38
|
+
|
|
39
|
+
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")
|
|
40
|
+
|
|
41
|
+
for inputs, targets in pbar:
|
|
42
|
+
inputs = inputs.to(device)
|
|
43
|
+
targets = targets.to(device)
|
|
44
|
+
|
|
45
|
+
optimizer.zero_grad()
|
|
46
|
+
outputs = model(inputs)
|
|
47
|
+
|
|
48
|
+
loss = criterion(outputs, targets)
|
|
49
|
+
loss.backward()
|
|
50
|
+
optimizer.step()
|
|
51
|
+
|
|
52
|
+
total_loss += loss.item()
|
|
53
|
+
|
|
54
|
+
avg_train_loss = total_loss / len(train_loader)
|
|
55
|
+
train_loss_ot.append(avg_train_loss)
|
|
56
|
+
print(f'Avg Training Loss = {avg_train_loss}')
|
|
57
|
+
|
|
58
|
+
if validate:
|
|
59
|
+
model.eval()
|
|
60
|
+
with torch.no_grad():
|
|
61
|
+
total_val_loss = 0
|
|
62
|
+
|
|
63
|
+
for images, labels in val_loader:
|
|
64
|
+
images = images.to(device)
|
|
65
|
+
labels = labels.to(device)
|
|
66
|
+
|
|
67
|
+
outputs = model(images)
|
|
68
|
+
|
|
69
|
+
loss = criterion(outputs, labels)
|
|
70
|
+
total_val_loss += loss.item()
|
|
71
|
+
|
|
72
|
+
avg_val_loss = total_val_loss / len(val_loader)
|
|
73
|
+
val_loss_ot.append(avg_val_loss)
|
|
74
|
+
if avg_val_loss < min_val_loss:
|
|
75
|
+
min_val_loss = avg_val_loss
|
|
76
|
+
torch.save(model.state_dict(), "best.pth")
|
|
77
|
+
|
|
78
|
+
print(f'Val Loss = {avg_val_loss}')
|
|
79
|
+
|
|
80
|
+
if not validate:
|
|
81
|
+
torch.save(model.state_dict(), "best.pth")
|
|
82
|
+
|
|
83
|
+
return train_loss_ot, val_loss_ot
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
src/locals/__init__.py
|
|
4
|
+
src/locals/architectures.py
|
|
5
|
+
src/locals/blocks.py
|
|
6
|
+
src/locals/constants.py
|
|
7
|
+
src/locals/dataset.py
|
|
8
|
+
src/locals/evaluator.py
|
|
9
|
+
src/locals/figmaker.py
|
|
10
|
+
src/locals/locals.py
|
|
11
|
+
src/locals/losses.py
|
|
12
|
+
src/locals/math.py
|
|
13
|
+
src/locals/runner.py
|
|
14
|
+
src/locals/trainer.py
|
|
15
|
+
src/locals_api.egg-info/PKG-INFO
|
|
16
|
+
src/locals_api.egg-info/SOURCES.txt
|
|
17
|
+
src/locals_api.egg-info/dependency_links.txt
|
|
18
|
+
src/locals_api.egg-info/requires.txt
|
|
19
|
+
src/locals_api.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
locals
|