scribble-annotation-generator 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scribble_annotation_generator/__init__.py +0 -0
- scribble_annotation_generator/cli.py +195 -0
- scribble_annotation_generator/crop_field.py +366 -0
- scribble_annotation_generator/dataset.py +96 -0
- scribble_annotation_generator/debug.py +43 -0
- scribble_annotation_generator/nn.py +570 -0
- scribble_annotation_generator/utils.py +495 -0
- scribble_annotation_generator-0.0.1.dist-info/METADATA +108 -0
- scribble_annotation_generator-0.0.1.dist-info/RECORD +11 -0
- scribble_annotation_generator-0.0.1.dist-info/WHEEL +4 -0
- scribble_annotation_generator-0.0.1.dist-info/entry_points.txt +2 -0
|
File without changes
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
from typing import Dict, Tuple
|
|
4
|
+
|
|
5
|
+
from scribble_annotation_generator.crop_field import (
|
|
6
|
+
NUM_SAMPLES_TO_GENERATE,
|
|
7
|
+
generate_crop_field_dataset,
|
|
8
|
+
)
|
|
9
|
+
from scribble_annotation_generator.nn import train_and_infer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def parse_colour_map(value: str) -> Dict[Tuple[int, int, int], int]:
|
|
13
|
+
"""Parse a colour map from an inline string or a file path."""
|
|
14
|
+
|
|
15
|
+
def _validate_rgb(rgb: Tuple[int, int, int]) -> Tuple[int, int, int]:
|
|
16
|
+
r, g, b = rgb
|
|
17
|
+
for channel in (r, g, b):
|
|
18
|
+
if channel < 0 or channel > 255:
|
|
19
|
+
raise ValueError("RGB values must be between 0 and 255")
|
|
20
|
+
return rgb
|
|
21
|
+
|
|
22
|
+
mapping: Dict[Tuple[int, int, int], int] = {}
|
|
23
|
+
|
|
24
|
+
if os.path.isfile(value):
|
|
25
|
+
with open(value, "r", encoding="utf-8") as handle:
|
|
26
|
+
for idx, line in enumerate(handle):
|
|
27
|
+
stripped = line.strip()
|
|
28
|
+
if not stripped:
|
|
29
|
+
continue
|
|
30
|
+
parts = [part.strip() for part in stripped.split(",") if part.strip()]
|
|
31
|
+
if len(parts) == 4:
|
|
32
|
+
r, g, b, cls = parts
|
|
33
|
+
elif len(parts) == 3:
|
|
34
|
+
r, g, b = parts
|
|
35
|
+
cls = idx
|
|
36
|
+
else:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"Each line in the colour map file must have 3 (RGB) or 4 (RGB,class) comma-separated values"
|
|
39
|
+
)
|
|
40
|
+
rgb = _validate_rgb((int(r), int(g), int(b)))
|
|
41
|
+
mapping[rgb] = int(cls)
|
|
42
|
+
else:
|
|
43
|
+
entries = [entry.strip() for entry in value.split(";") if entry.strip()]
|
|
44
|
+
for entry in entries:
|
|
45
|
+
if "=" in entry:
|
|
46
|
+
colour_part, class_part = entry.split("=", 1)
|
|
47
|
+
elif ":" in entry:
|
|
48
|
+
colour_part, class_part = entry.split(":", 1)
|
|
49
|
+
else:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"Inline colour map entries must separate colour and class with '=' or ':'"
|
|
52
|
+
)
|
|
53
|
+
rgb_parts = [part.strip() for part in colour_part.split(",") if part.strip()]
|
|
54
|
+
if len(rgb_parts) != 3:
|
|
55
|
+
raise ValueError("Colours must be provided as R,G,B")
|
|
56
|
+
rgb = _validate_rgb((int(rgb_parts[0]), int(rgb_parts[1]), int(rgb_parts[2])))
|
|
57
|
+
mapping[rgb] = int(class_part.strip())
|
|
58
|
+
|
|
59
|
+
if not mapping:
|
|
60
|
+
raise ValueError("No colours were parsed for the colour map")
|
|
61
|
+
|
|
62
|
+
return mapping
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def build_parser() -> argparse.ArgumentParser:
|
|
66
|
+
parser = argparse.ArgumentParser(
|
|
67
|
+
description="Scribble Annotation Generator CLI",
|
|
68
|
+
)
|
|
69
|
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
70
|
+
|
|
71
|
+
crop_parser = subparsers.add_parser(
|
|
72
|
+
"crop-field", help="Generate synthetic crop field scribble images."
|
|
73
|
+
)
|
|
74
|
+
crop_parser.add_argument(
|
|
75
|
+
"--output-dir",
|
|
76
|
+
default="./local/crop_field",
|
|
77
|
+
help="Directory to write generated crop field images.",
|
|
78
|
+
)
|
|
79
|
+
crop_parser.add_argument(
|
|
80
|
+
"--num-samples",
|
|
81
|
+
type=int,
|
|
82
|
+
default=NUM_SAMPLES_TO_GENERATE,
|
|
83
|
+
help="Number of images to generate.",
|
|
84
|
+
)
|
|
85
|
+
crop_parser.add_argument(
|
|
86
|
+
"--min-rows",
|
|
87
|
+
type=int,
|
|
88
|
+
default=4,
|
|
89
|
+
help="Minimum number of crop rows per sample.",
|
|
90
|
+
)
|
|
91
|
+
crop_parser.add_argument(
|
|
92
|
+
"--max-rows",
|
|
93
|
+
type=int,
|
|
94
|
+
default=6,
|
|
95
|
+
help="Maximum number of crop rows per sample.",
|
|
96
|
+
)
|
|
97
|
+
crop_parser.add_argument(
|
|
98
|
+
"--colour-map",
|
|
99
|
+
required=True,
|
|
100
|
+
help=(
|
|
101
|
+
"Colour map specified inline as 'R,G,B=class;...' or a path to a file "
|
|
102
|
+
"with one 'R,G,B,class' entry per line."
|
|
103
|
+
),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
train_parser = subparsers.add_parser(
|
|
107
|
+
"train-nn", help="Train the scribble object generator and run inference."
|
|
108
|
+
)
|
|
109
|
+
train_parser.add_argument(
|
|
110
|
+
"--train-dir",
|
|
111
|
+
required=True,
|
|
112
|
+
help="Path to the training dataset directory.",
|
|
113
|
+
)
|
|
114
|
+
train_parser.add_argument(
|
|
115
|
+
"--val-dir",
|
|
116
|
+
required=True,
|
|
117
|
+
help="Path to the validation dataset directory.",
|
|
118
|
+
)
|
|
119
|
+
train_parser.add_argument(
|
|
120
|
+
"--checkpoint-dir",
|
|
121
|
+
default="./local/nn-checkpoints",
|
|
122
|
+
help="Directory to save model checkpoints.",
|
|
123
|
+
)
|
|
124
|
+
train_parser.add_argument(
|
|
125
|
+
"--inference-dir",
|
|
126
|
+
default="./local/nn-inference",
|
|
127
|
+
help="Directory to save inference visualisations.",
|
|
128
|
+
)
|
|
129
|
+
train_parser.add_argument(
|
|
130
|
+
"--batch-size",
|
|
131
|
+
type=int,
|
|
132
|
+
default=8,
|
|
133
|
+
help="Batch size for training.",
|
|
134
|
+
)
|
|
135
|
+
train_parser.add_argument(
|
|
136
|
+
"--num-workers",
|
|
137
|
+
type=int,
|
|
138
|
+
default=4,
|
|
139
|
+
help="Number of worker processes for data loading.",
|
|
140
|
+
)
|
|
141
|
+
train_parser.add_argument(
|
|
142
|
+
"--max-epochs",
|
|
143
|
+
type=int,
|
|
144
|
+
default=50,
|
|
145
|
+
help="Maximum number of training epochs.",
|
|
146
|
+
)
|
|
147
|
+
train_parser.add_argument(
|
|
148
|
+
"--num-classes",
|
|
149
|
+
type=int,
|
|
150
|
+
default=None,
|
|
151
|
+
help="Override the number of classes; defaults to the number of unique class IDs in the colour map.",
|
|
152
|
+
)
|
|
153
|
+
train_parser.add_argument(
|
|
154
|
+
"--colour-map",
|
|
155
|
+
required=True,
|
|
156
|
+
help=(
|
|
157
|
+
"Colour map specified inline as 'R,G,B=class;...' or a path to a file "
|
|
158
|
+
"with one 'R,G,B,class' entry per line."
|
|
159
|
+
),
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
return parser
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def main(argv=None):
|
|
166
|
+
parser = build_parser()
|
|
167
|
+
args = parser.parse_args(argv)
|
|
168
|
+
colour_map = parse_colour_map(args.colour_map)
|
|
169
|
+
|
|
170
|
+
if args.command == "crop-field":
|
|
171
|
+
generate_crop_field_dataset(
|
|
172
|
+
output_dir=args.output_dir,
|
|
173
|
+
colour_map=colour_map,
|
|
174
|
+
num_samples=args.num_samples,
|
|
175
|
+
min_rows=args.min_rows,
|
|
176
|
+
max_rows=args.max_rows,
|
|
177
|
+
)
|
|
178
|
+
elif args.command == "train-nn":
|
|
179
|
+
train_and_infer(
|
|
180
|
+
train_dir=args.train_dir,
|
|
181
|
+
val_dir=args.val_dir,
|
|
182
|
+
colour_map=colour_map,
|
|
183
|
+
checkpoint_dir=args.checkpoint_dir,
|
|
184
|
+
inference_dir=args.inference_dir,
|
|
185
|
+
batch_size=args.batch_size,
|
|
186
|
+
num_workers=args.num_workers,
|
|
187
|
+
max_epochs=args.max_epochs,
|
|
188
|
+
num_classes=args.num_classes,
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
parser.error("A subcommand is required.")
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
if __name__ == "__main__":
|
|
195
|
+
main()
|
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import math
|
|
3
|
+
import numpy as np
|
|
4
|
+
import os
|
|
5
|
+
import random
|
|
6
|
+
|
|
7
|
+
from scribble_annotation_generator.utils import generate_multiclass_scribble
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
NUM_SAMPLES_TO_GENERATE = 200
|
|
11
|
+
ROW_STD = 0.02
|
|
12
|
+
ROW_CURVATURE_MEAN = -0.8
|
|
13
|
+
ROW_CURVATURE_STD = 0.05
|
|
14
|
+
ROW_MIN_LENGTH = 0.1
|
|
15
|
+
ROW_SPARSITY_DISTANCE_MEAN = 0.4
|
|
16
|
+
ROW_SPARSITY_DISTANCE_STD = 0.2
|
|
17
|
+
WEED_MAX_LENGTH = 0.5
|
|
18
|
+
WEED_MIN_LENGTH = 0.001
|
|
19
|
+
WEED_DIRECTIONAL_STD = math.pi / 6
|
|
20
|
+
WEED_CURVATURE_SHIFT_FACTOR = 0.3
|
|
21
|
+
WEED_CURVATURE_SCALE_CONSTANT = 0.2
|
|
22
|
+
WEED_CURVATURE_MIN_STD = 0.4
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Point:
|
|
26
|
+
def __init__(self, x, y):
|
|
27
|
+
self.x = x
|
|
28
|
+
self.y = y
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def ccw(A, B, C):
|
|
32
|
+
return (C.y - A.y) * (B.x - A.x) > (B.y - A.y) * (C.x - A.x)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# Return true if line segments AB and CD intersect
|
|
36
|
+
def intersect(A, B, C, D):
|
|
37
|
+
return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def x_at_y(p1, p2, y):
|
|
41
|
+
x1, y1 = p1
|
|
42
|
+
x2, y2 = p2
|
|
43
|
+
|
|
44
|
+
t = (y - y1) / (y2 - y1)
|
|
45
|
+
return x1 + t * (x2 - x1)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def split_row(row_object, sparsity):
|
|
49
|
+
num_splits = np.random.poisson(lam=((1 - sparsity) ** 2) * 5)
|
|
50
|
+
if num_splits == 0:
|
|
51
|
+
return [row_object]
|
|
52
|
+
|
|
53
|
+
distance_between_splits = np.random.normal(
|
|
54
|
+
loc=ROW_SPARSITY_DISTANCE_MEAN,
|
|
55
|
+
scale=ROW_SPARSITY_DISTANCE_STD,
|
|
56
|
+
size=num_splits,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
distance_between_splits = list(np.clip(distance_between_splits, 0.05, None))
|
|
60
|
+
|
|
61
|
+
# Split the line into num_splits + 1 segments
|
|
62
|
+
split_ys = np.random.uniform(
|
|
63
|
+
row_object["start_y"], row_object["end_y"], size=num_splits
|
|
64
|
+
)
|
|
65
|
+
split_ys = list(np.sort(split_ys))
|
|
66
|
+
|
|
67
|
+
# If split points are too close to each other or to the boundary, remove one
|
|
68
|
+
i = 1
|
|
69
|
+
while i < len(split_ys):
|
|
70
|
+
if (
|
|
71
|
+
split_ys[i - 1] - (distance_between_splits[i - 1] / 2.0)
|
|
72
|
+
< row_object["start_y"]
|
|
73
|
+
):
|
|
74
|
+
split_ys.pop(i - 1)
|
|
75
|
+
distance_between_splits.pop(i - 1)
|
|
76
|
+
elif split_ys[i] + (distance_between_splits[i] / 2.0) > row_object["end_y"]:
|
|
77
|
+
split_ys.pop(i)
|
|
78
|
+
distance_between_splits.pop(i)
|
|
79
|
+
elif split_ys[i] - split_ys[i - 1] < distance_between_splits[i - 1]:
|
|
80
|
+
split_ys.pop(i)
|
|
81
|
+
distance_between_splits.pop(i)
|
|
82
|
+
else:
|
|
83
|
+
i += 1
|
|
84
|
+
|
|
85
|
+
line_segment_ys = []
|
|
86
|
+
for i in range(len(split_ys) + 1):
|
|
87
|
+
if i == 0:
|
|
88
|
+
segment_start_y = row_object["start_y"]
|
|
89
|
+
else:
|
|
90
|
+
segment_start_y = split_ys[i - 1] + (distance_between_splits[i - 1] / 2.0)
|
|
91
|
+
|
|
92
|
+
if i == len(split_ys):
|
|
93
|
+
segment_end_y = row_object["end_y"]
|
|
94
|
+
else:
|
|
95
|
+
segment_end_y = split_ys[i] - (distance_between_splits[i] / 2.0)
|
|
96
|
+
|
|
97
|
+
line_segment_ys.append((segment_start_y, segment_end_y))
|
|
98
|
+
|
|
99
|
+
line_segment_xs = [
|
|
100
|
+
(
|
|
101
|
+
x_at_y(
|
|
102
|
+
(row_object["start_x"], row_object["start_y"]),
|
|
103
|
+
(row_object["end_x"], row_object["end_y"]),
|
|
104
|
+
y[0],
|
|
105
|
+
),
|
|
106
|
+
x_at_y(
|
|
107
|
+
(row_object["start_x"], row_object["start_y"]),
|
|
108
|
+
(row_object["end_x"], row_object["end_y"]),
|
|
109
|
+
y[1],
|
|
110
|
+
),
|
|
111
|
+
)
|
|
112
|
+
for y in line_segment_ys
|
|
113
|
+
]
|
|
114
|
+
|
|
115
|
+
objects = []
|
|
116
|
+
for i in range(len(line_segment_ys)):
|
|
117
|
+
split_row_object = {
|
|
118
|
+
"start_x": line_segment_xs[i][0],
|
|
119
|
+
"start_y": line_segment_ys[i][0],
|
|
120
|
+
"end_x": line_segment_xs[i][1],
|
|
121
|
+
"end_y": line_segment_ys[i][1],
|
|
122
|
+
"num_spurs": row_object["num_spurs"],
|
|
123
|
+
"curvature": row_object["curvature"],
|
|
124
|
+
"cos_angle": row_object["cos_angle"],
|
|
125
|
+
"sin_angle": row_object["sin_angle"],
|
|
126
|
+
}
|
|
127
|
+
objects.append(split_row_object)
|
|
128
|
+
|
|
129
|
+
return objects
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def generate_row_object(
|
|
133
|
+
row_starting_x: float,
|
|
134
|
+
row_class: int,
|
|
135
|
+
row_sparsity: float = 1.0,
|
|
136
|
+
):
|
|
137
|
+
row_x0 = np.clip(
|
|
138
|
+
np.random.normal(loc=row_starting_x, scale=ROW_STD),
|
|
139
|
+
-1.0,
|
|
140
|
+
1.0,
|
|
141
|
+
)
|
|
142
|
+
row_x1 = np.clip(
|
|
143
|
+
np.random.normal(loc=row_starting_x, scale=ROW_STD),
|
|
144
|
+
-1.0,
|
|
145
|
+
1.0,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
row_y0 = np.clip(
|
|
149
|
+
np.random.normal(loc=-1.0, scale=ROW_STD),
|
|
150
|
+
-1.0,
|
|
151
|
+
1.0,
|
|
152
|
+
)
|
|
153
|
+
row_y1 = np.clip(
|
|
154
|
+
np.random.normal(loc=1.0, scale=ROW_STD),
|
|
155
|
+
-1.0,
|
|
156
|
+
1.0,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
curvature = np.clip(
|
|
160
|
+
np.random.normal(
|
|
161
|
+
loc=ROW_CURVATURE_MEAN,
|
|
162
|
+
scale=ROW_CURVATURE_STD,
|
|
163
|
+
),
|
|
164
|
+
-1.0,
|
|
165
|
+
1.0,
|
|
166
|
+
)
|
|
167
|
+
num_spurs = 0
|
|
168
|
+
|
|
169
|
+
angle = math.atan2(row_y1 - row_y0, row_x1 - row_x0)
|
|
170
|
+
cos_angle = math.cos(angle)
|
|
171
|
+
sin_angle = math.sin(angle)
|
|
172
|
+
|
|
173
|
+
row_object = {
|
|
174
|
+
"start_x": row_x0,
|
|
175
|
+
"start_y": row_y0,
|
|
176
|
+
"end_x": row_x1,
|
|
177
|
+
"end_y": row_y1,
|
|
178
|
+
"num_spurs": num_spurs,
|
|
179
|
+
"curvature": curvature,
|
|
180
|
+
"cos_angle": cos_angle,
|
|
181
|
+
"sin_angle": sin_angle,
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
if row_sparsity < 1.0:
|
|
185
|
+
objects = split_row(row_object, row_sparsity)
|
|
186
|
+
classes = [row_class] * len(objects)
|
|
187
|
+
else:
|
|
188
|
+
objects = [row_object]
|
|
189
|
+
classes = [row_class]
|
|
190
|
+
|
|
191
|
+
return objects, classes
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def generate_weed_object():
|
|
195
|
+
weed_x0 = random.uniform(-1.0, 1.0)
|
|
196
|
+
weed_y0 = random.uniform(-1.0, 1.0)
|
|
197
|
+
|
|
198
|
+
weed_length = random.uniform(WEED_MIN_LENGTH, WEED_MAX_LENGTH)
|
|
199
|
+
weed_angle = np.random.normal(
|
|
200
|
+
loc=3 * math.pi / 2,
|
|
201
|
+
scale=WEED_DIRECTIONAL_STD,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
weed_x1 = np.clip(weed_x0 + weed_length * math.cos(weed_angle), -1.0, 1.0)
|
|
205
|
+
weed_y1 = np.clip(weed_y0 + weed_length * math.sin(weed_angle), -1.0, 1.0)
|
|
206
|
+
|
|
207
|
+
num_spurs = 0
|
|
208
|
+
weed_length_factor = (weed_length - WEED_MIN_LENGTH + 1e-6) / (
|
|
209
|
+
WEED_MAX_LENGTH - WEED_MIN_LENGTH
|
|
210
|
+
)
|
|
211
|
+
curvature = np.clip(
|
|
212
|
+
np.random.normal(
|
|
213
|
+
loc=((((1 - weed_length_factor) * 2) - 1))
|
|
214
|
+
* (1 - WEED_CURVATURE_SHIFT_FACTOR)
|
|
215
|
+
- WEED_CURVATURE_SHIFT_FACTOR,
|
|
216
|
+
scale=max(
|
|
217
|
+
(1 - weed_length_factor) * WEED_CURVATURE_SCALE_CONSTANT,
|
|
218
|
+
WEED_CURVATURE_MIN_STD,
|
|
219
|
+
),
|
|
220
|
+
),
|
|
221
|
+
-1.0,
|
|
222
|
+
1.0,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return {
|
|
226
|
+
"start_x": weed_x0,
|
|
227
|
+
"start_y": weed_y0,
|
|
228
|
+
"end_x": weed_x1,
|
|
229
|
+
"end_y": weed_y1,
|
|
230
|
+
"num_spurs": num_spurs,
|
|
231
|
+
"curvature": curvature,
|
|
232
|
+
"cos_angle": math.cos(weed_angle),
|
|
233
|
+
"sin_angle": math.sin(weed_angle),
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def generate_sample(
|
|
238
|
+
colour_map: dict[tuple[int, int, int], int],
|
|
239
|
+
num_rows: int = 5,
|
|
240
|
+
row_class: int = 1,
|
|
241
|
+
interspersed: bool = False,
|
|
242
|
+
interspersed_num_rows: int = 0,
|
|
243
|
+
interspersed_class: int = 2,
|
|
244
|
+
row_sparsity: float = 1.0,
|
|
245
|
+
num_weeds: dict[int, int] = {},
|
|
246
|
+
):
|
|
247
|
+
objects = []
|
|
248
|
+
classes = []
|
|
249
|
+
|
|
250
|
+
row_offset = 2.0 / (num_rows + 1)
|
|
251
|
+
initial_row_starting_x = random.uniform(-1.0, -1.0 + row_offset)
|
|
252
|
+
row_starting_x = initial_row_starting_x
|
|
253
|
+
for _ in range(num_rows):
|
|
254
|
+
row_objects, row_classes = generate_row_object(
|
|
255
|
+
row_starting_x=row_starting_x,
|
|
256
|
+
row_class=row_class,
|
|
257
|
+
row_sparsity=row_sparsity,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
objects.extend(row_objects)
|
|
261
|
+
classes.extend(row_classes)
|
|
262
|
+
|
|
263
|
+
row_starting_x += row_offset
|
|
264
|
+
|
|
265
|
+
if interspersed:
|
|
266
|
+
interspersed_row_starting_x = initial_row_starting_x - (row_offset / 2.0)
|
|
267
|
+
|
|
268
|
+
# Ensure interspersed row at index 0 is within bounds
|
|
269
|
+
if interspersed_row_starting_x < -1.0:
|
|
270
|
+
interspersed_row_starting_x += row_offset
|
|
271
|
+
|
|
272
|
+
# Get maximum number of interspersed rows that fit
|
|
273
|
+
num_interspersed_row_positions = num_rows
|
|
274
|
+
half_offset = row_offset / 2.0
|
|
275
|
+
if initial_row_starting_x - half_offset > -1.0:
|
|
276
|
+
num_interspersed_row_positions += 1
|
|
277
|
+
if initial_row_starting_x + (num_rows * row_offset) + half_offset < 1.0:
|
|
278
|
+
num_interspersed_row_positions += 1
|
|
279
|
+
|
|
280
|
+
# Select the starting position for interspersed rows
|
|
281
|
+
interspersed_row_starting_index = random.randint(
|
|
282
|
+
0, max(num_interspersed_row_positions - num_interspersed_row_positions, 0)
|
|
283
|
+
)
|
|
284
|
+
interspersed_row_starting_x += interspersed_row_starting_index * row_offset
|
|
285
|
+
|
|
286
|
+
for _ in range(interspersed_num_rows):
|
|
287
|
+
row_objects, row_classes = generate_row_object(
|
|
288
|
+
row_starting_x=interspersed_row_starting_x,
|
|
289
|
+
row_class=interspersed_class,
|
|
290
|
+
row_sparsity=row_sparsity,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
objects.extend(row_objects)
|
|
294
|
+
classes.extend(row_classes)
|
|
295
|
+
|
|
296
|
+
interspersed_row_starting_x += row_offset
|
|
297
|
+
|
|
298
|
+
for weed_class, num_weed in num_weeds.items():
|
|
299
|
+
|
|
300
|
+
for _ in range(num_weed):
|
|
301
|
+
intersects = True
|
|
302
|
+
while intersects:
|
|
303
|
+
intersects = False
|
|
304
|
+
|
|
305
|
+
weed_object = generate_weed_object()
|
|
306
|
+
weed_start = Point(weed_object["start_x"], weed_object["start_y"])
|
|
307
|
+
weed_end = Point(weed_object["end_x"], weed_object["end_y"])
|
|
308
|
+
|
|
309
|
+
for obj in objects:
|
|
310
|
+
obj_start = Point(obj["start_x"], obj["start_y"])
|
|
311
|
+
obj_end = Point(obj["end_x"], obj["end_y"])
|
|
312
|
+
if intersect(weed_start, weed_end, obj_start, obj_end):
|
|
313
|
+
intersects = True
|
|
314
|
+
break
|
|
315
|
+
|
|
316
|
+
objects.append(weed_object)
|
|
317
|
+
classes.append(weed_class)
|
|
318
|
+
|
|
319
|
+
synthetic = generate_multiclass_scribble(
|
|
320
|
+
image_shape=(512, 512),
|
|
321
|
+
objects=objects,
|
|
322
|
+
classes=classes,
|
|
323
|
+
colour_map=colour_map,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
return synthetic
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def generate_crop_field_dataset(
|
|
330
|
+
output_dir: str,
|
|
331
|
+
colour_map: dict,
|
|
332
|
+
num_samples: int = NUM_SAMPLES_TO_GENERATE,
|
|
333
|
+
min_rows: int = 4,
|
|
334
|
+
max_rows: int = 6,
|
|
335
|
+
):
|
|
336
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
337
|
+
|
|
338
|
+
for i in range(num_samples):
|
|
339
|
+
num_rows = random.randint(min_rows, max_rows)
|
|
340
|
+
row_class = random.randint(1, 3)
|
|
341
|
+
interspersed = random.choice([True, False])
|
|
342
|
+
interspersed_num_rows = random.randint(1, num_rows + 1)
|
|
343
|
+
interspersed_class = random.choice([c for c in [1, 2, 3] if c != row_class])
|
|
344
|
+
row_sparsity = random.uniform(0.1, 1.0)
|
|
345
|
+
num_weeds = {
|
|
346
|
+
2: random.randint(0, 5),
|
|
347
|
+
3: random.randint(0, 10),
|
|
348
|
+
4: random.randint(0, 10),
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
sample = generate_sample(
|
|
352
|
+
colour_map=colour_map,
|
|
353
|
+
num_rows=num_rows,
|
|
354
|
+
row_class=row_class,
|
|
355
|
+
interspersed=interspersed,
|
|
356
|
+
interspersed_num_rows=interspersed_num_rows,
|
|
357
|
+
interspersed_class=interspersed_class,
|
|
358
|
+
row_sparsity=row_sparsity,
|
|
359
|
+
num_weeds=num_weeds,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
if random.random() < 0.5:
|
|
363
|
+
sample = cv2.flip(sample, 1)
|
|
364
|
+
|
|
365
|
+
output_path = os.path.join(output_dir, f"synthetic_{i:04d}.png")
|
|
366
|
+
cv2.imwrite(output_path, cv2.cvtColor(sample, cv2.COLOR_RGB2BGR))
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
|
|
4
|
+
import cv2
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
from scribble_annotation_generator.utils import (
|
|
9
|
+
extract_class_masks,
|
|
10
|
+
extract_object_features,
|
|
11
|
+
get_objects,
|
|
12
|
+
rgb_to_indexed,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ScribbleDataset(torch.utils.data.Dataset):
|
|
17
|
+
def __init__(
|
|
18
|
+
self, num_classes, data_dir, colour_map=None, max_objects=50, late_shift=False
|
|
19
|
+
):
|
|
20
|
+
self.data_dir = data_dir
|
|
21
|
+
self.filenames = sorted(os.listdir(data_dir))
|
|
22
|
+
self.num_classes = num_classes
|
|
23
|
+
self.colour_map = colour_map
|
|
24
|
+
self.max_objects = max_objects
|
|
25
|
+
self.late_shift = late_shift
|
|
26
|
+
|
|
27
|
+
if len(self.filenames) == 0:
|
|
28
|
+
raise ValueError(f"No files found in {data_dir}")
|
|
29
|
+
|
|
30
|
+
# Auto-detect format from first image
|
|
31
|
+
first_image_path = os.path.join(self.data_dir, self.filenames[0])
|
|
32
|
+
first_img = cv2.imread(first_image_path, cv2.IMREAD_UNCHANGED)
|
|
33
|
+
if first_img is not None:
|
|
34
|
+
self.is_rgb = len(first_img.shape) == 3 and first_img.shape[2] >= 3
|
|
35
|
+
else:
|
|
36
|
+
raise IOError(f"Could not read {first_image_path}")
|
|
37
|
+
|
|
38
|
+
def __len__(self):
|
|
39
|
+
return len(self.filenames)
|
|
40
|
+
|
|
41
|
+
def __getitem__(self, idx):
|
|
42
|
+
filepath = os.path.join(self.data_dir, self.filenames[idx])
|
|
43
|
+
|
|
44
|
+
if self.is_rgb:
|
|
45
|
+
if self.colour_map is None:
|
|
46
|
+
raise ValueError("colour_map must be provided for RGB annotations")
|
|
47
|
+
|
|
48
|
+
mask = cv2.imread(str(filepath), cv2.IMREAD_COLOR)
|
|
49
|
+
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
|
|
50
|
+
mask = rgb_to_indexed(mask, self.colour_map)
|
|
51
|
+
else:
|
|
52
|
+
mask = cv2.imread(str(filepath), cv2.IMREAD_GRAYSCALE)
|
|
53
|
+
|
|
54
|
+
objects = []
|
|
55
|
+
classes = []
|
|
56
|
+
class_masks = extract_class_masks(mask)
|
|
57
|
+
|
|
58
|
+
for class_id, class_mask in class_masks.items():
|
|
59
|
+
class_objects = get_objects(class_mask)
|
|
60
|
+
objects.extend(class_objects)
|
|
61
|
+
classes.extend([class_id] * len(class_objects))
|
|
62
|
+
|
|
63
|
+
objects = [extract_object_features(obj) for obj in objects]
|
|
64
|
+
|
|
65
|
+
permutation = list(range(len(objects)))
|
|
66
|
+
random.shuffle(permutation)
|
|
67
|
+
|
|
68
|
+
objects = torch.stack([objects[i] for i in permutation])
|
|
69
|
+
classes = torch.tensor([classes[i] for i in permutation])
|
|
70
|
+
|
|
71
|
+
# Mask everything after a random point
|
|
72
|
+
if self.late_shift:
|
|
73
|
+
mask_start = random.randint((len(objects) // 4) * 3, len(objects) - 1)
|
|
74
|
+
else:
|
|
75
|
+
mask_start = random.randint(1, len(objects) - 1)
|
|
76
|
+
mask = torch.ones(len(objects))
|
|
77
|
+
mask[mask_start:] = 0
|
|
78
|
+
|
|
79
|
+
query_cls = classes[mask_start]
|
|
80
|
+
targets = objects[classes == query_cls, :]
|
|
81
|
+
|
|
82
|
+
objects = F.pad(objects, (0, 0, 0, self.max_objects - len(objects)), value=0)
|
|
83
|
+
classes = F.pad(classes, (0, self.max_objects - len(classes)), value=0)
|
|
84
|
+
mask = F.pad(mask, (0, self.max_objects - len(mask)), value=0)
|
|
85
|
+
targets = F.pad(
|
|
86
|
+
targets, (0, 0, 0, self.max_objects - targets.size(0)), value=1e7
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
return {
|
|
90
|
+
"objects": objects,
|
|
91
|
+
"classes": classes,
|
|
92
|
+
"mask": mask,
|
|
93
|
+
"query_cls": query_cls,
|
|
94
|
+
"targets": targets,
|
|
95
|
+
"counts": torch.bincount(classes, minlength=self.num_classes),
|
|
96
|
+
}
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
|
|
3
|
+
from scribble_annotation_generator.nn.nn import ScribbleDataset
|
|
4
|
+
from scribble_annotation_generator.nn.utils import (
|
|
5
|
+
generate_multiclass_scribble,
|
|
6
|
+
unpack_feature_vector,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def parameterize_and_unparameterize():
|
|
11
|
+
colour_map = {
|
|
12
|
+
(0, 0, 0): 0,
|
|
13
|
+
(0, 128, 255): 1,
|
|
14
|
+
(124, 255, 121): 2,
|
|
15
|
+
(127, 0, 0): 3,
|
|
16
|
+
(255, 148, 0): 4,
|
|
17
|
+
(0, 0, 127): 5,
|
|
18
|
+
}
|
|
19
|
+
dataset = ScribbleDataset(
|
|
20
|
+
num_classes=3, data_dir="./local/soybean1", colour_map=colour_map
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
for i in range(len(dataset)):
|
|
24
|
+
sample = dataset[i]
|
|
25
|
+
|
|
26
|
+
objects = sample["objects"]
|
|
27
|
+
classes = sample["classes"]
|
|
28
|
+
|
|
29
|
+
objects = [unpack_feature_vector(obj) for obj in objects.numpy()]
|
|
30
|
+
|
|
31
|
+
synthetic = generate_multiclass_scribble(
|
|
32
|
+
image_shape=(512, 512),
|
|
33
|
+
objects=objects,
|
|
34
|
+
classes=classes,
|
|
35
|
+
colour_map=colour_map,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Save the synthetic scribble
|
|
39
|
+
output_path = f"./local/nn-out/synthetic_{i:04d}.png"
|
|
40
|
+
|
|
41
|
+
# Convert RGB to BGR for saving with OpenCV
|
|
42
|
+
synthetic_bgr = cv2.cvtColor(synthetic, cv2.COLOR_RGB2BGR)
|
|
43
|
+
cv2.imwrite(str(output_path), synthetic_bgr)
|