chatan 0.1.0__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatan/__init__.py +4 -2
- chatan/dataset.py +50 -5
- chatan/evaluate.py +320 -0
- chatan/generator.py +51 -12
- chatan/viewer.py +581 -0
- chatan-0.1.3.dist-info/METADATA +124 -0
- chatan-0.1.3.dist-info/RECORD +10 -0
- chatan-0.1.3.dist-info/licenses/LICENSE +21 -0
- chatan-0.1.0.dist-info/METADATA +0 -83
- chatan-0.1.0.dist-info/RECORD +0 -7
- {chatan-0.1.0.dist-info → chatan-0.1.3.dist-info}/WHEEL +0 -0
chatan/__init__.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
"""Minos: Create synthetic datasets with LLM generators and samplers."""
|
2
2
|
|
3
|
-
__version__ = "0.1.
|
3
|
+
__version__ = "0.1.3"
|
4
4
|
|
5
5
|
from .dataset import dataset
|
6
6
|
from .generator import generator
|
7
7
|
from .sampler import sample
|
8
|
+
from .viewer import generate_with_viewer
|
9
|
+
from .evaluate import evaluate, eval
|
8
10
|
|
9
|
-
__all__ = ["dataset", "generator", "sample"]
|
11
|
+
__all__ = ["dataset", "generator", "sample", "generate_with_viewer", "evaluate", "eval"]
|
chatan/dataset.py
CHANGED
@@ -3,8 +3,10 @@
|
|
3
3
|
from typing import Dict, Any, Union, Optional, List, Callable
|
4
4
|
import pandas as pd
|
5
5
|
from datasets import Dataset as HFDataset
|
6
|
+
from tqdm import tqdm
|
6
7
|
from .generator import GeneratorFunction
|
7
8
|
from .sampler import SampleFunction
|
9
|
+
from .evaluate import DatasetEvaluator, EvaluationFunction
|
8
10
|
|
9
11
|
|
10
12
|
class Dataset:
|
@@ -26,18 +28,61 @@ class Dataset:
|
|
26
28
|
self.schema = schema
|
27
29
|
self.n = n
|
28
30
|
self._data = None
|
31
|
+
|
32
|
+
@property
|
33
|
+
def eval(self):
|
34
|
+
"""Get dataset evaluator for method chaining."""
|
35
|
+
if self._data is None:
|
36
|
+
raise ValueError("Dataset must be generated before evaluation")
|
37
|
+
return DatasetEvaluator(self)
|
38
|
+
|
39
|
+
def evaluate(self, eval_schema: Dict[str, EvaluationFunction]) -> Dict[str, float]:
|
40
|
+
"""
|
41
|
+
Evaluate multiple metrics on this dataset.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
eval_schema: Dictionary mapping metric names to evaluation function.
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
Dictionary of metric names to computed scores
|
48
|
+
"""
|
49
|
+
if self._data is None:
|
50
|
+
self.generate()
|
51
|
+
|
52
|
+
results = {}
|
53
|
+
for name, eval_function in eval_schema.items():
|
54
|
+
results[name] = eval_function(self._data)
|
55
|
+
return results
|
56
|
+
|
29
57
|
|
30
|
-
def generate(
|
31
|
-
|
58
|
+
def generate(
|
59
|
+
self, n: Optional[int] = None, progress: bool = True
|
60
|
+
) -> pd.DataFrame:
|
61
|
+
"""Generate the dataset.
|
62
|
+
|
63
|
+
Parameters
|
64
|
+
----------
|
65
|
+
n:
|
66
|
+
Number of samples to generate. Defaults to the value provided at
|
67
|
+
initialization.
|
68
|
+
progress:
|
69
|
+
Whether to display a progress bar. Defaults to ``True``. Pass
|
70
|
+
``False`` to disable the progress output.
|
71
|
+
"""
|
32
72
|
num_samples = n or self.n
|
33
|
-
|
73
|
+
show_progress = progress
|
74
|
+
|
34
75
|
# Build dependency graph
|
35
76
|
dependencies = self._build_dependency_graph()
|
36
77
|
execution_order = self._topological_sort(dependencies)
|
37
|
-
|
78
|
+
|
38
79
|
# Generate data
|
39
80
|
data = []
|
40
|
-
|
81
|
+
iterator = range(num_samples)
|
82
|
+
if show_progress:
|
83
|
+
iterator = tqdm(iterator, desc="Generating", leave=False)
|
84
|
+
|
85
|
+
for _ in iterator:
|
41
86
|
row = {}
|
42
87
|
for column in execution_order:
|
43
88
|
value = self._generate_value(column, row)
|
chatan/evaluate.py
ADDED
@@ -0,0 +1,320 @@
|
|
1
|
+
"""Evaluation functions for synthetic data quality assessment."""
|
2
|
+
|
3
|
+
from typing import Dict, Any, Union, List, Callable, Optional
|
4
|
+
import numpy as np
|
5
|
+
import pandas as pd
|
6
|
+
from abc import ABC, abstractmethod
|
7
|
+
|
8
|
+
|
9
|
+
class BaseEvaluator(ABC):
|
10
|
+
"""Base class for evaluation metrics."""
|
11
|
+
|
12
|
+
@abstractmethod
|
13
|
+
def compute(self, predictions: List[Any], targets: List[Any], **kwargs) -> float:
|
14
|
+
"""Compute the evaluation metric."""
|
15
|
+
pass
|
16
|
+
|
17
|
+
|
18
|
+
class ExactMatchEvaluator(BaseEvaluator):
|
19
|
+
"""Exact string match evaluator."""
|
20
|
+
|
21
|
+
def compute(self, predictions: List[str], targets: List[str], **kwargs) -> float:
|
22
|
+
"""Compute exact match accuracy."""
|
23
|
+
if len(predictions) != len(targets):
|
24
|
+
raise ValueError("Predictions and targets must have same length")
|
25
|
+
|
26
|
+
matches = sum(1 for p, t in zip(predictions, targets) if str(p) == str(t))
|
27
|
+
return matches / len(predictions)
|
28
|
+
|
29
|
+
|
30
|
+
class SemanticSimilarityEvaluator(BaseEvaluator):
|
31
|
+
"""Semantic similarity evaluator using sentence transformers."""
|
32
|
+
|
33
|
+
def __init__(self, model: str = "all-MiniLM-L6-v2"):
|
34
|
+
self.model_name = model
|
35
|
+
self._model = None
|
36
|
+
|
37
|
+
def _load_model(self):
|
38
|
+
"""Lazy load the sentence transformer model."""
|
39
|
+
if self._model is None:
|
40
|
+
try:
|
41
|
+
from sentence_transformers import SentenceTransformer
|
42
|
+
self._model = SentenceTransformer(self.model_name)
|
43
|
+
except ImportError:
|
44
|
+
raise ImportError(
|
45
|
+
"sentence-transformers is required for semantic similarity. "
|
46
|
+
"Install with: pip install sentence-transformers"
|
47
|
+
)
|
48
|
+
|
49
|
+
def compute(self, predictions: List[str], targets: List[str], **kwargs) -> float:
|
50
|
+
"""Compute mean cosine similarity between predictions and targets."""
|
51
|
+
if len(predictions) != len(targets):
|
52
|
+
raise ValueError("Predictions and targets must have same length")
|
53
|
+
|
54
|
+
self._load_model()
|
55
|
+
|
56
|
+
pred_embeddings = self._model.encode(predictions)
|
57
|
+
target_embeddings = self._model.encode(targets)
|
58
|
+
|
59
|
+
# Compute cosine similarities
|
60
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
61
|
+
similarities = []
|
62
|
+
for pred_emb, target_emb in zip(pred_embeddings, target_embeddings):
|
63
|
+
sim = cosine_similarity([pred_emb], [target_emb])[0][0]
|
64
|
+
similarities.append(sim)
|
65
|
+
|
66
|
+
return float(np.mean(similarities))
|
67
|
+
|
68
|
+
|
69
|
+
class BLEUEvaluator(BaseEvaluator):
|
70
|
+
"""BLEU score evaluator."""
|
71
|
+
|
72
|
+
def compute(self, predictions: List[str], targets: List[str], **kwargs) -> float:
|
73
|
+
"""Compute BLEU score."""
|
74
|
+
try:
|
75
|
+
from nltk.translate.bleu_score import sentence_bleu
|
76
|
+
from nltk.tokenize import word_tokenize
|
77
|
+
import nltk
|
78
|
+
|
79
|
+
# Download required NLTK data
|
80
|
+
try:
|
81
|
+
nltk.data.find('tokenizers/punkt')
|
82
|
+
except LookupError:
|
83
|
+
nltk.download('punkt', quiet=True)
|
84
|
+
|
85
|
+
except ImportError:
|
86
|
+
raise ImportError(
|
87
|
+
"nltk is required for BLEU score. "
|
88
|
+
"Install with: pip install nltk"
|
89
|
+
)
|
90
|
+
|
91
|
+
if len(predictions) != len(targets):
|
92
|
+
raise ValueError("Predictions and targets must have same length")
|
93
|
+
|
94
|
+
scores = []
|
95
|
+
for pred, target in zip(predictions, targets):
|
96
|
+
pred_tokens = word_tokenize(str(pred).lower())
|
97
|
+
target_tokens = [word_tokenize(str(target).lower())]
|
98
|
+
|
99
|
+
if len(pred_tokens) == 0 or len(target_tokens[0]) == 0:
|
100
|
+
scores.append(0.0)
|
101
|
+
else:
|
102
|
+
score = sentence_bleu(target_tokens, pred_tokens)
|
103
|
+
scores.append(score)
|
104
|
+
|
105
|
+
return float(np.mean(scores))
|
106
|
+
|
107
|
+
|
108
|
+
class EditDistanceEvaluator(BaseEvaluator):
|
109
|
+
"""Normalized edit distance evaluator."""
|
110
|
+
|
111
|
+
def compute(self, predictions: List[str], targets: List[str], **kwargs) -> float:
|
112
|
+
"""Compute normalized edit distance (lower is better, normalized to 0-1)."""
|
113
|
+
if len(predictions) != len(targets):
|
114
|
+
raise ValueError("Predictions and targets must have same length")
|
115
|
+
|
116
|
+
distances = []
|
117
|
+
for pred, target in zip(predictions, targets):
|
118
|
+
pred_str, target_str = str(pred), str(target)
|
119
|
+
distance = self._levenshtein_distance(pred_str, target_str)
|
120
|
+
max_len = max(len(pred_str), len(target_str))
|
121
|
+
normalized_distance = distance / max_len if max_len > 0 else 0.0
|
122
|
+
# Convert to similarity (1 - distance)
|
123
|
+
similarity = 1.0 - normalized_distance
|
124
|
+
distances.append(similarity)
|
125
|
+
|
126
|
+
return float(np.mean(distances))
|
127
|
+
|
128
|
+
def _levenshtein_distance(self, s1: str, s2: str) -> int:
|
129
|
+
"""Compute Levenshtein distance between two strings."""
|
130
|
+
if len(s1) < len(s2):
|
131
|
+
return self._levenshtein_distance(s2, s1)
|
132
|
+
|
133
|
+
if len(s2) == 0:
|
134
|
+
return len(s1)
|
135
|
+
|
136
|
+
previous_row = list(range(len(s2) + 1))
|
137
|
+
for i, c1 in enumerate(s1):
|
138
|
+
current_row = [i + 1]
|
139
|
+
for j, c2 in enumerate(s2):
|
140
|
+
insertions = previous_row[j + 1] + 1
|
141
|
+
deletions = current_row[j] + 1
|
142
|
+
substitutions = previous_row[j] + (c1 != c2)
|
143
|
+
current_row.append(min(insertions, deletions, substitutions))
|
144
|
+
previous_row = current_row
|
145
|
+
|
146
|
+
return previous_row[-1]
|
147
|
+
|
148
|
+
|
149
|
+
class LLMJudgeEvaluator(BaseEvaluator):
|
150
|
+
"""LLM-as-a-judge evaluator."""
|
151
|
+
|
152
|
+
def __init__(self, generator_client, prompt_template: str = None):
|
153
|
+
self.generator = generator_client
|
154
|
+
self.prompt_template = prompt_template or (
|
155
|
+
"Rate the quality of this response on a scale of 1-10:\n\n"
|
156
|
+
"Question: {question}\n"
|
157
|
+
"Response: {response}\n\n"
|
158
|
+
"Rating (1-10):"
|
159
|
+
)
|
160
|
+
|
161
|
+
def compute(self, predictions: List[str], targets: List[str],
|
162
|
+
questions: Optional[List[str]] = None, **kwargs) -> float:
|
163
|
+
"""Compute LLM judge scores."""
|
164
|
+
if len(predictions) != len(targets):
|
165
|
+
raise ValueError("Predictions and targets must have same length")
|
166
|
+
|
167
|
+
scores = []
|
168
|
+
for i, (pred, target) in enumerate(zip(predictions, targets)):
|
169
|
+
question = questions[i] if questions else f"Question {i+1}"
|
170
|
+
|
171
|
+
prompt = self.prompt_template.format(
|
172
|
+
question=question,
|
173
|
+
response=pred,
|
174
|
+
target=target
|
175
|
+
)
|
176
|
+
|
177
|
+
try:
|
178
|
+
response = self.generator._generator.generate(prompt)
|
179
|
+
# Extract numeric score from response
|
180
|
+
score = self._extract_score(response)
|
181
|
+
scores.append(score)
|
182
|
+
except Exception:
|
183
|
+
# If generation fails, assign neutral score
|
184
|
+
scores.append(0.5)
|
185
|
+
|
186
|
+
return float(np.mean(scores))
|
187
|
+
|
188
|
+
def _extract_score(self, response: str) -> float:
|
189
|
+
"""Extract numeric score from LLM response."""
|
190
|
+
import re
|
191
|
+
|
192
|
+
# Look for numbers in the response
|
193
|
+
numbers = re.findall(r'\b(\d+(?:\.\d+)?)\b', response)
|
194
|
+
|
195
|
+
if numbers:
|
196
|
+
score = float(numbers[0])
|
197
|
+
# Normalize to 0-1 scale if it looks like 1-10 scale
|
198
|
+
if score > 1:
|
199
|
+
score = score / 10.0
|
200
|
+
return min(max(score, 0.0), 1.0)
|
201
|
+
|
202
|
+
return 0.5 # Default neutral score
|
203
|
+
|
204
|
+
|
205
|
+
class EvaluationFunction:
|
206
|
+
"""Evaluation function for use in evaluation schemas."""
|
207
|
+
|
208
|
+
def __init__(self, evaluator: BaseEvaluator, column_a: str, column_b: str, **kwargs):
|
209
|
+
self.evaluator = evaluator
|
210
|
+
self.column_a = column_a
|
211
|
+
self.column_b = column_b
|
212
|
+
self.kwargs = kwargs
|
213
|
+
|
214
|
+
def __call__(self, data: pd.DataFrame) -> float:
|
215
|
+
"""Compute evaluation metric on dataset."""
|
216
|
+
predictions = data[self.column_a].tolist()
|
217
|
+
targets = data[self.column_b].tolist()
|
218
|
+
return self.evaluator.compute(predictions, targets, **self.kwargs)
|
219
|
+
|
220
|
+
|
221
|
+
class DatasetEvaluator:
|
222
|
+
"""Dataset-specific evaluator for method chaining."""
|
223
|
+
|
224
|
+
def __init__(self, dataset):
|
225
|
+
self.dataset = dataset
|
226
|
+
|
227
|
+
def exact_match(self, column_a: str, column_b: str, **kwargs) -> EvaluationFunction:
|
228
|
+
"""Create exact match evaluation function."""
|
229
|
+
return EvaluationFunction(ExactMatchEvaluator(), column_a, column_b, **kwargs)
|
230
|
+
|
231
|
+
def semantic_similarity(self, column_a: str, column_b: str,
|
232
|
+
model: str = "all-MiniLM-L6-v2", **kwargs) -> EvaluationFunction:
|
233
|
+
"""Create semantic similarity evaluation function."""
|
234
|
+
evaluator = SemanticSimilarityEvaluator(model)
|
235
|
+
return EvaluationFunction(evaluator, column_a, column_b, **kwargs)
|
236
|
+
|
237
|
+
def bleu_score(self, column_a: str, column_b: str, **kwargs) -> EvaluationFunction:
|
238
|
+
"""Create BLEU score evaluation function."""
|
239
|
+
return EvaluationFunction(BLEUEvaluator(), column_a, column_b, **kwargs)
|
240
|
+
|
241
|
+
def edit_distance(self, column_a: str, column_b: str, **kwargs) -> EvaluationFunction:
|
242
|
+
"""Create edit distance evaluation function."""
|
243
|
+
return EvaluationFunction(EditDistanceEvaluator(), column_a, column_b, **kwargs)
|
244
|
+
|
245
|
+
def llm_judge(self, column_a: str, column_b: str,
|
246
|
+
generator_client, prompt_template: str = None, **kwargs) -> EvaluationFunction:
|
247
|
+
"""Create LLM judge evaluation function."""
|
248
|
+
evaluator = LLMJudgeEvaluator(generator_client, prompt_template)
|
249
|
+
return EvaluationFunction(evaluator, column_a, column_b, **kwargs)
|
250
|
+
|
251
|
+
|
252
|
+
# Standalone evaluation functions for schema use
|
253
|
+
class EvalNamespace:
|
254
|
+
"""Namespace for evaluation functions."""
|
255
|
+
|
256
|
+
@staticmethod
|
257
|
+
def exact_match(column_a: str, column_b: str, **kwargs) -> Callable:
|
258
|
+
"""Exact match evaluation for use in dataset schemas."""
|
259
|
+
def eval_func(context: Dict[str, Any]) -> float:
|
260
|
+
pred, target = context[column_a], context[column_b]
|
261
|
+
return 1.0 if str(pred) == str(target) else 0.0
|
262
|
+
return eval_func
|
263
|
+
|
264
|
+
@staticmethod
|
265
|
+
def semantic_similarity(column_a: str, column_b: str,
|
266
|
+
model: str = "all-MiniLM-L6-v2") -> Callable:
|
267
|
+
"""Semantic similarity evaluation for use in dataset schemas."""
|
268
|
+
evaluator = SemanticSimilarityEvaluator(model)
|
269
|
+
|
270
|
+
def eval_func(context: Dict[str, Any]) -> float:
|
271
|
+
pred, target = context[column_a], context[column_b]
|
272
|
+
return evaluator.compute([str(pred)], [str(target)])
|
273
|
+
return eval_func
|
274
|
+
|
275
|
+
@staticmethod
|
276
|
+
def bleu_score(column_a: str, column_b: str) -> Callable:
|
277
|
+
"""BLEU score evaluation for use in dataset schemas."""
|
278
|
+
evaluator = BLEUEvaluator()
|
279
|
+
|
280
|
+
def eval_func(context: Dict[str, Any]) -> float:
|
281
|
+
pred, target = context[column_a], context[column_b]
|
282
|
+
return evaluator.compute([str(pred)], [str(target)])
|
283
|
+
return eval_func
|
284
|
+
|
285
|
+
@staticmethod
|
286
|
+
def edit_distance(column_a: str, column_b: str) -> Callable:
|
287
|
+
"""Edit distance evaluation for use in dataset schemas."""
|
288
|
+
evaluator = EditDistanceEvaluator()
|
289
|
+
|
290
|
+
def eval_func(context: Dict[str, Any]) -> float:
|
291
|
+
pred, target = context[column_a], context[column_b]
|
292
|
+
return evaluator.compute([str(pred)], [str(target)])
|
293
|
+
return eval_func
|
294
|
+
|
295
|
+
|
296
|
+
def evaluate(eval_schema: Dict[str, Union[EvaluationFunction, Callable]]) -> Dict[str, float]:
|
297
|
+
"""
|
298
|
+
Evaluate multiple metrics across datasets.
|
299
|
+
|
300
|
+
Args:
|
301
|
+
eval_schema: Dictionary mapping metric names to evaluation functions or callables
|
302
|
+
|
303
|
+
Returns:
|
304
|
+
Dictionary of metric names to computed scores
|
305
|
+
"""
|
306
|
+
results = {}
|
307
|
+
for name, eval_func in eval_schema.items():
|
308
|
+
if isinstance(eval_func, EvaluationFunction):
|
309
|
+
# This shouldn't happen in normal usage - EvaluationFunction needs a dataset
|
310
|
+
raise ValueError(f"EvaluationFunction for '{name}' requires dataset context")
|
311
|
+
elif callable(eval_func):
|
312
|
+
# Should be a callable that returns the result
|
313
|
+
results[name] = eval_func()
|
314
|
+
else:
|
315
|
+
raise ValueError(f"Invalid evaluation function type for '{name}': {type(eval_func)}")
|
316
|
+
return results
|
317
|
+
|
318
|
+
|
319
|
+
# Export the eval namespace
|
320
|
+
eval = EvalNamespace()
|
chatan/generator.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
from typing import Dict, Any, Optional, Union, List
|
4
4
|
import openai
|
5
5
|
import anthropic
|
6
|
+
from transformers import pipeline
|
6
7
|
from abc import ABC, abstractmethod
|
7
8
|
|
8
9
|
|
@@ -56,39 +57,77 @@ class AnthropicGenerator(BaseGenerator):
|
|
56
57
|
return response.content[0].text.strip()
|
57
58
|
|
58
59
|
|
60
|
+
class TransformersGenerator(BaseGenerator):
|
61
|
+
"""Local HuggingFace/transformers generator."""
|
62
|
+
|
63
|
+
def __init__(self, model: str = "gpt2", **kwargs):
|
64
|
+
self.pipeline = pipeline("text-generation", model=model, **kwargs)
|
65
|
+
|
66
|
+
def generate(self, prompt: str, **kwargs) -> str:
|
67
|
+
result = self.pipeline(prompt, **kwargs)[0]["generated_text"]
|
68
|
+
return result.strip()
|
69
|
+
|
70
|
+
|
59
71
|
class GeneratorFunction:
|
60
72
|
"""Callable generator function for use in dataset schemas."""
|
61
|
-
|
62
|
-
def __init__(
|
73
|
+
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
generator: BaseGenerator,
|
77
|
+
prompt_template: str,
|
78
|
+
variables: Optional[Dict[str, Any]] = None,
|
79
|
+
):
|
63
80
|
self.generator = generator
|
64
81
|
self.prompt_template = prompt_template
|
65
|
-
|
82
|
+
self.variables = variables or {}
|
83
|
+
|
66
84
|
def __call__(self, context: Dict[str, Any]) -> str:
|
67
85
|
"""Generate content with context substitution."""
|
68
|
-
|
86
|
+
merged = dict(context)
|
87
|
+
for key, value in self.variables.items():
|
88
|
+
merged[key] = value(context) if callable(value) else value
|
89
|
+
|
90
|
+
prompt = self.prompt_template.format(**merged)
|
69
91
|
result = self.generator.generate(prompt)
|
70
92
|
return result.strip() if isinstance(result, str) else result
|
71
93
|
|
72
94
|
|
73
95
|
class GeneratorClient:
|
74
96
|
"""Main interface for creating generators."""
|
75
|
-
|
76
|
-
def __init__(self, provider: str, api_key: str, **kwargs):
|
77
|
-
|
97
|
+
|
98
|
+
def __init__(self, provider: str, api_key: Optional[str] = None, **kwargs):
|
99
|
+
provider_lower = provider.lower()
|
100
|
+
if provider_lower == "openai":
|
101
|
+
if api_key is None:
|
102
|
+
raise ValueError("API key is required for OpenAI")
|
78
103
|
self._generator = OpenAIGenerator(api_key, **kwargs)
|
79
|
-
elif
|
104
|
+
elif provider_lower == "anthropic":
|
105
|
+
if api_key is None:
|
106
|
+
raise ValueError("API key is required for Anthropic")
|
80
107
|
self._generator = AnthropicGenerator(api_key, **kwargs)
|
108
|
+
elif provider_lower in {"huggingface", "transformers", "hf"}:
|
109
|
+
self._generator = TransformersGenerator(**kwargs)
|
81
110
|
else:
|
82
111
|
raise ValueError(f"Unsupported provider: {provider}")
|
83
112
|
|
84
|
-
def __call__(self, prompt_template: str) -> GeneratorFunction:
|
85
|
-
"""Create a generator function.
|
86
|
-
|
113
|
+
def __call__(self, prompt_template: str, **variables) -> GeneratorFunction:
|
114
|
+
"""Create a generator function.
|
115
|
+
|
116
|
+
Parameters
|
117
|
+
----------
|
118
|
+
prompt_template:
|
119
|
+
Template string for the prompt.
|
120
|
+
**variables:
|
121
|
+
Optional variables to include when formatting the prompt. If a value
|
122
|
+
is callable it will be invoked with the row context when the
|
123
|
+
generator function is executed.
|
124
|
+
"""
|
125
|
+
return GeneratorFunction(self._generator, prompt_template, variables)
|
87
126
|
|
88
127
|
|
89
128
|
# Factory function
|
90
129
|
def generator(provider: str = "openai", api_key: Optional[str] = None, **kwargs) -> GeneratorClient:
|
91
130
|
"""Create a generator client."""
|
92
|
-
if api_key is None:
|
131
|
+
if provider.lower() in {"openai", "anthropic"} and api_key is None:
|
93
132
|
raise ValueError("API key is required")
|
94
133
|
return GeneratorClient(provider, api_key, **kwargs)
|
chatan/viewer.py
ADDED
@@ -0,0 +1,581 @@
|
|
1
|
+
"""Live HTML viewer for dataset generation."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
import os
|
5
|
+
import tempfile
|
6
|
+
import webbrowser
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Dict, Any, Optional, Callable
|
9
|
+
import threading
|
10
|
+
import time
|
11
|
+
from http.server import HTTPServer, SimpleHTTPRequestHandler
|
12
|
+
import atexit
|
13
|
+
|
14
|
+
|
15
|
+
class LiveViewer:
|
16
|
+
"""Live HTML viewer for streaming dataset generation results."""
|
17
|
+
|
18
|
+
def __init__(self, title: str = "Dataset Generation", auto_open: bool = True):
|
19
|
+
self.title = title
|
20
|
+
self.auto_open = auto_open
|
21
|
+
self.temp_dir = None
|
22
|
+
self.html_file = None
|
23
|
+
self.data_file = None
|
24
|
+
self.server = None
|
25
|
+
self.server_thread = None
|
26
|
+
self.port = 8000
|
27
|
+
self._active = False
|
28
|
+
|
29
|
+
def start(self, schema: Dict[str, Any]) -> str:
|
30
|
+
"""Start the viewer and return the URL."""
|
31
|
+
self.temp_dir = tempfile.mkdtemp()
|
32
|
+
self.html_file = Path(self.temp_dir) / "viewer.html"
|
33
|
+
self.data_file = Path(self.temp_dir) / "data.json"
|
34
|
+
|
35
|
+
# Initialize empty data file
|
36
|
+
with open(self.data_file, 'w') as f:
|
37
|
+
json.dump({"rows": [], "completed": False, "current_row": None}, f)
|
38
|
+
|
39
|
+
# Create HTML file
|
40
|
+
html_content = self._generate_html(list(schema.keys()))
|
41
|
+
with open(self.html_file, 'w') as f:
|
42
|
+
f.write(html_content)
|
43
|
+
|
44
|
+
# Start local server
|
45
|
+
self._start_server()
|
46
|
+
|
47
|
+
# Open in browser
|
48
|
+
url = f"http://localhost:{self.port}/viewer.html"
|
49
|
+
if self.auto_open:
|
50
|
+
webbrowser.open(url)
|
51
|
+
|
52
|
+
self._active = True
|
53
|
+
return url
|
54
|
+
|
55
|
+
def add_row(self, row: Dict[str, Any]):
|
56
|
+
"""Add a new row to the viewer."""
|
57
|
+
if not self._active or not self.data_file:
|
58
|
+
return
|
59
|
+
|
60
|
+
try:
|
61
|
+
with open(self.data_file, 'r') as f:
|
62
|
+
data = json.load(f)
|
63
|
+
except:
|
64
|
+
data = {"rows": [], "completed": False, "current_row": None}
|
65
|
+
|
66
|
+
data["rows"].append(row)
|
67
|
+
# Keep current_row so we can update the UI with final values
|
68
|
+
data["completed_row"] = {"index": len(data["rows"]) - 1, "data": row}
|
69
|
+
data["current_row"] = None # Clear current row when row is complete
|
70
|
+
|
71
|
+
with open(self.data_file, 'w') as f:
|
72
|
+
json.dump(data, f)
|
73
|
+
|
74
|
+
def start_row(self, row_index: int):
|
75
|
+
"""Start a new row with empty cells."""
|
76
|
+
if not self._active or not self.data_file:
|
77
|
+
return
|
78
|
+
|
79
|
+
try:
|
80
|
+
with open(self.data_file, 'r') as f:
|
81
|
+
data = json.load(f)
|
82
|
+
except:
|
83
|
+
data = {"rows": [], "completed": False, "current_row": None}
|
84
|
+
|
85
|
+
data["current_row"] = {"index": row_index, "cells": {}}
|
86
|
+
|
87
|
+
with open(self.data_file, 'w') as f:
|
88
|
+
json.dump(data, f)
|
89
|
+
|
90
|
+
def update_cell(self, column: str, value: Any):
|
91
|
+
"""Update a single cell in the current row."""
|
92
|
+
if not self._active or not self.data_file:
|
93
|
+
return
|
94
|
+
|
95
|
+
try:
|
96
|
+
with open(self.data_file, 'r') as f:
|
97
|
+
data = json.load(f)
|
98
|
+
except:
|
99
|
+
data = {"rows": [], "completed": False, "current_row": None}
|
100
|
+
|
101
|
+
if data.get("current_row"):
|
102
|
+
data["current_row"]["cells"][column] = value
|
103
|
+
|
104
|
+
with open(self.data_file, 'w') as f:
|
105
|
+
json.dump(data, f)
|
106
|
+
|
107
|
+
def complete(self):
|
108
|
+
"""Mark generation as complete."""
|
109
|
+
if not self._active or not self.data_file:
|
110
|
+
return
|
111
|
+
|
112
|
+
try:
|
113
|
+
with open(self.data_file, 'r') as f:
|
114
|
+
data = json.load(f)
|
115
|
+
except:
|
116
|
+
data = {"rows": [], "completed": False, "current_row": None}
|
117
|
+
|
118
|
+
data["completed"] = True
|
119
|
+
|
120
|
+
with open(self.data_file, 'w') as f:
|
121
|
+
json.dump(data, f)
|
122
|
+
|
123
|
+
def stop(self):
|
124
|
+
"""Stop the viewer and cleanup resources."""
|
125
|
+
self._active = False
|
126
|
+
if self.server:
|
127
|
+
self.server.shutdown()
|
128
|
+
self.server.server_close()
|
129
|
+
|
130
|
+
def _start_server(self):
|
131
|
+
"""Start a local HTTP server."""
|
132
|
+
os.chdir(self.temp_dir)
|
133
|
+
|
134
|
+
# Find available port
|
135
|
+
for port in range(8000, 8100):
|
136
|
+
try:
|
137
|
+
self.server = HTTPServer(("localhost", port), SimpleHTTPRequestHandler)
|
138
|
+
self.port = port
|
139
|
+
break
|
140
|
+
except OSError:
|
141
|
+
continue
|
142
|
+
|
143
|
+
if self.server:
|
144
|
+
self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
145
|
+
self.server_thread.start()
|
146
|
+
|
147
|
+
# Register cleanup on exit
|
148
|
+
atexit.register(self.stop)
|
149
|
+
|
150
|
+
def _generate_html(self, columns) -> str:
|
151
|
+
"""Generate the HTML content."""
|
152
|
+
return f"""<!DOCTYPE html>
|
153
|
+
<html lang="en">
|
154
|
+
<head>
|
155
|
+
<meta charset="UTF-8">
|
156
|
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
157
|
+
<title>{self.title}</title>
|
158
|
+
<style>
|
159
|
+
body {{
|
160
|
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui, sans-serif;
|
161
|
+
margin: 0;
|
162
|
+
padding: 20px;
|
163
|
+
background: #f8fafc;
|
164
|
+
color: #1e293b;
|
165
|
+
}}
|
166
|
+
|
167
|
+
.header {{
|
168
|
+
background: white;
|
169
|
+
padding: 20px;
|
170
|
+
border-radius: 12px;
|
171
|
+
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
172
|
+
margin-bottom: 20px;
|
173
|
+
}}
|
174
|
+
|
175
|
+
.title {{
|
176
|
+
font-size: 24px;
|
177
|
+
font-weight: 600;
|
178
|
+
margin: 0 0 10px 0;
|
179
|
+
}}
|
180
|
+
|
181
|
+
.status {{
|
182
|
+
display: flex;
|
183
|
+
align-items: center;
|
184
|
+
gap: 8px;
|
185
|
+
font-size: 14px;
|
186
|
+
color: #64748b;
|
187
|
+
}}
|
188
|
+
|
189
|
+
.status-dot {{
|
190
|
+
width: 8px;
|
191
|
+
height: 8px;
|
192
|
+
border-radius: 50%;
|
193
|
+
background: #10b981;
|
194
|
+
animation: pulse 1.5s infinite;
|
195
|
+
}}
|
196
|
+
|
197
|
+
.status-dot.complete {{
|
198
|
+
background: #6366f1;
|
199
|
+
animation: none;
|
200
|
+
}}
|
201
|
+
|
202
|
+
@keyframes pulse {{
|
203
|
+
0%, 100% {{ opacity: 1; }}
|
204
|
+
50% {{ opacity: 0.5; }}
|
205
|
+
}}
|
206
|
+
|
207
|
+
.table-container {{
|
208
|
+
background: white;
|
209
|
+
border-radius: 12px;
|
210
|
+
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
211
|
+
overflow: hidden;
|
212
|
+
max-height: 70vh;
|
213
|
+
overflow-y: auto;
|
214
|
+
}}
|
215
|
+
|
216
|
+
table {{
|
217
|
+
width: 100%;
|
218
|
+
border-collapse: collapse;
|
219
|
+
table-layout: fixed;
|
220
|
+
}}
|
221
|
+
|
222
|
+
th {{
|
223
|
+
background: #f1f5f9;
|
224
|
+
padding: 16px;
|
225
|
+
text-align: left;
|
226
|
+
font-weight: 600;
|
227
|
+
border-bottom: 1px solid #e2e8f0;
|
228
|
+
position: sticky;
|
229
|
+
top: 0;
|
230
|
+
z-index: 10;
|
231
|
+
position: relative;
|
232
|
+
}}
|
233
|
+
|
234
|
+
th:not(:last-child), td:not(:last-child) {{
|
235
|
+
border-right: 1px solid #e2e8f0;
|
236
|
+
}}
|
237
|
+
|
238
|
+
.col-resizer {{
|
239
|
+
position: absolute;
|
240
|
+
right: 0;
|
241
|
+
top: 0;
|
242
|
+
height: 100%;
|
243
|
+
width: 5px;
|
244
|
+
cursor: col-resize;
|
245
|
+
user-select: none;
|
246
|
+
}}
|
247
|
+
|
248
|
+
td {{
|
249
|
+
padding: 12px 16px;
|
250
|
+
border-bottom: 1px solid #f1f5f9;
|
251
|
+
vertical-align: top;
|
252
|
+
}}
|
253
|
+
|
254
|
+
tr:hover {{
|
255
|
+
background: #f8fafc;
|
256
|
+
}}
|
257
|
+
|
258
|
+
.row-number {{
|
259
|
+
color: #64748b;
|
260
|
+
font-size: 12px;
|
261
|
+
font-weight: 500;
|
262
|
+
width: 60px;
|
263
|
+
}}
|
264
|
+
|
265
|
+
.cell-content {{
|
266
|
+
max-width: 300px;
|
267
|
+
word-wrap: break-word;
|
268
|
+
white-space: pre-wrap;
|
269
|
+
}}
|
270
|
+
|
271
|
+
.cell-generating {{
|
272
|
+
background: linear-gradient(90deg, #f1f5f9, #e2e8f0, #f1f5f9);
|
273
|
+
background-size: 200% 200%;
|
274
|
+
animation: shimmer 1.5s ease-in-out infinite;
|
275
|
+
}}
|
276
|
+
|
277
|
+
@keyframes shimmer {{
|
278
|
+
0% {{ background-position: -200% 0; }}
|
279
|
+
100% {{ background-position: 200% 0; }}
|
280
|
+
}}
|
281
|
+
|
282
|
+
.new-row {{
|
283
|
+
animation: slideIn 0.3s ease-out;
|
284
|
+
}}
|
285
|
+
|
286
|
+
@keyframes slideIn {{
|
287
|
+
from {{
|
288
|
+
opacity: 0;
|
289
|
+
transform: translateY(-10px);
|
290
|
+
}}
|
291
|
+
to {{
|
292
|
+
opacity: 1;
|
293
|
+
transform: translateY(0);
|
294
|
+
}}
|
295
|
+
}}
|
296
|
+
|
297
|
+
.loading {{
|
298
|
+
text-align: center;
|
299
|
+
padding: 40px;
|
300
|
+
color: #64748b;
|
301
|
+
}}
|
302
|
+
</style>
|
303
|
+
</head>
|
304
|
+
<body>
|
305
|
+
<div class="header">
|
306
|
+
<div class="title">{self.title}</div>
|
307
|
+
<div class="status">
|
308
|
+
<div class="status-dot" id="statusDot"></div>
|
309
|
+
<span id="statusText">Generating...</span>
|
310
|
+
<span id="rowCount">0 rows</span>
|
311
|
+
</div>
|
312
|
+
</div>
|
313
|
+
|
314
|
+
<div class="table-container">
|
315
|
+
<table>
|
316
|
+
<thead>
|
317
|
+
<tr>
|
318
|
+
<th class="row-number">#</th>
|
319
|
+
{' '.join(f'<th>{col}</th>' for col in columns)}
|
320
|
+
</tr>
|
321
|
+
</thead>
|
322
|
+
<tbody id="tableBody">
|
323
|
+
<tr>
|
324
|
+
<td colspan="{len(columns) + 1}" class="loading">
|
325
|
+
Waiting for data...
|
326
|
+
</td>
|
327
|
+
</tr>
|
328
|
+
</tbody>
|
329
|
+
</table>
|
330
|
+
</div>
|
331
|
+
|
332
|
+
<script>
|
333
|
+
let rowCount = 0;
|
334
|
+
let currentRowElement = null;
|
335
|
+
|
336
|
+
function makeColumnsResizable(table) {{
|
337
|
+
const headers = table.querySelectorAll('th');
|
338
|
+
headers.forEach((th, index) => {{
|
339
|
+
|
340
|
+
if (index === headers.length - 1) return;
|
341
|
+
const resizer = document.createElement('div');
|
342
|
+
resizer.className = 'col-resizer';
|
343
|
+
th.appendChild(resizer);
|
344
|
+
|
345
|
+
let startX, startWidth;
|
346
|
+
|
347
|
+
resizer.addEventListener('mousedown', (e) => {{
|
348
|
+
|
349
|
+
startX = e.clientX;
|
350
|
+
startWidth = th.offsetWidth;
|
351
|
+
document.addEventListener('mousemove', doDrag);
|
352
|
+
document.addEventListener('mouseup', stopDrag);
|
353
|
+
}});
|
354
|
+
|
355
|
+
function doDrag(e) {{
|
356
|
+
const width = startWidth + e.clientX - startX;
|
357
|
+
th.style.width = width + 'px';
|
358
|
+
}}
|
359
|
+
|
360
|
+
function stopDrag() {{
|
361
|
+
document.removeEventListener('mousemove', doDrag);
|
362
|
+
document.removeEventListener('mouseup', stopDrag);
|
363
|
+
}}
|
364
|
+
}});
|
365
|
+
}}
|
366
|
+
|
367
|
+
|
368
|
+
async function fetchData() {{
|
369
|
+
try {{
|
370
|
+
const response = await fetch('data.json?' + new Date().getTime());
|
371
|
+
const data = await response.json();
|
372
|
+
|
373
|
+
// Handle current row updates
|
374
|
+
if (data.current_row) {{
|
375
|
+
updateCurrentRow(data.current_row);
|
376
|
+
}}
|
377
|
+
|
378
|
+
// Handle completed rows
|
379
|
+
if (data.completed_row) {{
|
380
|
+
completeRow(data.completed_row);
|
381
|
+
}}
|
382
|
+
|
383
|
+
if (data.rows.length > rowCount) {{
|
384
|
+
rowCount = data.rows.length;
|
385
|
+
updateStatus(data.completed);
|
386
|
+
}}
|
387
|
+
|
388
|
+
if (data.completed) {{
|
389
|
+
document.getElementById('statusDot').classList.add('complete');
|
390
|
+
document.getElementById('statusText').textContent = 'Complete';
|
391
|
+
return;
|
392
|
+
}}
|
393
|
+
}} catch (error) {{
|
394
|
+
console.error('Error fetching data:', error);
|
395
|
+
}}
|
396
|
+
|
397
|
+
setTimeout(fetchData, 200); // Faster polling for cell updates
|
398
|
+
}}
|
399
|
+
|
400
|
+
function updateCurrentRow(currentRow) {{
|
401
|
+
const tbody = document.getElementById('tableBody');
|
402
|
+
|
403
|
+
// Remove loading message if present
|
404
|
+
if (tbody.children.length === 1 && tbody.children[0].cells.length === {len(columns) + 1}) {{
|
405
|
+
tbody.innerHTML = '';
|
406
|
+
}}
|
407
|
+
|
408
|
+
// Create new row element only if we don't have one for this index
|
409
|
+
if (!currentRowElement || parseInt(currentRowElement.dataset.rowIndex) !== currentRow.index) {{
|
410
|
+
currentRowElement = document.createElement('tr');
|
411
|
+
currentRowElement.className = 'new-row';
|
412
|
+
currentRowElement.dataset.rowIndex = currentRow.index;
|
413
|
+
|
414
|
+
// Row number
|
415
|
+
const numCell = document.createElement('td');
|
416
|
+
numCell.className = 'row-number';
|
417
|
+
numCell.textContent = currentRow.index + 1;
|
418
|
+
currentRowElement.appendChild(numCell);
|
419
|
+
|
420
|
+
// Create empty cells for all columns
|
421
|
+
{json.dumps(columns)}.forEach(col => {{
|
422
|
+
const td = document.createElement('td');
|
423
|
+
td.className = 'cell-content cell-generating';
|
424
|
+
td.textContent = '...';
|
425
|
+
td.id = `cell-${{currentRow.index}}-${{col}}`;
|
426
|
+
currentRowElement.appendChild(td);
|
427
|
+
}});
|
428
|
+
|
429
|
+
tbody.appendChild(currentRowElement);
|
430
|
+
}}
|
431
|
+
|
432
|
+
// Update cells with values
|
433
|
+
Object.entries(currentRow.cells).forEach(([col, value]) => {{
|
434
|
+
const cell = document.getElementById(`cell-${{currentRow.index}}-${{col}}`);
|
435
|
+
if (cell) {{
|
436
|
+
cell.textContent = value || '';
|
437
|
+
cell.classList.remove('cell-generating');
|
438
|
+
}}
|
439
|
+
}});
|
440
|
+
}}
|
441
|
+
|
442
|
+
function completeRow(completedRow) {{
|
443
|
+
// Find the row element and update it with final data
|
444
|
+
const rowElement = document.querySelector(`tr[data-row-index="${{completedRow.index}}"]`);
|
445
|
+
if (rowElement) {{
|
446
|
+
{json.dumps(columns)}.forEach((col, colIndex) => {{
|
447
|
+
const cell = rowElement.cells[colIndex + 1]; // +1 for row number column
|
448
|
+
if (cell) {{
|
449
|
+
cell.textContent = completedRow.data[col] || '';
|
450
|
+
cell.classList.remove('cell-generating');
|
451
|
+
}}
|
452
|
+
}});
|
453
|
+
}}
|
454
|
+
|
455
|
+
// Clear current row if this was it
|
456
|
+
if (currentRowElement && parseInt(currentRowElement.dataset.rowIndex) === completedRow.index) {{
|
457
|
+
currentRowElement = null;
|
458
|
+
}}
|
459
|
+
}}
|
460
|
+
|
461
|
+
function addRows(rows) {{
|
462
|
+
const tbody = document.getElementById('tableBody');
|
463
|
+
|
464
|
+
rows.forEach((row, index) => {{
|
465
|
+
const tr = document.createElement('tr');
|
466
|
+
tr.className = 'new-row';
|
467
|
+
|
468
|
+
const numCell = document.createElement('td');
|
469
|
+
numCell.className = 'row-number';
|
470
|
+
numCell.textContent = rowCount - rows.length + index + 1;
|
471
|
+
tr.appendChild(numCell);
|
472
|
+
|
473
|
+
{json.dumps(columns)}.forEach(col => {{
|
474
|
+
const td = document.createElement('td');
|
475
|
+
td.className = 'cell-content';
|
476
|
+
td.textContent = row[col] || '';
|
477
|
+
tr.appendChild(td);
|
478
|
+
}});
|
479
|
+
|
480
|
+
tbody.appendChild(tr);
|
481
|
+
}});
|
482
|
+
}}
|
483
|
+
|
484
|
+
function updateStatus(completed) {{
|
485
|
+
document.getElementById('rowCount').textContent = `${{rowCount}} rows`;
|
486
|
+
if (completed) {{
|
487
|
+
document.getElementById('statusText').textContent = 'Complete';
|
488
|
+
document.getElementById('statusDot').classList.add('complete');
|
489
|
+
}}
|
490
|
+
}}
|
491
|
+
|
492
|
+
makeColumnsResizable(document.querySelector('table'));
|
493
|
+
fetchData();
|
494
|
+
</script>
|
495
|
+
</body>
|
496
|
+
</html>"""
|
497
|
+
|
498
|
+
|
499
|
+
def create_viewer_callback(viewer: LiveViewer) -> Callable[[Dict[str, Any]], None]:
|
500
|
+
"""Create a callback function for dataset generation progress."""
|
501
|
+
def callback(row: Dict[str, Any]):
|
502
|
+
viewer.add_row(row)
|
503
|
+
return callback
|
504
|
+
|
505
|
+
|
506
|
+
def generate_with_viewer(
|
507
|
+
dataset_instance,
|
508
|
+
n: Optional[int] = None,
|
509
|
+
progress: bool = True,
|
510
|
+
viewer_title: str = "Dataset Generation",
|
511
|
+
auto_open: bool = True,
|
512
|
+
stream_delay: float = 0.05,
|
513
|
+
cell_delay: float = 0.3
|
514
|
+
):
|
515
|
+
"""Generate dataset with live viewer showing cell-by-cell generation.
|
516
|
+
|
517
|
+
Args:
|
518
|
+
dataset_instance: The Dataset instance
|
519
|
+
n: Number of samples to generate
|
520
|
+
progress: Show progress bar (ignored when using viewer)
|
521
|
+
viewer_title: Title for the HTML viewer
|
522
|
+
auto_open: Whether to auto-open browser
|
523
|
+
stream_delay: Delay between rows for streaming effect
|
524
|
+
cell_delay: Delay between individual cell generations
|
525
|
+
|
526
|
+
Returns:
|
527
|
+
pd.DataFrame: Generated dataset
|
528
|
+
"""
|
529
|
+
viewer = LiveViewer(title=viewer_title, auto_open=auto_open)
|
530
|
+
num_samples = n or dataset_instance.n
|
531
|
+
|
532
|
+
try:
|
533
|
+
# Start viewer
|
534
|
+
url = viewer.start(dataset_instance.schema)
|
535
|
+
print(f"Live viewer started at: {url}")
|
536
|
+
|
537
|
+
# Build dependency graph (copied from dataset logic)
|
538
|
+
dependencies = dataset_instance._build_dependency_graph()
|
539
|
+
execution_order = dataset_instance._topological_sort(dependencies)
|
540
|
+
|
541
|
+
# Generate data with live updates
|
542
|
+
data = []
|
543
|
+
|
544
|
+
for i in range(num_samples):
|
545
|
+
# Start new row
|
546
|
+
viewer.start_row(i)
|
547
|
+
|
548
|
+
row = {}
|
549
|
+
for column in execution_order:
|
550
|
+
# Generate cell value
|
551
|
+
value = dataset_instance._generate_value(column, row)
|
552
|
+
row[column] = value
|
553
|
+
|
554
|
+
# Update viewer with new cell value
|
555
|
+
viewer.update_cell(column, value)
|
556
|
+
|
557
|
+
# Delay to show cell generation effect
|
558
|
+
if cell_delay > 0:
|
559
|
+
time.sleep(cell_delay)
|
560
|
+
|
561
|
+
# Complete the row
|
562
|
+
data.append(row)
|
563
|
+
viewer.add_row(row)
|
564
|
+
|
565
|
+
# Small delay between rows
|
566
|
+
if stream_delay > 0:
|
567
|
+
time.sleep(stream_delay)
|
568
|
+
|
569
|
+
viewer.complete()
|
570
|
+
|
571
|
+
# Import pandas dynamically to avoid circular imports
|
572
|
+
import pandas as pd
|
573
|
+
dataset_instance._data = pd.DataFrame(data)
|
574
|
+
return dataset_instance._data
|
575
|
+
|
576
|
+
except Exception as e:
|
577
|
+
viewer.stop()
|
578
|
+
raise e
|
579
|
+
finally:
|
580
|
+
# Keep server running for a bit so user can see final state
|
581
|
+
time.sleep(1)
|
@@ -0,0 +1,124 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: chatan
|
3
|
+
Version: 0.1.3
|
4
|
+
Summary: Create synthetic datasets with LLM generators and samplers
|
5
|
+
Project-URL: Documentation, https://github.com/cdreetz/chatan#readme
|
6
|
+
Project-URL: Issues, https://github.com/cdreetz/chatan/issues
|
7
|
+
Project-URL: Source, https://github.com/cdreetz/chatan
|
8
|
+
Author-email: Christian Reetz <cdreetz@gmail.com>
|
9
|
+
License-Expression: MIT
|
10
|
+
License-File: LICENSE
|
11
|
+
Keywords: dataset generation,llm,machine learning,synthetic data
|
12
|
+
Classifier: Development Status :: 4 - Beta
|
13
|
+
Classifier: Intended Audience :: Developers
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
15
|
+
Classifier: Operating System :: OS Independent
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
17
|
+
Classifier: Programming Language :: Python :: 3.8
|
18
|
+
Classifier: Programming Language :: Python :: 3.9
|
19
|
+
Classifier: Programming Language :: Python :: 3.10
|
20
|
+
Classifier: Programming Language :: Python :: 3.11
|
21
|
+
Classifier: Programming Language :: Python :: 3.12
|
22
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
23
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
24
|
+
Requires-Python: >=3.8
|
25
|
+
Requires-Dist: anthropic>=0.7.0
|
26
|
+
Requires-Dist: datasets>=2.0.0
|
27
|
+
Requires-Dist: nltk>=3.9.1
|
28
|
+
Requires-Dist: numpy>=1.20.0
|
29
|
+
Requires-Dist: openai>=1.0.0
|
30
|
+
Requires-Dist: pandas>=1.3.0
|
31
|
+
Requires-Dist: pydantic>=2.0.0
|
32
|
+
Requires-Dist: scikit-learn>=1.3.2
|
33
|
+
Requires-Dist: sentence-transformers>=3.2.1
|
34
|
+
Requires-Dist: torch>=2.5.1
|
35
|
+
Requires-Dist: tqdm>=4.0.0
|
36
|
+
Requires-Dist: transformers>=4.0.0
|
37
|
+
Description-Content-Type: text/markdown
|
38
|
+
|
39
|
+
# Chatan
|
40
|
+
|
41
|
+
Create diverse, synthetic datasets. Start from scratch or augment an existing dataset. Simple define your dataset schema as a set of generators, typically being LLMs with a prompt describing what kind of examples you want.
|
42
|
+
|
43
|
+
## Installation
|
44
|
+
|
45
|
+
```
|
46
|
+
pip install chatan
|
47
|
+
```
|
48
|
+
|
49
|
+
## Getting Started
|
50
|
+
|
51
|
+
```
|
52
|
+
import chatan
|
53
|
+
|
54
|
+
# Create a generator
|
55
|
+
gen = chatan.generator("openai", "YOUR_API_KEY")
|
56
|
+
|
57
|
+
# Define a dataset schema
|
58
|
+
ds = chatan.dataset({
|
59
|
+
"topic": chatan.sample.choice(["Python", "JavaScript", "Rust"]),
|
60
|
+
"prompt": gen("write a programming question about {topic}"),
|
61
|
+
"response": gen("answer this question: {prompt}")
|
62
|
+
})
|
63
|
+
|
64
|
+
# Generate the data with a progress bar
|
65
|
+
df = ds.generate(n=10)
|
66
|
+
```
|
67
|
+
|
68
|
+
## Examples
|
69
|
+
|
70
|
+
Create Data Mixes
|
71
|
+
|
72
|
+
```
|
73
|
+
from chatan import dataset, generator, sample
|
74
|
+
import uuid
|
75
|
+
|
76
|
+
gen = generator("openai", "YOUR_API_KEY")
|
77
|
+
|
78
|
+
mix = [
|
79
|
+
"san antonio, tx",
|
80
|
+
"marfa, tx",
|
81
|
+
"paris, fr"
|
82
|
+
]
|
83
|
+
|
84
|
+
ds = dataset({
|
85
|
+
"id": sample.uuid(),
|
86
|
+
"topic": sample.choice(mix),
|
87
|
+
"prompt": gen("write an example question about the history of {topic}"),
|
88
|
+
"response": gen("respond to: {prompt}"),
|
89
|
+
})
|
90
|
+
```
|
91
|
+
|
92
|
+
Augment datasets
|
93
|
+
|
94
|
+
```
|
95
|
+
from chatan import generator, dataset, sample
|
96
|
+
from datasets import load_dataset
|
97
|
+
|
98
|
+
gen = generator("openai", "YOUR_API_KEY")
|
99
|
+
hf_data = load_dataset("some/dataset")
|
100
|
+
|
101
|
+
ds = dataset({
|
102
|
+
"original_prompt": sample.from_dataset(hf_data, "prompt"),
|
103
|
+
"variation": gen("rewrite this prompt: {original_prompt}"),
|
104
|
+
"response": gen("respond to: {variation}")
|
105
|
+
})
|
106
|
+
|
107
|
+
```
|
108
|
+
|
109
|
+
## Citation
|
110
|
+
|
111
|
+
If you use this code in your research, please cite:
|
112
|
+
|
113
|
+
```
|
114
|
+
@software{reetz2025chatan,
|
115
|
+
author = {Reetz, Christian},
|
116
|
+
title = {chatan: Create synthetic datasets with LLM generators.},
|
117
|
+
url = {https://github.com/cdreetz/chatan},
|
118
|
+
year = {2025}
|
119
|
+
}
|
120
|
+
```
|
121
|
+
|
122
|
+
## Contributing
|
123
|
+
|
124
|
+
Community contributions are more than welcome, bug reports, bug fixes, feature requests, feature additions, please refer to the Issues tab.
|
@@ -0,0 +1,10 @@
|
|
1
|
+
chatan/__init__.py,sha256=AEW7uf7P7t_BAkKHEUl9Eq02R6C8VAFfYAc5zM4ECEc,355
|
2
|
+
chatan/dataset.py,sha256=R0sa2m7LM_BABi2B1NJpGxiSFuh0h1cFMGUCw-9ZKhM,5716
|
3
|
+
chatan/evaluate.py,sha256=m_2zTEJUNfA3Jqt195dd4LqOTTMZOUVtoNDjOhwtHHM,12543
|
4
|
+
chatan/generator.py,sha256=syapQLWPvaLJ_Jw3Vn9egTvUzIkgotEFut1_qpqxHEk,4764
|
5
|
+
chatan/sampler.py,sha256=0X6AVQK20py4SwKnsppZC2yAZnP_jBhRxt9MfT1e-k4,4812
|
6
|
+
chatan/viewer.py,sha256=tnQPoOYz_c22BouPPC7zDjdjHox0SETlf9XeXYs82FA,18885
|
7
|
+
chatan-0.1.3.dist-info/METADATA,sha256=30X5jEMcs7XjAvTqbYaH1flCZ-OaWmmXsTbQ9iL5vPQ,3445
|
8
|
+
chatan-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
9
|
+
chatan-0.1.3.dist-info/licenses/LICENSE,sha256=QplUnrABTIgjBDFCKoGgdjnZndIBrRCS2WJr6gm9kX4,1072
|
10
|
+
chatan-0.1.3.dist-info/RECORD,,
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2025 Christian Reetz
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
chatan-0.1.0.dist-info/METADATA
DELETED
@@ -1,83 +0,0 @@
|
|
1
|
-
Metadata-Version: 2.4
|
2
|
-
Name: chatan
|
3
|
-
Version: 0.1.0
|
4
|
-
Summary: Create synthetic datasets with LLM generators and samplers
|
5
|
-
Project-URL: Documentation, https://github.com/cdreetz/chatan#readme
|
6
|
-
Project-URL: Issues, https://github.com/cdreetz/chatan/issues
|
7
|
-
Project-URL: Source, https://github.com/cdreetz/chatan
|
8
|
-
Author-email: Christian Reetz <cdreetz@gmail.com>
|
9
|
-
License-Expression: MIT
|
10
|
-
Keywords: dataset generation,llm,machine learning,synthetic data
|
11
|
-
Classifier: Development Status :: 4 - Beta
|
12
|
-
Classifier: Intended Audience :: Developers
|
13
|
-
Classifier: License :: OSI Approved :: MIT License
|
14
|
-
Classifier: Operating System :: OS Independent
|
15
|
-
Classifier: Programming Language :: Python :: 3
|
16
|
-
Classifier: Programming Language :: Python :: 3.8
|
17
|
-
Classifier: Programming Language :: Python :: 3.9
|
18
|
-
Classifier: Programming Language :: Python :: 3.10
|
19
|
-
Classifier: Programming Language :: Python :: 3.11
|
20
|
-
Classifier: Programming Language :: Python :: 3.12
|
21
|
-
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
22
|
-
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
23
|
-
Requires-Python: >=3.8
|
24
|
-
Requires-Dist: anthropic>=0.7.0
|
25
|
-
Requires-Dist: datasets>=2.0.0
|
26
|
-
Requires-Dist: numpy>=1.20.0
|
27
|
-
Requires-Dist: openai>=1.0.0
|
28
|
-
Requires-Dist: pandas>=1.3.0
|
29
|
-
Requires-Dist: pydantic>=2.0.0
|
30
|
-
Description-Content-Type: text/markdown
|
31
|
-
|
32
|
-
## Examples
|
33
|
-
|
34
|
-
Prompt a dataset
|
35
|
-
|
36
|
-
```
|
37
|
-
import chatan
|
38
|
-
|
39
|
-
gen = chatan.generator.client("YOUR_OPENAI_API_KEY")
|
40
|
-
ds = chatan.dataset("create a QA dataset for finetuning an LLM on pharmacology")
|
41
|
-
```
|
42
|
-
|
43
|
-
Creating datasets with different data mixes
|
44
|
-
|
45
|
-
```
|
46
|
-
import uuid
|
47
|
-
from chatan import dataset, generator, mix
|
48
|
-
|
49
|
-
gen = generator.client("YOUR_OPENAI_API_KEY")
|
50
|
-
#generator.client("anthropic", "YOUR_ANTHROPIC_API_KEY")
|
51
|
-
|
52
|
-
mix = {
|
53
|
-
"implementation": "Can you implement a matmul kernel in Triton",
|
54
|
-
"conversion": "Convert this pytorch model to Triton",
|
55
|
-
"explanation": "What memory access optimizations are being used here?"
|
56
|
-
}
|
57
|
-
|
58
|
-
ds = dataset({
|
59
|
-
"id": uuid,
|
60
|
-
"task": sample.choice(mix),
|
61
|
-
"prompt": gen("write a prompt for {task}"),
|
62
|
-
"response": gen("write a response to {prompt}"),
|
63
|
-
)}
|
64
|
-
```
|
65
|
-
|
66
|
-
Augment datasets
|
67
|
-
|
68
|
-
```
|
69
|
-
import uuid
|
70
|
-
from chatan import dataset, generator
|
71
|
-
from dataset import load_dataset
|
72
|
-
|
73
|
-
gen = generator.client("YOUR_OPENAI_API_KEY")
|
74
|
-
hf_dataset = load_dataset("GPU_MODE/KernelBook")
|
75
|
-
|
76
|
-
ds = dataset({
|
77
|
-
"id": sample.from_dataset(hf_data, "id", default=sample.uuid()),
|
78
|
-
"prompt": sample.from_dataset(hf_data, "prompt", aug=gen("provide a variation of this prompt")),
|
79
|
-
"response": gen("write a response to {prompt}")
|
80
|
-
|
81
|
-
})
|
82
|
-
|
83
|
-
```
|
chatan-0.1.0.dist-info/RECORD
DELETED
@@ -1,7 +0,0 @@
|
|
1
|
-
chatan/__init__.py,sha256=GdCxiObHomq4-2TyqS2dfEd2EuDYGpi2AkvC5QIb3mU,233
|
2
|
-
chatan/dataset.py,sha256=t6RrrQchsD2dROD886IfnlXWnn-F-IAWMcEIK0SS3xg,4358
|
3
|
-
chatan/generator.py,sha256=kj8axCWI00F8R0DJL831AUtK41Bvv2VL4nOweTgGfGc,3286
|
4
|
-
chatan/sampler.py,sha256=0X6AVQK20py4SwKnsppZC2yAZnP_jBhRxt9MfT1e-k4,4812
|
5
|
-
chatan-0.1.0.dist-info/METADATA,sha256=A5L6T8zlmoO5tKG30kmjQKQnOcCivSypaw8TFd4fS_k,2549
|
6
|
-
chatan-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
chatan-0.1.0.dist-info/RECORD,,
|
File without changes
|