scribble-annotation-generator 0.0.1__tar.gz → 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: scribble-annotation-generator
3
- Version: 0.0.1
3
+ Version: 0.1.0
4
4
  Summary: Programmatically generate semi-realistic synthetic scribble annotations based on statistics from existing scribble datasets
5
5
  Project-URL: Homepage, https://github.com/alexsenden/scribble-annotation-generator
6
6
  Project-URL: Repository, https://github.com/alexsenden/scribble-annotation-generator
@@ -24,6 +24,7 @@ Requires-Python: >=3.8
24
24
  Requires-Dist: numpy
25
25
  Requires-Dist: opencv-python
26
26
  Requires-Dist: scikit-image
27
+ Requires-Dist: scikit-learn
27
28
  Requires-Dist: scipy
28
29
  Description-Content-Type: text/markdown
29
30
 
@@ -31,7 +31,7 @@ classifiers = [
31
31
  "Topic :: Scientific/Engineering :: Artificial Intelligence",
32
32
  ]
33
33
 
34
- dependencies = ["numpy", "scipy", "scikit-image", "opencv-python"]
34
+ dependencies = ["numpy", "scipy", "scikit-image", "opencv-python", "scikit-learn"]
35
35
 
36
36
  # Version is automatically provided by hatch-vcs
37
37
  dynamic = ["version"]
@@ -327,13 +327,14 @@ def generate_sample(
327
327
 
328
328
 
329
329
  def generate_crop_field_dataset(
330
- output_dir: str,
331
330
  colour_map: dict,
331
+ output_dir: str = None,
332
332
  num_samples: int = NUM_SAMPLES_TO_GENERATE,
333
333
  min_rows: int = 4,
334
334
  max_rows: int = 6,
335
335
  ):
336
- os.makedirs(output_dir, exist_ok=True)
336
+ if output_dir is not None:
337
+ os.makedirs(output_dir, exist_ok=True)
337
338
 
338
339
  for i in range(num_samples):
339
340
  num_rows = random.randint(min_rows, max_rows)
@@ -0,0 +1,652 @@
1
+ import argparse
2
+ import math
3
+ import random
4
+ from collections import Counter
5
+ from dataclasses import dataclass
6
+ from typing import Dict, Iterable, List, Optional, Tuple
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from sklearn.mixture import GaussianMixture
11
+
12
+ import os
13
+
14
+ from scribble_annotation_generator.dataset import ScribbleDataset
15
+ from scribble_annotation_generator.cli import parse_colour_map
16
+ from scribble_annotation_generator.utils import (
17
+ extract_class_masks,
18
+ extract_object_features,
19
+ generate_multiclass_scribble,
20
+ get_objects,
21
+ rgb_to_indexed,
22
+ unpack_feature_vector,
23
+ )
24
+
25
+
26
+ @dataclass
27
+ class CountModelEmpirical:
28
+ """Empirical distribution over count vectors."""
29
+
30
+ counts: np.ndarray
31
+ probs: np.ndarray
32
+
33
+ @classmethod
34
+ def fit(cls, count_vectors: Iterable[np.ndarray]) -> "CountModelEmpirical":
35
+ tuples = [tuple(vec.tolist()) for vec in count_vectors]
36
+ counter = Counter(tuples)
37
+ keys = list(counter.keys())
38
+ values = np.array([counter[k] for k in keys], dtype=np.float64)
39
+ probs = values / max(values.sum(), 1.0)
40
+ counts = np.array(keys, dtype=np.int64)
41
+ return cls(counts=counts, probs=probs)
42
+
43
+ def sample(self, num_samples: int = 1) -> np.ndarray:
44
+ idx = np.random.choice(len(self.counts), size=num_samples, p=self.probs)
45
+ return self.counts[idx]
46
+
47
+
48
+ class GMMRandomGenerator:
49
+ """Fit per-class GMMs over object geometry and an empirical count model."""
50
+
51
+ def __init__(
52
+ self,
53
+ num_classes: int,
54
+ num_components: int = 5,
55
+ random_state: int = 0,
56
+ class_names: Optional[List[str]] = None,
57
+ ) -> None:
58
+ self.num_classes = num_classes
59
+ self.num_components = num_components
60
+ self.random_state = random_state
61
+ self.class_names = class_names or [f"class_{i}" for i in range(num_classes)]
62
+
63
+ self.class_gmms: Dict[int, GaussianMixture] = {}
64
+ self.class_stats: Dict[int, Dict[str, float]] = {}
65
+ self.class_spur_counts: Dict[int, List[int]] = {}
66
+ self.count_models: Dict[int, Optional[CountModelEmpirical]] = (
67
+ {}
68
+ ) # Per-class count models
69
+
70
+ # ------------------------------------------------------------------
71
+ # Fitting
72
+ # ------------------------------------------------------------------
73
+
74
+ def fit(self, dataset: ScribbleDataset) -> "GMMRandomGenerator":
75
+ print(
76
+ f"[GMMRandomGenerator] Starting fit on dataset with {len(dataset.filenames)} files..."
77
+ )
78
+
79
+ per_class_samples: Dict[int, List[np.ndarray]] = {
80
+ i: [] for i in range(self.num_classes)
81
+ }
82
+ per_class_spurs: Dict[int, List[int]] = {i: [] for i in range(self.num_classes)}
83
+ count_vectors: List[np.ndarray] = []
84
+
85
+ for class_ids, features in self._iter_objects(dataset):
86
+ counts = np.zeros(self.num_classes, dtype=np.int64)
87
+ for class_id in class_ids:
88
+ counts[class_id] += 1
89
+ count_vectors.append(counts)
90
+
91
+ for class_id, feat in zip(class_ids, features):
92
+ gmm_feat, spur_count = self._to_gmm_feature(feat)
93
+ per_class_samples[class_id].append(gmm_feat)
94
+ per_class_spurs[class_id].append(spur_count)
95
+
96
+ # Fit per-class count models (conditional on class presence)
97
+ for class_id in range(1, self.num_classes): # Skip background class (0)
98
+ # Filter count vectors to only those containing this class
99
+ class_specific_counts = [
100
+ counts for counts in count_vectors if counts[class_id] > 0
101
+ ]
102
+
103
+ if len(class_specific_counts) > 0:
104
+ self.count_models[class_id] = CountModelEmpirical.fit(
105
+ class_specific_counts
106
+ )
107
+ print(
108
+ f"[GMMRandomGenerator] Fitted count model for class {class_id} with {len(class_specific_counts)} observations"
109
+ )
110
+ else:
111
+ fallback_counts = np.zeros(self.num_classes, dtype=np.int64)
112
+ fallback_counts[class_id] = 1
113
+ self.count_models[class_id] = CountModelEmpirical.fit(
114
+ [fallback_counts]
115
+ )
116
+ print(
117
+ f"[GMMRandomGenerator] No samples for class {class_id}, using fallback singleton count"
118
+ )
119
+
120
+ fitted_classes = []
121
+ for class_id, samples in per_class_samples.items():
122
+ if len(samples) < 2:
123
+ continue
124
+ data = np.stack(samples, axis=0)
125
+ num_components = min(self.num_components, len(samples))
126
+ gmm = GaussianMixture(
127
+ n_components=num_components,
128
+ covariance_type="full",
129
+ random_state=self.random_state,
130
+ )
131
+ gmm.fit(data)
132
+ self.class_gmms[class_id] = gmm
133
+ fitted_classes.append(class_id)
134
+
135
+ lengths = data[:, 2]
136
+ curvatures = data[:, 4]
137
+ self.class_stats[class_id] = {
138
+ "length_min": float(lengths.min()),
139
+ "length_max": float(lengths.max()),
140
+ "curvature_min": float(curvatures.min()),
141
+ "curvature_max": float(curvatures.max()),
142
+ }
143
+
144
+ print(
145
+ f"[GMMRandomGenerator] Fitted GMMs for {len(fitted_classes)} classes: {fitted_classes}"
146
+ )
147
+ self.class_spur_counts = per_class_spurs
148
+ return self
149
+
150
+ # ------------------------------------------------------------------
151
+ # Sampling
152
+ # ------------------------------------------------------------------
153
+
154
+ def sample_counts(self, class_id: int, num_samples: int = 1) -> np.ndarray:
155
+ if class_id not in self.count_models or self.count_models[class_id] is None:
156
+ raise ValueError(f"Count model not available for class {class_id}")
157
+
158
+ model = self.count_models[class_id]
159
+ samples = []
160
+
161
+ # Keep sampling until we get samples that include the target class
162
+ for _ in range(num_samples):
163
+ while True:
164
+ sample = model.sample(1)[0]
165
+ if sample[class_id] > 0:
166
+ samples.append(sample)
167
+ break
168
+
169
+ return np.array(samples)
170
+
171
+ def sample_objects(self, counts: np.ndarray) -> Tuple[List[dict], np.ndarray]:
172
+ objects: List[dict] = []
173
+ classes: List[int] = []
174
+
175
+ for class_id in range(self.num_classes):
176
+ if class_id == 0:
177
+ continue
178
+ if class_id not in self.class_gmms:
179
+ continue
180
+ num = int(counts[class_id])
181
+ if num <= 0:
182
+ continue
183
+
184
+ samples, _ = self.class_gmms[class_id].sample(num)
185
+ for sample in samples:
186
+ obj = self._from_gmm_feature(class_id, sample)
187
+ objects.append(obj)
188
+ classes.append(class_id)
189
+
190
+ classes_arr = np.array(classes, dtype=np.int64)
191
+ return objects, classes_arr
192
+
193
+ def sample_image(
194
+ self,
195
+ image_shape: Tuple[int, int],
196
+ colour_map: Optional[dict[tuple[int, int, int], int]] = None,
197
+ overlap_iters: int = 40,
198
+ overlap_margin: float = 0.05,
199
+ ) -> Tuple[np.ndarray, str]:
200
+ # Randomly select a class to condition on
201
+ available_classes = [
202
+ c
203
+ for c in range(1, self.num_classes)
204
+ if c in self.count_models and self.count_models[c] is not None
205
+ ]
206
+
207
+ if not available_classes:
208
+ return np.zeros(image_shape, dtype=np.uint8), "Empty image"
209
+
210
+ chosen_class = random.choice(available_classes)
211
+ counts = self.sample_counts(chosen_class, 1)[0]
212
+ objects, classes = self.sample_objects(counts)
213
+
214
+ if len(objects) == 0:
215
+ return np.zeros(image_shape, dtype=np.uint8), "Empty image"
216
+
217
+ objects = self._resolve_overlaps(objects, overlap_iters, overlap_margin)
218
+
219
+ image = generate_multiclass_scribble(
220
+ image_shape=image_shape,
221
+ objects=objects,
222
+ classes=classes,
223
+ colour_map=colour_map,
224
+ )
225
+
226
+ # Generate description string
227
+ class_counts: Dict[int, int] = {}
228
+ for class_id in classes:
229
+ class_counts[class_id] = class_counts.get(class_id, 0) + 1
230
+
231
+ description_parts = []
232
+ for class_id in sorted(class_counts.keys()):
233
+ count = class_counts[class_id]
234
+ class_name = self.class_names[class_id]
235
+ if count == 1:
236
+ description_parts.append(f'1 object of class "{class_name}"')
237
+ else:
238
+ description_parts.append(f'{count} objects of class "{class_name}"')
239
+
240
+ description = "; ".join(description_parts)
241
+ return image, description
242
+
243
+ # ------------------------------------------------------------------
244
+ # Internal helpers
245
+ # ------------------------------------------------------------------
246
+
247
+ def _iter_objects(
248
+ self, dataset: ScribbleDataset
249
+ ) -> Iterable[Tuple[List[int], List[np.ndarray]]]:
250
+ for filename in dataset.filenames:
251
+ filepath = os.path.join(dataset.data_dir, filename)
252
+
253
+ if dataset.is_rgb:
254
+ if dataset.colour_map is None:
255
+ raise ValueError("colour_map must be provided for RGB annotations")
256
+ mask = cv2.imread(str(filepath), cv2.IMREAD_COLOR)
257
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
258
+ mask = rgb_to_indexed(mask, dataset.colour_map)
259
+ else:
260
+ mask = cv2.imread(str(filepath), cv2.IMREAD_GRAYSCALE)
261
+
262
+ class_masks = extract_class_masks(mask)
263
+ class_ids: List[int] = []
264
+ features: List[np.ndarray] = []
265
+
266
+ for class_id, class_mask in class_masks.items():
267
+ objects = get_objects(class_mask)
268
+ for obj_mask in objects:
269
+ feat = extract_object_features(obj_mask)
270
+ class_ids.append(int(class_id))
271
+ features.append(feat)
272
+
273
+ yield class_ids, features
274
+
275
+ def _to_gmm_feature(self, feat: np.ndarray) -> Tuple[np.ndarray, int]:
276
+ data = unpack_feature_vector(feat)
277
+
278
+ start_x = data["start_x"]
279
+ start_y = data["start_y"]
280
+ end_x = data["end_x"]
281
+ end_y = data["end_y"]
282
+
283
+ center_x = (start_x + end_x) / 2.0
284
+ center_y = (start_y + end_y) / 2.0
285
+
286
+ dx = end_x - start_x
287
+ dy = end_y - start_y
288
+ length = math.sqrt(dx * dx + dy * dy)
289
+ angle = math.atan2(dy, dx)
290
+
291
+ curvature = data["curvature"]
292
+ num_spurs = int(round(data["num_spurs"]))
293
+
294
+ return (
295
+ np.array([center_x, center_y, length, angle, curvature], dtype=np.float32),
296
+ num_spurs,
297
+ )
298
+
299
+ def _from_gmm_feature(self, class_id: int, sample: np.ndarray) -> dict:
300
+ center_x, center_y, length, angle, curvature = sample.tolist()
301
+
302
+ stats = self.class_stats.get(class_id, {})
303
+ length_min = stats.get("length_min", 0.0)
304
+ length_max = stats.get("length_max", 2.0)
305
+ curvature_min = stats.get("curvature_min", -1.0)
306
+ curvature_max = stats.get("curvature_max", 1.0)
307
+
308
+ length = float(
309
+ np.clip(
310
+ abs(length), max(1e-4, length_min), max(length_min + 1e-4, length_max)
311
+ )
312
+ )
313
+ curvature = float(np.clip(curvature, curvature_min, curvature_max))
314
+
315
+ center_x = float(np.clip(center_x, -1.0, 1.0))
316
+ center_y = float(np.clip(center_y, -1.0, 1.0))
317
+
318
+ cos_angle = math.cos(angle)
319
+ sin_angle = math.sin(angle)
320
+
321
+ length = self._fit_length_to_bounds(
322
+ center_x, center_y, length, cos_angle, sin_angle
323
+ )
324
+
325
+ half = 0.5 * length
326
+ start_x = center_x - half * cos_angle
327
+ start_y = center_y - half * sin_angle
328
+ end_x = center_x + half * cos_angle
329
+ end_y = center_y + half * sin_angle
330
+
331
+ spur_list = self.class_spur_counts.get(class_id, [])
332
+ num_spurs = int(random.choice(spur_list)) if spur_list else 0
333
+
334
+ return {
335
+ "start_x": start_x,
336
+ "start_y": start_y,
337
+ "end_x": end_x,
338
+ "end_y": end_y,
339
+ "num_spurs": num_spurs,
340
+ "curvature": curvature,
341
+ "cos_angle": cos_angle,
342
+ "sin_angle": sin_angle,
343
+ }
344
+
345
+ def _fit_length_to_bounds(
346
+ self,
347
+ center_x: float,
348
+ center_y: float,
349
+ length: float,
350
+ cos_angle: float,
351
+ sin_angle: float,
352
+ ) -> float:
353
+ half = 0.5 * length
354
+
355
+ max_half = half
356
+ if abs(cos_angle) > 1e-6:
357
+ max_half = min(max_half, (1.0 - abs(center_x)) / abs(cos_angle))
358
+ if abs(sin_angle) > 1e-6:
359
+ max_half = min(max_half, (1.0 - abs(center_y)) / abs(sin_angle))
360
+
361
+ max_half = max(0.0, max_half)
362
+ return max(2e-4, 2.0 * max_half)
363
+
364
+ def _resolve_overlaps(
365
+ self,
366
+ objects: List[dict],
367
+ max_iters: int,
368
+ min_margin: float,
369
+ ) -> List[dict]:
370
+ centers = np.array(
371
+ [
372
+ [
373
+ (obj["start_x"] + obj["end_x"]) / 2.0,
374
+ (obj["start_y"] + obj["end_y"]) / 2.0,
375
+ ]
376
+ for obj in objects
377
+ ],
378
+ dtype=np.float32,
379
+ )
380
+
381
+ radii = np.array(
382
+ [
383
+ 0.5
384
+ * math.sqrt(
385
+ (obj["end_x"] - obj["start_x"]) ** 2
386
+ + (obj["end_y"] - obj["start_y"]) ** 2
387
+ )
388
+ for obj in objects
389
+ ],
390
+ dtype=np.float32,
391
+ )
392
+
393
+ if len(objects) < 2:
394
+ return objects
395
+
396
+ # Gentle centering force strength (0 = no centering, 1 = full centering)
397
+ centering_strength = 0.1
398
+
399
+ # Iteratively resolve overlaps and push toward center
400
+ for iteration in range(max_iters):
401
+ overlaps_found = False
402
+
403
+ # Resolve overlaps between object pairs
404
+ for i in range(len(objects)):
405
+ for j in range(i + 1, len(objects)):
406
+ # Calculate distance between centers
407
+ dx = centers[j, 0] - centers[i, 0]
408
+ dy = centers[j, 1] - centers[i, 1]
409
+ dist = math.sqrt(dx * dx + dy * dy)
410
+
411
+ # Calculate minimum required distance
412
+ min_dist = radii[i] + radii[j] + min_margin
413
+
414
+ if dist < min_dist and dist > 1e-6:
415
+ overlaps_found = True
416
+
417
+ # Move objects apart
418
+ direction_x = dx / dist
419
+ direction_y = dy / dist
420
+
421
+ overlap = min_dist - dist
422
+ displacement = overlap / 2.0 + 1e-4
423
+
424
+ centers[i, 0] -= direction_x * displacement
425
+ centers[i, 1] -= direction_y * displacement
426
+ centers[j, 0] += direction_x * displacement
427
+ centers[j, 1] += direction_y * displacement
428
+
429
+ # Gently push objects toward center of image
430
+ for i in range(len(objects)):
431
+ # Center of image is at (0, 0) in normalized coordinates
432
+ # Objects closer to edges experience stronger centering force
433
+ dist_from_center = math.sqrt(centers[i, 0] ** 2 + centers[i, 1] ** 2)
434
+ max_dist = math.sqrt(
435
+ 2.0
436
+ ) # diagonal distance in normalized [-1, 1] space
437
+ adaptive_strength = centering_strength * (dist_from_center / max_dist)
438
+
439
+ centers[i, 0] *= 1.0 - adaptive_strength
440
+ centers[i, 1] *= 1.0 - adaptive_strength
441
+
442
+ # if not overlaps_found:
443
+ # break
444
+
445
+ # Update objects with new centers
446
+ for idx, obj in enumerate(objects):
447
+ center_x = float(np.clip(centers[idx, 0], -1.0, 1.0))
448
+ center_y = float(np.clip(centers[idx, 1], -1.0, 1.0))
449
+
450
+ cos_angle = obj["cos_angle"]
451
+ sin_angle = obj["sin_angle"]
452
+ length = (
453
+ 0.5
454
+ * math.sqrt(
455
+ (obj["end_x"] - obj["start_x"]) ** 2
456
+ + (obj["end_y"] - obj["start_y"]) ** 2
457
+ )
458
+ * 2.0
459
+ )
460
+
461
+ length = self._fit_length_to_bounds(
462
+ center_x, center_y, length, cos_angle, sin_angle
463
+ )
464
+
465
+ half = 0.5 * length
466
+ obj["start_x"] = center_x - half * cos_angle
467
+ obj["start_y"] = center_y - half * sin_angle
468
+ obj["end_x"] = center_x + half * cos_angle
469
+ obj["end_y"] = center_y + half * sin_angle
470
+
471
+ return objects
472
+
473
+
474
+ def build_parser() -> argparse.ArgumentParser:
475
+
476
+ parser = argparse.ArgumentParser(
477
+ description="Fit per-class GMMs and sample synthetic scribble annotations.",
478
+ )
479
+ parser.add_argument(
480
+ "--data-dir",
481
+ required=True,
482
+ help="Path to the training dataset directory.",
483
+ )
484
+ parser.add_argument(
485
+ "--num-classes",
486
+ type=int,
487
+ required=True,
488
+ help="Total number of classes (including background).",
489
+ )
490
+ parser.add_argument(
491
+ "--colour-map",
492
+ default=None,
493
+ help=(
494
+ "Colour map specified inline as 'R,G,B=class;...' or a path to a file "
495
+ "with one 'R,G,B,class' or 'R,G,B' entry per line. Required for RGB annotations."
496
+ ),
497
+ )
498
+ parser.add_argument(
499
+ "--output-dir",
500
+ default="./local/gmm-inference",
501
+ help="Directory to write generated samples.",
502
+ )
503
+ parser.add_argument(
504
+ "--num-samples",
505
+ type=int,
506
+ default=10,
507
+ help="Number of images to generate.",
508
+ )
509
+ parser.add_argument(
510
+ "--height",
511
+ type=int,
512
+ default=512,
513
+ help="Output image height.",
514
+ )
515
+ parser.add_argument(
516
+ "--width",
517
+ type=int,
518
+ default=512,
519
+ help="Output image width.",
520
+ )
521
+ parser.add_argument(
522
+ "--num-components",
523
+ type=int,
524
+ default=5,
525
+ help="Number of GMM components per class.",
526
+ )
527
+ parser.add_argument(
528
+ "--overlap-iters",
529
+ type=int,
530
+ default=40,
531
+ help="Iterations for overlap resolution.",
532
+ )
533
+ parser.add_argument(
534
+ "--overlap-margin",
535
+ type=float,
536
+ default=0.05,
537
+ help="Minimum margin between objects in normalized coordinates.",
538
+ )
539
+ parser.add_argument(
540
+ "--class-names",
541
+ type=str,
542
+ default=None,
543
+ help=("Path to file with one class name per line."),
544
+ )
545
+ return parser
546
+
547
+
548
+ def read_class_names(filepath: str) -> List[str]:
549
+ with open(filepath, "r") as f:
550
+ lines = f.read().splitlines()
551
+ return lines
552
+
553
+
554
+ def train_gmm_model(data_dir: str) -> GMMRandomGenerator:
555
+ """
556
+ Trains and returns a GMMRandomGenerator model based on the dataset located in data_dir.
557
+
558
+ Args:
559
+ data_dir (str): Path to the training dataset directory. This directory should contain
560
+ a subdirectory "segmentation" with annotation files, a "colour_map.txt" file,
561
+ and a "class_labelling.txt" file.
562
+
563
+ Returns:
564
+ GMMRandomGenerator: The trained GMMRandomGenerator model.
565
+ """
566
+
567
+ colour_map = parse_colour_map(os.path.join(data_dir, "colour_map.txt"))
568
+ num_classes = len(colour_map)
569
+
570
+ dataset = ScribbleDataset(
571
+ num_classes=num_classes,
572
+ data_dir=os.path.join(data_dir, "segmentation"),
573
+ colour_map=colour_map,
574
+ )
575
+
576
+ class_names = read_class_names(os.path.join(data_dir, "class_labelling.txt"))
577
+
578
+ generator = GMMRandomGenerator(
579
+ num_classes=num_classes,
580
+ num_components=5,
581
+ class_names=class_names,
582
+ ).fit(dataset)
583
+
584
+ return generator
585
+
586
+
587
+ def main(argv: Optional[List[str]] = None) -> None:
588
+ parser = build_parser()
589
+ args = parser.parse_args(argv)
590
+
591
+ print(f"\n{'='*70}")
592
+ print(f"{'GMMRandomGenerator - Scribble Annotation Generation':^70}")
593
+ print(f"{'='*70}\n")
594
+
595
+ print(f"Configuration:")
596
+ print(f" Data directory: {args.data_dir}")
597
+ print(f" Number of classes: {args.num_classes}")
598
+ print(f" GMM components: {args.num_components}")
599
+ print(f" Output directory: {args.output_dir}")
600
+ print(f" Image size: {args.width}x{args.height}")
601
+ print(f" Samples to gen: {args.num_samples}\n")
602
+
603
+ colour_map = parse_colour_map(args.colour_map) if args.colour_map else None
604
+
605
+ print("[1/3] Loading dataset...")
606
+ dataset = ScribbleDataset(
607
+ num_classes=args.num_classes,
608
+ data_dir=args.data_dir,
609
+ colour_map=colour_map,
610
+ )
611
+ print(f" Loaded {len(dataset.filenames)} annotation files\n")
612
+
613
+ print("[2/3] Fitting GMM models...")
614
+ class_names = read_class_names(args.class_names) if args.class_names else None
615
+
616
+ generator = GMMRandomGenerator(
617
+ num_classes=args.num_classes,
618
+ num_components=args.num_components,
619
+ class_names=class_names,
620
+ ).fit(dataset)
621
+ print()
622
+
623
+ print("[3/3] Generating samples...")
624
+ os.makedirs(args.output_dir, exist_ok=True)
625
+
626
+ for idx in range(args.num_samples):
627
+ image, description = generator.sample_image(
628
+ image_shape=(args.height, args.width),
629
+ colour_map=colour_map,
630
+ overlap_iters=args.overlap_iters,
631
+ overlap_margin=args.overlap_margin,
632
+ )
633
+
634
+ output_path = os.path.join(args.output_dir, f"gmm_sample_{idx:04d}.png")
635
+ if image.ndim == 3:
636
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
637
+ cv2.imwrite(output_path, image_bgr)
638
+ else:
639
+ cv2.imwrite(output_path, image)
640
+
641
+ print(f" Sample {idx + 1}: {description}")
642
+
643
+ if (idx + 1) % 10 == 0 or idx == args.num_samples - 1:
644
+ print(f" Generated {idx + 1}/{args.num_samples} samples")
645
+
646
+ print(f"\n{'='*70}")
647
+ print(f"✓ Generation complete! Samples saved to: {args.output_dir}")
648
+ print(f"{'='*70}\n")
649
+
650
+
651
+ if __name__ == "__main__":
652
+ main()