camel-ai 0.2.28__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

camel/datasets/base.py CHANGED
@@ -13,28 +13,26 @@
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
 
15
15
  import json
16
- import os
17
16
  import random
17
+ from datetime import datetime
18
18
  from pathlib import Path
19
19
  from typing import (
20
20
  Any,
21
- Callable,
22
21
  Dict,
23
22
  List,
24
23
  Optional,
25
24
  Sized,
26
- TypeVar,
27
25
  Union,
28
26
  )
29
27
 
30
- import torch
31
28
  from datasets import Dataset as HFDataset
32
29
  from pydantic import BaseModel, Field, ValidationError
33
- from torch.utils.data import DataLoader, Dataset
30
+ from torch.utils.data import Dataset
34
31
 
35
32
  from camel.agents import ChatAgent
36
33
  from camel.logger import get_logger
37
34
  from camel.verifiers import BaseVerifier
35
+ from camel.verifiers.models import VerifierInput
38
36
 
39
37
  logger = get_logger(__name__)
40
38
 
@@ -44,13 +42,9 @@ class DataPoint(BaseModel):
44
42
 
45
43
  Attributes:
46
44
  question (str): The primary question or issue to be addressed.
47
- rationale (str): Logical reasoning or explanation behind the
48
- answer.
45
+ rationale (Optional[str]): Logical reasoning or explanation behind the
46
+ answer. (default: :obj:`None`)
49
47
  final_answer (str): The final answer.
50
- raw_markdown (Optional[str]): Raw markdown content for generating
51
- rewards/hints. (default: :obj:`None`)
52
- difficulty (Optional[str]): Difficulty level of the question.
53
- (default: :obj:`None`)
54
48
  metadata Optional[Dict[str, Any]]: Additional metadata about the data
55
49
  point. (default: :obj:`None`)
56
50
  """
@@ -58,13 +52,12 @@ class DataPoint(BaseModel):
58
52
  question: str = Field(
59
53
  ..., description="The primary question or issue to be addressed."
60
54
  )
61
- rationale: str = Field(
62
- ..., description="Logical reasoning or explanation behind the answer."
55
+ rationale: Optional[str] = Field(
56
+ default=None,
57
+ description="Logical reasoning or explanation behind the answer.",
63
58
  )
64
59
  final_answer: str = Field(..., description="The final answer.")
65
- difficulty: Optional[str] = Field(
66
- None, description="Difficulty level of the question."
67
- )
60
+
68
61
  metadata: Optional[Dict[str, Any]] = Field(
69
62
  default=None, description="Additional metadata about the data point."
70
63
  )
@@ -90,244 +83,11 @@ class DataPoint(BaseModel):
90
83
  return cls(**data)
91
84
 
92
85
 
93
- class BaseDataset(Dataset):
94
- r"""A dataset contains questions and ground truth data for training.
95
- It can be either static (e.g., MATH dataset) or generative
96
- (using an LLM to generate questions).
97
- """
98
-
99
- def __init__(
100
- self,
101
- data: List[Dict[str, str]],
102
- cache_dir: Optional[str] = None,
103
- **kwargs,
104
- ):
105
- r"""Initialize the dataset.
106
-
107
- Args:
108
- data (List[Dict[str, str]]): List of dictionary items to
109
- create the dataset from.
110
- cache_dir (Optional[str]): Directory to cache dataset files.
111
- (default: :obj:`None`)
112
- **kwargs: Additional dataset parameters.
113
-
114
- Note:
115
- The dataset must be initialized by calling setup() before use.
116
- """
117
- self._is_setup = False
118
- self._raw_data: List[Dict[str, str]] = data if data is not None else []
119
- self._cache_dir = str(cache_dir) if cache_dir is not None else None
120
-
121
- # Store all parameters in metadata dict for compatibility
122
- self._metadata = {
123
- 'cache_dir': self._cache_dir,
124
- **kwargs,
125
- }
126
-
127
- self.data: List[DataPoint] = [] # Will be populated in setup()
128
-
129
- async def setup(self) -> None:
130
- r"""Set up the dataset with necessary resources.
131
-
132
- This method:
133
- 1. Creates cache directory if needed
134
- 2. Processes raw data into DataPoint objects using vectorized
135
- operations
136
- 3. Validates dataset integrity
137
-
138
- Raises:
139
- OSError: If cache directory creation fails.
140
- Exception: If dataset initialization fails.
141
- """
142
- if self._is_setup:
143
- logger.debug(f"{self.__class__.__name__} already initialized")
144
- return
145
-
146
- try:
147
- # Create cache directory if specified
148
- if self._cache_dir:
149
- try:
150
- os.makedirs(self._cache_dir, exist_ok=True)
151
- logger.debug(f"Created cache directory: {self._cache_dir}")
152
- except OSError as e:
153
- logger.error(
154
- f"Failed to create cache directory "
155
- f"{self._cache_dir}: {e}"
156
- )
157
- raise
158
-
159
- # Process raw data into DataPoint objects using vectorized
160
- # operations
161
- if not self._raw_data:
162
- self.data = []
163
- logger.debug("No raw data to process")
164
- else:
165
- try:
166
- # Helper function for validation that can be used with map
167
- def create_datapoint(item, idx=None):
168
- try:
169
- return DataPoint(
170
- question=item.get('question', ''),
171
- rationale=item.get('rationale', ''),
172
- final_answer=item.get('final_answer', ''),
173
- metadata=item.get('metadata', {})
174
- if isinstance(item.get('metadata'), dict)
175
- else {},
176
- raw_markdown='',
177
- difficulty='',
178
- )
179
- except ValidationError as e:
180
- idx_str = (
181
- f" at index {idx}" if idx is not None else ""
182
- )
183
- error_msg = (
184
- f"Sample{idx_str} validation error: {e}"
185
- )
186
- logger.error(error_msg)
187
- raise ValueError(error_msg)
188
-
189
- # If raw_data is already a HF dataset, use its map function
190
- if hasattr(self._raw_data, 'map') and callable(
191
- self._raw_data.map
192
- ):
193
- # Using HF dataset's map for vectorized processing
194
- processed_data = self._raw_data.map(
195
- lambda example, idx: {
196
- 'datapoint': create_datapoint(example, idx)
197
- },
198
- with_indices=True,
199
- )
200
- self.data = [
201
- item['datapoint'] for item in processed_data
202
- ]
203
- else:
204
- # Bulk create datapoints
205
- self.data = [
206
- create_datapoint(item, i)
207
- for i, item in enumerate(self._raw_data)
208
- ]
209
-
210
- logger.debug(f"Processed {len(self.data)} data points")
211
- except Exception as e:
212
- logger.error(f"Error processing data: {e}")
213
- raise
214
-
215
- self._is_setup = True
216
- logger.info(f"{self.__class__.__name__} initialized successfully")
217
-
218
- except Exception as e:
219
- logger.error(f"Error during {self.__class__.__name__} setup: {e}")
220
- await self.cleanup()
221
- raise
222
-
223
- async def cleanup(self) -> None:
224
- r"""Clean up dataset resources.
225
-
226
- This method handles cleanup of resources and resets the dataset state.
227
- """
228
- if not self._is_setup:
229
- return
230
-
231
- try:
232
- # Clear metadata while preserving init config
233
- init_config = {
234
- 'cache_dir': self._cache_dir,
235
- }
236
- self._metadata = init_config
237
-
238
- logger.info(f"{self.__class__.__name__} cleaned up successfully")
239
-
240
- except Exception as e:
241
- logger.error(
242
- f"Error during {self.__class__.__name__} cleanup: {e}"
243
- )
244
- raise
245
-
246
- finally:
247
- # Always mark as uninitialized, even if cleanup fails
248
- self._is_setup = False
249
-
250
- def sample(self) -> DataPoint:
251
- r"""Sample a random datapoint from the dataset.
252
-
253
- Returns:
254
- DataPoint: A randomly sampled DataPoint.
255
-
256
- Raises:
257
- RuntimeError: If dataset is not initialized.
258
- """
259
- if not self._is_setup:
260
- raise RuntimeError(
261
- f"{self.__class__.__name__} must be initialized "
262
- "before sampling"
263
- )
264
-
265
- idx = random.randint(0, len(self) - 1)
266
- return self[idx]
267
-
268
- def __len__(self) -> int:
269
- r"""Return the size of the dataset."""
270
- return len(self.data)
271
-
272
- def __getitem__(self, idx: int) -> DataPoint:
273
- r"""Get an item from the dataset.
274
-
275
- Args:
276
- idx (int): Index of the item to get.
277
-
278
- Returns:
279
- DataPoint: DataPoint from the dataset with the given index.
280
-
281
- Raises:
282
- IndexError: If idx is out of bounds.
283
- """
284
- if idx < 0 or idx >= len(self):
285
- raise IndexError(
286
- f"Index {idx} out of bounds for dataset of size {len(self)}"
287
- )
288
-
289
- return self.data[idx]
290
-
291
- @property
292
- def metadata(self) -> Dict[str, Any]:
293
- r"""Get dataset metadata."""
294
- return self._metadata.copy()
295
-
296
- def to_pytorch_dataset(
297
- self,
298
- transform: Optional[Callable] = None,
299
- target_transform: Optional[Callable] = None,
300
- batch_size: Optional[int] = None,
301
- ) -> Union["PyTorchDataset", "DataLoader"]:
302
- r"""Convert to a PyTorch dataset or DataLoader.
303
-
304
- Args:
305
- transform (Optional[Callable]): Transform to apply to samples.
306
- target_transform (Optional[Callable]): Transform to apply to
307
- targets.
308
- batch_size (Optional[int]): If provided, returns a DataLoader with
309
- the specified batch size instead of a PyTorchDataset.
310
- (default: :obj:`None`)
311
-
312
- Returns:
313
- Union[PyTorchDataset, torch.utils.data.DataLoader]: Dataset in
314
- PyTorch format or DataLoader if batch_size is provided.
315
- """
316
- dataset = PyTorchDataset.from_datapoints(
317
- self.data,
318
- transform=transform,
319
- target_transform=target_transform,
320
- )
321
-
322
- if batch_size is not None:
323
- return dataset.get_dataloader(batch_size=batch_size)
324
-
325
- return dataset
326
-
327
-
328
- class SeedDataset(BaseDataset):
329
- r"""A dataset containing validated seed examples for data generation.
86
+ class StaticDataset(Dataset):
87
+ r"""A static dataset containing a list of datapoints.
330
88
  Ensures that all items adhere to the DataPoint schema.
89
+ This dataset extends :obj:`Dataset` from PyTorch and should
90
+ be used when its size is fixed at runtime.
331
91
 
332
92
  This class can initialize from Hugging Face Datasets,
333
93
  PyTorch Datasets, JSON file paths, or lists of dictionaries,
@@ -337,46 +97,46 @@ class SeedDataset(BaseDataset):
337
97
  def __init__(
338
98
  self,
339
99
  data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]],
340
- cache_dir: Optional[str] = None,
341
- seed: Optional[int] = None,
100
+ seed: int = 42,
342
101
  min_samples: int = 1,
343
102
  strict: bool = False,
344
103
  **kwargs,
345
104
  ):
346
- r"""Initialize the seed dataset and validate integrity.
105
+ r"""Initialize the static dataset and validate integrity.
347
106
 
348
107
  Args:
349
- data (Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]):
350
- Input data, which can be:
351
- - A Hugging Face Dataset (HFDataset)
352
- - A PyTorch Dataset (torch.utils.data.Dataset)
353
- - A Path object representing the path to a JSON file
354
- - A list of dictionaries with DataPoint-compatible fields
355
- seed (Optional[int]): Seed for reproducibility.
356
- (default: :obj:`1`)
357
- min_samples (int): Minimum number of samples required.
358
- (default: :obj:`1`)
359
- strict (bool): Whether to raise an error on invalid datapoints
360
- (True) or skip/filter them (False). (default: False)
108
+ data (:obj:`Union[HFDataset, Dataset,
109
+ Path, List[Dict[str, Any]]]`):
110
+ Input data, which can be one of the following:
111
+ - A Hugging Face Dataset (:obj:`HFDataset`).
112
+ - A PyTorch Dataset (:obj:`torch.utils.data.Dataset`).
113
+ - A :obj:`Path` object representing a JSON file.
114
+ - A list of dictionaries with :obj:`DataPoint`-compatible
115
+ fields.
116
+ seed (:obj:`int`): Random seed for reproducibility.
117
+ Default is :obj:`42`.
118
+ min_samples (:obj:`int`): Minimum required number of samples.
119
+ Default is :obj:`1`.
120
+ strict (:obj:`bool`): Whether to raise an error on invalid
121
+ datapoints (:obj:`True`) or skip/filter them (:obj:`False`).
122
+ Default is :obj:`False`.
361
123
  **kwargs: Additional dataset parameters.
362
124
 
363
125
  Raises:
364
- TypeError: If the data type is not supported.
365
- ValueError: If dataset size is less than min_samples or
366
- if sample validation fails.
367
- FileNotFoundError: If the JSON file path doesn't exist.
368
- json.JSONDecodeError: If the JSON file is invalid.
126
+ TypeError: If the input data type is unsupported.
127
+ ValueError: If the dataset contains fewer than :obj:`min_samples`
128
+ datapoints or if validation fails.
129
+ FileNotFoundError: If the specified JSON file path does not exist.
130
+ json.JSONDecodeError: If the JSON file contains invalid formatting.
369
131
  """
370
- # Initialize BaseDataset with empty data, we'll populate it ourselves
371
- super().__init__(data=[], cache_dir=cache_dir, **kwargs)
372
132
 
133
+ # Store all parameters in metadata dict for compatibility
134
+ self._metadata = {
135
+ **kwargs,
136
+ }
373
137
  self._rng = random.Random(seed)
374
138
  self._strict = strict
375
139
 
376
- # Type checking and conversion into list of dicts to have a
377
- # consistent internal format. Since Seed Dataset should be
378
- # small, we can load it entirely into memory
379
-
380
140
  self.data: List[DataPoint] = self._init_data(data)
381
141
  self._length = len(self.data)
382
142
 
@@ -389,6 +149,26 @@ class SeedDataset(BaseDataset):
389
149
  def _init_data(
390
150
  self, data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]
391
151
  ) -> List[DataPoint]:
152
+ r"""Convert input data from various formats into a list of
153
+ :obj:`DataPoint` instances.
154
+
155
+ Args:
156
+ data (:obj:`Union[
157
+ HFDataset,
158
+ Dataset,
159
+ Path,
160
+ List[Dict[str, Any]]
161
+ ]`):
162
+ Input dataset in one of the supported formats.
163
+
164
+ Returns:
165
+ :obj:`List[DataPoint]`: A list of validated :obj:`DataPoint`
166
+ instances.
167
+
168
+ Raises:
169
+ TypeError: If the input data type is unsupported.
170
+ """
171
+
392
172
  if isinstance(data, HFDataset):
393
173
  raw_data = self._init_from_hf_dataset(data)
394
174
  elif isinstance(data, Dataset):
@@ -450,7 +230,6 @@ class SeedDataset(BaseDataset):
450
230
  rationale=rationale,
451
231
  final_answer=final_answer,
452
232
  metadata=item.get('metadata'),
453
- difficulty=item.get('difficulty'),
454
233
  )
455
234
  except ValidationError as e:
456
235
  if self._strict:
@@ -474,17 +253,19 @@ class SeedDataset(BaseDataset):
474
253
  return self._length
475
254
 
476
255
  def __getitem__(self, idx: int) -> DataPoint:
477
- r"""Get an item from the dataset.
256
+ r"""Retrieve a datapoint by index.
478
257
 
479
258
  Args:
480
- idx (int): Index of the item to get.
259
+ idx (:obj:`int`): Index of the datapoint.
481
260
 
482
261
  Returns:
483
- DataPoint: DataPoint from the dataset with the given index.
262
+ :obj:`DataPoint`: The datapoint corresponding to the given index.
484
263
 
485
264
  Raises:
486
- IndexError: If idx is out of bounds.
265
+ IndexError: If :obj:`idx` is out of bounds (negative or greater
266
+ than dataset length - 1).
487
267
  """
268
+
488
269
  if idx < 0 or idx >= self._length:
489
270
  raise IndexError(
490
271
  f"Index {idx} out of bounds for dataset of size {self._length}"
@@ -495,11 +276,12 @@ class SeedDataset(BaseDataset):
495
276
  r"""Sample a random datapoint from the dataset.
496
277
 
497
278
  Returns:
498
- DataPoint: A randomly sampled DataPoint.
279
+ :obj:`DataPoint`: A randomly sampled :obj:`DataPoint`.
499
280
 
500
281
  Raises:
501
- RuntimeError: If the dataset is empty.
282
+ RuntimeError: If the dataset is empty and no samples can be drawn.
502
283
  """
284
+
503
285
  if self._length == 0:
504
286
  raise RuntimeError("Dataset is empty, cannot sample.")
505
287
  idx = self._rng.randint(0, self._length - 1)
@@ -507,15 +289,42 @@ class SeedDataset(BaseDataset):
507
289
 
508
290
  @property
509
291
  def metadata(self) -> Dict[str, Any]:
510
- r"""Get dataset metadata."""
292
+ r"""Retrieve dataset metadata.
293
+
294
+ Returns:
295
+ :obj:`Dict[str, Any]`: A copy of the dataset metadata dictionary.
296
+ """
297
+
511
298
  return self._metadata.copy()
512
299
 
513
300
  def _init_from_hf_dataset(self, data: HFDataset) -> List[Dict[str, Any]]:
301
+ r"""Convert a Hugging Face dataset into a list of dictionaries.
302
+
303
+ Args:
304
+ data (:obj:`HFDataset`): A Hugging Face dataset.
305
+
306
+ Returns:
307
+ :obj:`List[Dict[str, Any]]`: A list of dictionaries representing
308
+ the dataset, where each dictionary corresponds to a datapoint.
309
+ """
514
310
  return [dict(item) for item in data]
515
311
 
516
312
  def _init_from_pytorch_dataset(
517
313
  self, data: Dataset
518
314
  ) -> List[Dict[str, Any]]:
315
+ r"""Convert a PyTorch dataset into a list of dictionaries.
316
+
317
+ Args:
318
+ data (:obj:`Dataset`): A PyTorch dataset.
319
+
320
+ Returns:
321
+ :obj:`List[Dict[str, Any]]`: A list of dictionaries representing
322
+ the dataset.
323
+
324
+ Raises:
325
+ TypeError: If the dataset does not implement :obj:`__len__()`
326
+ or contains non-dictionary elements.
327
+ """
519
328
  if not isinstance(data, Sized):
520
329
  raise TypeError(
521
330
  f"{type(data).__name__} does not implement `__len__()`."
@@ -533,6 +342,20 @@ class SeedDataset(BaseDataset):
533
342
  return raw_data
534
343
 
535
344
  def _init_from_json_path(self, data: Path) -> List[Dict[str, Any]]:
345
+ r"""Load and parse a dataset from a JSON file.
346
+
347
+ Args:
348
+ data (:obj:`Path`): Path to the JSON file.
349
+
350
+ Returns:
351
+ :obj:`List[Dict[str, Any]]`: A list of datapoint dictionaries.
352
+
353
+ Raises:
354
+ FileNotFoundError: If the specified JSON file does not exist.
355
+ ValueError: If the JSON content is not a list of dictionaries.
356
+ json.JSONDecodeError: If the JSON file has invalid formatting.
357
+ """
358
+
536
359
  if not data.exists():
537
360
  raise FileNotFoundError(f"JSON file not found: {data}")
538
361
  try:
@@ -557,6 +380,18 @@ class SeedDataset(BaseDataset):
557
380
  def _init_from_list(
558
381
  self, data: List[Dict[str, Any]]
559
382
  ) -> List[Dict[str, Any]]:
383
+ r"""Validate and convert a list of dictionaries into a dataset.
384
+
385
+ Args:
386
+ data (:obj:`List[Dict[str, Any]]`): A list of dictionaries where
387
+ each dictionary must be a valid :obj:`DataPoint`.
388
+
389
+ Returns:
390
+ :obj:`List[Dict[str, Any]]`: The validated list of dictionaries.
391
+
392
+ Raises:
393
+ ValueError: If any item in the list is not a dictionary.
394
+ """
560
395
  for i, item in enumerate(data):
561
396
  if not isinstance(item, dict):
562
397
  raise ValueError(
@@ -566,111 +401,7 @@ class SeedDataset(BaseDataset):
566
401
  return data
567
402
 
568
403
 
569
- class SyntheticDataset(BaseDataset):
570
- r"""A dataset for storing synthetically generated data points.
571
-
572
- This class is used to store datapoints that are generated through
573
- a generative process, such as using an agent.
574
- """
575
-
576
- def __init__(
577
- self,
578
- data: Optional[List[Dict[str, str]]] = None,
579
- cache_dir: Optional[str] = None,
580
- **kwargs,
581
- ):
582
- r"""Initialize the synthetic dataset.
583
-
584
- Args:
585
- data (Optional[List[Dict[str, str]]]): List of dictionary items to
586
- create the dataset from. (default: :obj:`None`)
587
- cache_dir (Optional[str]): Directory to cache dataset files.
588
- (default: :obj:`None`)
589
- **kwargs: Additional dataset parameters.
590
- """
591
- super().__init__(
592
- data=data if data is not None else [],
593
- cache_dir=cache_dir,
594
- **kwargs,
595
- )
596
- self.data: List[DataPoint] = []
597
-
598
- def add(self, item: DataPoint) -> None:
599
- r"""Add a new data point to the dataset.
600
-
601
- Args:
602
- item (DataPoint): The datapoint to add to the dataset.
603
- """
604
- self.data.append(item)
605
-
606
- def add_batch(self, items: List[DataPoint]) -> None:
607
- r"""Add multiple data points to the dataset.
608
-
609
- Args:
610
- items (List[DataPoint]): The datapoints to add to the dataset.
611
- """
612
- self.data.extend(items)
613
-
614
- def to_pytorch_dataset(
615
- self,
616
- transform: Optional[Callable] = None,
617
- target_transform: Optional[Callable] = None,
618
- batch_size: Optional[int] = None,
619
- ) -> Union["PyTorchDataset", "DataLoader"]:
620
- r"""Convert to a PyTorch dataset or DataLoader.
621
-
622
- Args:
623
- transform (Optional[Callable]): Transform to apply to samples.
624
- target_transform (Optional[Callable]): Transform to apply to
625
- targets.
626
- batch_size (Optional[int]): If provided, returns a DataLoader with
627
- the specified batch size instead of a PyTorchDataset.
628
- (default: :obj:`None`)
629
-
630
- Returns:
631
- Union[PyTorchDataset, torch.utils.data.DataLoader]: Dataset in
632
- PyTorch format or DataLoader if batch_size is provided.
633
- """
634
- return convert_synthetic_to_pytorch(
635
- self,
636
- transform=transform,
637
- target_transform=target_transform,
638
- batch_size=batch_size,
639
- )
640
-
641
- def save_pytorch_format(self, path: str, compression: bool = True) -> None:
642
- r"""Save the dataset to disk in PyTorch format.
643
-
644
- Args:
645
- path (str): Path to save the dataset to.
646
- compression (bool): Whether to use compression to reduce file size.
647
- (default: :obj:`True`)
648
- """
649
- save_synthetic_dataset(self, path, compression=compression)
650
-
651
- def filter(
652
- self, predicate: Callable[[DataPoint], bool]
653
- ) -> 'SyntheticDataset':
654
- r"""Filter the dataset using a predicate function.
655
-
656
- Args:
657
- predicate (Callable[[DataPoint], bool]): Function that takes a
658
- DataPoint and returns True if it should be kept, False
659
- otherwise.
660
-
661
- Returns:
662
- SyntheticDataset: A new dataset containing only the filtered items.
663
- """
664
- filtered_data = [dp for dp in self.data if predicate(dp)]
665
-
666
- # Create a new dataset with the filtered data
667
- new_dataset = SyntheticDataset()
668
- new_dataset.add_batch(filtered_data)
669
-
670
- return new_dataset
671
-
672
-
673
- class GenerativeDataset(BaseDataset):
404
+ class GenerativeDataset(Dataset):
674
405
  r"""A dataset for generating synthetic datapoints using external agents and
675
406
  verifiers.
676
407
 
@@ -680,30 +411,22 @@ class GenerativeDataset(BaseDataset):
680
411
 
681
412
  def __init__(
682
413
  self,
683
- seed_dataset: SeedDataset,
414
+ seed_dataset: StaticDataset,
684
415
  verifier: BaseVerifier,
685
416
  agent: ChatAgent,
686
- cache_dir: Optional[str] = None,
687
417
  seed: int = 42,
688
418
  **kwargs,
689
419
  ):
690
420
  r"""Initialize the generative dataset.
691
421
 
692
422
  Args:
693
- seed_dataset (SeedDataset): Validated dataset to use for examples.
423
+ seed_dataset (StaticDataset): Validated static dataset to
424
+ use for examples.
694
425
  verifier (BaseVerifier): Verifier to validate generated content.
695
426
  agent (ChatAgent): Agent to generate new datapoints.
696
- cache_dir (Optional[str]): Directory to cache dataset files.
697
- (default: :obj:`None`)
698
427
  seed (int): Random seed for reproducibility. (default: :obj:`42`)
699
428
  **kwargs: Additional dataset parameters.
700
429
  """
701
- # Initialize with empty data since we'll generate content dynamically
702
- super().__init__(
703
- data=[],
704
- cache_dir=cache_dir,
705
- **kwargs,
706
- )
707
430
 
708
431
  self.seed_dataset = seed_dataset
709
432
  self.verifier = verifier
@@ -712,8 +435,11 @@ class GenerativeDataset(BaseDataset):
712
435
  self.seed = seed
713
436
  random.seed(self.seed)
714
437
 
438
+ self._data: List[DataPoint] = []
439
+
715
440
  def _construct_prompt(self, examples: List[DataPoint]) -> str:
716
- r"""Construct a prompt for generating new datapoints.
441
+ r"""Construct a prompt for generating new datapoints
442
+ using a fixed sample of 3 examples from the seed dataset.
717
443
 
718
444
  Args:
719
445
  examples (List[DataPoint]): Examples to include in the prompt.
@@ -732,440 +458,182 @@ class GenerativeDataset(BaseDataset):
732
458
  prompt += "New datapoint:"
733
459
  return prompt
734
460
 
735
- async def generate_new(self, n: int) -> None:
736
- r"""Generate n new datapoints and add them to the dataset.
461
+ async def generate_new(
462
+ self, n: int, max_retries: int = 10
463
+ ) -> List[DataPoint]:
464
+ r"""Generates and validates `n` new datapoints through
465
+ few-shot prompting, with a retry limit.
466
+
467
+ Steps:
468
+ 1. Samples 3 examples from the seed dataset.
469
+ 2. Constructs a prompt using the selected examples.
470
+ 3. Uses an agent to generate a new datapoint.
471
+ 4. Verifies the datapoint using a verifier.
472
+ 5. Stores valid datapoints in memory.
737
473
 
738
474
  Args:
739
475
  n (int): Number of valid datapoints to generate.
476
+ max_retries (int): Maximum number of retries before stopping.
477
+
478
+ Returns:
479
+ List[DataPoint]: A list of newly generated valid datapoints.
740
480
 
741
- This method generates new datapoints by:
742
- 1. Sampling examples from the seed dataset
743
- 2. Constructing a prompt for the agent
744
- 3. Generating a new datapoint using the agent
745
- 4. Verifying the generated datapoint with the verifier
746
- 5. Adding valid datapoints to the dataset
481
+ Raises:
482
+ TypeError: If the agent's output is not a dictionary (or does not
483
+ match the expected format).
484
+ KeyError: If required keys are missing from the response.
485
+ AttributeError: If the verifier response lacks attributes.
486
+ ValidationError: If a datapoint fails schema validation.
487
+ RuntimeError: If retries are exhausted before `n` valid datapoints
488
+ are generated.
489
+
490
+ Notes:
491
+ - Retries on validation failures until `n` valid datapoints exist
492
+ or `max_retries` is reached, whichever comes first.
493
+ - If retries are exhausted before reaching `n`, a `RuntimeError`
494
+ is raised.
495
+ - Metadata includes a timestamp for tracking datapoint creation.
496
+ - This method can be overridden to implement custom generation.
747
497
  """
748
498
  valid_data_points: List[DataPoint] = []
499
+ retries = 0
749
500
 
750
- while len(valid_data_points) < n:
501
+ while len(valid_data_points) < n and retries < max_retries:
751
502
  try:
752
- indices = random.sample(range(len(self.seed_dataset)), 3)
753
- examples = [self.seed_dataset[i] for i in indices]
503
+ examples = [self.seed_dataset.sample() for _ in range(3)]
754
504
  prompt = self._construct_prompt(examples)
755
505
 
756
- # Get agent response
757
- agent_output = (
758
- self.agent.step(prompt, response_format=DataPoint)
759
- .msgs[0]
760
- .parsed
761
- )
762
-
763
- if not isinstance(agent_output, dict):
764
- raise TypeError("Agent output must be a dictionary")
765
- if (
766
- 'question' not in agent_output
767
- or 'rationale' not in agent_output
768
- ):
769
- raise KeyError(
770
- "Agent output missing required keys: "
771
- "'question' or 'rationale'"
506
+ try:
507
+ agent_output = (
508
+ self.agent.step(prompt, response_format=DataPoint)
509
+ .msgs[0]
510
+ .parsed
772
511
  )
773
-
774
- rationale = agent_output['rationale']
775
-
776
- # Verify the generated content
777
- verifier_response = await self.verifier.verify(rationale)
778
- if not hasattr(verifier_response, 'content'):
779
- raise AttributeError(
780
- "Verifier response missing 'content' attribute"
512
+ if not isinstance(agent_output, dict):
513
+ raise TypeError("Agent output must be a dictionary")
514
+ if (
515
+ "question" not in agent_output
516
+ or "rationale" not in agent_output
517
+ ):
518
+ raise KeyError(
519
+ "Missing 'question' or 'rationale' in agent output"
520
+ )
521
+ except (TypeError, KeyError) as e:
522
+ logger.warning(
523
+ f"Agent output issue: {e}, retrying... "
524
+ f"({retries + 1}/{max_retries})"
781
525
  )
782
-
783
- if not verifier_response.result:
526
+ retries += 1
784
527
  continue
785
528
 
786
- final_answer = verifier_response.result
787
-
788
- # Create and validate the new datapoint
789
- new_datapoint = {
790
- 'question': agent_output['question'],
791
- 'rationale': rationale,
792
- 'final_answer': final_answer,
793
- }
794
-
795
- datapoint = DataPoint(**new_datapoint)
796
- valid_data_points.append(datapoint)
797
-
798
- except (TypeError, KeyError, AttributeError, ValidationError) as e:
799
- logger.warning(
800
- f"Error encountered during generation: {e}, retrying..."
801
- )
802
-
803
- # Add all valid datapoints to the dataset
804
- for datapoint in valid_data_points:
805
- self.data.append(datapoint)
806
- logger.debug("Added new datapoint to dataset")
807
-
808
-
809
- # Define a type variable for return type flexibility
810
- T_co = TypeVar('T_co', covariant=True)
811
-
529
+ rationale = agent_output["rationale"]
812
530
 
813
- class PyTorchDataset(Dataset[T_co]):
814
- r"""A PyTorch-compatible dataset implementation that leverages PyTorch's
815
- efficient data handling capabilities.
816
- """
817
-
818
- def __init__(
819
- self,
820
- data: List[Dict[str, Any]],
821
- transform: Optional[Callable] = None,
822
- target_transform: Optional[Callable] = None,
823
- validate: bool = True,
824
- ):
825
- r"""Initialize the PyTorch dataset.
826
-
827
- Args:
828
- data (List[Dict[str, Any]]): List of dictionary items to create
829
- the dataset from.
830
- transform (Optional[Callable]): A function/transform that takes a
831
- sample and returns a transformed version for features.
832
- (default: :obj:`None`)
833
- target_transform (Optional[Callable]): A function/transform that
834
- takes a target and returns a transformed version. (default:
835
- :obj:`None`)
836
- validate (bool): Whether to validate data points using DataPoint
837
- schema. (default: :obj:`True`)
838
-
839
- Raises:
840
- ValidationError: If validation is enabled and data doesn't match
841
- DataPoint schema.
842
- """
843
- self.transform = transform
844
- self.target_transform = target_transform
845
-
846
- # Validate and store data
847
- self._raw_data = data
848
- self.data = []
531
+ try:
532
+ verifier_response = await self.verifier.verify(
533
+ VerifierInput(
534
+ llm_response=rationale, ground_truth=None
535
+ )
536
+ )
537
+ if not verifier_response or not verifier_response.result:
538
+ raise ValueError(
539
+ "Verifier unsuccessful, response: "
540
+ f"{verifier_response}"
541
+ )
542
+ except (ValueError, AttributeError) as e:
543
+ logger.warning(
544
+ f"Verifier issue: {e}, "
545
+ f"retrying... ({retries + 1}/{max_retries})"
546
+ )
547
+ retries += 1
548
+ continue
849
549
 
850
- if validate:
851
- for i, item in enumerate(self._raw_data):
852
550
  try:
853
- # Use DataPoint for validation only
854
- dp = DataPoint(**item)
855
- self.data.append(dp.to_dict())
551
+ new_datapoint = DataPoint(
552
+ question=agent_output["question"],
553
+ rationale=rationale,
554
+ final_answer=verifier_response.result,
555
+ metadata={
556
+ "synthetic": str(True),
557
+ "created": datetime.now().isoformat(),
558
+ },
559
+ )
856
560
  except ValidationError as e:
857
- logger.error(f"Sample {i} validation error: {e}")
858
- raise ValueError(f"Sample {i} validation error: {e}")
859
- else:
860
- # Skip validation and just store the data dictionaries
861
- self.data = [dict(item) for item in self._raw_data]
862
-
863
- def __getitem__(self, index: int) -> T_co:
864
- r"""Get an item from the dataset.
561
+ logger.warning(
562
+ f"Datapoint validation failed: {e}, "
563
+ f"retrying... ({retries + 1}/{max_retries})"
564
+ )
565
+ retries += 1
566
+ continue
865
567
 
866
- Args:
867
- index (int): Index of the item to get.
568
+ valid_data_points.append(new_datapoint)
868
569
 
869
- Returns:
870
- T_co: Item from the dataset, possibly transformed.
871
- """
872
- sample = self.data[index]
570
+ except Exception as e:
571
+ logger.warning(
572
+ f"Unexpected error: {e}, retrying..."
573
+ f" ({retries + 1}/{max_retries})"
574
+ )
575
+ retries += 1
873
576
 
874
- # Apply transformations if provided
875
- if self.transform is not None:
876
- sample = self.transform(sample)
577
+ if len(valid_data_points) < n:
578
+ raise RuntimeError(
579
+ f"Failed to generate {n} valid datapoints "
580
+ f"after {max_retries} retries."
581
+ )
877
582
 
878
- return sample # type: ignore[return-value]
583
+ self._data.extend(valid_data_points)
584
+ return valid_data_points
879
585
 
880
586
  def __len__(self) -> int:
881
- r"""Return the size of the dataset.
882
-
883
- Returns:
884
- int: Number of samples in the dataset.
885
- """
886
- return len(self.data)
587
+ r"""Return the size of the dataset."""
588
+ return len(self._data)
887
589
 
888
- @classmethod
889
- def from_datapoints(
890
- cls, datapoints: List[DataPoint], **kwargs
891
- ) -> 'PyTorchDataset':
892
- r"""Create a PyTorchDataset from a list of DataPoints.
590
+ def __getitem__(self, idx: int) -> DataPoint:
591
+ r"""Retrieve a datapoint by index.
893
592
 
894
593
  Args:
895
- datapoints (List[DataPoint]): List of DataPoint objects.
896
- **kwargs: Additional arguments to pass to the constructor.
594
+ idx (int): Index of the datapoint.
897
595
 
898
596
  Returns:
899
- PyTorchDataset: A new PyTorchDataset instance.
900
- """
901
- data = [dp.to_dict() for dp in datapoints]
902
- # We can skip validation since datapoints are already validated
903
- return cls(data, validate=False, **kwargs)
904
-
905
- def to_hf_dataset(self) -> HFDataset:
906
- r"""Convert to a HuggingFace dataset.
907
-
908
- Returns:
909
- HFDataset: Dataset in HuggingFace format.
910
- """
911
- return HFDataset.from_list(self.data)
597
+ DataPoint: The datapoint corresponding to the given index.
912
598
 
913
- def save_to_disk(self, path: str) -> None:
914
- r"""Save the dataset to disk using PyTorch.
915
-
916
- Args:
917
- path (str): Path to save the dataset to.
599
+ Raises:
600
+ IndexError: If idx is out of bounds.
918
601
  """
919
- torch.save(self.data, path)
920
-
921
- @classmethod
922
- def load_from_disk(
923
- cls,
924
- path: str,
925
- transform: Optional[Callable] = None,
926
- target_transform: Optional[Callable] = None,
927
- ) -> 'PyTorchDataset':
928
- r"""Load a dataset from disk.
929
-
930
- Args:
931
- path (str): Path to load the dataset from.
932
- transform (Optional[Callable]): Transform to apply to samples.
933
- (default: :obj:`None`)
934
- target_transform (Optional[Callable]): Transform to apply to
935
- targets. (default: :obj:`None`)
602
+ if idx < 0 or idx >= len(self._data):
603
+ raise IndexError(
604
+ f"Index {idx} out of bounds for dataset of "
605
+ f"size {len(self._data)}"
606
+ )
607
+ return self._data[idx]
936
608
 
937
- Returns:
938
- PyTorchDataset: Loaded dataset.
939
- """
940
- data = torch.load(path)
941
- return cls(
942
- data,
943
- transform=transform,
944
- target_transform=target_transform,
945
- validate=False,
946
- )
609
+ def save_to_jsonl(self, file_path: Union[str, Path]) -> None:
610
+ r"""Saves the dataset to a JSONL (JSON Lines) file.
947
611
 
948
- @staticmethod
949
- def collate_fn(
950
- batch: List[Dict[str, Any]],
951
- ) -> Dict[str, Union[List[Any], torch.Tensor]]:
952
- r"""Collate function for PyTorch DataLoader.
612
+ Each datapoint is stored as a separate JSON object on a new line.
953
613
 
954
614
  Args:
955
- batch (List[Dict[str, Any]]): Batch of samples from the dataset.
615
+ file_path (Union[str, Path]): Path to save the JSONL file.
956
616
 
957
- Returns:
958
- Dict[str, Union[List[Any], torch.Tensor]]: Collated batch with
959
- tensors for numerical data.
960
- """
961
- if not batch:
962
- return {}
963
-
964
- # Initialize result dictionary with keys from first item - start with
965
- # lists only
966
- result: Dict[str, List[Any]] = {k: [] for k in batch[0].keys()}
967
-
968
- # Collect values by key
969
- for item in batch:
970
- for k, v in item.items():
971
- result[k].append(v)
972
-
973
- # Convert numeric/boolean lists to tensors where possible
974
- result_with_tensors: Dict[str, Union[List[Any], torch.Tensor]] = {}
975
- for k, v in result.items():
976
- if all(isinstance(x, (int, float, bool)) for x in v):
977
- try:
978
- result_with_tensors[k] = torch.tensor(v)
979
- except (ValueError, TypeError):
980
- # Keep as list if tensor conversion fails
981
- result_with_tensors[k] = v
982
- else:
983
- result_with_tensors[k] = v
984
-
985
- return result_with_tensors
986
-
987
- def get_dataloader(
988
- self,
989
- batch_size: int = 32,
990
- shuffle: bool = True,
991
- num_workers: int = 0,
992
- pin_memory: bool = False,
993
- **kwargs,
994
- ) -> "DataLoader":
995
- r"""Create a PyTorch DataLoader for this dataset.
996
-
997
- Args:
998
- batch_size (int): Batch size. (default: :obj:`32`)
999
- shuffle (bool): Whether to shuffle the dataset. (default:
1000
- :obj:`True`)
1001
- num_workers (int): Number of workers for data loading. (default:
1002
- :obj:`0`)
1003
- pin_memory (bool): Whether to pin memory for faster GPU transfer.
1004
- (default: :obj:`False`)
1005
- **kwargs: Additional arguments to pass to DataLoader.
617
+ Raises:
618
+ ValueError: If the dataset is empty.
619
+ IOError: If there is an issue writing to the file.
1006
620
 
1007
- Returns:
1008
- torch.utils.data.DataLoader: DataLoader for this dataset.
621
+ Notes:
622
+ - Uses `self._data`, which contains the generated datapoints.
623
+ - Overwrites the file if it already exists.
624
+ - Ensures compatibility with large datasets by using JSONL format.
1009
625
  """
1010
- from torch.utils.data import DataLoader
1011
-
1012
- return DataLoader(
1013
- self,
1014
- batch_size=batch_size,
1015
- shuffle=shuffle,
1016
- num_workers=num_workers,
1017
- collate_fn=self.collate_fn,
1018
- pin_memory=pin_memory,
1019
- **kwargs,
1020
- )
1021
-
1022
-
1023
- def convert_hf_to_pytorch(
1024
- hf_dataset: HFDataset,
1025
- transform: Optional[Callable] = None,
1026
- target_transform: Optional[Callable] = None,
1027
- column_mapping: Optional[Dict[str, str]] = None,
1028
- validate: bool = True,
1029
- batch_size: Optional[int] = None,
1030
- ) -> Union["PyTorchDataset", "DataLoader"]:
1031
- r"""Convert a HuggingFace dataset to a PyTorchDataset or DataLoader.
1032
-
1033
- This function maps HuggingFace dataset columns to the expected DataPoint
1034
- format, validates the data, and creates a PyTorchDataset or DataLoader.
1035
-
1036
- Args:
1037
- hf_dataset (HFDataset): HuggingFace dataset to convert.
1038
- transform (Optional[Callable]): Transform to apply to samples.
1039
- target_transform (Optional[Callable]): Transform to apply to targets.
1040
- column_mapping (Optional[Dict[str, str]]): Mapping from HuggingFace
1041
- column names to DataPoint field names. If None, assumes columns
1042
- already match DataPoint fields.
1043
- validate (bool): Whether to validate data points using DataPoint
1044
- schema. (default: :obj:`True`)
1045
- batch_size (Optional[int]): If provided, returns a DataLoader with the
1046
- specified batch size instead of a PyTorchDataset. (default:
1047
- :obj:`None`)
1048
-
1049
- Returns:
1050
- Union[PyTorchDataset, torch.utils.data.DataLoader]: Converted dataset
1051
- or DataLoader if batch_size is provided.
1052
- """
1053
- # Convert HuggingFace dataset to list of dicts more efficiently
1054
- mapped_dataset = []
1055
-
1056
- for i in range(len(hf_dataset)):
1057
- item = hf_dataset[i]
1058
- if column_mapping is not None:
1059
- # Apply column mapping if provided
1060
- mapped_item = {}
1061
- for hf_col, dp_field in column_mapping.items():
1062
- if hf_col in item:
1063
- mapped_item[dp_field] = item[hf_col]
1064
- mapped_dataset.append(mapped_item)
1065
- else:
1066
- # Otherwise use item directly
1067
- mapped_dataset.append(dict(item))
1068
-
1069
- # Create PyTorchDataset
1070
- dataset: PyTorchDataset = PyTorchDataset(
1071
- mapped_dataset,
1072
- transform=transform,
1073
- target_transform=target_transform,
1074
- validate=validate,
1075
- )
1076
-
1077
- # Return DataLoader if batch_size is provided
1078
- if batch_size is not None:
1079
- return dataset.get_dataloader(batch_size=batch_size)
626
+ if not self._data:
627
+ raise ValueError("Dataset is empty. No data to save.")
1080
628
 
1081
- return dataset
629
+ file_path = Path(file_path)
1082
630
 
1083
-
1084
- def convert_synthetic_to_pytorch(
1085
- synthetic_dataset: 'SyntheticDataset',
1086
- transform: Optional[Callable] = None,
1087
- target_transform: Optional[Callable] = None,
1088
- batch_size: Optional[int] = None,
1089
- ) -> Union["PyTorchDataset", "DataLoader"]:
1090
- r"""Convert a SyntheticDataset to a PyTorchDataset or DataLoader.
1091
-
1092
- Args:
1093
- synthetic_dataset (SyntheticDataset): Synthetic dataset to convert.
1094
- transform (Optional[Callable]): Transform to apply to samples.
1095
- target_transform (Optional[Callable]): Transform to apply to targets.
1096
- batch_size (Optional[int]): If provided, returns a DataLoader with the
1097
- specified batch size instead of a PyTorchDataset. (default:
1098
- :obj:`None`)
1099
-
1100
- Returns:
1101
- Union[PyTorchDataset, torch.utils.data.DataLoader]: Converted dataset
1102
- or DataLoader if batch_size is provided.
1103
- """
1104
- dataset = PyTorchDataset.from_datapoints(
1105
- synthetic_dataset.data,
1106
- transform=transform,
1107
- target_transform=target_transform,
1108
- )
1109
-
1110
- # Return DataLoader if batch_size is provided
1111
- if batch_size is not None:
1112
- return dataset.get_dataloader(batch_size=batch_size)
1113
-
1114
- return dataset
1115
-
1116
-
1117
- def save_synthetic_dataset(
1118
- synthetic_dataset: 'SyntheticDataset',
1119
- path: str,
1120
- compression: bool = True,
1121
- ) -> None:
1122
- r"""Save a synthetic dataset to disk using PyTorch format.
1123
-
1124
- Args:
1125
- synthetic_dataset (SyntheticDataset): Dataset to save.
1126
- path (str): Path to save the dataset to.
1127
- compression (bool): Whether to use compression to reduce file size.
1128
- (default: :obj:`True`)
1129
- """
1130
- pytorch_dataset = convert_synthetic_to_pytorch(synthetic_dataset)
1131
-
1132
- # Save with compression if enabled (uses less disk space)
1133
- if compression:
1134
- torch.save(
1135
- pytorch_dataset.data, # type: ignore[union-attr]
1136
- path,
1137
- _use_new_zipfile_serialization=True,
1138
- )
1139
- else:
1140
- pytorch_dataset.save_to_disk(path) # type: ignore[union-attr]
1141
-
1142
-
1143
- def load_pytorch_dataset(
1144
- path: str,
1145
- transform: Optional[Callable] = None,
1146
- target_transform: Optional[Callable] = None,
1147
- batch_size: Optional[int] = None,
1148
- ) -> Union["PyTorchDataset", "DataLoader"]:
1149
- r"""Load a PyTorchDataset from disk.
1150
-
1151
- Args:
1152
- path (str): Path to load the dataset from.
1153
- transform (Optional[Callable]): Transform to apply to samples.
1154
- target_transform (Optional[Callable]): Transform to apply to targets.
1155
- batch_size (Optional[int]): If provided, returns a DataLoader with the
1156
- specified batch size instead of a PyTorchDataset. (default:
1157
- :obj:`None`)
1158
-
1159
- Returns:
1160
- Union[PyTorchDataset, torch.utils.data.DataLoader]: Loaded dataset or
1161
- DataLoader if batch_size is provided.
1162
- """
1163
- dataset = PyTorchDataset.load_from_disk(
1164
- path, transform=transform, target_transform=target_transform
1165
- )
1166
-
1167
- # Return DataLoader if batch_size is provided
1168
- if batch_size is not None:
1169
- return dataset.get_dataloader(batch_size=batch_size)
1170
-
1171
- return dataset
631
+ try:
632
+ with file_path.open("w", encoding="utf-8") as f:
633
+ for datapoint in self._data:
634
+ json.dump(datapoint.to_dict(), f)
635
+ f.write("\n") # Ensure each entry is on a new line
636
+ logger.info(f"Dataset saved successfully to {file_path}")
637
+ except IOError as e:
638
+ logger.error(f"Error writing to file {file_path}: {e}")
639
+ raise