camel-ai 0.2.29__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/datasets/__init__.py +7 -5
- camel/datasets/base_generator.py +335 -0
- camel/datasets/models.py +61 -0
- camel/datasets/static_dataset.py +346 -0
- camel/embeddings/openai_compatible_embedding.py +4 -4
- camel/environments/__init__.py +11 -2
- camel/environments/models.py +111 -0
- camel/environments/multi_step.py +271 -0
- camel/environments/single_step.py +293 -0
- camel/logger.py +56 -0
- camel/models/openai_compatible_model.py +4 -2
- camel/toolkits/browser_toolkit.py +59 -1
- camel/toolkits/search_toolkit.py +70 -0
- camel/utils/commons.py +1 -1
- {camel_ai-0.2.29.dist-info → camel_ai-0.2.30.dist-info}/METADATA +2 -1
- {camel_ai-0.2.29.dist-info → camel_ai-0.2.30.dist-info}/RECORD +19 -15
- camel/datasets/base.py +0 -639
- camel/environments/base.py +0 -509
- {camel_ai-0.2.29.dist-info → camel_ai-0.2.30.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.29.dist-info → camel_ai-0.2.30.dist-info}/licenses/LICENSE +0 -0
camel/datasets/base.py
DELETED
|
@@ -1,639 +0,0 @@
|
|
|
1
|
-
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
-
# you may not use this file except in compliance with the License.
|
|
4
|
-
# You may obtain a copy of the License at
|
|
5
|
-
#
|
|
6
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
-
#
|
|
8
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
-
# See the License for the specific language governing permissions and
|
|
12
|
-
# limitations under the License.
|
|
13
|
-
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
-
|
|
15
|
-
import json
|
|
16
|
-
import random
|
|
17
|
-
from datetime import datetime
|
|
18
|
-
from pathlib import Path
|
|
19
|
-
from typing import (
|
|
20
|
-
Any,
|
|
21
|
-
Dict,
|
|
22
|
-
List,
|
|
23
|
-
Optional,
|
|
24
|
-
Sized,
|
|
25
|
-
Union,
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
from datasets import Dataset as HFDataset
|
|
29
|
-
from pydantic import BaseModel, Field, ValidationError
|
|
30
|
-
from torch.utils.data import Dataset
|
|
31
|
-
|
|
32
|
-
from camel.agents import ChatAgent
|
|
33
|
-
from camel.logger import get_logger
|
|
34
|
-
from camel.verifiers import BaseVerifier
|
|
35
|
-
from camel.verifiers.models import VerifierInput
|
|
36
|
-
|
|
37
|
-
logger = get_logger(__name__)
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class DataPoint(BaseModel):
|
|
41
|
-
r"""A single data point in the dataset.
|
|
42
|
-
|
|
43
|
-
Attributes:
|
|
44
|
-
question (str): The primary question or issue to be addressed.
|
|
45
|
-
rationale (Optional[str]): Logical reasoning or explanation behind the
|
|
46
|
-
answer. (default: :obj:`None`)
|
|
47
|
-
final_answer (str): The final answer.
|
|
48
|
-
metadata Optional[Dict[str, Any]]: Additional metadata about the data
|
|
49
|
-
point. (default: :obj:`None`)
|
|
50
|
-
"""
|
|
51
|
-
|
|
52
|
-
question: str = Field(
|
|
53
|
-
..., description="The primary question or issue to be addressed."
|
|
54
|
-
)
|
|
55
|
-
rationale: Optional[str] = Field(
|
|
56
|
-
default=None,
|
|
57
|
-
description="Logical reasoning or explanation behind the answer.",
|
|
58
|
-
)
|
|
59
|
-
final_answer: str = Field(..., description="The final answer.")
|
|
60
|
-
|
|
61
|
-
metadata: Optional[Dict[str, Any]] = Field(
|
|
62
|
-
default=None, description="Additional metadata about the data point."
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
66
|
-
r"""Convert DataPoint to a dictionary.
|
|
67
|
-
|
|
68
|
-
Returns:
|
|
69
|
-
Dict[str, Any]: Dictionary representation of the DataPoint.
|
|
70
|
-
"""
|
|
71
|
-
return self.dict()
|
|
72
|
-
|
|
73
|
-
@classmethod
|
|
74
|
-
def from_dict(cls, data: Dict[str, Any]) -> 'DataPoint':
|
|
75
|
-
r"""Create a DataPoint from a dictionary.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
data (Dict[str, Any]): Dictionary containing DataPoint fields.
|
|
79
|
-
|
|
80
|
-
Returns:
|
|
81
|
-
DataPoint: New DataPoint instance.
|
|
82
|
-
"""
|
|
83
|
-
return cls(**data)
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
class StaticDataset(Dataset):
|
|
87
|
-
r"""A static dataset containing a list of datapoints.
|
|
88
|
-
Ensures that all items adhere to the DataPoint schema.
|
|
89
|
-
This dataset extends :obj:`Dataset` from PyTorch and should
|
|
90
|
-
be used when its size is fixed at runtime.
|
|
91
|
-
|
|
92
|
-
This class can initialize from Hugging Face Datasets,
|
|
93
|
-
PyTorch Datasets, JSON file paths, or lists of dictionaries,
|
|
94
|
-
converting them into a consistent internal format.
|
|
95
|
-
"""
|
|
96
|
-
|
|
97
|
-
def __init__(
|
|
98
|
-
self,
|
|
99
|
-
data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]],
|
|
100
|
-
seed: int = 42,
|
|
101
|
-
min_samples: int = 1,
|
|
102
|
-
strict: bool = False,
|
|
103
|
-
**kwargs,
|
|
104
|
-
):
|
|
105
|
-
r"""Initialize the static dataset and validate integrity.
|
|
106
|
-
|
|
107
|
-
Args:
|
|
108
|
-
data (:obj:`Union[HFDataset, Dataset,
|
|
109
|
-
Path, List[Dict[str, Any]]]`):
|
|
110
|
-
Input data, which can be one of the following:
|
|
111
|
-
- A Hugging Face Dataset (:obj:`HFDataset`).
|
|
112
|
-
- A PyTorch Dataset (:obj:`torch.utils.data.Dataset`).
|
|
113
|
-
- A :obj:`Path` object representing a JSON file.
|
|
114
|
-
- A list of dictionaries with :obj:`DataPoint`-compatible
|
|
115
|
-
fields.
|
|
116
|
-
seed (:obj:`int`): Random seed for reproducibility.
|
|
117
|
-
Default is :obj:`42`.
|
|
118
|
-
min_samples (:obj:`int`): Minimum required number of samples.
|
|
119
|
-
Default is :obj:`1`.
|
|
120
|
-
strict (:obj:`bool`): Whether to raise an error on invalid
|
|
121
|
-
datapoints (:obj:`True`) or skip/filter them (:obj:`False`).
|
|
122
|
-
Default is :obj:`False`.
|
|
123
|
-
**kwargs: Additional dataset parameters.
|
|
124
|
-
|
|
125
|
-
Raises:
|
|
126
|
-
TypeError: If the input data type is unsupported.
|
|
127
|
-
ValueError: If the dataset contains fewer than :obj:`min_samples`
|
|
128
|
-
datapoints or if validation fails.
|
|
129
|
-
FileNotFoundError: If the specified JSON file path does not exist.
|
|
130
|
-
json.JSONDecodeError: If the JSON file contains invalid formatting.
|
|
131
|
-
"""
|
|
132
|
-
|
|
133
|
-
# Store all parameters in metadata dict for compatibility
|
|
134
|
-
self._metadata = {
|
|
135
|
-
**kwargs,
|
|
136
|
-
}
|
|
137
|
-
self._rng = random.Random(seed)
|
|
138
|
-
self._strict = strict
|
|
139
|
-
|
|
140
|
-
self.data: List[DataPoint] = self._init_data(data)
|
|
141
|
-
self._length = len(self.data)
|
|
142
|
-
|
|
143
|
-
if self._length < min_samples:
|
|
144
|
-
raise ValueError(
|
|
145
|
-
"The dataset does not contain enough samples. "
|
|
146
|
-
f"Need {max(0, min_samples)}, got {self._length}"
|
|
147
|
-
)
|
|
148
|
-
|
|
149
|
-
def _init_data(
|
|
150
|
-
self, data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]
|
|
151
|
-
) -> List[DataPoint]:
|
|
152
|
-
r"""Convert input data from various formats into a list of
|
|
153
|
-
:obj:`DataPoint` instances.
|
|
154
|
-
|
|
155
|
-
Args:
|
|
156
|
-
data (:obj:`Union[
|
|
157
|
-
HFDataset,
|
|
158
|
-
Dataset,
|
|
159
|
-
Path,
|
|
160
|
-
List[Dict[str, Any]]
|
|
161
|
-
]`):
|
|
162
|
-
Input dataset in one of the supported formats.
|
|
163
|
-
|
|
164
|
-
Returns:
|
|
165
|
-
:obj:`List[DataPoint]`: A list of validated :obj:`DataPoint`
|
|
166
|
-
instances.
|
|
167
|
-
|
|
168
|
-
Raises:
|
|
169
|
-
TypeError: If the input data type is unsupported.
|
|
170
|
-
"""
|
|
171
|
-
|
|
172
|
-
if isinstance(data, HFDataset):
|
|
173
|
-
raw_data = self._init_from_hf_dataset(data)
|
|
174
|
-
elif isinstance(data, Dataset):
|
|
175
|
-
raw_data = self._init_from_pytorch_dataset(data)
|
|
176
|
-
elif isinstance(data, Path):
|
|
177
|
-
raw_data = self._init_from_json_path(data)
|
|
178
|
-
elif isinstance(data, list):
|
|
179
|
-
raw_data = self._init_from_list(data)
|
|
180
|
-
else:
|
|
181
|
-
raise TypeError("Unsupported data type")
|
|
182
|
-
|
|
183
|
-
def create_datapoint(
|
|
184
|
-
item: Dict[str, Any], idx: int
|
|
185
|
-
) -> Optional[DataPoint]:
|
|
186
|
-
# Add type checks for required fields to make mypy happy
|
|
187
|
-
question = item.get('question')
|
|
188
|
-
if not isinstance(question, str):
|
|
189
|
-
if self._strict:
|
|
190
|
-
raise ValueError(
|
|
191
|
-
f"Sample at index {idx} has invalid 'question': "
|
|
192
|
-
f"expected str, got {type(question)}"
|
|
193
|
-
)
|
|
194
|
-
else:
|
|
195
|
-
logger.warning(
|
|
196
|
-
f"Skipping sample at index {idx}: invalid 'question'"
|
|
197
|
-
)
|
|
198
|
-
return None
|
|
199
|
-
|
|
200
|
-
rationale = item.get('rationale')
|
|
201
|
-
if not isinstance(rationale, str):
|
|
202
|
-
if self._strict:
|
|
203
|
-
raise ValueError(
|
|
204
|
-
f"Sample at index {idx} has invalid 'rationale': "
|
|
205
|
-
f"expected str, got {type(rationale)}"
|
|
206
|
-
)
|
|
207
|
-
else:
|
|
208
|
-
logger.warning(
|
|
209
|
-
f"Skipping sample at index {idx}: invalid 'rationale'"
|
|
210
|
-
)
|
|
211
|
-
return None
|
|
212
|
-
|
|
213
|
-
final_answer = item.get('final_answer')
|
|
214
|
-
if not isinstance(final_answer, str):
|
|
215
|
-
if self._strict:
|
|
216
|
-
raise ValueError(
|
|
217
|
-
f"Sample at index {idx} has invalid 'final_answer': "
|
|
218
|
-
f"expected str, got {type(final_answer)}"
|
|
219
|
-
)
|
|
220
|
-
else:
|
|
221
|
-
logger.warning(
|
|
222
|
-
f"Skipping sample at index {idx}: "
|
|
223
|
-
"invalid 'final_answer'"
|
|
224
|
-
)
|
|
225
|
-
return None
|
|
226
|
-
|
|
227
|
-
try:
|
|
228
|
-
return DataPoint(
|
|
229
|
-
question=question,
|
|
230
|
-
rationale=rationale,
|
|
231
|
-
final_answer=final_answer,
|
|
232
|
-
metadata=item.get('metadata'),
|
|
233
|
-
)
|
|
234
|
-
except ValidationError as e:
|
|
235
|
-
if self._strict:
|
|
236
|
-
raise ValueError(
|
|
237
|
-
f"Sample at index {idx} validation error: {e}"
|
|
238
|
-
)
|
|
239
|
-
else:
|
|
240
|
-
logger.warning(
|
|
241
|
-
f"Skipping invalid sample at index {idx} "
|
|
242
|
-
f"due to validation error: {e}"
|
|
243
|
-
)
|
|
244
|
-
return None
|
|
245
|
-
|
|
246
|
-
unfiltered_data = [
|
|
247
|
-
create_datapoint(item, i) for i, item in enumerate(raw_data)
|
|
248
|
-
]
|
|
249
|
-
return [dp for dp in unfiltered_data if dp is not None]
|
|
250
|
-
|
|
251
|
-
def __len__(self) -> int:
|
|
252
|
-
r"""Return the size of the dataset."""
|
|
253
|
-
return self._length
|
|
254
|
-
|
|
255
|
-
def __getitem__(self, idx: int) -> DataPoint:
|
|
256
|
-
r"""Retrieve a datapoint by index.
|
|
257
|
-
|
|
258
|
-
Args:
|
|
259
|
-
idx (:obj:`int`): Index of the datapoint.
|
|
260
|
-
|
|
261
|
-
Returns:
|
|
262
|
-
:obj:`DataPoint`: The datapoint corresponding to the given index.
|
|
263
|
-
|
|
264
|
-
Raises:
|
|
265
|
-
IndexError: If :obj:`idx` is out of bounds (negative or greater
|
|
266
|
-
than dataset length - 1).
|
|
267
|
-
"""
|
|
268
|
-
|
|
269
|
-
if idx < 0 or idx >= self._length:
|
|
270
|
-
raise IndexError(
|
|
271
|
-
f"Index {idx} out of bounds for dataset of size {self._length}"
|
|
272
|
-
)
|
|
273
|
-
return self.data[idx]
|
|
274
|
-
|
|
275
|
-
def sample(self) -> DataPoint:
|
|
276
|
-
r"""Sample a random datapoint from the dataset.
|
|
277
|
-
|
|
278
|
-
Returns:
|
|
279
|
-
:obj:`DataPoint`: A randomly sampled :obj:`DataPoint`.
|
|
280
|
-
|
|
281
|
-
Raises:
|
|
282
|
-
RuntimeError: If the dataset is empty and no samples can be drawn.
|
|
283
|
-
"""
|
|
284
|
-
|
|
285
|
-
if self._length == 0:
|
|
286
|
-
raise RuntimeError("Dataset is empty, cannot sample.")
|
|
287
|
-
idx = self._rng.randint(0, self._length - 1)
|
|
288
|
-
return self[idx]
|
|
289
|
-
|
|
290
|
-
@property
|
|
291
|
-
def metadata(self) -> Dict[str, Any]:
|
|
292
|
-
r"""Retrieve dataset metadata.
|
|
293
|
-
|
|
294
|
-
Returns:
|
|
295
|
-
:obj:`Dict[str, Any]`: A copy of the dataset metadata dictionary.
|
|
296
|
-
"""
|
|
297
|
-
|
|
298
|
-
return self._metadata.copy()
|
|
299
|
-
|
|
300
|
-
def _init_from_hf_dataset(self, data: HFDataset) -> List[Dict[str, Any]]:
|
|
301
|
-
r"""Convert a Hugging Face dataset into a list of dictionaries.
|
|
302
|
-
|
|
303
|
-
Args:
|
|
304
|
-
data (:obj:`HFDataset`): A Hugging Face dataset.
|
|
305
|
-
|
|
306
|
-
Returns:
|
|
307
|
-
:obj:`List[Dict[str, Any]]`: A list of dictionaries representing
|
|
308
|
-
the dataset, where each dictionary corresponds to a datapoint.
|
|
309
|
-
"""
|
|
310
|
-
return [dict(item) for item in data]
|
|
311
|
-
|
|
312
|
-
def _init_from_pytorch_dataset(
|
|
313
|
-
self, data: Dataset
|
|
314
|
-
) -> List[Dict[str, Any]]:
|
|
315
|
-
r"""Convert a PyTorch dataset into a list of dictionaries.
|
|
316
|
-
|
|
317
|
-
Args:
|
|
318
|
-
data (:obj:`Dataset`): A PyTorch dataset.
|
|
319
|
-
|
|
320
|
-
Returns:
|
|
321
|
-
:obj:`List[Dict[str, Any]]`: A list of dictionaries representing
|
|
322
|
-
the dataset.
|
|
323
|
-
|
|
324
|
-
Raises:
|
|
325
|
-
TypeError: If the dataset does not implement :obj:`__len__()`
|
|
326
|
-
or contains non-dictionary elements.
|
|
327
|
-
"""
|
|
328
|
-
if not isinstance(data, Sized):
|
|
329
|
-
raise TypeError(
|
|
330
|
-
f"{type(data).__name__} does not implement `__len__()`."
|
|
331
|
-
)
|
|
332
|
-
raw_data = []
|
|
333
|
-
|
|
334
|
-
for i in range(len(data)):
|
|
335
|
-
item = data[i]
|
|
336
|
-
if not isinstance(item, dict):
|
|
337
|
-
raise TypeError(
|
|
338
|
-
f"Item at index {i} is not a dict: "
|
|
339
|
-
f"got {type(item).__name__}"
|
|
340
|
-
)
|
|
341
|
-
raw_data.append(dict(item))
|
|
342
|
-
return raw_data
|
|
343
|
-
|
|
344
|
-
def _init_from_json_path(self, data: Path) -> List[Dict[str, Any]]:
|
|
345
|
-
r"""Load and parse a dataset from a JSON file.
|
|
346
|
-
|
|
347
|
-
Args:
|
|
348
|
-
data (:obj:`Path`): Path to the JSON file.
|
|
349
|
-
|
|
350
|
-
Returns:
|
|
351
|
-
:obj:`List[Dict[str, Any]]`: A list of datapoint dictionaries.
|
|
352
|
-
|
|
353
|
-
Raises:
|
|
354
|
-
FileNotFoundError: If the specified JSON file does not exist.
|
|
355
|
-
ValueError: If the JSON content is not a list of dictionaries.
|
|
356
|
-
json.JSONDecodeError: If the JSON file has invalid formatting.
|
|
357
|
-
"""
|
|
358
|
-
|
|
359
|
-
if not data.exists():
|
|
360
|
-
raise FileNotFoundError(f"JSON file not found: {data}")
|
|
361
|
-
try:
|
|
362
|
-
logger.debug(f"Loading JSON from {data}")
|
|
363
|
-
with data.open('r', encoding='utf-8') as f:
|
|
364
|
-
loaded_data = json.load(f)
|
|
365
|
-
logger.info(
|
|
366
|
-
f"Successfully loaded {len(loaded_data)} items from {data}"
|
|
367
|
-
)
|
|
368
|
-
except json.JSONDecodeError as e:
|
|
369
|
-
raise ValueError(f"Invalid JSON in file {data}: {e}")
|
|
370
|
-
if not isinstance(loaded_data, list):
|
|
371
|
-
raise ValueError("JSON file must contain a list of dictionaries")
|
|
372
|
-
for i, item in enumerate(loaded_data):
|
|
373
|
-
if not isinstance(item, dict):
|
|
374
|
-
raise ValueError(
|
|
375
|
-
f"Expected a dictionary at index {i}, "
|
|
376
|
-
f"got {type(item).__name__}"
|
|
377
|
-
)
|
|
378
|
-
return loaded_data
|
|
379
|
-
|
|
380
|
-
def _init_from_list(
|
|
381
|
-
self, data: List[Dict[str, Any]]
|
|
382
|
-
) -> List[Dict[str, Any]]:
|
|
383
|
-
r"""Validate and convert a list of dictionaries into a dataset.
|
|
384
|
-
|
|
385
|
-
Args:
|
|
386
|
-
data (:obj:`List[Dict[str, Any]]`): A list of dictionaries where
|
|
387
|
-
each dictionary must be a valid :obj:`DataPoint`.
|
|
388
|
-
|
|
389
|
-
Returns:
|
|
390
|
-
:obj:`List[Dict[str, Any]]`: The validated list of dictionaries.
|
|
391
|
-
|
|
392
|
-
Raises:
|
|
393
|
-
ValueError: If any item in the list is not a dictionary.
|
|
394
|
-
"""
|
|
395
|
-
for i, item in enumerate(data):
|
|
396
|
-
if not isinstance(item, dict):
|
|
397
|
-
raise ValueError(
|
|
398
|
-
f"Expected a dictionary at index {i}, "
|
|
399
|
-
f"got {type(item).__name__}"
|
|
400
|
-
)
|
|
401
|
-
return data
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
class GenerativeDataset(Dataset):
|
|
405
|
-
r"""A dataset for generating synthetic datapoints using external agents and
|
|
406
|
-
verifiers.
|
|
407
|
-
|
|
408
|
-
This class leverages a seed dataset and external components to generate
|
|
409
|
-
new synthetic datapoints on demand.
|
|
410
|
-
"""
|
|
411
|
-
|
|
412
|
-
def __init__(
|
|
413
|
-
self,
|
|
414
|
-
seed_dataset: StaticDataset,
|
|
415
|
-
verifier: BaseVerifier,
|
|
416
|
-
agent: ChatAgent,
|
|
417
|
-
seed: int = 42,
|
|
418
|
-
**kwargs,
|
|
419
|
-
):
|
|
420
|
-
r"""Initialize the generative dataset.
|
|
421
|
-
|
|
422
|
-
Args:
|
|
423
|
-
seed_dataset (StaticDataset): Validated static dataset to
|
|
424
|
-
use for examples.
|
|
425
|
-
verifier (BaseVerifier): Verifier to validate generated content.
|
|
426
|
-
agent (ChatAgent): Agent to generate new datapoints.
|
|
427
|
-
seed (int): Random seed for reproducibility. (default: :obj:`42`)
|
|
428
|
-
**kwargs: Additional dataset parameters.
|
|
429
|
-
"""
|
|
430
|
-
|
|
431
|
-
self.seed_dataset = seed_dataset
|
|
432
|
-
self.verifier = verifier
|
|
433
|
-
self.agent = agent
|
|
434
|
-
|
|
435
|
-
self.seed = seed
|
|
436
|
-
random.seed(self.seed)
|
|
437
|
-
|
|
438
|
-
self._data: List[DataPoint] = []
|
|
439
|
-
|
|
440
|
-
def _construct_prompt(self, examples: List[DataPoint]) -> str:
|
|
441
|
-
r"""Construct a prompt for generating new datapoints
|
|
442
|
-
using a fixed sample of 3 examples from the seed dataset.
|
|
443
|
-
|
|
444
|
-
Args:
|
|
445
|
-
examples (List[DataPoint]): Examples to include in the prompt.
|
|
446
|
-
|
|
447
|
-
Returns:
|
|
448
|
-
str: Formatted prompt with examples.
|
|
449
|
-
"""
|
|
450
|
-
prompt = (
|
|
451
|
-
"Generate a new datapoint similar to the following examples:\n\n"
|
|
452
|
-
)
|
|
453
|
-
for i, example in enumerate(examples, 1):
|
|
454
|
-
prompt += f"Example {i}:\n"
|
|
455
|
-
prompt += f"Question: {example.question}\n"
|
|
456
|
-
prompt += f"Rationale: {example.rationale}\n"
|
|
457
|
-
prompt += f"Final Answer: {example.final_answer}\n\n"
|
|
458
|
-
prompt += "New datapoint:"
|
|
459
|
-
return prompt
|
|
460
|
-
|
|
461
|
-
async def generate_new(
|
|
462
|
-
self, n: int, max_retries: int = 10
|
|
463
|
-
) -> List[DataPoint]:
|
|
464
|
-
r"""Generates and validates `n` new datapoints through
|
|
465
|
-
few-shot prompting, with a retry limit.
|
|
466
|
-
|
|
467
|
-
Steps:
|
|
468
|
-
1. Samples 3 examples from the seed dataset.
|
|
469
|
-
2. Constructs a prompt using the selected examples.
|
|
470
|
-
3. Uses an agent to generate a new datapoint.
|
|
471
|
-
4. Verifies the datapoint using a verifier.
|
|
472
|
-
5. Stores valid datapoints in memory.
|
|
473
|
-
|
|
474
|
-
Args:
|
|
475
|
-
n (int): Number of valid datapoints to generate.
|
|
476
|
-
max_retries (int): Maximum number of retries before stopping.
|
|
477
|
-
|
|
478
|
-
Returns:
|
|
479
|
-
List[DataPoint]: A list of newly generated valid datapoints.
|
|
480
|
-
|
|
481
|
-
Raises:
|
|
482
|
-
TypeError: If the agent's output is not a dictionary (or does not
|
|
483
|
-
match the expected format).
|
|
484
|
-
KeyError: If required keys are missing from the response.
|
|
485
|
-
AttributeError: If the verifier response lacks attributes.
|
|
486
|
-
ValidationError: If a datapoint fails schema validation.
|
|
487
|
-
RuntimeError: If retries are exhausted before `n` valid datapoints
|
|
488
|
-
are generated.
|
|
489
|
-
|
|
490
|
-
Notes:
|
|
491
|
-
- Retries on validation failures until `n` valid datapoints exist
|
|
492
|
-
or `max_retries` is reached, whichever comes first.
|
|
493
|
-
- If retries are exhausted before reaching `n`, a `RuntimeError`
|
|
494
|
-
is raised.
|
|
495
|
-
- Metadata includes a timestamp for tracking datapoint creation.
|
|
496
|
-
- This method can be overridden to implement custom generation.
|
|
497
|
-
"""
|
|
498
|
-
valid_data_points: List[DataPoint] = []
|
|
499
|
-
retries = 0
|
|
500
|
-
|
|
501
|
-
while len(valid_data_points) < n and retries < max_retries:
|
|
502
|
-
try:
|
|
503
|
-
examples = [self.seed_dataset.sample() for _ in range(3)]
|
|
504
|
-
prompt = self._construct_prompt(examples)
|
|
505
|
-
|
|
506
|
-
try:
|
|
507
|
-
agent_output = (
|
|
508
|
-
self.agent.step(prompt, response_format=DataPoint)
|
|
509
|
-
.msgs[0]
|
|
510
|
-
.parsed
|
|
511
|
-
)
|
|
512
|
-
if not isinstance(agent_output, dict):
|
|
513
|
-
raise TypeError("Agent output must be a dictionary")
|
|
514
|
-
if (
|
|
515
|
-
"question" not in agent_output
|
|
516
|
-
or "rationale" not in agent_output
|
|
517
|
-
):
|
|
518
|
-
raise KeyError(
|
|
519
|
-
"Missing 'question' or 'rationale' in agent output"
|
|
520
|
-
)
|
|
521
|
-
except (TypeError, KeyError) as e:
|
|
522
|
-
logger.warning(
|
|
523
|
-
f"Agent output issue: {e}, retrying... "
|
|
524
|
-
f"({retries + 1}/{max_retries})"
|
|
525
|
-
)
|
|
526
|
-
retries += 1
|
|
527
|
-
continue
|
|
528
|
-
|
|
529
|
-
rationale = agent_output["rationale"]
|
|
530
|
-
|
|
531
|
-
try:
|
|
532
|
-
verifier_response = await self.verifier.verify(
|
|
533
|
-
VerifierInput(
|
|
534
|
-
llm_response=rationale, ground_truth=None
|
|
535
|
-
)
|
|
536
|
-
)
|
|
537
|
-
if not verifier_response or not verifier_response.result:
|
|
538
|
-
raise ValueError(
|
|
539
|
-
"Verifier unsuccessful, response: "
|
|
540
|
-
f"{verifier_response}"
|
|
541
|
-
)
|
|
542
|
-
except (ValueError, AttributeError) as e:
|
|
543
|
-
logger.warning(
|
|
544
|
-
f"Verifier issue: {e}, "
|
|
545
|
-
f"retrying... ({retries + 1}/{max_retries})"
|
|
546
|
-
)
|
|
547
|
-
retries += 1
|
|
548
|
-
continue
|
|
549
|
-
|
|
550
|
-
try:
|
|
551
|
-
new_datapoint = DataPoint(
|
|
552
|
-
question=agent_output["question"],
|
|
553
|
-
rationale=rationale,
|
|
554
|
-
final_answer=verifier_response.result,
|
|
555
|
-
metadata={
|
|
556
|
-
"synthetic": str(True),
|
|
557
|
-
"created": datetime.now().isoformat(),
|
|
558
|
-
},
|
|
559
|
-
)
|
|
560
|
-
except ValidationError as e:
|
|
561
|
-
logger.warning(
|
|
562
|
-
f"Datapoint validation failed: {e}, "
|
|
563
|
-
f"retrying... ({retries + 1}/{max_retries})"
|
|
564
|
-
)
|
|
565
|
-
retries += 1
|
|
566
|
-
continue
|
|
567
|
-
|
|
568
|
-
valid_data_points.append(new_datapoint)
|
|
569
|
-
|
|
570
|
-
except Exception as e:
|
|
571
|
-
logger.warning(
|
|
572
|
-
f"Unexpected error: {e}, retrying..."
|
|
573
|
-
f" ({retries + 1}/{max_retries})"
|
|
574
|
-
)
|
|
575
|
-
retries += 1
|
|
576
|
-
|
|
577
|
-
if len(valid_data_points) < n:
|
|
578
|
-
raise RuntimeError(
|
|
579
|
-
f"Failed to generate {n} valid datapoints "
|
|
580
|
-
f"after {max_retries} retries."
|
|
581
|
-
)
|
|
582
|
-
|
|
583
|
-
self._data.extend(valid_data_points)
|
|
584
|
-
return valid_data_points
|
|
585
|
-
|
|
586
|
-
def __len__(self) -> int:
|
|
587
|
-
r"""Return the size of the dataset."""
|
|
588
|
-
return len(self._data)
|
|
589
|
-
|
|
590
|
-
def __getitem__(self, idx: int) -> DataPoint:
|
|
591
|
-
r"""Retrieve a datapoint by index.
|
|
592
|
-
|
|
593
|
-
Args:
|
|
594
|
-
idx (int): Index of the datapoint.
|
|
595
|
-
|
|
596
|
-
Returns:
|
|
597
|
-
DataPoint: The datapoint corresponding to the given index.
|
|
598
|
-
|
|
599
|
-
Raises:
|
|
600
|
-
IndexError: If idx is out of bounds.
|
|
601
|
-
"""
|
|
602
|
-
if idx < 0 or idx >= len(self._data):
|
|
603
|
-
raise IndexError(
|
|
604
|
-
f"Index {idx} out of bounds for dataset of "
|
|
605
|
-
f"size {len(self._data)}"
|
|
606
|
-
)
|
|
607
|
-
return self._data[idx]
|
|
608
|
-
|
|
609
|
-
def save_to_jsonl(self, file_path: Union[str, Path]) -> None:
|
|
610
|
-
r"""Saves the dataset to a JSONL (JSON Lines) file.
|
|
611
|
-
|
|
612
|
-
Each datapoint is stored as a separate JSON object on a new line.
|
|
613
|
-
|
|
614
|
-
Args:
|
|
615
|
-
file_path (Union[str, Path]): Path to save the JSONL file.
|
|
616
|
-
|
|
617
|
-
Raises:
|
|
618
|
-
ValueError: If the dataset is empty.
|
|
619
|
-
IOError: If there is an issue writing to the file.
|
|
620
|
-
|
|
621
|
-
Notes:
|
|
622
|
-
- Uses `self._data`, which contains the generated datapoints.
|
|
623
|
-
- Overwrites the file if it already exists.
|
|
624
|
-
- Ensures compatibility with large datasets by using JSONL format.
|
|
625
|
-
"""
|
|
626
|
-
if not self._data:
|
|
627
|
-
raise ValueError("Dataset is empty. No data to save.")
|
|
628
|
-
|
|
629
|
-
file_path = Path(file_path)
|
|
630
|
-
|
|
631
|
-
try:
|
|
632
|
-
with file_path.open("w", encoding="utf-8") as f:
|
|
633
|
-
for datapoint in self._data:
|
|
634
|
-
json.dump(datapoint.to_dict(), f)
|
|
635
|
-
f.write("\n") # Ensure each entry is on a new line
|
|
636
|
-
logger.info(f"Dataset saved successfully to {file_path}")
|
|
637
|
-
except IOError as e:
|
|
638
|
-
logger.error(f"Error writing to file {file_path}: {e}")
|
|
639
|
-
raise
|