camel-ai 0.2.21__py3-none-any.whl → 0.2.23a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (106) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/_types.py +41 -0
  3. camel/agents/_utils.py +188 -0
  4. camel/agents/chat_agent.py +556 -965
  5. camel/agents/knowledge_graph_agent.py +7 -1
  6. camel/agents/multi_hop_generator_agent.py +1 -1
  7. camel/configs/base_config.py +10 -13
  8. camel/configs/deepseek_config.py +4 -30
  9. camel/configs/gemini_config.py +5 -31
  10. camel/configs/openai_config.py +14 -32
  11. camel/configs/qwen_config.py +36 -36
  12. camel/datagen/self_improving_cot.py +79 -1
  13. camel/datagen/self_instruct/filter/instruction_filter.py +19 -3
  14. camel/datagen/self_instruct/self_instruct.py +7 -2
  15. camel/datasets/__init__.py +28 -0
  16. camel/datasets/base.py +969 -0
  17. camel/embeddings/openai_embedding.py +10 -1
  18. camel/environments/__init__.py +16 -0
  19. camel/environments/base.py +503 -0
  20. camel/extractors/__init__.py +16 -0
  21. camel/extractors/base.py +263 -0
  22. camel/interpreters/docker/Dockerfile +12 -0
  23. camel/interpreters/docker_interpreter.py +19 -1
  24. camel/interpreters/subprocess_interpreter.py +42 -17
  25. camel/loaders/__init__.py +2 -0
  26. camel/loaders/mineru_extractor.py +250 -0
  27. camel/memories/agent_memories.py +16 -1
  28. camel/memories/blocks/chat_history_block.py +10 -2
  29. camel/memories/blocks/vectordb_block.py +1 -0
  30. camel/memories/context_creators/score_based.py +20 -3
  31. camel/memories/records.py +10 -0
  32. camel/messages/base.py +8 -8
  33. camel/models/_utils.py +57 -0
  34. camel/models/aiml_model.py +48 -17
  35. camel/models/anthropic_model.py +41 -3
  36. camel/models/azure_openai_model.py +39 -3
  37. camel/models/base_model.py +132 -4
  38. camel/models/cohere_model.py +88 -11
  39. camel/models/deepseek_model.py +107 -63
  40. camel/models/gemini_model.py +133 -15
  41. camel/models/groq_model.py +72 -10
  42. camel/models/internlm_model.py +14 -3
  43. camel/models/litellm_model.py +9 -2
  44. camel/models/mistral_model.py +42 -5
  45. camel/models/model_manager.py +48 -3
  46. camel/models/moonshot_model.py +33 -4
  47. camel/models/nemotron_model.py +32 -3
  48. camel/models/nvidia_model.py +43 -3
  49. camel/models/ollama_model.py +139 -17
  50. camel/models/openai_audio_models.py +7 -1
  51. camel/models/openai_compatible_model.py +37 -3
  52. camel/models/openai_model.py +158 -46
  53. camel/models/qwen_model.py +61 -4
  54. camel/models/reka_model.py +53 -3
  55. camel/models/samba_model.py +209 -4
  56. camel/models/sglang_model.py +153 -14
  57. camel/models/siliconflow_model.py +16 -3
  58. camel/models/stub_model.py +46 -4
  59. camel/models/togetherai_model.py +38 -3
  60. camel/models/vllm_model.py +37 -3
  61. camel/models/yi_model.py +36 -3
  62. camel/models/zhipuai_model.py +38 -3
  63. camel/retrievers/__init__.py +3 -0
  64. camel/retrievers/hybrid_retrival.py +237 -0
  65. camel/toolkits/__init__.py +9 -2
  66. camel/toolkits/arxiv_toolkit.py +2 -1
  67. camel/toolkits/ask_news_toolkit.py +4 -2
  68. camel/toolkits/base.py +22 -3
  69. camel/toolkits/code_execution.py +2 -0
  70. camel/toolkits/dappier_toolkit.py +2 -1
  71. camel/toolkits/data_commons_toolkit.py +38 -12
  72. camel/toolkits/function_tool.py +13 -0
  73. camel/toolkits/github_toolkit.py +5 -1
  74. camel/toolkits/google_maps_toolkit.py +2 -1
  75. camel/toolkits/google_scholar_toolkit.py +2 -0
  76. camel/toolkits/human_toolkit.py +0 -3
  77. camel/toolkits/linkedin_toolkit.py +3 -2
  78. camel/toolkits/meshy_toolkit.py +3 -2
  79. camel/toolkits/mineru_toolkit.py +178 -0
  80. camel/toolkits/networkx_toolkit.py +240 -0
  81. camel/toolkits/notion_toolkit.py +2 -0
  82. camel/toolkits/openbb_toolkit.py +3 -2
  83. camel/toolkits/reddit_toolkit.py +11 -3
  84. camel/toolkits/retrieval_toolkit.py +6 -1
  85. camel/toolkits/semantic_scholar_toolkit.py +2 -1
  86. camel/toolkits/stripe_toolkit.py +8 -2
  87. camel/toolkits/sympy_toolkit.py +44 -1
  88. camel/toolkits/video_toolkit.py +2 -0
  89. camel/toolkits/whatsapp_toolkit.py +3 -2
  90. camel/toolkits/zapier_toolkit.py +191 -0
  91. camel/types/__init__.py +2 -2
  92. camel/types/agents/__init__.py +16 -0
  93. camel/types/agents/tool_calling_record.py +52 -0
  94. camel/types/enums.py +3 -0
  95. camel/types/openai_types.py +16 -14
  96. camel/utils/__init__.py +2 -1
  97. camel/utils/async_func.py +2 -2
  98. camel/utils/commons.py +114 -1
  99. camel/verifiers/__init__.py +23 -0
  100. camel/verifiers/base.py +340 -0
  101. camel/verifiers/models.py +82 -0
  102. camel/verifiers/python_verifier.py +202 -0
  103. {camel_ai-0.2.21.dist-info → camel_ai-0.2.23a0.dist-info}/METADATA +273 -256
  104. {camel_ai-0.2.21.dist-info → camel_ai-0.2.23a0.dist-info}/RECORD +106 -85
  105. {camel_ai-0.2.21.dist-info → camel_ai-0.2.23a0.dist-info}/WHEEL +1 -1
  106. {camel_ai-0.2.21.dist-info → camel_ai-0.2.23a0.dist-info}/LICENSE +0 -0
camel/datasets/base.py ADDED
@@ -0,0 +1,969 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import os
16
+ import random
17
+ from typing import (
18
+ Any,
19
+ Callable,
20
+ Dict,
21
+ List,
22
+ Optional,
23
+ TypeVar,
24
+ Union,
25
+ )
26
+
27
+ import torch
28
+ from datasets import Dataset as HFDataset
29
+ from pydantic import BaseModel, Field, ValidationError
30
+ from torch.utils.data import DataLoader, Dataset
31
+
32
+ from camel.agents import ChatAgent
33
+ from camel.logger import get_logger
34
+ from camel.verifiers import BaseVerifier
35
+
36
+ logger = get_logger(__name__)
37
+
38
+
39
+ class DataPoint(BaseModel):
40
+ r"""A single data point in the dataset.
41
+
42
+ Attributes:
43
+ question (str): The primary question or issue to be addressed.
44
+ rationale (str): Logical reasoning or explanation behind the
45
+ answer.
46
+ final_answer (str): The final answer.
47
+ raw_markdown (Optional[str]): Raw markdown content for generating
48
+ rewards/hints. (default: :obj:`None`)
49
+ difficulty (Optional[str]): Difficulty level of the question.
50
+ (default: :obj:`None`)
51
+ metadata Optional[Dict[str, Any]]: Additional metadata about the data
52
+ point. (default: :obj:`None`)
53
+ """
54
+
55
+ question: str = Field(
56
+ ..., description="The primary question or issue to be addressed."
57
+ )
58
+ rationale: str = Field(
59
+ ..., description="Logical reasoning or explanation behind the answer."
60
+ )
61
+ final_answer: str = Field(..., description="The final answer.")
62
+ difficulty: Optional[str] = Field(
63
+ None, description="Difficulty level of the question."
64
+ )
65
+ metadata: Optional[Dict[str, Any]] = Field(
66
+ default=None, description="Additional metadata about the data point."
67
+ )
68
+
69
+ def to_dict(self) -> Dict[str, Any]:
70
+ r"""Convert DataPoint to a dictionary.
71
+
72
+ Returns:
73
+ Dict[str, Any]: Dictionary representation of the DataPoint.
74
+ """
75
+ return self.dict()
76
+
77
+ @classmethod
78
+ def from_dict(cls, data: Dict[str, Any]) -> 'DataPoint':
79
+ r"""Create a DataPoint from a dictionary.
80
+
81
+ Args:
82
+ data (Dict[str, Any]): Dictionary containing DataPoint fields.
83
+
84
+ Returns:
85
+ DataPoint: New DataPoint instance.
86
+ """
87
+ return cls(**data)
88
+
89
+
90
+ class BaseDataset(Dataset):
91
+ r"""A dataset contains questions and ground truth data for training.
92
+ It can be either static (e.g., MATH dataset) or generative
93
+ (using an LLM to generate questions).
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ data: List[Dict[str, str]],
99
+ cache_dir: Optional[str] = None,
100
+ **kwargs,
101
+ ):
102
+ r"""Initialize the dataset.
103
+
104
+ Args:
105
+ data (List[Dict[str, str]]): List of dictionary items to
106
+ create the dataset from.
107
+ cache_dir (Optional[str]): Directory to cache dataset files.
108
+ (default: :obj:`None`)
109
+ **kwargs: Additional dataset parameters.
110
+
111
+ Note:
112
+ The dataset must be initialized by calling setup() before use.
113
+ """
114
+ self._is_setup = False
115
+ self._raw_data: List[Dict[str, str]] = data if data is not None else []
116
+ self._cache_dir = str(cache_dir) if cache_dir is not None else None
117
+
118
+ # Store all parameters in metadata dict for compatibility
119
+ self._metadata = {
120
+ 'cache_dir': self._cache_dir,
121
+ **kwargs,
122
+ }
123
+
124
+ self.data: List[DataPoint] = [] # Will be populated in setup()
125
+
126
+ async def setup(self) -> None:
127
+ r"""Set up the dataset with necessary resources.
128
+
129
+ This method:
130
+ 1. Creates cache directory if needed
131
+ 2. Processes raw data into DataPoint objects using vectorized
132
+ operations
133
+ 3. Validates dataset integrity
134
+
135
+ Raises:
136
+ OSError: If cache directory creation fails.
137
+ Exception: If dataset initialization fails.
138
+ """
139
+ if self._is_setup:
140
+ logger.debug(f"{self.__class__.__name__} already initialized")
141
+ return
142
+
143
+ try:
144
+ # Create cache directory if specified
145
+ if self._cache_dir:
146
+ try:
147
+ os.makedirs(self._cache_dir, exist_ok=True)
148
+ logger.debug(f"Created cache directory: {self._cache_dir}")
149
+ except OSError as e:
150
+ logger.error(
151
+ f"Failed to create cache directory "
152
+ f"{self._cache_dir}: {e}"
153
+ )
154
+ raise
155
+
156
+ # Process raw data into DataPoint objects using vectorized
157
+ # operations
158
+ if not self._raw_data:
159
+ self.data = []
160
+ logger.debug("No raw data to process")
161
+ else:
162
+ try:
163
+ # Helper function for validation that can be used with map
164
+ def create_datapoint(item, idx=None):
165
+ try:
166
+ return DataPoint(
167
+ question=item.get('question', ''),
168
+ rationale=item.get('rationale', ''),
169
+ final_answer=item.get('final_answer', ''),
170
+ metadata=item.get('metadata', {})
171
+ if isinstance(item.get('metadata'), dict)
172
+ else {},
173
+ raw_markdown='',
174
+ difficulty='',
175
+ )
176
+ except ValidationError as e:
177
+ idx_str = (
178
+ f" at index {idx}" if idx is not None else ""
179
+ )
180
+ error_msg = (
181
+ f"Sample{idx_str} validation error: {e}"
182
+ )
183
+ logger.error(error_msg)
184
+ raise ValueError(error_msg)
185
+
186
+ # If raw_data is already a HF dataset, use its map function
187
+ if hasattr(self._raw_data, 'map') and callable(
188
+ self._raw_data.map
189
+ ):
190
+ # Using HF dataset's map for vectorized processing
191
+ processed_data = self._raw_data.map(
192
+ lambda example, idx: {
193
+ 'datapoint': create_datapoint(example, idx)
194
+ },
195
+ with_indices=True,
196
+ )
197
+ self.data = [
198
+ item['datapoint'] for item in processed_data
199
+ ]
200
+ else:
201
+ # Bulk create datapoints
202
+ self.data = [
203
+ create_datapoint(item, i)
204
+ for i, item in enumerate(self._raw_data)
205
+ ]
206
+
207
+ logger.debug(f"Processed {len(self.data)} data points")
208
+ except Exception as e:
209
+ logger.error(f"Error processing data: {e}")
210
+ raise
211
+
212
+ self._is_setup = True
213
+ logger.info(f"{self.__class__.__name__} initialized successfully")
214
+
215
+ except Exception as e:
216
+ logger.error(f"Error during {self.__class__.__name__} setup: {e}")
217
+ await self.cleanup()
218
+ raise
219
+
220
+ async def cleanup(self) -> None:
221
+ r"""Clean up dataset resources.
222
+
223
+ This method handles cleanup of resources and resets the dataset state.
224
+ """
225
+ if not self._is_setup:
226
+ return
227
+
228
+ try:
229
+ # Clear metadata while preserving init config
230
+ init_config = {
231
+ 'cache_dir': self._cache_dir,
232
+ }
233
+ self._metadata = init_config
234
+
235
+ logger.info(f"{self.__class__.__name__} cleaned up successfully")
236
+
237
+ except Exception as e:
238
+ logger.error(
239
+ f"Error during {self.__class__.__name__} cleanup: {e}"
240
+ )
241
+ raise
242
+
243
+ finally:
244
+ # Always mark as uninitialized, even if cleanup fails
245
+ self._is_setup = False
246
+
247
+ def sample(self) -> DataPoint:
248
+ r"""Sample a random datapoint from the dataset.
249
+
250
+ Returns:
251
+ DataPoint: A randomly sampled DataPoint.
252
+
253
+ Raises:
254
+ RuntimeError: If dataset is not initialized.
255
+ """
256
+ if not self._is_setup:
257
+ raise RuntimeError(
258
+ f"{self.__class__.__name__} must be initialized "
259
+ "before sampling"
260
+ )
261
+
262
+ idx = random.randint(0, len(self) - 1)
263
+ return self[idx]
264
+
265
+ def __len__(self) -> int:
266
+ r"""Return the size of the dataset."""
267
+ return len(self.data)
268
+
269
+ def __getitem__(self, idx: int) -> DataPoint:
270
+ r"""Get an item from the dataset.
271
+
272
+ Args:
273
+ idx (int): Index of the item to get.
274
+
275
+ Returns:
276
+ DataPoint: DataPoint from the dataset with the given index.
277
+
278
+ Raises:
279
+ IndexError: If idx is out of bounds.
280
+ """
281
+ if idx < 0 or idx >= len(self):
282
+ raise IndexError(
283
+ f"Index {idx} out of bounds for dataset of size {len(self)}"
284
+ )
285
+
286
+ return self.data[idx]
287
+
288
+ @property
289
+ def metadata(self) -> Dict[str, Any]:
290
+ r"""Get dataset metadata."""
291
+ return self._metadata.copy()
292
+
293
+ def to_pytorch_dataset(
294
+ self,
295
+ transform: Optional[Callable] = None,
296
+ target_transform: Optional[Callable] = None,
297
+ batch_size: Optional[int] = None,
298
+ ) -> Union["PyTorchDataset", "DataLoader"]:
299
+ r"""Convert to a PyTorch dataset or DataLoader.
300
+
301
+ Args:
302
+ transform (Optional[Callable]): Transform to apply to samples.
303
+ target_transform (Optional[Callable]): Transform to apply to
304
+ targets.
305
+ batch_size (Optional[int]): If provided, returns a DataLoader with
306
+ the specified batch size instead of a PyTorchDataset.
307
+ (default: :obj:`None`)
308
+
309
+ Returns:
310
+ Union[PyTorchDataset, torch.utils.data.DataLoader]: Dataset in
311
+ PyTorch format or DataLoader if batch_size is provided.
312
+ """
313
+ dataset = PyTorchDataset.from_datapoints(
314
+ self.data,
315
+ transform=transform,
316
+ target_transform=target_transform,
317
+ )
318
+
319
+ if batch_size is not None:
320
+ return dataset.get_dataloader(batch_size=batch_size)
321
+
322
+ return dataset
323
+
324
+
325
+ class SeedDataset(BaseDataset):
326
+ r"""A dataset containing validated seed examples for data generation.
327
+ Ensures that all items adhere to the DataPoint schema.
328
+
329
+ This class is used to initialize a dataset from a list of dictionary items,
330
+ validating each against the DataPoint schema.
331
+ """
332
+
333
+ def __init__(
334
+ self,
335
+ data: List[Dict[str, str]],
336
+ cache_dir: Optional[str] = None,
337
+ min_samples: int = 1,
338
+ **kwargs,
339
+ ):
340
+ r"""Initialize the seed dataset.
341
+
342
+ Args:
343
+ data (List[Dict[str, str]]): List of dictionary items to create the
344
+ dataset from.
345
+ cache_dir (Optional[str]): Directory to cache dataset files.
346
+ (default: :obj:`None`)
347
+ min_samples (int): Minimum number of samples required.
348
+ (default: :obj:`1`)
349
+ **kwargs: Additional dataset parameters.
350
+
351
+ Raises:
352
+ ValueError: If dataset size is less than min_samples or if sample
353
+ validation fails.
354
+ """
355
+ if len(data) < min_samples:
356
+ raise ValueError(
357
+ f"Seed dataset must contain at least {min_samples} samples."
358
+ )
359
+
360
+ super().__init__(
361
+ data=data,
362
+ cache_dir=cache_dir,
363
+ **kwargs,
364
+ )
365
+
366
+
367
+ class SyntheticDataset(BaseDataset):
368
+ r"""A dataset for storing synthetically generated data points.
369
+
370
+ This class is used to store datapoints that are generated through
371
+ a generative process, such as using an agent.
372
+ """
373
+
374
+ def __init__(
375
+ self,
376
+ data: Optional[List[Dict[str, str]]] = None,
377
+ cache_dir: Optional[str] = None,
378
+ **kwargs,
379
+ ):
380
+ r"""Initialize the synthetic dataset.
381
+
382
+ Args:
383
+ data (Optional[List[Dict[str, str]]]): List of dictionary items to
384
+ create the dataset from. (default: :obj:`None`)
385
+ cache_dir (Optional[str]): Directory to cache dataset files.
386
+ (default: :obj:`None`)
387
+ **kwargs: Additional dataset parameters.
388
+ """
389
+ super().__init__(
390
+ data=data if data is not None else [],
391
+ cache_dir=cache_dir,
392
+ **kwargs,
393
+ )
394
+ self.data: List[DataPoint] = []
395
+
396
+ def add(self, item: DataPoint) -> None:
397
+ r"""Add a new data point to the dataset.
398
+
399
+ Args:
400
+ item (DataPoint): The datapoint to add to the dataset.
401
+ """
402
+ self.data.append(item)
403
+
404
+ def add_batch(self, items: List[DataPoint]) -> None:
405
+ r"""Add multiple data points to the dataset.
406
+
407
+ Args:
408
+ items (List[DataPoint]): The datapoints to add to the dataset.
409
+ """
410
+ self.data.extend(items)
411
+
412
+ def to_pytorch_dataset(
413
+ self,
414
+ transform: Optional[Callable] = None,
415
+ target_transform: Optional[Callable] = None,
416
+ batch_size: Optional[int] = None,
417
+ ) -> Union["PyTorchDataset", "DataLoader"]:
418
+ r"""Convert to a PyTorch dataset or DataLoader.
419
+
420
+ Args:
421
+ transform (Optional[Callable]): Transform to apply to samples.
422
+ target_transform (Optional[Callable]): Transform to apply to
423
+ targets.
424
+ batch_size (Optional[int]): If provided, returns a DataLoader with
425
+ the specified batch size instead of a PyTorchDataset.
426
+ (default: :obj:`None`)
427
+
428
+ Returns:
429
+ Union[PyTorchDataset, torch.utils.data.DataLoader]: Dataset in
430
+ PyTorch format or DataLoader if batch_size is provided.
431
+ """
432
+ return convert_synthetic_to_pytorch(
433
+ self,
434
+ transform=transform,
435
+ target_transform=target_transform,
436
+ batch_size=batch_size,
437
+ )
438
+
439
+ def save_pytorch_format(self, path: str, compression: bool = True) -> None:
440
+ r"""Save the dataset to disk in PyTorch format.
441
+
442
+ Args:
443
+ path (str): Path to save the dataset to.
444
+ compression (bool): Whether to use compression to reduce file size.
445
+ (default: :obj:`True`)
446
+ """
447
+ save_synthetic_dataset(self, path, compression=compression)
448
+
449
+ def filter(
450
+ self, predicate: Callable[[DataPoint], bool]
451
+ ) -> 'SyntheticDataset':
452
+ r"""Filter the dataset using a predicate function.
453
+
454
+ Args:
455
+ predicate (Callable[[DataPoint], bool]): Function that takes a
456
+ DataPoint and returns True if it should be kept, False
457
+ otherwise.
458
+
459
+ Returns:
460
+ SyntheticDataset: A new dataset containing only the filtered items.
461
+ """
462
+ filtered_data = [dp for dp in self.data if predicate(dp)]
463
+
464
+ # Create a new dataset with the filtered data
465
+ new_dataset = SyntheticDataset()
466
+ new_dataset.add_batch(filtered_data)
467
+
468
+ return new_dataset
469
+
470
+
471
+ class GenerativeDataset(BaseDataset):
472
+ r"""A dataset for generating synthetic datapoints using external agents and
473
+ verifiers.
474
+
475
+ This class leverages a seed dataset and external components to generate
476
+ new synthetic datapoints on demand.
477
+ """
478
+
479
+ def __init__(
480
+ self,
481
+ seed_dataset: SeedDataset,
482
+ verifier: BaseVerifier,
483
+ agent: ChatAgent,
484
+ cache_dir: Optional[str] = None,
485
+ seed: int = 42,
486
+ **kwargs,
487
+ ):
488
+ r"""Initialize the generative dataset.
489
+
490
+ Args:
491
+ seed_dataset (SeedDataset): Validated dataset to use for examples.
492
+ verifier (BaseVerifier): Verifier to validate generated content.
493
+ agent (ChatAgent): Agent to generate new datapoints.
494
+ cache_dir (Optional[str]): Directory to cache dataset files.
495
+ (default: :obj:`None`)
496
+ seed (int): Random seed for reproducibility. (default: :obj:`42`)
497
+ **kwargs: Additional dataset parameters.
498
+ """
499
+ # Initialize with empty data since we'll generate content dynamically
500
+ super().__init__(
501
+ data=[],
502
+ cache_dir=cache_dir,
503
+ **kwargs,
504
+ )
505
+
506
+ self.seed_dataset = seed_dataset
507
+ self.verifier = verifier
508
+ self.agent = agent
509
+
510
+ self.seed = seed
511
+ random.seed(self.seed)
512
+
513
+ def _construct_prompt(self, examples: List[DataPoint]) -> str:
514
+ r"""Construct a prompt for generating new datapoints.
515
+
516
+ Args:
517
+ examples (List[DataPoint]): Examples to include in the prompt.
518
+
519
+ Returns:
520
+ str: Formatted prompt with examples.
521
+ """
522
+ prompt = (
523
+ "Generate a new datapoint similar to the following examples:\n\n"
524
+ )
525
+ for i, example in enumerate(examples, 1):
526
+ prompt += f"Example {i}:\n"
527
+ prompt += f"Question: {example.question}\n"
528
+ prompt += f"Rationale: {example.rationale}\n"
529
+ prompt += f"Final Answer: {example.final_answer}\n\n"
530
+ prompt += "New datapoint:"
531
+ return prompt
532
+
533
+ async def generate_new(self, n: int) -> None:
534
+ r"""Generate n new datapoints and add them to the dataset.
535
+
536
+ Args:
537
+ n (int): Number of valid datapoints to generate.
538
+
539
+ This method generates new datapoints by:
540
+ 1. Sampling examples from the seed dataset
541
+ 2. Constructing a prompt for the agent
542
+ 3. Generating a new datapoint using the agent
543
+ 4. Verifying the generated datapoint with the verifier
544
+ 5. Adding valid datapoints to the dataset
545
+ """
546
+ valid_data_points: List[DataPoint] = []
547
+
548
+ while len(valid_data_points) < n:
549
+ try:
550
+ indices = random.sample(range(len(self.seed_dataset)), 3)
551
+ examples = [self.seed_dataset[i] for i in indices]
552
+ prompt = self._construct_prompt(examples)
553
+
554
+ # Get agent response
555
+ agent_output = (
556
+ self.agent.step(prompt, response_format=DataPoint)
557
+ .msgs[0]
558
+ .parsed
559
+ )
560
+
561
+ if not isinstance(agent_output, dict):
562
+ raise TypeError("Agent output must be a dictionary")
563
+ if (
564
+ 'question' not in agent_output
565
+ or 'rationale' not in agent_output
566
+ ):
567
+ raise KeyError(
568
+ "Agent output missing required keys: "
569
+ "'question' or 'rationale'"
570
+ )
571
+
572
+ rationale = agent_output['rationale']
573
+
574
+ # Verify the generated content
575
+ verifier_response = await self.verifier.verify(rationale)
576
+ if not hasattr(verifier_response, 'content'):
577
+ raise AttributeError(
578
+ "Verifier response missing 'content' attribute"
579
+ )
580
+
581
+ if not verifier_response.result:
582
+ continue
583
+
584
+ final_answer = verifier_response.result
585
+
586
+ # Create and validate the new datapoint
587
+ new_datapoint = {
588
+ 'question': agent_output['question'],
589
+ 'rationale': rationale,
590
+ 'final_answer': final_answer,
591
+ }
592
+
593
+ datapoint = DataPoint(**new_datapoint)
594
+ valid_data_points.append(datapoint)
595
+
596
+ except (TypeError, KeyError, AttributeError, ValidationError) as e:
597
+ logger.warning(
598
+ f"Error encountered during generation: {e}, retrying..."
599
+ )
600
+
601
+ # Add all valid datapoints to the dataset
602
+ for datapoint in valid_data_points:
603
+ self.data.append(datapoint)
604
+ logger.debug("Added new datapoint to dataset")
605
+
606
+
607
+ # Define a type variable for return type flexibility
608
+ T_co = TypeVar('T_co', covariant=True)
609
+
610
+
611
+ class PyTorchDataset(Dataset[T_co]):
612
+ r"""A PyTorch-compatible dataset implementation that leverages PyTorch's
613
+ efficient data handling capabilities.
614
+ """
615
+
616
+ def __init__(
617
+ self,
618
+ data: List[Dict[str, Any]],
619
+ transform: Optional[Callable] = None,
620
+ target_transform: Optional[Callable] = None,
621
+ validate: bool = True,
622
+ ):
623
+ r"""Initialize the PyTorch dataset.
624
+
625
+ Args:
626
+ data (List[Dict[str, Any]]): List of dictionary items to create
627
+ the dataset from.
628
+ transform (Optional[Callable]): A function/transform that takes a
629
+ sample and returns a transformed version for features.
630
+ (default: :obj:`None`)
631
+ target_transform (Optional[Callable]): A function/transform that
632
+ takes a target and returns a transformed version. (default:
633
+ :obj:`None`)
634
+ validate (bool): Whether to validate data points using DataPoint
635
+ schema. (default: :obj:`True`)
636
+
637
+ Raises:
638
+ ValidationError: If validation is enabled and data doesn't match
639
+ DataPoint schema.
640
+ """
641
+ self.transform = transform
642
+ self.target_transform = target_transform
643
+
644
+ # Validate and store data
645
+ self._raw_data = data
646
+ self.data = []
647
+
648
+ if validate:
649
+ for i, item in enumerate(self._raw_data):
650
+ try:
651
+ # Use DataPoint for validation only
652
+ dp = DataPoint(**item)
653
+ self.data.append(dp.to_dict())
654
+ except ValidationError as e:
655
+ logger.error(f"Sample {i} validation error: {e}")
656
+ raise ValueError(f"Sample {i} validation error: {e}")
657
+ else:
658
+ # Skip validation and just store the data dictionaries
659
+ self.data = [dict(item) for item in self._raw_data]
660
+
661
+ def __getitem__(self, index: int) -> T_co:
662
+ r"""Get an item from the dataset.
663
+
664
+ Args:
665
+ index (int): Index of the item to get.
666
+
667
+ Returns:
668
+ T_co: Item from the dataset, possibly transformed.
669
+ """
670
+ sample = self.data[index]
671
+
672
+ # Apply transformations if provided
673
+ if self.transform is not None:
674
+ sample = self.transform(sample)
675
+
676
+ return sample # type: ignore[return-value]
677
+
678
+ def __len__(self) -> int:
679
+ r"""Return the size of the dataset.
680
+
681
+ Returns:
682
+ int: Number of samples in the dataset.
683
+ """
684
+ return len(self.data)
685
+
686
+ @classmethod
687
+ def from_datapoints(
688
+ cls, datapoints: List[DataPoint], **kwargs
689
+ ) -> 'PyTorchDataset':
690
+ r"""Create a PyTorchDataset from a list of DataPoints.
691
+
692
+ Args:
693
+ datapoints (List[DataPoint]): List of DataPoint objects.
694
+ **kwargs: Additional arguments to pass to the constructor.
695
+
696
+ Returns:
697
+ PyTorchDataset: A new PyTorchDataset instance.
698
+ """
699
+ data = [dp.to_dict() for dp in datapoints]
700
+ # We can skip validation since datapoints are already validated
701
+ return cls(data, validate=False, **kwargs)
702
+
703
+ def to_hf_dataset(self) -> HFDataset:
704
+ r"""Convert to a HuggingFace dataset.
705
+
706
+ Returns:
707
+ HFDataset: Dataset in HuggingFace format.
708
+ """
709
+ return HFDataset.from_list(self.data)
710
+
711
+ def save_to_disk(self, path: str) -> None:
712
+ r"""Save the dataset to disk using PyTorch.
713
+
714
+ Args:
715
+ path (str): Path to save the dataset to.
716
+ """
717
+ torch.save(self.data, path)
718
+
719
+ @classmethod
720
+ def load_from_disk(
721
+ cls,
722
+ path: str,
723
+ transform: Optional[Callable] = None,
724
+ target_transform: Optional[Callable] = None,
725
+ ) -> 'PyTorchDataset':
726
+ r"""Load a dataset from disk.
727
+
728
+ Args:
729
+ path (str): Path to load the dataset from.
730
+ transform (Optional[Callable]): Transform to apply to samples.
731
+ (default: :obj:`None`)
732
+ target_transform (Optional[Callable]): Transform to apply to
733
+ targets. (default: :obj:`None`)
734
+
735
+ Returns:
736
+ PyTorchDataset: Loaded dataset.
737
+ """
738
+ data = torch.load(path)
739
+ return cls(
740
+ data,
741
+ transform=transform,
742
+ target_transform=target_transform,
743
+ validate=False,
744
+ )
745
+
746
+ @staticmethod
747
+ def collate_fn(
748
+ batch: List[Dict[str, Any]],
749
+ ) -> Dict[str, Union[List[Any], torch.Tensor]]:
750
+ r"""Collate function for PyTorch DataLoader.
751
+
752
+ Args:
753
+ batch (List[Dict[str, Any]]): Batch of samples from the dataset.
754
+
755
+ Returns:
756
+ Dict[str, Union[List[Any], torch.Tensor]]: Collated batch with
757
+ tensors for numerical data.
758
+ """
759
+ if not batch:
760
+ return {}
761
+
762
+ # Initialize result dictionary with keys from first item - start with
763
+ # lists only
764
+ result: Dict[str, List[Any]] = {k: [] for k in batch[0].keys()}
765
+
766
+ # Collect values by key
767
+ for item in batch:
768
+ for k, v in item.items():
769
+ result[k].append(v)
770
+
771
+ # Convert numeric/boolean lists to tensors where possible
772
+ result_with_tensors: Dict[str, Union[List[Any], torch.Tensor]] = {}
773
+ for k, v in result.items():
774
+ if all(isinstance(x, (int, float, bool)) for x in v):
775
+ try:
776
+ result_with_tensors[k] = torch.tensor(v)
777
+ except (ValueError, TypeError):
778
+ # Keep as list if tensor conversion fails
779
+ result_with_tensors[k] = v
780
+ else:
781
+ result_with_tensors[k] = v
782
+
783
+ return result_with_tensors
784
+
785
+ def get_dataloader(
786
+ self,
787
+ batch_size: int = 32,
788
+ shuffle: bool = True,
789
+ num_workers: int = 0,
790
+ pin_memory: bool = False,
791
+ **kwargs,
792
+ ) -> "DataLoader":
793
+ r"""Create a PyTorch DataLoader for this dataset.
794
+
795
+ Args:
796
+ batch_size (int): Batch size. (default: :obj:`32`)
797
+ shuffle (bool): Whether to shuffle the dataset. (default:
798
+ :obj:`True`)
799
+ num_workers (int): Number of workers for data loading. (default:
800
+ :obj:`0`)
801
+ pin_memory (bool): Whether to pin memory for faster GPU transfer.
802
+ (default: :obj:`False`)
803
+ **kwargs: Additional arguments to pass to DataLoader.
804
+
805
+ Returns:
806
+ torch.utils.data.DataLoader: DataLoader for this dataset.
807
+ """
808
+ from torch.utils.data import DataLoader
809
+
810
+ return DataLoader(
811
+ self,
812
+ batch_size=batch_size,
813
+ shuffle=shuffle,
814
+ num_workers=num_workers,
815
+ collate_fn=self.collate_fn,
816
+ pin_memory=pin_memory,
817
+ **kwargs,
818
+ )
819
+
820
+
821
+ def convert_hf_to_pytorch(
822
+ hf_dataset: HFDataset,
823
+ transform: Optional[Callable] = None,
824
+ target_transform: Optional[Callable] = None,
825
+ column_mapping: Optional[Dict[str, str]] = None,
826
+ validate: bool = True,
827
+ batch_size: Optional[int] = None,
828
+ ) -> Union["PyTorchDataset", "DataLoader"]:
829
+ r"""Convert a HuggingFace dataset to a PyTorchDataset or DataLoader.
830
+
831
+ This function maps HuggingFace dataset columns to the expected DataPoint
832
+ format, validates the data, and creates a PyTorchDataset or DataLoader.
833
+
834
+ Args:
835
+ hf_dataset (HFDataset): HuggingFace dataset to convert.
836
+ transform (Optional[Callable]): Transform to apply to samples.
837
+ target_transform (Optional[Callable]): Transform to apply to targets.
838
+ column_mapping (Optional[Dict[str, str]]): Mapping from HuggingFace
839
+ column names to DataPoint field names. If None, assumes columns
840
+ already match DataPoint fields.
841
+ validate (bool): Whether to validate data points using DataPoint
842
+ schema. (default: :obj:`True`)
843
+ batch_size (Optional[int]): If provided, returns a DataLoader with the
844
+ specified batch size instead of a PyTorchDataset. (default:
845
+ :obj:`None`)
846
+
847
+ Returns:
848
+ Union[PyTorchDataset, torch.utils.data.DataLoader]: Converted dataset
849
+ or DataLoader if batch_size is provided.
850
+ """
851
+ # Convert HuggingFace dataset to list of dicts more efficiently
852
+ mapped_dataset = []
853
+
854
+ for i in range(len(hf_dataset)):
855
+ item = hf_dataset[i]
856
+ if column_mapping is not None:
857
+ # Apply column mapping if provided
858
+ mapped_item = {}
859
+ for hf_col, dp_field in column_mapping.items():
860
+ if hf_col in item:
861
+ mapped_item[dp_field] = item[hf_col]
862
+ mapped_dataset.append(mapped_item)
863
+ else:
864
+ # Otherwise use item directly
865
+ mapped_dataset.append(dict(item))
866
+
867
+ # Create PyTorchDataset
868
+ dataset: PyTorchDataset = PyTorchDataset(
869
+ mapped_dataset,
870
+ transform=transform,
871
+ target_transform=target_transform,
872
+ validate=validate,
873
+ )
874
+
875
+ # Return DataLoader if batch_size is provided
876
+ if batch_size is not None:
877
+ return dataset.get_dataloader(batch_size=batch_size)
878
+
879
+ return dataset
880
+
881
+
882
+ def convert_synthetic_to_pytorch(
883
+ synthetic_dataset: 'SyntheticDataset',
884
+ transform: Optional[Callable] = None,
885
+ target_transform: Optional[Callable] = None,
886
+ batch_size: Optional[int] = None,
887
+ ) -> Union["PyTorchDataset", "DataLoader"]:
888
+ r"""Convert a SyntheticDataset to a PyTorchDataset or DataLoader.
889
+
890
+ Args:
891
+ synthetic_dataset (SyntheticDataset): Synthetic dataset to convert.
892
+ transform (Optional[Callable]): Transform to apply to samples.
893
+ target_transform (Optional[Callable]): Transform to apply to targets.
894
+ batch_size (Optional[int]): If provided, returns a DataLoader with the
895
+ specified batch size instead of a PyTorchDataset. (default:
896
+ :obj:`None`)
897
+
898
+ Returns:
899
+ Union[PyTorchDataset, torch.utils.data.DataLoader]: Converted dataset
900
+ or DataLoader if batch_size is provided.
901
+ """
902
+ dataset = PyTorchDataset.from_datapoints(
903
+ synthetic_dataset.data,
904
+ transform=transform,
905
+ target_transform=target_transform,
906
+ )
907
+
908
+ # Return DataLoader if batch_size is provided
909
+ if batch_size is not None:
910
+ return dataset.get_dataloader(batch_size=batch_size)
911
+
912
+ return dataset
913
+
914
+
915
+ def save_synthetic_dataset(
916
+ synthetic_dataset: 'SyntheticDataset',
917
+ path: str,
918
+ compression: bool = True,
919
+ ) -> None:
920
+ r"""Save a synthetic dataset to disk using PyTorch format.
921
+
922
+ Args:
923
+ synthetic_dataset (SyntheticDataset): Dataset to save.
924
+ path (str): Path to save the dataset to.
925
+ compression (bool): Whether to use compression to reduce file size.
926
+ (default: :obj:`True`)
927
+ """
928
+ pytorch_dataset = convert_synthetic_to_pytorch(synthetic_dataset)
929
+
930
+ # Save with compression if enabled (uses less disk space)
931
+ if compression:
932
+ torch.save(
933
+ pytorch_dataset.data, # type: ignore[union-attr]
934
+ path,
935
+ _use_new_zipfile_serialization=True,
936
+ )
937
+ else:
938
+ pytorch_dataset.save_to_disk(path) # type: ignore[union-attr]
939
+
940
+
941
+ def load_pytorch_dataset(
942
+ path: str,
943
+ transform: Optional[Callable] = None,
944
+ target_transform: Optional[Callable] = None,
945
+ batch_size: Optional[int] = None,
946
+ ) -> Union["PyTorchDataset", "DataLoader"]:
947
+ r"""Load a PyTorchDataset from disk.
948
+
949
+ Args:
950
+ path (str): Path to load the dataset from.
951
+ transform (Optional[Callable]): Transform to apply to samples.
952
+ target_transform (Optional[Callable]): Transform to apply to targets.
953
+ batch_size (Optional[int]): If provided, returns a DataLoader with the
954
+ specified batch size instead of a PyTorchDataset. (default:
955
+ :obj:`None`)
956
+
957
+ Returns:
958
+ Union[PyTorchDataset, torch.utils.data.DataLoader]: Loaded dataset or
959
+ DataLoader if batch_size is provided.
960
+ """
961
+ dataset = PyTorchDataset.load_from_disk(
962
+ path, transform=transform, target_transform=target_transform
963
+ )
964
+
965
+ # Return DataLoader if batch_size is provided
966
+ if batch_size is not None:
967
+ return dataset.get_dataloader(batch_size=batch_size)
968
+
969
+ return dataset