locals-api 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,7 @@
1
+ Metadata-Version: 2.4
2
+ Name: locals-api
3
+ Version: 0.1.0
4
+ Requires-Python: >=3.10
5
+ Requires-Dist: numpy
6
+ Requires-Dist: opencv-python
7
+ Requires-Dist: matplotlib
@@ -0,0 +1,28 @@
1
+ # LOCALS
2
+
3
+ Rapid experimentation framework for the LOCALS detector.
4
+
5
+ ## Installation
6
+
7
+ Install PyTorch according to your CUDA version.
8
+
9
+ Then:
10
+
11
+ ```bash
12
+ pip install locals-api
13
+ ```
14
+
15
+ ## Example
16
+
17
+ ```python
18
+ from locals import run
19
+
20
+ run(
21
+ seed=1111,
22
+ train_split=0.8,
23
+ test_split=0.1,
24
+ images_dir="dataset/images",
25
+ labels_dir="dataset/labels",
26
+ num_epochs=50,
27
+ )
28
+ ```
@@ -0,0 +1,20 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "locals-api"
7
+ version = "0.1.0"
8
+ requires-python = ">=3.10"
9
+
10
+ dependencies = [
11
+ "numpy",
12
+ "opencv-python",
13
+ "matplotlib",
14
+ ]
15
+
16
+ [tool.setuptools]
17
+ package-dir = {"" = "src"}
18
+
19
+ [tool.setuptools.packages.find]
20
+ where = ["src"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,6 @@
1
+ from .locals import LOCALS
2
+ from .dataset import LOCALSDataset
3
+ from .figmaker import visualise_dataset, visualise_predictions
4
+ from .runner import run
5
+
6
+ __all__ = ['LOCALS', 'LOCALSDataset', 'visualise_dataset', 'visualise_predictions', 'run']
@@ -0,0 +1,65 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .blocks import ResnetBlock, ConvBlock
5
+
6
+ class LOCALSn(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ self.model = nn.Sequential(
11
+ ConvBlock(input_channels=3, output_channels=16, kernel_size=2, stride=2), # 224
12
+ ResnetBlock(input_channels=16, output_channels=32), # 112
13
+ ResnetBlock(input_channels=32, output_channels=64), # 56
14
+ ConvBlock(input_channels=64, output_channels=128, kernel_size=2, stride=2), # 28
15
+ ResnetBlock(input_channels=128, output_channels=256), # 14
16
+ ResnetBlock(input_channels=256, output_channels=256), # 7
17
+ )
18
+
19
+ # 1x1 conv is important here because we do not want to do feature extraction in heads
20
+
21
+ self.head = nn.Sequential(
22
+ ConvBlock(input_channels=256, output_channels=3, activation=False),
23
+ )
24
+
25
+ def forward(self, x):
26
+ out = self.model(x)
27
+ out = self.head(out)
28
+ out = out.permute(0, 2, 3, 1) # [B, 7, 7, 3]
29
+ out = torch.sigmoid(out)
30
+ return out
31
+
32
+ class LOCALSs(nn.Module):
33
+ def __init__(self):
34
+ super().__init__()
35
+
36
+ self.model = nn.Sequential(
37
+ ConvBlock(input_channels=3, output_channels=16, kernel_size=3, padding=1), # 448
38
+ ResnetBlock(input_channels=16, output_channels=32), # 224
39
+ ResnetBlock(input_channels=32, output_channels=64), # 112
40
+ ResnetBlock(input_channels=64, output_channels=128), # 56
41
+ ResnetBlock(input_channels=128, output_channels=256), # 28
42
+ ResnetBlock(input_channels=256, output_channels=512), # 14
43
+ ResnetBlock(input_channels=512, output_channels=256), # 7
44
+ )
45
+
46
+ # 1x1 conv is important here because we do not want to do feature extraction in heads
47
+
48
+ self.loc_head = nn.Sequential(
49
+ ConvBlock(input_channels=256, output_channels=64),
50
+ ConvBlock(input_channels=64, output_channels=16),
51
+ ConvBlock(input_channels=16, output_channels=2, activation=False),
52
+ )
53
+
54
+ self.class_head = nn.Sequential(
55
+ ConvBlock(input_channels=256, output_channels=1, activation=False),
56
+ )
57
+
58
+ def forward(self, x):
59
+ out = self.model(x)
60
+ loc_info = self.loc_head(out) # [B, 2, 7, 7]
61
+ class_info = self.class_head(out) # [B, 1, 7, 7]
62
+ out = torch.concatenate([loc_info, class_info], dim=1) # [B, 3, 7, 7]
63
+ out = out.permute(0, 2, 3, 1) # [B, 7, 7, 3]
64
+ out = torch.sigmoid(out)
65
+ return out
@@ -0,0 +1,56 @@
1
+ import torch.nn as nn
2
+
3
+ '''
4
+ This block models the skip connection introduced in the Resnet paper 2015.
5
+ '''
6
+
7
+ class ResnetBlock(nn.Module):
8
+ def __init__(self, input_channels, output_channels):
9
+ super().__init__()
10
+
11
+ self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, padding=1)
12
+ self.conv2 = nn.Conv2d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, stride=2, padding=1)
13
+ self.downsamp_conv = nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=2, stride=2)
14
+ self.relu = nn.ReLU()
15
+
16
+ self.bn1 = nn.BatchNorm2d(output_channels)
17
+ self.bn2 = nn.BatchNorm2d(output_channels)
18
+ self.downsamp_bn = nn.BatchNorm2d(output_channels)
19
+
20
+ def forward(self, x):
21
+ # feature extraction
22
+ out = self.conv1(x)
23
+ out = self.bn1(out)
24
+ out = self.relu(out)
25
+
26
+ # reduce H and W by 2
27
+ out = self.conv2(out)
28
+ out = self.bn2(out)
29
+
30
+ # bring x to required number of channels
31
+ downsamp_x = self.downsamp_conv(x)
32
+ downsamp_x = self.downsamp_bn(downsamp_x)
33
+
34
+ # skip connection
35
+ return self.relu(out + downsamp_x)
36
+
37
+ '''
38
+ Simple 1x1 convolution block, nothing much to see here.
39
+ '''
40
+
41
+ class ConvBlock(nn.Module):
42
+ def __init__(self, input_channels, output_channels, kernel_size=1, padding=0, stride=1, activation=True):
43
+ super().__init__()
44
+
45
+ self.channel_downsamp_conv = nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=kernel_size, padding=padding, stride=stride)
46
+ self.channel_downsamp_bn = nn.BatchNorm2d(output_channels)
47
+ self.relu = nn.ReLU()
48
+ self.activation = activation
49
+
50
+ def forward(self, x):
51
+ out = self.channel_downsamp_conv(x)
52
+ out = self.channel_downsamp_bn(out)
53
+ if self.activation:
54
+ out = self.relu(out)
55
+
56
+ return out
@@ -0,0 +1 @@
1
+ NUM_GRID_CELLS = 7
@@ -0,0 +1,78 @@
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from torch.utils.data import DataLoader, Dataset, random_split
5
+
6
+ import os
7
+
8
+ class LOCALSDataset(Dataset):
9
+ def __init__(self, images_dir, labels_dir):
10
+ self.image_list = [
11
+ f_name for f_name in os.listdir(images_dir)
12
+ if f_name.lower().endswith((".png", ".jpg", ".jpeg"))
13
+ ]
14
+ self.images_dir = images_dir
15
+ self.labels_dir = labels_dir
16
+ self.image_size=448
17
+
18
+ def __len__(self):
19
+ return len(self.image_list)
20
+
21
+ def __getitem__(self, idx):
22
+ image_name = self.image_list[idx]
23
+
24
+ image = cv2.imread(
25
+ os.path.join(self.images_dir,
26
+ image_name)
27
+ )
28
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
29
+ image = cv2.resize(image, (self.image_size, self.image_size))
30
+ image = image.astype(np.float32) / 255.0
31
+ image = torch.from_numpy(image).permute(2, 0, 1)
32
+
33
+ prefix = image_name.split('.')[0]
34
+ annotation_matrix = np.load(os.path.join(self.labels_dir, f'{prefix}.npy'))
35
+
36
+ return image, annotation_matrix
37
+
38
+ def get_dataloaders(self, train_split = 0.8, test_split = 0.1, batch_size=16):
39
+ assert (
40
+ 0 < train_split < 1
41
+ and 0 < test_split < 1
42
+ and train_split + test_split <= 1
43
+ ), "train_split + test_split must not exceed 1."
44
+
45
+ N = len(self)
46
+ train_size = int(train_split * N)
47
+ test_size = int(N * test_split)
48
+ val_size = N - (train_size + test_size)
49
+
50
+ train_dataset, test_dataset, val_dataset = random_split(
51
+ self,
52
+ [train_size, test_size, val_size]
53
+ )
54
+
55
+ train_loader = DataLoader(
56
+ train_dataset,
57
+ batch_size=batch_size,
58
+ shuffle=True,
59
+ pin_memory=True
60
+ )
61
+
62
+ test_loader = DataLoader(
63
+ test_dataset,
64
+ batch_size=1,
65
+ shuffle=True,
66
+ pin_memory=True
67
+ )
68
+
69
+ val_loader = DataLoader(
70
+ val_dataset,
71
+ batch_size=batch_size,
72
+ shuffle=False,
73
+ pin_memory=True
74
+ ) if val_size > 0 else None
75
+
76
+ if val_size > 0:
77
+ return train_loader, test_loader, val_loader
78
+ return train_loader, test_loader
@@ -0,0 +1,192 @@
1
+ import torch
2
+ import numpy as np
3
+ from torch.utils.data import DataLoader
4
+
5
+ from .math import pearson_corr
6
+ from .constants import NUM_GRID_CELLS
7
+
8
+ def find_recall_precision_f1_score(model, data, threshold=0.5, num_batches=100):
9
+ if not isinstance(data, DataLoader):
10
+ data = DataLoader(
11
+ data,
12
+ batch_size=1,
13
+ shuffle=False
14
+ )
15
+ model.eval()
16
+
17
+ true_positives = 0
18
+ false_positives = 0
19
+ false_negatives = 0
20
+
21
+ with torch.no_grad():
22
+ for batch_idx, (images, labels) in enumerate(data):
23
+ if batch_idx >= num_batches:
24
+ break
25
+
26
+ images_tensor = images.to(model.device)
27
+ predictions_batch = model(images_tensor) # [B, 7, 7, 3]
28
+ batch_size = images_tensor.shape[0]
29
+
30
+ for i in range(batch_size):
31
+ prediction = predictions_batch[i]
32
+ prediction = prediction.cpu().numpy()
33
+
34
+ label = labels[i]
35
+ label_numpy = label.cpu().numpy()
36
+
37
+ # compare predictions vs labels
38
+ for j in range(7):
39
+ for k in range(7):
40
+ pred_obj = prediction[j, k, -1]
41
+ label_obj = label_numpy[j, k, -1]
42
+
43
+ if pred_obj > threshold and label_obj > 0:
44
+ true_positives += 1
45
+ elif pred_obj > threshold and label_obj == 0:
46
+ false_positives += 1
47
+ elif pred_obj <= threshold and label_obj > 0:
48
+ false_negatives += 1
49
+
50
+ recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
51
+ precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
52
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
53
+
54
+ return recall, precision, f1_score
55
+
56
+ def find_mAP(model, data, class_threshold=0.5):
57
+ if not isinstance(data, DataLoader):
58
+ data = DataLoader(
59
+ data,
60
+ batch_size=1,
61
+ shuffle=False
62
+ )
63
+ model.eval()
64
+
65
+ dist_thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
66
+ APs = []
67
+
68
+ for dist_threshold in dist_thresholds:
69
+ all_detections = []
70
+ num_trues = 0
71
+ for batched_images, batched_labels in data:
72
+ batched_images = batched_images.to(model.device)
73
+ batched_predictions = model(batched_images).detach().cpu()
74
+ batch_size = batched_images.shape[0]
75
+
76
+ for b in range(batch_size):
77
+ predictions = batched_predictions[b]
78
+ labels = batched_labels[b]
79
+
80
+ predictions = predictions.reshape(-1, predictions.shape[-1])
81
+ trues = labels.reshape(-1, labels.shape[-1])
82
+ taken_trues = [False] * trues.shape[0]
83
+
84
+ num_trues += torch.sum(trues[:, 2] == 1).item()
85
+ indices = torch.arange(predictions.shape[0])
86
+ order = predictions[:, -1].argsort(descending=True)
87
+ predictions = predictions[order]
88
+ indices = indices[order]
89
+
90
+ for i in range(predictions.shape[0]):
91
+ pred_xnb, pred_ynb, pred_conf = predictions[i]
92
+ cell_idx = indices[i]
93
+ row = cell_idx // NUM_GRID_CELLS
94
+ col = cell_idx % NUM_GRID_CELLS
95
+ pred_x = ((col / NUM_GRID_CELLS) + (pred_xnb * (1 / NUM_GRID_CELLS)))
96
+ pred_y = ((row / NUM_GRID_CELLS) + (pred_ynb * (1 / NUM_GRID_CELLS)))
97
+
98
+ if pred_conf >= class_threshold:
99
+ closest_idx = -1
100
+ closest_distance = torch.inf
101
+ for j in range(trues.shape[0]):
102
+ true_xnb, true_ynb, true_conf = trues[j]
103
+ row = j // NUM_GRID_CELLS
104
+ col = j % NUM_GRID_CELLS
105
+ true_x = ((col / NUM_GRID_CELLS) + (true_xnb * (1 / NUM_GRID_CELLS)))
106
+ true_y = ((row / NUM_GRID_CELLS) + (true_ynb * (1 / NUM_GRID_CELLS)))
107
+
108
+ if true_conf >= 0.5:
109
+ pred_coord = torch.tensor([pred_x, pred_y])
110
+ true_coord = torch.tensor([true_x, true_y])
111
+
112
+ dist = torch.sqrt(torch.sum((pred_coord - true_coord) ** 2, dim=-1))
113
+ if dist < closest_distance:
114
+ closest_distance = dist
115
+ closest_idx = j
116
+
117
+ dist_conf = 1 - torch.sigmoid(69 * (closest_distance - 0.1))
118
+ if closest_idx != -1 and not taken_trues[closest_idx] and dist_conf >= dist_threshold:
119
+ taken_trues[closest_idx] = True
120
+ all_detections.append((pred_conf.item(), 1, 0))
121
+ else:
122
+ all_detections.append((pred_conf.item(), 0, 1))
123
+
124
+ all_detections.sort(key=lambda x: x[0], reverse=True)
125
+ TP = [d[1] for d in all_detections]
126
+ FP = [d[2] for d in all_detections]
127
+
128
+ TP_cum = np.cumsum(TP)
129
+ FP_cum = np.cumsum(FP)
130
+
131
+ precisions = TP_cum / (TP_cum + FP_cum + 1e-6)
132
+ recalls = TP_cum / (num_trues + 1e-6)
133
+ AP = np.trapezoid(precisions, recalls)
134
+ APs.append(AP)
135
+
136
+ return sum(APs)/len(APs)
137
+
138
+ def find_mCS(model, data, threshold=0.5, num_batches=100):
139
+ if not isinstance(data, DataLoader):
140
+ data = DataLoader(
141
+ data,
142
+ batch_size=1,
143
+ shuffle=False
144
+ )
145
+ model.eval()
146
+
147
+ correlations = []
148
+
149
+ with torch.no_grad():
150
+ for batch_idx, (images, labels) in enumerate(data):
151
+ if batch_idx >= num_batches:
152
+ break
153
+
154
+ images_tensor = images.to(model.device)
155
+
156
+ predictions_batch = model(images_tensor)
157
+ batch_size = images_tensor.shape[0]
158
+
159
+ for i in range(batch_size):
160
+ prediction = predictions_batch[i]
161
+ prediction = prediction.cpu().numpy()
162
+
163
+ label = labels[i]
164
+
165
+ # extract predicted points
166
+ predicted_points = []
167
+ for row in range(prediction.shape[0]):
168
+ for col in range(prediction.shape[1]):
169
+ cell = prediction[row][col]
170
+ if cell[-1] > threshold:
171
+ xnb, ynb, c = cell
172
+ xn = ((col / NUM_GRID_CELLS) + (xnb * (1 / NUM_GRID_CELLS)))
173
+ yn = ((row / NUM_GRID_CELLS) + (ynb * (1 / NUM_GRID_CELLS)))
174
+ predicted_points.append([xn, yn])
175
+
176
+ # extract label points
177
+ label_points = []
178
+ for row in range(label.shape[0]):
179
+ for col in range(label.shape[1]):
180
+ cell = label[row][col]
181
+ if cell[-1] > 0:
182
+ xnb, ynb, c = cell
183
+ xn = ((col / NUM_GRID_CELLS) + (xnb * (1 / NUM_GRID_CELLS)))
184
+ yn = ((row / NUM_GRID_CELLS) + (ynb * (1 / NUM_GRID_CELLS)))
185
+ label_points.append([xn, yn])
186
+
187
+ if not predicted_points:
188
+ correlations.append(0)
189
+ else:
190
+ correlations.append(abs(pearson_corr(label_points + predicted_points)))
191
+
192
+ return np.mean(correlations) if correlations else 0.0
@@ -0,0 +1,212 @@
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ from torch.utils.data import DataLoader
4
+
5
+ import os
6
+
7
+ from .constants import NUM_GRID_CELLS
8
+ from .math import smooth_curve, closest_factors
9
+
10
+ image_size = 448
11
+
12
+ def save_loss_fig(train_loss_ot: list, val_loss_ot: list):
13
+ os.makedirs('figures', exist_ok=True)
14
+
15
+ # smooth the losses
16
+ smooth_train_loss = smooth_curve(train_loss_ot)
17
+ smooth_val_loss = smooth_curve(val_loss_ot) if val_loss_ot else None
18
+
19
+ # plot the smoothed losses
20
+ epochs = range(1, len(train_loss_ot) + 1)
21
+ plt.figure(figsize=(10, 6), dpi=300)
22
+ plt.plot(epochs, smooth_train_loss, color='blue', linewidth=2)
23
+ if val_loss_ot:
24
+ plt.plot(epochs, smooth_val_loss, color='red', linewidth=2)
25
+ plt.title(r'Smoothed Training Loss Over Epochs', fontsize=14)
26
+ plt.xlabel(r'Epoch', fontsize=12)
27
+ plt.ylabel(r'Loss', fontsize=12)
28
+ plt.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.7)
29
+ plt.xticks(fontsize=10)
30
+ plt.yticks(fontsize=10)
31
+ plt.tight_layout()
32
+ plt.savefig("figures/smoothed_training_loss.png", dpi=300)
33
+
34
+ def visualise_dataset(data, num_images, plot_title=None):
35
+ if not isinstance(data, DataLoader):
36
+ data = DataLoader(
37
+ data,
38
+ batch_size=1,
39
+ shuffle=False
40
+ )
41
+
42
+ rows, cols = closest_factors(num_images)
43
+ fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, max(3 * rows, 4)), squeeze=False)
44
+ fig.suptitle('Visualising Dataset of LOCALS for Pencils') if not plot_title else fig.suptitle(plot_title)
45
+ image_idx = 0
46
+ axes = axes.flatten()
47
+
48
+ for images, labels in data:
49
+ if image_idx >= num_images: break
50
+ for i, image in enumerate(images):
51
+ if image_idx >= num_images: break
52
+ image_plt = image.permute(1, 2, 0)
53
+ ax = axes[image_idx]
54
+
55
+ ax.axis('off')
56
+ label = labels[i]
57
+ ax.set_title(f'Sample {image_idx + 1}')
58
+ ax.imshow(image_plt)
59
+
60
+ grid_size = image_size / NUM_GRID_CELLS
61
+ for g in range(NUM_GRID_CELLS + 1):
62
+ pos = g * grid_size
63
+
64
+ # vertical lines
65
+ ax.axvline(
66
+ x=pos,
67
+ color='white',
68
+ linewidth=0.5,
69
+ alpha=0.5
70
+ )
71
+
72
+ # horizontal lines
73
+ ax.axhline(
74
+ y=pos,
75
+ color='white',
76
+ linewidth=0.5,
77
+ alpha=0.5
78
+ )
79
+
80
+ for j in range(NUM_GRID_CELLS):
81
+ for k in range(NUM_GRID_CELLS):
82
+ xnb, ynb, c = label[j][k]
83
+ if c < 0.5:
84
+ continue
85
+ x = ((k / NUM_GRID_CELLS) + (xnb * (1 / NUM_GRID_CELLS))) * image_size
86
+ y = ((j / NUM_GRID_CELLS) + (ynb * (1 / NUM_GRID_CELLS))) * image_size
87
+
88
+ ax.scatter(x, y, color='red', marker='x', s=40)
89
+ image_idx += 1
90
+
91
+ plt.savefig('figures/visualised_dataset.png', dpi=300, bbox_inches='tight') if not plot_title else \
92
+ plt.savefig(f"figures/visualised_dataset_{plot_title.lower().replace(' ', '_')}.png", dpi=300, bbox_inches='tight')
93
+ plt.show()
94
+
95
+ def visualise_predictions(model, data, num_images, plot_title=None):
96
+ if not isinstance(data, DataLoader):
97
+ data = DataLoader(
98
+ data,
99
+ batch_size=1,
100
+ shuffle=False
101
+ )
102
+ model.eval()
103
+
104
+ with torch.no_grad():
105
+ rows, cols = closest_factors(num_images)
106
+ fig, axes = plt.subplots(
107
+ rows,
108
+ cols,
109
+ figsize=(3 * cols, 3 * rows + 1),
110
+ squeeze=False
111
+ )
112
+ fig.suptitle('Visualising Predictions of LOCALS for Pencils') if not plot_title else fig.suptitle(plot_title)
113
+ axes = axes.flatten()
114
+ image_idx = 0
115
+
116
+ for images, labels in data:
117
+ if image_idx >= num_images: break
118
+ outputs = model(images.to(model.device))
119
+ for i, image in enumerate(images):
120
+ image_plt = image.permute(1, 2, 0)
121
+ label = labels[i]
122
+ output = outputs[i].detach().cpu()
123
+ ax = axes[image_idx]
124
+ ax.axis('off')
125
+ ax.imshow(image_plt.detach().cpu())
126
+ ax.set_title(f'Prediction {image_idx + 1}')
127
+ grid_size = image_size / NUM_GRID_CELLS
128
+
129
+ for g in range(NUM_GRID_CELLS + 1):
130
+ pos = g * grid_size
131
+
132
+ # vertical lines
133
+ ax.axvline(
134
+ x=pos,
135
+ color='white',
136
+ linewidth=0.5,
137
+ alpha=0.5
138
+ )
139
+
140
+ # horizontal lines
141
+ ax.axhline(
142
+ y=pos,
143
+ color='white',
144
+ linewidth=0.5,
145
+ alpha=0.5
146
+ )
147
+
148
+ for j in range(NUM_GRID_CELLS):
149
+ for k in range(NUM_GRID_CELLS):
150
+
151
+ # ground truth
152
+ xnb, ynb, c = label[j][k]
153
+
154
+ if c.item() > 0.5:
155
+
156
+ x = ((k / NUM_GRID_CELLS) + (xnb * (1 / NUM_GRID_CELLS))) * image_size
157
+ y = ((j / NUM_GRID_CELLS) + (ynb * (1 / NUM_GRID_CELLS))) * image_size
158
+ ax.scatter(
159
+ x.detach().cpu(),
160
+ y.detach().cpu(),
161
+ color='yellow',
162
+ marker='x',
163
+ s=40,
164
+ alpha=1
165
+ )
166
+
167
+ # prediction
168
+ pxnb, pynb, pc = output[j][k]
169
+
170
+ if pc.item() >= 0.5:
171
+
172
+ x = ((k / NUM_GRID_CELLS) + (pxnb * (1 / NUM_GRID_CELLS))) * image_size
173
+ y = ((j / NUM_GRID_CELLS) + (pynb * (1 / NUM_GRID_CELLS))) * image_size
174
+
175
+ ax.scatter(
176
+ x.detach().cpu(),
177
+ y.detach().cpu(),
178
+ color='red',
179
+ marker='x',
180
+ s=40,
181
+ alpha=0.5,
182
+ )
183
+
184
+ ax.text(
185
+ x.detach().cpu() + 5, # x offset
186
+ y.detach().cpu() + 5, # y offset
187
+ f'{pc.detach().cpu():.2f}',
188
+ color='white',
189
+ fontsize=10
190
+ )
191
+
192
+ image_idx += 1
193
+
194
+
195
+ fig.tight_layout(rect=[0, 0, 1, 0.95])
196
+ plt.savefig(f"figures/visualised_predictions_{plot_title.lower().replace(' ', '_')}.png", dpi=300, bbox_inches='tight') if plot_title else \
197
+ plt.savefig(f'figures/visualised_predictions.png')
198
+ plt.show()
199
+
200
+ def save_recall_precision_f1_score(recall, precision, f1_score):
201
+ metrics = ['Recall', 'Precision', 'F1 Score']
202
+ values = [recall, precision, f1_score]
203
+
204
+ plt.figure(figsize=(6, 4))
205
+ plt.bar(metrics, values, color=['skyblue', 'lightgreen', 'salmon'])
206
+ plt.ylim(0, 1.1)
207
+ plt.title('Model Classification Performance Metrics')
208
+ plt.ylabel('Score')
209
+ for i, v in enumerate(values):
210
+ plt.text(i, v + 0.02, f"{v:.2f}", ha='center')
211
+ plt.tight_layout()
212
+ plt.savefig("figures/recall-precision-f1score.png", dpi=300)
@@ -0,0 +1,50 @@
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+
4
+ from .trainer import train
5
+ from .architectures import LOCALSn, LOCALSs
6
+ from .figmaker import save_loss_fig, save_recall_precision_f1_score
7
+ from .evaluator import find_recall_precision_f1_score, find_mAP, find_mCS
8
+
9
+ class LOCALS:
10
+ def __init__(self, architecture='n', file_path=None, device='cuda'):
11
+ self.architecture = architecture
12
+ self.device = device
13
+ if architecture == 'n':
14
+ self.model = LOCALSn()
15
+ self.model.to(device)
16
+ if file_path:
17
+ self.model.load_state_dict(torch.load(file_path, weights_only=True))
18
+
19
+ def __call__(self, images):
20
+ return self.model(images)
21
+
22
+ def fit(self, train_loader: DataLoader, *, num_epochs=50, val_loader=None, epoch_to_lr={}, do_save_loss_fig=True):
23
+ train_loss_ot, val_loss_ot = train(train_loader, num_epochs, val_loader, self.model, self.device, epoch_to_lr)
24
+ self.model.load_state_dict(torch.load("best.pth", weights_only=True))
25
+
26
+ if do_save_loss_fig:
27
+ save_loss_fig(train_loss_ot, val_loss_ot)
28
+
29
+ def eval(self):
30
+ self.model.eval()
31
+
32
+ def train(self):
33
+ self.model.train()
34
+
35
+ def evaluate(self, data, do_save_fig=True):
36
+ recall, precision, f1_score = find_recall_precision_f1_score(self, data)
37
+ mAP = find_mAP(self, data)
38
+ mCS = find_mCS(self, data)
39
+
40
+ if do_save_fig:
41
+ save_recall_precision_f1_score(recall, precision, f1_score)
42
+
43
+ return {'recall': recall, 'precision': precision, 'f1_score': f1_score, 'mAP': mAP, 'mCS': mCS}
44
+
45
+ def get_num_params(self):
46
+ num_params = sum(
47
+ p.numel()
48
+ for p in self.model.parameters()
49
+ )
50
+ return num_params
@@ -0,0 +1,63 @@
1
+ import torch
2
+
3
+ def locals_loss(beta=1.0, gamma=5.0):
4
+ '''
5
+ beta: weight for localization loss
6
+ gamma: weight for objectness loss
7
+ '''
8
+
9
+ def binary_focal_loss(pred, target, alpha=0.25, gamma=2.0, eps=1e-8):
10
+ '''
11
+ pred: predicted probabilities (after sigmoid)
12
+ target: ground truth (0 or 1)
13
+ alpha: weighting factor for class imbalance
14
+ gamma: focusing parameter for hard examples
15
+ '''
16
+ pred = pred.clamp(eps, 1.0 - eps) # avoid log(0)
17
+ pt = pred * target + (1 - pred) * (1 - target)
18
+ loss = - alpha * (1 - pt) ** gamma * (target * torch.log(pred + eps) + (1 - target) * torch.log(1 - pred + eps))
19
+ return loss.mean()
20
+
21
+ def focal_localization_loss(pred_coords, true_coords, mask, alpha=0.25, gamma=2.0, eps=1e-8):
22
+ d = torch.sqrt(torch.sum((pred_coords - true_coords) ** 2, dim=-1))
23
+ pt = 1 - torch.sigmoid(20 * (d - 0.1))
24
+ pt = pt.clamp(eps, 1.0 - eps)
25
+ loss = - alpha * (1 - pt) ** gamma * mask * torch.log(pt)
26
+ if mask.sum() == 0:
27
+ return 0
28
+
29
+ return (loss).sum() / mask.sum()
30
+
31
+ # actual loss function
32
+ def loss_func(predicted, true):
33
+ # goal is to sum each loss for each prediction in each batch
34
+ loc_loss = 0
35
+ obj_loss = 0
36
+
37
+ # iterate through each image in the batch
38
+ for i in range(true.shape[0]):
39
+ ith_predicted = predicted[i]
40
+ ith_true = true[i]
41
+
42
+ obj_mask = ith_true[..., 2]
43
+ true_coordinates = ith_true[..., :2]
44
+
45
+ obj_pred = ith_predicted[..., 2]
46
+ pred_coordinates = ith_predicted[..., :2]
47
+
48
+ # find localization loss
49
+ loc_loss += focal_localization_loss(pred_coordinates, true_coordinates, obj_mask)
50
+
51
+ # find objectness loss
52
+ ith_obj_loss = binary_focal_loss(obj_pred, obj_mask)
53
+ obj_loss += ith_obj_loss
54
+
55
+ # first find mean loss
56
+ loc_loss /= true.shape[0]
57
+ obj_loss /= true.shape[0]
58
+
59
+ # then find total loss
60
+ total_loss = beta * loc_loss + gamma * obj_loss
61
+ return total_loss
62
+
63
+ return loss_func
@@ -0,0 +1,29 @@
1
+ import numpy as np
2
+
3
+ import math
4
+
5
+ def smooth_curve(points, factor=0.9):
6
+ # applies exponential smoothing to a curve
7
+ smoothed = []
8
+ for point in points:
9
+ if smoothed:
10
+ smoothed.append(smoothed[-1] * factor + point * (1 - factor))
11
+ else:
12
+ smoothed.append(point)
13
+ return smoothed
14
+
15
+ def closest_factors(n):
16
+ d = math.isqrt(n) # integer sqrt
17
+ while n % d != 0:
18
+ d -= 1
19
+ return d, n // d
20
+
21
+ # function that calculates the pearson correlation coefficient given a list of points
22
+ def pearson_corr(points):
23
+ points_array = np.array(points)
24
+
25
+ x = points_array[:, 0]
26
+ y = points_array[:, 1]
27
+
28
+ corr = np.corrcoef(x, y)[0, 1]
29
+ return corr
@@ -0,0 +1,67 @@
1
+ import torch
2
+ import numpy as np
3
+ from torch.utils.flop_counter import FlopCounterMode
4
+
5
+ import os
6
+ import json
7
+ import uuid
8
+
9
+ from .locals import LOCALS
10
+ from .dataset import LOCALSDataset
11
+
12
+ def save_run_metrics(run_details: dict):
13
+ os.makedirs('runs', exist_ok=True)
14
+
15
+ run_id = str(uuid.uuid4())
16
+
17
+ with open(f'runs/run_metrics_{run_id}.json', 'w') as run_metrics_file:
18
+ json.dump(run_details, run_metrics_file)
19
+
20
+ def run(*, seed: int,
21
+ train_split: float,
22
+ test_split: float,
23
+ images_dir: str,
24
+ labels_dir: str,
25
+ architecture='n',
26
+ num_epochs=50,
27
+ epoch_to_lr={}):
28
+
29
+ # NumPy
30
+ np.random.seed(seed)
31
+
32
+ # PyTorch
33
+ torch.manual_seed(seed)
34
+ if torch.cuda.is_available():
35
+ torch.cuda.manual_seed(seed)
36
+ torch.cuda.manual_seed_all(seed) # for multi-GPU
37
+
38
+ torch.backends.cudnn.deterministic = True
39
+ torch.backends.cudnn.benchmark = False
40
+
41
+ dataset = LOCALSDataset(images_dir, labels_dir)
42
+ train_loader, test_loader, val_loader = dataset.get_dataloaders(train_split, test_split)
43
+
44
+ model = LOCALS(architecture)
45
+ model.fit(train_loader, num_epochs=num_epochs, val_loader=val_loader, epoch_to_lr=epoch_to_lr)
46
+
47
+ flop_counter = FlopCounterMode(display=False)
48
+ with flop_counter:
49
+ for images, labels in test_loader:
50
+ outputs = model(images.to(model.device))
51
+ break
52
+ total_flops = flop_counter.get_total_flops()
53
+
54
+ metrics= model.evaluate(test_loader)
55
+ run_details = {'seed': seed,
56
+ 'train_split': train_split,
57
+ 'test_split': test_split,
58
+ 'recall': metrics['recall'],
59
+ 'precision': metrics['precision'],
60
+ 'f1_score': metrics['f1_score'],
61
+ 'mAP': metrics['mAP'],
62
+ 'mCS': metrics['mCS'],
63
+ 'num_params': model.get_num_params(),
64
+ 'total_flops': total_flops}
65
+
66
+ save_run_metrics(run_details)
67
+ return run_details
@@ -0,0 +1,83 @@
1
+ import torch
2
+ from tqdm import tqdm
3
+ from torch.optim import Adam
4
+ from torch.utils.data import DataLoader
5
+
6
+ from .losses import locals_loss
7
+
8
+ def train(train_loader: DataLoader,
9
+ num_epochs: int,
10
+ val_loader: DataLoader,
11
+ model: torch.nn.Module,
12
+ device='cuda',
13
+ epoch_to_lr = {}):
14
+
15
+ model.to(device)
16
+ torch.save(model.state_dict(), "best.pth")
17
+
18
+ optimizer = Adam(model.parameters(), lr=1e-3)
19
+ criterion = locals_loss()
20
+ validate = val_loader is not None
21
+
22
+ train_loss_ot = []
23
+ val_loss_ot = []
24
+ avg_train_loss = 0
25
+ avg_val_loss = 0
26
+ min_val_loss = float('inf')
27
+
28
+ for epoch in range(1, num_epochs + 1):
29
+ if epoch in epoch_to_lr:
30
+ new_lr = epoch_to_lr[epoch]
31
+ for param_group in optimizer.param_groups:
32
+ param_group["lr"] = new_lr
33
+ if validate:
34
+ model.load_state_dict(torch.load("best.pth", weights_only=True))
35
+
36
+ model.train()
37
+ total_loss = 0
38
+
39
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")
40
+
41
+ for inputs, targets in pbar:
42
+ inputs = inputs.to(device)
43
+ targets = targets.to(device)
44
+
45
+ optimizer.zero_grad()
46
+ outputs = model(inputs)
47
+
48
+ loss = criterion(outputs, targets)
49
+ loss.backward()
50
+ optimizer.step()
51
+
52
+ total_loss += loss.item()
53
+
54
+ avg_train_loss = total_loss / len(train_loader)
55
+ train_loss_ot.append(avg_train_loss)
56
+ print(f'Avg Training Loss = {avg_train_loss}')
57
+
58
+ if validate:
59
+ model.eval()
60
+ with torch.no_grad():
61
+ total_val_loss = 0
62
+
63
+ for images, labels in val_loader:
64
+ images = images.to(device)
65
+ labels = labels.to(device)
66
+
67
+ outputs = model(images)
68
+
69
+ loss = criterion(outputs, labels)
70
+ total_val_loss += loss.item()
71
+
72
+ avg_val_loss = total_val_loss / len(val_loader)
73
+ val_loss_ot.append(avg_val_loss)
74
+ if avg_val_loss < min_val_loss:
75
+ min_val_loss = avg_val_loss
76
+ torch.save(model.state_dict(), "best.pth")
77
+
78
+ print(f'Val Loss = {avg_val_loss}')
79
+
80
+ if not validate:
81
+ torch.save(model.state_dict(), "best.pth")
82
+
83
+ return train_loss_ot, val_loss_ot
@@ -0,0 +1,7 @@
1
+ Metadata-Version: 2.4
2
+ Name: locals-api
3
+ Version: 0.1.0
4
+ Requires-Python: >=3.10
5
+ Requires-Dist: numpy
6
+ Requires-Dist: opencv-python
7
+ Requires-Dist: matplotlib
@@ -0,0 +1,19 @@
1
+ README.md
2
+ pyproject.toml
3
+ src/locals/__init__.py
4
+ src/locals/architectures.py
5
+ src/locals/blocks.py
6
+ src/locals/constants.py
7
+ src/locals/dataset.py
8
+ src/locals/evaluator.py
9
+ src/locals/figmaker.py
10
+ src/locals/locals.py
11
+ src/locals/losses.py
12
+ src/locals/math.py
13
+ src/locals/runner.py
14
+ src/locals/trainer.py
15
+ src/locals_api.egg-info/PKG-INFO
16
+ src/locals_api.egg-info/SOURCES.txt
17
+ src/locals_api.egg-info/dependency_links.txt
18
+ src/locals_api.egg-info/requires.txt
19
+ src/locals_api.egg-info/top_level.txt
@@ -0,0 +1,3 @@
1
+ numpy
2
+ opencv-python
3
+ matplotlib