camel-ai 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

camel/datasets/base.py DELETED
@@ -1,1171 +0,0 @@
1
- # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
- # Licensed under the Apache License, Version 2.0 (the "License");
3
- # you may not use this file except in compliance with the License.
4
- # You may obtain a copy of the License at
5
- #
6
- # http://www.apache.org/licenses/LICENSE-2.0
7
- #
8
- # Unless required by applicable law or agreed to in writing, software
9
- # distributed under the License is distributed on an "AS IS" BASIS,
10
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
- # See the License for the specific language governing permissions and
12
- # limitations under the License.
13
- # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
-
15
- import json
16
- import os
17
- import random
18
- from pathlib import Path
19
- from typing import (
20
- Any,
21
- Callable,
22
- Dict,
23
- List,
24
- Optional,
25
- Sized,
26
- TypeVar,
27
- Union,
28
- )
29
-
30
- import torch
31
- from datasets import Dataset as HFDataset
32
- from pydantic import BaseModel, Field, ValidationError
33
- from torch.utils.data import DataLoader, Dataset
34
-
35
- from camel.agents import ChatAgent
36
- from camel.logger import get_logger
37
- from camel.verifiers import BaseVerifier
38
-
39
- logger = get_logger(__name__)
40
-
41
-
42
- class DataPoint(BaseModel):
43
- r"""A single data point in the dataset.
44
-
45
- Attributes:
46
- question (str): The primary question or issue to be addressed.
47
- rationale (str): Logical reasoning or explanation behind the
48
- answer.
49
- final_answer (str): The final answer.
50
- raw_markdown (Optional[str]): Raw markdown content for generating
51
- rewards/hints. (default: :obj:`None`)
52
- difficulty (Optional[str]): Difficulty level of the question.
53
- (default: :obj:`None`)
54
- metadata Optional[Dict[str, Any]]: Additional metadata about the data
55
- point. (default: :obj:`None`)
56
- """
57
-
58
- question: str = Field(
59
- ..., description="The primary question or issue to be addressed."
60
- )
61
- rationale: str = Field(
62
- ..., description="Logical reasoning or explanation behind the answer."
63
- )
64
- final_answer: str = Field(..., description="The final answer.")
65
- difficulty: Optional[str] = Field(
66
- None, description="Difficulty level of the question."
67
- )
68
- metadata: Optional[Dict[str, Any]] = Field(
69
- default=None, description="Additional metadata about the data point."
70
- )
71
-
72
- def to_dict(self) -> Dict[str, Any]:
73
- r"""Convert DataPoint to a dictionary.
74
-
75
- Returns:
76
- Dict[str, Any]: Dictionary representation of the DataPoint.
77
- """
78
- return self.dict()
79
-
80
- @classmethod
81
- def from_dict(cls, data: Dict[str, Any]) -> 'DataPoint':
82
- r"""Create a DataPoint from a dictionary.
83
-
84
- Args:
85
- data (Dict[str, Any]): Dictionary containing DataPoint fields.
86
-
87
- Returns:
88
- DataPoint: New DataPoint instance.
89
- """
90
- return cls(**data)
91
-
92
-
93
- class BaseDataset(Dataset):
94
- r"""A dataset contains questions and ground truth data for training.
95
- It can be either static (e.g., MATH dataset) or generative
96
- (using an LLM to generate questions).
97
- """
98
-
99
- def __init__(
100
- self,
101
- data: List[Dict[str, str]],
102
- cache_dir: Optional[str] = None,
103
- **kwargs,
104
- ):
105
- r"""Initialize the dataset.
106
-
107
- Args:
108
- data (List[Dict[str, str]]): List of dictionary items to
109
- create the dataset from.
110
- cache_dir (Optional[str]): Directory to cache dataset files.
111
- (default: :obj:`None`)
112
- **kwargs: Additional dataset parameters.
113
-
114
- Note:
115
- The dataset must be initialized by calling setup() before use.
116
- """
117
- self._is_setup = False
118
- self._raw_data: List[Dict[str, str]] = data if data is not None else []
119
- self._cache_dir = str(cache_dir) if cache_dir is not None else None
120
-
121
- # Store all parameters in metadata dict for compatibility
122
- self._metadata = {
123
- 'cache_dir': self._cache_dir,
124
- **kwargs,
125
- }
126
-
127
- self.data: List[DataPoint] = [] # Will be populated in setup()
128
-
129
- async def setup(self) -> None:
130
- r"""Set up the dataset with necessary resources.
131
-
132
- This method:
133
- 1. Creates cache directory if needed
134
- 2. Processes raw data into DataPoint objects using vectorized
135
- operations
136
- 3. Validates dataset integrity
137
-
138
- Raises:
139
- OSError: If cache directory creation fails.
140
- Exception: If dataset initialization fails.
141
- """
142
- if self._is_setup:
143
- logger.debug(f"{self.__class__.__name__} already initialized")
144
- return
145
-
146
- try:
147
- # Create cache directory if specified
148
- if self._cache_dir:
149
- try:
150
- os.makedirs(self._cache_dir, exist_ok=True)
151
- logger.debug(f"Created cache directory: {self._cache_dir}")
152
- except OSError as e:
153
- logger.error(
154
- f"Failed to create cache directory "
155
- f"{self._cache_dir}: {e}"
156
- )
157
- raise
158
-
159
- # Process raw data into DataPoint objects using vectorized
160
- # operations
161
- if not self._raw_data:
162
- self.data = []
163
- logger.debug("No raw data to process")
164
- else:
165
- try:
166
- # Helper function for validation that can be used with map
167
- def create_datapoint(item, idx=None):
168
- try:
169
- return DataPoint(
170
- question=item.get('question', ''),
171
- rationale=item.get('rationale', ''),
172
- final_answer=item.get('final_answer', ''),
173
- metadata=item.get('metadata', {})
174
- if isinstance(item.get('metadata'), dict)
175
- else {},
176
- raw_markdown='',
177
- difficulty='',
178
- )
179
- except ValidationError as e:
180
- idx_str = (
181
- f" at index {idx}" if idx is not None else ""
182
- )
183
- error_msg = (
184
- f"Sample{idx_str} validation error: {e}"
185
- )
186
- logger.error(error_msg)
187
- raise ValueError(error_msg)
188
-
189
- # If raw_data is already a HF dataset, use its map function
190
- if hasattr(self._raw_data, 'map') and callable(
191
- self._raw_data.map
192
- ):
193
- # Using HF dataset's map for vectorized processing
194
- processed_data = self._raw_data.map(
195
- lambda example, idx: {
196
- 'datapoint': create_datapoint(example, idx)
197
- },
198
- with_indices=True,
199
- )
200
- self.data = [
201
- item['datapoint'] for item in processed_data
202
- ]
203
- else:
204
- # Bulk create datapoints
205
- self.data = [
206
- create_datapoint(item, i)
207
- for i, item in enumerate(self._raw_data)
208
- ]
209
-
210
- logger.debug(f"Processed {len(self.data)} data points")
211
- except Exception as e:
212
- logger.error(f"Error processing data: {e}")
213
- raise
214
-
215
- self._is_setup = True
216
- logger.info(f"{self.__class__.__name__} initialized successfully")
217
-
218
- except Exception as e:
219
- logger.error(f"Error during {self.__class__.__name__} setup: {e}")
220
- await self.cleanup()
221
- raise
222
-
223
- async def cleanup(self) -> None:
224
- r"""Clean up dataset resources.
225
-
226
- This method handles cleanup of resources and resets the dataset state.
227
- """
228
- if not self._is_setup:
229
- return
230
-
231
- try:
232
- # Clear metadata while preserving init config
233
- init_config = {
234
- 'cache_dir': self._cache_dir,
235
- }
236
- self._metadata = init_config
237
-
238
- logger.info(f"{self.__class__.__name__} cleaned up successfully")
239
-
240
- except Exception as e:
241
- logger.error(
242
- f"Error during {self.__class__.__name__} cleanup: {e}"
243
- )
244
- raise
245
-
246
- finally:
247
- # Always mark as uninitialized, even if cleanup fails
248
- self._is_setup = False
249
-
250
- def sample(self) -> DataPoint:
251
- r"""Sample a random datapoint from the dataset.
252
-
253
- Returns:
254
- DataPoint: A randomly sampled DataPoint.
255
-
256
- Raises:
257
- RuntimeError: If dataset is not initialized.
258
- """
259
- if not self._is_setup:
260
- raise RuntimeError(
261
- f"{self.__class__.__name__} must be initialized "
262
- "before sampling"
263
- )
264
-
265
- idx = random.randint(0, len(self) - 1)
266
- return self[idx]
267
-
268
- def __len__(self) -> int:
269
- r"""Return the size of the dataset."""
270
- return len(self.data)
271
-
272
- def __getitem__(self, idx: int) -> DataPoint:
273
- r"""Get an item from the dataset.
274
-
275
- Args:
276
- idx (int): Index of the item to get.
277
-
278
- Returns:
279
- DataPoint: DataPoint from the dataset with the given index.
280
-
281
- Raises:
282
- IndexError: If idx is out of bounds.
283
- """
284
- if idx < 0 or idx >= len(self):
285
- raise IndexError(
286
- f"Index {idx} out of bounds for dataset of size {len(self)}"
287
- )
288
-
289
- return self.data[idx]
290
-
291
- @property
292
- def metadata(self) -> Dict[str, Any]:
293
- r"""Get dataset metadata."""
294
- return self._metadata.copy()
295
-
296
- def to_pytorch_dataset(
297
- self,
298
- transform: Optional[Callable] = None,
299
- target_transform: Optional[Callable] = None,
300
- batch_size: Optional[int] = None,
301
- ) -> Union["PyTorchDataset", "DataLoader"]:
302
- r"""Convert to a PyTorch dataset or DataLoader.
303
-
304
- Args:
305
- transform (Optional[Callable]): Transform to apply to samples.
306
- target_transform (Optional[Callable]): Transform to apply to
307
- targets.
308
- batch_size (Optional[int]): If provided, returns a DataLoader with
309
- the specified batch size instead of a PyTorchDataset.
310
- (default: :obj:`None`)
311
-
312
- Returns:
313
- Union[PyTorchDataset, torch.utils.data.DataLoader]: Dataset in
314
- PyTorch format or DataLoader if batch_size is provided.
315
- """
316
- dataset = PyTorchDataset.from_datapoints(
317
- self.data,
318
- transform=transform,
319
- target_transform=target_transform,
320
- )
321
-
322
- if batch_size is not None:
323
- return dataset.get_dataloader(batch_size=batch_size)
324
-
325
- return dataset
326
-
327
-
328
- class SeedDataset(BaseDataset):
329
- r"""A dataset containing validated seed examples for data generation.
330
- Ensures that all items adhere to the DataPoint schema.
331
-
332
- This class can initialize from Hugging Face Datasets,
333
- PyTorch Datasets, JSON file paths, or lists of dictionaries,
334
- converting them into a consistent internal format.
335
- """
336
-
337
- def __init__(
338
- self,
339
- data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]],
340
- cache_dir: Optional[str] = None,
341
- seed: Optional[int] = None,
342
- min_samples: int = 1,
343
- strict: bool = False,
344
- **kwargs,
345
- ):
346
- r"""Initialize the seed dataset and validate integrity.
347
-
348
- Args:
349
- data (Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]):
350
- Input data, which can be:
351
- - A Hugging Face Dataset (HFDataset)
352
- - A PyTorch Dataset (torch.utils.data.Dataset)
353
- - A Path object representing the path to a JSON file
354
- - A list of dictionaries with DataPoint-compatible fields
355
- seed (Optional[int]): Seed for reproducibility.
356
- (default: :obj:`1`)
357
- min_samples (int): Minimum number of samples required.
358
- (default: :obj:`1`)
359
- strict (bool): Whether to raise an error on invalid datapoints
360
- (True) or skip/filter them (False). (default: False)
361
- **kwargs: Additional dataset parameters.
362
-
363
- Raises:
364
- TypeError: If the data type is not supported.
365
- ValueError: If dataset size is less than min_samples or
366
- if sample validation fails.
367
- FileNotFoundError: If the JSON file path doesn't exist.
368
- json.JSONDecodeError: If the JSON file is invalid.
369
- """
370
- # Initialize BaseDataset with empty data, we'll populate it ourselves
371
- super().__init__(data=[], cache_dir=cache_dir, **kwargs)
372
-
373
- self._rng = random.Random(seed)
374
- self._strict = strict
375
-
376
- # Type checking and conversion into list of dicts to have a
377
- # consistent internal format. Since Seed Dataset should be
378
- # small, we can load it entirely into memory
379
-
380
- self.data: List[DataPoint] = self._init_data(data)
381
- self._length = len(self.data)
382
-
383
- if self._length < min_samples:
384
- raise ValueError(
385
- "The dataset does not contain enough samples. "
386
- f"Need {max(0, min_samples)}, got {self._length}"
387
- )
388
-
389
- def _init_data(
390
- self, data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]
391
- ) -> List[DataPoint]:
392
- if isinstance(data, HFDataset):
393
- raw_data = self._init_from_hf_dataset(data)
394
- elif isinstance(data, Dataset):
395
- raw_data = self._init_from_pytorch_dataset(data)
396
- elif isinstance(data, Path):
397
- raw_data = self._init_from_json_path(data)
398
- elif isinstance(data, list):
399
- raw_data = self._init_from_list(data)
400
- else:
401
- raise TypeError("Unsupported data type")
402
-
403
- def create_datapoint(
404
- item: Dict[str, Any], idx: int
405
- ) -> Optional[DataPoint]:
406
- # Add type checks for required fields to make mypy happy
407
- question = item.get('question')
408
- if not isinstance(question, str):
409
- if self._strict:
410
- raise ValueError(
411
- f"Sample at index {idx} has invalid 'question': "
412
- f"expected str, got {type(question)}"
413
- )
414
- else:
415
- logger.warning(
416
- f"Skipping sample at index {idx}: invalid 'question'"
417
- )
418
- return None
419
-
420
- rationale = item.get('rationale')
421
- if not isinstance(rationale, str):
422
- if self._strict:
423
- raise ValueError(
424
- f"Sample at index {idx} has invalid 'rationale': "
425
- f"expected str, got {type(rationale)}"
426
- )
427
- else:
428
- logger.warning(
429
- f"Skipping sample at index {idx}: invalid 'rationale'"
430
- )
431
- return None
432
-
433
- final_answer = item.get('final_answer')
434
- if not isinstance(final_answer, str):
435
- if self._strict:
436
- raise ValueError(
437
- f"Sample at index {idx} has invalid 'final_answer': "
438
- f"expected str, got {type(final_answer)}"
439
- )
440
- else:
441
- logger.warning(
442
- f"Skipping sample at index {idx}: "
443
- "invalid 'final_answer'"
444
- )
445
- return None
446
-
447
- try:
448
- return DataPoint(
449
- question=question,
450
- rationale=rationale,
451
- final_answer=final_answer,
452
- metadata=item.get('metadata'),
453
- difficulty=item.get('difficulty'),
454
- )
455
- except ValidationError as e:
456
- if self._strict:
457
- raise ValueError(
458
- f"Sample at index {idx} validation error: {e}"
459
- )
460
- else:
461
- logger.warning(
462
- f"Skipping invalid sample at index {idx} "
463
- f"due to validation error: {e}"
464
- )
465
- return None
466
-
467
- unfiltered_data = [
468
- create_datapoint(item, i) for i, item in enumerate(raw_data)
469
- ]
470
- return [dp for dp in unfiltered_data if dp is not None]
471
-
472
- def __len__(self) -> int:
473
- r"""Return the size of the dataset."""
474
- return self._length
475
-
476
- def __getitem__(self, idx: int) -> DataPoint:
477
- r"""Get an item from the dataset.
478
-
479
- Args:
480
- idx (int): Index of the item to get.
481
-
482
- Returns:
483
- DataPoint: DataPoint from the dataset with the given index.
484
-
485
- Raises:
486
- IndexError: If idx is out of bounds.
487
- """
488
- if idx < 0 or idx >= self._length:
489
- raise IndexError(
490
- f"Index {idx} out of bounds for dataset of size {self._length}"
491
- )
492
- return self.data[idx]
493
-
494
- def sample(self) -> DataPoint:
495
- r"""Sample a random datapoint from the dataset.
496
-
497
- Returns:
498
- DataPoint: A randomly sampled DataPoint.
499
-
500
- Raises:
501
- RuntimeError: If the dataset is empty.
502
- """
503
- if self._length == 0:
504
- raise RuntimeError("Dataset is empty, cannot sample.")
505
- idx = self._rng.randint(0, self._length - 1)
506
- return self[idx]
507
-
508
- @property
509
- def metadata(self) -> Dict[str, Any]:
510
- r"""Get dataset metadata."""
511
- return self._metadata.copy()
512
-
513
- def _init_from_hf_dataset(self, data: HFDataset) -> List[Dict[str, Any]]:
514
- return [dict(item) for item in data]
515
-
516
- def _init_from_pytorch_dataset(
517
- self, data: Dataset
518
- ) -> List[Dict[str, Any]]:
519
- if not isinstance(data, Sized):
520
- raise TypeError(
521
- f"{type(data).__name__} does not implement `__len__()`."
522
- )
523
- raw_data = []
524
-
525
- for i in range(len(data)):
526
- item = data[i]
527
- if not isinstance(item, dict):
528
- raise TypeError(
529
- f"Item at index {i} is not a dict: "
530
- f"got {type(item).__name__}"
531
- )
532
- raw_data.append(dict(item))
533
- return raw_data
534
-
535
- def _init_from_json_path(self, data: Path) -> List[Dict[str, Any]]:
536
- if not data.exists():
537
- raise FileNotFoundError(f"JSON file not found: {data}")
538
- try:
539
- logger.debug(f"Loading JSON from {data}")
540
- with data.open('r', encoding='utf-8') as f:
541
- loaded_data = json.load(f)
542
- logger.info(
543
- f"Successfully loaded {len(loaded_data)} items from {data}"
544
- )
545
- except json.JSONDecodeError as e:
546
- raise ValueError(f"Invalid JSON in file {data}: {e}")
547
- if not isinstance(loaded_data, list):
548
- raise ValueError("JSON file must contain a list of dictionaries")
549
- for i, item in enumerate(loaded_data):
550
- if not isinstance(item, dict):
551
- raise ValueError(
552
- f"Expected a dictionary at index {i}, "
553
- f"got {type(item).__name__}"
554
- )
555
- return loaded_data
556
-
557
- def _init_from_list(
558
- self, data: List[Dict[str, Any]]
559
- ) -> List[Dict[str, Any]]:
560
- for i, item in enumerate(data):
561
- if not isinstance(item, dict):
562
- raise ValueError(
563
- f"Expected a dictionary at index {i}, "
564
- f"got {type(item).__name__}"
565
- )
566
- return data
567
-
568
-
569
- class SyntheticDataset(BaseDataset):
570
- r"""A dataset for storing synthetically generated data points.
571
-
572
- This class is used to store datapoints that are generated through
573
- a generative process, such as using an agent.
574
- """
575
-
576
- def __init__(
577
- self,
578
- data: Optional[List[Dict[str, str]]] = None,
579
- cache_dir: Optional[str] = None,
580
- **kwargs,
581
- ):
582
- r"""Initialize the synthetic dataset.
583
-
584
- Args:
585
- data (Optional[List[Dict[str, str]]]): List of dictionary items to
586
- create the dataset from. (default: :obj:`None`)
587
- cache_dir (Optional[str]): Directory to cache dataset files.
588
- (default: :obj:`None`)
589
- **kwargs: Additional dataset parameters.
590
- """
591
- super().__init__(
592
- data=data if data is not None else [],
593
- cache_dir=cache_dir,
594
- **kwargs,
595
- )
596
- self.data: List[DataPoint] = []
597
-
598
- def add(self, item: DataPoint) -> None:
599
- r"""Add a new data point to the dataset.
600
-
601
- Args:
602
- item (DataPoint): The datapoint to add to the dataset.
603
- """
604
- self.data.append(item)
605
-
606
- def add_batch(self, items: List[DataPoint]) -> None:
607
- r"""Add multiple data points to the dataset.
608
-
609
- Args:
610
- items (List[DataPoint]): The datapoints to add to the dataset.
611
- """
612
- self.data.extend(items)
613
-
614
- def to_pytorch_dataset(
615
- self,
616
- transform: Optional[Callable] = None,
617
- target_transform: Optional[Callable] = None,
618
- batch_size: Optional[int] = None,
619
- ) -> Union["PyTorchDataset", "DataLoader"]:
620
- r"""Convert to a PyTorch dataset or DataLoader.
621
-
622
- Args:
623
- transform (Optional[Callable]): Transform to apply to samples.
624
- target_transform (Optional[Callable]): Transform to apply to
625
- targets.
626
- batch_size (Optional[int]): If provided, returns a DataLoader with
627
- the specified batch size instead of a PyTorchDataset.
628
- (default: :obj:`None`)
629
-
630
- Returns:
631
- Union[PyTorchDataset, torch.utils.data.DataLoader]: Dataset in
632
- PyTorch format or DataLoader if batch_size is provided.
633
- """
634
- return convert_synthetic_to_pytorch(
635
- self,
636
- transform=transform,
637
- target_transform=target_transform,
638
- batch_size=batch_size,
639
- )
640
-
641
- def save_pytorch_format(self, path: str, compression: bool = True) -> None:
642
- r"""Save the dataset to disk in PyTorch format.
643
-
644
- Args:
645
- path (str): Path to save the dataset to.
646
- compression (bool): Whether to use compression to reduce file size.
647
- (default: :obj:`True`)
648
- """
649
- save_synthetic_dataset(self, path, compression=compression)
650
-
651
- def filter(
652
- self, predicate: Callable[[DataPoint], bool]
653
- ) -> 'SyntheticDataset':
654
- r"""Filter the dataset using a predicate function.
655
-
656
- Args:
657
- predicate (Callable[[DataPoint], bool]): Function that takes a
658
- DataPoint and returns True if it should be kept, False
659
- otherwise.
660
-
661
- Returns:
662
- SyntheticDataset: A new dataset containing only the filtered items.
663
- """
664
- filtered_data = [dp for dp in self.data if predicate(dp)]
665
-
666
- # Create a new dataset with the filtered data
667
- new_dataset = SyntheticDataset()
668
- new_dataset.add_batch(filtered_data)
669
-
670
- return new_dataset
671
-
672
-
673
- class GenerativeDataset(BaseDataset):
674
- r"""A dataset for generating synthetic datapoints using external agents and
675
- verifiers.
676
-
677
- This class leverages a seed dataset and external components to generate
678
- new synthetic datapoints on demand.
679
- """
680
-
681
- def __init__(
682
- self,
683
- seed_dataset: SeedDataset,
684
- verifier: BaseVerifier,
685
- agent: ChatAgent,
686
- cache_dir: Optional[str] = None,
687
- seed: int = 42,
688
- **kwargs,
689
- ):
690
- r"""Initialize the generative dataset.
691
-
692
- Args:
693
- seed_dataset (SeedDataset): Validated dataset to use for examples.
694
- verifier (BaseVerifier): Verifier to validate generated content.
695
- agent (ChatAgent): Agent to generate new datapoints.
696
- cache_dir (Optional[str]): Directory to cache dataset files.
697
- (default: :obj:`None`)
698
- seed (int): Random seed for reproducibility. (default: :obj:`42`)
699
- **kwargs: Additional dataset parameters.
700
- """
701
- # Initialize with empty data since we'll generate content dynamically
702
- super().__init__(
703
- data=[],
704
- cache_dir=cache_dir,
705
- **kwargs,
706
- )
707
-
708
- self.seed_dataset = seed_dataset
709
- self.verifier = verifier
710
- self.agent = agent
711
-
712
- self.seed = seed
713
- random.seed(self.seed)
714
-
715
- def _construct_prompt(self, examples: List[DataPoint]) -> str:
716
- r"""Construct a prompt for generating new datapoints.
717
-
718
- Args:
719
- examples (List[DataPoint]): Examples to include in the prompt.
720
-
721
- Returns:
722
- str: Formatted prompt with examples.
723
- """
724
- prompt = (
725
- "Generate a new datapoint similar to the following examples:\n\n"
726
- )
727
- for i, example in enumerate(examples, 1):
728
- prompt += f"Example {i}:\n"
729
- prompt += f"Question: {example.question}\n"
730
- prompt += f"Rationale: {example.rationale}\n"
731
- prompt += f"Final Answer: {example.final_answer}\n\n"
732
- prompt += "New datapoint:"
733
- return prompt
734
-
735
- async def generate_new(self, n: int) -> None:
736
- r"""Generate n new datapoints and add them to the dataset.
737
-
738
- Args:
739
- n (int): Number of valid datapoints to generate.
740
-
741
- This method generates new datapoints by:
742
- 1. Sampling examples from the seed dataset
743
- 2. Constructing a prompt for the agent
744
- 3. Generating a new datapoint using the agent
745
- 4. Verifying the generated datapoint with the verifier
746
- 5. Adding valid datapoints to the dataset
747
- """
748
- valid_data_points: List[DataPoint] = []
749
-
750
- while len(valid_data_points) < n:
751
- try:
752
- indices = random.sample(range(len(self.seed_dataset)), 3)
753
- examples = [self.seed_dataset[i] for i in indices]
754
- prompt = self._construct_prompt(examples)
755
-
756
- # Get agent response
757
- agent_output = (
758
- self.agent.step(prompt, response_format=DataPoint)
759
- .msgs[0]
760
- .parsed
761
- )
762
-
763
- if not isinstance(agent_output, dict):
764
- raise TypeError("Agent output must be a dictionary")
765
- if (
766
- 'question' not in agent_output
767
- or 'rationale' not in agent_output
768
- ):
769
- raise KeyError(
770
- "Agent output missing required keys: "
771
- "'question' or 'rationale'"
772
- )
773
-
774
- rationale = agent_output['rationale']
775
-
776
- # Verify the generated content
777
- verifier_response = await self.verifier.verify(rationale)
778
- if not hasattr(verifier_response, 'content'):
779
- raise AttributeError(
780
- "Verifier response missing 'content' attribute"
781
- )
782
-
783
- if not verifier_response.result:
784
- continue
785
-
786
- final_answer = verifier_response.result
787
-
788
- # Create and validate the new datapoint
789
- new_datapoint = {
790
- 'question': agent_output['question'],
791
- 'rationale': rationale,
792
- 'final_answer': final_answer,
793
- }
794
-
795
- datapoint = DataPoint(**new_datapoint)
796
- valid_data_points.append(datapoint)
797
-
798
- except (TypeError, KeyError, AttributeError, ValidationError) as e:
799
- logger.warning(
800
- f"Error encountered during generation: {e}, retrying..."
801
- )
802
-
803
- # Add all valid datapoints to the dataset
804
- for datapoint in valid_data_points:
805
- self.data.append(datapoint)
806
- logger.debug("Added new datapoint to dataset")
807
-
808
-
809
- # Define a type variable for return type flexibility
810
- T_co = TypeVar('T_co', covariant=True)
811
-
812
-
813
- class PyTorchDataset(Dataset[T_co]):
814
- r"""A PyTorch-compatible dataset implementation that leverages PyTorch's
815
- efficient data handling capabilities.
816
- """
817
-
818
- def __init__(
819
- self,
820
- data: List[Dict[str, Any]],
821
- transform: Optional[Callable] = None,
822
- target_transform: Optional[Callable] = None,
823
- validate: bool = True,
824
- ):
825
- r"""Initialize the PyTorch dataset.
826
-
827
- Args:
828
- data (List[Dict[str, Any]]): List of dictionary items to create
829
- the dataset from.
830
- transform (Optional[Callable]): A function/transform that takes a
831
- sample and returns a transformed version for features.
832
- (default: :obj:`None`)
833
- target_transform (Optional[Callable]): A function/transform that
834
- takes a target and returns a transformed version. (default:
835
- :obj:`None`)
836
- validate (bool): Whether to validate data points using DataPoint
837
- schema. (default: :obj:`True`)
838
-
839
- Raises:
840
- ValidationError: If validation is enabled and data doesn't match
841
- DataPoint schema.
842
- """
843
- self.transform = transform
844
- self.target_transform = target_transform
845
-
846
- # Validate and store data
847
- self._raw_data = data
848
- self.data = []
849
-
850
- if validate:
851
- for i, item in enumerate(self._raw_data):
852
- try:
853
- # Use DataPoint for validation only
854
- dp = DataPoint(**item)
855
- self.data.append(dp.to_dict())
856
- except ValidationError as e:
857
- logger.error(f"Sample {i} validation error: {e}")
858
- raise ValueError(f"Sample {i} validation error: {e}")
859
- else:
860
- # Skip validation and just store the data dictionaries
861
- self.data = [dict(item) for item in self._raw_data]
862
-
863
- def __getitem__(self, index: int) -> T_co:
864
- r"""Get an item from the dataset.
865
-
866
- Args:
867
- index (int): Index of the item to get.
868
-
869
- Returns:
870
- T_co: Item from the dataset, possibly transformed.
871
- """
872
- sample = self.data[index]
873
-
874
- # Apply transformations if provided
875
- if self.transform is not None:
876
- sample = self.transform(sample)
877
-
878
- return sample # type: ignore[return-value]
879
-
880
- def __len__(self) -> int:
881
- r"""Return the size of the dataset.
882
-
883
- Returns:
884
- int: Number of samples in the dataset.
885
- """
886
- return len(self.data)
887
-
888
- @classmethod
889
- def from_datapoints(
890
- cls, datapoints: List[DataPoint], **kwargs
891
- ) -> 'PyTorchDataset':
892
- r"""Create a PyTorchDataset from a list of DataPoints.
893
-
894
- Args:
895
- datapoints (List[DataPoint]): List of DataPoint objects.
896
- **kwargs: Additional arguments to pass to the constructor.
897
-
898
- Returns:
899
- PyTorchDataset: A new PyTorchDataset instance.
900
- """
901
- data = [dp.to_dict() for dp in datapoints]
902
- # We can skip validation since datapoints are already validated
903
- return cls(data, validate=False, **kwargs)
904
-
905
- def to_hf_dataset(self) -> HFDataset:
906
- r"""Convert to a HuggingFace dataset.
907
-
908
- Returns:
909
- HFDataset: Dataset in HuggingFace format.
910
- """
911
- return HFDataset.from_list(self.data)
912
-
913
- def save_to_disk(self, path: str) -> None:
914
- r"""Save the dataset to disk using PyTorch.
915
-
916
- Args:
917
- path (str): Path to save the dataset to.
918
- """
919
- torch.save(self.data, path)
920
-
921
- @classmethod
922
- def load_from_disk(
923
- cls,
924
- path: str,
925
- transform: Optional[Callable] = None,
926
- target_transform: Optional[Callable] = None,
927
- ) -> 'PyTorchDataset':
928
- r"""Load a dataset from disk.
929
-
930
- Args:
931
- path (str): Path to load the dataset from.
932
- transform (Optional[Callable]): Transform to apply to samples.
933
- (default: :obj:`None`)
934
- target_transform (Optional[Callable]): Transform to apply to
935
- targets. (default: :obj:`None`)
936
-
937
- Returns:
938
- PyTorchDataset: Loaded dataset.
939
- """
940
- data = torch.load(path)
941
- return cls(
942
- data,
943
- transform=transform,
944
- target_transform=target_transform,
945
- validate=False,
946
- )
947
-
948
- @staticmethod
949
- def collate_fn(
950
- batch: List[Dict[str, Any]],
951
- ) -> Dict[str, Union[List[Any], torch.Tensor]]:
952
- r"""Collate function for PyTorch DataLoader.
953
-
954
- Args:
955
- batch (List[Dict[str, Any]]): Batch of samples from the dataset.
956
-
957
- Returns:
958
- Dict[str, Union[List[Any], torch.Tensor]]: Collated batch with
959
- tensors for numerical data.
960
- """
961
- if not batch:
962
- return {}
963
-
964
- # Initialize result dictionary with keys from first item - start with
965
- # lists only
966
- result: Dict[str, List[Any]] = {k: [] for k in batch[0].keys()}
967
-
968
- # Collect values by key
969
- for item in batch:
970
- for k, v in item.items():
971
- result[k].append(v)
972
-
973
- # Convert numeric/boolean lists to tensors where possible
974
- result_with_tensors: Dict[str, Union[List[Any], torch.Tensor]] = {}
975
- for k, v in result.items():
976
- if all(isinstance(x, (int, float, bool)) for x in v):
977
- try:
978
- result_with_tensors[k] = torch.tensor(v)
979
- except (ValueError, TypeError):
980
- # Keep as list if tensor conversion fails
981
- result_with_tensors[k] = v
982
- else:
983
- result_with_tensors[k] = v
984
-
985
- return result_with_tensors
986
-
987
- def get_dataloader(
988
- self,
989
- batch_size: int = 32,
990
- shuffle: bool = True,
991
- num_workers: int = 0,
992
- pin_memory: bool = False,
993
- **kwargs,
994
- ) -> "DataLoader":
995
- r"""Create a PyTorch DataLoader for this dataset.
996
-
997
- Args:
998
- batch_size (int): Batch size. (default: :obj:`32`)
999
- shuffle (bool): Whether to shuffle the dataset. (default:
1000
- :obj:`True`)
1001
- num_workers (int): Number of workers for data loading. (default:
1002
- :obj:`0`)
1003
- pin_memory (bool): Whether to pin memory for faster GPU transfer.
1004
- (default: :obj:`False`)
1005
- **kwargs: Additional arguments to pass to DataLoader.
1006
-
1007
- Returns:
1008
- torch.utils.data.DataLoader: DataLoader for this dataset.
1009
- """
1010
- from torch.utils.data import DataLoader
1011
-
1012
- return DataLoader(
1013
- self,
1014
- batch_size=batch_size,
1015
- shuffle=shuffle,
1016
- num_workers=num_workers,
1017
- collate_fn=self.collate_fn,
1018
- pin_memory=pin_memory,
1019
- **kwargs,
1020
- )
1021
-
1022
-
1023
- def convert_hf_to_pytorch(
1024
- hf_dataset: HFDataset,
1025
- transform: Optional[Callable] = None,
1026
- target_transform: Optional[Callable] = None,
1027
- column_mapping: Optional[Dict[str, str]] = None,
1028
- validate: bool = True,
1029
- batch_size: Optional[int] = None,
1030
- ) -> Union["PyTorchDataset", "DataLoader"]:
1031
- r"""Convert a HuggingFace dataset to a PyTorchDataset or DataLoader.
1032
-
1033
- This function maps HuggingFace dataset columns to the expected DataPoint
1034
- format, validates the data, and creates a PyTorchDataset or DataLoader.
1035
-
1036
- Args:
1037
- hf_dataset (HFDataset): HuggingFace dataset to convert.
1038
- transform (Optional[Callable]): Transform to apply to samples.
1039
- target_transform (Optional[Callable]): Transform to apply to targets.
1040
- column_mapping (Optional[Dict[str, str]]): Mapping from HuggingFace
1041
- column names to DataPoint field names. If None, assumes columns
1042
- already match DataPoint fields.
1043
- validate (bool): Whether to validate data points using DataPoint
1044
- schema. (default: :obj:`True`)
1045
- batch_size (Optional[int]): If provided, returns a DataLoader with the
1046
- specified batch size instead of a PyTorchDataset. (default:
1047
- :obj:`None`)
1048
-
1049
- Returns:
1050
- Union[PyTorchDataset, torch.utils.data.DataLoader]: Converted dataset
1051
- or DataLoader if batch_size is provided.
1052
- """
1053
- # Convert HuggingFace dataset to list of dicts more efficiently
1054
- mapped_dataset = []
1055
-
1056
- for i in range(len(hf_dataset)):
1057
- item = hf_dataset[i]
1058
- if column_mapping is not None:
1059
- # Apply column mapping if provided
1060
- mapped_item = {}
1061
- for hf_col, dp_field in column_mapping.items():
1062
- if hf_col in item:
1063
- mapped_item[dp_field] = item[hf_col]
1064
- mapped_dataset.append(mapped_item)
1065
- else:
1066
- # Otherwise use item directly
1067
- mapped_dataset.append(dict(item))
1068
-
1069
- # Create PyTorchDataset
1070
- dataset: PyTorchDataset = PyTorchDataset(
1071
- mapped_dataset,
1072
- transform=transform,
1073
- target_transform=target_transform,
1074
- validate=validate,
1075
- )
1076
-
1077
- # Return DataLoader if batch_size is provided
1078
- if batch_size is not None:
1079
- return dataset.get_dataloader(batch_size=batch_size)
1080
-
1081
- return dataset
1082
-
1083
-
1084
- def convert_synthetic_to_pytorch(
1085
- synthetic_dataset: 'SyntheticDataset',
1086
- transform: Optional[Callable] = None,
1087
- target_transform: Optional[Callable] = None,
1088
- batch_size: Optional[int] = None,
1089
- ) -> Union["PyTorchDataset", "DataLoader"]:
1090
- r"""Convert a SyntheticDataset to a PyTorchDataset or DataLoader.
1091
-
1092
- Args:
1093
- synthetic_dataset (SyntheticDataset): Synthetic dataset to convert.
1094
- transform (Optional[Callable]): Transform to apply to samples.
1095
- target_transform (Optional[Callable]): Transform to apply to targets.
1096
- batch_size (Optional[int]): If provided, returns a DataLoader with the
1097
- specified batch size instead of a PyTorchDataset. (default:
1098
- :obj:`None`)
1099
-
1100
- Returns:
1101
- Union[PyTorchDataset, torch.utils.data.DataLoader]: Converted dataset
1102
- or DataLoader if batch_size is provided.
1103
- """
1104
- dataset = PyTorchDataset.from_datapoints(
1105
- synthetic_dataset.data,
1106
- transform=transform,
1107
- target_transform=target_transform,
1108
- )
1109
-
1110
- # Return DataLoader if batch_size is provided
1111
- if batch_size is not None:
1112
- return dataset.get_dataloader(batch_size=batch_size)
1113
-
1114
- return dataset
1115
-
1116
-
1117
- def save_synthetic_dataset(
1118
- synthetic_dataset: 'SyntheticDataset',
1119
- path: str,
1120
- compression: bool = True,
1121
- ) -> None:
1122
- r"""Save a synthetic dataset to disk using PyTorch format.
1123
-
1124
- Args:
1125
- synthetic_dataset (SyntheticDataset): Dataset to save.
1126
- path (str): Path to save the dataset to.
1127
- compression (bool): Whether to use compression to reduce file size.
1128
- (default: :obj:`True`)
1129
- """
1130
- pytorch_dataset = convert_synthetic_to_pytorch(synthetic_dataset)
1131
-
1132
- # Save with compression if enabled (uses less disk space)
1133
- if compression:
1134
- torch.save(
1135
- pytorch_dataset.data, # type: ignore[union-attr]
1136
- path,
1137
- _use_new_zipfile_serialization=True,
1138
- )
1139
- else:
1140
- pytorch_dataset.save_to_disk(path) # type: ignore[union-attr]
1141
-
1142
-
1143
- def load_pytorch_dataset(
1144
- path: str,
1145
- transform: Optional[Callable] = None,
1146
- target_transform: Optional[Callable] = None,
1147
- batch_size: Optional[int] = None,
1148
- ) -> Union["PyTorchDataset", "DataLoader"]:
1149
- r"""Load a PyTorchDataset from disk.
1150
-
1151
- Args:
1152
- path (str): Path to load the dataset from.
1153
- transform (Optional[Callable]): Transform to apply to samples.
1154
- target_transform (Optional[Callable]): Transform to apply to targets.
1155
- batch_size (Optional[int]): If provided, returns a DataLoader with the
1156
- specified batch size instead of a PyTorchDataset. (default:
1157
- :obj:`None`)
1158
-
1159
- Returns:
1160
- Union[PyTorchDataset, torch.utils.data.DataLoader]: Loaded dataset or
1161
- DataLoader if batch_size is provided.
1162
- """
1163
- dataset = PyTorchDataset.load_from_disk(
1164
- path, transform=transform, target_transform=target_transform
1165
- )
1166
-
1167
- # Return DataLoader if batch_size is provided
1168
- if batch_size is not None:
1169
- return dataset.get_dataloader(batch_size=batch_size)
1170
-
1171
- return dataset