scribble-annotation-generator 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scribble_annotation_generator/crop_field.py +3 -2
- scribble_annotation_generator/gmm_random_generator.py +652 -0
- {scribble_annotation_generator-0.0.1.dist-info → scribble_annotation_generator-0.1.0.dist-info}/METADATA +2 -1
- {scribble_annotation_generator-0.0.1.dist-info → scribble_annotation_generator-0.1.0.dist-info}/RECORD +6 -5
- {scribble_annotation_generator-0.0.1.dist-info → scribble_annotation_generator-0.1.0.dist-info}/WHEEL +0 -0
- {scribble_annotation_generator-0.0.1.dist-info → scribble_annotation_generator-0.1.0.dist-info}/entry_points.txt +0 -0
|
@@ -327,13 +327,14 @@ def generate_sample(
|
|
|
327
327
|
|
|
328
328
|
|
|
329
329
|
def generate_crop_field_dataset(
|
|
330
|
-
output_dir: str,
|
|
331
330
|
colour_map: dict,
|
|
331
|
+
output_dir: str = None,
|
|
332
332
|
num_samples: int = NUM_SAMPLES_TO_GENERATE,
|
|
333
333
|
min_rows: int = 4,
|
|
334
334
|
max_rows: int = 6,
|
|
335
335
|
):
|
|
336
|
-
|
|
336
|
+
if output_dir is not None:
|
|
337
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
337
338
|
|
|
338
339
|
for i in range(num_samples):
|
|
339
340
|
num_rows = random.randint(min_rows, max_rows)
|
|
@@ -0,0 +1,652 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import math
|
|
3
|
+
import random
|
|
4
|
+
from collections import Counter
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Dict, Iterable, List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import cv2
|
|
9
|
+
import numpy as np
|
|
10
|
+
from sklearn.mixture import GaussianMixture
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
|
|
14
|
+
from scribble_annotation_generator.dataset import ScribbleDataset
|
|
15
|
+
from scribble_annotation_generator.cli import parse_colour_map
|
|
16
|
+
from scribble_annotation_generator.utils import (
|
|
17
|
+
extract_class_masks,
|
|
18
|
+
extract_object_features,
|
|
19
|
+
generate_multiclass_scribble,
|
|
20
|
+
get_objects,
|
|
21
|
+
rgb_to_indexed,
|
|
22
|
+
unpack_feature_vector,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class CountModelEmpirical:
|
|
28
|
+
"""Empirical distribution over count vectors."""
|
|
29
|
+
|
|
30
|
+
counts: np.ndarray
|
|
31
|
+
probs: np.ndarray
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def fit(cls, count_vectors: Iterable[np.ndarray]) -> "CountModelEmpirical":
|
|
35
|
+
tuples = [tuple(vec.tolist()) for vec in count_vectors]
|
|
36
|
+
counter = Counter(tuples)
|
|
37
|
+
keys = list(counter.keys())
|
|
38
|
+
values = np.array([counter[k] for k in keys], dtype=np.float64)
|
|
39
|
+
probs = values / max(values.sum(), 1.0)
|
|
40
|
+
counts = np.array(keys, dtype=np.int64)
|
|
41
|
+
return cls(counts=counts, probs=probs)
|
|
42
|
+
|
|
43
|
+
def sample(self, num_samples: int = 1) -> np.ndarray:
|
|
44
|
+
idx = np.random.choice(len(self.counts), size=num_samples, p=self.probs)
|
|
45
|
+
return self.counts[idx]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class GMMRandomGenerator:
|
|
49
|
+
"""Fit per-class GMMs over object geometry and an empirical count model."""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
num_classes: int,
|
|
54
|
+
num_components: int = 5,
|
|
55
|
+
random_state: int = 0,
|
|
56
|
+
class_names: Optional[List[str]] = None,
|
|
57
|
+
) -> None:
|
|
58
|
+
self.num_classes = num_classes
|
|
59
|
+
self.num_components = num_components
|
|
60
|
+
self.random_state = random_state
|
|
61
|
+
self.class_names = class_names or [f"class_{i}" for i in range(num_classes)]
|
|
62
|
+
|
|
63
|
+
self.class_gmms: Dict[int, GaussianMixture] = {}
|
|
64
|
+
self.class_stats: Dict[int, Dict[str, float]] = {}
|
|
65
|
+
self.class_spur_counts: Dict[int, List[int]] = {}
|
|
66
|
+
self.count_models: Dict[int, Optional[CountModelEmpirical]] = (
|
|
67
|
+
{}
|
|
68
|
+
) # Per-class count models
|
|
69
|
+
|
|
70
|
+
# ------------------------------------------------------------------
|
|
71
|
+
# Fitting
|
|
72
|
+
# ------------------------------------------------------------------
|
|
73
|
+
|
|
74
|
+
def fit(self, dataset: ScribbleDataset) -> "GMMRandomGenerator":
|
|
75
|
+
print(
|
|
76
|
+
f"[GMMRandomGenerator] Starting fit on dataset with {len(dataset.filenames)} files..."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
per_class_samples: Dict[int, List[np.ndarray]] = {
|
|
80
|
+
i: [] for i in range(self.num_classes)
|
|
81
|
+
}
|
|
82
|
+
per_class_spurs: Dict[int, List[int]] = {i: [] for i in range(self.num_classes)}
|
|
83
|
+
count_vectors: List[np.ndarray] = []
|
|
84
|
+
|
|
85
|
+
for class_ids, features in self._iter_objects(dataset):
|
|
86
|
+
counts = np.zeros(self.num_classes, dtype=np.int64)
|
|
87
|
+
for class_id in class_ids:
|
|
88
|
+
counts[class_id] += 1
|
|
89
|
+
count_vectors.append(counts)
|
|
90
|
+
|
|
91
|
+
for class_id, feat in zip(class_ids, features):
|
|
92
|
+
gmm_feat, spur_count = self._to_gmm_feature(feat)
|
|
93
|
+
per_class_samples[class_id].append(gmm_feat)
|
|
94
|
+
per_class_spurs[class_id].append(spur_count)
|
|
95
|
+
|
|
96
|
+
# Fit per-class count models (conditional on class presence)
|
|
97
|
+
for class_id in range(1, self.num_classes): # Skip background class (0)
|
|
98
|
+
# Filter count vectors to only those containing this class
|
|
99
|
+
class_specific_counts = [
|
|
100
|
+
counts for counts in count_vectors if counts[class_id] > 0
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
if len(class_specific_counts) > 0:
|
|
104
|
+
self.count_models[class_id] = CountModelEmpirical.fit(
|
|
105
|
+
class_specific_counts
|
|
106
|
+
)
|
|
107
|
+
print(
|
|
108
|
+
f"[GMMRandomGenerator] Fitted count model for class {class_id} with {len(class_specific_counts)} observations"
|
|
109
|
+
)
|
|
110
|
+
else:
|
|
111
|
+
fallback_counts = np.zeros(self.num_classes, dtype=np.int64)
|
|
112
|
+
fallback_counts[class_id] = 1
|
|
113
|
+
self.count_models[class_id] = CountModelEmpirical.fit(
|
|
114
|
+
[fallback_counts]
|
|
115
|
+
)
|
|
116
|
+
print(
|
|
117
|
+
f"[GMMRandomGenerator] No samples for class {class_id}, using fallback singleton count"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
fitted_classes = []
|
|
121
|
+
for class_id, samples in per_class_samples.items():
|
|
122
|
+
if len(samples) < 2:
|
|
123
|
+
continue
|
|
124
|
+
data = np.stack(samples, axis=0)
|
|
125
|
+
num_components = min(self.num_components, len(samples))
|
|
126
|
+
gmm = GaussianMixture(
|
|
127
|
+
n_components=num_components,
|
|
128
|
+
covariance_type="full",
|
|
129
|
+
random_state=self.random_state,
|
|
130
|
+
)
|
|
131
|
+
gmm.fit(data)
|
|
132
|
+
self.class_gmms[class_id] = gmm
|
|
133
|
+
fitted_classes.append(class_id)
|
|
134
|
+
|
|
135
|
+
lengths = data[:, 2]
|
|
136
|
+
curvatures = data[:, 4]
|
|
137
|
+
self.class_stats[class_id] = {
|
|
138
|
+
"length_min": float(lengths.min()),
|
|
139
|
+
"length_max": float(lengths.max()),
|
|
140
|
+
"curvature_min": float(curvatures.min()),
|
|
141
|
+
"curvature_max": float(curvatures.max()),
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
print(
|
|
145
|
+
f"[GMMRandomGenerator] Fitted GMMs for {len(fitted_classes)} classes: {fitted_classes}"
|
|
146
|
+
)
|
|
147
|
+
self.class_spur_counts = per_class_spurs
|
|
148
|
+
return self
|
|
149
|
+
|
|
150
|
+
# ------------------------------------------------------------------
|
|
151
|
+
# Sampling
|
|
152
|
+
# ------------------------------------------------------------------
|
|
153
|
+
|
|
154
|
+
def sample_counts(self, class_id: int, num_samples: int = 1) -> np.ndarray:
|
|
155
|
+
if class_id not in self.count_models or self.count_models[class_id] is None:
|
|
156
|
+
raise ValueError(f"Count model not available for class {class_id}")
|
|
157
|
+
|
|
158
|
+
model = self.count_models[class_id]
|
|
159
|
+
samples = []
|
|
160
|
+
|
|
161
|
+
# Keep sampling until we get samples that include the target class
|
|
162
|
+
for _ in range(num_samples):
|
|
163
|
+
while True:
|
|
164
|
+
sample = model.sample(1)[0]
|
|
165
|
+
if sample[class_id] > 0:
|
|
166
|
+
samples.append(sample)
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
return np.array(samples)
|
|
170
|
+
|
|
171
|
+
def sample_objects(self, counts: np.ndarray) -> Tuple[List[dict], np.ndarray]:
|
|
172
|
+
objects: List[dict] = []
|
|
173
|
+
classes: List[int] = []
|
|
174
|
+
|
|
175
|
+
for class_id in range(self.num_classes):
|
|
176
|
+
if class_id == 0:
|
|
177
|
+
continue
|
|
178
|
+
if class_id not in self.class_gmms:
|
|
179
|
+
continue
|
|
180
|
+
num = int(counts[class_id])
|
|
181
|
+
if num <= 0:
|
|
182
|
+
continue
|
|
183
|
+
|
|
184
|
+
samples, _ = self.class_gmms[class_id].sample(num)
|
|
185
|
+
for sample in samples:
|
|
186
|
+
obj = self._from_gmm_feature(class_id, sample)
|
|
187
|
+
objects.append(obj)
|
|
188
|
+
classes.append(class_id)
|
|
189
|
+
|
|
190
|
+
classes_arr = np.array(classes, dtype=np.int64)
|
|
191
|
+
return objects, classes_arr
|
|
192
|
+
|
|
193
|
+
def sample_image(
|
|
194
|
+
self,
|
|
195
|
+
image_shape: Tuple[int, int],
|
|
196
|
+
colour_map: Optional[dict[tuple[int, int, int], int]] = None,
|
|
197
|
+
overlap_iters: int = 40,
|
|
198
|
+
overlap_margin: float = 0.05,
|
|
199
|
+
) -> Tuple[np.ndarray, str]:
|
|
200
|
+
# Randomly select a class to condition on
|
|
201
|
+
available_classes = [
|
|
202
|
+
c
|
|
203
|
+
for c in range(1, self.num_classes)
|
|
204
|
+
if c in self.count_models and self.count_models[c] is not None
|
|
205
|
+
]
|
|
206
|
+
|
|
207
|
+
if not available_classes:
|
|
208
|
+
return np.zeros(image_shape, dtype=np.uint8), "Empty image"
|
|
209
|
+
|
|
210
|
+
chosen_class = random.choice(available_classes)
|
|
211
|
+
counts = self.sample_counts(chosen_class, 1)[0]
|
|
212
|
+
objects, classes = self.sample_objects(counts)
|
|
213
|
+
|
|
214
|
+
if len(objects) == 0:
|
|
215
|
+
return np.zeros(image_shape, dtype=np.uint8), "Empty image"
|
|
216
|
+
|
|
217
|
+
objects = self._resolve_overlaps(objects, overlap_iters, overlap_margin)
|
|
218
|
+
|
|
219
|
+
image = generate_multiclass_scribble(
|
|
220
|
+
image_shape=image_shape,
|
|
221
|
+
objects=objects,
|
|
222
|
+
classes=classes,
|
|
223
|
+
colour_map=colour_map,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Generate description string
|
|
227
|
+
class_counts: Dict[int, int] = {}
|
|
228
|
+
for class_id in classes:
|
|
229
|
+
class_counts[class_id] = class_counts.get(class_id, 0) + 1
|
|
230
|
+
|
|
231
|
+
description_parts = []
|
|
232
|
+
for class_id in sorted(class_counts.keys()):
|
|
233
|
+
count = class_counts[class_id]
|
|
234
|
+
class_name = self.class_names[class_id]
|
|
235
|
+
if count == 1:
|
|
236
|
+
description_parts.append(f'1 object of class "{class_name}"')
|
|
237
|
+
else:
|
|
238
|
+
description_parts.append(f'{count} objects of class "{class_name}"')
|
|
239
|
+
|
|
240
|
+
description = "; ".join(description_parts)
|
|
241
|
+
return image, description
|
|
242
|
+
|
|
243
|
+
# ------------------------------------------------------------------
|
|
244
|
+
# Internal helpers
|
|
245
|
+
# ------------------------------------------------------------------
|
|
246
|
+
|
|
247
|
+
def _iter_objects(
|
|
248
|
+
self, dataset: ScribbleDataset
|
|
249
|
+
) -> Iterable[Tuple[List[int], List[np.ndarray]]]:
|
|
250
|
+
for filename in dataset.filenames:
|
|
251
|
+
filepath = os.path.join(dataset.data_dir, filename)
|
|
252
|
+
|
|
253
|
+
if dataset.is_rgb:
|
|
254
|
+
if dataset.colour_map is None:
|
|
255
|
+
raise ValueError("colour_map must be provided for RGB annotations")
|
|
256
|
+
mask = cv2.imread(str(filepath), cv2.IMREAD_COLOR)
|
|
257
|
+
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
|
|
258
|
+
mask = rgb_to_indexed(mask, dataset.colour_map)
|
|
259
|
+
else:
|
|
260
|
+
mask = cv2.imread(str(filepath), cv2.IMREAD_GRAYSCALE)
|
|
261
|
+
|
|
262
|
+
class_masks = extract_class_masks(mask)
|
|
263
|
+
class_ids: List[int] = []
|
|
264
|
+
features: List[np.ndarray] = []
|
|
265
|
+
|
|
266
|
+
for class_id, class_mask in class_masks.items():
|
|
267
|
+
objects = get_objects(class_mask)
|
|
268
|
+
for obj_mask in objects:
|
|
269
|
+
feat = extract_object_features(obj_mask)
|
|
270
|
+
class_ids.append(int(class_id))
|
|
271
|
+
features.append(feat)
|
|
272
|
+
|
|
273
|
+
yield class_ids, features
|
|
274
|
+
|
|
275
|
+
def _to_gmm_feature(self, feat: np.ndarray) -> Tuple[np.ndarray, int]:
|
|
276
|
+
data = unpack_feature_vector(feat)
|
|
277
|
+
|
|
278
|
+
start_x = data["start_x"]
|
|
279
|
+
start_y = data["start_y"]
|
|
280
|
+
end_x = data["end_x"]
|
|
281
|
+
end_y = data["end_y"]
|
|
282
|
+
|
|
283
|
+
center_x = (start_x + end_x) / 2.0
|
|
284
|
+
center_y = (start_y + end_y) / 2.0
|
|
285
|
+
|
|
286
|
+
dx = end_x - start_x
|
|
287
|
+
dy = end_y - start_y
|
|
288
|
+
length = math.sqrt(dx * dx + dy * dy)
|
|
289
|
+
angle = math.atan2(dy, dx)
|
|
290
|
+
|
|
291
|
+
curvature = data["curvature"]
|
|
292
|
+
num_spurs = int(round(data["num_spurs"]))
|
|
293
|
+
|
|
294
|
+
return (
|
|
295
|
+
np.array([center_x, center_y, length, angle, curvature], dtype=np.float32),
|
|
296
|
+
num_spurs,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
def _from_gmm_feature(self, class_id: int, sample: np.ndarray) -> dict:
|
|
300
|
+
center_x, center_y, length, angle, curvature = sample.tolist()
|
|
301
|
+
|
|
302
|
+
stats = self.class_stats.get(class_id, {})
|
|
303
|
+
length_min = stats.get("length_min", 0.0)
|
|
304
|
+
length_max = stats.get("length_max", 2.0)
|
|
305
|
+
curvature_min = stats.get("curvature_min", -1.0)
|
|
306
|
+
curvature_max = stats.get("curvature_max", 1.0)
|
|
307
|
+
|
|
308
|
+
length = float(
|
|
309
|
+
np.clip(
|
|
310
|
+
abs(length), max(1e-4, length_min), max(length_min + 1e-4, length_max)
|
|
311
|
+
)
|
|
312
|
+
)
|
|
313
|
+
curvature = float(np.clip(curvature, curvature_min, curvature_max))
|
|
314
|
+
|
|
315
|
+
center_x = float(np.clip(center_x, -1.0, 1.0))
|
|
316
|
+
center_y = float(np.clip(center_y, -1.0, 1.0))
|
|
317
|
+
|
|
318
|
+
cos_angle = math.cos(angle)
|
|
319
|
+
sin_angle = math.sin(angle)
|
|
320
|
+
|
|
321
|
+
length = self._fit_length_to_bounds(
|
|
322
|
+
center_x, center_y, length, cos_angle, sin_angle
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
half = 0.5 * length
|
|
326
|
+
start_x = center_x - half * cos_angle
|
|
327
|
+
start_y = center_y - half * sin_angle
|
|
328
|
+
end_x = center_x + half * cos_angle
|
|
329
|
+
end_y = center_y + half * sin_angle
|
|
330
|
+
|
|
331
|
+
spur_list = self.class_spur_counts.get(class_id, [])
|
|
332
|
+
num_spurs = int(random.choice(spur_list)) if spur_list else 0
|
|
333
|
+
|
|
334
|
+
return {
|
|
335
|
+
"start_x": start_x,
|
|
336
|
+
"start_y": start_y,
|
|
337
|
+
"end_x": end_x,
|
|
338
|
+
"end_y": end_y,
|
|
339
|
+
"num_spurs": num_spurs,
|
|
340
|
+
"curvature": curvature,
|
|
341
|
+
"cos_angle": cos_angle,
|
|
342
|
+
"sin_angle": sin_angle,
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
def _fit_length_to_bounds(
|
|
346
|
+
self,
|
|
347
|
+
center_x: float,
|
|
348
|
+
center_y: float,
|
|
349
|
+
length: float,
|
|
350
|
+
cos_angle: float,
|
|
351
|
+
sin_angle: float,
|
|
352
|
+
) -> float:
|
|
353
|
+
half = 0.5 * length
|
|
354
|
+
|
|
355
|
+
max_half = half
|
|
356
|
+
if abs(cos_angle) > 1e-6:
|
|
357
|
+
max_half = min(max_half, (1.0 - abs(center_x)) / abs(cos_angle))
|
|
358
|
+
if abs(sin_angle) > 1e-6:
|
|
359
|
+
max_half = min(max_half, (1.0 - abs(center_y)) / abs(sin_angle))
|
|
360
|
+
|
|
361
|
+
max_half = max(0.0, max_half)
|
|
362
|
+
return max(2e-4, 2.0 * max_half)
|
|
363
|
+
|
|
364
|
+
def _resolve_overlaps(
|
|
365
|
+
self,
|
|
366
|
+
objects: List[dict],
|
|
367
|
+
max_iters: int,
|
|
368
|
+
min_margin: float,
|
|
369
|
+
) -> List[dict]:
|
|
370
|
+
centers = np.array(
|
|
371
|
+
[
|
|
372
|
+
[
|
|
373
|
+
(obj["start_x"] + obj["end_x"]) / 2.0,
|
|
374
|
+
(obj["start_y"] + obj["end_y"]) / 2.0,
|
|
375
|
+
]
|
|
376
|
+
for obj in objects
|
|
377
|
+
],
|
|
378
|
+
dtype=np.float32,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
radii = np.array(
|
|
382
|
+
[
|
|
383
|
+
0.5
|
|
384
|
+
* math.sqrt(
|
|
385
|
+
(obj["end_x"] - obj["start_x"]) ** 2
|
|
386
|
+
+ (obj["end_y"] - obj["start_y"]) ** 2
|
|
387
|
+
)
|
|
388
|
+
for obj in objects
|
|
389
|
+
],
|
|
390
|
+
dtype=np.float32,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
if len(objects) < 2:
|
|
394
|
+
return objects
|
|
395
|
+
|
|
396
|
+
# Gentle centering force strength (0 = no centering, 1 = full centering)
|
|
397
|
+
centering_strength = 0.1
|
|
398
|
+
|
|
399
|
+
# Iteratively resolve overlaps and push toward center
|
|
400
|
+
for iteration in range(max_iters):
|
|
401
|
+
overlaps_found = False
|
|
402
|
+
|
|
403
|
+
# Resolve overlaps between object pairs
|
|
404
|
+
for i in range(len(objects)):
|
|
405
|
+
for j in range(i + 1, len(objects)):
|
|
406
|
+
# Calculate distance between centers
|
|
407
|
+
dx = centers[j, 0] - centers[i, 0]
|
|
408
|
+
dy = centers[j, 1] - centers[i, 1]
|
|
409
|
+
dist = math.sqrt(dx * dx + dy * dy)
|
|
410
|
+
|
|
411
|
+
# Calculate minimum required distance
|
|
412
|
+
min_dist = radii[i] + radii[j] + min_margin
|
|
413
|
+
|
|
414
|
+
if dist < min_dist and dist > 1e-6:
|
|
415
|
+
overlaps_found = True
|
|
416
|
+
|
|
417
|
+
# Move objects apart
|
|
418
|
+
direction_x = dx / dist
|
|
419
|
+
direction_y = dy / dist
|
|
420
|
+
|
|
421
|
+
overlap = min_dist - dist
|
|
422
|
+
displacement = overlap / 2.0 + 1e-4
|
|
423
|
+
|
|
424
|
+
centers[i, 0] -= direction_x * displacement
|
|
425
|
+
centers[i, 1] -= direction_y * displacement
|
|
426
|
+
centers[j, 0] += direction_x * displacement
|
|
427
|
+
centers[j, 1] += direction_y * displacement
|
|
428
|
+
|
|
429
|
+
# Gently push objects toward center of image
|
|
430
|
+
for i in range(len(objects)):
|
|
431
|
+
# Center of image is at (0, 0) in normalized coordinates
|
|
432
|
+
# Objects closer to edges experience stronger centering force
|
|
433
|
+
dist_from_center = math.sqrt(centers[i, 0] ** 2 + centers[i, 1] ** 2)
|
|
434
|
+
max_dist = math.sqrt(
|
|
435
|
+
2.0
|
|
436
|
+
) # diagonal distance in normalized [-1, 1] space
|
|
437
|
+
adaptive_strength = centering_strength * (dist_from_center / max_dist)
|
|
438
|
+
|
|
439
|
+
centers[i, 0] *= 1.0 - adaptive_strength
|
|
440
|
+
centers[i, 1] *= 1.0 - adaptive_strength
|
|
441
|
+
|
|
442
|
+
# if not overlaps_found:
|
|
443
|
+
# break
|
|
444
|
+
|
|
445
|
+
# Update objects with new centers
|
|
446
|
+
for idx, obj in enumerate(objects):
|
|
447
|
+
center_x = float(np.clip(centers[idx, 0], -1.0, 1.0))
|
|
448
|
+
center_y = float(np.clip(centers[idx, 1], -1.0, 1.0))
|
|
449
|
+
|
|
450
|
+
cos_angle = obj["cos_angle"]
|
|
451
|
+
sin_angle = obj["sin_angle"]
|
|
452
|
+
length = (
|
|
453
|
+
0.5
|
|
454
|
+
* math.sqrt(
|
|
455
|
+
(obj["end_x"] - obj["start_x"]) ** 2
|
|
456
|
+
+ (obj["end_y"] - obj["start_y"]) ** 2
|
|
457
|
+
)
|
|
458
|
+
* 2.0
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
length = self._fit_length_to_bounds(
|
|
462
|
+
center_x, center_y, length, cos_angle, sin_angle
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
half = 0.5 * length
|
|
466
|
+
obj["start_x"] = center_x - half * cos_angle
|
|
467
|
+
obj["start_y"] = center_y - half * sin_angle
|
|
468
|
+
obj["end_x"] = center_x + half * cos_angle
|
|
469
|
+
obj["end_y"] = center_y + half * sin_angle
|
|
470
|
+
|
|
471
|
+
return objects
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def build_parser() -> argparse.ArgumentParser:
|
|
475
|
+
|
|
476
|
+
parser = argparse.ArgumentParser(
|
|
477
|
+
description="Fit per-class GMMs and sample synthetic scribble annotations.",
|
|
478
|
+
)
|
|
479
|
+
parser.add_argument(
|
|
480
|
+
"--data-dir",
|
|
481
|
+
required=True,
|
|
482
|
+
help="Path to the training dataset directory.",
|
|
483
|
+
)
|
|
484
|
+
parser.add_argument(
|
|
485
|
+
"--num-classes",
|
|
486
|
+
type=int,
|
|
487
|
+
required=True,
|
|
488
|
+
help="Total number of classes (including background).",
|
|
489
|
+
)
|
|
490
|
+
parser.add_argument(
|
|
491
|
+
"--colour-map",
|
|
492
|
+
default=None,
|
|
493
|
+
help=(
|
|
494
|
+
"Colour map specified inline as 'R,G,B=class;...' or a path to a file "
|
|
495
|
+
"with one 'R,G,B,class' or 'R,G,B' entry per line. Required for RGB annotations."
|
|
496
|
+
),
|
|
497
|
+
)
|
|
498
|
+
parser.add_argument(
|
|
499
|
+
"--output-dir",
|
|
500
|
+
default="./local/gmm-inference",
|
|
501
|
+
help="Directory to write generated samples.",
|
|
502
|
+
)
|
|
503
|
+
parser.add_argument(
|
|
504
|
+
"--num-samples",
|
|
505
|
+
type=int,
|
|
506
|
+
default=10,
|
|
507
|
+
help="Number of images to generate.",
|
|
508
|
+
)
|
|
509
|
+
parser.add_argument(
|
|
510
|
+
"--height",
|
|
511
|
+
type=int,
|
|
512
|
+
default=512,
|
|
513
|
+
help="Output image height.",
|
|
514
|
+
)
|
|
515
|
+
parser.add_argument(
|
|
516
|
+
"--width",
|
|
517
|
+
type=int,
|
|
518
|
+
default=512,
|
|
519
|
+
help="Output image width.",
|
|
520
|
+
)
|
|
521
|
+
parser.add_argument(
|
|
522
|
+
"--num-components",
|
|
523
|
+
type=int,
|
|
524
|
+
default=5,
|
|
525
|
+
help="Number of GMM components per class.",
|
|
526
|
+
)
|
|
527
|
+
parser.add_argument(
|
|
528
|
+
"--overlap-iters",
|
|
529
|
+
type=int,
|
|
530
|
+
default=40,
|
|
531
|
+
help="Iterations for overlap resolution.",
|
|
532
|
+
)
|
|
533
|
+
parser.add_argument(
|
|
534
|
+
"--overlap-margin",
|
|
535
|
+
type=float,
|
|
536
|
+
default=0.05,
|
|
537
|
+
help="Minimum margin between objects in normalized coordinates.",
|
|
538
|
+
)
|
|
539
|
+
parser.add_argument(
|
|
540
|
+
"--class-names",
|
|
541
|
+
type=str,
|
|
542
|
+
default=None,
|
|
543
|
+
help=("Path to file with one class name per line."),
|
|
544
|
+
)
|
|
545
|
+
return parser
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
def read_class_names(filepath: str) -> List[str]:
|
|
549
|
+
with open(filepath, "r") as f:
|
|
550
|
+
lines = f.read().splitlines()
|
|
551
|
+
return lines
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def train_gmm_model(data_dir: str) -> GMMRandomGenerator:
|
|
555
|
+
"""
|
|
556
|
+
Trains and returns a GMMRandomGenerator model based on the dataset located in data_dir.
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
data_dir (str): Path to the training dataset directory. This directory should contain
|
|
560
|
+
a subdirectory "segmentation" with annotation files, a "colour_map.txt" file,
|
|
561
|
+
and a "class_labelling.txt" file.
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
GMMRandomGenerator: The trained GMMRandomGenerator model.
|
|
565
|
+
"""
|
|
566
|
+
|
|
567
|
+
colour_map = parse_colour_map(os.path.join(data_dir, "colour_map.txt"))
|
|
568
|
+
num_classes = len(colour_map)
|
|
569
|
+
|
|
570
|
+
dataset = ScribbleDataset(
|
|
571
|
+
num_classes=num_classes,
|
|
572
|
+
data_dir=os.path.join(data_dir, "segmentation"),
|
|
573
|
+
colour_map=colour_map,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
class_names = read_class_names(os.path.join(data_dir, "class_labelling.txt"))
|
|
577
|
+
|
|
578
|
+
generator = GMMRandomGenerator(
|
|
579
|
+
num_classes=num_classes,
|
|
580
|
+
num_components=5,
|
|
581
|
+
class_names=class_names,
|
|
582
|
+
).fit(dataset)
|
|
583
|
+
|
|
584
|
+
return generator
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
def main(argv: Optional[List[str]] = None) -> None:
|
|
588
|
+
parser = build_parser()
|
|
589
|
+
args = parser.parse_args(argv)
|
|
590
|
+
|
|
591
|
+
print(f"\n{'='*70}")
|
|
592
|
+
print(f"{'GMMRandomGenerator - Scribble Annotation Generation':^70}")
|
|
593
|
+
print(f"{'='*70}\n")
|
|
594
|
+
|
|
595
|
+
print(f"Configuration:")
|
|
596
|
+
print(f" Data directory: {args.data_dir}")
|
|
597
|
+
print(f" Number of classes: {args.num_classes}")
|
|
598
|
+
print(f" GMM components: {args.num_components}")
|
|
599
|
+
print(f" Output directory: {args.output_dir}")
|
|
600
|
+
print(f" Image size: {args.width}x{args.height}")
|
|
601
|
+
print(f" Samples to gen: {args.num_samples}\n")
|
|
602
|
+
|
|
603
|
+
colour_map = parse_colour_map(args.colour_map) if args.colour_map else None
|
|
604
|
+
|
|
605
|
+
print("[1/3] Loading dataset...")
|
|
606
|
+
dataset = ScribbleDataset(
|
|
607
|
+
num_classes=args.num_classes,
|
|
608
|
+
data_dir=args.data_dir,
|
|
609
|
+
colour_map=colour_map,
|
|
610
|
+
)
|
|
611
|
+
print(f" Loaded {len(dataset.filenames)} annotation files\n")
|
|
612
|
+
|
|
613
|
+
print("[2/3] Fitting GMM models...")
|
|
614
|
+
class_names = read_class_names(args.class_names) if args.class_names else None
|
|
615
|
+
|
|
616
|
+
generator = GMMRandomGenerator(
|
|
617
|
+
num_classes=args.num_classes,
|
|
618
|
+
num_components=args.num_components,
|
|
619
|
+
class_names=class_names,
|
|
620
|
+
).fit(dataset)
|
|
621
|
+
print()
|
|
622
|
+
|
|
623
|
+
print("[3/3] Generating samples...")
|
|
624
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
625
|
+
|
|
626
|
+
for idx in range(args.num_samples):
|
|
627
|
+
image, description = generator.sample_image(
|
|
628
|
+
image_shape=(args.height, args.width),
|
|
629
|
+
colour_map=colour_map,
|
|
630
|
+
overlap_iters=args.overlap_iters,
|
|
631
|
+
overlap_margin=args.overlap_margin,
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
output_path = os.path.join(args.output_dir, f"gmm_sample_{idx:04d}.png")
|
|
635
|
+
if image.ndim == 3:
|
|
636
|
+
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
|
637
|
+
cv2.imwrite(output_path, image_bgr)
|
|
638
|
+
else:
|
|
639
|
+
cv2.imwrite(output_path, image)
|
|
640
|
+
|
|
641
|
+
print(f" Sample {idx + 1}: {description}")
|
|
642
|
+
|
|
643
|
+
if (idx + 1) % 10 == 0 or idx == args.num_samples - 1:
|
|
644
|
+
print(f" Generated {idx + 1}/{args.num_samples} samples")
|
|
645
|
+
|
|
646
|
+
print(f"\n{'='*70}")
|
|
647
|
+
print(f"✓ Generation complete! Samples saved to: {args.output_dir}")
|
|
648
|
+
print(f"{'='*70}\n")
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
if __name__ == "__main__":
|
|
652
|
+
main()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: scribble-annotation-generator
|
|
3
|
-
Version: 0.0
|
|
3
|
+
Version: 0.1.0
|
|
4
4
|
Summary: Programmatically generate semi-realistic synthetic scribble annotations based on statistics from existing scribble datasets
|
|
5
5
|
Project-URL: Homepage, https://github.com/alexsenden/scribble-annotation-generator
|
|
6
6
|
Project-URL: Repository, https://github.com/alexsenden/scribble-annotation-generator
|
|
@@ -24,6 +24,7 @@ Requires-Python: >=3.8
|
|
|
24
24
|
Requires-Dist: numpy
|
|
25
25
|
Requires-Dist: opencv-python
|
|
26
26
|
Requires-Dist: scikit-image
|
|
27
|
+
Requires-Dist: scikit-learn
|
|
27
28
|
Requires-Dist: scipy
|
|
28
29
|
Description-Content-Type: text/markdown
|
|
29
30
|
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
scribble_annotation_generator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
scribble_annotation_generator/cli.py,sha256=e6Ufmf0CLMrKZCL08pB5RqiiQMNJBU9uZwcL86S-Jcw,6251
|
|
3
|
-
scribble_annotation_generator/crop_field.py,sha256
|
|
3
|
+
scribble_annotation_generator/crop_field.py,sha256=zh0jNpjz5_pWzybQNfga9hYqRIYPXysMPp_xNvL37gM,10774
|
|
4
4
|
scribble_annotation_generator/dataset.py,sha256=jauKr8ZBJ1o8jEn8T_RKpVgqu8kwLmucIMyhZCkPiTg,3298
|
|
5
5
|
scribble_annotation_generator/debug.py,sha256=YJnfkBJL7Vwlqz9SWeybAhwID8Pcwzp_RFnL9xPOQyI,1194
|
|
6
|
+
scribble_annotation_generator/gmm_random_generator.py,sha256=j7ZLEYTH8COJsKFJXHBzY4qKTEreDDONmmbad3xaBi8,22161
|
|
6
7
|
scribble_annotation_generator/nn.py,sha256=aSQPkVpvsya942hz02LKoUEkZfL_f7lCvYQ5cI8R3Ts,16627
|
|
7
8
|
scribble_annotation_generator/utils.py,sha256=gluwQSroMd4bg6iwchiv4VBTK57t4OvScH8uv29erWY,13628
|
|
8
|
-
scribble_annotation_generator-0.0.
|
|
9
|
-
scribble_annotation_generator-0.0.
|
|
10
|
-
scribble_annotation_generator-0.0.
|
|
11
|
-
scribble_annotation_generator-0.0.
|
|
9
|
+
scribble_annotation_generator-0.1.0.dist-info/METADATA,sha256=p4MOjTbGTeAi5zS8HGXGj5pK5Op5vE5YCM3JDhRAuIY,4309
|
|
10
|
+
scribble_annotation_generator-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
11
|
+
scribble_annotation_generator-0.1.0.dist-info/entry_points.txt,sha256=A5UbznzAcE5XF5MZrth2rdLvG2IQXhjK2lhklUf9QyU,89
|
|
12
|
+
scribble_annotation_generator-0.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|