chatan 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
chatan/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ """Minos: Create synthetic datasets with LLM generators and samplers."""
2
+
3
+ __version__ = "0.0.1"
4
+
5
+ from .dataset import dataset
6
+ from .generator import generator
7
+ from .sampler import sample
8
+
9
+ __all__ = ["dataset", "generator", "sample"]
chatan/dataset.py ADDED
@@ -0,0 +1,131 @@
1
+ """Dataset creation and manipulation."""
2
+
3
+ from typing import Dict, Any, Union, Optional, List, Callable
4
+ import pandas as pd
5
+ from datasets import Dataset as HFDataset
6
+ from .generator import GeneratorFunction
7
+ from .sampler import SampleFunction
8
+
9
+
10
+ class Dataset:
11
+ """Synthetic dataset generator."""
12
+
13
+ def __init__(self, schema: Union[Dict[str, Any], str], n: int = 100):
14
+ """
15
+ Initialize dataset with schema.
16
+
17
+ Args:
18
+ schema: Either a dict mapping column names to generators/samplers,
19
+ or a string prompt for high-level dataset generation
20
+ n: Number of samples to generate
21
+ """
22
+ if isinstance(schema, str):
23
+ # High-level prompting - we'll implement this later
24
+ raise NotImplementedError("High-level prompting not yet implemented")
25
+
26
+ self.schema = schema
27
+ self.n = n
28
+ self._data = None
29
+
30
+ def generate(self, n: Optional[int] = None) -> pd.DataFrame:
31
+ """Generate the dataset."""
32
+ num_samples = n or self.n
33
+
34
+ # Build dependency graph
35
+ dependencies = self._build_dependency_graph()
36
+ execution_order = self._topological_sort(dependencies)
37
+
38
+ # Generate data
39
+ data = []
40
+ for i in range(num_samples):
41
+ row = {}
42
+ for column in execution_order:
43
+ value = self._generate_value(column, row)
44
+ row[column] = value
45
+ data.append(row)
46
+
47
+ self._data = pd.DataFrame(data)
48
+ return self._data
49
+
50
+ def _build_dependency_graph(self) -> Dict[str, List[str]]:
51
+ """Build dependency graph from schema."""
52
+ dependencies = {}
53
+
54
+ for column, func in self.schema.items():
55
+ deps = []
56
+ if isinstance(func, GeneratorFunction):
57
+ # Extract column references from prompt template
58
+ import re
59
+ template = func.prompt_template
60
+ deps = re.findall(r'\{(\w+)\}', template)
61
+
62
+ dependencies[column] = [dep for dep in deps if dep in self.schema]
63
+
64
+ return dependencies
65
+
66
+ def _topological_sort(self, dependencies: Dict[str, List[str]]) -> List[str]:
67
+ """Topologically sort columns by dependencies."""
68
+ visited = set()
69
+ temp_visited = set()
70
+ result = []
71
+
72
+ def visit(column):
73
+ if column in temp_visited:
74
+ raise ValueError(f"Circular dependency detected involving {column}")
75
+ if column in visited:
76
+ return
77
+
78
+ temp_visited.add(column)
79
+ for dep in dependencies.get(column, []):
80
+ visit(dep)
81
+ temp_visited.remove(column)
82
+ visited.add(column)
83
+ result.append(column)
84
+
85
+ for column in self.schema:
86
+ visit(column)
87
+
88
+ return result
89
+
90
+ def _generate_value(self, column: str, context: Dict[str, Any]) -> Any:
91
+ """Generate a single value for a column."""
92
+ func = self.schema[column]
93
+
94
+ if isinstance(func, (GeneratorFunction, SampleFunction)):
95
+ return func(context)
96
+ elif callable(func):
97
+ return func(context)
98
+ else:
99
+ # Static value
100
+ return func
101
+
102
+ def to_pandas(self) -> pd.DataFrame:
103
+ """Convert to pandas DataFrame."""
104
+ if self._data is None:
105
+ self.generate()
106
+ return self._data
107
+
108
+ def to_huggingface(self) -> HFDataset:
109
+ """Convert to HuggingFace Dataset."""
110
+ if self._data is None:
111
+ self.generate()
112
+ return HFDataset.from_pandas(self._data)
113
+
114
+ def save(self, path: str, format: str = "parquet") -> None:
115
+ """Save dataset to file."""
116
+ if self._data is None:
117
+ self.generate()
118
+
119
+ if format == "parquet":
120
+ self._data.to_parquet(path)
121
+ elif format == "csv":
122
+ self._data.to_csv(path, index=False)
123
+ elif format == "json":
124
+ self._data.to_json(path, orient="records")
125
+ else:
126
+ raise ValueError(f"Unsupported format: {format}")
127
+
128
+
129
+ def dataset(schema: Union[Dict[str, Any], str], n: int = 100) -> Dataset:
130
+ """Create a synthetic dataset."""
131
+ return Dataset(schema, n)
chatan/generator.py ADDED
@@ -0,0 +1,114 @@
1
+ """LLM generators for synthetic data creation."""
2
+
3
+ from typing import Dict, Any, Optional, Union, List
4
+ import openai
5
+ import anthropic
6
+ from abc import ABC, abstractmethod
7
+
8
+
9
+ class BaseGenerator(ABC):
10
+ """Base class for LLM generators."""
11
+
12
+ @abstractmethod
13
+ def generate(self, prompt: str, **kwargs) -> str:
14
+ """Generate content from a prompt."""
15
+ pass
16
+
17
+
18
+ class OpenAIGenerator(BaseGenerator):
19
+ """OpenAI GPT generator."""
20
+
21
+ def __init__(self, api_key: str, model: str = "gpt-3.5-turbo", **kwargs):
22
+ self.client = openai.OpenAI(api_key=api_key)
23
+ self.model = model
24
+ self.default_kwargs = kwargs
25
+
26
+ def generate(self, prompt: str, **kwargs) -> str:
27
+ """Generate content using OpenAI API."""
28
+ merged_kwargs = {**self.default_kwargs, **kwargs}
29
+
30
+ response = self.client.chat.completions.create(
31
+ model=self.model,
32
+ messages=[{"role": "user", "content": prompt}],
33
+ **merged_kwargs
34
+ )
35
+ return response.choices[0].message.content.strip()
36
+
37
+
38
+ class AnthropicGenerator(BaseGenerator):
39
+ """Anthropic Claude generator."""
40
+
41
+ def __init__(self, api_key: str, model: str = "claude-3-sonnet-20240229", **kwargs):
42
+ self.client = anthropic.Anthropic(api_key=api_key)
43
+ self.model = model
44
+ self.default_kwargs = kwargs
45
+
46
+ def generate(self, prompt: str, **kwargs) -> str:
47
+ """Generate content using Anthropic API."""
48
+ merged_kwargs = {**self.default_kwargs, **kwargs}
49
+
50
+ response = self.client.messages.create(
51
+ model=self.model,
52
+ messages=[{"role": "user", "content": prompt}],
53
+ max_tokens=merged_kwargs.pop("max_tokens", 1000),
54
+ **merged_kwargs
55
+ )
56
+ return response.content[0].text.strip()
57
+
58
+
59
+ class GeneratorFunction:
60
+ """Callable generator function for use in dataset schemas."""
61
+
62
+ def __init__(
63
+ self,
64
+ generator: BaseGenerator,
65
+ prompt_template: str,
66
+ variables: Optional[Dict[str, Any]] = None,
67
+ ):
68
+ self.generator = generator
69
+ self.prompt_template = prompt_template
70
+ self.variables = variables or {}
71
+
72
+ def __call__(self, context: Dict[str, Any]) -> str:
73
+ """Generate content with context substitution."""
74
+ merged = dict(context)
75
+ for key, value in self.variables.items():
76
+ merged[key] = value(context) if callable(value) else value
77
+
78
+ prompt = self.prompt_template.format(**merged)
79
+ result = self.generator.generate(prompt)
80
+ return result.strip() if isinstance(result, str) else result
81
+
82
+
83
+ class GeneratorClient:
84
+ """Main interface for creating generators."""
85
+
86
+ def __init__(self, provider: str, api_key: str, **kwargs):
87
+ if provider.lower() == "openai":
88
+ self._generator = OpenAIGenerator(api_key, **kwargs)
89
+ elif provider.lower() == "anthropic":
90
+ self._generator = AnthropicGenerator(api_key, **kwargs)
91
+ else:
92
+ raise ValueError(f"Unsupported provider: {provider}")
93
+
94
+ def __call__(self, prompt_template: str, **variables) -> GeneratorFunction:
95
+ """Create a generator function.
96
+
97
+ Parameters
98
+ ----------
99
+ prompt_template:
100
+ Template string for the prompt.
101
+ **variables:
102
+ Optional variables to include when formatting the prompt. If a value
103
+ is callable it will be invoked with the row context when the
104
+ generator function is executed.
105
+ """
106
+ return GeneratorFunction(self._generator, prompt_template, variables)
107
+
108
+
109
+ # Factory function
110
+ def generator(provider: str = "openai", api_key: Optional[str] = None, **kwargs) -> GeneratorClient:
111
+ """Create a generator client."""
112
+ if api_key is None:
113
+ raise ValueError("API key is required")
114
+ return GeneratorClient(provider, api_key, **kwargs)
chatan/sampler.py ADDED
@@ -0,0 +1,143 @@
1
+ """Sampling functions for synthetic data creation."""
2
+
3
+ import random
4
+ import uuid
5
+ from datetime import datetime, timedelta
6
+ from typing import Any, Dict, List, Union, Optional, Callable
7
+ import pandas as pd
8
+ from datasets import Dataset as HFDataset
9
+
10
+
11
+ class SampleFunction:
12
+ """Base class for sampling functions."""
13
+
14
+ def __call__(self, context: Dict[str, Any] = None) -> Any:
15
+ """Generate a sample value."""
16
+ raise NotImplementedError
17
+
18
+
19
+ class ChoiceSampler(SampleFunction):
20
+ """Sample from a list of choices."""
21
+
22
+ def __init__(self, choices: Union[List[Any], Dict[str, Any]]):
23
+ if isinstance(choices, dict):
24
+ self.choices = list(choices.keys())
25
+ self.weights = list(choices.values())
26
+ else:
27
+ self.choices = choices
28
+ self.weights = None
29
+
30
+ def __call__(self, context: Dict[str, Any] = None) -> Any:
31
+ if self.weights:
32
+ return random.choices(self.choices, weights=self.weights, k=1)[0]
33
+ return random.choice(self.choices)
34
+
35
+
36
+ class WeightedSampler(SampleFunction):
37
+ """Sample from weighted choices."""
38
+
39
+ def __init__(self, choices: Dict[str, float]):
40
+ self.choices = list(choices.keys())
41
+ self.weights = list(choices.values())
42
+
43
+ def __call__(self, context: Dict[str, Any] = None) -> Any:
44
+ return random.choices(self.choices, weights=self.weights, k=1)[0]
45
+
46
+
47
+ class UUIDSampler(SampleFunction):
48
+ """Generate UUID strings."""
49
+
50
+ def __call__(self, context: Dict[str, Any] = None) -> str:
51
+ return str(uuid.uuid4())
52
+
53
+
54
+ class DatetimeSampler(SampleFunction):
55
+ """Sample random datetimes."""
56
+
57
+ def __init__(self, start: str, end: str, format: str = "%Y-%m-%d"):
58
+ self.start = datetime.strptime(start, format)
59
+ self.end = datetime.strptime(end, format)
60
+ self.delta = self.end - self.start
61
+
62
+ def __call__(self, context: Dict[str, Any] = None) -> datetime:
63
+ random_days = random.randint(0, self.delta.days)
64
+ return self.start + timedelta(days=random_days)
65
+
66
+
67
+ class RangeSampler(SampleFunction):
68
+ """Sample from numeric ranges."""
69
+
70
+ def __init__(self, start: Union[int, float], end: Union[int, float],
71
+ step: Optional[Union[int, float]] = None):
72
+ self.start = start
73
+ self.end = end
74
+ self.step = step
75
+ self.is_int = isinstance(start, int) and isinstance(end, int)
76
+
77
+ def __call__(self, context: Dict[str, Any] = None) -> Union[int, float]:
78
+ if self.is_int:
79
+ return random.randint(self.start, self.end)
80
+ return random.uniform(self.start, self.end)
81
+
82
+
83
+ class DatasetSampler(SampleFunction):
84
+ """Sample from existing dataset columns."""
85
+
86
+ def __init__(self, dataset: Union[pd.DataFrame, HFDataset, Dict],
87
+ column: str, default: Optional[SampleFunction] = None):
88
+ if isinstance(dataset, pd.DataFrame):
89
+ self.values = dataset[column].tolist()
90
+ elif isinstance(dataset, HFDataset):
91
+ self.values = dataset[column]
92
+ elif isinstance(dataset, dict):
93
+ self.values = dataset[column]
94
+ else:
95
+ raise ValueError("Unsupported dataset type")
96
+
97
+ self.default = default
98
+
99
+ def __call__(self, context: Dict[str, Any] = None) -> Any:
100
+ if not self.values and self.default:
101
+ return self.default(context)
102
+ return random.choice(self.values)
103
+
104
+
105
+ # Factory functions for the sample namespace
106
+ class SampleNamespace:
107
+ """Namespace for sampling functions."""
108
+
109
+ @staticmethod
110
+ def choice(choices: Union[List[Any], Dict[str, Any]]) -> ChoiceSampler:
111
+ """Sample from choices."""
112
+ return ChoiceSampler(choices)
113
+
114
+ @staticmethod
115
+ def weighted(choices: Dict[str, float]) -> WeightedSampler:
116
+ """Sample from weighted choices."""
117
+ return WeightedSampler(choices)
118
+
119
+ @staticmethod
120
+ def uuid() -> UUIDSampler:
121
+ """Generate UUIDs."""
122
+ return UUIDSampler()
123
+
124
+ @staticmethod
125
+ def datetime(start: str, end: str, format: str = "%Y-%m-%d") -> DatetimeSampler:
126
+ """Sample random datetimes."""
127
+ return DatetimeSampler(start, end, format)
128
+
129
+ @staticmethod
130
+ def range(start: Union[int, float], end: Union[int, float],
131
+ step: Optional[Union[int, float]] = None) -> RangeSampler:
132
+ """Sample from numeric ranges."""
133
+ return RangeSampler(start, end, step)
134
+
135
+ @staticmethod
136
+ def from_dataset(dataset: Union[pd.DataFrame, HFDataset, Dict],
137
+ column: str, default: Optional[SampleFunction] = None) -> DatasetSampler:
138
+ """Sample from existing dataset."""
139
+ return DatasetSampler(dataset, column, default)
140
+
141
+
142
+ # Export the sample namespace
143
+ sample = SampleNamespace()
@@ -0,0 +1,83 @@
1
+ Metadata-Version: 2.4
2
+ Name: chatan
3
+ Version: 0.0.1
4
+ Summary: Create synthetic datasets with LLM generators and samplers
5
+ Project-URL: Documentation, https://github.com/cdreetz/chatan#readme
6
+ Project-URL: Issues, https://github.com/cdreetz/chatan/issues
7
+ Project-URL: Source, https://github.com/cdreetz/chatan
8
+ Author-email: Christian Reetz <cdreetz@gmail.com>
9
+ License-Expression: MIT
10
+ Keywords: dataset generation,llm,machine learning,synthetic data
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Operating System :: OS Independent
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.8
17
+ Classifier: Programming Language :: Python :: 3.9
18
+ Classifier: Programming Language :: Python :: 3.10
19
+ Classifier: Programming Language :: Python :: 3.11
20
+ Classifier: Programming Language :: Python :: 3.12
21
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
23
+ Requires-Python: >=3.8
24
+ Requires-Dist: anthropic>=0.7.0
25
+ Requires-Dist: datasets>=2.0.0
26
+ Requires-Dist: numpy>=1.20.0
27
+ Requires-Dist: openai>=1.0.0
28
+ Requires-Dist: pandas>=1.3.0
29
+ Requires-Dist: pydantic>=2.0.0
30
+ Description-Content-Type: text/markdown
31
+
32
+ ## Examples
33
+
34
+ Prompt a dataset
35
+
36
+ ```
37
+ import chatan
38
+
39
+ gen = chatan.generator.client("YOUR_OPENAI_API_KEY")
40
+ ds = chatan.dataset("create a QA dataset for finetuning an LLM on pharmacology")
41
+ ```
42
+
43
+ Creating datasets with different data mixes
44
+
45
+ ```
46
+ import uuid
47
+ from chatan import dataset, generator, mix
48
+
49
+ gen = generator.client("YOUR_OPENAI_API_KEY")
50
+ #generator.client("anthropic", "YOUR_ANTHROPIC_API_KEY")
51
+
52
+ mix = {
53
+ "implementation": "Can you implement a matmul kernel in Triton",
54
+ "conversion": "Convert this pytorch model to Triton",
55
+ "explanation": "What memory access optimizations are being used here?"
56
+ }
57
+
58
+ ds = dataset({
59
+ "id": uuid,
60
+ "task": sample.choice(mix),
61
+ "prompt": gen("write a prompt for {task}"),
62
+ "response": gen("write a response to {prompt}"),
63
+ )}
64
+ ```
65
+
66
+ Augment datasets
67
+
68
+ ```
69
+ import uuid
70
+ from chatan import dataset, generator
71
+ from dataset import load_dataset
72
+
73
+ gen = generator.client("YOUR_OPENAI_API_KEY")
74
+ hf_dataset = load_dataset("GPU_MODE/KernelBook")
75
+
76
+ ds = dataset({
77
+ "id": sample.from_dataset(hf_data, "id", default=sample.uuid()),
78
+ "prompt": sample.from_dataset(hf_data, "prompt", aug=gen("provide a variation of this prompt")),
79
+ "response": gen("write a response to {prompt}")
80
+
81
+ })
82
+
83
+ ```
@@ -0,0 +1,7 @@
1
+ chatan/__init__.py,sha256=KRSqdfRz3wKEnBJrOaHcptN5bSOrMj3uqaLXym-PrEQ,233
2
+ chatan/dataset.py,sha256=t6RrrQchsD2dROD886IfnlXWnn-F-IAWMcEIK0SS3xg,4358
3
+ chatan/generator.py,sha256=zAOVxd2iY_KMK21dCMeFhvsJu2q8NEwch5jGJP-Me9s,3914
4
+ chatan/sampler.py,sha256=0X6AVQK20py4SwKnsppZC2yAZnP_jBhRxt9MfT1e-k4,4812
5
+ chatan-0.0.1.dist-info/METADATA,sha256=poFjbYZG9CnvxHy_U2GKHaitTdfK3xbGB9q1WtUaaDQ,2549
6
+ chatan-0.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ chatan-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any