chapman-datagen 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chapman_datagen-0.1.0/PKG-INFO +21 -0
- chapman_datagen-0.1.0/README.md +3 -0
- chapman_datagen-0.1.0/chapman_datagen/__init__.py +29 -0
- chapman_datagen-0.1.0/chapman_datagen/chunker.py +126 -0
- chapman_datagen-0.1.0/chapman_datagen/cli.py +193 -0
- chapman_datagen-0.1.0/chapman_datagen/config.py +156 -0
- chapman_datagen-0.1.0/chapman_datagen/generator.py +129 -0
- chapman_datagen-0.1.0/chapman_datagen/pipeline.py +466 -0
- chapman_datagen-0.1.0/chapman_datagen/uploader.py +47 -0
- chapman_datagen-0.1.0/chapman_datagen.egg-info/PKG-INFO +21 -0
- chapman_datagen-0.1.0/chapman_datagen.egg-info/SOURCES.txt +15 -0
- chapman_datagen-0.1.0/chapman_datagen.egg-info/dependency_links.txt +1 -0
- chapman_datagen-0.1.0/chapman_datagen.egg-info/entry_points.txt +2 -0
- chapman_datagen-0.1.0/chapman_datagen.egg-info/requires.txt +7 -0
- chapman_datagen-0.1.0/chapman_datagen.egg-info/top_level.txt +1 -0
- chapman_datagen-0.1.0/pyproject.toml +32 -0
- chapman_datagen-0.1.0/setup.cfg +4 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: chapman-datagen
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A highly generalizable prompt-based dataset generator and Hugging Face uploader
|
|
5
|
+
License: MIT
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
8
|
+
Classifier: Operating System :: OS Independent
|
|
9
|
+
Requires-Python: >=3.9
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
Requires-Dist: openai>=1.0.0
|
|
12
|
+
Requires-Dist: pandas>=1.5.0
|
|
13
|
+
Requires-Dist: pyarrow>=10.0.0
|
|
14
|
+
Requires-Dist: huggingface_hub>=0.15.0
|
|
15
|
+
Requires-Dist: tqdm>=4.65.0
|
|
16
|
+
Requires-Dist: jinja2>=3.0.0
|
|
17
|
+
Requires-Dist: datasets>=2.0.0
|
|
18
|
+
|
|
19
|
+
# chapman-datagen
|
|
20
|
+
|
|
21
|
+
A highly generalizable prompt-based dataset generator and Hugging Face uploader.
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from chapman_datagen.config import (
|
|
2
|
+
PipelineConfig,
|
|
3
|
+
APIConfig,
|
|
4
|
+
InputConfig,
|
|
5
|
+
StepConfig,
|
|
6
|
+
OutputConfig,
|
|
7
|
+
HuggingFaceConfig
|
|
8
|
+
)
|
|
9
|
+
from chapman_datagen.pipeline import run_pipeline, run_generation_pipeline_async
|
|
10
|
+
from chapman_datagen.generator import Generator
|
|
11
|
+
from chapman_datagen.chunker import chunk_jsonl_to_parquet, write_parquet_chunks
|
|
12
|
+
from chapman_datagen.uploader import upload_to_huggingface
|
|
13
|
+
|
|
14
|
+
__version__ = "0.1.0"
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"PipelineConfig",
|
|
18
|
+
"APIConfig",
|
|
19
|
+
"InputConfig",
|
|
20
|
+
"StepConfig",
|
|
21
|
+
"OutputConfig",
|
|
22
|
+
"HuggingFaceConfig",
|
|
23
|
+
"run_pipeline",
|
|
24
|
+
"run_generation_pipeline_async",
|
|
25
|
+
"Generator",
|
|
26
|
+
"chunk_jsonl_to_parquet",
|
|
27
|
+
"write_parquet_chunks",
|
|
28
|
+
"upload_to_huggingface"
|
|
29
|
+
]
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import io
|
|
3
|
+
import math
|
|
4
|
+
import shutil
|
|
5
|
+
import logging
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
def estimate_rows_per_chunk(df: pd.DataFrame, max_chunk_size_mb: float) -> int:
|
|
11
|
+
"""
|
|
12
|
+
Estimates the number of rows that can fit into a single Parquet file under the limit.
|
|
13
|
+
Uses the first 100 rows to estimate the average compressed size per row.
|
|
14
|
+
"""
|
|
15
|
+
num_rows = len(df)
|
|
16
|
+
test_rows = min(num_rows, 100)
|
|
17
|
+
if test_rows == 0:
|
|
18
|
+
return 1000
|
|
19
|
+
|
|
20
|
+
test_df = df.iloc[:test_rows]
|
|
21
|
+
buffer = io.BytesIO()
|
|
22
|
+
try:
|
|
23
|
+
test_df.to_parquet(buffer, index=False)
|
|
24
|
+
compressed_size = len(buffer.getvalue())
|
|
25
|
+
except Exception as e:
|
|
26
|
+
logger.warning(f"Failed to estimate Parquet size in memory: {e}. Falling back to default row estimation.")
|
|
27
|
+
# Fallback estimation based on memory usage
|
|
28
|
+
compressed_size = test_df.memory_usage(deep=True).sum() * 0.3 # assume 30% compression
|
|
29
|
+
|
|
30
|
+
avg_row_size_bytes = compressed_size / test_rows
|
|
31
|
+
target_size_bytes = max_chunk_size_mb * 1024 * 1024
|
|
32
|
+
|
|
33
|
+
# Use a safety factor of 0.85 to handle text length variability
|
|
34
|
+
safety_factor = 0.85
|
|
35
|
+
estimated_rows = int((target_size_bytes * safety_factor) / avg_row_size_bytes)
|
|
36
|
+
|
|
37
|
+
# Ensure at least 1 row per chunk
|
|
38
|
+
return max(1, estimated_rows)
|
|
39
|
+
|
|
40
|
+
def write_parquet_chunks(df: pd.DataFrame, output_dir: str, max_chunk_size_mb: float) -> list[str]:
|
|
41
|
+
"""
|
|
42
|
+
Splits the dataframe and writes to parquet chunks of size <= max_chunk_size_mb.
|
|
43
|
+
Includes a safety checker that recursively splits a chunk if its final disk size exceeds the limit.
|
|
44
|
+
"""
|
|
45
|
+
if os.path.exists(output_dir):
|
|
46
|
+
# Clear existing directory to avoid stale files
|
|
47
|
+
shutil.rmtree(output_dir)
|
|
48
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
49
|
+
|
|
50
|
+
if df.empty:
|
|
51
|
+
logger.warning("Empty DataFrame provided. Writing empty parquet file.")
|
|
52
|
+
empty_path = os.path.join(output_dir, "data-00000-of-00001.parquet")
|
|
53
|
+
df.to_parquet(empty_path, index=False)
|
|
54
|
+
return [empty_path]
|
|
55
|
+
|
|
56
|
+
rows_per_chunk = estimate_rows_per_chunk(df, max_chunk_size_mb)
|
|
57
|
+
logger.info(f"Estimated rows per chunk: {rows_per_chunk} (for target size {max_chunk_size_mb} MB)")
|
|
58
|
+
|
|
59
|
+
# Initial slice mapping
|
|
60
|
+
initial_chunks = []
|
|
61
|
+
num_rows = len(df)
|
|
62
|
+
num_initial_chunks = math.ceil(num_rows / rows_per_chunk)
|
|
63
|
+
|
|
64
|
+
for i in range(num_initial_chunks):
|
|
65
|
+
start_idx = i * rows_per_chunk
|
|
66
|
+
end_idx = min(start_idx + rows_per_chunk, num_rows)
|
|
67
|
+
initial_chunks.append(df.iloc[start_idx:end_idx])
|
|
68
|
+
|
|
69
|
+
# Queue of dataframes to write.
|
|
70
|
+
# Each item is a tuple: (dataframe, chunk_id)
|
|
71
|
+
write_queue = [(chunk_df, i) for i, chunk_df in enumerate(initial_chunks)]
|
|
72
|
+
written_files = []
|
|
73
|
+
|
|
74
|
+
max_chunk_bytes = max_chunk_size_mb * 1024 * 1024
|
|
75
|
+
|
|
76
|
+
while write_queue:
|
|
77
|
+
current_df, chunk_id = write_queue.pop(0)
|
|
78
|
+
|
|
79
|
+
# Temp file name to check size
|
|
80
|
+
temp_name = f"data-chunk-{chunk_id}.parquet"
|
|
81
|
+
temp_path = os.path.join(output_dir, temp_name)
|
|
82
|
+
|
|
83
|
+
current_df.to_parquet(temp_path, index=False)
|
|
84
|
+
file_size = os.path.getsize(temp_path)
|
|
85
|
+
|
|
86
|
+
if file_size > max_chunk_bytes and len(current_df) > 1:
|
|
87
|
+
# The chunk exceeded the max size! Split it in half.
|
|
88
|
+
os.remove(temp_path)
|
|
89
|
+
mid = len(current_df) // 2
|
|
90
|
+
logger.warning(
|
|
91
|
+
f"Chunk {chunk_id} with {len(current_df)} rows size on disk ({file_size / (1024*1024):.2f} MB) "
|
|
92
|
+
f"exceeded maximum size limit {max_chunk_size_mb} MB. Splitting in half."
|
|
93
|
+
)
|
|
94
|
+
write_queue.insert(0, (current_df.iloc[mid:], f"{chunk_id}b"))
|
|
95
|
+
write_queue.insert(0, (current_df.iloc[:mid], f"{chunk_id}a"))
|
|
96
|
+
else:
|
|
97
|
+
written_files.append((temp_path, current_df))
|
|
98
|
+
|
|
99
|
+
# Rename files to follow standard dataset format: data-00000-of-0000N.parquet
|
|
100
|
+
total_chunks = len(written_files)
|
|
101
|
+
final_paths = []
|
|
102
|
+
|
|
103
|
+
for i, (temp_path, chunk_df) in enumerate(written_files):
|
|
104
|
+
filename = f"data-{i:05d}-of-{total_chunks:05d}.parquet"
|
|
105
|
+
final_path = os.path.join(output_dir, filename)
|
|
106
|
+
os.rename(temp_path, final_path)
|
|
107
|
+
|
|
108
|
+
actual_size_mb = os.path.getsize(final_path) / (1024 * 1024)
|
|
109
|
+
logger.info(f"Wrote chunk {i+1}/{total_chunks}: {filename} (size: {actual_size_mb:.2f} MB, rows: {len(chunk_df)})")
|
|
110
|
+
final_paths.append(final_path)
|
|
111
|
+
|
|
112
|
+
return final_paths
|
|
113
|
+
|
|
114
|
+
def chunk_jsonl_to_parquet(jsonl_path: str, output_dir: str, max_chunk_size_mb: float) -> list[str]:
|
|
115
|
+
"""
|
|
116
|
+
Utility function that loads a JSONL file, parses it to a DataFrame,
|
|
117
|
+
and runs the parquet chunking logic.
|
|
118
|
+
"""
|
|
119
|
+
logger.info(f"Reading JSONL file from {jsonl_path}...")
|
|
120
|
+
try:
|
|
121
|
+
df = pd.read_json(jsonl_path, lines=True)
|
|
122
|
+
except Exception as e:
|
|
123
|
+
logger.error(f"Failed to read JSONL file: {e}")
|
|
124
|
+
raise ValueError(f"Could not load JSONL file at {jsonl_path}: {e}")
|
|
125
|
+
|
|
126
|
+
return write_parquet_chunks(df, output_dir, max_chunk_size_mb)
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import argparse
|
|
4
|
+
import logging
|
|
5
|
+
import asyncio
|
|
6
|
+
|
|
7
|
+
from chapman_datagen.config import PipelineConfig
|
|
8
|
+
from chapman_datagen.pipeline import run_pipeline, run_generation_pipeline_async
|
|
9
|
+
from chapman_datagen.chunker import chunk_jsonl_to_parquet
|
|
10
|
+
from chapman_datagen.uploader import upload_to_huggingface
|
|
11
|
+
|
|
12
|
+
def setup_logging(verbose: bool):
|
|
13
|
+
level = logging.DEBUG if verbose else logging.INFO
|
|
14
|
+
logging.basicConfig(
|
|
15
|
+
level=level,
|
|
16
|
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
17
|
+
datefmt="%Y-%m-%d %H:%M:%S"
|
|
18
|
+
)
|
|
19
|
+
# Silence third-party logs a bit unless in debug mode
|
|
20
|
+
if not verbose:
|
|
21
|
+
logging.getLogger("openai").setLevel(logging.WARNING)
|
|
22
|
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
23
|
+
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
|
24
|
+
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
|
|
25
|
+
|
|
26
|
+
def handle_run(args):
|
|
27
|
+
"""
|
|
28
|
+
Handles 'run' subcommand: executes input -> generate -> chunk -> upload.
|
|
29
|
+
"""
|
|
30
|
+
if not os.path.exists(args.config):
|
|
31
|
+
print(f"Error: Config file not found at '{args.config}'", file=sys.stderr)
|
|
32
|
+
sys.exit(1)
|
|
33
|
+
|
|
34
|
+
config = PipelineConfig.from_json_file(args.config)
|
|
35
|
+
|
|
36
|
+
# Overrides
|
|
37
|
+
if args.api_key:
|
|
38
|
+
config.api.api_key = args.api_key
|
|
39
|
+
if args.base_url:
|
|
40
|
+
config.api.base_url = args.base_url
|
|
41
|
+
if args.hf_token:
|
|
42
|
+
if config.huggingface:
|
|
43
|
+
config.huggingface.token = args.hf_token
|
|
44
|
+
else:
|
|
45
|
+
print("Warning: Hugging Face token provided but 'huggingface' section missing in config.json. Upload will be skipped.")
|
|
46
|
+
if args.repo_id:
|
|
47
|
+
if config.huggingface:
|
|
48
|
+
config.huggingface.repo_id = args.repo_id
|
|
49
|
+
else:
|
|
50
|
+
print("Warning: Hugging Face repo_id provided but 'huggingface' section missing in config.json. Upload will be skipped.")
|
|
51
|
+
|
|
52
|
+
run_pipeline(config)
|
|
53
|
+
|
|
54
|
+
def handle_generate(args):
|
|
55
|
+
"""
|
|
56
|
+
Handles 'generate' subcommand: only executes inputs and steps to generate output JSONL.
|
|
57
|
+
"""
|
|
58
|
+
if not os.path.exists(args.config):
|
|
59
|
+
print(f"Error: Config file not found at '{args.config}'", file=sys.stderr)
|
|
60
|
+
sys.exit(1)
|
|
61
|
+
|
|
62
|
+
config = PipelineConfig.from_json_file(args.config)
|
|
63
|
+
|
|
64
|
+
if args.api_key:
|
|
65
|
+
config.api.api_key = args.api_key
|
|
66
|
+
if args.base_url:
|
|
67
|
+
config.api.base_url = args.base_url
|
|
68
|
+
|
|
69
|
+
logging.info("Starting generation phase only...")
|
|
70
|
+
asyncio.run(run_generation_pipeline_async(config))
|
|
71
|
+
logging.info(f"Generation phase complete. Output saved to {config.output.jsonl_path}")
|
|
72
|
+
|
|
73
|
+
def handle_process(args):
|
|
74
|
+
"""
|
|
75
|
+
Handles 'process' subcommand: converts JSONL to Parquet chunks.
|
|
76
|
+
"""
|
|
77
|
+
jsonl_path = args.jsonl
|
|
78
|
+
parquet_dir = args.parquet_dir
|
|
79
|
+
max_chunk_size = args.max_chunk_size
|
|
80
|
+
|
|
81
|
+
if not jsonl_path or not os.path.exists(jsonl_path):
|
|
82
|
+
# Check if config is provided to fallback
|
|
83
|
+
if args.config and os.path.exists(args.config):
|
|
84
|
+
config = PipelineConfig.from_json_file(args.config)
|
|
85
|
+
jsonl_path = jsonl_path or config.output.jsonl_path
|
|
86
|
+
parquet_dir = parquet_dir or config.output.parquet_dir
|
|
87
|
+
max_chunk_size = max_chunk_size or config.output.max_parquet_chunk_size_mb
|
|
88
|
+
|
|
89
|
+
if not jsonl_path or not os.path.exists(jsonl_path):
|
|
90
|
+
print("Error: Valid JSONL path must be provided via --jsonl or config.json", file=sys.stderr)
|
|
91
|
+
sys.exit(1)
|
|
92
|
+
if not parquet_dir:
|
|
93
|
+
print("Error: Output Parquet directory must be provided via --parquet-dir or config.json", file=sys.stderr)
|
|
94
|
+
sys.exit(1)
|
|
95
|
+
if max_chunk_size is None or max_chunk_size <= 0:
|
|
96
|
+
max_chunk_size = 20.0
|
|
97
|
+
|
|
98
|
+
logging.info(f"Starting Parquet chunking phase for {jsonl_path}...")
|
|
99
|
+
chunk_jsonl_to_parquet(
|
|
100
|
+
jsonl_path=jsonl_path,
|
|
101
|
+
output_dir=parquet_dir,
|
|
102
|
+
max_chunk_size_mb=max_chunk_size
|
|
103
|
+
)
|
|
104
|
+
logging.info(f"Parquet chunking phase complete. Saved to: {parquet_dir}")
|
|
105
|
+
|
|
106
|
+
def handle_upload(args):
|
|
107
|
+
"""
|
|
108
|
+
Handles 'upload' subcommand: uploads Parquet chunks directory to HF.
|
|
109
|
+
"""
|
|
110
|
+
parquet_dir = args.parquet_dir
|
|
111
|
+
repo_id = args.repo_id
|
|
112
|
+
hf_token = args.hf_token
|
|
113
|
+
|
|
114
|
+
if not parquet_dir or not os.path.exists(parquet_dir):
|
|
115
|
+
# Fallback to config.json
|
|
116
|
+
if args.config and os.path.exists(args.config):
|
|
117
|
+
config = PipelineConfig.from_json_file(args.config)
|
|
118
|
+
parquet_dir = parquet_dir or config.output.parquet_dir
|
|
119
|
+
if config.huggingface:
|
|
120
|
+
repo_id = repo_id or config.huggingface.repo_id
|
|
121
|
+
hf_token = hf_token or config.huggingface.token
|
|
122
|
+
|
|
123
|
+
if not parquet_dir or not os.path.exists(parquet_dir):
|
|
124
|
+
print("Error: Valid Parquet directory must be provided via --parquet-dir or config.json", file=sys.stderr)
|
|
125
|
+
sys.exit(1)
|
|
126
|
+
if not repo_id:
|
|
127
|
+
print("Error: Hugging Face repo_id must be provided via --repo-id or config.json", file=sys.stderr)
|
|
128
|
+
sys.exit(1)
|
|
129
|
+
|
|
130
|
+
logging.info(f"Starting Hugging Face upload phase for {parquet_dir}...")
|
|
131
|
+
repo_url = upload_to_huggingface(
|
|
132
|
+
folder_path=parquet_dir,
|
|
133
|
+
repo_id=repo_id,
|
|
134
|
+
token=hf_token
|
|
135
|
+
)
|
|
136
|
+
logging.info(f"Upload complete! Repository URL: {repo_url}")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def main():
|
|
140
|
+
parser = argparse.ArgumentParser(
|
|
141
|
+
description="chapman-datagen: A Highly Generalizable Dataset Generator and Hugging Face Uploader."
|
|
142
|
+
)
|
|
143
|
+
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
|
|
144
|
+
|
|
145
|
+
subparsers = parser.add_subparsers(dest="command", help="Available subcommands")
|
|
146
|
+
|
|
147
|
+
# 1. run command
|
|
148
|
+
run_parser = subparsers.add_parser("run", help="Run the entire pipeline (generate -> chunk -> upload)")
|
|
149
|
+
run_parser.add_argument("--config", "-c", default="config.json", help="Path to config.json file (default: config.json)")
|
|
150
|
+
run_parser.add_argument("--api-key", help="Override OpenAI API key")
|
|
151
|
+
run_parser.add_argument("--base-url", help="Override OpenAI API base URL")
|
|
152
|
+
run_parser.add_argument("--hf-token", help="Override Hugging Face Hub token")
|
|
153
|
+
run_parser.add_argument("--repo-id", help="Override Hugging Face dataset repository ID")
|
|
154
|
+
|
|
155
|
+
# 2. generate command
|
|
156
|
+
gen_parser = subparsers.add_parser("generate", help="Only generate the dataset (output raw JSONL)")
|
|
157
|
+
gen_parser.add_argument("--config", "-c", default="config.json", help="Path to config.json file (default: config.json)")
|
|
158
|
+
gen_parser.add_argument("--api-key", help="Override OpenAI API key")
|
|
159
|
+
gen_parser.add_argument("--base-url", help="Override OpenAI API base URL")
|
|
160
|
+
|
|
161
|
+
# 3. process command
|
|
162
|
+
proc_parser = subparsers.add_parser("process", help="Convert JSONL into size-bounded Parquet chunks")
|
|
163
|
+
proc_parser.add_argument("--config", "-c", help="Path to config.json file (optional fallback)")
|
|
164
|
+
proc_parser.add_argument("--jsonl", "-j", help="Path to JSONL source file")
|
|
165
|
+
proc_parser.add_argument("--parquet-dir", "-p", help="Output directory for Parquet chunks")
|
|
166
|
+
proc_parser.add_argument("--max-size", type=float, dest="max_chunk_size", help="Max size of Parquet chunk in MB")
|
|
167
|
+
|
|
168
|
+
# 4. upload command
|
|
169
|
+
up_parser = subparsers.add_parser("upload", help="Upload Parquet directory to Hugging Face")
|
|
170
|
+
up_parser.add_argument("--config", "-c", help="Path to config.json file (optional fallback)")
|
|
171
|
+
up_parser.add_argument("--parquet-dir", "-p", help="Path to Parquet chunks directory")
|
|
172
|
+
up_parser.add_argument("--repo-id", "-r", help="Hugging Face repo name (e.g. username/dataset)")
|
|
173
|
+
up_parser.add_argument("--hf-token", help="Hugging Face Hub API Token")
|
|
174
|
+
|
|
175
|
+
args = parser.parse_args()
|
|
176
|
+
|
|
177
|
+
if not args.command:
|
|
178
|
+
parser.print_help()
|
|
179
|
+
sys.exit(0)
|
|
180
|
+
|
|
181
|
+
setup_logging(args.verbose)
|
|
182
|
+
|
|
183
|
+
if args.command == "run":
|
|
184
|
+
handle_run(args)
|
|
185
|
+
elif args.command == "generate":
|
|
186
|
+
handle_generate(args)
|
|
187
|
+
elif args.command == "process":
|
|
188
|
+
handle_process(args)
|
|
189
|
+
elif args.command == "upload":
|
|
190
|
+
handle_upload(args)
|
|
191
|
+
|
|
192
|
+
if __name__ == "__main__":
|
|
193
|
+
main()
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import List, Dict, Any, Optional
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class APIConfig:
|
|
8
|
+
api_key: Optional[str] = field(default=None)
|
|
9
|
+
base_url: str = field(default="https://api.openai.com/v1")
|
|
10
|
+
api_type: str = field(default="chat") # "chat" or "completion"
|
|
11
|
+
model: str = field(default="gpt-4o-mini")
|
|
12
|
+
temperature: float = field(default=0.7)
|
|
13
|
+
max_tokens: int = field(default=1000)
|
|
14
|
+
concurrency_limit: int = field(default=10)
|
|
15
|
+
|
|
16
|
+
def __post_init__(self):
|
|
17
|
+
# Fallback to env variables if not set
|
|
18
|
+
if not self.api_key:
|
|
19
|
+
self.api_key = os.environ.get("OPENAI_API_KEY")
|
|
20
|
+
if self.base_url == "https://api.openai.com/v1" and os.environ.get("OPENAI_BASE_URL"):
|
|
21
|
+
self.base_url = os.environ.get("OPENAI_BASE_URL")
|
|
22
|
+
|
|
23
|
+
# Validation
|
|
24
|
+
if self.api_type not in ("chat", "completion"):
|
|
25
|
+
raise ValueError(f"api_type must be either 'chat' or 'completion', got '{self.api_type}'")
|
|
26
|
+
if self.concurrency_limit <= 0:
|
|
27
|
+
raise ValueError("concurrency_limit must be greater than 0")
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class InputConfig:
|
|
31
|
+
type: str # "list", "file", "synthetic", "huggingface"
|
|
32
|
+
data: Optional[List[Dict[str, Any]]] = field(default=None) # for type "list"
|
|
33
|
+
path: Optional[str] = field(default=None) # for type "file"
|
|
34
|
+
criteria: Optional[str] = field(default=None) # for type "synthetic"
|
|
35
|
+
num_prompts: Optional[int] = field(default=None) # for type "synthetic"
|
|
36
|
+
generator_model: Optional[str] = field(default=None) # for type "synthetic"
|
|
37
|
+
repo_id: Optional[str] = field(default=None) # for type "huggingface"
|
|
38
|
+
split: Optional[str] = field(default=None) # for type "huggingface"
|
|
39
|
+
subset: Optional[str] = field(default=None) # for type "huggingface"
|
|
40
|
+
limit: Optional[int] = field(default=None) # general limit for loading samples
|
|
41
|
+
|
|
42
|
+
def __post_init__(self):
|
|
43
|
+
valid_types = ("list", "file", "synthetic", "huggingface")
|
|
44
|
+
if self.type not in valid_types:
|
|
45
|
+
raise ValueError(f"input.type must be one of {valid_types}, got '{self.type}'")
|
|
46
|
+
|
|
47
|
+
if self.type == "list" and not self.data:
|
|
48
|
+
raise ValueError("input.data must be provided when input.type is 'list'")
|
|
49
|
+
if self.type == "file" and not self.path:
|
|
50
|
+
raise ValueError("input.path must be provided when input.type is 'file'")
|
|
51
|
+
if self.type == "huggingface" and not self.repo_id:
|
|
52
|
+
raise ValueError("input.repo_id must be provided when input.type is 'huggingface'")
|
|
53
|
+
if self.type == "synthetic":
|
|
54
|
+
if not self.criteria:
|
|
55
|
+
raise ValueError("input.criteria must be provided when input.type is 'synthetic'")
|
|
56
|
+
if not self.num_prompts or self.num_prompts <= 0:
|
|
57
|
+
raise ValueError("input.num_prompts must be greater than 0 for synthetic input")
|
|
58
|
+
if self.limit is not None and self.limit <= 0:
|
|
59
|
+
raise ValueError("input.limit must be greater than 0")
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class StepConfig:
|
|
63
|
+
prompt_template: str
|
|
64
|
+
type: str = field(default="llm")
|
|
65
|
+
system_prompt: Optional[str] = field(default=None)
|
|
66
|
+
output_field: str = field(default="response")
|
|
67
|
+
cot_field: Optional[str] = field(default=None) # field to save reasoning/thinking
|
|
68
|
+
loop_field: Optional[str] = field(default=None) # field to loop over (e.g. "messages")
|
|
69
|
+
loop_filter_role: Optional[str] = field(default=None) # role to filter and rewrite (e.g. "assistant")
|
|
70
|
+
model: Optional[str] = field(default=None)
|
|
71
|
+
temperature: Optional[float] = field(default=None)
|
|
72
|
+
max_tokens: Optional[int] = field(default=None)
|
|
73
|
+
|
|
74
|
+
def __post_init__(self):
|
|
75
|
+
if self.type != "llm":
|
|
76
|
+
raise ValueError(f"step.type currently only supports 'llm', got '{self.type}'")
|
|
77
|
+
if not self.prompt_template:
|
|
78
|
+
raise ValueError("step.prompt_template cannot be empty")
|
|
79
|
+
if not self.output_field:
|
|
80
|
+
raise ValueError("step.output_field cannot be empty")
|
|
81
|
+
if self.cot_field == "":
|
|
82
|
+
raise ValueError("step.cot_field cannot be an empty string")
|
|
83
|
+
if self.loop_field == "":
|
|
84
|
+
raise ValueError("step.loop_field cannot be an empty string")
|
|
85
|
+
|
|
86
|
+
@dataclass
|
|
87
|
+
class OutputConfig:
|
|
88
|
+
jsonl_path: str
|
|
89
|
+
parquet_dir: str
|
|
90
|
+
max_parquet_chunk_size_mb: float = field(default=20.0)
|
|
91
|
+
|
|
92
|
+
def __post_init__(self):
|
|
93
|
+
if not self.jsonl_path:
|
|
94
|
+
raise ValueError("output.jsonl_path cannot be empty")
|
|
95
|
+
if not self.parquet_dir:
|
|
96
|
+
raise ValueError("output.parquet_dir cannot be empty")
|
|
97
|
+
if self.max_parquet_chunk_size_mb <= 0:
|
|
98
|
+
raise ValueError("output.max_parquet_chunk_size_mb must be greater than 0")
|
|
99
|
+
|
|
100
|
+
@dataclass
|
|
101
|
+
class HuggingFaceConfig:
|
|
102
|
+
repo_id: str
|
|
103
|
+
token: Optional[str] = field(default=None)
|
|
104
|
+
private: bool = field(default=False)
|
|
105
|
+
|
|
106
|
+
def __post_init__(self):
|
|
107
|
+
if not self.repo_id:
|
|
108
|
+
raise ValueError("huggingface.repo_id cannot be empty")
|
|
109
|
+
if not self.token:
|
|
110
|
+
self.token = os.environ.get("HF_TOKEN")
|
|
111
|
+
|
|
112
|
+
@dataclass
|
|
113
|
+
class PipelineConfig:
|
|
114
|
+
api: APIConfig
|
|
115
|
+
input: InputConfig
|
|
116
|
+
steps: List[StepConfig]
|
|
117
|
+
output: OutputConfig
|
|
118
|
+
huggingface: Optional[HuggingFaceConfig] = field(default=None)
|
|
119
|
+
|
|
120
|
+
@classmethod
|
|
121
|
+
def from_dict(cls, d: Dict[str, Any]) -> "PipelineConfig":
|
|
122
|
+
# Parse api
|
|
123
|
+
api_data = d.get("api", {})
|
|
124
|
+
api = APIConfig(**api_data)
|
|
125
|
+
|
|
126
|
+
# Parse input
|
|
127
|
+
input_data = d.get("input")
|
|
128
|
+
if not input_data:
|
|
129
|
+
raise ValueError("Pipeline configuration must have an 'input' section")
|
|
130
|
+
input_cfg = InputConfig(**input_data)
|
|
131
|
+
|
|
132
|
+
# Parse steps
|
|
133
|
+
steps_data = d.get("steps")
|
|
134
|
+
if not steps_data:
|
|
135
|
+
raise ValueError("Pipeline configuration must have a 'steps' list")
|
|
136
|
+
if not isinstance(steps_data, list):
|
|
137
|
+
raise ValueError("'steps' must be a list of step configurations")
|
|
138
|
+
steps = [StepConfig(**s) for s in steps_data]
|
|
139
|
+
|
|
140
|
+
# Parse output
|
|
141
|
+
output_data = d.get("output")
|
|
142
|
+
if not output_data:
|
|
143
|
+
raise ValueError("Pipeline configuration must have an 'output' section")
|
|
144
|
+
output = OutputConfig(**output_data)
|
|
145
|
+
|
|
146
|
+
# Parse huggingface
|
|
147
|
+
hf_data = d.get("huggingface")
|
|
148
|
+
hf = HuggingFaceConfig(**hf_data) if hf_data else None
|
|
149
|
+
|
|
150
|
+
return cls(api=api, input=input_cfg, steps=steps, output=output, huggingface=hf)
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def from_json_file(cls, path: str) -> "PipelineConfig":
|
|
154
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
155
|
+
data = json.load(f)
|
|
156
|
+
return cls.from_dict(data)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import random
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional, List, Dict, Any
|
|
5
|
+
from openai import AsyncOpenAI, RateLimitError, APIConnectionError, APIStatusError
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
class Generator:
|
|
10
|
+
def __init__(self, api_key: Optional[str], base_url: str, api_type: str = "chat"):
|
|
11
|
+
self.api_type = api_type
|
|
12
|
+
# Initialize client. If api_key is None, AsyncOpenAI will look up OPENAI_API_KEY from env automatically.
|
|
13
|
+
self.client = AsyncOpenAI(
|
|
14
|
+
api_key=api_key or "mock-key", # Fallback to prevent client validation crash if mocking/stubbing
|
|
15
|
+
base_url=base_url
|
|
16
|
+
)
|
|
17
|
+
if api_key is None and not asyncio.iscoroutinefunction(self.client.with_options):
|
|
18
|
+
# If no API key is specified and it's not in the env, we set a dummy to avoid initialization error,
|
|
19
|
+
# but log a warning.
|
|
20
|
+
import os
|
|
21
|
+
if "OPENAI_API_KEY" not in os.environ:
|
|
22
|
+
logger.warning("No API Key provided in configuration or OPENAI_API_KEY environment variable. Request might fail.")
|
|
23
|
+
|
|
24
|
+
async def generate(
|
|
25
|
+
self,
|
|
26
|
+
prompt: str,
|
|
27
|
+
system_prompt: Optional[str] = None,
|
|
28
|
+
messages: Optional[List[Dict[str, str]]] = None,
|
|
29
|
+
model: str = "gpt-4o-mini",
|
|
30
|
+
temperature: float = 0.7,
|
|
31
|
+
max_tokens: int = 1000,
|
|
32
|
+
max_retries: int = 5,
|
|
33
|
+
initial_backoff: float = 1.0,
|
|
34
|
+
) -> dict[str, str]:
|
|
35
|
+
"""
|
|
36
|
+
Executes a prompt generation using either Chat Completions or Text Completions API.
|
|
37
|
+
Includes robust retry logic with exponential backoff and jitter.
|
|
38
|
+
Returns a dictionary containing "content" and "reasoning_content" keys.
|
|
39
|
+
"""
|
|
40
|
+
retries = 0
|
|
41
|
+
backoff = initial_backoff
|
|
42
|
+
|
|
43
|
+
while True:
|
|
44
|
+
try:
|
|
45
|
+
if self.api_type == "chat":
|
|
46
|
+
chat_messages = []
|
|
47
|
+
if system_prompt:
|
|
48
|
+
chat_messages.append({"role": "system", "content": system_prompt})
|
|
49
|
+
|
|
50
|
+
if messages:
|
|
51
|
+
# Append the pre-structured chat messages directly
|
|
52
|
+
for msg in messages:
|
|
53
|
+
chat_messages.append({
|
|
54
|
+
"role": msg.get("role", "user"),
|
|
55
|
+
"content": msg.get("content", "")
|
|
56
|
+
})
|
|
57
|
+
# Only append prompt as a separate instruction if it differs from the last message's content
|
|
58
|
+
last_content = chat_messages[-1].get("content") if chat_messages else None
|
|
59
|
+
if prompt and prompt != last_content:
|
|
60
|
+
chat_messages.append({"role": "user", "content": prompt})
|
|
61
|
+
else:
|
|
62
|
+
chat_messages.append({"role": "user", "content": prompt})
|
|
63
|
+
|
|
64
|
+
response = await self.client.chat.completions.create(
|
|
65
|
+
model=model,
|
|
66
|
+
messages=chat_messages,
|
|
67
|
+
temperature=temperature,
|
|
68
|
+
max_tokens=max_tokens,
|
|
69
|
+
frequency_penalty=1.2,
|
|
70
|
+
presence_penalty=0.5,
|
|
71
|
+
)
|
|
72
|
+
choice = response.choices[0]
|
|
73
|
+
message = choice.message
|
|
74
|
+
content = message.content or ""
|
|
75
|
+
|
|
76
|
+
# Extract reasoning content (DeepSeek R1/OpenAI o-series standard)
|
|
77
|
+
reasoning_content = getattr(message, "reasoning_content", None)
|
|
78
|
+
if not reasoning_content:
|
|
79
|
+
reasoning_content = getattr(message, "reasoning", None)
|
|
80
|
+
if not reasoning_content and hasattr(message, "model_extra"):
|
|
81
|
+
extra = message.model_extra or {}
|
|
82
|
+
reasoning_content = extra.get("reasoning_content") or extra.get("reasoning")
|
|
83
|
+
|
|
84
|
+
return {
|
|
85
|
+
"content": content,
|
|
86
|
+
"reasoning_content": reasoning_content or ""
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
elif self.api_type == "completion":
|
|
90
|
+
full_prompt = prompt
|
|
91
|
+
if system_prompt:
|
|
92
|
+
full_prompt = f"{system_prompt}\n\n{prompt}"
|
|
93
|
+
|
|
94
|
+
response = await self.client.completions.create(
|
|
95
|
+
model=model,
|
|
96
|
+
prompt=full_prompt,
|
|
97
|
+
temperature=temperature,
|
|
98
|
+
max_tokens=max_tokens,
|
|
99
|
+
)
|
|
100
|
+
choice = response.choices[0]
|
|
101
|
+
content = choice.text or ""
|
|
102
|
+
return {
|
|
103
|
+
"content": content,
|
|
104
|
+
"reasoning_content": ""
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError(f"Unsupported api_type: {self.api_type}")
|
|
109
|
+
|
|
110
|
+
except (RateLimitError, APIConnectionError, APIStatusError, Exception) as e:
|
|
111
|
+
# If we have reached max retries, propagate the error
|
|
112
|
+
if retries >= max_retries:
|
|
113
|
+
logger.error(f"Max retries ({max_retries}) reached. Raising error: {e}")
|
|
114
|
+
raise e
|
|
115
|
+
|
|
116
|
+
# Check if it's a non-retryable API error (e.g. 401 Unauthorized, 404 Not Found)
|
|
117
|
+
if isinstance(e, APIStatusError) and e.status_code in (400, 401, 403, 404):
|
|
118
|
+
logger.error(f"Non-retryable API Error {e.status_code}: {e.message}. Aborting request.")
|
|
119
|
+
raise e
|
|
120
|
+
|
|
121
|
+
# Determine backoff time with jitter
|
|
122
|
+
sleep_time = backoff + (random.random() * 0.5 * backoff)
|
|
123
|
+
logger.warning(
|
|
124
|
+
f"Request failed due to: {e}. Retrying in {sleep_time:.2f} seconds (Attempt {retries + 1}/{max_retries})..."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
await asyncio.sleep(sleep_time)
|
|
128
|
+
retries += 1
|
|
129
|
+
backoff *= 2 # Exponential backoff
|
|
@@ -0,0 +1,466 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import asyncio
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from typing import List, Dict, Any, Optional
|
|
7
|
+
import tqdm.asyncio
|
|
8
|
+
|
|
9
|
+
from chapman_datagen.config import PipelineConfig, StepConfig
|
|
10
|
+
from chapman_datagen.generator import Generator
|
|
11
|
+
from chapman_datagen.chunker import chunk_jsonl_to_parquet
|
|
12
|
+
from chapman_datagen.uploader import upload_to_huggingface
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
def load_inputs(config: PipelineConfig, generator: Generator) -> List[Dict[str, Any]]:
|
|
17
|
+
"""
|
|
18
|
+
Loads inputs based on the input configuration.
|
|
19
|
+
Supports lists, local files (CSV, JSON, JSONL, Parquet), synthetic LLM prompt generation,
|
|
20
|
+
or loading from the Hugging Face Hub (huggingface type).
|
|
21
|
+
"""
|
|
22
|
+
input_cfg = config.input
|
|
23
|
+
logger.info(f"Loading inputs of type: {input_cfg.type}")
|
|
24
|
+
|
|
25
|
+
records: List[Dict[str, Any]] = []
|
|
26
|
+
|
|
27
|
+
if input_cfg.type == "list":
|
|
28
|
+
records = input_cfg.data or []
|
|
29
|
+
|
|
30
|
+
elif input_cfg.type == "file":
|
|
31
|
+
path = input_cfg.path
|
|
32
|
+
if not path or not os.path.exists(path):
|
|
33
|
+
raise FileNotFoundError(f"Input file not found at: {path}")
|
|
34
|
+
|
|
35
|
+
# Support CSV, Parquet, JSON, and JSONL via Pandas
|
|
36
|
+
ext = os.path.splitext(path)[1].lower()
|
|
37
|
+
try:
|
|
38
|
+
if ext == ".csv":
|
|
39
|
+
df = pd.read_csv(path)
|
|
40
|
+
elif ext == ".parquet":
|
|
41
|
+
df = pd.read_parquet(path)
|
|
42
|
+
elif ext in (".jsonl", ".json"):
|
|
43
|
+
# Guess lines=True/False
|
|
44
|
+
try:
|
|
45
|
+
df = pd.read_json(path, lines=True)
|
|
46
|
+
except Exception:
|
|
47
|
+
df = pd.read_json(path, lines=False)
|
|
48
|
+
else:
|
|
49
|
+
# Default fallback
|
|
50
|
+
df = pd.read_json(path, lines=True)
|
|
51
|
+
|
|
52
|
+
# Convert NaN to None for clean JSON serialization
|
|
53
|
+
df = df.where(pd.notnull(df), None)
|
|
54
|
+
records = df.to_dict(orient="records")
|
|
55
|
+
except Exception as e:
|
|
56
|
+
raise ValueError(f"Failed to read input file {path}: {e}")
|
|
57
|
+
|
|
58
|
+
elif input_cfg.type == "huggingface":
|
|
59
|
+
# Load from Hugging Face Hub using the `datasets` library
|
|
60
|
+
try:
|
|
61
|
+
from datasets import load_dataset
|
|
62
|
+
except ImportError:
|
|
63
|
+
raise ImportError(
|
|
64
|
+
"The 'datasets' library is required to load Hugging Face datasets. "
|
|
65
|
+
"Please install it using: pip install datasets"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
logger.info(f"Loading Hugging Face dataset '{input_cfg.repo_id}' (split: '{input_cfg.split or 'train'}', subset: '{input_cfg.subset or 'default'}')...")
|
|
69
|
+
try:
|
|
70
|
+
ds = load_dataset(
|
|
71
|
+
input_cfg.repo_id,
|
|
72
|
+
name=input_cfg.subset,
|
|
73
|
+
split=input_cfg.split or "train"
|
|
74
|
+
)
|
|
75
|
+
records = ds.to_list()
|
|
76
|
+
except Exception as e:
|
|
77
|
+
raise ValueError(f"Failed to load dataset '{input_cfg.repo_id}' from Hugging Face: {e}")
|
|
78
|
+
|
|
79
|
+
elif input_cfg.type == "synthetic":
|
|
80
|
+
# Generate prompts using the LLM based on criteria
|
|
81
|
+
criteria = input_cfg.criteria
|
|
82
|
+
num_prompts = input_cfg.num_prompts or 10
|
|
83
|
+
model = input_cfg.generator_model or config.api.model
|
|
84
|
+
|
|
85
|
+
logger.info(f"Generating {num_prompts} synthetic prompts based on criteria: '{criteria}'")
|
|
86
|
+
|
|
87
|
+
system_prompt = "You are a synthetic dataset generator."
|
|
88
|
+
prompt = (
|
|
89
|
+
f"We are generating a dataset based on the following criteria:\n"
|
|
90
|
+
f"'{criteria}'\n\n"
|
|
91
|
+
f"Please generate a JSON list containing {num_prompts} distinct user prompts that match this criteria. "
|
|
92
|
+
f"The output must be a valid JSON array of strings, for example:\n"
|
|
93
|
+
f"[\n \"Prompt 1\",\n \"Prompt 2\"\n]\n"
|
|
94
|
+
f"Return ONLY the raw JSON (no markdown block, no explanation, no headers)."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
loop = asyncio.get_running_loop()
|
|
99
|
+
except RuntimeError:
|
|
100
|
+
loop = None
|
|
101
|
+
|
|
102
|
+
if loop and loop.is_running():
|
|
103
|
+
# If loop is already running, run it in task
|
|
104
|
+
future = asyncio.run_coroutine_threadsafe(
|
|
105
|
+
generator.generate(prompt, system_prompt=system_prompt, model=model),
|
|
106
|
+
loop
|
|
107
|
+
)
|
|
108
|
+
response_dict = future.result()
|
|
109
|
+
else:
|
|
110
|
+
response_dict = asyncio.run(
|
|
111
|
+
generator.generate(prompt, system_prompt=system_prompt, model=model)
|
|
112
|
+
)
|
|
113
|
+
response_text = response_dict["content"]
|
|
114
|
+
|
|
115
|
+
# Parse response
|
|
116
|
+
response_text = response_text.strip()
|
|
117
|
+
if response_text.startswith("```"):
|
|
118
|
+
# Strip markdown code blocks
|
|
119
|
+
lines = response_text.splitlines()
|
|
120
|
+
if lines[0].startswith("```"):
|
|
121
|
+
lines = lines[1:]
|
|
122
|
+
if lines[-1].startswith("```"):
|
|
123
|
+
lines = lines[:-1]
|
|
124
|
+
response_text = "\n".join(lines).strip()
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
parsed = json.loads(response_text)
|
|
128
|
+
if not isinstance(parsed, list):
|
|
129
|
+
raise ValueError("Parsed JSON is not a list")
|
|
130
|
+
|
|
131
|
+
# Format as records
|
|
132
|
+
for item in parsed:
|
|
133
|
+
if isinstance(item, dict):
|
|
134
|
+
records.append(item)
|
|
135
|
+
else:
|
|
136
|
+
records.append({"prompt": str(item)})
|
|
137
|
+
except Exception as e:
|
|
138
|
+
logger.error(f"Failed to parse synthetic prompts LLM response: {response_text}")
|
|
139
|
+
raise ValueError(f"Failed to parse generated prompts as JSON: {e}")
|
|
140
|
+
else:
|
|
141
|
+
raise ValueError(f"Unknown input type: {input_cfg.type}")
|
|
142
|
+
|
|
143
|
+
# Apply general sample limit if specified
|
|
144
|
+
if input_cfg.limit and input_cfg.limit > 0:
|
|
145
|
+
logger.info(f"Limiting loaded records to the first {input_cfg.limit} samples.")
|
|
146
|
+
records = records[:input_cfg.limit]
|
|
147
|
+
|
|
148
|
+
return records
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def extract_cot_from_text(text: str) -> tuple[str, str]:
|
|
152
|
+
"""
|
|
153
|
+
Parses <think>...</think> or <reasoning>...</reasoning> blocks from text.
|
|
154
|
+
Returns a tuple of (cleaned_text, extracted_cot).
|
|
155
|
+
"""
|
|
156
|
+
import re
|
|
157
|
+
# Pattern for <think>...</think>
|
|
158
|
+
think_pattern = re.compile(r'<think>(.*?)</think>', re.DOTALL)
|
|
159
|
+
think_match = think_pattern.search(text)
|
|
160
|
+
if think_match:
|
|
161
|
+
cot = think_match.group(1).strip()
|
|
162
|
+
cleaned = think_pattern.sub("", text).strip()
|
|
163
|
+
return cleaned, cot
|
|
164
|
+
|
|
165
|
+
# Pattern for <reasoning>...</reasoning>
|
|
166
|
+
reasoning_pattern = re.compile(r'<reasoning>(.*?)</reasoning>', re.DOTALL)
|
|
167
|
+
reasoning_match = reasoning_pattern.search(text)
|
|
168
|
+
if reasoning_match:
|
|
169
|
+
cot = reasoning_match.group(1).strip()
|
|
170
|
+
cleaned = reasoning_pattern.sub("", text).strip()
|
|
171
|
+
return cleaned, cot
|
|
172
|
+
|
|
173
|
+
return text, ""
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
async def execute_step_for_record(
|
|
177
|
+
record: Dict[str, Any],
|
|
178
|
+
step: StepConfig,
|
|
179
|
+
generator: Generator,
|
|
180
|
+
global_model: str,
|
|
181
|
+
global_temp: float,
|
|
182
|
+
global_max_tokens: int,
|
|
183
|
+
semaphore: asyncio.Semaphore
|
|
184
|
+
) -> Dict[str, Any]:
|
|
185
|
+
"""
|
|
186
|
+
Executes a single step on a record under the semaphore lock.
|
|
187
|
+
If loop_field is specified, it loops over items in the field (e.g. conversation messages)
|
|
188
|
+
and rewrites them sequentially.
|
|
189
|
+
"""
|
|
190
|
+
# Override settings if step defines them
|
|
191
|
+
model = step.model or global_model
|
|
192
|
+
temp = step.temperature if step.temperature is not None else global_temp
|
|
193
|
+
max_tokens = step.max_tokens if step.max_tokens is not None else global_max_tokens
|
|
194
|
+
|
|
195
|
+
if step.loop_field:
|
|
196
|
+
# Loop mode: rewrite elements within a list field in place
|
|
197
|
+
items = record.get(step.loop_field)
|
|
198
|
+
if not items or not isinstance(items, list):
|
|
199
|
+
logger.warning(f"loop_field '{step.loop_field}' not found or is not a list in record. Skipping loop execution.")
|
|
200
|
+
return record
|
|
201
|
+
|
|
202
|
+
rewritten_items = []
|
|
203
|
+
cots = []
|
|
204
|
+
|
|
205
|
+
# We process the loop sequentially to build context step-by-step
|
|
206
|
+
for idx, item in enumerate(items):
|
|
207
|
+
if not isinstance(item, dict):
|
|
208
|
+
rewritten_items.append(item)
|
|
209
|
+
continue
|
|
210
|
+
|
|
211
|
+
role = item.get("role")
|
|
212
|
+
if role != step.loop_filter_role:
|
|
213
|
+
# Keep non-matching items (e.g. user messages) exactly unchanged
|
|
214
|
+
rewritten_items.append(item.copy())
|
|
215
|
+
continue
|
|
216
|
+
|
|
217
|
+
# We want to perform generation with the prefix history up to the current item.
|
|
218
|
+
# To do that, we construct a snapshot of the messages up to this point,
|
|
219
|
+
# but using the already rewritten history for previous assistant responses.
|
|
220
|
+
prefix_messages = rewritten_items.copy()
|
|
221
|
+
|
|
222
|
+
# Format the prompt using both global record keys and helper variables
|
|
223
|
+
local_context = record.copy()
|
|
224
|
+
# If the template references '{messages}', feed the JSON prefix string
|
|
225
|
+
local_context["messages"] = json.dumps(prefix_messages)
|
|
226
|
+
local_context["original_response"] = item.get("content", "")
|
|
227
|
+
local_context["current_prompt"] = prefix_messages[-1].get("content", "") if prefix_messages else ""
|
|
228
|
+
|
|
229
|
+
# Format system prompt and template
|
|
230
|
+
try:
|
|
231
|
+
formatted_prompt = step.prompt_template.format(**local_context)
|
|
232
|
+
except KeyError as e:
|
|
233
|
+
# If formatting fails, fallback to passing the original response directly
|
|
234
|
+
formatted_prompt = f"Original Response: {item.get('content', '')}"
|
|
235
|
+
|
|
236
|
+
formatted_system = None
|
|
237
|
+
if step.system_prompt:
|
|
238
|
+
try:
|
|
239
|
+
formatted_system = step.system_prompt.format(**local_context)
|
|
240
|
+
except KeyError as e:
|
|
241
|
+
formatted_system = step.system_prompt
|
|
242
|
+
|
|
243
|
+
logger.info(f"Loop Step: Processing conversation turn {idx+1}/{len(items)} for role '{role}'...")
|
|
244
|
+
import time
|
|
245
|
+
start_time = time.time()
|
|
246
|
+
async with semaphore:
|
|
247
|
+
res_dict = await generator.generate(
|
|
248
|
+
prompt=formatted_prompt,
|
|
249
|
+
system_prompt=formatted_system,
|
|
250
|
+
messages=prefix_messages, # Pass native messages directly to Chat API
|
|
251
|
+
model=model,
|
|
252
|
+
temperature=temp,
|
|
253
|
+
max_tokens=max_tokens
|
|
254
|
+
)
|
|
255
|
+
duration = time.time() - start_time
|
|
256
|
+
|
|
257
|
+
content = res_dict["content"]
|
|
258
|
+
reasoning_content = res_dict["reasoning_content"]
|
|
259
|
+
logger.info(f"Loop Step: Turn {idx+1} complete in {duration:.2f}s (generated {len(content)} chars, CoT: {len(reasoning_content) if reasoning_content else 0} chars).")
|
|
260
|
+
|
|
261
|
+
if step.cot_field:
|
|
262
|
+
if not reasoning_content:
|
|
263
|
+
content, reasoning_content = extract_cot_from_text(content)
|
|
264
|
+
else:
|
|
265
|
+
# API returned reasoning_content separately, but content may still
|
|
266
|
+
# contain inline <think> tags — strip them so they don't pollute
|
|
267
|
+
# the conversation history passed to subsequent turns.
|
|
268
|
+
content, _ = extract_cot_from_text(content)
|
|
269
|
+
if reasoning_content:
|
|
270
|
+
cots.append(reasoning_content)
|
|
271
|
+
else:
|
|
272
|
+
# Even without a cot_field, always strip think tags from content
|
|
273
|
+
# before storing it back into the conversation history.
|
|
274
|
+
content, _ = extract_cot_from_text(content)
|
|
275
|
+
|
|
276
|
+
# Update item content with the rewritten response (CoT stripped)
|
|
277
|
+
new_item = item.copy()
|
|
278
|
+
new_item["content"] = content
|
|
279
|
+
rewritten_items.append(new_item)
|
|
280
|
+
|
|
281
|
+
# Update the main record with the fully rewritten list
|
|
282
|
+
record[step.output_field] = rewritten_items
|
|
283
|
+
if step.cot_field and cots:
|
|
284
|
+
# Join all step CoTs together for database visibility
|
|
285
|
+
record[step.cot_field] = "\n\n=== Step CoT ===\n\n".join(cots)
|
|
286
|
+
return record
|
|
287
|
+
|
|
288
|
+
else:
|
|
289
|
+
# Standard mode: single generation
|
|
290
|
+
async with semaphore:
|
|
291
|
+
# Format the templates using the current record values
|
|
292
|
+
try:
|
|
293
|
+
formatted_prompt = step.prompt_template.format(**record)
|
|
294
|
+
except KeyError as e:
|
|
295
|
+
logger.error(f"Formatting failed for template '{step.prompt_template}' with record keys: {list(record.keys())}. Missing key: {e}")
|
|
296
|
+
raise KeyError(f"Missing variable {e} in record for step template '{step.prompt_template}'")
|
|
297
|
+
|
|
298
|
+
formatted_system = None
|
|
299
|
+
if step.system_prompt:
|
|
300
|
+
try:
|
|
301
|
+
formatted_system = step.system_prompt.format(**record)
|
|
302
|
+
except KeyError as e:
|
|
303
|
+
formatted_system = step.system_prompt # Fallback to literal string if formatting fails
|
|
304
|
+
|
|
305
|
+
# Generate response (returns dict with content and reasoning_content keys)
|
|
306
|
+
res_dict = await generator.generate(
|
|
307
|
+
prompt=formatted_prompt,
|
|
308
|
+
system_prompt=formatted_system,
|
|
309
|
+
model=model,
|
|
310
|
+
temperature=temp,
|
|
311
|
+
max_tokens=max_tokens
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
response_text = res_dict["content"]
|
|
315
|
+
reasoning_content = res_dict["reasoning_content"]
|
|
316
|
+
|
|
317
|
+
if step.cot_field:
|
|
318
|
+
if not reasoning_content:
|
|
319
|
+
# If no native reasoning content was returned by the API, parse it using XML-like tags
|
|
320
|
+
response_text, reasoning_content = extract_cot_from_text(response_text)
|
|
321
|
+
|
|
322
|
+
record[step.cot_field] = reasoning_content
|
|
323
|
+
record[step.output_field] = response_text
|
|
324
|
+
else:
|
|
325
|
+
record[step.output_field] = response_text
|
|
326
|
+
|
|
327
|
+
return record
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
async def process_record_pipeline(
|
|
331
|
+
record: Dict[str, Any],
|
|
332
|
+
steps: List[StepConfig],
|
|
333
|
+
generator: Generator,
|
|
334
|
+
config: PipelineConfig,
|
|
335
|
+
semaphore: asyncio.Semaphore,
|
|
336
|
+
output_file: Any,
|
|
337
|
+
write_lock: asyncio.Lock
|
|
338
|
+
) -> Dict[str, Any]:
|
|
339
|
+
"""
|
|
340
|
+
Passes a record through all pipeline steps sequentially,
|
|
341
|
+
then writes the result incrementally to the JSONL output file.
|
|
342
|
+
"""
|
|
343
|
+
current_record = record.copy()
|
|
344
|
+
|
|
345
|
+
for step in steps:
|
|
346
|
+
current_record = await execute_step_for_record(
|
|
347
|
+
record=current_record,
|
|
348
|
+
step=step,
|
|
349
|
+
generator=generator,
|
|
350
|
+
global_model=config.api.model,
|
|
351
|
+
global_temp=config.api.temperature,
|
|
352
|
+
global_max_tokens=config.api.max_tokens,
|
|
353
|
+
semaphore=semaphore
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Save incrementally to JSONL file under the lock
|
|
357
|
+
async with write_lock:
|
|
358
|
+
output_file.write(json.dumps(current_record) + "\n")
|
|
359
|
+
output_file.flush()
|
|
360
|
+
|
|
361
|
+
return current_record
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
async def run_generation_pipeline_async(config: PipelineConfig) -> List[Dict[str, Any]]:
|
|
365
|
+
"""
|
|
366
|
+
Executes the generation phase asynchronously.
|
|
367
|
+
"""
|
|
368
|
+
# 1. Initialize generator client
|
|
369
|
+
generator = Generator(
|
|
370
|
+
api_key=config.api.api_key,
|
|
371
|
+
base_url=config.api.base_url,
|
|
372
|
+
api_type=config.api.api_type
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# 2. Load inputs
|
|
376
|
+
records = load_inputs(config, generator)
|
|
377
|
+
if not records:
|
|
378
|
+
logger.warning("No input records found to process.")
|
|
379
|
+
return []
|
|
380
|
+
|
|
381
|
+
logger.info(f"Loaded {len(records)} records. Starting execution steps...")
|
|
382
|
+
|
|
383
|
+
# Ensure output file directory exists
|
|
384
|
+
jsonl_path = config.output.jsonl_path
|
|
385
|
+
os.makedirs(os.path.dirname(os.path.abspath(jsonl_path)) or ".", exist_ok=True)
|
|
386
|
+
|
|
387
|
+
# 3. Process records in parallel with concurrency semaphore
|
|
388
|
+
semaphore = asyncio.Semaphore(config.api.concurrency_limit)
|
|
389
|
+
write_lock = asyncio.Lock()
|
|
390
|
+
|
|
391
|
+
# Open the file in write mode to clear previous runs
|
|
392
|
+
processed_records = []
|
|
393
|
+
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
394
|
+
tasks = [
|
|
395
|
+
process_record_pipeline(
|
|
396
|
+
record=record,
|
|
397
|
+
steps=config.steps,
|
|
398
|
+
generator=generator,
|
|
399
|
+
config=config,
|
|
400
|
+
semaphore=semaphore,
|
|
401
|
+
output_file=f,
|
|
402
|
+
write_lock=write_lock
|
|
403
|
+
)
|
|
404
|
+
for record in records
|
|
405
|
+
]
|
|
406
|
+
|
|
407
|
+
logger.info(f"Running pipeline with concurrency limit={config.api.concurrency_limit}...")
|
|
408
|
+
# Use tqdm to show async progress bar
|
|
409
|
+
for task in tqdm.asyncio.tqdm.as_completed(tasks, total=len(records), desc="Generating dataset"):
|
|
410
|
+
res = await task
|
|
411
|
+
processed_records.append(res)
|
|
412
|
+
|
|
413
|
+
logger.info(f"Generation phase complete. Wrote {len(processed_records)} records to {jsonl_path}")
|
|
414
|
+
return processed_records
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def run_pipeline(config: PipelineConfig) -> List[Dict[str, Any]]:
|
|
418
|
+
"""
|
|
419
|
+
Programmatic entrypoint to run the entire dataset pipeline:
|
|
420
|
+
1. Runs generation steps asynchronously (incremental JSONL writing).
|
|
421
|
+
2. Chunks the JSONL to Parquet files under max_parquet_chunk_size_mb.
|
|
422
|
+
3. Uploads Parquet directory to Hugging Face if config.huggingface is provided.
|
|
423
|
+
"""
|
|
424
|
+
logger.info("Starting chapman-datagen pipeline...")
|
|
425
|
+
|
|
426
|
+
try:
|
|
427
|
+
loop = asyncio.get_running_loop()
|
|
428
|
+
except RuntimeError:
|
|
429
|
+
loop = None
|
|
430
|
+
|
|
431
|
+
if loop and loop.is_running():
|
|
432
|
+
# Running in environment with active loop (e.g. Jupyter notebook or another framework)
|
|
433
|
+
# Check if we are running in an existing loop or create a task
|
|
434
|
+
try:
|
|
435
|
+
import nest_asyncio
|
|
436
|
+
nest_asyncio.apply()
|
|
437
|
+
except ImportError:
|
|
438
|
+
pass
|
|
439
|
+
results = loop.run_until_complete(run_generation_pipeline_async(config))
|
|
440
|
+
else:
|
|
441
|
+
results = asyncio.run(run_generation_pipeline_async(config))
|
|
442
|
+
|
|
443
|
+
# Chunks Parquet
|
|
444
|
+
logger.info("Starting Parquet chunking phase...")
|
|
445
|
+
parquet_files = chunk_jsonl_to_parquet(
|
|
446
|
+
jsonl_path=config.output.jsonl_path,
|
|
447
|
+
output_dir=config.output.parquet_dir,
|
|
448
|
+
max_chunk_size_mb=config.output.max_parquet_chunk_size_mb
|
|
449
|
+
)
|
|
450
|
+
logger.info(f"Successfully chunked into {len(parquet_files)} Parquet files.")
|
|
451
|
+
|
|
452
|
+
# Upload to HF if configured
|
|
453
|
+
if config.huggingface:
|
|
454
|
+
logger.info("Starting Hugging Face uploading phase...")
|
|
455
|
+
repo_url = upload_to_huggingface(
|
|
456
|
+
folder_path=config.output.parquet_dir,
|
|
457
|
+
repo_id=config.huggingface.repo_id,
|
|
458
|
+
token=config.huggingface.token,
|
|
459
|
+
private=config.huggingface.private
|
|
460
|
+
)
|
|
461
|
+
logger.info(f"Pipeline execution complete. Dataset uploaded to: {repo_url}")
|
|
462
|
+
else:
|
|
463
|
+
logger.info("No Hugging Face repository configured. Skipping upload phase.")
|
|
464
|
+
logger.info("Pipeline execution complete.")
|
|
465
|
+
|
|
466
|
+
return results
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from huggingface_hub import HfApi
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
def upload_to_huggingface(
|
|
9
|
+
folder_path: str,
|
|
10
|
+
repo_id: str,
|
|
11
|
+
token: Optional[str] = None,
|
|
12
|
+
private: bool = False
|
|
13
|
+
) -> str:
|
|
14
|
+
"""
|
|
15
|
+
Creates/updates a Hugging Face dataset repository and uploads all files in `folder_path` to it.
|
|
16
|
+
Returns the URL of the uploaded repository.
|
|
17
|
+
"""
|
|
18
|
+
# Hugging Face Hub token resolution
|
|
19
|
+
# If not explicitly provided, HfApi will default to HF_TOKEN env variable or local cache
|
|
20
|
+
api = HfApi(token=token)
|
|
21
|
+
|
|
22
|
+
logger.info(f"Authenticating and checking/creating Hugging Face dataset repository: {repo_id}")
|
|
23
|
+
try:
|
|
24
|
+
api.create_repo(
|
|
25
|
+
repo_id=repo_id,
|
|
26
|
+
repo_type="dataset",
|
|
27
|
+
private=private,
|
|
28
|
+
exist_ok=True
|
|
29
|
+
)
|
|
30
|
+
except Exception as e:
|
|
31
|
+
logger.error(f"Failed to create/verify Hugging Face repository {repo_id}: {e}")
|
|
32
|
+
raise ValueError(f"Hugging Face repository verification failed: {e}")
|
|
33
|
+
|
|
34
|
+
logger.info(f"Uploading files from directory {folder_path} to HF dataset repo {repo_id}...")
|
|
35
|
+
try:
|
|
36
|
+
api.upload_folder(
|
|
37
|
+
folder_path=folder_path,
|
|
38
|
+
repo_id=repo_id,
|
|
39
|
+
repo_type="dataset",
|
|
40
|
+
)
|
|
41
|
+
except Exception as e:
|
|
42
|
+
logger.error(f"Failed to upload folder {folder_path} to Hugging Face: {e}")
|
|
43
|
+
raise ValueError(f"Hugging Face upload failed: {e}")
|
|
44
|
+
|
|
45
|
+
repo_url = f"https://huggingface.co/datasets/{repo_id}"
|
|
46
|
+
logger.info(f"Successfully uploaded dataset to: {repo_url}")
|
|
47
|
+
return repo_url
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: chapman-datagen
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A highly generalizable prompt-based dataset generator and Hugging Face uploader
|
|
5
|
+
License: MIT
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
8
|
+
Classifier: Operating System :: OS Independent
|
|
9
|
+
Requires-Python: >=3.9
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
Requires-Dist: openai>=1.0.0
|
|
12
|
+
Requires-Dist: pandas>=1.5.0
|
|
13
|
+
Requires-Dist: pyarrow>=10.0.0
|
|
14
|
+
Requires-Dist: huggingface_hub>=0.15.0
|
|
15
|
+
Requires-Dist: tqdm>=4.65.0
|
|
16
|
+
Requires-Dist: jinja2>=3.0.0
|
|
17
|
+
Requires-Dist: datasets>=2.0.0
|
|
18
|
+
|
|
19
|
+
# chapman-datagen
|
|
20
|
+
|
|
21
|
+
A highly generalizable prompt-based dataset generator and Hugging Face uploader.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
chapman_datagen/__init__.py
|
|
4
|
+
chapman_datagen/chunker.py
|
|
5
|
+
chapman_datagen/cli.py
|
|
6
|
+
chapman_datagen/config.py
|
|
7
|
+
chapman_datagen/generator.py
|
|
8
|
+
chapman_datagen/pipeline.py
|
|
9
|
+
chapman_datagen/uploader.py
|
|
10
|
+
chapman_datagen.egg-info/PKG-INFO
|
|
11
|
+
chapman_datagen.egg-info/SOURCES.txt
|
|
12
|
+
chapman_datagen.egg-info/dependency_links.txt
|
|
13
|
+
chapman_datagen.egg-info/entry_points.txt
|
|
14
|
+
chapman_datagen.egg-info/requires.txt
|
|
15
|
+
chapman_datagen.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
chapman_datagen
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0.0", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "chapman-datagen"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "A highly generalizable prompt-based dataset generator and Hugging Face uploader"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.9"
|
|
11
|
+
license = { text = "MIT" }
|
|
12
|
+
classifiers = [
|
|
13
|
+
"Programming Language :: Python :: 3",
|
|
14
|
+
"License :: OSI Approved :: MIT License",
|
|
15
|
+
"Operating System :: OS Independent",
|
|
16
|
+
]
|
|
17
|
+
dependencies = [
|
|
18
|
+
"openai>=1.0.0",
|
|
19
|
+
"pandas>=1.5.0",
|
|
20
|
+
"pyarrow>=10.0.0",
|
|
21
|
+
"huggingface_hub>=0.15.0",
|
|
22
|
+
"tqdm>=4.65.0",
|
|
23
|
+
"jinja2>=3.0.0",
|
|
24
|
+
"datasets>=2.0.0",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
[project.scripts]
|
|
28
|
+
chapman-datagen = "chapman_datagen.cli:main"
|
|
29
|
+
|
|
30
|
+
[tool.setuptools.packages.find]
|
|
31
|
+
where = ["."]
|
|
32
|
+
include = ["chapman_datagen*"]
|