camel-ai 0.2.33__py3-none-any.whl → 0.2.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

@@ -13,45 +13,63 @@
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
 
15
15
  import abc
16
+ import asyncio
16
17
  import json
17
18
  import random
18
- from datetime import datetime
19
19
  from pathlib import Path
20
- from typing import (
21
- List,
22
- Union,
23
- )
20
+ from typing import Any, Dict, List, Union
24
21
 
25
22
  from pydantic import ValidationError
23
+ from torch.utils.data import IterableDataset
26
24
 
27
- from camel.agents import ChatAgent
28
25
  from camel.logger import get_logger
29
- from camel.verifiers import BaseVerifier
30
- from camel.verifiers.models import VerifierInput
31
26
 
32
27
  from .models import DataPoint
33
- from .static_dataset import StaticDataset
34
28
 
35
29
  logger = get_logger(__name__)
36
30
 
37
31
 
38
- class BaseGenerator(abc.ABC):
32
+ class BaseGenerator(abc.ABC, IterableDataset):
39
33
  r"""Abstract base class for data generators.
40
34
 
41
35
  This class defines the interface for generating synthetic datapoints.
42
36
  Concrete implementations should provide specific generation strategies.
43
37
  """
44
38
 
45
- def __init__(self, seed: int = 42, **kwargs):
39
+ def __init__(
40
+ self,
41
+ seed: int = 42,
42
+ cache: Union[str, Path, None] = None,
43
+ data_path: Union[str, Path, None] = None,
44
+ **kwargs,
45
+ ):
46
46
  r"""Initialize the base generator.
47
47
 
48
48
  Args:
49
49
  seed (int): Random seed for reproducibility. (default: :obj:`42`)
50
+ cache (Union[str, Path, None]): Optional path to save generated
51
+ datapoints during iteration. If None is provided, datapoints
52
+ will be discarded every 100 generations.
53
+ data_path (Union[str, Path, None]): Optional path to a JSONL file
54
+ to initialize the dataset from.
50
55
  **kwargs: Additional generator parameters.
51
56
  """
52
57
  self._rng = random.Random(seed)
58
+ self.cache = Path(cache) if cache else None
53
59
 
54
60
  self._data: List[DataPoint] = []
61
+ self._batch_to_save: List[DataPoint] = []
62
+
63
+ if data_path:
64
+ file_path = Path(data_path)
65
+ raw_data = self._init_from_jsonl(file_path)
66
+ try:
67
+ data_points = [DataPoint(**item) for item in raw_data]
68
+ self._data.extend(data_points)
69
+ except ValidationError as e:
70
+ raise ValueError(
71
+ f"Failed to create DataPoint from JSONL data: {e}"
72
+ )
55
73
 
56
74
  @abc.abstractmethod
57
75
  async def generate_new(self, n: int, **kwargs) -> List[DataPoint]:
@@ -66,34 +84,112 @@ class BaseGenerator(abc.ABC):
66
84
  """
67
85
  pass
68
86
 
69
- def __len__(self) -> int:
70
- r"""Return the size of the generated dataset."""
71
- return len(self._data)
87
+ def __aiter__(self):
88
+ r"""Async iterator that yields datapoints dynamically.
72
89
 
73
- def __getitem__(self, idx: int) -> DataPoint:
74
- r"""Retrieve a datapoint by index.
90
+ If a `data_path` was provided during initialization, those datapoints
91
+ are yielded first. When self._data is empty, 20 new datapoints
92
+ are generated. Every 100 yields, the batch is appended to the
93
+ JSONL file or discarded if `cache` is None.
75
94
 
76
- Args:
77
- idx (int): Index of the datapoint.
95
+ Yields:
96
+ DataPoint: A single datapoint.
97
+ """
78
98
 
79
- Returns:
80
- DataPoint: The datapoint corresponding to the given index.
99
+ async def generator():
100
+ while True:
101
+ if not self._data:
102
+ new_datapoints = await self.generate_new(20)
103
+ self._data.extend(new_datapoints)
104
+ datapoint = self._data.pop(0)
105
+ yield datapoint
106
+ self._batch_to_save.append(datapoint)
107
+ if len(self._batch_to_save) == 100:
108
+ if self.cache:
109
+ with self.cache.open("a", encoding="utf-8") as f:
110
+ for dp in self._batch_to_save:
111
+ json.dump(dp.to_dict(), f, ensure_ascii=False)
112
+ f.write("\n")
113
+ self._batch_to_save = []
114
+
115
+ return generator()
116
+
117
+ def __iter__(self):
118
+ r"""Synchronous iterator for PyTorch IterableDataset compatibility.
119
+
120
+ If a `data_path` was provided during initialization, those datapoints
121
+ are yielded first. When self._data is empty, 20 new datapoints
122
+ are generated. Every 100 yields, the batch is appended to the
123
+ JSONL file or discarded if `cache` is None.
124
+
125
+ Yields:
126
+ DataPoint: A single datapoint.
127
+ """
128
+ try:
129
+ if asyncio.get_event_loop().is_running():
130
+ raise RuntimeError(
131
+ "Cannot use synchronous iteration (__iter__) in an async "
132
+ "context; use 'async for' with __aiter__ instead"
133
+ )
134
+ except RuntimeError as e:
135
+ if "no running event loop" not in str(e):
136
+ raise
137
+
138
+ while True:
139
+ if not self._data:
140
+ new_datapoints = asyncio.run(self.generate_new(20))
141
+ self._data.extend(new_datapoints)
142
+ datapoint = self._data.pop(0)
143
+ yield datapoint
144
+ self._batch_to_save.append(datapoint)
145
+ if len(self._batch_to_save) == 100:
146
+ if self.cache:
147
+ with self.cache.open("a", encoding="utf-8") as f:
148
+ for dp in self._batch_to_save:
149
+ json.dump(dp.to_dict(), f, ensure_ascii=False)
150
+ f.write("\n")
151
+ self._batch_to_save = []
152
+
153
+ def sample(self) -> DataPoint:
154
+ r"""Returns the next datapoint from the current dataset
155
+ synchronously.
81
156
 
82
157
  Raises:
83
- IndexError: If idx is out of bounds.
158
+ RuntimeError: If called in an async context.
159
+
160
+ Returns:
161
+ DataPoint: The next DataPoint.
162
+
163
+ Note:
164
+ This method is intended for synchronous contexts.
165
+ Use 'async_sample' in asynchronous contexts to
166
+ avoid blocking or runtime errors.
84
167
  """
85
- if idx < 0 or idx >= len(self._data):
86
- raise IndexError(
87
- f"Index {idx} out of bounds for dataset of "
88
- f"size {len(self._data)}"
89
- )
90
- return self._data[idx]
168
+ try:
169
+ if asyncio.get_event_loop().is_running():
170
+ raise RuntimeError(
171
+ "Cannot use synchronous sampling (sample) "
172
+ "in an async context; use async_sample instead"
173
+ )
174
+ except RuntimeError as e:
175
+ if "no running event loop" not in str(e):
176
+ raise
91
177
 
92
- def sample(self) -> DataPoint:
93
- if len(self._data) == 0:
94
- raise RuntimeError("Dataset is empty, cannot sample.")
95
- idx = self._rng.randint(0, len(self._data) - 1)
96
- return self[idx]
178
+ return next(iter(self))
179
+
180
+ async def async_sample(self) -> DataPoint:
181
+ r"""Returns the next datapoint from the current dataset asynchronously.
182
+
183
+ Returns:
184
+ DataPoint: The next datapoint.
185
+
186
+ Note:
187
+ This method is intended for asynchronous contexts. Use 'sample'
188
+ in synchronous contexts.
189
+ """
190
+
191
+ async_iter = self.__aiter__()
192
+ return await async_iter.__anext__()
97
193
 
98
194
  def save_to_jsonl(self, file_path: Union[str, Path]) -> None:
99
195
  r"""Saves the generated datapoints to a JSONL (JSON Lines) file.
@@ -109,7 +205,7 @@ class BaseGenerator(abc.ABC):
109
205
 
110
206
  Notes:
111
207
  - Uses `self._data`, which contains the generated datapoints.
112
- - Overwrites the file if it already exists.
208
+ - Appends to the file if it already exists.
113
209
  - Ensures compatibility with large datasets by using JSONL format.
114
210
  """
115
211
  if not self._data:
@@ -118,218 +214,66 @@ class BaseGenerator(abc.ABC):
118
214
  file_path = Path(file_path)
119
215
 
120
216
  try:
121
- with file_path.open("w", encoding="utf-8") as f:
217
+ with file_path.open("a", encoding="utf-8") as f:
122
218
  for datapoint in self._data:
123
219
  json.dump(datapoint.to_dict(), f, ensure_ascii=False)
124
- f.write("\n") # Ensure each entry is on a new line
220
+ f.write("\n")
125
221
  logger.info(f"Dataset saved successfully to {file_path}")
126
222
  except IOError as e:
127
223
  logger.error(f"Error writing to file {file_path}: {e}")
128
224
  raise
129
225
 
130
-
131
- class FewShotGenerator(BaseGenerator):
132
- r"""A generator for creating synthetic datapoints using few-shot learning.
133
-
134
- This class leverages a seed dataset, an agent, and a verifier to generate
135
- new synthetic datapoints on demand through few-shot prompting.
136
- """
137
-
138
- def __init__(
139
- self,
140
- seed_dataset: StaticDataset,
141
- verifier: BaseVerifier,
142
- agent: ChatAgent,
143
- seed: int = 42,
144
- **kwargs,
145
- ):
146
- r"""Initialize the few-shot generator.
226
+ def flush(self, file_path: Union[str, Path]) -> None:
227
+ r"""Flush the current data to a JSONL file and clear the data.
147
228
 
148
229
  Args:
149
- seed_dataset (StaticDataset): Validated static dataset to
150
- use for examples.
151
- verifier (BaseVerifier): Verifier to validate generated content.
152
- agent (ChatAgent): Agent to generate new datapoints.
153
- seed (int): Random seed for reproducibility. (default: :obj:`42`)
154
- **kwargs: Additional generator parameters.
155
- """
156
- super().__init__(seed=seed, **kwargs)
157
- self.seed_dataset = seed_dataset
158
- try:
159
- self._validate_seed_dataset()
160
- except Exception:
161
- raise RuntimeError("Seed Data does not follow Datapoint format")
162
- self.verifier = verifier
163
- self.agent = agent
164
-
165
- # TODO: Validate that seed dataset contains rationale
166
- def _validate_seed_dataset(self) -> None:
167
- pass
230
+ file_path (Union[str, Path]): Path to save the JSONL file.
168
231
 
169
- def _construct_prompt(self, examples: List[DataPoint]) -> str:
170
- r"""Construct a prompt for generating new datapoints
171
- using a fixed sample of examples from the seed dataset.
232
+ Notes:
233
+ - Uses `save_to_jsonl` to save `self._data`.
234
+ """
172
235
 
173
- Args:
174
- examples (List[DataPoint]): Examples to include in the prompt.
236
+ self.save_to_jsonl(file_path)
237
+ self._data = []
238
+ logger.info(f"Data flushed to {file_path} and cleared from the memory")
175
239
 
176
- Returns:
177
- str: Formatted prompt with examples.
178
- """
179
- prompt = (
180
- "Generate a new datapoint similar to the following examples:\n\n"
181
- )
182
- for i, example in enumerate(examples, 1):
183
- prompt += f"Example {i}:\n"
184
- prompt += f"Question: {example.question}\n"
185
- if example.rationale is not None:
186
- prompt += f"Rationale: {example.rationale}\n"
187
- else:
188
- prompt += "Rationale: None\n"
189
- prompt += f"Final Answer: {example.final_answer}\n\n"
190
- prompt += "New datapoint:"
191
- return prompt
192
-
193
- async def generate_new(
194
- self,
195
- n: int,
196
- max_retries: int = 10,
197
- num_examples: int = 3,
198
- **kwargs,
199
- ) -> List[DataPoint]:
200
- r"""Generates and validates `n` new datapoints through
201
- few-shot prompting, with a retry limit.
202
-
203
- Steps:
204
- 1. Samples examples from the seed dataset.
205
- 2. Constructs a prompt using the selected examples.
206
- 3. Uses an agent to generate a new datapoint,
207
- consisting of a question and code to solve the question.
208
- 4. Executes code using a verifier to get pseudo ground truth.
209
- 5. Stores valid datapoints in memory.
240
+ def _init_from_jsonl(self, file_path: Path) -> List[Dict[str, Any]]:
241
+ r"""Load and parse a dataset from a JSONL file.
210
242
 
211
243
  Args:
212
- n (int): Number of valid datapoints to generate.
213
- max_retries (int): Maximum number of retries before stopping.
214
- (default: :obj:`10`)
215
- num_examples (int): Number of examples to sample from the
216
- seed dataset for few shot prompting.
217
- (default: :obj:`3`)
218
- **kwargs: Additional generation parameters.
244
+ file_path (Path): Path to the JSONL file.
219
245
 
220
246
  Returns:
221
- List[DataPoint]: A list of newly generated valid datapoints.
247
+ List[Dict[str, Any]]: A list of datapoint dictionaries.
222
248
 
223
249
  Raises:
224
- TypeError: If the agent's output is not a dictionary (or does not
225
- match the expected format).
226
- KeyError: If required keys are missing from the response.
227
- AttributeError: If the verifier response lacks attributes.
228
- ValidationError: If a datapoint fails schema validation.
229
- RuntimeError: If retries are exhausted before `n` valid datapoints
230
- are generated.
231
-
232
- Notes:
233
- - Retries on validation failures until `n` valid datapoints exist
234
- or `max_retries` is reached, whichever comes first.
235
- - If retries are exhausted before reaching `n`, a `RuntimeError`
236
- is raised.
237
- - Metadata includes a timestamp for tracking datapoint creation.
250
+ FileNotFoundError: If the specified JSONL file does not exist.
251
+ ValueError: If a line contains invalid JSON or is not a dictionary.
238
252
  """
239
- valid_data_points: List[DataPoint] = []
240
- retries = 0
241
-
242
- while len(valid_data_points) < n and retries < max_retries:
243
- try:
244
- examples = [
245
- self.seed_dataset.sample() for _ in range(num_examples)
246
- ]
247
- prompt = self._construct_prompt(examples)
248
-
249
- try:
250
- agent_output = (
251
- self.agent.step(prompt, response_format=DataPoint)
252
- .msgs[0]
253
- .parsed
254
- )
255
- if not isinstance(agent_output, dict):
256
- raise TypeError("Agent output must be a dictionary")
257
- if "question" not in agent_output:
258
- raise KeyError(
259
- "Missing 'question' in agent"
260
- f"output {agent_output}"
261
- )
262
- if "rationale" not in agent_output:
263
- raise KeyError(
264
- "Missing 'rationale' in agent"
265
- f"output {agent_output}"
266
- )
267
- except (TypeError, KeyError) as e:
268
- logger.warning(
269
- f"Agent output issue: {e}, retrying... "
270
- f"({retries + 1}/{max_retries})"
271
- )
272
- retries += 1
273
- continue
274
-
275
- rationale = agent_output.get("rationale")
276
-
277
- if not isinstance(rationale, str):
278
- raise TypeError(f"Rationale {rationale} is not a string.")
279
-
280
- try:
281
- verifier_response = await self.verifier.verify(
282
- VerifierInput(
283
- llm_response=rationale,
284
- ground_truth=None,
285
- )
286
- )
287
- if not verifier_response or not verifier_response.result:
288
- raise ValueError(
289
- "Verifier unsuccessful, response: "
290
- f"{verifier_response}"
291
- )
292
- except (ValueError, AttributeError) as e:
293
- logger.warning(
294
- f"Verifier issue: {e}, "
295
- f"retrying... ({retries + 1}/{max_retries})"
296
- )
297
- retries += 1
298
- continue
299
-
253
+ if not file_path.exists():
254
+ raise FileNotFoundError(f"JSONL file not found: {file_path}")
255
+
256
+ raw_data = []
257
+ logger.debug(f"Loading JSONL from {file_path}")
258
+ with file_path.open('r', encoding='utf-8') as f:
259
+ for line_number, line in enumerate(f, start=1):
260
+ line = line.strip()
261
+ if not line:
262
+ continue # Skip blank lines
300
263
  try:
301
- new_datapoint = DataPoint(
302
- question=agent_output["question"],
303
- rationale=rationale,
304
- final_answer=verifier_response.result,
305
- metadata={
306
- "synthetic": str(True),
307
- "created": datetime.now().isoformat(),
308
- "generator": "few_shot",
309
- },
264
+ record = json.loads(line)
265
+ except json.JSONDecodeError as e:
266
+ raise ValueError(
267
+ f"Invalid JSON on line {line_number} "
268
+ f"in file {file_path}: {e}"
310
269
  )
311
- except ValidationError as e:
312
- logger.warning(
313
- f"Datapoint validation failed: {e}, "
314
- f"retrying... ({retries + 1}/{max_retries})"
270
+ if not isinstance(record, dict):
271
+ raise ValueError(
272
+ f"Expected a dictionary at line {line_number}, "
273
+ f"got {type(record).__name__}"
315
274
  )
316
- retries += 1
317
- continue
318
-
319
- valid_data_points.append(new_datapoint)
320
-
321
- except Exception as e:
322
- logger.warning(
323
- f"Unexpected error: {e}, retrying..."
324
- f" ({retries + 1}/{max_retries})"
325
- )
326
- retries += 1
327
-
328
- if len(valid_data_points) < n:
329
- raise RuntimeError(
330
- f"Failed to generate {n} valid datapoints "
331
- f"after {max_retries} retries."
332
- )
333
-
334
- self._data.extend(valid_data_points)
335
- return valid_data_points
275
+ raw_data.append(record)
276
+ logger.info(
277
+ f"Successfully loaded {len(raw_data)} items from {file_path}"
278
+ )
279
+ return raw_data