scribble-annotation-generator 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scribble_annotation_generator/__init__.py +0 -0
- scribble_annotation_generator/cli.py +195 -0
- scribble_annotation_generator/crop_field.py +366 -0
- scribble_annotation_generator/dataset.py +96 -0
- scribble_annotation_generator/debug.py +43 -0
- scribble_annotation_generator/nn.py +570 -0
- scribble_annotation_generator/utils.py +495 -0
- scribble_annotation_generator-0.0.1.dist-info/METADATA +108 -0
- scribble_annotation_generator-0.0.1.dist-info/RECORD +11 -0
- scribble_annotation_generator-0.0.1.dist-info/WHEEL +4 -0
- scribble_annotation_generator-0.0.1.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,495 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import numpy as np
|
|
3
|
+
import cv2
|
|
4
|
+
from typing import Tuple, List, Optional
|
|
5
|
+
from skimage.morphology import skeletonize
|
|
6
|
+
from scipy.interpolate import splprep, splev
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
FEATURE_TO_KEY = {
|
|
10
|
+
"start_x": 0,
|
|
11
|
+
"start_y": 1,
|
|
12
|
+
"end_x": 2,
|
|
13
|
+
"end_y": 3,
|
|
14
|
+
"num_spurs": 4,
|
|
15
|
+
"curvature": 5,
|
|
16
|
+
"cos_angle": 6,
|
|
17
|
+
"sin_angle": 7,
|
|
18
|
+
}
|
|
19
|
+
KEY_TO_FEATURE = {v: k for k, v in FEATURE_TO_KEY.items()}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_curvature(mask: np.ndarray) -> float:
|
|
23
|
+
"""
|
|
24
|
+
Compute the average curvature of the skeletonized binary mask.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
mask: (H, W) binary mask
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Average curvature value
|
|
31
|
+
"""
|
|
32
|
+
skeleton = skeletonize(mask > 0).astype(np.uint8)
|
|
33
|
+
ys, xs = np.nonzero(skeleton)
|
|
34
|
+
|
|
35
|
+
if len(xs) < 10:
|
|
36
|
+
return 0
|
|
37
|
+
|
|
38
|
+
# Curvature proxy: angle changes
|
|
39
|
+
coords = np.stack([xs, ys], axis=1)
|
|
40
|
+
diffs = np.diff(coords, axis=0)
|
|
41
|
+
norms = np.linalg.norm(diffs, axis=1, keepdims=True) + 1e-6
|
|
42
|
+
directions = diffs / norms
|
|
43
|
+
|
|
44
|
+
angles = np.arccos(
|
|
45
|
+
np.clip(
|
|
46
|
+
np.sum(directions[:-1] * directions[1:], axis=1),
|
|
47
|
+
-1.0,
|
|
48
|
+
1.0,
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
return np.mean(angles)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_endpoints_and_spurs(mask: np.ndarray) -> Tuple[List[Tuple[int, int]], int]:
|
|
56
|
+
"""
|
|
57
|
+
Get endpoints of a skeletonized binary mask.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
mask: (H, W) binary mask
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
List of (x, y) coordinates of endpoints
|
|
64
|
+
Number of spurs (extra endpoints beyond two)
|
|
65
|
+
"""
|
|
66
|
+
# Skeletonize the mask
|
|
67
|
+
skeleton = skeletonize(mask > 0).astype(np.uint8)
|
|
68
|
+
|
|
69
|
+
# Define a kernel to find endpoints
|
|
70
|
+
kernel = np.array([[1, 1, 1], [1, 10, 1], [1, 1, 1]], dtype=np.uint8)
|
|
71
|
+
|
|
72
|
+
# Convolve to find endpoints
|
|
73
|
+
filtered = cv2.filter2D(skeleton, -1, kernel, borderType=0)
|
|
74
|
+
|
|
75
|
+
# Endpoints will have a value of 11 in the filtered image
|
|
76
|
+
endpoints = np.argwhere(filtered == 11)
|
|
77
|
+
|
|
78
|
+
# Get the endpoints with the greatest distance between them
|
|
79
|
+
if len(endpoints) >= 2:
|
|
80
|
+
max_dist = 0
|
|
81
|
+
pt1, pt2 = endpoints[0], endpoints[1]
|
|
82
|
+
for i in range(len(endpoints) - 1):
|
|
83
|
+
for j in range(i + 1, len(endpoints)):
|
|
84
|
+
dist = np.linalg.norm(endpoints[i] - endpoints[j])
|
|
85
|
+
if dist > max_dist:
|
|
86
|
+
max_dist = dist
|
|
87
|
+
pt1, pt2 = endpoints[i], endpoints[j]
|
|
88
|
+
|
|
89
|
+
# Ensure pt1 is the topmost point
|
|
90
|
+
if pt1[0] > pt2[0]:
|
|
91
|
+
pt1, pt2 = pt2, pt1
|
|
92
|
+
else:
|
|
93
|
+
pt1 = np.unravel_index(skeleton.argmax(), skeleton.shape)
|
|
94
|
+
|
|
95
|
+
if pt1[0] == skeleton.shape[0] - 1:
|
|
96
|
+
pt2 = pt1
|
|
97
|
+
pt1 = (pt2[0] - 1, pt2[1])
|
|
98
|
+
else:
|
|
99
|
+
pt2 = (pt1[0] + 1, pt1[1])
|
|
100
|
+
|
|
101
|
+
# Normalize to [-1, 1]
|
|
102
|
+
pt1 = (pt1 / np.array(mask.shape)) * 2 - 1
|
|
103
|
+
pt2 = (pt2 / np.array(mask.shape)) * 2 - 1
|
|
104
|
+
|
|
105
|
+
return [(pt1[1], pt1[0]), (pt2[1], pt2[0])], max(
|
|
106
|
+
len(endpoints) - 2, 0
|
|
107
|
+
) # Return as (x, y) tuples, number of spurs
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def is_rgb_annotation(mask: np.ndarray) -> bool:
|
|
111
|
+
"""
|
|
112
|
+
Detect if annotation is RGB (H, W, 3) or indexed (H, W).
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
mask: Annotation array
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
True if RGB, False if indexed
|
|
119
|
+
"""
|
|
120
|
+
return len(mask.shape) == 3 and mask.shape[2] == 3
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def rgb_to_indexed(
|
|
124
|
+
rgb_mask: np.ndarray, colour_map: dict[tuple[int, int, int], int]
|
|
125
|
+
) -> np.ndarray:
|
|
126
|
+
"""
|
|
127
|
+
Convert an RGB segmentation mask to an index mask.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
rgb_mask: (H, W, 3) uint8 array
|
|
131
|
+
colour_map: dict mapping (R, G, B) -> class index
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
index_mask: (H, W) int64 array
|
|
135
|
+
"""
|
|
136
|
+
if rgb_mask.ndim != 3 or rgb_mask.shape[-1] != 3:
|
|
137
|
+
raise ValueError("mask_rgb must have shape (H, W, 3)")
|
|
138
|
+
|
|
139
|
+
h, w, _ = rgb_mask.shape
|
|
140
|
+
index_mask = np.zeros((h, w), dtype=np.int64)
|
|
141
|
+
|
|
142
|
+
# Vectorized comparison per class
|
|
143
|
+
for rgb, idx in colour_map.items():
|
|
144
|
+
rgb = np.array(rgb, dtype=rgb_mask.dtype)
|
|
145
|
+
matches = np.all(rgb_mask == rgb, axis=-1)
|
|
146
|
+
index_mask[matches] = idx
|
|
147
|
+
|
|
148
|
+
return index_mask
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def indexed_to_rgb(
|
|
152
|
+
indexed_mask: np.ndarray, colour_map: dict[tuple[int, int, int], int]
|
|
153
|
+
) -> np.ndarray:
|
|
154
|
+
"""
|
|
155
|
+
Convert indexed annotation back to RGB using color palette.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
indexed_mask: (H, W) array with class indices
|
|
159
|
+
colour_map: dict mapping (R, G, B) -> class index
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
(H, W, 3) RGB annotation
|
|
163
|
+
"""
|
|
164
|
+
palette = np.zeros((max(colour_map.values()) + 1, 3), dtype=np.uint8)
|
|
165
|
+
for rgb, idx in colour_map.items():
|
|
166
|
+
palette[idx] = np.array(rgb, dtype=np.uint8)
|
|
167
|
+
|
|
168
|
+
rgb_mask = palette[indexed_mask]
|
|
169
|
+
return rgb_mask.astype(np.uint8)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def extract_class_masks(indexed_mask: np.ndarray) -> dict[int, np.ndarray]:
|
|
173
|
+
"""
|
|
174
|
+
Extract individual binary masks for each class.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
indexed_mask: (H, W) indexed annotation
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
List of (class_id, binary_mask) tuples
|
|
181
|
+
"""
|
|
182
|
+
unique_classes = np.unique(indexed_mask)
|
|
183
|
+
class_masks = {}
|
|
184
|
+
|
|
185
|
+
for class_id in unique_classes:
|
|
186
|
+
if class_id == 0: # Skip background
|
|
187
|
+
continue
|
|
188
|
+
binary_mask = (indexed_mask == class_id).astype(np.uint8)
|
|
189
|
+
class_masks[class_id] = binary_mask
|
|
190
|
+
|
|
191
|
+
return class_masks
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def count_objects(mask: np.ndarray) -> int:
|
|
195
|
+
"""
|
|
196
|
+
Count the number of connected components (objects) in a binary mask.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
mask: (H, W) binary mask
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
Number of separate objects (connected components)
|
|
203
|
+
"""
|
|
204
|
+
if mask.max() == 0:
|
|
205
|
+
return 0
|
|
206
|
+
|
|
207
|
+
# Use OpenCV to find connected components
|
|
208
|
+
num_labels, _ = cv2.connectedComponents(mask.astype(np.uint8))
|
|
209
|
+
# Subtract 1 to exclude background
|
|
210
|
+
return max(0, num_labels - 1)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def get_objects(mask: np.ndarray) -> List[np.ndarray]:
|
|
214
|
+
"""
|
|
215
|
+
Extract individual object masks from a binary mask.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
mask: (H, W) binary mask
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
List of (H, W) binary masks for each object
|
|
222
|
+
"""
|
|
223
|
+
if mask.max() == 0:
|
|
224
|
+
return []
|
|
225
|
+
|
|
226
|
+
# Use OpenCV to find connected components
|
|
227
|
+
num_labels, labels = cv2.connectedComponents(mask.astype(np.uint8))
|
|
228
|
+
|
|
229
|
+
object_masks = []
|
|
230
|
+
for label in range(1, num_labels): # Skip background label 0
|
|
231
|
+
object_mask = (labels == label).astype(np.uint8)
|
|
232
|
+
object_masks.append(object_mask)
|
|
233
|
+
|
|
234
|
+
return object_masks
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def extract_object_features(mask: np.ndarray) -> np.ndarray:
|
|
238
|
+
"""
|
|
239
|
+
Extract features of a single object in a binary mask.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
mask: (H, W) binary mask of a single object
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Feature vector as a numpy array
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
endpoints, num_spurs = get_endpoints_and_spurs(mask)
|
|
249
|
+
(x1, y1), (x2, y2) = endpoints if len(endpoints) >= 2 else ((0, 0), (0, 0))
|
|
250
|
+
vector = np.array([x2 - x1, y2 - y1], dtype=np.float32)
|
|
251
|
+
angle = np.arctan2(vector[1], vector[0]) if np.linalg.norm(vector) > 0 else 0.0
|
|
252
|
+
curvature = get_curvature(mask)
|
|
253
|
+
|
|
254
|
+
# Normalize curvature to [-1, 1]
|
|
255
|
+
normalized_curvature = min(curvature, 0.5) * 4 - 1
|
|
256
|
+
|
|
257
|
+
features = {
|
|
258
|
+
"start_x": x1,
|
|
259
|
+
"start_y": y1,
|
|
260
|
+
"end_x": x2,
|
|
261
|
+
"end_y": y2,
|
|
262
|
+
"num_spurs": num_spurs,
|
|
263
|
+
"curvature": normalized_curvature,
|
|
264
|
+
"cos_angle": math.cos(angle),
|
|
265
|
+
"sin_angle": math.sin(angle),
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
return pack_feature_vector(features)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def create_spline(
|
|
272
|
+
image_shape: Tuple[int, int],
|
|
273
|
+
start: Tuple[int, int],
|
|
274
|
+
end: Tuple[int, int],
|
|
275
|
+
num_ctrl: int = 5,
|
|
276
|
+
curvature_scale: float = 0.2,
|
|
277
|
+
) -> np.ndarray:
|
|
278
|
+
"""
|
|
279
|
+
Generate a smooth random spline curve.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
image_shape: Tuple (H, W) for the target image size
|
|
283
|
+
start: Tuple (x, y) for the starting point
|
|
284
|
+
end: Tuple (x, y) for the ending point
|
|
285
|
+
num_ctrl: Number of control points for the spline
|
|
286
|
+
curvature_scale: Scale factor for curvature randomness
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Array of shape (N, 2) containing the spline points [x, y]
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
H, W = image_shape
|
|
293
|
+
|
|
294
|
+
# Random control points between start and end
|
|
295
|
+
ctrl_x = (
|
|
296
|
+
np.linspace(start[0], end[0], num_ctrl)
|
|
297
|
+
+ np.random.randn(num_ctrl) * curvature_scale
|
|
298
|
+
)
|
|
299
|
+
ctrl_y = (
|
|
300
|
+
np.linspace(start[1], end[1], num_ctrl)
|
|
301
|
+
+ np.random.randn(num_ctrl) * curvature_scale
|
|
302
|
+
)
|
|
303
|
+
num_samples = max(int(300 * curvature_scale), 10)
|
|
304
|
+
|
|
305
|
+
try:
|
|
306
|
+
tck, _ = splprep([ctrl_x, ctrl_y], s=0)
|
|
307
|
+
u = np.linspace(0, 1, num_samples)
|
|
308
|
+
x, y = splev(u, tck)
|
|
309
|
+
except ValueError:
|
|
310
|
+
# Fallback to linear if spline fails
|
|
311
|
+
x = np.linspace(start[0], end[0], num_samples)
|
|
312
|
+
y = np.linspace(start[1], end[1], num_samples)
|
|
313
|
+
|
|
314
|
+
return np.stack([x, y], axis=1)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def add_branch(points, angle_std=0.5, length_scale=0.5):
|
|
318
|
+
"""
|
|
319
|
+
Add a branch to an existing spline.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
points: Main spline points array of shape (N, 2)
|
|
323
|
+
angle_std: Standard deviation for branch angle variation
|
|
324
|
+
length_scale: Scale factor for branch length
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
Array of branch points
|
|
328
|
+
"""
|
|
329
|
+
idx = np.random.randint(len(points) // 4, len(points) * 3 // 4)
|
|
330
|
+
base = points[idx]
|
|
331
|
+
|
|
332
|
+
direction = points[idx + 1] - points[idx]
|
|
333
|
+
theta = np.arctan2(direction[1], direction[0])
|
|
334
|
+
theta += np.random.randn() * angle_std
|
|
335
|
+
|
|
336
|
+
length = np.linalg.norm(direction) * 50 * length_scale
|
|
337
|
+
branch_end = base + length * np.array([np.cos(theta), np.sin(theta)])
|
|
338
|
+
|
|
339
|
+
return np.linspace(base, branch_end, 50)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def draw_polyline(
|
|
343
|
+
image_shape,
|
|
344
|
+
points,
|
|
345
|
+
thickness=10,
|
|
346
|
+
):
|
|
347
|
+
"""
|
|
348
|
+
Rasterize a polyline (spline) to an image.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
image_shape: Tuple (H, W) for the canvas size
|
|
352
|
+
points: Array of shape (N, 2) containing points [x, y]
|
|
353
|
+
thickness: Line thickness in pixels
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
Binary image with the drawn polyline
|
|
357
|
+
"""
|
|
358
|
+
canvas = np.zeros(image_shape, dtype=np.uint8)
|
|
359
|
+
|
|
360
|
+
pts = points.astype(np.int32)
|
|
361
|
+
for i in range(len(pts) - 1):
|
|
362
|
+
cv2.line(
|
|
363
|
+
canvas,
|
|
364
|
+
tuple(pts[i]),
|
|
365
|
+
tuple(pts[i + 1]),
|
|
366
|
+
color=1,
|
|
367
|
+
thickness=thickness,
|
|
368
|
+
)
|
|
369
|
+
return canvas
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def generate_scribble(
|
|
373
|
+
image_shape,
|
|
374
|
+
features: dict,
|
|
375
|
+
class_id: int = 1,
|
|
376
|
+
):
|
|
377
|
+
"""
|
|
378
|
+
Generate a synthetic scribble annotation based on dataset statistics.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
image_shape: Tuple (H, W) for the output size
|
|
382
|
+
features: Dictionary with scribble features
|
|
383
|
+
class_id: Class index to assign to the generated scribble (default: 1)
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
Binary scribble mask of shape (H, W) with values 0 (background) and class_id
|
|
387
|
+
"""
|
|
388
|
+
start = (features["start_x"], features["start_y"])
|
|
389
|
+
end = (features["end_x"], features["end_y"])
|
|
390
|
+
|
|
391
|
+
main = create_spline(
|
|
392
|
+
image_shape,
|
|
393
|
+
start=start,
|
|
394
|
+
end=end,
|
|
395
|
+
num_ctrl=10,
|
|
396
|
+
curvature_scale=min(features["curvature"], 0.5) * 20,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
canvas = draw_polyline(image_shape, main)
|
|
400
|
+
|
|
401
|
+
# Branching
|
|
402
|
+
for _ in range(int(features["num_spurs"])):
|
|
403
|
+
branch = add_branch(main)
|
|
404
|
+
canvas |= draw_polyline(image_shape, branch)
|
|
405
|
+
|
|
406
|
+
# Apply class_id
|
|
407
|
+
canvas = canvas * class_id
|
|
408
|
+
|
|
409
|
+
return canvas
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def generate_multiclass_scribble(
|
|
413
|
+
image_shape,
|
|
414
|
+
objects: list[dict],
|
|
415
|
+
classes: np.ndarray | torch.Tensor,
|
|
416
|
+
colour_map: Optional[dict[tuple[int, int, int], int]] = None,
|
|
417
|
+
) -> np.ndarray:
|
|
418
|
+
"""
|
|
419
|
+
Generate multi-class synthetic scribble annotation.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
image_shape: Tuple (H, W) for the output size
|
|
423
|
+
stats_per_class: Dictionary mapping class_id to stats dict
|
|
424
|
+
colour_map: Optional dict mapping (R, G, B) -> class index. If provided, returns RGB output.
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
Either indexed (H, W) or RGB (H, W, 3) scribble annotation
|
|
428
|
+
"""
|
|
429
|
+
if type(classes) is torch.Tensor:
|
|
430
|
+
classes = classes.cpu().numpy()
|
|
431
|
+
|
|
432
|
+
# Denormalize
|
|
433
|
+
for obj in objects:
|
|
434
|
+
obj["curvature"] = (obj["curvature"] / 4.0) + 0.25
|
|
435
|
+
obj["start_x"] = int((obj["start_x"] + 1) * image_shape[1] / 2)
|
|
436
|
+
obj["start_y"] = int((obj["start_y"] + 1) * image_shape[0] / 2)
|
|
437
|
+
obj["end_x"] = int((obj["end_x"] + 1) * image_shape[1] / 2)
|
|
438
|
+
obj["end_y"] = int((obj["end_y"] + 1) * image_shape[0] / 2)
|
|
439
|
+
|
|
440
|
+
# Create empty canvas
|
|
441
|
+
indexed_output = np.zeros(image_shape, dtype=np.uint8)
|
|
442
|
+
|
|
443
|
+
# Generate scribble for each class
|
|
444
|
+
for class_id, features in zip(classes, objects):
|
|
445
|
+
if class_id == 0: # Skip background
|
|
446
|
+
continue
|
|
447
|
+
|
|
448
|
+
class_scribble = generate_scribble(
|
|
449
|
+
image_shape=image_shape,
|
|
450
|
+
features=features,
|
|
451
|
+
class_id=class_id,
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
# Add to canvas (later classes can overwrite earlier ones at overlaps)
|
|
455
|
+
indexed_output = np.where(class_scribble > 0, class_scribble, indexed_output)
|
|
456
|
+
|
|
457
|
+
# Convert to RGB if colour_map provided
|
|
458
|
+
if colour_map is not None:
|
|
459
|
+
return indexed_to_rgb(indexed_output, colour_map)
|
|
460
|
+
|
|
461
|
+
return indexed_output
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def pack_feature_vector(features: dict) -> torch.Tensor:
|
|
465
|
+
"""
|
|
466
|
+
Pack a feature vector into a Tensor from a dict.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
features: dictionary of feature names to values
|
|
470
|
+
|
|
471
|
+
Returns:
|
|
472
|
+
feature vector as a torch Tensor
|
|
473
|
+
"""
|
|
474
|
+
vector = torch.zeros(len(FEATURE_TO_KEY), dtype=torch.float32)
|
|
475
|
+
for key, value in features.items():
|
|
476
|
+
if key in FEATURE_TO_KEY:
|
|
477
|
+
vector[FEATURE_TO_KEY[key]] = value
|
|
478
|
+
return vector
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def unpack_feature_vector(vector: torch.Tensor) -> dict:
|
|
482
|
+
"""
|
|
483
|
+
Unpack a feature vector Tensor into a dict.
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
vector: feature vector as a torch Tensor
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
dictionary of feature names to values
|
|
490
|
+
"""
|
|
491
|
+
features = {}
|
|
492
|
+
for i in range(len(vector)):
|
|
493
|
+
key = KEY_TO_FEATURE[i]
|
|
494
|
+
features[key] = vector[i].item()
|
|
495
|
+
return features
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: scribble-annotation-generator
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Programmatically generate semi-realistic synthetic scribble annotations based on statistics from existing scribble datasets
|
|
5
|
+
Project-URL: Homepage, https://github.com/alexsenden/scribble-annotation-generator
|
|
6
|
+
Project-URL: Repository, https://github.com/alexsenden/scribble-annotation-generator
|
|
7
|
+
Project-URL: Issues, https://github.com/alexsenden/scribble-annotation-generator/issues
|
|
8
|
+
Author: Alex Senden
|
|
9
|
+
License: MIT
|
|
10
|
+
Keywords: annotation,computer-vision,scribble,segmentation,synthetic-data
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
22
|
+
Classifier: Topic :: Scientific/Engineering :: Image Processing
|
|
23
|
+
Requires-Python: >=3.8
|
|
24
|
+
Requires-Dist: numpy
|
|
25
|
+
Requires-Dist: opencv-python
|
|
26
|
+
Requires-Dist: scikit-image
|
|
27
|
+
Requires-Dist: scipy
|
|
28
|
+
Description-Content-Type: text/markdown
|
|
29
|
+
|
|
30
|
+
# Scribble Annotation Generator
|
|
31
|
+
|
|
32
|
+
Programmatically generate semi-realistic scribble annotations for segmentation-style tasks. The project exposes a single CLI entrypoint for two workflows: synthetic crop-field generation and training/inference of the neural scribble generator.
|
|
33
|
+
|
|
34
|
+
## Installation
|
|
35
|
+
|
|
36
|
+
```bash
|
|
37
|
+
pip install -e .
|
|
38
|
+
# or
|
|
39
|
+
pip install scribble-annotation-generator
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
After installation, the CLI command `scribble-annotation-generator` becomes available.
|
|
43
|
+
|
|
44
|
+
## Colour Map Specification
|
|
45
|
+
|
|
46
|
+
Many commands require a colour map that links RGB tuples to class IDs. Provide it in either form:
|
|
47
|
+
|
|
48
|
+
- Inline string: `R,G,B=class;R,G,B=class` (also accepts `R,G,B:class`)
|
|
49
|
+
- Example: `0,0,0=0;0,128,255=1;124,255,121=2`
|
|
50
|
+
- File path: a text file with one entry per line. Each line is `R,G,B,class`. If the class column is omitted, class IDs are assigned by line order starting at 0.
|
|
51
|
+
|
|
52
|
+
## CLI
|
|
53
|
+
|
|
54
|
+
### 1) Crop-field synthesis
|
|
55
|
+
|
|
56
|
+
Generate synthetic crop-field scribble images using a procedural model.
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
scribble-annotation-generator crop-field \
|
|
60
|
+
--colour-map "0,0,0=0;0,128,255=1;124,255,121=2" \
|
|
61
|
+
--output-dir ./path/to/output \
|
|
62
|
+
--num-samples 50 \
|
|
63
|
+
--min-rows 4 \
|
|
64
|
+
--max-rows 6
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
Key flags:
|
|
68
|
+
|
|
69
|
+
- `--colour-map` (required): inline or file as described above
|
|
70
|
+
- `--output-dir`: where PNGs are written (default `./local/crop_field`)
|
|
71
|
+
- `--num-samples`: number of images to create (default `200`)
|
|
72
|
+
- `--min-rows`, `--max-rows`: range for rows per sample
|
|
73
|
+
|
|
74
|
+
### 2) Train and run neural generator
|
|
75
|
+
|
|
76
|
+
Train the transformer-based object generator on a dataset of scribble annotations, then render model predictions on the validation set.
|
|
77
|
+
|
|
78
|
+
```bash
|
|
79
|
+
scribble-annotation-generator train-nn \
|
|
80
|
+
--train-dir ./local/soybean1/train \
|
|
81
|
+
--val-dir ./local/soybean1/val \
|
|
82
|
+
--colour-map ./colour_map.csv \
|
|
83
|
+
--checkpoint-dir ./local/nn-checkpoints \
|
|
84
|
+
--inference-dir ./local/nn-inference \
|
|
85
|
+
--batch-size 8 \
|
|
86
|
+
--num-workers 4 \
|
|
87
|
+
--max-epochs 50
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
Key flags:
|
|
91
|
+
|
|
92
|
+
- `--train-dir`, `--val-dir` (required): directories containing training and validation data
|
|
93
|
+
- `--colour-map` (required): inline or file form
|
|
94
|
+
- `--checkpoint-dir`: where PyTorch Lightning checkpoints are stored (default `./local/nn-checkpoints`)
|
|
95
|
+
- `--inference-dir`: where rendered scribbles from validation samples are saved (default `./local/nn-inference`)
|
|
96
|
+
- `--batch-size`, `--num-workers`, `--max-epochs`: training configuration
|
|
97
|
+
- `--num-classes`: override number of classes; by default derived from the colour map
|
|
98
|
+
|
|
99
|
+
## Python API
|
|
100
|
+
|
|
101
|
+
Instead of calling the CLI, you can call the main functions directly:
|
|
102
|
+
|
|
103
|
+
- `scribble_annotation_generator.crop_field.generate_crop_field_dataset(output_dir, colour_map, num_samples=..., min_rows=..., max_rows=...)`
|
|
104
|
+
- `scribble_annotation_generator.nn.train_and_infer(train_dir, val_dir, colour_map, checkpoint_dir=..., inference_dir=..., batch_size=..., num_workers=..., max_epochs=..., num_classes=None)`
|
|
105
|
+
|
|
106
|
+
## License
|
|
107
|
+
|
|
108
|
+
MIT
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
scribble_annotation_generator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
scribble_annotation_generator/cli.py,sha256=e6Ufmf0CLMrKZCL08pB5RqiiQMNJBU9uZwcL86S-Jcw,6251
|
|
3
|
+
scribble_annotation_generator/crop_field.py,sha256=-Sil1SXywALHwWHF7Q03FnuC4jZaXdohjS_ZiD9E58o,10732
|
|
4
|
+
scribble_annotation_generator/dataset.py,sha256=jauKr8ZBJ1o8jEn8T_RKpVgqu8kwLmucIMyhZCkPiTg,3298
|
|
5
|
+
scribble_annotation_generator/debug.py,sha256=YJnfkBJL7Vwlqz9SWeybAhwID8Pcwzp_RFnL9xPOQyI,1194
|
|
6
|
+
scribble_annotation_generator/nn.py,sha256=aSQPkVpvsya942hz02LKoUEkZfL_f7lCvYQ5cI8R3Ts,16627
|
|
7
|
+
scribble_annotation_generator/utils.py,sha256=gluwQSroMd4bg6iwchiv4VBTK57t4OvScH8uv29erWY,13628
|
|
8
|
+
scribble_annotation_generator-0.0.1.dist-info/METADATA,sha256=6qXFoYfQXb5cI-4wOD-Fm_OAHHAIW-bGwIvdTbmAHT0,4281
|
|
9
|
+
scribble_annotation_generator-0.0.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
10
|
+
scribble_annotation_generator-0.0.1.dist-info/entry_points.txt,sha256=A5UbznzAcE5XF5MZrth2rdLvG2IQXhjK2lhklUf9QyU,89
|
|
11
|
+
scribble_annotation_generator-0.0.1.dist-info/RECORD,,
|