camel-ai 0.2.29__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

camel/datasets/base.py DELETED
@@ -1,639 +0,0 @@
1
- # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
- # Licensed under the Apache License, Version 2.0 (the "License");
3
- # you may not use this file except in compliance with the License.
4
- # You may obtain a copy of the License at
5
- #
6
- # http://www.apache.org/licenses/LICENSE-2.0
7
- #
8
- # Unless required by applicable law or agreed to in writing, software
9
- # distributed under the License is distributed on an "AS IS" BASIS,
10
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
- # See the License for the specific language governing permissions and
12
- # limitations under the License.
13
- # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
-
15
- import json
16
- import random
17
- from datetime import datetime
18
- from pathlib import Path
19
- from typing import (
20
- Any,
21
- Dict,
22
- List,
23
- Optional,
24
- Sized,
25
- Union,
26
- )
27
-
28
- from datasets import Dataset as HFDataset
29
- from pydantic import BaseModel, Field, ValidationError
30
- from torch.utils.data import Dataset
31
-
32
- from camel.agents import ChatAgent
33
- from camel.logger import get_logger
34
- from camel.verifiers import BaseVerifier
35
- from camel.verifiers.models import VerifierInput
36
-
37
- logger = get_logger(__name__)
38
-
39
-
40
- class DataPoint(BaseModel):
41
- r"""A single data point in the dataset.
42
-
43
- Attributes:
44
- question (str): The primary question or issue to be addressed.
45
- rationale (Optional[str]): Logical reasoning or explanation behind the
46
- answer. (default: :obj:`None`)
47
- final_answer (str): The final answer.
48
- metadata Optional[Dict[str, Any]]: Additional metadata about the data
49
- point. (default: :obj:`None`)
50
- """
51
-
52
- question: str = Field(
53
- ..., description="The primary question or issue to be addressed."
54
- )
55
- rationale: Optional[str] = Field(
56
- default=None,
57
- description="Logical reasoning or explanation behind the answer.",
58
- )
59
- final_answer: str = Field(..., description="The final answer.")
60
-
61
- metadata: Optional[Dict[str, Any]] = Field(
62
- default=None, description="Additional metadata about the data point."
63
- )
64
-
65
- def to_dict(self) -> Dict[str, Any]:
66
- r"""Convert DataPoint to a dictionary.
67
-
68
- Returns:
69
- Dict[str, Any]: Dictionary representation of the DataPoint.
70
- """
71
- return self.dict()
72
-
73
- @classmethod
74
- def from_dict(cls, data: Dict[str, Any]) -> 'DataPoint':
75
- r"""Create a DataPoint from a dictionary.
76
-
77
- Args:
78
- data (Dict[str, Any]): Dictionary containing DataPoint fields.
79
-
80
- Returns:
81
- DataPoint: New DataPoint instance.
82
- """
83
- return cls(**data)
84
-
85
-
86
- class StaticDataset(Dataset):
87
- r"""A static dataset containing a list of datapoints.
88
- Ensures that all items adhere to the DataPoint schema.
89
- This dataset extends :obj:`Dataset` from PyTorch and should
90
- be used when its size is fixed at runtime.
91
-
92
- This class can initialize from Hugging Face Datasets,
93
- PyTorch Datasets, JSON file paths, or lists of dictionaries,
94
- converting them into a consistent internal format.
95
- """
96
-
97
- def __init__(
98
- self,
99
- data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]],
100
- seed: int = 42,
101
- min_samples: int = 1,
102
- strict: bool = False,
103
- **kwargs,
104
- ):
105
- r"""Initialize the static dataset and validate integrity.
106
-
107
- Args:
108
- data (:obj:`Union[HFDataset, Dataset,
109
- Path, List[Dict[str, Any]]]`):
110
- Input data, which can be one of the following:
111
- - A Hugging Face Dataset (:obj:`HFDataset`).
112
- - A PyTorch Dataset (:obj:`torch.utils.data.Dataset`).
113
- - A :obj:`Path` object representing a JSON file.
114
- - A list of dictionaries with :obj:`DataPoint`-compatible
115
- fields.
116
- seed (:obj:`int`): Random seed for reproducibility.
117
- Default is :obj:`42`.
118
- min_samples (:obj:`int`): Minimum required number of samples.
119
- Default is :obj:`1`.
120
- strict (:obj:`bool`): Whether to raise an error on invalid
121
- datapoints (:obj:`True`) or skip/filter them (:obj:`False`).
122
- Default is :obj:`False`.
123
- **kwargs: Additional dataset parameters.
124
-
125
- Raises:
126
- TypeError: If the input data type is unsupported.
127
- ValueError: If the dataset contains fewer than :obj:`min_samples`
128
- datapoints or if validation fails.
129
- FileNotFoundError: If the specified JSON file path does not exist.
130
- json.JSONDecodeError: If the JSON file contains invalid formatting.
131
- """
132
-
133
- # Store all parameters in metadata dict for compatibility
134
- self._metadata = {
135
- **kwargs,
136
- }
137
- self._rng = random.Random(seed)
138
- self._strict = strict
139
-
140
- self.data: List[DataPoint] = self._init_data(data)
141
- self._length = len(self.data)
142
-
143
- if self._length < min_samples:
144
- raise ValueError(
145
- "The dataset does not contain enough samples. "
146
- f"Need {max(0, min_samples)}, got {self._length}"
147
- )
148
-
149
- def _init_data(
150
- self, data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]
151
- ) -> List[DataPoint]:
152
- r"""Convert input data from various formats into a list of
153
- :obj:`DataPoint` instances.
154
-
155
- Args:
156
- data (:obj:`Union[
157
- HFDataset,
158
- Dataset,
159
- Path,
160
- List[Dict[str, Any]]
161
- ]`):
162
- Input dataset in one of the supported formats.
163
-
164
- Returns:
165
- :obj:`List[DataPoint]`: A list of validated :obj:`DataPoint`
166
- instances.
167
-
168
- Raises:
169
- TypeError: If the input data type is unsupported.
170
- """
171
-
172
- if isinstance(data, HFDataset):
173
- raw_data = self._init_from_hf_dataset(data)
174
- elif isinstance(data, Dataset):
175
- raw_data = self._init_from_pytorch_dataset(data)
176
- elif isinstance(data, Path):
177
- raw_data = self._init_from_json_path(data)
178
- elif isinstance(data, list):
179
- raw_data = self._init_from_list(data)
180
- else:
181
- raise TypeError("Unsupported data type")
182
-
183
- def create_datapoint(
184
- item: Dict[str, Any], idx: int
185
- ) -> Optional[DataPoint]:
186
- # Add type checks for required fields to make mypy happy
187
- question = item.get('question')
188
- if not isinstance(question, str):
189
- if self._strict:
190
- raise ValueError(
191
- f"Sample at index {idx} has invalid 'question': "
192
- f"expected str, got {type(question)}"
193
- )
194
- else:
195
- logger.warning(
196
- f"Skipping sample at index {idx}: invalid 'question'"
197
- )
198
- return None
199
-
200
- rationale = item.get('rationale')
201
- if not isinstance(rationale, str):
202
- if self._strict:
203
- raise ValueError(
204
- f"Sample at index {idx} has invalid 'rationale': "
205
- f"expected str, got {type(rationale)}"
206
- )
207
- else:
208
- logger.warning(
209
- f"Skipping sample at index {idx}: invalid 'rationale'"
210
- )
211
- return None
212
-
213
- final_answer = item.get('final_answer')
214
- if not isinstance(final_answer, str):
215
- if self._strict:
216
- raise ValueError(
217
- f"Sample at index {idx} has invalid 'final_answer': "
218
- f"expected str, got {type(final_answer)}"
219
- )
220
- else:
221
- logger.warning(
222
- f"Skipping sample at index {idx}: "
223
- "invalid 'final_answer'"
224
- )
225
- return None
226
-
227
- try:
228
- return DataPoint(
229
- question=question,
230
- rationale=rationale,
231
- final_answer=final_answer,
232
- metadata=item.get('metadata'),
233
- )
234
- except ValidationError as e:
235
- if self._strict:
236
- raise ValueError(
237
- f"Sample at index {idx} validation error: {e}"
238
- )
239
- else:
240
- logger.warning(
241
- f"Skipping invalid sample at index {idx} "
242
- f"due to validation error: {e}"
243
- )
244
- return None
245
-
246
- unfiltered_data = [
247
- create_datapoint(item, i) for i, item in enumerate(raw_data)
248
- ]
249
- return [dp for dp in unfiltered_data if dp is not None]
250
-
251
- def __len__(self) -> int:
252
- r"""Return the size of the dataset."""
253
- return self._length
254
-
255
- def __getitem__(self, idx: int) -> DataPoint:
256
- r"""Retrieve a datapoint by index.
257
-
258
- Args:
259
- idx (:obj:`int`): Index of the datapoint.
260
-
261
- Returns:
262
- :obj:`DataPoint`: The datapoint corresponding to the given index.
263
-
264
- Raises:
265
- IndexError: If :obj:`idx` is out of bounds (negative or greater
266
- than dataset length - 1).
267
- """
268
-
269
- if idx < 0 or idx >= self._length:
270
- raise IndexError(
271
- f"Index {idx} out of bounds for dataset of size {self._length}"
272
- )
273
- return self.data[idx]
274
-
275
- def sample(self) -> DataPoint:
276
- r"""Sample a random datapoint from the dataset.
277
-
278
- Returns:
279
- :obj:`DataPoint`: A randomly sampled :obj:`DataPoint`.
280
-
281
- Raises:
282
- RuntimeError: If the dataset is empty and no samples can be drawn.
283
- """
284
-
285
- if self._length == 0:
286
- raise RuntimeError("Dataset is empty, cannot sample.")
287
- idx = self._rng.randint(0, self._length - 1)
288
- return self[idx]
289
-
290
- @property
291
- def metadata(self) -> Dict[str, Any]:
292
- r"""Retrieve dataset metadata.
293
-
294
- Returns:
295
- :obj:`Dict[str, Any]`: A copy of the dataset metadata dictionary.
296
- """
297
-
298
- return self._metadata.copy()
299
-
300
- def _init_from_hf_dataset(self, data: HFDataset) -> List[Dict[str, Any]]:
301
- r"""Convert a Hugging Face dataset into a list of dictionaries.
302
-
303
- Args:
304
- data (:obj:`HFDataset`): A Hugging Face dataset.
305
-
306
- Returns:
307
- :obj:`List[Dict[str, Any]]`: A list of dictionaries representing
308
- the dataset, where each dictionary corresponds to a datapoint.
309
- """
310
- return [dict(item) for item in data]
311
-
312
- def _init_from_pytorch_dataset(
313
- self, data: Dataset
314
- ) -> List[Dict[str, Any]]:
315
- r"""Convert a PyTorch dataset into a list of dictionaries.
316
-
317
- Args:
318
- data (:obj:`Dataset`): A PyTorch dataset.
319
-
320
- Returns:
321
- :obj:`List[Dict[str, Any]]`: A list of dictionaries representing
322
- the dataset.
323
-
324
- Raises:
325
- TypeError: If the dataset does not implement :obj:`__len__()`
326
- or contains non-dictionary elements.
327
- """
328
- if not isinstance(data, Sized):
329
- raise TypeError(
330
- f"{type(data).__name__} does not implement `__len__()`."
331
- )
332
- raw_data = []
333
-
334
- for i in range(len(data)):
335
- item = data[i]
336
- if not isinstance(item, dict):
337
- raise TypeError(
338
- f"Item at index {i} is not a dict: "
339
- f"got {type(item).__name__}"
340
- )
341
- raw_data.append(dict(item))
342
- return raw_data
343
-
344
- def _init_from_json_path(self, data: Path) -> List[Dict[str, Any]]:
345
- r"""Load and parse a dataset from a JSON file.
346
-
347
- Args:
348
- data (:obj:`Path`): Path to the JSON file.
349
-
350
- Returns:
351
- :obj:`List[Dict[str, Any]]`: A list of datapoint dictionaries.
352
-
353
- Raises:
354
- FileNotFoundError: If the specified JSON file does not exist.
355
- ValueError: If the JSON content is not a list of dictionaries.
356
- json.JSONDecodeError: If the JSON file has invalid formatting.
357
- """
358
-
359
- if not data.exists():
360
- raise FileNotFoundError(f"JSON file not found: {data}")
361
- try:
362
- logger.debug(f"Loading JSON from {data}")
363
- with data.open('r', encoding='utf-8') as f:
364
- loaded_data = json.load(f)
365
- logger.info(
366
- f"Successfully loaded {len(loaded_data)} items from {data}"
367
- )
368
- except json.JSONDecodeError as e:
369
- raise ValueError(f"Invalid JSON in file {data}: {e}")
370
- if not isinstance(loaded_data, list):
371
- raise ValueError("JSON file must contain a list of dictionaries")
372
- for i, item in enumerate(loaded_data):
373
- if not isinstance(item, dict):
374
- raise ValueError(
375
- f"Expected a dictionary at index {i}, "
376
- f"got {type(item).__name__}"
377
- )
378
- return loaded_data
379
-
380
- def _init_from_list(
381
- self, data: List[Dict[str, Any]]
382
- ) -> List[Dict[str, Any]]:
383
- r"""Validate and convert a list of dictionaries into a dataset.
384
-
385
- Args:
386
- data (:obj:`List[Dict[str, Any]]`): A list of dictionaries where
387
- each dictionary must be a valid :obj:`DataPoint`.
388
-
389
- Returns:
390
- :obj:`List[Dict[str, Any]]`: The validated list of dictionaries.
391
-
392
- Raises:
393
- ValueError: If any item in the list is not a dictionary.
394
- """
395
- for i, item in enumerate(data):
396
- if not isinstance(item, dict):
397
- raise ValueError(
398
- f"Expected a dictionary at index {i}, "
399
- f"got {type(item).__name__}"
400
- )
401
- return data
402
-
403
-
404
- class GenerativeDataset(Dataset):
405
- r"""A dataset for generating synthetic datapoints using external agents and
406
- verifiers.
407
-
408
- This class leverages a seed dataset and external components to generate
409
- new synthetic datapoints on demand.
410
- """
411
-
412
- def __init__(
413
- self,
414
- seed_dataset: StaticDataset,
415
- verifier: BaseVerifier,
416
- agent: ChatAgent,
417
- seed: int = 42,
418
- **kwargs,
419
- ):
420
- r"""Initialize the generative dataset.
421
-
422
- Args:
423
- seed_dataset (StaticDataset): Validated static dataset to
424
- use for examples.
425
- verifier (BaseVerifier): Verifier to validate generated content.
426
- agent (ChatAgent): Agent to generate new datapoints.
427
- seed (int): Random seed for reproducibility. (default: :obj:`42`)
428
- **kwargs: Additional dataset parameters.
429
- """
430
-
431
- self.seed_dataset = seed_dataset
432
- self.verifier = verifier
433
- self.agent = agent
434
-
435
- self.seed = seed
436
- random.seed(self.seed)
437
-
438
- self._data: List[DataPoint] = []
439
-
440
- def _construct_prompt(self, examples: List[DataPoint]) -> str:
441
- r"""Construct a prompt for generating new datapoints
442
- using a fixed sample of 3 examples from the seed dataset.
443
-
444
- Args:
445
- examples (List[DataPoint]): Examples to include in the prompt.
446
-
447
- Returns:
448
- str: Formatted prompt with examples.
449
- """
450
- prompt = (
451
- "Generate a new datapoint similar to the following examples:\n\n"
452
- )
453
- for i, example in enumerate(examples, 1):
454
- prompt += f"Example {i}:\n"
455
- prompt += f"Question: {example.question}\n"
456
- prompt += f"Rationale: {example.rationale}\n"
457
- prompt += f"Final Answer: {example.final_answer}\n\n"
458
- prompt += "New datapoint:"
459
- return prompt
460
-
461
- async def generate_new(
462
- self, n: int, max_retries: int = 10
463
- ) -> List[DataPoint]:
464
- r"""Generates and validates `n` new datapoints through
465
- few-shot prompting, with a retry limit.
466
-
467
- Steps:
468
- 1. Samples 3 examples from the seed dataset.
469
- 2. Constructs a prompt using the selected examples.
470
- 3. Uses an agent to generate a new datapoint.
471
- 4. Verifies the datapoint using a verifier.
472
- 5. Stores valid datapoints in memory.
473
-
474
- Args:
475
- n (int): Number of valid datapoints to generate.
476
- max_retries (int): Maximum number of retries before stopping.
477
-
478
- Returns:
479
- List[DataPoint]: A list of newly generated valid datapoints.
480
-
481
- Raises:
482
- TypeError: If the agent's output is not a dictionary (or does not
483
- match the expected format).
484
- KeyError: If required keys are missing from the response.
485
- AttributeError: If the verifier response lacks attributes.
486
- ValidationError: If a datapoint fails schema validation.
487
- RuntimeError: If retries are exhausted before `n` valid datapoints
488
- are generated.
489
-
490
- Notes:
491
- - Retries on validation failures until `n` valid datapoints exist
492
- or `max_retries` is reached, whichever comes first.
493
- - If retries are exhausted before reaching `n`, a `RuntimeError`
494
- is raised.
495
- - Metadata includes a timestamp for tracking datapoint creation.
496
- - This method can be overridden to implement custom generation.
497
- """
498
- valid_data_points: List[DataPoint] = []
499
- retries = 0
500
-
501
- while len(valid_data_points) < n and retries < max_retries:
502
- try:
503
- examples = [self.seed_dataset.sample() for _ in range(3)]
504
- prompt = self._construct_prompt(examples)
505
-
506
- try:
507
- agent_output = (
508
- self.agent.step(prompt, response_format=DataPoint)
509
- .msgs[0]
510
- .parsed
511
- )
512
- if not isinstance(agent_output, dict):
513
- raise TypeError("Agent output must be a dictionary")
514
- if (
515
- "question" not in agent_output
516
- or "rationale" not in agent_output
517
- ):
518
- raise KeyError(
519
- "Missing 'question' or 'rationale' in agent output"
520
- )
521
- except (TypeError, KeyError) as e:
522
- logger.warning(
523
- f"Agent output issue: {e}, retrying... "
524
- f"({retries + 1}/{max_retries})"
525
- )
526
- retries += 1
527
- continue
528
-
529
- rationale = agent_output["rationale"]
530
-
531
- try:
532
- verifier_response = await self.verifier.verify(
533
- VerifierInput(
534
- llm_response=rationale, ground_truth=None
535
- )
536
- )
537
- if not verifier_response or not verifier_response.result:
538
- raise ValueError(
539
- "Verifier unsuccessful, response: "
540
- f"{verifier_response}"
541
- )
542
- except (ValueError, AttributeError) as e:
543
- logger.warning(
544
- f"Verifier issue: {e}, "
545
- f"retrying... ({retries + 1}/{max_retries})"
546
- )
547
- retries += 1
548
- continue
549
-
550
- try:
551
- new_datapoint = DataPoint(
552
- question=agent_output["question"],
553
- rationale=rationale,
554
- final_answer=verifier_response.result,
555
- metadata={
556
- "synthetic": str(True),
557
- "created": datetime.now().isoformat(),
558
- },
559
- )
560
- except ValidationError as e:
561
- logger.warning(
562
- f"Datapoint validation failed: {e}, "
563
- f"retrying... ({retries + 1}/{max_retries})"
564
- )
565
- retries += 1
566
- continue
567
-
568
- valid_data_points.append(new_datapoint)
569
-
570
- except Exception as e:
571
- logger.warning(
572
- f"Unexpected error: {e}, retrying..."
573
- f" ({retries + 1}/{max_retries})"
574
- )
575
- retries += 1
576
-
577
- if len(valid_data_points) < n:
578
- raise RuntimeError(
579
- f"Failed to generate {n} valid datapoints "
580
- f"after {max_retries} retries."
581
- )
582
-
583
- self._data.extend(valid_data_points)
584
- return valid_data_points
585
-
586
- def __len__(self) -> int:
587
- r"""Return the size of the dataset."""
588
- return len(self._data)
589
-
590
- def __getitem__(self, idx: int) -> DataPoint:
591
- r"""Retrieve a datapoint by index.
592
-
593
- Args:
594
- idx (int): Index of the datapoint.
595
-
596
- Returns:
597
- DataPoint: The datapoint corresponding to the given index.
598
-
599
- Raises:
600
- IndexError: If idx is out of bounds.
601
- """
602
- if idx < 0 or idx >= len(self._data):
603
- raise IndexError(
604
- f"Index {idx} out of bounds for dataset of "
605
- f"size {len(self._data)}"
606
- )
607
- return self._data[idx]
608
-
609
- def save_to_jsonl(self, file_path: Union[str, Path]) -> None:
610
- r"""Saves the dataset to a JSONL (JSON Lines) file.
611
-
612
- Each datapoint is stored as a separate JSON object on a new line.
613
-
614
- Args:
615
- file_path (Union[str, Path]): Path to save the JSONL file.
616
-
617
- Raises:
618
- ValueError: If the dataset is empty.
619
- IOError: If there is an issue writing to the file.
620
-
621
- Notes:
622
- - Uses `self._data`, which contains the generated datapoints.
623
- - Overwrites the file if it already exists.
624
- - Ensures compatibility with large datasets by using JSONL format.
625
- """
626
- if not self._data:
627
- raise ValueError("Dataset is empty. No data to save.")
628
-
629
- file_path = Path(file_path)
630
-
631
- try:
632
- with file_path.open("w", encoding="utf-8") as f:
633
- for datapoint in self._data:
634
- json.dump(datapoint.to_dict(), f)
635
- f.write("\n") # Ensure each entry is on a new line
636
- logger.info(f"Dataset saved successfully to {file_path}")
637
- except IOError as e:
638
- logger.error(f"Error writing to file {file_path}: {e}")
639
- raise