chatan 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatan/__init__.py +9 -0
- chatan/dataset.py +131 -0
- chatan/generator.py +114 -0
- chatan/sampler.py +143 -0
- chatan-0.0.1.dist-info/METADATA +83 -0
- chatan-0.0.1.dist-info/RECORD +7 -0
- chatan-0.0.1.dist-info/WHEEL +4 -0
chatan/__init__.py
ADDED
chatan/dataset.py
ADDED
@@ -0,0 +1,131 @@
|
|
1
|
+
"""Dataset creation and manipulation."""
|
2
|
+
|
3
|
+
from typing import Dict, Any, Union, Optional, List, Callable
|
4
|
+
import pandas as pd
|
5
|
+
from datasets import Dataset as HFDataset
|
6
|
+
from .generator import GeneratorFunction
|
7
|
+
from .sampler import SampleFunction
|
8
|
+
|
9
|
+
|
10
|
+
class Dataset:
|
11
|
+
"""Synthetic dataset generator."""
|
12
|
+
|
13
|
+
def __init__(self, schema: Union[Dict[str, Any], str], n: int = 100):
|
14
|
+
"""
|
15
|
+
Initialize dataset with schema.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
schema: Either a dict mapping column names to generators/samplers,
|
19
|
+
or a string prompt for high-level dataset generation
|
20
|
+
n: Number of samples to generate
|
21
|
+
"""
|
22
|
+
if isinstance(schema, str):
|
23
|
+
# High-level prompting - we'll implement this later
|
24
|
+
raise NotImplementedError("High-level prompting not yet implemented")
|
25
|
+
|
26
|
+
self.schema = schema
|
27
|
+
self.n = n
|
28
|
+
self._data = None
|
29
|
+
|
30
|
+
def generate(self, n: Optional[int] = None) -> pd.DataFrame:
|
31
|
+
"""Generate the dataset."""
|
32
|
+
num_samples = n or self.n
|
33
|
+
|
34
|
+
# Build dependency graph
|
35
|
+
dependencies = self._build_dependency_graph()
|
36
|
+
execution_order = self._topological_sort(dependencies)
|
37
|
+
|
38
|
+
# Generate data
|
39
|
+
data = []
|
40
|
+
for i in range(num_samples):
|
41
|
+
row = {}
|
42
|
+
for column in execution_order:
|
43
|
+
value = self._generate_value(column, row)
|
44
|
+
row[column] = value
|
45
|
+
data.append(row)
|
46
|
+
|
47
|
+
self._data = pd.DataFrame(data)
|
48
|
+
return self._data
|
49
|
+
|
50
|
+
def _build_dependency_graph(self) -> Dict[str, List[str]]:
|
51
|
+
"""Build dependency graph from schema."""
|
52
|
+
dependencies = {}
|
53
|
+
|
54
|
+
for column, func in self.schema.items():
|
55
|
+
deps = []
|
56
|
+
if isinstance(func, GeneratorFunction):
|
57
|
+
# Extract column references from prompt template
|
58
|
+
import re
|
59
|
+
template = func.prompt_template
|
60
|
+
deps = re.findall(r'\{(\w+)\}', template)
|
61
|
+
|
62
|
+
dependencies[column] = [dep for dep in deps if dep in self.schema]
|
63
|
+
|
64
|
+
return dependencies
|
65
|
+
|
66
|
+
def _topological_sort(self, dependencies: Dict[str, List[str]]) -> List[str]:
|
67
|
+
"""Topologically sort columns by dependencies."""
|
68
|
+
visited = set()
|
69
|
+
temp_visited = set()
|
70
|
+
result = []
|
71
|
+
|
72
|
+
def visit(column):
|
73
|
+
if column in temp_visited:
|
74
|
+
raise ValueError(f"Circular dependency detected involving {column}")
|
75
|
+
if column in visited:
|
76
|
+
return
|
77
|
+
|
78
|
+
temp_visited.add(column)
|
79
|
+
for dep in dependencies.get(column, []):
|
80
|
+
visit(dep)
|
81
|
+
temp_visited.remove(column)
|
82
|
+
visited.add(column)
|
83
|
+
result.append(column)
|
84
|
+
|
85
|
+
for column in self.schema:
|
86
|
+
visit(column)
|
87
|
+
|
88
|
+
return result
|
89
|
+
|
90
|
+
def _generate_value(self, column: str, context: Dict[str, Any]) -> Any:
|
91
|
+
"""Generate a single value for a column."""
|
92
|
+
func = self.schema[column]
|
93
|
+
|
94
|
+
if isinstance(func, (GeneratorFunction, SampleFunction)):
|
95
|
+
return func(context)
|
96
|
+
elif callable(func):
|
97
|
+
return func(context)
|
98
|
+
else:
|
99
|
+
# Static value
|
100
|
+
return func
|
101
|
+
|
102
|
+
def to_pandas(self) -> pd.DataFrame:
|
103
|
+
"""Convert to pandas DataFrame."""
|
104
|
+
if self._data is None:
|
105
|
+
self.generate()
|
106
|
+
return self._data
|
107
|
+
|
108
|
+
def to_huggingface(self) -> HFDataset:
|
109
|
+
"""Convert to HuggingFace Dataset."""
|
110
|
+
if self._data is None:
|
111
|
+
self.generate()
|
112
|
+
return HFDataset.from_pandas(self._data)
|
113
|
+
|
114
|
+
def save(self, path: str, format: str = "parquet") -> None:
|
115
|
+
"""Save dataset to file."""
|
116
|
+
if self._data is None:
|
117
|
+
self.generate()
|
118
|
+
|
119
|
+
if format == "parquet":
|
120
|
+
self._data.to_parquet(path)
|
121
|
+
elif format == "csv":
|
122
|
+
self._data.to_csv(path, index=False)
|
123
|
+
elif format == "json":
|
124
|
+
self._data.to_json(path, orient="records")
|
125
|
+
else:
|
126
|
+
raise ValueError(f"Unsupported format: {format}")
|
127
|
+
|
128
|
+
|
129
|
+
def dataset(schema: Union[Dict[str, Any], str], n: int = 100) -> Dataset:
|
130
|
+
"""Create a synthetic dataset."""
|
131
|
+
return Dataset(schema, n)
|
chatan/generator.py
ADDED
@@ -0,0 +1,114 @@
|
|
1
|
+
"""LLM generators for synthetic data creation."""
|
2
|
+
|
3
|
+
from typing import Dict, Any, Optional, Union, List
|
4
|
+
import openai
|
5
|
+
import anthropic
|
6
|
+
from abc import ABC, abstractmethod
|
7
|
+
|
8
|
+
|
9
|
+
class BaseGenerator(ABC):
|
10
|
+
"""Base class for LLM generators."""
|
11
|
+
|
12
|
+
@abstractmethod
|
13
|
+
def generate(self, prompt: str, **kwargs) -> str:
|
14
|
+
"""Generate content from a prompt."""
|
15
|
+
pass
|
16
|
+
|
17
|
+
|
18
|
+
class OpenAIGenerator(BaseGenerator):
|
19
|
+
"""OpenAI GPT generator."""
|
20
|
+
|
21
|
+
def __init__(self, api_key: str, model: str = "gpt-3.5-turbo", **kwargs):
|
22
|
+
self.client = openai.OpenAI(api_key=api_key)
|
23
|
+
self.model = model
|
24
|
+
self.default_kwargs = kwargs
|
25
|
+
|
26
|
+
def generate(self, prompt: str, **kwargs) -> str:
|
27
|
+
"""Generate content using OpenAI API."""
|
28
|
+
merged_kwargs = {**self.default_kwargs, **kwargs}
|
29
|
+
|
30
|
+
response = self.client.chat.completions.create(
|
31
|
+
model=self.model,
|
32
|
+
messages=[{"role": "user", "content": prompt}],
|
33
|
+
**merged_kwargs
|
34
|
+
)
|
35
|
+
return response.choices[0].message.content.strip()
|
36
|
+
|
37
|
+
|
38
|
+
class AnthropicGenerator(BaseGenerator):
|
39
|
+
"""Anthropic Claude generator."""
|
40
|
+
|
41
|
+
def __init__(self, api_key: str, model: str = "claude-3-sonnet-20240229", **kwargs):
|
42
|
+
self.client = anthropic.Anthropic(api_key=api_key)
|
43
|
+
self.model = model
|
44
|
+
self.default_kwargs = kwargs
|
45
|
+
|
46
|
+
def generate(self, prompt: str, **kwargs) -> str:
|
47
|
+
"""Generate content using Anthropic API."""
|
48
|
+
merged_kwargs = {**self.default_kwargs, **kwargs}
|
49
|
+
|
50
|
+
response = self.client.messages.create(
|
51
|
+
model=self.model,
|
52
|
+
messages=[{"role": "user", "content": prompt}],
|
53
|
+
max_tokens=merged_kwargs.pop("max_tokens", 1000),
|
54
|
+
**merged_kwargs
|
55
|
+
)
|
56
|
+
return response.content[0].text.strip()
|
57
|
+
|
58
|
+
|
59
|
+
class GeneratorFunction:
|
60
|
+
"""Callable generator function for use in dataset schemas."""
|
61
|
+
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
generator: BaseGenerator,
|
65
|
+
prompt_template: str,
|
66
|
+
variables: Optional[Dict[str, Any]] = None,
|
67
|
+
):
|
68
|
+
self.generator = generator
|
69
|
+
self.prompt_template = prompt_template
|
70
|
+
self.variables = variables or {}
|
71
|
+
|
72
|
+
def __call__(self, context: Dict[str, Any]) -> str:
|
73
|
+
"""Generate content with context substitution."""
|
74
|
+
merged = dict(context)
|
75
|
+
for key, value in self.variables.items():
|
76
|
+
merged[key] = value(context) if callable(value) else value
|
77
|
+
|
78
|
+
prompt = self.prompt_template.format(**merged)
|
79
|
+
result = self.generator.generate(prompt)
|
80
|
+
return result.strip() if isinstance(result, str) else result
|
81
|
+
|
82
|
+
|
83
|
+
class GeneratorClient:
|
84
|
+
"""Main interface for creating generators."""
|
85
|
+
|
86
|
+
def __init__(self, provider: str, api_key: str, **kwargs):
|
87
|
+
if provider.lower() == "openai":
|
88
|
+
self._generator = OpenAIGenerator(api_key, **kwargs)
|
89
|
+
elif provider.lower() == "anthropic":
|
90
|
+
self._generator = AnthropicGenerator(api_key, **kwargs)
|
91
|
+
else:
|
92
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
93
|
+
|
94
|
+
def __call__(self, prompt_template: str, **variables) -> GeneratorFunction:
|
95
|
+
"""Create a generator function.
|
96
|
+
|
97
|
+
Parameters
|
98
|
+
----------
|
99
|
+
prompt_template:
|
100
|
+
Template string for the prompt.
|
101
|
+
**variables:
|
102
|
+
Optional variables to include when formatting the prompt. If a value
|
103
|
+
is callable it will be invoked with the row context when the
|
104
|
+
generator function is executed.
|
105
|
+
"""
|
106
|
+
return GeneratorFunction(self._generator, prompt_template, variables)
|
107
|
+
|
108
|
+
|
109
|
+
# Factory function
|
110
|
+
def generator(provider: str = "openai", api_key: Optional[str] = None, **kwargs) -> GeneratorClient:
|
111
|
+
"""Create a generator client."""
|
112
|
+
if api_key is None:
|
113
|
+
raise ValueError("API key is required")
|
114
|
+
return GeneratorClient(provider, api_key, **kwargs)
|
chatan/sampler.py
ADDED
@@ -0,0 +1,143 @@
|
|
1
|
+
"""Sampling functions for synthetic data creation."""
|
2
|
+
|
3
|
+
import random
|
4
|
+
import uuid
|
5
|
+
from datetime import datetime, timedelta
|
6
|
+
from typing import Any, Dict, List, Union, Optional, Callable
|
7
|
+
import pandas as pd
|
8
|
+
from datasets import Dataset as HFDataset
|
9
|
+
|
10
|
+
|
11
|
+
class SampleFunction:
|
12
|
+
"""Base class for sampling functions."""
|
13
|
+
|
14
|
+
def __call__(self, context: Dict[str, Any] = None) -> Any:
|
15
|
+
"""Generate a sample value."""
|
16
|
+
raise NotImplementedError
|
17
|
+
|
18
|
+
|
19
|
+
class ChoiceSampler(SampleFunction):
|
20
|
+
"""Sample from a list of choices."""
|
21
|
+
|
22
|
+
def __init__(self, choices: Union[List[Any], Dict[str, Any]]):
|
23
|
+
if isinstance(choices, dict):
|
24
|
+
self.choices = list(choices.keys())
|
25
|
+
self.weights = list(choices.values())
|
26
|
+
else:
|
27
|
+
self.choices = choices
|
28
|
+
self.weights = None
|
29
|
+
|
30
|
+
def __call__(self, context: Dict[str, Any] = None) -> Any:
|
31
|
+
if self.weights:
|
32
|
+
return random.choices(self.choices, weights=self.weights, k=1)[0]
|
33
|
+
return random.choice(self.choices)
|
34
|
+
|
35
|
+
|
36
|
+
class WeightedSampler(SampleFunction):
|
37
|
+
"""Sample from weighted choices."""
|
38
|
+
|
39
|
+
def __init__(self, choices: Dict[str, float]):
|
40
|
+
self.choices = list(choices.keys())
|
41
|
+
self.weights = list(choices.values())
|
42
|
+
|
43
|
+
def __call__(self, context: Dict[str, Any] = None) -> Any:
|
44
|
+
return random.choices(self.choices, weights=self.weights, k=1)[0]
|
45
|
+
|
46
|
+
|
47
|
+
class UUIDSampler(SampleFunction):
|
48
|
+
"""Generate UUID strings."""
|
49
|
+
|
50
|
+
def __call__(self, context: Dict[str, Any] = None) -> str:
|
51
|
+
return str(uuid.uuid4())
|
52
|
+
|
53
|
+
|
54
|
+
class DatetimeSampler(SampleFunction):
|
55
|
+
"""Sample random datetimes."""
|
56
|
+
|
57
|
+
def __init__(self, start: str, end: str, format: str = "%Y-%m-%d"):
|
58
|
+
self.start = datetime.strptime(start, format)
|
59
|
+
self.end = datetime.strptime(end, format)
|
60
|
+
self.delta = self.end - self.start
|
61
|
+
|
62
|
+
def __call__(self, context: Dict[str, Any] = None) -> datetime:
|
63
|
+
random_days = random.randint(0, self.delta.days)
|
64
|
+
return self.start + timedelta(days=random_days)
|
65
|
+
|
66
|
+
|
67
|
+
class RangeSampler(SampleFunction):
|
68
|
+
"""Sample from numeric ranges."""
|
69
|
+
|
70
|
+
def __init__(self, start: Union[int, float], end: Union[int, float],
|
71
|
+
step: Optional[Union[int, float]] = None):
|
72
|
+
self.start = start
|
73
|
+
self.end = end
|
74
|
+
self.step = step
|
75
|
+
self.is_int = isinstance(start, int) and isinstance(end, int)
|
76
|
+
|
77
|
+
def __call__(self, context: Dict[str, Any] = None) -> Union[int, float]:
|
78
|
+
if self.is_int:
|
79
|
+
return random.randint(self.start, self.end)
|
80
|
+
return random.uniform(self.start, self.end)
|
81
|
+
|
82
|
+
|
83
|
+
class DatasetSampler(SampleFunction):
|
84
|
+
"""Sample from existing dataset columns."""
|
85
|
+
|
86
|
+
def __init__(self, dataset: Union[pd.DataFrame, HFDataset, Dict],
|
87
|
+
column: str, default: Optional[SampleFunction] = None):
|
88
|
+
if isinstance(dataset, pd.DataFrame):
|
89
|
+
self.values = dataset[column].tolist()
|
90
|
+
elif isinstance(dataset, HFDataset):
|
91
|
+
self.values = dataset[column]
|
92
|
+
elif isinstance(dataset, dict):
|
93
|
+
self.values = dataset[column]
|
94
|
+
else:
|
95
|
+
raise ValueError("Unsupported dataset type")
|
96
|
+
|
97
|
+
self.default = default
|
98
|
+
|
99
|
+
def __call__(self, context: Dict[str, Any] = None) -> Any:
|
100
|
+
if not self.values and self.default:
|
101
|
+
return self.default(context)
|
102
|
+
return random.choice(self.values)
|
103
|
+
|
104
|
+
|
105
|
+
# Factory functions for the sample namespace
|
106
|
+
class SampleNamespace:
|
107
|
+
"""Namespace for sampling functions."""
|
108
|
+
|
109
|
+
@staticmethod
|
110
|
+
def choice(choices: Union[List[Any], Dict[str, Any]]) -> ChoiceSampler:
|
111
|
+
"""Sample from choices."""
|
112
|
+
return ChoiceSampler(choices)
|
113
|
+
|
114
|
+
@staticmethod
|
115
|
+
def weighted(choices: Dict[str, float]) -> WeightedSampler:
|
116
|
+
"""Sample from weighted choices."""
|
117
|
+
return WeightedSampler(choices)
|
118
|
+
|
119
|
+
@staticmethod
|
120
|
+
def uuid() -> UUIDSampler:
|
121
|
+
"""Generate UUIDs."""
|
122
|
+
return UUIDSampler()
|
123
|
+
|
124
|
+
@staticmethod
|
125
|
+
def datetime(start: str, end: str, format: str = "%Y-%m-%d") -> DatetimeSampler:
|
126
|
+
"""Sample random datetimes."""
|
127
|
+
return DatetimeSampler(start, end, format)
|
128
|
+
|
129
|
+
@staticmethod
|
130
|
+
def range(start: Union[int, float], end: Union[int, float],
|
131
|
+
step: Optional[Union[int, float]] = None) -> RangeSampler:
|
132
|
+
"""Sample from numeric ranges."""
|
133
|
+
return RangeSampler(start, end, step)
|
134
|
+
|
135
|
+
@staticmethod
|
136
|
+
def from_dataset(dataset: Union[pd.DataFrame, HFDataset, Dict],
|
137
|
+
column: str, default: Optional[SampleFunction] = None) -> DatasetSampler:
|
138
|
+
"""Sample from existing dataset."""
|
139
|
+
return DatasetSampler(dataset, column, default)
|
140
|
+
|
141
|
+
|
142
|
+
# Export the sample namespace
|
143
|
+
sample = SampleNamespace()
|
@@ -0,0 +1,83 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: chatan
|
3
|
+
Version: 0.0.1
|
4
|
+
Summary: Create synthetic datasets with LLM generators and samplers
|
5
|
+
Project-URL: Documentation, https://github.com/cdreetz/chatan#readme
|
6
|
+
Project-URL: Issues, https://github.com/cdreetz/chatan/issues
|
7
|
+
Project-URL: Source, https://github.com/cdreetz/chatan
|
8
|
+
Author-email: Christian Reetz <cdreetz@gmail.com>
|
9
|
+
License-Expression: MIT
|
10
|
+
Keywords: dataset generation,llm,machine learning,synthetic data
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
12
|
+
Classifier: Intended Audience :: Developers
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
14
|
+
Classifier: Operating System :: OS Independent
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
16
|
+
Classifier: Programming Language :: Python :: 3.8
|
17
|
+
Classifier: Programming Language :: Python :: 3.9
|
18
|
+
Classifier: Programming Language :: Python :: 3.10
|
19
|
+
Classifier: Programming Language :: Python :: 3.11
|
20
|
+
Classifier: Programming Language :: Python :: 3.12
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
22
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
23
|
+
Requires-Python: >=3.8
|
24
|
+
Requires-Dist: anthropic>=0.7.0
|
25
|
+
Requires-Dist: datasets>=2.0.0
|
26
|
+
Requires-Dist: numpy>=1.20.0
|
27
|
+
Requires-Dist: openai>=1.0.0
|
28
|
+
Requires-Dist: pandas>=1.3.0
|
29
|
+
Requires-Dist: pydantic>=2.0.0
|
30
|
+
Description-Content-Type: text/markdown
|
31
|
+
|
32
|
+
## Examples
|
33
|
+
|
34
|
+
Prompt a dataset
|
35
|
+
|
36
|
+
```
|
37
|
+
import chatan
|
38
|
+
|
39
|
+
gen = chatan.generator.client("YOUR_OPENAI_API_KEY")
|
40
|
+
ds = chatan.dataset("create a QA dataset for finetuning an LLM on pharmacology")
|
41
|
+
```
|
42
|
+
|
43
|
+
Creating datasets with different data mixes
|
44
|
+
|
45
|
+
```
|
46
|
+
import uuid
|
47
|
+
from chatan import dataset, generator, mix
|
48
|
+
|
49
|
+
gen = generator.client("YOUR_OPENAI_API_KEY")
|
50
|
+
#generator.client("anthropic", "YOUR_ANTHROPIC_API_KEY")
|
51
|
+
|
52
|
+
mix = {
|
53
|
+
"implementation": "Can you implement a matmul kernel in Triton",
|
54
|
+
"conversion": "Convert this pytorch model to Triton",
|
55
|
+
"explanation": "What memory access optimizations are being used here?"
|
56
|
+
}
|
57
|
+
|
58
|
+
ds = dataset({
|
59
|
+
"id": uuid,
|
60
|
+
"task": sample.choice(mix),
|
61
|
+
"prompt": gen("write a prompt for {task}"),
|
62
|
+
"response": gen("write a response to {prompt}"),
|
63
|
+
)}
|
64
|
+
```
|
65
|
+
|
66
|
+
Augment datasets
|
67
|
+
|
68
|
+
```
|
69
|
+
import uuid
|
70
|
+
from chatan import dataset, generator
|
71
|
+
from dataset import load_dataset
|
72
|
+
|
73
|
+
gen = generator.client("YOUR_OPENAI_API_KEY")
|
74
|
+
hf_dataset = load_dataset("GPU_MODE/KernelBook")
|
75
|
+
|
76
|
+
ds = dataset({
|
77
|
+
"id": sample.from_dataset(hf_data, "id", default=sample.uuid()),
|
78
|
+
"prompt": sample.from_dataset(hf_data, "prompt", aug=gen("provide a variation of this prompt")),
|
79
|
+
"response": gen("write a response to {prompt}")
|
80
|
+
|
81
|
+
})
|
82
|
+
|
83
|
+
```
|
@@ -0,0 +1,7 @@
|
|
1
|
+
chatan/__init__.py,sha256=KRSqdfRz3wKEnBJrOaHcptN5bSOrMj3uqaLXym-PrEQ,233
|
2
|
+
chatan/dataset.py,sha256=t6RrrQchsD2dROD886IfnlXWnn-F-IAWMcEIK0SS3xg,4358
|
3
|
+
chatan/generator.py,sha256=zAOVxd2iY_KMK21dCMeFhvsJu2q8NEwch5jGJP-Me9s,3914
|
4
|
+
chatan/sampler.py,sha256=0X6AVQK20py4SwKnsppZC2yAZnP_jBhRxt9MfT1e-k4,4812
|
5
|
+
chatan-0.0.1.dist-info/METADATA,sha256=poFjbYZG9CnvxHy_U2GKHaitTdfK3xbGB9q1WtUaaDQ,2549
|
6
|
+
chatan-0.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
chatan-0.0.1.dist-info/RECORD,,
|