cudag 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudag/__init__.py +334 -0
- cudag/annotation/__init__.py +77 -0
- cudag/annotation/codegen.py +648 -0
- cudag/annotation/config.py +545 -0
- cudag/annotation/loader.py +342 -0
- cudag/annotation/scaffold.py +121 -0
- cudag/annotation/transcription.py +296 -0
- cudag/cli/__init__.py +5 -0
- cudag/cli/main.py +315 -0
- cudag/cli/new.py +873 -0
- cudag/core/__init__.py +364 -0
- cudag/core/button.py +137 -0
- cudag/core/canvas.py +222 -0
- cudag/core/config.py +70 -0
- cudag/core/coords.py +233 -0
- cudag/core/data_grid.py +804 -0
- cudag/core/dataset.py +678 -0
- cudag/core/distribution.py +136 -0
- cudag/core/drawing.py +75 -0
- cudag/core/fonts.py +156 -0
- cudag/core/generator.py +163 -0
- cudag/core/grid.py +367 -0
- cudag/core/grounding_task.py +247 -0
- cudag/core/icon.py +207 -0
- cudag/core/iconlist_task.py +301 -0
- cudag/core/models.py +1251 -0
- cudag/core/random.py +130 -0
- cudag/core/renderer.py +190 -0
- cudag/core/screen.py +402 -0
- cudag/core/scroll_task.py +254 -0
- cudag/core/scrollable_grid.py +447 -0
- cudag/core/state.py +110 -0
- cudag/core/task.py +293 -0
- cudag/core/taskbar.py +350 -0
- cudag/core/text.py +212 -0
- cudag/core/utils.py +82 -0
- cudag/data/surnames.txt +5000 -0
- cudag/modal_apps/__init__.py +4 -0
- cudag/modal_apps/archive.py +103 -0
- cudag/modal_apps/extract.py +138 -0
- cudag/modal_apps/preprocess.py +529 -0
- cudag/modal_apps/upload.py +317 -0
- cudag/prompts/SYSTEM_PROMPT.txt +104 -0
- cudag/prompts/__init__.py +33 -0
- cudag/prompts/system.py +43 -0
- cudag/prompts/tools.py +382 -0
- cudag/py.typed +0 -0
- cudag/schemas/filesystem.json +90 -0
- cudag/schemas/test_record.schema.json +113 -0
- cudag/schemas/train_record.schema.json +90 -0
- cudag/server/__init__.py +21 -0
- cudag/server/app.py +232 -0
- cudag/server/services/__init__.py +9 -0
- cudag/server/services/generator.py +128 -0
- cudag/templates/scripts/archive.sh +35 -0
- cudag/templates/scripts/build.sh +13 -0
- cudag/templates/scripts/extract.sh +54 -0
- cudag/templates/scripts/generate.sh +116 -0
- cudag/templates/scripts/pre-commit.sh +44 -0
- cudag/templates/scripts/preprocess.sh +46 -0
- cudag/templates/scripts/upload.sh +63 -0
- cudag/templates/scripts/verify.py +428 -0
- cudag/validation/__init__.py +35 -0
- cudag/validation/validate.py +508 -0
- cudag-0.3.10.dist-info/METADATA +570 -0
- cudag-0.3.10.dist-info/RECORD +69 -0
- cudag-0.3.10.dist-info/WHEEL +4 -0
- cudag-0.3.10.dist-info/entry_points.txt +2 -0
- cudag-0.3.10.dist-info/licenses/LICENSE +66 -0
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# Copyright (c) 2025 Tylt LLC. All rights reserved.
|
|
2
|
+
# CONFIDENTIAL AND PROPRIETARY. Unauthorized use, copying, or distribution
|
|
3
|
+
# is strictly prohibited. For licensing inquiries: hello@claimhawk.app
|
|
4
|
+
|
|
5
|
+
"""Distribution sampling utilities for task generation.
|
|
6
|
+
|
|
7
|
+
Provides weighted random sampling from configured distributions,
|
|
8
|
+
commonly used for controlling task type distribution in datasets.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from random import Random
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from cudag.core.dataset import DatasetConfig
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class DistributionSampler:
|
|
23
|
+
"""Weighted random sampling from a configured distribution.
|
|
24
|
+
|
|
25
|
+
This class encapsulates the pattern of sampling from task-specific
|
|
26
|
+
distributions defined in dataset configuration. Distributions are
|
|
27
|
+
dictionaries mapping type names to probability weights that must
|
|
28
|
+
sum to approximately 1.0.
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
distribution: Mapping of type names to probability weights
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
>>> sampler = DistributionSampler({
|
|
35
|
+
... "normal": 0.8,
|
|
36
|
+
... "edge_case": 0.15,
|
|
37
|
+
... "adversarial": 0.05
|
|
38
|
+
... })
|
|
39
|
+
>>> rng = Random(42)
|
|
40
|
+
>>> sampler.sample(rng)
|
|
41
|
+
'normal'
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
distribution: dict[str, float]
|
|
45
|
+
|
|
46
|
+
def __post_init__(self) -> None:
|
|
47
|
+
"""Validate that probabilities sum to approximately 1.0."""
|
|
48
|
+
if not self.distribution:
|
|
49
|
+
raise ValueError("Distribution cannot be empty")
|
|
50
|
+
|
|
51
|
+
total = sum(self.distribution.values())
|
|
52
|
+
if not (0.99 <= total <= 1.01):
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Distribution probabilities must sum to 1.0, got {total:.4f}. "
|
|
55
|
+
f"Distribution: {self.distribution}"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def sample(self, rng: Random) -> str:
|
|
59
|
+
"""Sample a distribution type based on configured weights.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
rng: Random number generator
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Sampled distribution type name
|
|
66
|
+
"""
|
|
67
|
+
rand = rng.random()
|
|
68
|
+
cumulative = 0.0
|
|
69
|
+
for dist_type, weight in self.distribution.items():
|
|
70
|
+
cumulative += weight
|
|
71
|
+
if rand < cumulative:
|
|
72
|
+
return dist_type
|
|
73
|
+
# Fallback to last type (handles floating point edge cases)
|
|
74
|
+
return list(self.distribution.keys())[-1]
|
|
75
|
+
|
|
76
|
+
def sample_n(self, rng: Random, n: int) -> list[str]:
|
|
77
|
+
"""Sample n distribution types.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
rng: Random number generator
|
|
81
|
+
n: Number of samples to generate
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
List of sampled distribution type names
|
|
85
|
+
"""
|
|
86
|
+
return [self.sample(rng) for _ in range(n)]
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def from_config(
|
|
90
|
+
cls,
|
|
91
|
+
config: DatasetConfig,
|
|
92
|
+
task_type: str,
|
|
93
|
+
default: dict[str, float] | None = None,
|
|
94
|
+
) -> DistributionSampler:
|
|
95
|
+
"""Create sampler from dataset configuration.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
config: Dataset configuration object
|
|
99
|
+
task_type: Task type to get distribution for
|
|
100
|
+
default: Default distribution if not in config
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Configured DistributionSampler instance
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
ValueError: If no distribution found and no default provided
|
|
107
|
+
"""
|
|
108
|
+
dist = config.get_distribution(task_type)
|
|
109
|
+
if not dist and default:
|
|
110
|
+
dist = default
|
|
111
|
+
if not dist:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"No distribution found for task type: '{task_type}'. "
|
|
114
|
+
f"Either add a distribution to your config or provide a default."
|
|
115
|
+
)
|
|
116
|
+
return cls(dist)
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def uniform(cls, types: list[str]) -> DistributionSampler:
|
|
120
|
+
"""Create a uniform distribution over the given types.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
types: List of type names to distribute uniformly
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
DistributionSampler with equal probability for each type
|
|
127
|
+
|
|
128
|
+
Example:
|
|
129
|
+
>>> sampler = DistributionSampler.uniform(["a", "b", "c"])
|
|
130
|
+
>>> sampler.distribution
|
|
131
|
+
{'a': 0.333..., 'b': 0.333..., 'c': 0.333...}
|
|
132
|
+
"""
|
|
133
|
+
if not types:
|
|
134
|
+
raise ValueError("Types list cannot be empty")
|
|
135
|
+
prob = 1.0 / len(types)
|
|
136
|
+
return cls({t: prob for t in types})
|
cudag/core/drawing.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# Copyright (c) 2025 Tylt LLC. All rights reserved.
|
|
2
|
+
# CONFIDENTIAL AND PROPRIETARY. Unauthorized use, copying, or distribution
|
|
3
|
+
# is strictly prohibited. For licensing inquiries: hello@claimhawk.app
|
|
4
|
+
|
|
5
|
+
"""Drawing utilities for CUDAG framework."""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from PIL import Image, ImageDraw
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def render_scrollbar(
|
|
13
|
+
content_height: int,
|
|
14
|
+
visible_height: int,
|
|
15
|
+
scroll_offset: int,
|
|
16
|
+
width: int = 12,
|
|
17
|
+
*,
|
|
18
|
+
min_thumb: int = 30,
|
|
19
|
+
track_color: tuple[int, int, int] = (240, 240, 240),
|
|
20
|
+
thumb_color: tuple[int, int, int] = (100, 100, 100),
|
|
21
|
+
thumb_width: int = 4,
|
|
22
|
+
) -> Image.Image:
|
|
23
|
+
"""Render a scrollbar track with thumb indicating position.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
content_height: Total content height in pixels.
|
|
27
|
+
visible_height: Visible viewport height.
|
|
28
|
+
scroll_offset: Current scroll offset in pixels.
|
|
29
|
+
width: Scrollbar width.
|
|
30
|
+
min_thumb: Minimum thumb height in pixels.
|
|
31
|
+
track_color: RGB color for scrollbar track.
|
|
32
|
+
thumb_color: RGB color for scrollbar thumb.
|
|
33
|
+
thumb_width: Width of the thumb in pixels.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
PIL Image of the scrollbar.
|
|
37
|
+
|
|
38
|
+
Example:
|
|
39
|
+
>>> scrollbar = render_scrollbar(
|
|
40
|
+
... content_height=1000,
|
|
41
|
+
... visible_height=400,
|
|
42
|
+
... scroll_offset=200,
|
|
43
|
+
... width=12,
|
|
44
|
+
... )
|
|
45
|
+
>>> scrollbar.size
|
|
46
|
+
(12, 400)
|
|
47
|
+
"""
|
|
48
|
+
track_height = visible_height
|
|
49
|
+
|
|
50
|
+
# Create light gray track
|
|
51
|
+
track = Image.new("RGB", (width, track_height), color=track_color)
|
|
52
|
+
draw = ImageDraw.Draw(track)
|
|
53
|
+
|
|
54
|
+
if content_height <= visible_height:
|
|
55
|
+
# No scrolling needed - no thumb
|
|
56
|
+
return track
|
|
57
|
+
|
|
58
|
+
# Calculate thumb size proportional to visible content
|
|
59
|
+
ratio = visible_height / content_height
|
|
60
|
+
thumb_height = max(min_thumb, int(track_height * ratio))
|
|
61
|
+
|
|
62
|
+
# Calculate thumb position
|
|
63
|
+
max_offset = max(content_height - visible_height, 1)
|
|
64
|
+
travel = track_height - thumb_height
|
|
65
|
+
thumb_y = int((scroll_offset / max_offset) * travel) if travel > 0 else 0
|
|
66
|
+
|
|
67
|
+
# Draw thumb (centered thin dark rectangle)
|
|
68
|
+
thumb_x0 = (width - thumb_width) // 2
|
|
69
|
+
thumb_x1 = thumb_x0 + thumb_width
|
|
70
|
+
draw.rectangle(
|
|
71
|
+
[(thumb_x0, thumb_y), (thumb_x1, thumb_y + thumb_height)],
|
|
72
|
+
fill=thumb_color,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
return track
|
cudag/core/fonts.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
# Copyright (c) 2025 Tylt LLC. All rights reserved.
|
|
2
|
+
# CONFIDENTIAL AND PROPRIETARY. Unauthorized use, copying, or distribution
|
|
3
|
+
# is strictly prohibited. For licensing inquiries: hello@claimhawk.app
|
|
4
|
+
|
|
5
|
+
"""Font loading utilities with platform-aware fallbacks.
|
|
6
|
+
|
|
7
|
+
This module provides utilities for loading fonts with automatic fallback
|
|
8
|
+
to system fonts when the primary font is not available.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import sys
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Sequence
|
|
16
|
+
|
|
17
|
+
from PIL import ImageFont
|
|
18
|
+
from PIL.ImageFont import FreeTypeFont
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Common system font paths by platform
|
|
22
|
+
SYSTEM_FONTS: dict[str, list[str]] = {
|
|
23
|
+
"darwin": [
|
|
24
|
+
"/System/Library/Fonts/Helvetica.ttc",
|
|
25
|
+
"/System/Library/Fonts/SFNSText.ttf",
|
|
26
|
+
"/Library/Fonts/Arial.ttf",
|
|
27
|
+
"/System/Library/Fonts/SFNS.ttf",
|
|
28
|
+
],
|
|
29
|
+
"linux": [
|
|
30
|
+
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
|
|
31
|
+
"/usr/share/fonts/TTF/DejaVuSans.ttf",
|
|
32
|
+
"/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf",
|
|
33
|
+
],
|
|
34
|
+
"win32": [
|
|
35
|
+
"C:/Windows/Fonts/arial.ttf",
|
|
36
|
+
"C:/Windows/Fonts/segoeui.ttf",
|
|
37
|
+
"C:/Windows/Fonts/tahoma.ttf",
|
|
38
|
+
],
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def load_font(
|
|
43
|
+
primary_path: Path | str,
|
|
44
|
+
size: int,
|
|
45
|
+
fallbacks: Sequence[Path | str] | None = None,
|
|
46
|
+
) -> FreeTypeFont:
|
|
47
|
+
"""Load a font with fallback support.
|
|
48
|
+
|
|
49
|
+
Tries primary font first, then fallbacks in order, finally falls back
|
|
50
|
+
to platform-specific system fonts.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
primary_path: Primary font file path (absolute or relative).
|
|
54
|
+
size: Font size in points.
|
|
55
|
+
fallbacks: Optional sequence of fallback font paths to try before
|
|
56
|
+
system fonts.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Loaded FreeTypeFont ready for use with PIL ImageDraw.
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
OSError: If no font could be loaded from any path.
|
|
63
|
+
|
|
64
|
+
Example:
|
|
65
|
+
Basic usage::
|
|
66
|
+
|
|
67
|
+
font = load_font("assets/fonts/Inter.ttf", size=14)
|
|
68
|
+
|
|
69
|
+
With explicit fallbacks::
|
|
70
|
+
|
|
71
|
+
font = load_font(
|
|
72
|
+
self.asset_path("fonts/Inter.ttf"),
|
|
73
|
+
size=14,
|
|
74
|
+
fallbacks=["/System/Library/Fonts/Helvetica.ttc"]
|
|
75
|
+
)
|
|
76
|
+
"""
|
|
77
|
+
paths_to_try: list[Path] = [Path(primary_path)]
|
|
78
|
+
|
|
79
|
+
if fallbacks:
|
|
80
|
+
paths_to_try.extend(Path(p) for p in fallbacks)
|
|
81
|
+
|
|
82
|
+
# Add platform-specific system fonts
|
|
83
|
+
platform = sys.platform
|
|
84
|
+
if platform in SYSTEM_FONTS:
|
|
85
|
+
paths_to_try.extend(Path(p) for p in SYSTEM_FONTS[platform])
|
|
86
|
+
|
|
87
|
+
errors: list[str] = []
|
|
88
|
+
for path in paths_to_try:
|
|
89
|
+
try:
|
|
90
|
+
return ImageFont.truetype(str(path), size)
|
|
91
|
+
except OSError as e:
|
|
92
|
+
errors.append(f"{path}: {e}")
|
|
93
|
+
continue
|
|
94
|
+
|
|
95
|
+
raise OSError(
|
|
96
|
+
f"Could not load any font. Tried {len(paths_to_try)} paths:\n"
|
|
97
|
+
+ "\n".join(f" - {err}" for err in errors[:5])
|
|
98
|
+
+ (f"\n ... and {len(errors) - 5} more" if len(errors) > 5 else "")
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def load_font_family(
|
|
103
|
+
regular: Path | str,
|
|
104
|
+
size: int,
|
|
105
|
+
*,
|
|
106
|
+
bold: Path | str | None = None,
|
|
107
|
+
italic: Path | str | None = None,
|
|
108
|
+
bold_italic: Path | str | None = None,
|
|
109
|
+
fallbacks: Sequence[Path | str] | None = None,
|
|
110
|
+
) -> dict[str, FreeTypeFont]:
|
|
111
|
+
"""Load a font family with multiple weights/styles.
|
|
112
|
+
|
|
113
|
+
Loads regular font and optionally bold, italic, and bold-italic variants.
|
|
114
|
+
If a variant fails to load, it falls back to regular.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
regular: Path to regular weight font (required).
|
|
118
|
+
size: Font size in points.
|
|
119
|
+
bold: Optional path to bold weight font.
|
|
120
|
+
italic: Optional path to italic font.
|
|
121
|
+
bold_italic: Optional path to bold-italic font.
|
|
122
|
+
fallbacks: Optional fallbacks for primary fonts.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Dictionary with keys 'regular', 'bold', 'italic', 'bold_italic'.
|
|
126
|
+
Missing variants fall back to 'regular'.
|
|
127
|
+
|
|
128
|
+
Example:
|
|
129
|
+
::
|
|
130
|
+
|
|
131
|
+
fonts = load_font_family(
|
|
132
|
+
"fonts/Inter-Regular.ttf",
|
|
133
|
+
size=14,
|
|
134
|
+
bold="fonts/Inter-Bold.ttf",
|
|
135
|
+
)
|
|
136
|
+
draw.text((10, 10), "Normal", font=fonts["regular"])
|
|
137
|
+
draw.text((10, 30), "Bold", font=fonts["bold"])
|
|
138
|
+
"""
|
|
139
|
+
regular_font = load_font(regular, size, fallbacks)
|
|
140
|
+
|
|
141
|
+
result: dict[str, FreeTypeFont] = {
|
|
142
|
+
"regular": regular_font,
|
|
143
|
+
"bold": regular_font,
|
|
144
|
+
"italic": regular_font,
|
|
145
|
+
"bold_italic": regular_font,
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
# Try to load variants, fall back to regular on failure
|
|
149
|
+
for key, path in [("bold", bold), ("italic", italic), ("bold_italic", bold_italic)]:
|
|
150
|
+
if path is not None:
|
|
151
|
+
try:
|
|
152
|
+
result[key] = ImageFont.truetype(str(path), size)
|
|
153
|
+
except OSError:
|
|
154
|
+
pass # Keep fallback to regular
|
|
155
|
+
|
|
156
|
+
return result
|
cudag/core/generator.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Copyright (c) 2025 Tylt LLC. All rights reserved.
|
|
2
|
+
# CONFIDENTIAL AND PROPRIETARY. Unauthorized use, copying, or distribution
|
|
3
|
+
# is strictly prohibited. For licensing inquiries: hello@claimhawk.app
|
|
4
|
+
|
|
5
|
+
"""Generator entry point helper for CUDAG projects.
|
|
6
|
+
|
|
7
|
+
This module provides a standard entry point for dataset generation that
|
|
8
|
+
handles boilerplate like argument parsing, config loading, and dataset naming.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import argparse
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
17
|
+
|
|
18
|
+
from cudag.core.dataset import DatasetBuilder, DatasetConfig
|
|
19
|
+
from cudag.core.utils import check_script_invocation, get_researcher_name
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from cudag.core.renderer import BaseRenderer
|
|
23
|
+
from cudag.core.task import BaseTask
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def run_generator(
|
|
27
|
+
renderer: BaseRenderer[Any],
|
|
28
|
+
tasks: list[BaseTask],
|
|
29
|
+
*,
|
|
30
|
+
config_path: Path | str = "config/dataset.yaml",
|
|
31
|
+
description: str = "Generate training dataset",
|
|
32
|
+
extra_args: list[tuple[str, dict[str, Any]]] | None = None,
|
|
33
|
+
config_modifier: Callable[[DatasetConfig, argparse.Namespace], None] | None = None,
|
|
34
|
+
post_build: Callable[[Path, BaseRenderer[Any]], None] | None = None,
|
|
35
|
+
) -> Path:
|
|
36
|
+
"""Standard dataset generation entry point.
|
|
37
|
+
|
|
38
|
+
Handles common boilerplate:
|
|
39
|
+
- Script invocation check (warning if not run via ./scripts/generate.sh)
|
|
40
|
+
- Argument parsing (--config, --seed, plus custom args)
|
|
41
|
+
- Config loading from YAML
|
|
42
|
+
- Dataset naming ({prefix}--{researcher}--{timestamp})
|
|
43
|
+
- Dataset building and test generation
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
renderer: Initialized renderer instance for generating images.
|
|
47
|
+
tasks: List of task instances to generate samples from.
|
|
48
|
+
config_path: Default path to config YAML file. Defaults to
|
|
49
|
+
"config/dataset.yaml".
|
|
50
|
+
description: CLI description shown in --help. Defaults to
|
|
51
|
+
"Generate training dataset".
|
|
52
|
+
extra_args: Additional CLI arguments as list of (name, kwargs) tuples.
|
|
53
|
+
Each tuple is passed to argparser.add_argument(name, **kwargs).
|
|
54
|
+
config_modifier: Optional callback to modify config after loading.
|
|
55
|
+
Called with (config, args) after config is loaded but before
|
|
56
|
+
dataset name is generated. Use this to apply custom logic.
|
|
57
|
+
post_build: Optional callback after dataset is built. Called with
|
|
58
|
+
(output_dir, renderer). Use for debug images, validation, etc.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Path to the generated dataset output directory.
|
|
62
|
+
|
|
63
|
+
Example:
|
|
64
|
+
Basic usage::
|
|
65
|
+
|
|
66
|
+
from cudag import run_generator
|
|
67
|
+
from .renderer import MyRenderer
|
|
68
|
+
from .tasks import ClickTask, ScrollTask
|
|
69
|
+
|
|
70
|
+
def main() -> None:
|
|
71
|
+
renderer = MyRenderer(assets_dir=Path("assets"))
|
|
72
|
+
tasks = [ClickTask(config={}, renderer=renderer)]
|
|
73
|
+
run_generator(renderer, tasks)
|
|
74
|
+
|
|
75
|
+
With custom arguments::
|
|
76
|
+
|
|
77
|
+
def main() -> None:
|
|
78
|
+
renderer = MyRenderer(assets_dir=Path("assets"))
|
|
79
|
+
tasks = [ClickTask(config={}, renderer=renderer)]
|
|
80
|
+
run_generator(
|
|
81
|
+
renderer,
|
|
82
|
+
tasks,
|
|
83
|
+
extra_args=[
|
|
84
|
+
("--debug", {"action": "store_true", "help": "Enable debug"}),
|
|
85
|
+
],
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
With config modification::
|
|
89
|
+
|
|
90
|
+
def modify_config(config, args):
|
|
91
|
+
if args.debug:
|
|
92
|
+
config.task_counts = {"click-day": 10}
|
|
93
|
+
|
|
94
|
+
def main() -> None:
|
|
95
|
+
renderer = MyRenderer(assets_dir=Path("assets"))
|
|
96
|
+
tasks = [ClickTask(config={}, renderer=renderer)]
|
|
97
|
+
run_generator(
|
|
98
|
+
renderer,
|
|
99
|
+
tasks,
|
|
100
|
+
config_modifier=modify_config,
|
|
101
|
+
)
|
|
102
|
+
"""
|
|
103
|
+
check_script_invocation()
|
|
104
|
+
|
|
105
|
+
parser = argparse.ArgumentParser(description=description)
|
|
106
|
+
parser.add_argument(
|
|
107
|
+
"--config",
|
|
108
|
+
type=Path,
|
|
109
|
+
default=Path(config_path),
|
|
110
|
+
help="Path to dataset config YAML",
|
|
111
|
+
)
|
|
112
|
+
parser.add_argument(
|
|
113
|
+
"--seed",
|
|
114
|
+
type=int,
|
|
115
|
+
default=None,
|
|
116
|
+
help="Override random seed from config",
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Add any extra arguments
|
|
120
|
+
if extra_args:
|
|
121
|
+
for name, kwargs in extra_args:
|
|
122
|
+
parser.add_argument(name, **kwargs)
|
|
123
|
+
|
|
124
|
+
args = parser.parse_args()
|
|
125
|
+
|
|
126
|
+
# Load config
|
|
127
|
+
config = DatasetConfig.from_yaml(args.config)
|
|
128
|
+
if args.seed is not None:
|
|
129
|
+
config.seed = args.seed
|
|
130
|
+
|
|
131
|
+
# Allow custom config modification
|
|
132
|
+
if config_modifier:
|
|
133
|
+
config_modifier(config, args)
|
|
134
|
+
|
|
135
|
+
# Build dataset name: {prefix}--{researcher}--{timestamp}
|
|
136
|
+
# Using "--" delimiter to disambiguate from hyphens in expert names
|
|
137
|
+
researcher = get_researcher_name()
|
|
138
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
139
|
+
name_parts = [config.name_prefix]
|
|
140
|
+
if researcher:
|
|
141
|
+
name_parts.append(researcher)
|
|
142
|
+
name_parts.append(timestamp)
|
|
143
|
+
dataset_name = "--".join(name_parts)
|
|
144
|
+
|
|
145
|
+
config.output_dir = Path("datasets") / dataset_name
|
|
146
|
+
|
|
147
|
+
print(f"Loaded config: {config.name_prefix}")
|
|
148
|
+
print(f"Tasks: {config.task_counts}")
|
|
149
|
+
|
|
150
|
+
# Build dataset
|
|
151
|
+
builder = DatasetBuilder(config=config, tasks=tasks)
|
|
152
|
+
output_dir = builder.build()
|
|
153
|
+
|
|
154
|
+
# Build tests
|
|
155
|
+
builder.build_tests()
|
|
156
|
+
|
|
157
|
+
# Optional post-build callback
|
|
158
|
+
if post_build:
|
|
159
|
+
post_build(output_dir, renderer)
|
|
160
|
+
|
|
161
|
+
print(f"\nDataset generated at: {output_dir}")
|
|
162
|
+
|
|
163
|
+
return output_dir
|