scribble-annotation-generator 0.0.1__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: scribble-annotation-generator
3
- Version: 0.0.1
3
+ Version: 0.1.1
4
4
  Summary: Programmatically generate semi-realistic synthetic scribble annotations based on statistics from existing scribble datasets
5
5
  Project-URL: Homepage, https://github.com/alexsenden/scribble-annotation-generator
6
6
  Project-URL: Repository, https://github.com/alexsenden/scribble-annotation-generator
@@ -24,6 +24,7 @@ Requires-Python: >=3.8
24
24
  Requires-Dist: numpy
25
25
  Requires-Dist: opencv-python
26
26
  Requires-Dist: scikit-image
27
+ Requires-Dist: scikit-learn
27
28
  Requires-Dist: scipy
28
29
  Description-Content-Type: text/markdown
29
30
 
@@ -31,7 +31,7 @@ classifiers = [
31
31
  "Topic :: Scientific/Engineering :: Artificial Intelligence",
32
32
  ]
33
33
 
34
- dependencies = ["numpy", "scipy", "scikit-image", "opencv-python"]
34
+ dependencies = ["numpy", "scipy", "scikit-image", "opencv-python", "scikit-learn"]
35
35
 
36
36
  # Version is automatically provided by hatch-vcs
37
37
  dynamic = ["version"]
@@ -327,13 +327,14 @@ def generate_sample(
327
327
 
328
328
 
329
329
  def generate_crop_field_dataset(
330
- output_dir: str,
331
330
  colour_map: dict,
331
+ output_dir: str = None,
332
332
  num_samples: int = NUM_SAMPLES_TO_GENERATE,
333
333
  min_rows: int = 4,
334
334
  max_rows: int = 6,
335
335
  ):
336
- os.makedirs(output_dir, exist_ok=True)
336
+ if output_dir is not None:
337
+ os.makedirs(output_dir, exist_ok=True)
337
338
 
338
339
  for i in range(num_samples):
339
340
  num_rows = random.randint(min_rows, max_rows)
@@ -0,0 +1,654 @@
1
+ import argparse
2
+ import math
3
+ import random
4
+ from collections import Counter
5
+ from dataclasses import dataclass
6
+ from typing import Dict, Iterable, List, Optional, Tuple
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from sklearn.mixture import GaussianMixture
11
+
12
+ import os
13
+
14
+ from scribble_annotation_generator.dataset import ScribbleDataset
15
+ from scribble_annotation_generator.cli import parse_colour_map
16
+ from scribble_annotation_generator.utils import (
17
+ extract_class_masks,
18
+ extract_object_features,
19
+ generate_multiclass_scribble,
20
+ get_objects,
21
+ rgb_to_indexed,
22
+ unpack_feature_vector,
23
+ )
24
+
25
+
26
+ @dataclass
27
+ class CountModelEmpirical:
28
+ """Empirical distribution over count vectors."""
29
+
30
+ counts: np.ndarray
31
+ probs: np.ndarray
32
+
33
+ @classmethod
34
+ def fit(cls, count_vectors: Iterable[np.ndarray]) -> "CountModelEmpirical":
35
+ tuples = [tuple(vec.tolist()) for vec in count_vectors]
36
+ counter = Counter(tuples)
37
+ keys = list(counter.keys())
38
+ values = np.array([counter[k] for k in keys], dtype=np.float64)
39
+ probs = values / max(values.sum(), 1.0)
40
+ counts = np.array(keys, dtype=np.int64)
41
+ return cls(counts=counts, probs=probs)
42
+
43
+ def sample(self, num_samples: int = 1) -> np.ndarray:
44
+ idx = np.random.choice(len(self.counts), size=num_samples, p=self.probs)
45
+ return self.counts[idx]
46
+
47
+
48
+ class GMMRandomGenerator:
49
+ """Fit per-class GMMs over object geometry and an empirical count model."""
50
+
51
+ def __init__(
52
+ self,
53
+ num_classes: int,
54
+ num_components: int = 5,
55
+ random_state: int = 0,
56
+ class_names: Optional[List[str]] = None,
57
+ colour_map: Optional[dict[tuple[int, int, int], int]] = None,
58
+ ) -> None:
59
+ self.num_classes = num_classes
60
+ self.num_components = num_components
61
+ self.random_state = random_state
62
+ self.class_names = class_names or [f"class_{i}" for i in range(num_classes)]
63
+ self.colour_map = colour_map
64
+
65
+ self.class_gmms: Dict[int, GaussianMixture] = {}
66
+ self.class_stats: Dict[int, Dict[str, float]] = {}
67
+ self.class_spur_counts: Dict[int, List[int]] = {}
68
+ self.count_models: Dict[int, Optional[CountModelEmpirical]] = (
69
+ {}
70
+ ) # Per-class count models
71
+
72
+ # ------------------------------------------------------------------
73
+ # Fitting
74
+ # ------------------------------------------------------------------
75
+
76
+ def fit(self, dataset: ScribbleDataset) -> "GMMRandomGenerator":
77
+ print(
78
+ f"[GMMRandomGenerator] Starting fit on dataset with {len(dataset.filenames)} files..."
79
+ )
80
+
81
+ per_class_samples: Dict[int, List[np.ndarray]] = {
82
+ i: [] for i in range(self.num_classes)
83
+ }
84
+ per_class_spurs: Dict[int, List[int]] = {i: [] for i in range(self.num_classes)}
85
+ count_vectors: List[np.ndarray] = []
86
+
87
+ for class_ids, features in self._iter_objects(dataset):
88
+ counts = np.zeros(self.num_classes, dtype=np.int64)
89
+ for class_id in class_ids:
90
+ counts[class_id] += 1
91
+ count_vectors.append(counts)
92
+
93
+ for class_id, feat in zip(class_ids, features):
94
+ gmm_feat, spur_count = self._to_gmm_feature(feat)
95
+ per_class_samples[class_id].append(gmm_feat)
96
+ per_class_spurs[class_id].append(spur_count)
97
+
98
+ # Fit per-class count models (conditional on class presence)
99
+ for class_id in range(1, self.num_classes): # Skip background class (0)
100
+ # Filter count vectors to only those containing this class
101
+ class_specific_counts = [
102
+ counts for counts in count_vectors if counts[class_id] > 0
103
+ ]
104
+
105
+ if len(class_specific_counts) > 0:
106
+ self.count_models[class_id] = CountModelEmpirical.fit(
107
+ class_specific_counts
108
+ )
109
+ print(
110
+ f"[GMMRandomGenerator] Fitted count model for class {class_id} with {len(class_specific_counts)} observations"
111
+ )
112
+ else:
113
+ fallback_counts = np.zeros(self.num_classes, dtype=np.int64)
114
+ fallback_counts[class_id] = 1
115
+ self.count_models[class_id] = CountModelEmpirical.fit(
116
+ [fallback_counts]
117
+ )
118
+ print(
119
+ f"[GMMRandomGenerator] No samples for class {class_id}, using fallback singleton count"
120
+ )
121
+
122
+ fitted_classes = []
123
+ for class_id, samples in per_class_samples.items():
124
+ if len(samples) < 2:
125
+ continue
126
+ data = np.stack(samples, axis=0)
127
+ num_components = min(self.num_components, len(samples))
128
+ gmm = GaussianMixture(
129
+ n_components=num_components,
130
+ covariance_type="full",
131
+ random_state=self.random_state,
132
+ )
133
+ gmm.fit(data)
134
+ self.class_gmms[class_id] = gmm
135
+ fitted_classes.append(class_id)
136
+
137
+ lengths = data[:, 2]
138
+ curvatures = data[:, 4]
139
+ self.class_stats[class_id] = {
140
+ "length_min": float(lengths.min()),
141
+ "length_max": float(lengths.max()),
142
+ "curvature_min": float(curvatures.min()),
143
+ "curvature_max": float(curvatures.max()),
144
+ }
145
+
146
+ print(
147
+ f"[GMMRandomGenerator] Fitted GMMs for {len(fitted_classes)} classes: {fitted_classes}"
148
+ )
149
+ self.class_spur_counts = per_class_spurs
150
+ return self
151
+
152
+ # ------------------------------------------------------------------
153
+ # Sampling
154
+ # ------------------------------------------------------------------
155
+
156
+ def sample_counts(self, class_id: int, num_samples: int = 1) -> np.ndarray:
157
+ if class_id not in self.count_models or self.count_models[class_id] is None:
158
+ raise ValueError(f"Count model not available for class {class_id}")
159
+
160
+ model = self.count_models[class_id]
161
+ samples = []
162
+
163
+ # Keep sampling until we get samples that include the target class
164
+ for _ in range(num_samples):
165
+ while True:
166
+ sample = model.sample(1)[0]
167
+ if sample[class_id] > 0:
168
+ samples.append(sample)
169
+ break
170
+
171
+ return np.array(samples)
172
+
173
+ def sample_objects(self, counts: np.ndarray) -> Tuple[List[dict], np.ndarray]:
174
+ objects: List[dict] = []
175
+ classes: List[int] = []
176
+
177
+ for class_id in range(self.num_classes):
178
+ if class_id == 0:
179
+ continue
180
+ if class_id not in self.class_gmms:
181
+ continue
182
+ num = int(counts[class_id])
183
+ if num <= 0:
184
+ continue
185
+
186
+ samples, _ = self.class_gmms[class_id].sample(num)
187
+ for sample in samples:
188
+ obj = self._from_gmm_feature(class_id, sample)
189
+ objects.append(obj)
190
+ classes.append(class_id)
191
+
192
+ classes_arr = np.array(classes, dtype=np.int64)
193
+ return objects, classes_arr
194
+
195
+ def sample_image(
196
+ self,
197
+ image_shape: Tuple[int, int],
198
+ overlap_iters: int = 40,
199
+ overlap_margin: float = 0.05,
200
+ ) -> Tuple[np.ndarray, str]:
201
+ # Randomly select a class to condition on
202
+ available_classes = [
203
+ c
204
+ for c in range(1, self.num_classes)
205
+ if c in self.count_models and self.count_models[c] is not None
206
+ ]
207
+
208
+ if not available_classes:
209
+ return np.zeros(image_shape, dtype=np.uint8), "Empty image"
210
+
211
+ chosen_class = random.choice(available_classes)
212
+ counts = self.sample_counts(chosen_class, 1)[0]
213
+ objects, classes = self.sample_objects(counts)
214
+
215
+ if len(objects) == 0:
216
+ return np.zeros(image_shape, dtype=np.uint8), "Empty image"
217
+
218
+ objects = self._resolve_overlaps(objects, overlap_iters, overlap_margin)
219
+
220
+ image = generate_multiclass_scribble(
221
+ image_shape=image_shape,
222
+ objects=objects,
223
+ classes=classes,
224
+ colour_map=self.colour_map,
225
+ )
226
+
227
+ # Generate description string
228
+ class_counts: Dict[int, int] = {}
229
+ for class_id in classes:
230
+ class_counts[class_id] = class_counts.get(class_id, 0) + 1
231
+
232
+ description_parts = []
233
+ for class_id in sorted(class_counts.keys()):
234
+ count = class_counts[class_id]
235
+ class_name = self.class_names[class_id]
236
+ if count == 1:
237
+ description_parts.append(f'1 object of class "{class_name}"')
238
+ else:
239
+ description_parts.append(f'{count} objects of class "{class_name}"')
240
+
241
+ description = "; ".join(description_parts)
242
+ return image, description
243
+
244
+ # ------------------------------------------------------------------
245
+ # Internal helpers
246
+ # ------------------------------------------------------------------
247
+
248
+ def _iter_objects(
249
+ self, dataset: ScribbleDataset
250
+ ) -> Iterable[Tuple[List[int], List[np.ndarray]]]:
251
+ for filename in dataset.filenames:
252
+ filepath = os.path.join(dataset.data_dir, filename)
253
+
254
+ if dataset.is_rgb:
255
+ if dataset.colour_map is None:
256
+ raise ValueError("colour_map must be provided for RGB annotations")
257
+ mask = cv2.imread(str(filepath), cv2.IMREAD_COLOR)
258
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
259
+ mask = rgb_to_indexed(mask, dataset.colour_map)
260
+ else:
261
+ mask = cv2.imread(str(filepath), cv2.IMREAD_GRAYSCALE)
262
+
263
+ class_masks = extract_class_masks(mask)
264
+ class_ids: List[int] = []
265
+ features: List[np.ndarray] = []
266
+
267
+ for class_id, class_mask in class_masks.items():
268
+ objects = get_objects(class_mask)
269
+ for obj_mask in objects:
270
+ feat = extract_object_features(obj_mask)
271
+ class_ids.append(int(class_id))
272
+ features.append(feat)
273
+
274
+ yield class_ids, features
275
+
276
+ def _to_gmm_feature(self, feat: np.ndarray) -> Tuple[np.ndarray, int]:
277
+ data = unpack_feature_vector(feat)
278
+
279
+ start_x = data["start_x"]
280
+ start_y = data["start_y"]
281
+ end_x = data["end_x"]
282
+ end_y = data["end_y"]
283
+
284
+ center_x = (start_x + end_x) / 2.0
285
+ center_y = (start_y + end_y) / 2.0
286
+
287
+ dx = end_x - start_x
288
+ dy = end_y - start_y
289
+ length = math.sqrt(dx * dx + dy * dy)
290
+ angle = math.atan2(dy, dx)
291
+
292
+ curvature = data["curvature"]
293
+ num_spurs = int(round(data["num_spurs"]))
294
+
295
+ return (
296
+ np.array([center_x, center_y, length, angle, curvature], dtype=np.float32),
297
+ num_spurs,
298
+ )
299
+
300
+ def _from_gmm_feature(self, class_id: int, sample: np.ndarray) -> dict:
301
+ center_x, center_y, length, angle, curvature = sample.tolist()
302
+
303
+ stats = self.class_stats.get(class_id, {})
304
+ length_min = stats.get("length_min", 0.0)
305
+ length_max = stats.get("length_max", 2.0)
306
+ curvature_min = stats.get("curvature_min", -1.0)
307
+ curvature_max = stats.get("curvature_max", 1.0)
308
+
309
+ length = float(
310
+ np.clip(
311
+ abs(length), max(1e-4, length_min), max(length_min + 1e-4, length_max)
312
+ )
313
+ )
314
+ curvature = float(np.clip(curvature, curvature_min, curvature_max))
315
+
316
+ center_x = float(np.clip(center_x, -1.0, 1.0))
317
+ center_y = float(np.clip(center_y, -1.0, 1.0))
318
+
319
+ cos_angle = math.cos(angle)
320
+ sin_angle = math.sin(angle)
321
+
322
+ length = self._fit_length_to_bounds(
323
+ center_x, center_y, length, cos_angle, sin_angle
324
+ )
325
+
326
+ half = 0.5 * length
327
+ start_x = center_x - half * cos_angle
328
+ start_y = center_y - half * sin_angle
329
+ end_x = center_x + half * cos_angle
330
+ end_y = center_y + half * sin_angle
331
+
332
+ spur_list = self.class_spur_counts.get(class_id, [])
333
+ num_spurs = int(random.choice(spur_list)) if spur_list else 0
334
+
335
+ return {
336
+ "start_x": start_x,
337
+ "start_y": start_y,
338
+ "end_x": end_x,
339
+ "end_y": end_y,
340
+ "num_spurs": num_spurs,
341
+ "curvature": curvature,
342
+ "cos_angle": cos_angle,
343
+ "sin_angle": sin_angle,
344
+ }
345
+
346
+ def _fit_length_to_bounds(
347
+ self,
348
+ center_x: float,
349
+ center_y: float,
350
+ length: float,
351
+ cos_angle: float,
352
+ sin_angle: float,
353
+ ) -> float:
354
+ half = 0.5 * length
355
+
356
+ max_half = half
357
+ if abs(cos_angle) > 1e-6:
358
+ max_half = min(max_half, (1.0 - abs(center_x)) / abs(cos_angle))
359
+ if abs(sin_angle) > 1e-6:
360
+ max_half = min(max_half, (1.0 - abs(center_y)) / abs(sin_angle))
361
+
362
+ max_half = max(0.0, max_half)
363
+ return max(2e-4, 2.0 * max_half)
364
+
365
+ def _resolve_overlaps(
366
+ self,
367
+ objects: List[dict],
368
+ max_iters: int,
369
+ min_margin: float,
370
+ ) -> List[dict]:
371
+ centers = np.array(
372
+ [
373
+ [
374
+ (obj["start_x"] + obj["end_x"]) / 2.0,
375
+ (obj["start_y"] + obj["end_y"]) / 2.0,
376
+ ]
377
+ for obj in objects
378
+ ],
379
+ dtype=np.float32,
380
+ )
381
+
382
+ radii = np.array(
383
+ [
384
+ 0.5
385
+ * math.sqrt(
386
+ (obj["end_x"] - obj["start_x"]) ** 2
387
+ + (obj["end_y"] - obj["start_y"]) ** 2
388
+ )
389
+ for obj in objects
390
+ ],
391
+ dtype=np.float32,
392
+ )
393
+
394
+ if len(objects) < 2:
395
+ return objects
396
+
397
+ # Gentle centering force strength (0 = no centering, 1 = full centering)
398
+ centering_strength = 0.1
399
+
400
+ # Iteratively resolve overlaps and push toward center
401
+ for iteration in range(max_iters):
402
+ overlaps_found = False
403
+
404
+ # Resolve overlaps between object pairs
405
+ for i in range(len(objects)):
406
+ for j in range(i + 1, len(objects)):
407
+ # Calculate distance between centers
408
+ dx = centers[j, 0] - centers[i, 0]
409
+ dy = centers[j, 1] - centers[i, 1]
410
+ dist = math.sqrt(dx * dx + dy * dy)
411
+
412
+ # Calculate minimum required distance
413
+ min_dist = radii[i] + radii[j] + min_margin
414
+
415
+ if dist < min_dist and dist > 1e-6:
416
+ overlaps_found = True
417
+
418
+ # Move objects apart
419
+ direction_x = dx / dist
420
+ direction_y = dy / dist
421
+
422
+ overlap = min_dist - dist
423
+ displacement = overlap / 2.0 + 1e-4
424
+
425
+ centers[i, 0] -= direction_x * displacement
426
+ centers[i, 1] -= direction_y * displacement
427
+ centers[j, 0] += direction_x * displacement
428
+ centers[j, 1] += direction_y * displacement
429
+
430
+ # Gently push objects toward center of image
431
+ for i in range(len(objects)):
432
+ # Center of image is at (0, 0) in normalized coordinates
433
+ # Objects closer to edges experience stronger centering force
434
+ dist_from_center = math.sqrt(centers[i, 0] ** 2 + centers[i, 1] ** 2)
435
+ max_dist = math.sqrt(
436
+ 2.0
437
+ ) # diagonal distance in normalized [-1, 1] space
438
+ adaptive_strength = centering_strength * (dist_from_center / max_dist)
439
+
440
+ centers[i, 0] *= 1.0 - adaptive_strength
441
+ centers[i, 1] *= 1.0 - adaptive_strength
442
+
443
+ # if not overlaps_found:
444
+ # break
445
+
446
+ # Update objects with new centers
447
+ for idx, obj in enumerate(objects):
448
+ center_x = float(np.clip(centers[idx, 0], -1.0, 1.0))
449
+ center_y = float(np.clip(centers[idx, 1], -1.0, 1.0))
450
+
451
+ cos_angle = obj["cos_angle"]
452
+ sin_angle = obj["sin_angle"]
453
+ length = (
454
+ 0.5
455
+ * math.sqrt(
456
+ (obj["end_x"] - obj["start_x"]) ** 2
457
+ + (obj["end_y"] - obj["start_y"]) ** 2
458
+ )
459
+ * 2.0
460
+ )
461
+
462
+ length = self._fit_length_to_bounds(
463
+ center_x, center_y, length, cos_angle, sin_angle
464
+ )
465
+
466
+ half = 0.5 * length
467
+ obj["start_x"] = center_x - half * cos_angle
468
+ obj["start_y"] = center_y - half * sin_angle
469
+ obj["end_x"] = center_x + half * cos_angle
470
+ obj["end_y"] = center_y + half * sin_angle
471
+
472
+ return objects
473
+
474
+
475
+ def build_parser() -> argparse.ArgumentParser:
476
+
477
+ parser = argparse.ArgumentParser(
478
+ description="Fit per-class GMMs and sample synthetic scribble annotations.",
479
+ )
480
+ parser.add_argument(
481
+ "--data-dir",
482
+ required=True,
483
+ help="Path to the training dataset directory.",
484
+ )
485
+ parser.add_argument(
486
+ "--num-classes",
487
+ type=int,
488
+ required=True,
489
+ help="Total number of classes (including background).",
490
+ )
491
+ parser.add_argument(
492
+ "--colour-map",
493
+ default=None,
494
+ help=(
495
+ "Colour map specified inline as 'R,G,B=class;...' or a path to a file "
496
+ "with one 'R,G,B,class' or 'R,G,B' entry per line. Required for RGB annotations."
497
+ ),
498
+ )
499
+ parser.add_argument(
500
+ "--output-dir",
501
+ default="./local/gmm-inference",
502
+ help="Directory to write generated samples.",
503
+ )
504
+ parser.add_argument(
505
+ "--num-samples",
506
+ type=int,
507
+ default=10,
508
+ help="Number of images to generate.",
509
+ )
510
+ parser.add_argument(
511
+ "--height",
512
+ type=int,
513
+ default=512,
514
+ help="Output image height.",
515
+ )
516
+ parser.add_argument(
517
+ "--width",
518
+ type=int,
519
+ default=512,
520
+ help="Output image width.",
521
+ )
522
+ parser.add_argument(
523
+ "--num-components",
524
+ type=int,
525
+ default=5,
526
+ help="Number of GMM components per class.",
527
+ )
528
+ parser.add_argument(
529
+ "--overlap-iters",
530
+ type=int,
531
+ default=40,
532
+ help="Iterations for overlap resolution.",
533
+ )
534
+ parser.add_argument(
535
+ "--overlap-margin",
536
+ type=float,
537
+ default=0.05,
538
+ help="Minimum margin between objects in normalized coordinates.",
539
+ )
540
+ parser.add_argument(
541
+ "--class-names",
542
+ type=str,
543
+ default=None,
544
+ help=("Path to file with one class name per line."),
545
+ )
546
+ return parser
547
+
548
+
549
+ def read_class_names(filepath: str) -> List[str]:
550
+ with open(filepath, "r") as f:
551
+ lines = f.read().splitlines()
552
+ return lines
553
+
554
+
555
+ def train_gmm_model(data_dir: str) -> GMMRandomGenerator:
556
+ """
557
+ Trains and returns a GMMRandomGenerator model based on the dataset located in data_dir.
558
+
559
+ Args:
560
+ data_dir (str): Path to the training dataset directory. This directory should contain
561
+ a subdirectory "segmentation" with annotation files, a "colour_map.txt" file,
562
+ and a "class_labelling.txt" file.
563
+
564
+ Returns:
565
+ GMMRandomGenerator: The trained GMMRandomGenerator model.
566
+ """
567
+
568
+ colour_map = parse_colour_map(os.path.join(data_dir, "colour_map.txt"))
569
+ num_classes = len(colour_map)
570
+
571
+ dataset = ScribbleDataset(
572
+ num_classes=num_classes,
573
+ data_dir=os.path.join(data_dir, "segmentation"),
574
+ colour_map=colour_map,
575
+ )
576
+
577
+ class_names = read_class_names(os.path.join(data_dir, "class_labelling.txt"))
578
+
579
+ generator = GMMRandomGenerator(
580
+ num_classes=num_classes,
581
+ num_components=5,
582
+ class_names=class_names,
583
+ colour_map=colour_map,
584
+ ).fit(dataset)
585
+
586
+ return generator
587
+
588
+
589
+ def main(argv: Optional[List[str]] = None) -> None:
590
+ parser = build_parser()
591
+ args = parser.parse_args(argv)
592
+
593
+ print(f"\n{'='*70}")
594
+ print(f"{'GMMRandomGenerator - Scribble Annotation Generation':^70}")
595
+ print(f"{'='*70}\n")
596
+
597
+ print(f"Configuration:")
598
+ print(f" Data directory: {args.data_dir}")
599
+ print(f" Number of classes: {args.num_classes}")
600
+ print(f" GMM components: {args.num_components}")
601
+ print(f" Output directory: {args.output_dir}")
602
+ print(f" Image size: {args.width}x{args.height}")
603
+ print(f" Samples to gen: {args.num_samples}\n")
604
+
605
+ colour_map = parse_colour_map(args.colour_map) if args.colour_map else None
606
+
607
+ print("[1/3] Loading dataset...")
608
+ dataset = ScribbleDataset(
609
+ num_classes=args.num_classes,
610
+ data_dir=args.data_dir,
611
+ colour_map=colour_map,
612
+ )
613
+ print(f" Loaded {len(dataset.filenames)} annotation files\n")
614
+
615
+ print("[2/3] Fitting GMM models...")
616
+ class_names = read_class_names(args.class_names) if args.class_names else None
617
+
618
+ generator = GMMRandomGenerator(
619
+ num_classes=args.num_classes,
620
+ num_components=args.num_components,
621
+ class_names=class_names,
622
+ colour_map=colour_map,
623
+ ).fit(dataset)
624
+ print()
625
+
626
+ print("[3/3] Generating samples...")
627
+ os.makedirs(args.output_dir, exist_ok=True)
628
+
629
+ for idx in range(args.num_samples):
630
+ image, description = generator.sample_image(
631
+ image_shape=(args.height, args.width),
632
+ overlap_iters=args.overlap_iters,
633
+ overlap_margin=args.overlap_margin,
634
+ )
635
+
636
+ output_path = os.path.join(args.output_dir, f"gmm_sample_{idx:04d}.png")
637
+ if image.ndim == 3:
638
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
639
+ cv2.imwrite(output_path, image_bgr)
640
+ else:
641
+ cv2.imwrite(output_path, image)
642
+
643
+ print(f" Sample {idx + 1}: {description}")
644
+
645
+ if (idx + 1) % 10 == 0 or idx == args.num_samples - 1:
646
+ print(f" Generated {idx + 1}/{args.num_samples} samples")
647
+
648
+ print(f"\n{'='*70}")
649
+ print(f"✓ Generation complete! Samples saved to: {args.output_dir}")
650
+ print(f"{'='*70}\n")
651
+
652
+
653
+ if __name__ == "__main__":
654
+ main()